# DDPM using MONAI generative model


In [None]:
import os
import shutil
import tempfile
import time

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from monai import transforms
from monai.apps import MedNISTDataset
from monai.config import print_config
from monai.data import CacheDataset, DataLoader
from monai.utils import first, set_determinism
from torch.amp import GradScaler, autocast
from tqdm import tqdm

from generative.inferers import DiffusionInferer
from generative.networks.nets import DiffusionModelUNet
from generative.networks.schedulers import DDPMScheduler

In [None]:
# Set pwd directory

DATA_DIR = "/home/kdang/projects/spartDM/data/starmen/output_random_noacc"
os.chdir("/home/kdang/projects/spartDM/")
print("Current working directory:", os.getcwd())

In [None]:
from src.data import StarmenDataset

train_ds = StarmenDataset(data_dir=DATA_DIR, 
                            split="train", 
                            nb_subject=10,
                            save_data=False, 
                            workdir="workdir/test")
train_loader = DataLoader(train_ds, 
                            batch_size=2,
                            shuffle=True, 
                            num_workers=1)


val_ds = StarmenDataset(data_dir=DATA_DIR, 
                            split="val", 
                            nb_subject=10,
                            save_data=False, 
                            workdir="workdir/test")
val_loader = DataLoader(val_ds, 
                            batch_size=2,
                            shuffle=True, 
                            num_workers=1)


In [None]:
# Check training data
from einops import rearrange 

check_data = first(train_loader)

# rearrange data from B T ... -> (B T) ...
check_data = rearrange(check_data["x_origin"], "b t ... -> (b t) ...")
print(f"batch shape: {check_data.shape}")



image_visualisation = torch.cat(
    [check_data[i, 0] for i in range(6)], dim=1
)
plt.figure("training images", (12, 6))
plt.imshow(image_visualisation, vmin=0, vmax=1, cmap="gray")
plt.axis("off")
plt.tight_layout()
plt.show()

# Train the DDPM model

In [None]:
progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), ncols=70)
step, item = first(progress_bar)
x = item["x_origin"]
x = rearrange(x, "b t ... -> (b t) ...")
x.shape

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = DiffusionModelUNet(
    spatial_dims=2,
    in_channels=1,
    out_channels=1,
    num_channels=(128, 256, 256),
    attention_levels=(False, True, True),
    num_res_blocks=1,
    num_head_channels=256,
)
model.to(device)

scheduler = DDPMScheduler(num_train_timesteps=1000)

optimizer = torch.optim.Adam(params=model.parameters(), lr=2.5e-5)

inferer = DiffusionInferer(scheduler)


###
# Train the model
###

n_epochs = 5
val_interval = 5
epoch_loss_list = []
val_epoch_loss_list = []

scaler = GradScaler(device=device)
total_start = time.time()


for epoch in range(n_epochs):
    model.train()
    epoch_loss = 0
    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), ncols=70)
    progress_bar.set_description(f"Epoch {epoch}")
    for step, batch in progress_bar:
        x = batch["x_origin"]
        x = rearrange(x, "b t ... -> (b t) ...").to(device)
        optimizer.zero_grad(set_to_none=True)

        with autocast(device_type=device.type, enabled=True):
            # Generate random noise
            noise = torch.randn_like(x).to(device)

            # Create timesteps
            timesteps = torch.randint(
                0, inferer.scheduler.num_train_timesteps, (x.shape[0],), device=x.device
            ).long()

            # Get model prediction
            noise_pred = inferer(inputs=x, diffusion_model=model, noise=noise, timesteps=timesteps)

            loss = F.mse_loss(noise_pred.float(), noise.float())

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        epoch_loss += loss.item()

        progress_bar.set_postfix({"loss": epoch_loss / (step + 1)})
    epoch_loss_list.append(epoch_loss / (step + 1))

