## Latent Diffusion Model

In [None]:
import numpy as np
import os
import matplotlib.pyplot as plt

import torch
import torchvision
from torch.utils.data import DataLoader

# library for loading and manipulating medical images
import torchio as tio

# import custom src code
from src_latent.architecture import UNet
from src_latent.autoencoder import AutoEncoder, VAE, train_ae, train_vae, load_ae, test_model
from src_latent.database_toy import LightSourceDB, sample_batch_toy
from src_latent.testing import sample_latent, load_unet
from src_latent.training import train_model, get_noise_scheduler

# reflect changes in src code immediately without restarting kernel
%load_ext autoreload
%autoreload 2

### Set Device

In [None]:
# set the device depending on available GPU
if torch.backends.mps.is_available():
    device = torch.device("mps:0")  # use mac GPU if available (first mps index in case multiple)
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(device)

### Create and Train Autoencoders
Train autoencoders for input and target data separately. Later use input encoder and target decoder for LDM.

In [None]:
# latent dimension of embeddings produced by encoder
latent_chan = 8

# create model and move to device
model = AutoEncoder(latent_chan=latent_chan)
model.to(device)

# training hyperparameters
batch_size = 64
epochs = 200

# optimizer (Adam seems to perform better than SGD)
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# loss function
loss_fn = torch.nn.MSELoss()

In [None]:
# dynamically set number of workers to optimize use of cores
def get_num_workers():
    try:
        num_cpus = os.cpu_count()
        # heuristic: leave 1–2 cores free
        workers = max(1, num_cpus - 2)
        return workers
    except:
        return 4  # fallback default

num_workers = get_num_workers()
print(num_workers)

In [None]:
'''
# create dataset and dataloader
train_set = LightSourceDB(num_samples=2048, method="random") 
train_loader = DataLoader(dataset=train_set, batch_size=batch_size, 
                            num_workers=num_workers,
                            pin_memory=torch.cuda.is_available(),  # speeds up GPU transfer
                            persistent_workers=True)

# sanity check: print first batch of data
sample_batch_toy(train_loader)
'''

In [None]:
# import custom src code
from src_diffusion.architecture import UNet
from src_diffusion.database import MRImagesDB
from src_diffusion.training import sample_batch, train_model, get_noise_scheduler
from src_diffusion.testing import load_model, sample

# load files
data_dir_path = '/cs/student/projects3/cgvi/2024/morrison/DWsynth_project/'
bvals_path = data_dir_path + 'bvals_round.bval'
bvecs_path = data_dir_path + 'bvecs.bvec'
img_dir_path = data_dir_path + 'train'

# volume dimensions  [1, H, W, D]
volume_dims = tio.ScalarImage(data_dir_path + 'train/sub-051-01/anat/sub-051-01_t1.nii.gz').data.shape

# choose axis to slice along: 0 = saggital, 1 = coronal, 2 = horizontal
slice_axis = 2


# create dataset and dataloader
train_set = MRImagesDB(img_dir_path, bvals_path, bvecs_path, volume_dims, num_samples=1000, slice_axis=slice_axis)
train_loader = DataLoader(dataset=train_set, batch_size=batch_size,
                            num_workers=num_workers,
                            pin_memory=torch.cuda.is_available(),  # speeds up GPU transfer
                            persistent_workers=True)

# display sample from batch to verify data
sample_batch(DataLoader(dataset=train_set, batch_size=batch_size))

In [None]:
# select data type (blob or shadow)
data_type="blob"

# train the model, saves weights in model folder, and plots loss curve
train_ae(model, device, train_loader, loss_fn, optimizer, epochs, batch_size, learning_rate, data_type=data_type)

In [None]:
# so above we have started training an autoencoder for the blobs which is the anatomical images
# now we will train an autoencoder for the shadow/ diffusion images 
# then we can use these for latent diffusion on the real data


# latent dimension of embeddings produced by encoder
latent_chan = 8

# create model and move to device
model = AutoEncoder(latent_chan=latent_chan)
model.to(device)

# training hyperparameters
batch_size = 64
epochs = 300

# optimizer (Adam seems to perform better than SGD)
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# loss function
loss_fn = torch.nn.MSELoss()

num_workers = get_num_workers()

# create dataset and dataloader
train_set = MRImagesDB(img_dir_path, bvals_path, bvecs_path, volume_dims, num_samples=2048, slice_axis=slice_axis)
train_loader = DataLoader(dataset=train_set, batch_size=batch_size,
                            num_workers=num_workers,
                            pin_memory=torch.cuda.is_available(),  # speeds up GPU transfer
                            persistent_workers=True)

# display sample from batch to verify data
sample_batch(DataLoader(dataset=train_set, batch_size=batch_size))

# select data type (blob or shadow)
data_type="shadow"

# train the model, saves weights in model folder, and plots loss curve
train_ae(model, device, train_loader, loss_fn, optimizer, epochs, batch_size, learning_rate, data_type=data_type)

In [None]:
# added parameters to testing function

num_samples = 5
data_type = "shadow"
model_path = "models_diffusion/ae_dw.pth"

loaded_model = load_ae(model_path, device)

In [None]:
# test model on 5 random unseen samples
inputs, preds = test_model(loaded_model, device, num_samples, num_workers, super_title="", data_type=data_type)
print(inputs[0].max(), preds[0].max())
print(inputs[0].min(), preds[0].min())

### Separate the Encoder and Decoder

In [None]:
# take the encoder from the blob autoencoder
encoder = loaded_model.encoder
encoder.to(device)

# sample a batch of blobs from training loader
test_set = LightSourceDB(num_samples=5, method="random") 
test_loader = DataLoader(dataset=test_set, batch_size=5)
sample_blobs = next(iter(test_loader))[0]

# input blob images and get latent rep
encoder.eval()
with torch.no_grad():
    latent_blobs = encoder(sample_blobs.to(device))

# convert to numpy for plotting
sample_blobs = sample_blobs.squeeze().cpu().numpy()


# test putting the latent codes back into the decoder
decoder = loaded_model.decoder
decoder.to(device)

# reconstruct the blobs from the latent codes
decoder.eval()
with torch.no_grad():
    reconstructed_blobs = decoder(latent_blobs)

# convert to numpy for plotting
reconstructed_blobs = reconstructed_blobs.detach().squeeze().cpu().numpy()

# display
fig, axes = plt.subplots(2, 5, figsize=(15, 6))

for i in range(5):
    # Plot original images in the first row
    axes[0, i].imshow(sample_blobs[i])
    axes[0, i].axis('off')
    if i == 0:
        axes[0, i].set_title('Ground Truth')

    # Plot reconstructed images in the second row
    axes[1, i].imshow(reconstructed_blobs[i])
    axes[1, i].axis('off')
    if i == 0:
        axes[1, i].set_title('Reconstructed')

plt.tight_layout()
plt.show()

### Lets try latent diffusion
- encorporate the encoder and decoder into the training code
- then just set up diffusion as normal??

In [None]:
# set the device depending on available GPU
if torch.backends.mps.is_available():
    device = torch.device("mps:0")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(device)

In [None]:
# load autoencoders for shadow and blob
shadow_ae = load_ae("models_diffusion/ae_dw.pth", device)
blob_ae = load_ae("models_diffusion/ae_anat.pth", device)

# extract encoder and decoder components
encoder_shadow = shadow_ae.encoder
decoder_shadow = shadow_ae.decoder
encoder_blob = blob_ae.encoder

# move all to device
encoder_shadow = encoder_shadow.to(device)
decoder_shadow = decoder_shadow.to(device)
encoder_blob = encoder_blob.to(device)

In [None]:
# set the conditioning dimension
# 3 for toy data:   1 for timestep, 2 for angle (x, y) on unit circle
# 5 for MRI data:   1 for timestep, 3 for (x,y,z) bvec and 1 for scalar bval
# 8 for better MRI: 1 for timestep, 6 for transformed bvec and 1 for scalar bval
cond_dim = 8     

# change to 2 channels for shape image + noisy output img (shadow/dwi)
# now going to be 2 multiplied by the latent dim
latent_chan = 8
in_chan = 2 * latent_chan
out_chan = latent_chan  # out_chan should be same as latent_chan, then we use decoder to get back to original chan

# create model and move to device
model = UNet(in_chan=in_chan, out_chan=out_chan, cond_dim=cond_dim)
model.to(device)

# training hyperparameters
batch_size = 100
epochs = 3000

# optimizer (Adam seemed to work better than SGD)
learning_rate = 0.0001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# loss function
loss_fn = torch.nn.MSELoss()

# choose axis to slice along: 0 = saggital, 1 = coronal, 2 = horizontal
slice_axis = 2

# diffusion hyperparameters (from DDPM paper)
timesteps = 1000
beta_start, beta_end = 1e-4, 0.02

In [None]:
# or test model on toy data
#train_set = LightSourceDB(num_samples=1000, method="random") 
#train_loader = DataLoader(dataset=train_set, batch_size=batch_size)

# display sample from batch to verify data
#sample_batch_toy(DataLoader(dataset=train_set, batch_size=batch_size))

In [None]:
# load files
data_dir_path = '../DWsynth_project/'
bvals_path = data_dir_path + 'bvals_round.bval'
bvecs_path = data_dir_path + 'bvecs.bvec'
img_dir_path = data_dir_path + 'train'

# volume dimensions  [1, H, W, D]
volume_dims = tio.ScalarImage(data_dir_path + 'train/sub-051-01/anat/sub-051-01_t1.nii.gz').data.shape

# choose axis to slice along: 0 = saggital, 1 = coronal, 2 = horizontal
slice_axis = 2


# create dataset and dataloader
train_set = MRImagesDB(img_dir_path, bvals_path, bvecs_path, volume_dims, num_samples=1000, slice_axis=slice_axis)
train_loader = DataLoader(dataset=train_set, batch_size=batch_size,
                            num_workers=num_workers,
                            pin_memory=torch.cuda.is_available(),  # speeds up GPU transfer
                            persistent_workers=torch.cuda.is_available())

# display sample from batch to verify data
sample_batch(DataLoader(dataset=train_set, batch_size=batch_size))

In [None]:
# decide signal type for training
data_type = 'shadows'             # 'blobs' or 'shadows'

# train model, saves weights in model folder, and plots loss curve
train_model(model, device, train_set, train_loader, loss_fn, optimizer,
            epochs, batch_size, learning_rate, timesteps, beta_start, beta_end, data_type, encoder_shadow, encoder_blob)

In [None]:
# same unet dims as above
model_path = "models_diffusion/best_latent_diffusion_model.pth"
loaded_model = load_unet(model_path, device, in_chan, out_chan, cond_dim)

In [None]:
img_shape = next(iter(DataLoader(dataset=LightSourceDB(num_samples=1, method="random"), batch_size=1)))[0][0].shape
print(img_shape)

In [None]:
# sample from diffusion model
n_samples = 5

# noise scheduler
timesteps = 1000
beta_start, beta_end = 1e-4, 0.02
betas, alphas, alphas_bar = get_noise_scheduler('linear', timesteps, beta_start, beta_end, device)

# image shape
img_shape = volume_dims[:3]    # (C, H, W). for MRI data
#img_shape = next(iter(DataLoader(dataset=LightSourceDB(num_samples=1, method="random"), batch_size=1)))[0][0].shape    #(C, H, W)

# create test data
#test_loader = DataLoader(dataset=LightSourceDB(num_samples=n_samples, method="random"), batch_size=n_samples)
img_dir_path = data_dir_path + 'test'
test_set = MRImagesDB(img_dir_path, bvals_path, bvecs_path, volume_dims, num_samples=n_samples, slice_axis=slice_axis)
test_loader = DataLoader(dataset=test_set, batch_size=n_samples)

In [None]:
sampler='ddpm'
#ddpm_samples = sample(loaded_model, test_loader, n_samples, timesteps, betas, alphas, alphas_bar, device, sampler, 
#                      encoder_blob, decoder_shadow)

decoded_samples, gt_shadows, blobs = sample_latent(loaded_model, test_loader, n_samples, timesteps, beta_start, beta_end, device,
           sampler, encoder_blob, decoder_shadow, ddim_steps=None, ddim_eta=0.0)

In [None]:
import matplotlib.pyplot as plt

def plot_samples(anat, gt, preds):
    """
    anat, gt, preds: tensors of shape [b, 1, 96, 96]
    """
    b = anat.shape[0]

    fig, axes = plt.subplots(b, 3, figsize=(9, 3*b))

    if b == 1:  # special case if batch size = 1
        axes = axes[None, :]  # add batch dim

    # set column titles only for the first row
    col_titles = ["Anatomical", "Ground Truth", "Prediction"]
    for j, title in enumerate(col_titles):
        axes[0, j].set_title(title)

    for i in range(b):
        # anatomical image
        axes[i, 0].imshow(anat[i, 0].detach().cpu().numpy(), cmap="gray", origin="lower")
        axes[i, 0].axis("off")

        # ground truth
        axes[i, 1].imshow(gt[i, 0].detach().cpu().numpy(), cmap="gray", origin="lower")
        axes[i, 1].axis("off")

        # prediction
        axes[i, 2].imshow(preds[i, 0].detach().cpu().numpy(), cmap="gray", origin="lower")
        axes[i, 2].axis("off")

    plt.tight_layout()
    plt.show()

In [None]:
plot_samples(decoded_samples, gt_shadows, blobs)

### SSIM and PSNR

In [None]:
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import os
import numpy as np

from real_src.utils import load_model, get_num_workers, set_device, batch_metrics
from src_latent.testing import sample_latent
from real_src.database_mri import MRImagesDB
from torch.utils.data import DataLoader

%load_ext autoreload
%autoreload 2

In [None]:
# set the device depending on available GPU
device = set_device()
num_workers = get_num_workers()

In [None]:
# set unet dims
cond_dim = 8        # 1 for timestep, 6 for bvec, 1 for bval
in_chan = 2         # change to 2 channels for noisy img + shape image   
latent_chan = 8
in_chan = 2 * latent_chan # multiplied by the latent dim
out_chan = latent_chan  # out_chan should be same as latent_chan, then we use decoder to get back to original chan

model_path = "models_diffusion/best_latent_diffusion_model.pth"
loaded_model = load_model(model_path, device, cond_dim, in_chan, out_chan)

In [None]:
# load encoder blob and decoder shadow
from src_latent.autoencoder import load_ae

shadow_ae = load_ae("models_diffusion/ae_dw.pth", device)
blob_ae = load_ae("models_diffusion/ae_anat.pth", device)

# extract encoder and decoder components
encoder_shadow = shadow_ae.encoder
decoder_shadow = shadow_ae.decoder
encoder_blob = blob_ae.encoder

# move all to device
encoder_shadow = encoder_shadow.to(device)
decoder_shadow = decoder_shadow.to(device)
encoder_blob = encoder_blob.to(device)

In [None]:
# data paths
data_dir_path = '../DWsynth_project/'
img_dir_path = os.path.join(data_dir_path, 'test')
bvals_path = os.path.join(data_dir_path, 'bvals_round.bval')
bvecs_path = os.path.join(data_dir_path, 'bvecs.bvec')

# load bvals and bvecs
bvals = np.loadtxt(bvals_path)
bvecs = np.loadtxt(bvecs_path)

# volume dimensions  [1, H, W, D]
volume_dims = [1, 96, 96, 70]

# choose method of generating images
method = 'latent'

# list of 50 test subjects
subjects = next(os.walk(img_dir_path))[1]

# indices where bvals non zero (dw)
dw_inds = np.where(bvals > 0)[0]

# num samples (number of non-zero dw volumes = 105 out of 118 total volumes)
num_samples = len(dw_inds)

# decide slice index (maybe middle slice)
slice_idx = 35

In [None]:
# to create pandas df
ssim_results_list = []
psnr_results_list = []

ctr = 0

# iterate through each subject
for subject in subjects:

    # create dataloader - just needs subject and slice idx
    test_set = MRImagesDB(img_dir_path, bvals_path, bvecs_path, volume_dims, num_samples=num_samples, slice_axis=2, 
                              subject=subject, slice_idx=slice_idx)
    test_loader = DataLoader(dataset=test_set, batch_size=num_samples)

    if method == "latent":
        sampler = 'ddpm'
        timesteps = 1000
        beta_start, beta_end = 1e-4, 0.02
        
        ###
        ddpm_sample, gt_shadows = sample_latent(loaded_model, test_loader, num_samples, timesteps, beta_start, beta_end, device, sampler, 
                                                encoder_blob, decoder_shadow)
        
        preds = ddpm_sample.detach().cpu().numpy()[:,0,:,:]     # get rid of channel dim so just [B, 96, 96]
        targets = gt_shadows.detach().cpu().numpy()[:,0,:,:]

    ctr += 1
    print(ctr)

    # get ssim and psnr esults for the batch (subject) and add to master list
    ssim_results, psnr_results = batch_metrics(preds, targets)
    ssim_results_list.append(ssim_results)
    psnr_results_list.append(psnr_results)

In [None]:
import pandas as pd 


all_dfs = []

# match bvals  (105 non zero bvals in order)
matched_bvals = [bvals[i] for i in dw_inds]

for subj in range(50):

    # store as df
    df = pd.DataFrame({
        "subject": subj,
        "bval": matched_bvals,
        "ssim": ssim_results_list[subj],
        "psnr": psnr_results_list[subj]
    })
    all_dfs.append(df)

# concatenate all subjects into one big dataframe
df_all = pd.concat(all_dfs, ignore_index=True)

# mean SSIM and PSNR per bval across all subjects
mean_metrics = df_all.groupby("bval")[["ssim", "psnr"]].mean()
print(mean_metrics)

In [None]:
# SAVE DATAFRAME!!!

df_all.to_csv("metrics_df_latent.csv", index=False)