In [None]:
# imports
import os
import wget
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")

# settings for reproducibility
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# misc
def parameter_count(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

class VGG(nn.Module):
    def __init__(self):
        super(VGG, self).__init__()
        
        channels_in = 3
        channels_out = 32
        
        layers = []
        for i in range(4):
            layers.append(nn.Conv2d(channels_in, channels_out, 3, stride=1, padding=1, bias=False))
            layers.append(nn.BatchNorm2d(channels_out))
            layers.append(nn.ReLU())
            channels_in = channels_out
            
            if(i%2==1):
                channels_out *=2
                layers.append(nn.MaxPool2d(2))

        layers.append(nn.Conv2d(channels_in, channels_in, 1, bias=False))
        layers.append(nn.BatchNorm2d(channels_in))
        layers.append(nn.ReLU())
        layers.append(nn.Conv2d(channels_in, 2, 1))
        
        self.features = nn.Sequential(*layers)
        
    def forward(self, x):
        return self.features(x)

In [None]:
# load train/test dataset (this may take a minute or two)
if not os.path.isfile('tcia_pancreas_data.pth'):
    print('downloading tcia pancreas data from https://cloud.imi.uni-luebeck.de/s/Z7LWqcbopfwzmKn/download')
    wget.download('https://cloud.imi.uni-luebeck.de/s/Z7LWqcbopfwzmKn/download')
    
imgs = torch.clamp(torch.load('tcia_pancreas_data.pth').float()/128-1, -1, 1)

### Task 0: Familiarize yourself with the data set and visualize some CT image slices

In [None]:
# visualise some data
## TODO ##

### Task 1: Build a variational autoencoder architecture and train with KLD loss

1. Create an encoder that takes Bx3x256x256 input images and produces two 512 dim. latent vectors (μ and σ). Use nine blocks of Conv2d > BNorm > LeakyReLU with kernel size 3x3 and an increasing number of filter channels from 16 to 64. Use five stride=2 convolutions (every other layer) and padding=1. Add fully-connected layers with 1024 and then 512 channels (both implemented as Conv2d) the first one requires a kernel-size of 8x8.
2. Implement the decoder, which takes a Bx512x1x1 input and should generate a full-sized (Bx3x256x256) output image. Start with one fully-connected 1x1 Conv2d layer (512>1024 channels) followed by a ConvTranspose2d with kernel=8, channel-out=64 and no padding. Use then again blocks of Conv2d of kernel size 3x3, but alternate them with ConvTranspose2d of size 4x4 with stride 2. All these Conv-Layers should have appropriate padding and BNorm > LeakyReLU. Finish the architecture with a Conv2d with 3 output-channels and a tanh.
3. Build a VAE with your encoder-decoder architecture, you can directly follow the example from https://github.com/pytorch/examples/blob/master/vae/main.py.
Note: mu and log_var have to be returned after the forward path for loss calculations.
4.  Train this network for 160 epochs with batch size = 32, learning rate = 0.0025. Use the L1loss for the reconstruction of images and the Kullback-Leibler divergence (KLD) as shown in the example. Visualise some reconstructed examples after every 15 epoch (given a training image as input) as well as synthetically created examples, you can simply pass a torch.randn(B, 512,1,1) tensor to the decode function of your VAE model.

In [None]:
# definition of vae model
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        
        self.encoder = ## TODO ##
        self.encoder_mu = nn.Conv2d(1024, 512, 1)
        self.encoder_logvar = nn.Conv2d(1024, 512, 1)
        
        self.decoder = ## TODO ##
        
    def encode(self, x):
        x = self.encoder(x)
        return self.encoder_mu(x), self.encoder_logvar(x)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std
    
    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

print('parameter count:', parameter_count(VAE()))

In [None]:
# train vae with perceptual loss

# parameters
batch_size = 32
init_lr = 0.0025
num_epochs = 160

# data
class TCIAPancreasDataset(Dataset):
    def __init__(self, imgs):
        self.imgs = imgs
        
    def __len__(self):
        return self.imgs.shape[0]
    
    def __getitem__(self, idx):
        img = self.imgs[idx]
        return img
    
data_set = TCIAPancreasDataset(imgs)
data_loader = DataLoader(data_set, batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)

# model
vae = VAE()
vae.cuda()

# optimizer
optimizer = optim.Adam(vae.parameters(), lr=init_lr)

# for num_epochs
for epoch in range(num_epochs):
    ## TODO ##

### Task 2: Implement a perceptual loss and add this to your VAE training

A pre-trained fully-convolutional network which roughly follows the VGG architecture that has been trained for CT segmentation is provided for this task. 

1. Implement a forward hook function to extract perceptual features from after the ReLUs in layers 2, 5 and 9. You can use the following snippet to store the output during a forward path: 

```
def get_output():
    def hook(model, input, output):
        model.output = output
    return hook
    
```

2. Next you need to call the method register_forward_hook for layers 2, 5 and 9 with get_output as function argument. After a forward path of your ground truth (training) image through the vgg_model you can copy the output tensors to a list of tensors. Repeat the same procedure after passing the reconstructed image through your network.  Use the L1Loss between those two feature tensors for all three layers as additional perceptual loss and retrain the network from scratch. Visualise the outputs to see whether there are any improvements.

In [None]:
# train vae with perceptual loss

# parameters
batch_size = 32
init_lr = 0.0025
num_epochs = 160

# data
class TCIAPancreasDataset(Dataset):
    def __init__(self, imgs):
        self.imgs = imgs
        
    def __len__(self):
        return self.imgs.shape[0]
    
    def __getitem__(self, idx):
        img = self.imgs[idx]
        return img
    
data_set = TCIAPancreasDataset(imgs)
data_loader = DataLoader(data_set, batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)

# model
vae = VAE()
vae.cuda()

# optimizer
optimizer = optim.Adam(vae.parameters(), lr=init_lr)

# criterion
vgg = VGG()

if not os.path.isfile('tcia_pancreas_vgg.pth'):
    print('downloading tcia pancreas vgg model from https://cloud.imi.uni-luebeck.de/s/xJ8edbQ3mmZK7Tg/download')
    wget.download('https://cloud.imi.uni-luebeck.de/s/xJ8edbQ3mmZK7Tg/download')                       
vgg_state_dict = torch.load('tcia_pancreas_vgg.pth')
vgg.features.load_state_dict(vgg_state_dict, False)
vgg.cuda()
vgg.eval()

def get_output():
    def hook(model, input, output):
        model.output = output
    return hook

layer = [2,5,9]
for i in layer:
    vgg.features[i].register_forward_hook(get_output())

l1_loss = nn.L1Loss()    
    
def perceptual_loss(img_batch, recon):
    with torch.no_grad():
        img_batch_out = vgg(img_batch)
    features_img_batch = []
    for i in layer:
        features_img_batch.append(vgg.features[i].output)
        
    recon_out = vgg(recon)
    features_recon = []
    for i in layer:
        features_recon.append(vgg.features[i].output)
    
    loss = 0.0
    for i in range(len(layer)):
        loss += l1_loss(features_recon[i], features_img_batch[i])
        
    return loss

def kld_loss(mu, logvar):
    return (-0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp()))

def criterion(img_batch, recon, mu, logvar):
    return l1_loss(recon, img_batch) + kld_loss(mu, logvar) + perceptual_loss(img_batch, recon)

# statistics
train_losses = []
test_losses = []

# for num_epochs
for epoch in range(num_epochs):
    ## TODO ##

### Task 3: Adress the checkerboard pattern in the synthetically generated images
1. As described in the following blog article: https://distill.pub/2016/deconv-checkerboard/ the checkerboard could be reduced by replacing transpose convolutions with bilinear upsampling. Explore how and whether you can visually improve your synthetic image generation like that.

In [None]:
# definition of vae model
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        
        self.encoder = ## TODO ##
        self.encoder_mu = nn.Conv2d(1024, 512, 1)
        self.encoder_logvar = nn.Conv2d(1024, 512, 1)
        
        self.decoder = ## TODO ##
        
    def encode(self, x):
        x = self.encoder(x)
        return self.encoder_mu(x), self.encoder_logvar(x)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std
    
    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

print('parameter count:', parameter_count(VAE()))

In [None]:
# train vae with perceptual loss and adress checkerboard pattern

# parameters
batch_size = 32
init_lr = 0.0025
num_epochs = 160

# data
class TCIAPancreasDataset(Dataset):
    def __init__(self, imgs):
        self.imgs = imgs
        
    def __len__(self):
        return self.imgs.shape[0]
    
    def __getitem__(self, idx):
        img = self.imgs[idx]
        return img
    
data_set = TCIAPancreasDataset(imgs)
data_loader = DataLoader(data_set, batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)

# model
vae = VAE()
vae.cuda()

# optimizer
optimizer = optim.Adam(vae.parameters(), lr=init_lr)

# for num_epochs
for epoch in range(num_epochs):
    ## TODO ##