## Diffusion Model for FastMRI - Training

This code is for training a denoising diffusion probabilistic model (DDPM) to conditionally generate brain MR images of different contrasts.

In [None]:
from tqdm import tqdm
from datetime import datetime
import os

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from torchvision import transforms
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt

from fastmri_brain_data_builder import DataGenImagesOnly, DataGenImagesDownsampled
from unet import real_to_complex, complex_to_real, ContextUnet, BigContextUnet
from ddpm import DDPM

from skimage.metrics import peak_signal_noise_ratio

device = torch.device("cuda")

#### Initialize training settings

In [None]:
n_epoch = 500
batch_sz = 8
num_sub = 20      # number of subjects of each contrast
res = 160         # resolution after downsampling
n_T = 500         # noise levels
complex_in = False
num_feat = 64

load_save = True
load_save_dir = './cond_ddpm_models/May15_1510_mag/'
chkpoint = 500

#### Load dataset

We will be using FastMRI multicoil brain data, which contains images of four different contrasts: FLAIR, T1POST, T1, and T2.

In [None]:
data_path = '../LSS/lss_jcb/aniket/FastMRI brain data/'
ds = DataGenImagesDownsampled(start_sub=0, num_sub=num_sub, device=device, res=res, complex_in=complex_in, data_path=data_path)

num_classes = len(ds.labels_key)
num_slices = len(ds)
idx = 0
img_shape = ds[idx][0].shape
print("Num images:", num_slices)
print(f"Image {idx} shape:", ds[idx][0].shape)
print(f"Image {idx} min/max:", ds[idx][0].abs().min(), "/", ds[idx][0].abs().max())

### Setup before training

#### Set up save directory

In [None]:
timestamp = datetime.now()
using_mag = "comp" if complex_in else "mag"
save_dir = f"./cond_ddpm_models/{timestamp.strftime('%b%d_%H%M')}_{using_mag}"
if not(os.path.exists(save_dir)):
    os.makedirs(save_dir)
    
print(f'\nSaving to {save_dir}\n')

#### Initialize score matching network (e.g. a U-net)

In [None]:
# eps_model = ContextUnet(complex_in=complex_in, n_classes=num_classes)
eps_model = ContextUnet(complex_in=complex_in, n_feat=num_feat, n_classes=num_classes)

# Helper functions to log progress and save model weights
def log(s, mute=True):
    with open(save_dir+'/log.txt', 'a', encoding='utf8') as f: f.write(s+"\n")
    if not mute: print(s)
        
def save_checkpoint(states, path, filename='model_best.pth.tar'):
    checkpoint_name = os.path.join(path, filename)
    torch.save(states, checkpoint_name)

#### Initialize diffusion model with score matching network

In [None]:
# Initialize 
ddpm = DDPM(eps_model=eps_model, betas=(1e-4, 0.02), n_T=n_T, complex_in=complex_in, device=device)
ddpm.to(device)

if load_save:    
    weights_fname = load_save_dir + f'/ddpm_ep{chkpoint}.pth'
    ddpm.load_state_dict(torch.load(weights_fname))
    log(f"Loaded saved weights from {weights_fname}", False)

optim = torch.optim.Adam(ddpm.parameters(), lr=1e-4)


### Training

In [None]:
dataloader = DataLoader(ds, batch_size=batch_sz, shuffle=True, num_workers=1)

with tqdm(total=n_epoch, position=0, leave=True) as pbar:
    for i in range(1, n_epoch+1):
        ddpm.train()
        loss_ema = None
        
        for x, label in dataloader:
            
            x, label = x.to(device), label.to(device)
            
            optim.zero_grad()
            x = x.to(device)
            loss = ddpm(x, label)
            loss.backward()
            if loss_ema is None:
                loss_ema = loss.item()
            else:
                loss_ema = 0.9 * loss_ema + 0.1 * loss.item()
            
            optim.step()
        
        summary = f"[Epoch {i}]\tloss: {loss_ema:.6f}"
        pbar.set_description(summary)
        pbar.update(1)
            
        if i == 1 or i % 50 == 0:
            log(summary)
                
            ddpm.eval()
            
            with torch.no_grad():
                xh = ddpm.sample(batch_sz, (1, img_shape[-2], img_shape[-1]))
                if complex_in: xh = xh.abs()
                grid = make_grid(xh, nrow=4)
                save_image(grid, save_dir+f"/ddpm_sample_{i:03d}.png")

                # save model
                torch.save(ddpm.state_dict(), save_dir+f"/ddpm_ep{i}.pth")
                
pbar.close()

### Sampling

#### Generate samples

In [None]:
ddpm.eval()
with torch.no_grad():
    out_samples_list, nT_list = ddpm.sample_all((1, img_shape[-2], img_shape[-1]))
    out_samples = out_samples_list[-1]

#### Display images

In [None]:
fig, ax = plt.subplots(1, num_classes, figsize=(20,10))
plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0.01)
for i in range(num_classes):
    img = out_samples_list[-1][i,0]
    if complex_in: img = np.abs(img)   
    # img = np.abs(img)   
    ax[i].imshow(img, cmap='gray', vmin=0., vmax=1.)
    ax[i].axis('off')
    ax[i].set_title(f'{ds.labels_key[i]} sample')
fig.suptitle('Final samples')
    
plt.savefig(save_dir+f'/sample_all_contrasts_{chkpoint}.png')
plt.show()