# 11785 Fall 2024 Project: GAN-based Cross-Modality Medical Image Synthesis for Prostate Cancer 

# Group XV (Fifteen)

# Install Libraries

In [6]:
# install libraries
!pip install torch torchmetrics --q
!pip install wandb --quiet
!pip install pytorch-msssim --q
!pip install adabelief-pytorch --q

!pip install torchsummary -q


# Import Libraries

In [None]:
# import libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchsummary import summary
from torch.utils.data import Dataset, DataLoader, Subset
import torchvision.transforms as transforms
from torchvision.transforms.functional import to_pil_image
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import os
from pathlib import Path

import random
from sklearn.model_selection import train_test_split
from pytorch_msssim import ssim
import random

from adabelief_pytorch import AdaBelief

import random
from tqdm import tqdm

import torchvision.models as models


# Check and Set Device

In [None]:
!nvidia-smi
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device: ", device)

# Unzip Dataset


In [None]:
# unzip dataset
!unzip -q /content/ProstateCancerDataset.zip
!ls

# Config

In [None]:
# configs
config ={
    "run_name" : 'bgan-sa-perc-gradpen-v11',
    "dataset_dir" : 'ProstateCancerDataset',
    "output_dir" : 'output_bgan',

    "epochs" : 300,
    "img_size": 256,

    "lr": 0.0001,

    #if using SGD (TTUR) vary the learning rates according to this:
    "lr_g": 0.0001,         # Generator (slowest)
    "lr_e": 0.0002,         # Encoder (between G and D rates)
    "lr_d": 0.0004,         # Discriminator (fastest)

    "batch_size" : 1,       # typically 1
    "num_workers" : 4,
    "random_seed": 0,

    "latent_dim": 256,      # dimension of the latent vector, adjust based on desired diversity and complexity

    "lambda_L1": 10,        # weight for L1 loss, adjust based on image sharpness
    "lambda_KL": 0.001,     # Weight for KL divergence, adjust based on latent space regularization
    "grad_penalty_weight": 10.0,
    "perceptual_loss_weight": 1.0,

    "max_grad_norm": 0.1,

    "transforms": "default", # "default", "augmentation"

    }

# Data Loading

In [None]:

class MRIDataset(Dataset):
    def __init__(self, dataset_dir, transform=None, img_size=256):
        self.root_dir = dataset_dir
        self.transform = transform
        self.img_size = img_size
        # create resize transform
        self.resize = transforms.Compose([
            transforms.Resize((img_size, img_size), antialias=True),
        ])

        # collect all slices
        self.slices = []

        # Scan directories
        adc_dir = os.path.join(dataset_dir, "ADC")
        for patient_dir in sorted(os.listdir(adc_dir)):
            patient_adc_dir = os.path.join(adc_dir, patient_dir)
            patient_t2w_dir = os.path.join(dataset_dir, "T2w_Type",
                                         patient_dir.replace('_adc', '_t2w'))

            if os.path.exists(patient_t2w_dir):
                adc_files = sorted([f for f in os.listdir(patient_adc_dir)
                                  if os.path.isfile(os.path.join(patient_adc_dir, f))])
                t2w_files = sorted([f for f in os.listdir(patient_t2w_dir)
                                  if os.path.isfile(os.path.join(patient_t2w_dir, f))])

                min_slices = min(len(adc_files), len(t2w_files))

                for i in range(min_slices):
                    self.slices.append({
                        'adc_path': os.path.join(patient_adc_dir, adc_files[i]),
                        't2w_path': os.path.join(patient_t2w_dir, t2w_files[i])
                    })

    def __getitem__(self, idx):
        slice_paths = self.slices[idx]

        # load images
        adc_image = Image.open(slice_paths['adc_path']).convert('L')
        t2w_image = Image.open(slice_paths['t2w_path']).convert('L')

        # resize images
        adc_image = self.resize(adc_image)
        t2w_image = self.resize(t2w_image)

        # apply other transforms
        if self.transform:
            adc_image = self.transform(adc_image)
            t2w_image = self.transform(t2w_image)

        return {
            'A': adc_image,  # source domain (ADC)
            'B': t2w_image   # target domain (T2w)
        }


    def __len__(self):
        return len(self.slices)



In [None]:
def create_data_loaders(config):
    """
    Create train, validation, and test data loaders
    """
    # define transformations
    if config["transforms"] == "default":
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])
    elif config["transforms"] == "augmentation":

        transform = transforms.Compose([
            transforms.RandomAffine(degrees=5, translate=(0.05, 0.05), scale=(0.95, 1.05)),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.1, contrast=0.1),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])


    # create full dataset
    full_dataset = MRIDataset(dataset_dir=config['dataset_dir'], transform=transform)

    # calculate lengths for splits
    total_size = len(full_dataset)
    train_size = int(0.78 * total_size)
    val_size = int(0.12 * total_size)
    test_size = total_size - train_size - val_size

    # create splits
    train_indices, temp_indices = train_test_split(
        range(total_size), test_size=(val_size + test_size),
        random_state=config['random_seed']
    )
    val_indices, test_indices = train_test_split(
        temp_indices, test_size=test_size,
        random_state=config['random_seed']
    )

    # create subset datasets based on split ratios
    train_dataset = Subset(full_dataset, train_indices)
    val_dataset = Subset(full_dataset, val_indices)
    test_dataset = Subset(full_dataset, test_indices)

    # create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=config['num_workers'],
        pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=config['num_workers'],
        pin_memory=True
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=config['num_workers'],
        pin_memory=True
    )

    print(f"Total slices: {total_size}")
    print(f"Training slices: {len(train_dataset)}")
    print(f"Validation slices: {len(val_dataset)}")
    print(f"Testing slices: {len(test_dataset)}")

    return train_loader, val_loader, test_loader

In [None]:
# create data loaders
train_loader, val_loader, test_loader = create_data_loaders(config)

# Custom Loss Functions

In [None]:

# define a feature extractor 
class FeatureExtractor(nn.Module):
    def __init__(self, layers):
        super(FeatureExtractor, self).__init__()
        # use a pre-trained VGG model
        vgg = models.vgg19(pretrained=True).features
        self.selected_layers = layers
        self.features = nn.ModuleList([vgg[i] for i in layers]).eval()

    def forward(self, x):
        outputs = []
        for i, layer in enumerate(self.features):
            x = layer(x)
            if i in self.selected_layers:
                outputs.append(x)
        return outputs

# convert single-channel images to 3-channel
def to_three_channels(img):
    return img.repeat(1, 3, 1, 1)  # Repeat along the channel dimension

def perceptual_loss(real, fake, feature_extractor):
    # convert grayscale images to 3-channel
    real = to_three_channels(real)
    fake = to_three_channels(fake)

    # move images to the same device as the feature extractor
    real = real.to(next(feature_extractor.parameters()).device)
    fake = fake.to(next(feature_extractor.parameters()).device)

    # compute feature maps
    real_features = feature_extractor(real)
    fake_features = feature_extractor(fake)

    # compute L1 loss between feature maps
    criterion = nn.L1Loss()
    loss = 0
    for r, f in zip(real_features, fake_features):
        loss += criterion(r, f)
    return loss


feature_extractor = FeatureExtractor(layers=[0, 5, 10, 19]).to(device)  
feature_extractor.eval()  


Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:02<00:00, 212MB/s]  


FeatureExtractor(
  (features): ModuleList(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
)

In [None]:
# gradient penalty
def gradient_penalty(discriminator, real_samples, fake_samples):
    batch_size = real_samples.size(0)
    alpha = torch.rand(batch_size, 1, 1, 1).to(real_samples.device)

    # split concatenated samples back into source and target
    # each sample has 2 channels (concatenated source and target)
    source = real_samples[:, :1]  # first channel
    interpolated_target = alpha * real_samples[:, 1:] + (1 - alpha) * fake_samples[:, 1:]
    interpolated_target.requires_grad_(True)

    # forward pass
    d_interpolates = discriminator(source, interpolated_target)

    # compute gradients
    fake = torch.ones_like(d_interpolates).to(real_samples.device)
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolated_target,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]

    # calculate gradient penalty
    gradients = gradients.view(batch_size, -1)
    gradient_norm = gradients.norm(2, dim=1)
    gradient_penalty = ((gradient_norm - 1) ** 2).mean()

    return gradient_penalty

# Model Definition

## Encoder

In [None]:

class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        # adjust projection dimensions to match input channels
        self.query = nn.Conv2d(in_channels, in_channels, 1)
        self.key = nn.Conv2d(in_channels, in_channels, 1)
        self.value = nn.Conv2d(in_channels, in_channels, 1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch_size, channels, height, width = x.size()

        # project queries, keys and values
        proj_query = self.query(x).view(batch_size, -1, height * width).permute(0, 2, 1)  # B x HW x C
        proj_key = self.key(x).view(batch_size, -1, height * width)  # B x C x HW
        proj_value = self.value(x).view(batch_size, -1, height * width)  # B x C x HW

        # attention map
        energy = torch.bmm(proj_query, proj_key)  # B x HW x HW
        attention = F.softmax(energy, dim=-1)

        # apply attention to values
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))  # B x C x HW
        out = out.view(batch_size, channels, height, width)

        return self.gamma * out + x

class Encoder(nn.Module):
    def __init__(self, latent_dim=8):
        super().__init__()

        # initial layer
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2, True)
        )

        # downsampling layers with residual connections
        self.down_blocks = nn.ModuleList([
            self._make_down_block(64, 128),    # 128x128 -> 64x64
            self._make_down_block(128, 256),   # 64x64 -> 32x32
            self._make_down_block(256, 512),   # 32x32 -> 16x16
            self._make_down_block(512, 512),   # 16x16 -> 8x8
            self._make_down_block(512, 512)    # 8x8 -> 4x4
        ])

        # attention modules - match channel dimensions
        self.attention_blocks = nn.ModuleList([
            SelfAttention(512),  # after 3rd down block
            SelfAttention(512),  # after 4th down block
            SelfAttention(512)   # after 5th down block
        ])

        # output layers for mu and logvar with dropout
        self.dropout = nn.Dropout(0.5)
        self.flatten = nn.Flatten()
        self.fc_mu = nn.Sequential(
            nn.Linear(512 * 4 * 4, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, latent_dim)
        )
        self.fc_logvar = nn.Sequential(
            nn.Linear(512 * 4 * 4, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, latent_dim)
        )

    def _make_down_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2, True)
        )

    def forward(self, x):
        # initial convolution
        x = self.conv1(x)

        # down blocks with attention
        features = []
        for i, block in enumerate(self.down_blocks):
            x = block(x)
            features.append(x)
            # apply attention only after reaching 512 channels (i >= 2)
            if i >= 2:
                x = self.attention_blocks[i-2](x)

        # flatten and get latent parameters
        x = self.dropout(self.flatten(x))
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)

        return mu, logvar

    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return mu + eps * std
        return mu

## Generator

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim=8):
        super().__init__()

        # initial processing of source image
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 64, 4, stride=2, padding=1),  # B x 64 x 128 x 128
            nn.LeakyReLU(0.3, True),
            nn.Dropout(0.25)
        )

        # downsample blocks
        self.down_blocks = nn.ModuleList([
            self._make_down_block(64, 128),    # B x 128 x 64 x 64
            self._make_down_block(128, 256),   # B x 256 x 32 x 32
            self._make_down_block(256, 512),   # B x 512 x 16 x 16
            self._make_down_block(512, 512),   # B x 512 x 8 x 8
            self._make_down_block(512, 512),   # B x 512 x 4 x 4
        ])

        # inject latent code
        self.latent_projection = nn.Sequential(
            nn.Linear(latent_dim, 512 * 4 * 4),
            nn.ReLU(True),
        )

        # upsample blocks
        self.up_blocks = nn.ModuleList([
            self._make_up_block(1024, 512),    # B x 512 x 8 x 8
            self._make_up_block(1024, 512),    # B x 512 x 16 x 16
            self._make_up_block(1024, 256),    # B x 256 x 32 x 32
            self._make_up_block(512, 128),     # B x 128 x 64 x 64
            self._make_up_block(256, 64),      # B x 64 x 128 x 128
        ])

        # final output layer
        self.final = nn.Sequential(
            nn.ConvTranspose2d(128, 1, 4, stride=2, padding=1),  # B x 1 x 256 x 256
            nn.Tanh()
        )

    def _make_down_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.3, True),
        )

    def _make_up_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, 4, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(True)
        )

    def forward(self, x, z):
        # process input image through downsample blocks
        features = [self.conv1(x)]
        for block in self.down_blocks:
            features.append(block(features[-1]))

        # process latent code
        z = self.latent_projection(z)
        z = z.view(z.size(0), -1, 4, 4)

        # combine with last feature map
        x = torch.cat([features[-1], z], dim=1)

        # upsample with skip connections
        for i, block in enumerate(self.up_blocks):
            x = block(x)
            if i < len(features) - 1:  # Skip connection
                x = torch.cat([x, features[-(i+2)]], dim=1)

        # final output
        return self.final(x)

## Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()

        # process paired images (source + target)
        self.model = nn.Sequential(
            # input: B x 2 x 256 x 256 (concatenated source and target images)
            nn.Conv2d(2, 64, 4, stride=2, padding=1),  # B x 64 x 128 x 128
            nn.LeakyReLU(0.3, True),

            nn.Conv2d(64, 128, 4, stride=2, padding=1),  # B x 128 x 64 x 64
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.3, True),


            nn.Conv2d(128, 256, 4, stride=2, padding=1),  # B x 256 x 32 x 32
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.3, True),

            nn.Conv2d(256, 512, 4, stride=2, padding=1),  # B x 512 x 16 x 16
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.3, True),

            nn.Conv2d(512, 1, 4, stride=1, padding=1),  # B x 1 x 15 x 15
        )

    def forward(self, source, target):
        # concatenate source and target images along channel dimension
        x = torch.cat([source, target], dim=1)
        return self.model(x)



## Bicycle GAN

In [None]:
# class that manages all models (E, G, D) and their interactions
class BicycleGAN:
    def __init__(self):
        self.latent_dim = config['latent_dim']
        self.input_size = config['img_size']
        self.lambda_L1 = config['lambda_L1']
        self.lambda_KL = config['lambda_KL']
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # initialize networks with correct input size
        self.E = Encoder(latent_dim=config['latent_dim']).to(self.device)
        self.G = Generator(latent_dim=config['latent_dim']).to(self.device)
        self.D = Discriminator().to(self.device)

        # initialize weights
        self.E.apply(self.weights_init)
        self.G.apply(self.weights_init)
        self.D.apply(self.weights_init)

        # initialize optimizers
        self.opt_G = AdaBelief(
            self.G.parameters(),
            lr=config['lr_g'],
            eps=1e-16,
            betas=(0.9, 0.999),
            weight_decouple=True,
            rectify=False,
            weight_decay=0.01
        )

        self.opt_E = AdaBelief(
            self.E.parameters(),
            lr=config['lr_e'],
            eps=1e-16,
            betas=(0.9, 0.999),
            weight_decouple=True,
            rectify=False,
            weight_decay=0.01
        )

        self.opt_D = AdaBelief(
            self.D.parameters(),
            lr=config['lr_d'],
            eps=1e-16,
            betas=(0.9, 0.999),
            weight_decouple=True,
            rectify=False,
            weight_decay=0.01
        )


        # initialize schedulers
        self.scheduler_G = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.opt_G,
            mode='min',
            factor=0.8,
            patience=4,
            verbose=True,
            min_lr=1e-7
        )

        self.scheduler_E = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.opt_E,
            mode='min',
            factor=0.4,
            patience=7,
            verbose=True,
            min_lr=1e-7
        )

        self.scheduler_D = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.opt_D,
            mode='min',
            factor=0.8,
            patience=5,
            verbose=True,
            min_lr=1e-7
        )

        # define loss functions
        self.criterion_GAN = nn.MSELoss()
        self.criterion_L1 = nn.L1Loss()
        self.criterion_KL = lambda mu, logvar: -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())


    @staticmethod
    def weights_init(m):
        """Initialize network weights."""
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)

    def train_step(self, source, target, lambda_L1=10, lambda_KL=0.01):
        """Single training step"""
        # add gradient clipping value
        max_grad_norm = config['max_grad_norm']

        real_label = torch.ones((source.size(0), 1, 15, 15)).to(self.device)
        fake_label = torch.zeros((source.size(0), 1, 15, 15)).to(self.device)

        # Fforward cycle (cVAE-GAN)
        self.opt_E.zero_grad()
        self.opt_G.zero_grad()

        # get latent encoding of target image
        mu, logvar = self.E(target)
        z = self.E.reparameterize(mu, logvar)

        # generate fake image
        fake_B = self.G(source, z)

        # discriminator loss for generated image
        pred_fake = self.D(source, fake_B)
        loss_G_GAN = self.criterion_GAN(pred_fake, real_label) * 0.5  # Scale down GAN loss

        # l1 loss between generated and target
        loss_G_L1 = self.criterion_L1(fake_B, target) * lambda_L1

        # add perceptual loss
        # loss_G_perceptual = config['perceptual_loss_weight'] * perceptual_loss(target, fake_B)
        loss_G_perceptual = config['perceptual_loss_weight'] * perceptual_loss(target, fake_B, feature_extractor)


        # KL loss with scaling
        loss_KL = self.criterion_KL(mu, logvar) * lambda_KL

        # backward cycle (cLR-GAN)
        z_random = torch.randn(source.size(0), self.latent_dim).to(self.device)
        fake_B_random = self.G(source, z_random)
        mu2, logvar2 = self.E(fake_B_random)

        # latent regression loss
        loss_z_L1 = self.criterion_L1(mu2, z_random) * lambda_L1

        # total generator and encoder loss
        loss_G = loss_G_GAN + loss_G_L1 + loss_KL + loss_z_L1 + loss_G_perceptual

        # check for NaN loss value
        if not torch.isnan(loss_G):
            loss_G.backward()
            # clip gradients
            torch.nn.utils.clip_grad_norm_(self.G.parameters(), max_grad_norm)
            torch.nn.utils.clip_grad_norm_(self.E.parameters(), max_grad_norm)
            self.opt_E.step()
            self.opt_G.step()

        # discriminator update
        self.opt_D.zero_grad()

        # real loss
        pred_real = self.D(source, target)
        loss_D_real = self.criterion_GAN(pred_real, real_label)

        # fake loss (cVAE-GAN)
        pred_fake1 = self.D(source, fake_B.detach())
        loss_D_fake1 = self.criterion_GAN(pred_fake1, fake_label)

        # fake loss (cLR-GAN)
        pred_fake2 = self.D(source, fake_B_random.detach())
        loss_D_fake2 = self.criterion_GAN(pred_fake2, fake_label)

        # add gradient penalty
        gp = gradient_penalty(
            self.D,
            torch.cat([source, target], dim=1),
            torch.cat([source, fake_B.detach()], dim=1)
        )

        # total discriminator loss
        loss_D = (loss_D_real + loss_D_fake1 + loss_D_fake2) * 0.5 + config['grad_penalty_weight'] * gp

        # check for NaN
        if not torch.isnan(loss_D):
            loss_D.backward()
            # clip gradients
            torch.nn.utils.clip_grad_norm_(self.D.parameters(), max_grad_norm)
            self.opt_D.step()

        return {
            'G_loss': loss_G.item() if not torch.isnan(loss_G) else 0,
            'D_loss': loss_D.item() if not torch.isnan(loss_D) else 0,
            'KL': loss_KL.item() if not torch.isnan(loss_KL) else 0,
            'L1': loss_G_L1.item() if not torch.isnan(loss_G_L1) else 0,
            'z_L1': loss_z_L1.item() if not torch.isnan(loss_z_L1) else 0,
            'Grad_Penalty': gp.item() if not torch.isnan(gp) else 0,
            'Perceptual_Loss': loss_G_perceptual.item() if not torch.isnan(loss_G_perceptual) else 0
        }


    def set_requires_grad(self, nets, requires_grad=False):
        """Set requies_grad=False for networks to avoid unnecessary computations"""
        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad


    def encode(self, input_image):
        mu, logvar = self.E(input_image)
        z = self.E.reparameterize(mu, logvar)
        return z, mu, logvar

    def generate(self, input_image, z):
        return self.G(input_image, z)


    def visualize_results(self, source, target, epoch, num_samples=4):
        """
        Visualize multiple generation results from the same input image
        Args:
            source: Source image (ADC)
            target: Ground truth target image (T2w)
            epoch: Current epoch number
            num_samples: Number of different results to generate
        """
        self.E.eval()
        self.G.eval()
        with torch.no_grad():
            # generate multiple outputs using different random latent codes
            outputs = []
            for _ in range(num_samples):
                z = torch.randn(source.size(0), self.latent_dim).to(self.device)
                fake = self.G(source, z)
                outputs.append(fake)

            # create visualization grid
            # first row: [source, target]
            # following rows: different generated outputs
            vis_images = [source[0], target[0]]  # First row
            for out in outputs:
                vis_images.append(out[0])

            # convert from [-1, 1] to [0, 1] range for visualization
            vis_images = [(img + 1) / 2 for img in vis_images]

            # create grid
            image_grid = vutils.make_grid(vis_images, nrow=2, padding=2, normalize=False)

            # plot
            plt.figure(figsize=(10, 10))
            plt.axis('off')
            plt.imshow(image_grid.cpu().numpy()[0], cmap='gray')
            plt.title(f'Epoch {epoch}\nTop: [Source | Target]\nBottom: Generated Samples')
            plt.show()
            plt.close()



    def visualize_interpolation(self, source, target, epoch, num_steps=5):
        """
        Visualize interpolation between two random latent codes
        """
        self.E.eval()
        self.G.eval()
        with torch.no_grad():
            # get two random latent codes
            z1 = torch.randn(1, self.latent_dim).to(self.device)
            z2 = torch.randn(1, self.latent_dim).to(self.device)

            # interpolate between latent codes
            alphas = torch.linspace(0, 1, num_steps)
            interpolated = []

            for alpha in alphas:
                z = alpha * z1 + (1 - alpha) * z2
                fake = self.G(source, z)
                interpolated.append(fake[0])

            # create visualization grid
            vis_images = [source[0], target[0]] + interpolated
            vis_images = [(img + 1) / 2 for img in vis_images]

            # create grid
            image_grid = vutils.make_grid(vis_images, nrow=len(vis_images), padding=2, normalize=False)

            # plot
            plt.figure(figsize=(15, 5))
            plt.axis('off')
            plt.imshow(image_grid.cpu().numpy()[0], cmap='gray')
            plt.title(f'Epoch {epoch}\nLeft: [Source | Target | Interpolated Samples]')
            plt.show()
            plt.close()

    @staticmethod
    def calculate_psnr(fake, real):
        """Calculate PSNR between fake and real images"""
        mse = F.mse_loss(fake, real)
        psnr = 20 * torch.log10(2.0 / torch.sqrt(mse))
        return psnr.item()

    @staticmethod
    def calculate_ssim(fake, real):
        """Calculate SSIM between fake and real images"""
        # implementation of SSIM calculation
        return ssim(fake, real, data_range=2.0).item()  # 2.0 for [-1,1] range


    def visualize_batch(self, source, target, generated, title, save_path=None):
        """
        Visualize a batch of images: source, target, and generated
        Args:
            source: source images
            target: target images
            generated: list of generated images (for multiple samples)
            title: title for the plot
            save_path: if provided, save the plot to this path
        """
        # convert tensors to PIL images
        def tensor_to_numpy(tensor):
            #  convert to [0, 1] (for tensor in [-1, 1] range)
            return ((tensor.cpu().detach() + 1) / 2.0).numpy()

        # create figure
        num_samples = len(generated)
        fig, axes = plt.subplots(2, num_samples + 2, figsize=(3*(num_samples + 2), 6))
        plt.suptitle(title)

        # plot source and target in first row
        axes[0, 0].imshow(tensor_to_numpy(source[0])[0], cmap='gray')
        axes[0, 0].set_title('Source (ADC)')
        axes[0, 0].axis('off')

        axes[0, 1].imshow(tensor_to_numpy(target[0])[0], cmap='gray')
        axes[0, 1].set_title('Target (T2)')
        axes[0, 1].axis('off')

        # plot generated samples
        for i, gen in enumerate(generated):
            axes[0, i+2].imshow(tensor_to_numpy(gen[0])[0], cmap='gray')
            axes[0, i+2].set_title(f'Generated {i+1}')
            axes[0, i+2].axis('off')

        # plot differences in second row
        diff_target = np.abs(tensor_to_numpy(target[0])[0] - tensor_to_numpy(source[0])[0])
        axes[1, 0].imshow(diff_target, cmap='hot')
        axes[1, 0].set_title('Diff: Target-Source')
        axes[1, 0].axis('off')

        axes[1, 1].axis('off')  # empty plot for alignment

        for i, gen in enumerate(generated):
            diff_gen = np.abs(tensor_to_numpy(gen[0])[0] - tensor_to_numpy(target[0])[0])
            axes[1, i+2].imshow(diff_gen, cmap='hot')
            axes[1, i+2].set_title(f'Diff: Gen{i+1}-Target')
            axes[1, i+2].axis('off')

        plt.tight_layout()

        if save_path:
            plt.savefig(save_path)
        plt.show()
        plt.close()



    def validate(self, val_loader, epoch, save_dir, vis=False):
        """Run validation with visualization"""
        self.E.eval()
        self.G.eval()
        self.D.eval()

        val_losses = []

        with torch.no_grad():
            for i, batch in enumerate(val_loader):
                source = batch['A'].to(self.device)
                target = batch['B'].to(self.device)

                # generate multiple samples
                generated_samples = []
                for _ in range(3):  # generate 3 different samples
                    z = torch.randn(source.size(0), self.latent_dim).to(self.device)
                    fake_B = self.G(source, z)
                    generated_samples.append(fake_B)

                # calculate validation losses
                loss_G_L1 = self.criterion_L1(generated_samples[0], target) * self.lambda_L1

                val_losses.append({
                    'val_L1': loss_G_L1.item(),
                })

                # visualize first few batches
                if i == 5 and vis:  # show validation samples
                    save_path = f"{save_dir}/validation_epoch{epoch}_batch{i}.png"
                    self.visualize_batch(
                        source, target, generated_samples,
                        f"Validation - Epoch {epoch}, Batch {i}",
                        save_path
                    )

        avg_losses = {k: sum(d[k] for d in val_losses) / len(val_losses)
                     for k in val_losses[0].keys()}

        return avg_losses



    def test(self, test_loader, num_samples=5, save_dir=None, vis = False):
        """Run testing with visualization"""
        self.E.eval()
        self.G.eval()

        test_metrics = []

        with torch.no_grad():
            for i, batch in enumerate(test_loader):
                source = batch['A'].to(self.device)
                target = batch['B'].to(self.device)

                # generate multiple samples
                generated_samples = []
                for _ in range(num_samples):
                    z = torch.randn(source.size(0), self.latent_dim).to(self.device)
                    fake_B = self.G(source, z)
                    generated_samples.append(fake_B)

                # calculate metrics
                metrics = {
                    'PSNR': BicycleGAN.calculate_psnr(generated_samples[0], target),
                    'SSIM': BicycleGAN.calculate_ssim(generated_samples[0], target),
                    'L1': self.criterion_L1(generated_samples[0], target).item()
                }
                test_metrics.append(metrics)

                # visualize
                if save_dir and vis and (i == random.choice([s for s in range(len(test_loader))])) :  # Show 2 test samples

                    random_index = random.randint(0, num_samples - 1)  # random index between 0 and num_samples - 1
                    random_sample = generated_samples[random_index]

                    save_path = f"{save_dir}/test_sample{i}_random.png"
                    self.visualize_batch(
                        source, target, [random_sample],  # pass the random sample
                        f"Test Sample {i} (Random Selection)",
                        save_path
                    )

        avg_metrics = {k: sum(d[k] for d in test_metrics) / len(test_metrics)
                      for k in test_metrics[0].keys()}

        return avg_metrics

    def get_current_lrs(self):
        return {
            'lr_G': self.opt_G.param_groups[0]['lr'],
            'lr_E': self.opt_E.param_groups[0]['lr'],
            'lr_D': self.opt_D.param_groups[0]['lr']
        }



In [None]:
def print_model_summaries(model):
    """
    Print detailed summaries for all components of the BicycleGAN model in torchsummary style
    """
    def calculate_conv_output_shape(input_size, kernel_size, stride, padding):
        return ((input_size + 2 * padding - kernel_size) // stride) + 1

    def calculate_convt_output_shape(input_size, kernel_size, stride, padding):
        return (input_size - 1) * stride - 2 * padding + kernel_size

    def print_layer_summary(layer_name, input_shape, output_shape, params):
        print(f"| {layer_name:<20} | {str(input_shape):<15} | {str(output_shape):<15} | {params:>8} |")

    def print_network_header(name):
        print(f"\n{name} Network:")
        print("=" * 71)
        print(f"| {'Layer':<20} | {'Input Shape':<15} | {'Output Shape':<15} | {'Params':>8} |")
        print("=" * 71)

    print("\n===================== BicycleGAN Model Summary =========================")

    # encoder Summary
    print_network_header("Encoder")
    curr_size = 256
    in_channels = 1

    for name, layer in model.E.named_modules():
        if isinstance(layer, nn.Conv2d):
            output_size = calculate_conv_output_shape(curr_size, layer.kernel_size[0],
                                                    layer.stride[0], layer.padding[0])
            params = layer.in_channels * layer.out_channels * layer.kernel_size[0] * layer.kernel_size[1]
            if layer.bias is not None:
                params += layer.out_channels

            print_layer_summary(
                f"Conv2d-{layer.out_channels}",
                (in_channels, curr_size, curr_size),
                (layer.out_channels, output_size, output_size),
                params
            )
            curr_size = output_size
            in_channels = layer.out_channels
        elif isinstance(layer, nn.BatchNorm2d):
            params = 2 * layer.num_features 
            print_layer_summary(
                "BatchNorm2d",
                (in_channels, curr_size, curr_size),
                (in_channels, curr_size, curr_size),
                params
            )

    # linear layers for mu and logvar
    fc_input_size = 512 * 4 * 4
    print_layer_summary(
        "Linear-mu",
        (fc_input_size,),
        (model.latent_dim,),
        fc_input_size * model.latent_dim + model.latent_dim
    )
    print_layer_summary(
        "Linear-logvar",
        (fc_input_size,),
        (model.latent_dim,),
        fc_input_size * model.latent_dim + model.latent_dim
    )

    # generator summary
    print_network_header("Generator")
    curr_size = 256
    in_channels = 1

    # generator's encoder path
    for block in model.G.down_blocks:
        for layer in block:
            if isinstance(layer, nn.Conv2d):
                output_size = calculate_conv_output_shape(curr_size, layer.kernel_size[0],
                                                        layer.stride[0], layer.padding[0])
                params = layer.in_channels * layer.out_channels * layer.kernel_size[0] * layer.kernel_size[1]
                if layer.bias is not None:
                    params += layer.out_channels

                print_layer_summary(
                    f"Conv2d-{layer.out_channels}",
                    (in_channels, curr_size, curr_size),
                    (layer.out_channels, output_size, output_size),
                    params
                )
                curr_size = output_size
                in_channels = layer.out_channels

    # generator's decoder path
    for block in model.G.up_blocks:
        for layer in block:
            if isinstance(layer, nn.ConvTranspose2d):
                output_size = calculate_convt_output_shape(curr_size, layer.kernel_size[0],
                                                         layer.stride[0], layer.padding[0])
                params = layer.in_channels * layer.out_channels * layer.kernel_size[0] * layer.kernel_size[1]
                if layer.bias is not None:
                    params += layer.out_channels

                print_layer_summary(
                    f"ConvTranspose2d-{layer.out_channels}",
                    (in_channels, curr_size, curr_size),
                    (layer.out_channels, output_size, output_size),
                    params
                )
                curr_size = output_size
                in_channels = layer.out_channels

    # discriminator summary
    print_network_header("Discriminator")
    curr_size = 256
    in_channels = 2  # concatenated source and target

    for name, layer in model.D.named_modules():
        if isinstance(layer, nn.Conv2d):
            output_size = calculate_conv_output_shape(curr_size, layer.kernel_size[0],
                                                    layer.stride[0], layer.padding[0])
            params = layer.in_channels * layer.out_channels * layer.kernel_size[0] * layer.kernel_size[1]
            if layer.bias is not None:
                params += layer.out_channels

            print_layer_summary(
                f"Conv2d-{layer.out_channels}",
                (in_channels, curr_size, curr_size),
                (layer.out_channels, output_size, output_size),
                params
            )
            curr_size = output_size
            in_channels = layer.out_channels

    # Total parameters
    total_params = (sum(p.numel() for p in model.E.parameters()) +
                   sum(p.numel() for p in model.G.parameters()) +
                   sum(p.numel() for p in model.D.parameters()))
    print("\n" + "=" * 71)
    print(f"Total Parameters: {total_params:,}")


In [None]:
# model summary
model = BicycleGAN()
print_model_summaries(model)

# Saving paths

In [22]:
save_path = config["output_dir"]
os.makedirs(save_path, exist_ok=True)
save_checkp_path = f'{config["output_dir"]}/checkpoints'
os.makedirs(save_checkp_path, exist_ok=True)

# Wandb Config

In [None]:
import wandb
wandb.login(key="")

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [24]:
run = wandb.init(
    name = config["run_name"],
    reinit = True,
    # id = ###
    # resume = "must"
    project = "project-ablations",
    config = config
)

[34m[1mwandb[0m: Currently logged in as: [33mfrancistate[0m ([33mGroupXV[0m). Use [1m`wandb login --relogin`[0m to force relogin


## Training

In [None]:
def train(model, train_loader, val_loader, test_loader, num_epochs, save_path):
    """Complete training process with visualizations"""
    losses_history = []
    val_history = []
    test_metrics = None
    best_val_loss = float('inf')
    best_psnr = 2.6692
    patience = 15 
    #epochs until early stop if no improvement in
    patience_counter = 0

    # Create directories for saving visualizations
    vis_dir = os.path.join(save_path, 'visualizations')
    val_dir = os.path.join(vis_dir, 'validation')
    test_dir = os.path.join(vis_dir, 'test')
    os.makedirs(val_dir, exist_ok=True)
    os.makedirs(test_dir, exist_ok=True)

    for epoch in range(1, num_epochs + 1):
        epoch_losses = []
        model.G.train()
        model.E.train()
        model.D.train()

        all_batches = []

        with tqdm(total=len(train_loader), desc=f"Epoch {epoch}/{num_epochs}", unit="batch") as pbar:
            for batch in train_loader:
                source = batch['A'].to(model.device)
                target = batch['B'].to(model.device)

                # store batch for potential random visualization
                all_batches.append((source, target))

                # training step
                losses = model.train_step(source, target)

                # skip this batch if NaN losses were computed
                if losses['G_loss'] == 0 and losses['D_loss'] == 0:
                    print(f"Skipping batch  due to NaN losses")
                    pbar.update(1)
                    continue

                epoch_losses.append(losses)
                pbar.set_postfix(**losses)
                pbar.update(1)


        # visualize a random batch after completing the epoch
        if epoch % 5 == 0 and all_batches:
            random_batch = random.choice(all_batches)  # Select a random batch from the epoch
            source, target = random_batch
            model.visualize_results(source, target, epoch)

        # average epoch losses
        if epoch_losses:  # Only if we have valid losses
            avg_losses = {k: sum(d[k] for d in epoch_losses) / len(epoch_losses)
                          for k in epoch_losses[0].keys()}
            wandb.log(avg_losses)
            losses_history.append(avg_losses)

        # print epoch summary
        print(f"Epoch [{epoch}/{num_epochs}] - Avg G_loss: {avg_losses['G_loss']:.4f}, "
              f"Avg D_loss: {avg_losses['D_loss']:.4f}, LRs: {model.get_current_lrs()}")

        # run validation
        val_losses = model.validate(val_loader, epoch, val_dir, vis = False) #vis=(epoch % 10 == 0))
        val_history.append(val_losses)
        wandb.log(val_losses)

        # # Periodic test evaluation
        # if epoch % 5 == 0:
        #     test_metrics = model.test(test_loader, num_samples=1, save_dir=test_dir, vis=True) # change t
        # else:
        #     test_metrics = model.test(test_loader, num_samples=1, save_dir=test_dir)

        test_metrics = model.test(test_loader, num_samples=1, save_dir=test_dir)
        # print test metrics summary
        print(f"Test Metrics - Epoch {epoch}:")
        for k, v in test_metrics.items():
            print(f"{k}: {v:.4f}")
        wandb.log(test_metrics)

        # update learning rate schedulers using test metrics (e.g., PSNR)
        model.scheduler_G.step(-test_metrics['PSNR'])
        model.scheduler_E.step(-test_metrics['PSNR'])
        model.scheduler_D.step(-test_metrics['PSNR'])

        # save the best model based on PSNR
        if test_metrics['PSNR'] > best_psnr:
            best_psnr = test_metrics['PSNR']
            checkpoint = {
                'epoch': epoch,
                'G_state_dict': model.G.state_dict(),
                'E_state_dict': model.E.state_dict(),
                'D_state_dict': model.D.state_dict(),
                'opt_G_state_dict': model.opt_G.state_dict(),
                'opt_E_state_dict': model.opt_E.state_dict(),
                'opt_D_state_dict': model.opt_D.state_dict(),
                'losses': losses_history,
                'val_losses': val_history,
                'test_metrics': test_metrics
            }
            torch.save(checkpoint, f"{save_path}/bicycle_gan_best_psnr.pt")
            wandb.save(f"{save_path}/bicycle_gan_best_psnr.pt")
            print("Saved best PSNR model")

            patience_counter = 0
            

        # # Early stopping check
        # if val_losses['val_L1'] < best_val_loss:
        #     best_val_loss = val_losses['val_L1']
        #     patience_counter = 0

            # # Save best validation model
            # checkpoint = {
            #     'epoch': epoch,
            #     'G_state_dict': model.G.state_dict(),
            #     'E_state_dict': model.E.state_dict(),
            #     'D_state_dict': model.D.state_dict(),
            #     'opt_G_state_dict': model.opt_G.state_dict(),
            #     'opt_E_state_dict': model.opt_E.state_dict(),
            #     'opt_D_state_dict': model.opt_D.state_dict(),
            #     'losses': losses_history,
            #     'val_losses': val_history
            # }
            # torch.save(checkpoint, f"{save_path}/bicycle_gan_best.pt")
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print("Early stopping triggered")
            break

        # Save regular checkpoint every 10 epochs
        # if epoch % 10 == 0:
        #     checkpoint = {
        #         'epoch': epoch,
        #         'G_state_dict': model.G.state_dict(),
        #         'E_state_dict': model.E.state_dict(),
        #         'D_state_dict': model.D.state_dict(),
        #         'opt_G_state_dict': model.opt_G.state_dict(),
        #         'opt_E_state_dict': model.opt_E.state_dict(),
        #         'opt_D_state_dict': model.opt_D.state_dict(),
        #         'losses': losses_history,
        #         'val_losses': val_history
        #     }
        #     torch.save(checkpoint, f"{save_path}/bicycle_gan_epoch_{epoch}.pt")

    # return losses_history, val_history



# Train


In [26]:
# losses, val_losses = train(model, train_loader, val_loader, test_loader,
#                          num_epochs=config['epochs'], save_path=save_path)

In [None]:
train(model, train_loader, val_loader, test_loader,
      num_epochs=config['epochs'], save_path=save_path)