In [8]:
import torch
import os
import yaml
from omegaconf import OmegaConf

os.chdir("/lustre/scratch/bakerh/cvpr/CVPR2025/") 

from src.datamodules.Datamodules_eval import Brats21
from src.datamodules.Datamodules_train import IXI
from src.models.DDPM_2D import DDPM_2D
fold = 1
t1_ddpm_checkpoint = f'/lustre/scratch/bakerh/cvpr/diffusion_models_checkpoint_gaussian/last_fold-{fold}.ckpt'
self = DDPM_2D.load_from_checkpoint(t1_ddpm_checkpoint, strict=False)


def remove_underscore_keys(d):
    return {k: v for k, v in d.items() if not k.startswith("_")}

with open("./configs/datamodule/IXI.yaml", "r") as file:
    config = yaml.safe_load(file)
config = remove_underscore_keys(config)
config = config['cfg']
config['mode'] = 't1'
config['data_dir'] = '/lustre/scratch/bakerh/cvpr/data/Data/'
config['sample_set'] = True
config['num_workers'] = 0
config['rescaleFactor'] = 2
config['imageDim'] = [192,192,100]
config['num_folds'] = fold

config = OmegaConf.create(config)
config = OmegaConf.create(OmegaConf.to_container(config, resolve=True))
lightningDataModule = Brats21(config)

lightningDataModule.setup()
dataloader = lightningDataModule.test_dataloader()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


self.eval()


DDPM_2D(
  (diffusion): GaussianDiffusion(
    (model): UNetModel(
      (time_embed): Sequential(
        (0): Linear(in_features=128, out_features=512, bias=True)
        (1): SiLU()
        (2): Linear(in_features=512, out_features=512, bias=True)
      )
      (input_blocks): ModuleList(
        (0): TimestepEmbedSequential(
          (0): Conv2d(1, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (1-3): 3 x TimestepEmbedSequential(
          (0): ResBlock(
            (in_layers): Sequential(
              (0): GroupNorm32(32, 128, eps=1e-05, affine=True)
              (1): SiLU()
              (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            )
            (h_upd): Identity()
            (x_upd): Identity()
            (emb_layers): Sequential(
              (0): SiLU()
              (1): Linear(in_features=512, out_features=256, bias=True)
            )
            (out_layers): Sequential(
              (0): GroupNorm3

In [None]:
import torchio as tio
import torch.nn.functional as F
import numpy as np
import cv2
import matplotlib.pyplot as plt
from torchvision.utils import make_grid

noise_type = 'gaussian'
testing_timesteps = 300
dilation_size = 24

def gen_noise(type,shape):
    if type == 'gaussian':
        return torch.randn(shape)


def visualize_predictions(x_hat, x, seg, dilated_seg, step=5):
    B, C, W, D = x.shape
    start = 10
    end = D - 10
    indices = list(range(start, end, step))

    if x_hat.shape[2:] != x.shape[2:]:
        x_hat = F.interpolate(x_hat, size=(W, D), mode='bilinear', align_corners=False)
    if dilated_seg.shape[2:] != x.shape[2:]:
        dilated_seg = F.interpolate(dilated_seg, size=(W, D), mode='bilinear', align_corners=False)

    for b in range(B):
        all_slices = []
        for i in indices:
            x_slice = x[b, 0, :, i]
            x_hat_slice = x_hat[b, 0, :, i]
            seg_slice = seg[b, 0, :, i]
            dilated_slice = dilated_seg[b, 0, :, i]
            error_slice = torch.abs(x_slice - x_hat_slice)

            stack = torch.stack([
                x_slice, x_hat_slice, seg_slice, dilated_slice, error_slice
            ], dim=0)  # shape: (5, W)
            all_slices.append(stack)

        grid_tensor = torch.stack(all_slices, dim=0)  # (N, 5, W)
        grid_tensor = grid_tensor.permute(1, 0, 2).unsqueeze(2)  # (5, N, 1, W)
        grid_flat = grid_tensor.reshape(-1, 1, W)  # (5*N, 1, W)

        grid = make_grid(grid_flat.unsqueeze(1), nrow=len(indices), normalize=True, pad_value=1)

        plt.figure(figsize=(15, 6))
        plt.imshow(grid.squeeze().cpu(), cmap='gray', aspect='auto')
        plt.axis('off')
        plt.show()

    
def dilate_masks(masks, step):
    kernel = np.ones((step, step), np.uint8)

    dilated_masks = torch.zeros_like(masks)
    for i in range(masks.shape[0]):
        mask = masks[i][0].detach().cpu().numpy()
        if np.sum(mask) < 1:
            dilated_masks[i] = masks[i]
            continue
        dilated_mask = cv2.dilate(mask, kernel, iterations=1)
        dilated_mask = torch.from_numpy(dilated_mask).to(masks.device).unsqueeze(dim=0)
        dilated_masks[i] = dilated_mask

    return dilated_masks
    
for batch in dataloader:
        self.dataset = batch['Dataset']
        input = batch['vol'][tio.DATA]
        data_orig = batch['vol_orig'][tio.DATA]
        data_seg = batch['seg_orig'][tio.DATA] if batch['seg_available'][0] else torch.zeros_like(data_orig)
        data_mask = batch['mask_orig'][tio.DATA]
        ID = batch['ID'][0]
        age = batch['age'][0]
        self.stage = batch['stage'][0]
        label = batch['label'][0]
        
        dilated_seg = dilate_masks(data_seg,dilation_size)
        _, x_hat = self(input,cond=None,t=torch.tensor([testing_timesteps],device=device),
                                          noise=gen_noise(noise_type, input.shape).to(device),occlusion_mask=dilated_seg)
        visualize_predictions(x_hat, input, data_seg, dilated_seg)

