# Imaging inverse problems with adversarial networks

This example shows you how to train various networks using adversarial training for deblurring problems. We demonstrate running training and inference using DeblurGAN, AmbientGAN and UAIR implemented in the `deepinv` library, and how to simply train your own GAN by using `deepinv.training.AdversarialTrainer`. These examples can also be easily extended to train more complicated GANs such as CycleGAN.

- Kupyn et al., [_DeblurGAN: Blind Motion Deblurring Using Conditional Adversarial Networks_](https://openaccess.thecvf.com/content_cvpr_2018/papers/Kupyn_DeblurGAN_Blind_Motion_CVPR_2018_paper.pdf)
- Bora et al., [_AmbientGAN: Generative models from lossy measurements_](https://openreview.net/forum?id=Hy7fDog0b)
- Pajot et al., [_Unsupervised Adversarial Image Reconstruction_](https://openreview.net/forum?id=BJg4Z3RqF7)

Adversarial networks are characterised by the addition of an adversarial loss $\mathcal{L}_\text{adv}$ to the standard reconstruction loss:

$$\mathcal{L}_\text{adv}(x,\hat x;D)=\mathbb{E}_{x\sim p_x}\left[q(D(x))\right]+\mathbb{E}_{\hat x\sim p_{\hat x}}\left[q(1-D(\hat x))\right]$$

where $D(\cdot)$ is the discriminator model, $x$ is the reference image, $\hat x$ is the estimated reconstruction, $q(\cdot)$ is a quality function (e.g $q(x)=x$ for WGAN). Training alternates between generator $f$ and discriminator $D$ in a minimax game. When there are no ground truths (i.e unsupervised), this may be defined on the measurements $y$ instead.

**DeblurGAN** forward pass: 

$$\hat x = f(y)$$

**DeblurGAN** loss: 

$$\mathcal{L}=\mathcal{L}_\text{sup}(\hat x, x)+\mathcal{L}_\text{adv}(\hat x, x;D)$$

where $\mathcal{L}_\text{sup}$ is a supervised loss such as pixel-wise MSE or VGG Perceptual Loss.

**AmbientGAN** forward pass: 

$$\hat x = f(z),\quad z\sim \mathcal{N}(\mathbf{0},\mathbf{I}_k)$$

**AmbientGAN** loss (where $A(\cdot)$ is the physics): 

$$\mathcal{L}=\mathcal{L}_\text{adv}(A(\hat x), y;D)$$

Forward pass at eval time:

$$\hat x = f(\hat z)\quad\text{s.t.}\quad\hat z=\operatorname*{argmin}_z \lVert A(f(z))-y\rVert_2^2$$

**UAIR** forward pass:

$$\hat x = f(y)$$

**UAIR** loss: 

$$\mathcal{L}=\mathcal{L}_\text{adv}(\hat y, y;D)+\lVert A(f(\hat y))- \hat y\rVert^2_2,\quad\hat y=A(\hat x)$$

In [1]:
import deepinv as dinv
from deepinv.loss import adversarial
import torch
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, ToTensor, CenterCrop
from torchvision.datasets.utils import download_and_extract_archive

device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"

Load data and apply some forward degradation to the images. For this example we use the Urban100 dataset resized to 128x128. For simplicity we apply an isotropic Gaussian blur for demonstration, although the original papers deal with harder inverse problems.

In [2]:
physics = dinv.physics.Blur(dinv.physics.blur.gaussian_blur(sigma=(5, 5)))

In [3]:
download_and_extract_archive(
            "https://huggingface.co/datasets/eugenesiow/Urban100/resolve/main/data/Urban100_HR.tar.gz?download=true",
            "Urban100",
            filename="Urban100_HR.tar.gz",
            md5="65d9d84a34b72c6f7ca1e26a12df1e4c"
        )

Using downloaded and verified file: Urban100\Urban100_HR.tar.gz
Extracting Urban100\Urban100_HR.tar.gz to Urban100


In [4]:
train_dataset, test_dataset = random_split(ImageFolder("Urban100", transform=Compose([ToTensor(), CenterCrop(64)])), (0.8, 0.2))

dataset_path = dinv.datasets.generate_dataset(
    train_dataset= train_dataset,
    test_dataset = test_dataset,
    physics=physics,
    device=device,
    save_dir=f"Urban100",
    )
    
train_dataloader = DataLoader(
    dinv.datasets.HDF5Dataset(dataset_path, train=True), 
    shuffle=True
)
test_dataloader = DataLoader(
    dinv.datasets.HDF5Dataset(dataset_path, train=False),
    shuffle=False
)

Computing train measurement vectors from base dataset...


100%|██████████| 2/2 [00:03<00:00,  1.98s/it]


Computing test measurement vectors from base dataset...


100%|██████████| 5/5 [00:00<00:00, 10.14it/s]

Dataset has been saved in Urban100





Define reconstruction network (i.e conditional generator) and discriminator network to use for adversarial training. For demonstration we use a simple U-Net as the reconstruction network and the discriminator from [PatchGAN](https://arxiv.org/abs/1611.07004), but these can be replaced with any architecture e.g transformers, unrolled etc. Further discriminator models are in `deepinv.models.gan`.

In [5]:
def get_models(model=None, D=None, lr_g=1e-4, lr_d=1e-4):
    if model is None:
        model = dinv.models.UNet(
            in_channels=3, 
            out_channels=3,
            scales=2,
            circular_padding=True,
            batch_norm=False
            )

    if D is None:
        D = dinv.models.PatchGANDiscriminator(
            n_layers=2,
            batch_norm=False
        )

    #TODO make sure zero_grad_only right way round for training
    optimizer = dinv.training_utils.adversarial.AdversarialOptimizer(
        torch.optim.Adam(model.parameters(), lr=lr_g, weight_decay=1e-8),
        torch.optim.Adam(D.parameters(),     lr=lr_d, weight_decay=1e-8),
    )
    scheduler = dinv.training_utils.adversarial.AdversarialScheduler(
        torch.optim.lr_scheduler.StepLR(optimizer.G, step_size=5, gamma=0.9),
        torch.optim.lr_scheduler.StepLR(optimizer.D, step_size=5, gamma=0.9)
    )
    
    return model, D, optimizer, scheduler

### DeblurGAN training

In [6]:
model, D, optimizer, scheduler = get_models()

Construct pixel-wise and adversarial losses as defined above. We use the MSE for the supervised pixel-wise metric for simplicity but this can be easily replaced with a perceptual loss if desired.

In [7]:
loss_g = [
    dinv.loss.SupLoss(metric=torch.nn.MSELoss()),
    adversarial.DeblurGANGeneratorLoss(device=device)
]
loss_d = adversarial.DeblurGANDiscriminatorLoss(device=device)

Train the networks using `AdversarialTrainer`

In [8]:
model = dinv.training_utils.AdversarialTrainer(
    model=model,
    D=D,
    physics=physics,
    train_dataloader=train_dataloader,
    eval_dataloader=test_dataloader,
    epochs=3,
    losses=loss_g,
    losses_d=loss_d,
    optimizer=optimizer,
    scheduler=scheduler,
    verbose=True,
    save_path=None,
    device=device
).train()



The model has 444867 trainable parameters


Eval epoch 1: 100%|████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 54.01it/s, PSNR=17.3]
Train epoch 1: 100%|██████████████████████████| 80/80 [00:05<00:00, 13.78it/s, SupLoss=0.0228, DeblurGANGeneratorLoss=0.00422, TotalLoss=1, PSNR=17.9]
Eval epoch 2: 100%|████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 63.27it/s, PSNR=17.9]
Train epoch 2: 100%|████████████████████████████| 80/80 [00:05<00:00, 15.23it/s, SupLoss=0.0225, DeblurGANGeneratorLoss=0.00408, TotalLoss=1, PSNR=18]
Eval epoch 3: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 65.26it/s, PSNR=18]
Train epoch 3: 100%|████████████████████████████| 80/80 [00:05<00:00, 15.17it/s, SupLoss=0.0224, DeblurGANGeneratorLoss=0.00415, TotalLoss=1, PSNR=18]


UNet(
  (Maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (Conv1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=circular)
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
  )
  (Conv2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=circular)
    (1): ReLU(inplace=True)
    (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
  )
  (Up2): Sequential(
    (0): Upsample(scale_factor=2.0, mode='nearest')
    (1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): ReLU(inplace=True)
  )
  (Up_conv2): Sequential(
    (0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=circular)
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1),

### UAIR training

In [6]:
model, D, optimizer, scheduler = get_models(lr_g=1e-4, lr_d=4e-4) # learning rates from original paper

Construct losses as defined above

In [7]:
loss_g = adversarial.UAIRGeneratorLoss(device=device)
loss_d = adversarial.UAIRDiscriminatorLoss(device=device)

Train the networks using `AdversarialTrainer`

In [8]:
model = dinv.training_utils.AdversarialTrainer(
    model=model,
    D=D,
    physics=physics,
    train_dataloader=train_dataloader,
    eval_dataloader=test_dataloader,
    epochs=3,
    losses=loss_g,
    losses_d=loss_d,
    optimizer=optimizer,
    scheduler=scheduler,
    verbose=True,
    save_path=None,
    device=device
).train()



The model has 444867 trainable parameters


Eval epoch 1: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 50.45it/s, PSNR=18]
Train epoch 1: 100%|████████████████████████████████████████████████████████████████████████████| 80/80 [00:08<00:00,  9.09it/s, TotalLoss=1, PSNR=17]
Eval epoch 2: 100%|████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 67.42it/s, PSNR=17.2]
Train epoch 2: 100%|██████████████████████████████████████████████████████████████████████████| 80/80 [00:08<00:00,  8.92it/s, TotalLoss=1, PSNR=16.8]
Eval epoch 3: 100%|████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 58.31it/s, PSNR=16.9]
Train epoch 3: 100%|████████████████████████████████████████████████████████████████████████████| 80/80 [00:10<00:00,  7.63it/s, TotalLoss=1, PSNR=17]


### AmbientGAN training

In [14]:
model = dinv.models.AmbientDCGANGenerator(nz=100, ngf=32)
D = dinv.models.DCGANDiscriminator(ndf=32)
_, _, optimizer, scheduler = get_models(model=model, D=D, lr_g=2e-4, lr_d=2e-4) # learning rates from original paper

Construct losses as defined above

In [11]:
loss_g = adversarial.AmbientGANGeneratorLoss(device=device)
loss_d = adversarial.AmbientGANDiscriminatorLoss(device=device)

Train the networks using `AdversarialTrainer`

In [15]:
model = dinv.training_utils.AdversarialTrainer(
    model=model,
    D=D,
    physics=physics,
    train_dataloader=train_dataloader,
    eval_dataloader=test_dataloader,
    epochs=3,
    losses=loss_g,
    losses_d=loss_d,
    optimizer=optimizer,
    scheduler=scheduler,
    verbose=True,
    save_path=None,
    device=device
).train()



The model has 1100224 trainable parameters


Eval epoch 1: 100%|████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 108.28it/s, PSNR=7.1]
Train epoch 1: 100%|██████████████████████████████████████████████████████████████████████████| 80/80 [00:03<00:00, 21.61it/s, TotalLoss=1, PSNR=4.13]
Eval epoch 2: 100%|███████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 162.89it/s, PSNR=4.16]
Train epoch 2: 100%|██████████████████████████████████████████████████████████████████████████| 80/80 [00:03<00:00, 21.94it/s, TotalLoss=1, PSNR=4.24]
Eval epoch 3: 100%|███████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 164.98it/s, PSNR=4.26]
Train epoch 3: 100%|██████████████████████████████████████████████████████████████████████████| 80/80 [00:03<00:00, 21.64it/s, TotalLoss=1, PSNR=4.34]


AmbientDCGANGenerator(
  (model): Sequential(
    (0): ConvTranspose2d(100, 256, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(32, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tan