Skip to content

Commit

Permalink
multi res support
Browse files Browse the repository at this point in the history
  • Loading branch information
liopeer committed Jan 12, 2024
1 parent 07b341d commit 052b328
Show file tree
Hide file tree
Showing 7 changed files with 1,709 additions and 101 deletions.
15 changes: 13 additions & 2 deletions diffusion_models/utils/datasets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Callable, Optional, Tuple
from torchvision.datasets import MNIST, CIFAR10, ImageNet
from torchvision.transforms import Compose, ToTensor, Normalize, Resize, CenterCrop, InterpolationMode
from torchvision.transforms import Compose, ToTensor, Normalize, Resize, CenterCrop, InterpolationMode, RandomCrop
from torchvision.transforms.functional import resize, center_crop
from typing import Any
from torch.utils.data import Dataset
Expand Down Expand Up @@ -113,7 +113,10 @@ def __init__(self, root: str, size: int=128) -> None:
for i in range(slices):
self.imgs.append({"file_name":file_name, "index":i})
file.close()
self.transform = Compose([ToTensor(), Resize((size, size), antialias=True)])
if size == 320:
self.transform = ToTensor()
else:
self.transform = Compose([ToTensor(), Resize((size, size), antialias=True)])

def __len__(self):
return len(self.imgs)
Expand All @@ -129,6 +132,14 @@ def __getitem__(self, index) -> Any:
x = x * (1 / x.max())
return (x, )

class FastMRIRandCrop(FastMRIBrainTrain):
def __init__(self, root: str, size: int = 320, crop_size = 128) -> None:
super().__init__(root, size)
self.transform = Compose([ToTensor(), RandomCrop((crop_size,crop_size))])

class FastMRIRandCropDebug(FastMRIRandCrop):
__len__ = lambda x: 500

class FastMRIDebug(FastMRIBrainTrain):
def __len__(self):
return 128
Expand Down
8 changes: 4 additions & 4 deletions eval_scripts/direct_freq_replacement.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def freq_replacement(
target_img = sampler._to_imgspace(target_kspace)
target_img, _ = sampler.model.fwd_diff(target_img, t)
target_kspace = sampler._to_kspace(target_img)
pred_kspace = pred_kspace + target_kspace * mask - pred_kspace * mask
return sampler._to_imgspace(pred_kspace)
pred_img = pred_img + sampler._to_imgspace(target_kspace * mask) - sampler._to_imgspace(pred_kspace * mask)
return pred_img
else:
raise ValueError("no such noising process")

Expand Down Expand Up @@ -63,10 +63,10 @@ def reconstruction2(sampler, corrupted_kspace, mask, process=Literal["kspace","i
kspace = sampler._to_kspace(samples)
kspace = kspace * mask

processes = ["kspace","imgspace"]
processes = ["imgspace"]
for process in processes:
kspace, mask = kspace.to(device), mask.to(device)
res = reconstruction2(sampler, kspace, mask, "imgspace")
res = reconstruction(sampler, kspace, mask, process)
save_image(make_grid(res, nrow=4), f"reconstruction_{process}.png")

save_image(make_grid(samples, nrow=4), "samples.png")
Expand Down
36 changes: 22 additions & 14 deletions eval_scripts/filter_schedule.ipynb

Large diffs are not rendered by default.

1,644 changes: 1,596 additions & 48 deletions eval_scripts/tests.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions tests/job.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
#SBATCH --account=student
#SBATCH --output=log/%j.out
#SBATCH --error=log/%j.err
#SBATCH --gres=gpu:2
#SBATCH --gres=gpu:4
#SBATCH --mem=64G
#SBATCH --job-name=mnist_double
#SBATCH --job-name=fmri320
#SBATCH --constraint='titan_xp|geforce_gtx_titan_x'

source /scratch_net/biwidl311/peerli/conda/etc/profile.d/conda.sh
Expand Down
56 changes: 48 additions & 8 deletions tests/noising_process.ipynb

Large diffs are not rendered by default.

47 changes: 24 additions & 23 deletions tests/train_generative.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,50 +12,51 @@
import torch.multiprocessing as mp
import os
from utils.mp_setup import DDP_Proc_Group
from utils.datasets import FastMRIBrainKSpaceDebug, FastMRIBrainKSpace, MNISTTrainDataset, MNISTDebugDataset, MNISTKSpace
from utils.datasets import FastMRIBrainKSpaceDebug, FastMRIBrainKSpace, MNISTTrainDataset, MNISTDebugDataset, MNISTKSpace, FastMRIRandCrop, FastMRIRandCropDebug
from utils.helpers import dotdict
import wandb
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR, CosineAnnealingWarmRestarts

config = dotdict(
world_size = 2,
total_epochs = 70,
world_size = 4,
total_epochs = 100,
log_wandb = True,
project = "mnist_gen_trials",
#data_path = "/itet-stor/peerli/bmicdatasets-originals/Originals/fastMRI/brain/multicoil_train",
data_path = "/itet-stor/peerli/net_scratch",
project = "fastMRI_gen_trials",
data_path = "/itet-stor/peerli/bmicdatasets-originals/Originals/fastMRI/brain/multicoil_train",
#data_path = "/itet-stor/peerli/net_scratch",
checkpoint_folder = "/itet-stor/peerli/net_scratch/run_name", # append wandb run name to this path
wandb_dir = "/itet-stor/peerli/net_scratch",
from_checkpoint = "/itet-stor/peerli/net_scratch/curious-river-16/checkpoint1.pt",
#from_checkpoint = "/itet-stor/peerli/net_scratch/curious-river-16/checkpoint1.pt",
from_checkpoint = False,
loss_func = F.mse_loss,
mixed_precision = True,
optimizer = torch.optim.AdamW,
lr_scheduler = "cosine_ann_warm",
cosine_ann_T_0 = 3,
cosine_ann_T_mult = 2,
k_space = True,
optimizer = torch.optim.Adam,
lr_scheduler = None,
#cosine_ann_T_0 = 3,
#cosine_ann_T_mult = 2,
k_space = False,
save_every = 1,
num_samples = 9,
batch_size = 64,
gradient_accumulation_rate = 8,
learning_rate = 0.001,
img_size = 32,
num_samples = 4,
batch_size = 48,
gradient_accumulation_rate = 4,
learning_rate = 0.0001,
img_size = 320,
device_type = "cuda",
in_channels = 2,
dataset = MNISTKSpace,
in_channels = 1,
dataset = FastMRIRandCrop,
architecture = DiffusionModel,
backbone = UNet,
attention = True,
attention = False,
attention_heads = 4,
attention_ff_dim = None,
unet_init_channels = 128,
unet_init_channels = 64,
activation = nn.SiLU,
backbone_enc_depth = 4,
backbone_enc_depth = 5,
kernel_size = 3,
dropout = 0.0,
forward_diff = ForwardDiffusion,
max_timesteps = 1000,
max_timesteps = 4000,
t_start = 0.0001,
t_end = 0.02,
offset = 0.008,
Expand Down

0 comments on commit 052b328

Please sign in to comment.