diff --git a/.gitignore b/.gitignore index f6b4615..723fe68 100644 --- a/.gitignore +++ b/.gitignore @@ -16,3 +16,6 @@ benchopt.ini .DS_Store coverage.xml + +tmp +data/ \ No newline at end of file diff --git a/benchmark_utils/custom_models.py b/benchmark_utils/custom_models.py new file mode 100644 index 0000000..98ebd65 --- /dev/null +++ b/benchmark_utils/custom_models.py @@ -0,0 +1,22 @@ +from deepinv.models import UNet + + +class MRIUNet(UNet): + def __init__(self, in_channels, out_channels, scales=3, batch_norm=False): + self.name = "MRIUNet" + self.in_channels = in_channels + + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + scales=scales, + batch_norm=batch_norm + ) + + def forward(self, x, sigma=None, **kwargs): + # Reshape for MRI specific processing + x = x.reshape(1, self.in_channels, x.shape[3], x.shape[4]) + + x = super().forward(x, sigma=sigma, **kwargs) + + return x diff --git a/benchmark_utils/denoiser_2c.py b/benchmark_utils/denoiser_2c.py new file mode 100644 index 0000000..82ec22f --- /dev/null +++ b/benchmark_utils/denoiser_2c.py @@ -0,0 +1,26 @@ +import torch +from deepinv.models import DRUNet +from deepinv.models import Denoiser + + +class Denoiser_2c(Denoiser): + def __init__(self, device): + super(Denoiser_2c, self).__init__() + self.model_c1 = DRUNet( + in_channels=1, out_channels=1, + pretrained="download", device=device + ) + self.model_c2 = DRUNet( + in_channels=1, out_channels=1, + pretrained="download", device=device + ) + + def forward(self, y, sigma): + y1, y2 = torch.split(y, 1, dim=1) + + x_hat_1 = self.model_c1(y1, sigma=sigma) + x_hat_2 = self.model_c2(y2, sigma=sigma) + + x_hat = torch.cat([x_hat_1, x_hat_2], dim=1) + + return x_hat diff --git a/benchmark_utils/fastmri_dataset.py b/benchmark_utils/fastmri_dataset.py new file mode 100644 index 0000000..1a180b7 --- /dev/null +++ b/benchmark_utils/fastmri_dataset.py @@ -0,0 +1,49 @@ +import torch +from torch.utils.data import Dataset +from deepinv.datasets import FastMRISliceDataset +import torch.nn.functional as F + + +class FastMRIDataset(Dataset): + def __init__(self, dataset: FastMRISliceDataset, mask, max_coils=32): + self.dataset = dataset + self.max_coils = max_coils + self.mask = mask + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + x, y = self.dataset[idx] + x, y = x.to(device=self.mask.device), y.to(device=self.mask.device) + + # Pad the width + target_width = 400 + pad_total = target_width - y.shape[3] + pad_left = pad_total // 2 + pad_right = pad_total - pad_left + y = F.pad(y, (pad_left, pad_right, 0, 0), mode='constant', value=0) + + # Pad the height + target_height = 700 + pad_total = target_height - y.shape[2] + pad_left = pad_total // 2 + pad_right = pad_total - pad_left + y = F.pad(y, (0, 0, pad_left, pad_right), mode='constant', value=0) + + # Transform the mask to match the kspace shape + mask = self.mask.repeat(y.shape[0], y.shape[1], 1, 1) + + # Apply the mask to the k-space data + y = y * mask + + # Add an imaginary part of zeros + x = torch.cat([x, torch.zeros_like(x)], dim=0) + + # Pad the coil dimension if necessary + coil_dim = y.shape[1] + if coil_dim < self.max_coils: + pad_size = self.max_coils - coil_dim + y = F.pad(y, (0, 0, 0, 0, 0, pad_size)) + + return x, y diff --git a/benchmark_utils/hugging_face_torch_dataset.py b/benchmark_utils/hugging_face_torch_dataset.py index 95edc10..9f251b9 100644 --- a/benchmark_utils/hugging_face_torch_dataset.py +++ b/benchmark_utils/hugging_face_torch_dataset.py @@ -2,19 +2,26 @@ class HuggingFaceTorchDataset(torch.utils.data.Dataset): - def __init__(self, hf_dataset, key, transform=None): + def __init__(self, hf_dataset, key, physics, device, transform=None): self.hf_dataset = hf_dataset self.transform = transform self.key = key + self.device = device + self.physics = physics def __len__(self): return len(self.hf_dataset) def __getitem__(self, idx): sample = self.hf_dataset[idx] - image = sample[self.key] # Image PIL + x = sample[self.key] # Image PIL if self.transform: - image = self.transform(image) + x = self.transform(x) - return image + x = x.to(self.device) + + y = self.physics(x.unsqueeze(0)) + y = y.squeeze(0) + + return x, y diff --git a/benchmark_utils/metrics.py b/benchmark_utils/metrics.py new file mode 100644 index 0000000..4fa48cc --- /dev/null +++ b/benchmark_utils/metrics.py @@ -0,0 +1,33 @@ +import deepinv as dinv + + +class CustomMSE(dinv.metric.MSE): + + transform = lambda x: x # noqa: E731 + + def forward(self, x_net=None, x=None, *args, **kwargs): + return super().forward(self.transform(x_net), x, *args, **kwargs) + + +class CustomPSNR(dinv.metric.PSNR): + + transform = lambda x: x # noqa: E731 + + def forward(self, x_net=None, x=None, *args, **kwargs): + return super().forward(self.transform(x_net), x, *args, **kwargs) + + +class CustomSSIM(dinv.metric.SSIM): + + transform = lambda x: x # noqa: E731 + + def forward(self, x_net=None, x=None, *args, **kwargs): + return super().forward(self.transform(x_net), x, *args, **kwargs) + + +class CustomLPIPS(dinv.metric.LPIPS): + + transform = lambda x: x # noqa: E731 + + def forward(self, x_net=None, x=None, *args, **kwargs): + return super().forward(self.transform(x_net), x, *args, **kwargs) diff --git a/config.yml b/config.yml new file mode 100644 index 0000000..c1d4eb0 --- /dev/null +++ b/config.yml @@ -0,0 +1,3 @@ +data_paths: + fastmri_train: /data/parietal/store3/data/fastMRI-multicoil/multicoil_train + fastmri_test: /data/parietal/store3/data/fastMRI-multicoil/multicoil_val diff --git a/datasets/bsd500_cbsd68.py b/datasets/bsd500_cbsd68.py index bdc5e8a..beeb3d9 100644 --- a/datasets/bsd500_cbsd68.py +++ b/datasets/bsd500_cbsd68.py @@ -9,7 +9,12 @@ from benchmark_utils.hugging_face_torch_dataset import ( HuggingFaceTorchDataset ) - from deepinv.physics import Denoising, GaussianNoise, Downsampling + from deepinv.physics import ( + Denoising, + GaussianNoise, + Downsampling, + Demosaicing + ) from deepinv.physics.generator import MotionBlurGenerator @@ -21,7 +26,9 @@ class Dataset(BaseDataset): 'task': ['denoising', 'gaussian-debluring', 'motion-debluring', - 'SRx4'], + 'SRx4', + 'inpainting', + 'demosaicing'], 'img_size': [256], } @@ -32,23 +39,24 @@ def get_data(self): device = ( dinv.utils.get_freer_gpu()) if torch.cuda.is_available() else "cpu" + n_channels = 3 + img_size = (n_channels, self.img_size, self.img_size) + if self.task == "denoising": - noise_level_img = 0.03 + noise_level_img = 0.1 physics = Denoising(GaussianNoise(sigma=noise_level_img)) elif self.task == "gaussian-debluring": filter_torch = dinv.physics.blur.gaussian_blur(sigma=(3, 3)) noise_level_img = 0.03 - n_channels = 3 physics = dinv.physics.BlurFFT( - img_size=(n_channels, self.img_size, self.img_size), + img_size=img_size, filter=filter_torch, noise_model=dinv.physics.GaussianNoise(sigma=noise_level_img), device=device ) elif self.task == "motion-debluring": psf_size = 31 - n_channels = 3 motion_generator = MotionBlurGenerator( (psf_size, psf_size), device=device @@ -57,18 +65,22 @@ def get_data(self): filters = motion_generator.step(batch_size=1) physics = dinv.physics.BlurFFT( - img_size=(n_channels, self.img_size, self.img_size), + img_size=img_size, filter=filters["filter"], device=device ) elif self.task == "SRx4": - n_channels = 3 - physics = Downsampling(img_size=(n_channels, - self.img_size, - self.img_size), + physics = Downsampling(img_size=img_size, filter="bicubic", factor=4, device=device) + elif self.task == "inpainting": + physics = dinv.physics.Inpainting(img_size, + mask=0.7, + device=device) + elif self.task == "demosaicing": + physics = Demosaicing(img_size=img_size, + device=device) else: raise Exception("Unknown task") @@ -78,41 +90,31 @@ def get_data(self): ]) path = get_data_path("BSD500") - train_dataset = dinv.datasets.BSDS500( + bsd500_dataset = dinv.datasets.BSDS500( path, download=True, transform=transform ) + train_dataset = HuggingFaceTorchDataset( + bsd500_dataset, + key=..., + physics=physics, + device=device, + transform=transforms.Resize((self.img_size, self.img_size)) + ) dataset_cbsd68 = load_dataset("deepinv/CBSD68") test_dataset = HuggingFaceTorchDataset( - dataset_cbsd68["train"], key="png", transform=transform - ) - - dinv_dataset_path = dinv.datasets.generate_dataset( - train_dataset=train_dataset, - test_dataset=test_dataset, + dataset_cbsd68["train"], + key="png", physics=physics, - save_dir=get_data_path("bsd500_cbsd68"), - dataset_filename=self.task, - device=device - ) - - train_dataset = dinv.datasets.HDF5Dataset( - path=dinv_dataset_path, train=True + device=device, + transform=transform ) - test_dataset = dinv.datasets.HDF5Dataset( - path=dinv_dataset_path, train=False - ) - - x, y = train_dataset[0] - dinv.utils.plot([x.unsqueeze(0), y.unsqueeze(0)]) - - x, y = test_dataset[0] - dinv.utils.plot([x.unsqueeze(0), y.unsqueeze(0)]) return dict( train_dataset=train_dataset, test_dataset=test_dataset, physics=physics, dataset_name="BSD68", - task_name=self.task + task_name=self.task, + image_size=img_size ) diff --git a/datasets/bsd500_imnet100.py b/datasets/bsd500_imnet100.py index d0f6959..642e671 100644 --- a/datasets/bsd500_imnet100.py +++ b/datasets/bsd500_imnet100.py @@ -8,7 +8,12 @@ from benchmark_utils.hugging_face_torch_dataset import ( HuggingFaceTorchDataset ) - from deepinv.physics import Downsampling, Denoising, GaussianNoise + from deepinv.physics import ( + Downsampling, + Denoising, + GaussianNoise, + Demosaicing + ) from deepinv.physics.generator import MotionBlurGenerator from datasets import load_dataset @@ -21,7 +26,9 @@ class Dataset(BaseDataset): 'task': ['denoising', 'gaussian-debluring', 'motion-debluring', - 'SRx4'], + 'SRx4', + 'inpainting', + 'demosaicing'], 'img_size': [256], } @@ -32,23 +39,24 @@ def get_data(self): device = ( dinv.utils.get_freer_gpu()) if torch.cuda.is_available() else "cpu" + n_channels = 3 + img_size = (n_channels, self.img_size, self.img_size) + if self.task == "denoising": - noise_level_img = 0.03 + noise_level_img = 0.1 physics = Denoising(GaussianNoise(sigma=noise_level_img)) elif self.task == "gaussian-debluring": filter_torch = dinv.physics.blur.gaussian_blur(sigma=(3, 3)) noise_level_img = 0.03 - n_channels = 3 physics = dinv.physics.BlurFFT( - img_size=(n_channels, self.img_size, self.img_size), + img_size=img_size, filter=filter_torch, noise_model=dinv.physics.GaussianNoise(sigma=noise_level_img), device=device ) elif self.task == "motion-debluring": psf_size = 31 - n_channels = 3 motion_generator = MotionBlurGenerator( (psf_size, psf_size), device=device @@ -57,18 +65,22 @@ def get_data(self): filters = motion_generator.step(batch_size=1) physics = dinv.physics.BlurFFT( - img_size=(n_channels, self.img_size, self.img_size), + img_size=img_size, filter=filters["filter"], device=device ) elif self.task == "SRx4": - n_channels = 3 - physics = Downsampling(img_size=(n_channels, - self.img_size, - self.img_size), + physics = Downsampling(img_size=img_size, filter="bicubic", factor=4, device=device) + elif self.task == "inpainting": + physics = dinv.physics.Inpainting(img_size, + mask=0.7, + device=device) + elif self.task == "demosaicing": + physics = Demosaicing(img_size=img_size, + device=device) else: raise Exception("Unknown task") @@ -78,43 +90,31 @@ def get_data(self): ]) path = get_data_path("BSD500") - train_dataset = dinv.datasets.BSDS500( + bsd500_dataset = dinv.datasets.BSDS500( path, download=True, transform=transform ) + train_dataset = HuggingFaceTorchDataset( + bsd500_dataset, + key=..., + physics=physics, + device=device, + transform=transforms.Resize((self.img_size, self.img_size)) + ) dataset_miniImnet100 = load_dataset("mterris/miniImnet100") test_dataset = HuggingFaceTorchDataset( dataset_miniImnet100["validation"], key="image", - transform=transform - ) - - dinv_dataset_path = dinv.datasets.generate_dataset( - train_dataset=train_dataset, - test_dataset=test_dataset, physics=physics, - save_dir=get_data_path("bsd500_imnet100"), - dataset_filename=self.task, - device=device - ) - - train_dataset = dinv.datasets.HDF5Dataset( - path=dinv_dataset_path, train=True - ) - test_dataset = dinv.datasets.HDF5Dataset( - path=dinv_dataset_path, train=False + device=device, + transform=transform ) - x, y = train_dataset[0] - dinv.utils.plot([x.unsqueeze(0), y.unsqueeze(0)]) - - x, y = test_dataset[0] - dinv.utils.plot([x.unsqueeze(0), y.unsqueeze(0)]) - return dict( train_dataset=train_dataset, test_dataset=test_dataset, physics=physics, dataset_name="BSD68", - task_name=self.task + task_name=self.task, + image_size=img_size ) diff --git a/datasets/cbsd68_set3c.py b/datasets/cbsd68_set3c.py index fadd716..4c36197 100644 --- a/datasets/cbsd68_set3c.py +++ b/datasets/cbsd68_set3c.py @@ -1,5 +1,4 @@ from benchopt import BaseDataset, safe_import_context -from benchopt.config import get_data_path with safe_import_context() as import_ctx: import deepinv as dinv @@ -9,7 +8,12 @@ from benchmark_utils.hugging_face_torch_dataset import ( HuggingFaceTorchDataset ) - from deepinv.physics import Denoising, GaussianNoise, Downsampling + from deepinv.physics import ( + Denoising, + GaussianNoise, + Downsampling, + Demosaicing, + ) from deepinv.physics.generator import MotionBlurGenerator @@ -21,7 +25,9 @@ class Dataset(BaseDataset): 'task': ['denoising', 'gaussian-debluring', 'motion-debluring', - 'SRx4'], + 'SRx4', + 'inpainting', + 'demosaicing'], 'img_size': [256], } @@ -33,23 +39,24 @@ def get_data(self): dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu" ) + n_channels = 3 + image_size = (n_channels, self.img_size, self.img_size) + if self.task == "denoising": - noise_level_img = 0.03 + noise_level_img = 0.1 physics = Denoising(GaussianNoise(sigma=noise_level_img)) elif self.task == "gaussian-debluring": filter_torch = dinv.physics.blur.gaussian_blur(sigma=(3, 3)) noise_level_img = 0.03 - n_channels = 3 physics = dinv.physics.BlurFFT( - img_size=(n_channels, self.img_size, self.img_size), + img_size=image_size, filter=filter_torch, noise_model=dinv.physics.GaussianNoise(sigma=noise_level_img), device=device ) elif self.task == "motion-debluring": psf_size = 31 - n_channels = 3 motion_generator = MotionBlurGenerator( (psf_size, psf_size), device=device @@ -58,18 +65,22 @@ def get_data(self): filters = motion_generator.step(batch_size=1) physics = dinv.physics.BlurFFT( - img_size=(n_channels, self.img_size, self.img_size), + img_size=image_size, filter=filters["filter"], device=device ) elif self.task == "SRx4": - n_channels = 3 - physics = Downsampling(img_size=(n_channels, - self.img_size, - self.img_size), + physics = Downsampling(img_size=image_size, filter="bicubic", factor=4, device=device) + elif self.task == "inpainting": + physics = dinv.physics.Inpainting(image_size, + mask=0.7, + device=device) + elif self.task == "demosaicing": + physics = Demosaicing(img_size=image_size, + device=device) else: raise Exception("Unknown task") @@ -80,42 +91,27 @@ def get_data(self): dataset_CBSD68 = load_dataset("deepinv/CBSD68") train_dataset = HuggingFaceTorchDataset( - dataset_CBSD68["train"], key="png", transform=transform + dataset_CBSD68["train"], + key="png", + physics=physics, + device=device, + transform=transform ) dataset_Set3c = load_dataset("deepinv/set3c") test_dataset = HuggingFaceTorchDataset( - dataset_Set3c["train"], key="image", transform=transform - ) - - dinv_dataset_path = dinv.datasets.generate_dataset( - train_dataset=train_dataset, - test_dataset=test_dataset, + dataset_Set3c["train"], + key="image", physics=physics, - save_dir=get_data_path("cbsd68_set3c"), - dataset_filename=self.task, - device=device - ) - - train_dataset = dinv.datasets.HDF5Dataset( - path=dinv_dataset_path, - train=True + device=device, + transform=transform ) - test_dataset = dinv.datasets.HDF5Dataset( - path=dinv_dataset_path, - train=False - ) - - x, y = train_dataset[0] - dinv.utils.plot([x.unsqueeze(0), y.unsqueeze(0)]) - - x, y = test_dataset[0] - dinv.utils.plot([x.unsqueeze(0), y.unsqueeze(0)]) return dict( train_dataset=train_dataset, test_dataset=test_dataset, physics=physics, dataset_name="Set3c", - task_name=self.task + task_name=self.task, + image_size=image_size ) diff --git a/datasets/fastmri.py b/datasets/fastmri.py new file mode 100644 index 0000000..00ca8b8 --- /dev/null +++ b/datasets/fastmri.py @@ -0,0 +1,59 @@ +from benchopt import BaseDataset, safe_import_context, config + +with safe_import_context() as import_ctx: + import deepinv as dinv + import torch + from benchmark_utils.fastmri_dataset import FastMRIDataset + +MAX_COILS = 32 # Maximum number of coils to pad to +KSPACE_PADDED_SIZE = (700, 400) # K-space size for FastMRI dataset + + +class Dataset(BaseDataset): + name = "FastMRI" + + parameters = {} + + def get_data(self): + device = "cpu" + if torch.cuda.is_available(): + device = dinv.utils.get_freer_gpu() + rng = torch.Generator(device=device).manual_seed(0) + + physics_generator = dinv.physics.generator.GaussianMaskGenerator( + img_size=KSPACE_PADDED_SIZE, acceleration=4, rng=rng, device=device + ) + mask = physics_generator.step( + batch_size=1, img_size=KSPACE_PADDED_SIZE + )["mask"] + + train_dataset = FastMRIDataset(dinv.datasets.FastMRISliceDataset( + config.get_data_path(key="fastmri_train"), slice_index="middle" + ), mask, MAX_COILS) + + test_dataset = FastMRIDataset(dinv.datasets.FastMRISliceDataset( + config.get_data_path(key="fastmri_test"), slice_index="middle" + ), mask, MAX_COILS) + + x, y = train_dataset[0] + + img_size, kspace_shape = x.shape[-2:], KSPACE_PADDED_SIZE + + physics = dinv.physics.MultiCoilMRI( + img_size=img_size, + mask=mask, + coil_maps=torch.ones( + (MAX_COILS,) + kspace_shape, + dtype=torch.complex64 + ), + device=device, + ) + + return dict( + train_dataset=train_dataset, + test_dataset=test_dataset, + physics=physics, + dataset_name="FastMRI", + task_name="MRI", + image_size=y.shape + ) diff --git a/datasets/simulated.py b/datasets/simulated.py index 27befd3..c4c7fb5 100644 --- a/datasets/simulated.py +++ b/datasets/simulated.py @@ -25,5 +25,6 @@ def get_data(self): test_dataset=test_dataset, physics=Denoising(GaussianNoise(sigma=0.03)), dataset_name="simulated", - task_name="test" + task_name="test", + image_size=(3, 32, 32) ) diff --git a/objective.py b/objective.py index 5ed8c3b..132f96a 100644 --- a/objective.py +++ b/objective.py @@ -7,6 +7,11 @@ import torch from torch.utils.data import DataLoader import deepinv as dinv + import torchvision + import torch.nn.functional as F + from benchmark_utils.metrics import CustomPSNR, CustomSSIM + from tqdm import tqdm + import time # The benchmark objective must be named `Objective` and @@ -43,7 +48,8 @@ def set_data(self, test_dataset, physics, dataset_name, - task_name): + task_name, + image_size): # The keyword arguments of this function are the keys of the dictionary # returned by `Dataset.get_data`. This defines the benchmark's # API to pass data. This is customizable for each benchmark. @@ -52,6 +58,7 @@ def set_data(self, self.physics = physics self.dataset_name = dataset_name self.task_name = task_name + self.image_size = image_size def evaluate_result(self, model, model_name, device): # The keyword arguments of this function are the keys of the @@ -59,49 +66,105 @@ def evaluate_result(self, model, model_name, device): # benchmark's API to pass solvers' result. This is customizable for # each benchmark. - batch_size = 2 + batch_size = 1 test_dataloader = DataLoader( self.test_dataset, batch_size=batch_size, shuffle=False ) - if isinstance(model, dinv.models.DeepImagePrior): - psnr = [] - ssim = [] - lpips = [] + # DeepImagePrior use images one by one, thus we can't use dinv.test + # if isinstance(model, dinv.models.DeepImagePrior): + psnr = [] + ssim = [] + lpips = [] + times = [] - for x, y in test_dataloader: - x, y = x.to(device), y.to(device) - x_hat = torch.cat([ + for x, y in tqdm(test_dataloader, desc=f"Evaluating {model_name}"): + x, y = x.to(device), y.to(device) + + if isinstance(model, dinv.models.DeepImagePrior): + start = time.time() + x_hat = [ model(y_i[None], self.physics) for y_i in y - ]) + ] + exec_time = time.time() - start + x_hat = torch.cat(x_hat) + else: + if ( + type(self.physics) is dinv.physics.blur.Downsampling + and model_name == 'U-Net' + ): + _, _, x_h, x_w = x.shape + _, _, y_h, y_w = y.shape + + diff_h = x_h - y_h + diff_w = x_w - y_w + + pad_top = diff_h // 2 + pad_bottom = diff_h - pad_top + pad_left = diff_w // 2 + pad_right = diff_w - pad_left + + y = F.pad( + y, + pad=(pad_left, pad_right, pad_top, pad_bottom), + value=0 + ) + + start = time.time() + x_hat = model(y, self.physics) + exec_time = time.time() - start + + times.append(exec_time) + + if (self.dataset_name == 'FastMRI'): + transform = torchvision.transforms.Compose( + [ + torchvision.transforms.CenterCrop(x.shape[-2:]), + dinv.metric.functional.complex_abs, + ] + ) + + CustomPSNR.transform = transform + + transform = torchvision.transforms.Compose( + [ + torchvision.transforms.CenterCrop(x.shape[-2:]), + ] + ) + + CustomSSIM.transform = transform + + psnr.append(CustomPSNR()(x_hat, x)) + else: psnr.append(dinv.metric.PSNR()(x_hat, x)) ssim.append(dinv.metric.SSIM()(x_hat, x)) lpips.append(dinv.metric.LPIPS(device=device)(x_hat, x)) - psnr = torch.mean(torch.cat(psnr)).item() + psnr = torch.mean(torch.cat(psnr)).item() + times = torch.mean(torch.tensor(times)).item() + + results = dict(PSNR=psnr) + + if self.dataset_name != 'FastMRI': ssim = torch.mean(torch.cat(ssim)).item() lpips = torch.mean(torch.cat(lpips)).item() + results['SSIM'] = ssim + results['LPIPS'] = lpips + + results['Time'] = times - results = dict(PSNR=psnr, SSIM=ssim, LPIPS=lpips) - else: - results = dinv.test( - model, - test_dataloader, - self.physics, - metrics=[dinv.metric.PSNR(), - dinv.metric.SSIM(), - dinv.metric.LPIPS(device=device)], - device=device - ) - - # This method can return many metrics in a dictionary. One of these - # metrics needs to be `value` for convergence detection purposes. - return dict( + values = dict( value=results["PSNR"], - ssim=results["SSIM"], - lpips=results["LPIPS"] ) + if self.dataset_name != 'FastMRI': + values['ssim'] = results["SSIM"] + values['lpips'] = results["LPIPS"] + + values['time'] = results["Time"] + + return values + def get_one_result(self): # Return one solution. The return value should be an object compatible # with `self.evaluate_result`. This is mainly for testing purposes. @@ -114,4 +177,8 @@ def get_objective(self): # for `Solver.set_objective`. This defines the # benchmark's API for passing the objective to the solver. # It is customizable for each benchmark. - return dict(train_dataset=self.train_dataset, physics=self.physics) + + return dict(train_dataset=self.train_dataset, + physics=self.physics, + image_size=self.image_size, + dataset_name=self.dataset_name,) diff --git a/solvers/ddrm.py b/solvers/ddrm.py new file mode 100644 index 0000000..8b248f0 --- /dev/null +++ b/solvers/ddrm.py @@ -0,0 +1,48 @@ +from benchopt import BaseSolver, safe_import_context + +with safe_import_context() as import_ctx: + import torch + from torch.utils.data import DataLoader + import deepinv as dinv + import numpy as np + + +class Solver(BaseSolver): + name = 'DDRM' + + parameters = {} + + sampling_strategy = 'run_once' + + requirements = [] + + def set_objective(self, train_dataset, physics, image_size): + batch_size = 2 + self.train_dataloader = DataLoader( + train_dataset, batch_size=batch_size, shuffle=False + ) + self.device = ( + dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu" + ) + self.physics = physics + + def run(self, n_iter): + denoiser = dinv.models.DRUNet(pretrained="download").to(self.device) + + sigmas = (np.linspace(1, 0, 100) + if torch.cuda.is_available() + else np.linspace(1, 0, 10)) + + self.model = dinv.sampling.DDRM( + denoiser=denoiser, + etab=1.0, + sigmas=sigmas, + verbose=True + ) + self.model.eval() + + def get_result(self): + return dict(model=self.model, model_name="DiffPIR", device=self.device) + + def skip(self, train_dataset, physics, image_size, dataset_name): + return True, "Not yet implemented." diff --git a/solvers/diffpir.py b/solvers/diffpir.py index a3bcaf2..8c800e7 100644 --- a/solvers/diffpir.py +++ b/solvers/diffpir.py @@ -4,6 +4,7 @@ import torch from torch.utils.data import DataLoader import deepinv as dinv + from benchmark_utils.denoiser_2c import Denoiser_2c class Solver(BaseSolver): @@ -15,7 +16,7 @@ class Solver(BaseSolver): requirements = [] - def set_objective(self, train_dataset, physics): + def set_objective(self, train_dataset, physics, image_size, dataset_name): batch_size = 2 self.train_dataloader = DataLoader( train_dataset, batch_size=batch_size, shuffle=False @@ -25,14 +26,23 @@ def set_objective(self, train_dataset, physics): ) self.physics = physics + self.image_size = image_size + def run(self, n_iter): - denoiser = dinv.models.DRUNet(pretrained="download").to(self.device) + if self.image_size[0] == 2: + denoiser = Denoiser_2c(device=self.device) + else: + denoiser = dinv.models.DRUNet( + pretrained="download", + device=self.device + ) self.model = dinv.sampling.DiffPIR( model=denoiser, data_fidelity=dinv.optim.data_fidelity.L2(), device=self.device ) + self.model.eval() def get_result(self): diff --git a/solvers/dip.py b/solvers/dip.py index 3a41c91..f77260e 100644 --- a/solvers/dip.py +++ b/solvers/dip.py @@ -16,9 +16,9 @@ class Solver(BaseSolver): requirements = ["optuna"] - def set_objective(self, train_dataset, physics): + def set_objective(self, train_dataset, physics, image_size, dataset_name): self.train_dataset = train_dataset - batch_size = 32 + batch_size = 1 self.train_dataloader = DataLoader( train_dataset, batch_size=batch_size, shuffle=False ) @@ -26,24 +26,24 @@ def set_objective(self, train_dataset, physics): dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu" ) self.physics = physics.to(self.device) + self.image_size = image_size def run(self, n_iter): def objective(trial): lr = trial.suggest_float('lr', 1e-5, 1e-2, log=True) iterations = trial.suggest_int('iterations', 50, 500, log=True) - # TODO: Remove - # iterations = 5 - model = self.get_model(lr, iterations) psnr = [] for x, y in self.train_dataloader: x, y = x.to(self.device), y.to(self.device) + x_hat = torch.cat([ model(y_i[None], self.physics) for y_i in y ]) + psnr.append(dinv.metric.PSNR()(x_hat, x)) psnr = torch.mean(torch.cat(psnr)).item() @@ -51,13 +51,11 @@ def objective(trial): return psnr study = optuna.create_study(direction='maximize') - study.optimize(objective, n_trials=1) + study.optimize(objective, n_trials=3) best_trial = study.best_trial best_params = best_trial.params - # TODO : replace 5 by best_params['iterations']) - # self.model = self.get_model(best_params['lr'], 5) self.model = self.get_model( best_params['lr'], best_params['iterations'] diff --git a/solvers/dpir.py b/solvers/dpir.py index 5669eed..33727ea 100644 --- a/solvers/dpir.py +++ b/solvers/dpir.py @@ -5,6 +5,15 @@ from torch.utils.data import DataLoader import deepinv as dinv import numpy as np + import torchvision + from deepinv.optim import BaseOptim + from deepinv.optim.prior import PnP + from deepinv.optim.data_fidelity import L2 + from deepinv.optim.optimizers import create_iterator + from deepinv.optim.dpir import get_DPIR_params + from benchmark_utils.denoiser_2c import Denoiser_2c + from benchmark_utils.metrics import CustomPSNR + from tqdm import tqdm class Solver(BaseSolver): @@ -16,8 +25,8 @@ class Solver(BaseSolver): requirements = [] - def set_objective(self, train_dataset, physics): - batch_size = 2 + def set_objective(self, train_dataset, physics, image_size, dataset_name): + batch_size = 1 self.train_dataloader = DataLoader( train_dataset, batch_size=batch_size, shuffle=False ) @@ -25,27 +34,73 @@ def set_objective(self, train_dataset, physics): dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu" ) self.physics = physics + self.image_size = image_size + self.dataset_name = dataset_name def run(self, n_iter): best_sigma = 0 best_psnr = 0 + + # If the number of channels is 2 we use a custom DPIR solver + if self.image_size[0] == 2: + model_class = DPIR_2C + else: + model_class = dinv.optim.DPIR + + # If the number of channels is different from 1 or 3 + # then we can't use pretrained DRUNet for sigma in np.linspace(0.01, 0.1, 10): - model = dinv.optim.DPIR(sigma=sigma, device=self.device) + model = model_class(sigma=sigma, device=self.device) + + psnr = [] - results = dinv.test( - model, + bar = tqdm( self.train_dataloader, - self.physics, - metrics=[dinv.metric.PSNR(), dinv.metric.SSIM()], - device=self.device + desc="DPIR : Looking for the best sigma" ) + for x, y in bar: + x, y = x.to(self.device), y.to(self.device) + + x_hat = model(y, self.physics) + + if (self.dataset_name == 'FastMRI'): + transform = torchvision.transforms.Compose( + [ + torchvision.transforms.CenterCrop(x.shape[-2:]), + dinv.metric.functional.complex_abs, + ] + ) + + CustomPSNR.transform = transform - if results["PSNR"] > best_psnr: + psnr.append(CustomPSNR()(x_hat, x)) + else: + psnr.append(dinv.metric.PSNR()(x_hat, x)) + + psnr = torch.mean(torch.cat(psnr)).item() + + if psnr > best_psnr: best_sigma = sigma - best_psnr = results["PSNR"] + best_psnr = psnr - self.model = dinv.optim.DPIR(sigma=best_sigma, device=self.device) + self.model = model_class(sigma=best_sigma, device=self.device) self.model.eval() def get_result(self): return dict(model=self.model, model_name="DPIR", device=self.device) + + +# Custom DPIR solver with 2 channels +class DPIR_2C(BaseOptim): + def __init__(self, sigma=0.1, device="cuda"): + prior = PnP(denoiser=Denoiser_2c(device=device)) + sigma_denoiser, stepsize, max_iter = get_DPIR_params(sigma) + params_algo = {"stepsize": stepsize, "g_param": sigma_denoiser} + super(DPIR_2C, self).__init__( + create_iterator("HQS", prior=prior, F_fn=None, g_first=False), + max_iter=max_iter, + data_fidelity=L2(), + prior=prior, + early_stop=False, + params_algo=params_algo, + ) diff --git a/solvers/ifft2.py b/solvers/ifft2.py new file mode 100644 index 0000000..c368b96 --- /dev/null +++ b/solvers/ifft2.py @@ -0,0 +1,46 @@ +from benchopt import BaseSolver, safe_import_context + +with safe_import_context() as import_ctx: + import torch + from torch.utils.data import DataLoader + import deepinv as dinv + + +class Solver(BaseSolver): + name = 'IFFT2' + + parameters = {} + + sampling_strategy = 'run_once' + + requirements = [] + + def set_objective(self, train_dataset, physics, image_size, dataset_name): + batch_size = 2 + self.train_dataloader = DataLoader( + train_dataset, batch_size=batch_size, shuffle=False + ) + self.device = ( + dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu" + ) + self.physics = physics + self.image_size = image_size + self.dataset_name = dataset_name + + def run(self, n_iter): + def model(y, physics): + return physics.A_dagger(y) + + self.model = model + + def get_result(self): + return dict(model=self.model, model_name="IFFT2", device=self.device) + + def skip(self, **objective_dict): + if isinstance( + objective_dict['physics'], + dinv.physics.mri.MultiCoilMRI + ): + return False, None + + return True, "This solver is only available for MRI dataset" diff --git a/solvers/u-net.py b/solvers/u-net.py index b3e9534..4bd5e4b 100644 --- a/solvers/u-net.py +++ b/solvers/u-net.py @@ -3,8 +3,12 @@ with safe_import_context() as import_ctx: import torch + import torch.nn.functional as F from torch.utils.data import DataLoader import deepinv as dinv + import torchvision + from benchmark_utils.metrics import CustomMSE + from benchmark_utils.custom_models import MRIUNet class Solver(BaseSolver): @@ -19,8 +23,8 @@ class Solver(BaseSolver): requirements = [] - def set_objective(self, train_dataset, physics): - batch_size = 2 + def set_objective(self, train_dataset, physics, image_size, dataset_name): + batch_size = 1 self.train_dataloader = DataLoader( train_dataset, batch_size=batch_size, shuffle=False ) @@ -28,43 +32,96 @@ def set_objective(self, train_dataset, physics): dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu" ) self.physics = physics.to(self.device) + self.image_size = image_size + self.dataset_name = dataset_name def run(self, n_iter): epochs = 4 - model = dinv.models.UNet( - in_channels=3, out_channels=3, scales=3, batch_norm=False - ).to(self.device) + x, y = next(iter(self.train_dataloader)) - verbose = True # print training information - wandb_vis = False # plot curves and images in Weight&Bias + if self.dataset_name == 'FastMRI': + model = MRIUNet( + in_channels=y.shape[1] * y.shape[2], + out_channels=x.shape[1], + scales=3, + batch_norm=False + ).to(self.device) + else: + model = dinv.models.UNet( + in_channels=y.shape[1], out_channels=x.shape[1], scales=4, + batch_norm=False + ).to(self.device) - # choose training losses - losses = dinv.loss.SupLoss(metric=dinv.metric.MSE()) - - # choose optimizer and scheduler optimizer = torch.optim.Adam( model.parameters(), lr=self.lr, weight_decay=1e-8 ) scheduler = torch.optim.lr_scheduler.StepLR( - optimizer, step_size=int(epochs * 0.8) - ) - trainer = dinv.Trainer( - model, - device=self.device, - verbose=verbose, - wandb_vis=wandb_vis, - physics=self.physics, - epochs=epochs, - scheduler=scheduler, - losses=losses, - optimizer=optimizer, - show_progress_bar=True, - train_dataloader=self.train_dataloader, + optimizer, step_size=int(epochs * 0.7) ) - self.model = trainer.train() - self.model.eval() + # choose training losses + if self.dataset_name == 'FastMRI': + criterion = dinv.loss.SupLoss(metric=CustomMSE()) + else: + criterion = dinv.loss.SupLoss(metric=dinv.metric.MSE()) + + for epoch in range(epochs): + model.train() + running_loss = 0.0 + + for x, y in self.train_dataloader: + x, y = x.to(self.device), y.to(self.device) + + if type(self.physics) is dinv.physics.blur.Downsampling: + _, _, x_h, x_w = x.shape + _, _, y_h, y_w = y.shape + + diff_h = x_h - y_h + diff_w = x_w - y_w + + pad_top = diff_h // 2 + pad_bottom = diff_h - pad_top + pad_left = diff_w // 2 + pad_right = diff_w - pad_left + + y = F.pad( + y, + pad=(pad_left, pad_right, pad_top, pad_bottom), + value=0 + ) + + x_hat = model(y, self.physics) + + if self.dataset_name == 'FastMRI': + transform = torchvision.transforms.Compose( + [ + torchvision.transforms.CenterCrop(x.shape[-2:]), + dinv.metric.functional.complex_abs, + ] + ) + criterion.metric.transform = transform + + loss = criterion(x_hat, x) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + running_loss += loss.item() + + avg_loss = running_loss / len(self.train_dataloader) + print(f"Epoch [{epoch + 1}/{epochs}], Loss: {avg_loss:.4f}") + + scheduler.step() + + model.eval() + + self.model = model def get_result(self): - return dict(model=self.model, model_name="U-Net", device=self.device) + return dict( + model=self.model, + model_name=f"U-Net_{self.lr}", + device=self.device + )