# T1 + T2 MRI → Multi-tracer PET Synthesis

## 1) Imports & Reproducibility

In [1]:
import os
import time
import random
from pathlib import Path

import nibabel as nib
import numpy as np
import torch
import torch.nn.functional as F
import torchio as tio
from monai.utils import set_determinism
from torch.cuda.amp import GradScaler, autocast
from torchmetrics.functional.image import peak_signal_noise_ratio, structural_similarity_index_measure
from tqdm import tqdm

from generative.inferers import DiffusionInferer
from generative.losses import RelativisticPatchAdversarialLoss
from generative.networks.nets import DiffusionModelUNet, PatchDiscriminator
from generative.networks.schedulers import DDPMScheduler

  from .autonotebook import tqdm as notebook_tqdm
A matching Triton is not available, some optimizations will not be enabled.
Error caught was: No module named 'triton'


In [2]:
# Reproducibility
SEED = 42
set_determinism(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cuda


## 2) Data paths & mask

In [3]:
from pathlib import Path
import os
import random

import nibabel as nib
import torch
import torchio as tio

# -------------------------
# Data loading (TorchIO) — matches the original notebook logic
# -------------------------
DATA_ROOT = Path("F:/Dataset/NFLLONG new normalised/all modalities")
MASK_PATH = Path("F:/Dataset/NFLLONG new normalised/resized_mask_181_217_181.nii.gz")

# Load brain mask (X, Y, Z) numpy array
mask_img = nib.load(str(MASK_PATH))
mask = mask_img.get_fdata()


## 3) Build TorchIO Subjects (subject-level split; tracer-conditioned targets)

In [4]:
# List subject folders
sbj_ids = os.listdir(DATA_ROOT)

# 1) Filter subjects with all required modalities
subjects_with_complete_modality = []
for sbj_id in sbj_ids:
    base_path = DATA_ROOT / sbj_id
    mod1_list = list(base_path.glob("*_T1_CS.nii"))
    mod2_list = list(base_path.glob("*_PBR_CS.nii"))
    mod3_list = list(base_path.glob("*_PIB_CS.nii"))
    mod4_list = list(base_path.glob("*_TAU_CS.nii"))
    mod5_list = list(base_path.glob("w*.nii"))  # auxiliary MRI (e.g., T2-FLAIR)

    if all([len(mod1_list) == 1, len(mod2_list) == 1, len(mod3_list) == 1, len(mod4_list) == 1, len(mod5_list) == 1]):
        subjects_with_complete_modality.append(
            tio.Subject(
                source_1=tio.ScalarImage(mod1_list[0]),  # T1
                source_2=tio.ScalarImage(mod5_list[0]),  # T2F
                modality_2=tio.ScalarImage(mod2_list[0]),  # PBR
                modality_3=tio.ScalarImage(mod3_list[0]),  # PIB
                modality_4=tio.ScalarImage(mod4_list[0])   # TAU
            )
        )

dataset_with_complete_modality = tio.SubjectsDataset(subjects_with_complete_modality)
print("Complete-modality dataset size:", len(dataset_with_complete_modality), "subjects")

# 2) Choose a held-out subset of physical subjects (optional)
SEED = 0
random.seed(SEED)
val_size = 10
indexes = random.sample(range(len(dataset_with_complete_modality)), min(val_size, len(dataset_with_complete_modality)))
heldout_sbjs = {str(dataset_with_complete_modality[i]["source_1"]["stem"])[:5] for i in indexes}

# 3) Build per-tracer samples (each subject contributes up to 3 samples: PBR/PIB/TAU)
subjects_train, subjects_sample = [], []

def append_subject(sbj_id, src_path, tgt_path, t2f_path, label, is_sample):
    subject = tio.Subject(
        source_1=tio.ScalarImage(src_path),
        source_2=tio.ScalarImage(t2f_path),
        target_modality=tio.ScalarImage(tgt_path),
        label=label
    )
    # Apply mask (zeros background)
    subject["source_1"]["data"] *= mask
    subject["source_2"]["data"] *= mask
    subject["target_modality"]["data"] *= mask

    if is_sample:
        subjects_sample.append(subject)
    else:
        subjects_train.append(subject)

for sbj_id in sbj_ids:
    is_sample = sbj_id in heldout_sbjs
    base_path = DATA_ROOT / sbj_id
    try:
        t1 = list(base_path.glob("*_T1_CS.nii"))[0]
        t2f = list(base_path.glob("w*"))[0]
        for label, target_glob in enumerate(["*_PBR_CS.nii", "*_PIB_CS.nii", "*_TAU_CS.nii"]):
            target_list = list(base_path.glob(target_glob))
            if target_list:
                append_subject(sbj_id, t1, target_list[0], t2f, label, is_sample)
    except IndexError:
        continue

dataset_train = tio.SubjectsDataset(subjects_train)
dataset_sample = tio.SubjectsDataset(subjects_sample)
print("Train samples:", len(dataset_train))
print("Sampling samples:", len(dataset_sample))

Complete-modality dataset size: 80 subjects
Train samples: 341
Sampling samples: 30


## 4) Preprocessing transforms & dataloaders

In [5]:
transform = tio.Compose([
    tio.RescaleIntensity(out_min_max=(-1, 1)),
    tio.Crop([11, 10, 20, 17, 0, 21]),
    tio.Resize((160, 180, 160)),
])

training_set = tio.SubjectsDataset(dataset_train, transform=transform)
sampling_set = tio.SubjectsDataset(dataset_sample, transform=transform)

train_loader = torch.utils.data.DataLoader(training_set, batch_size=1, shuffle=True, num_workers=0)
sample_loader = torch.utils.data.DataLoader(sampling_set, batch_size=1, shuffle=False, num_workers=0)

## 5) Model, discriminator, scheduler, inferer

In [6]:
model = DiffusionModelUNet(
    spatial_dims=3,
    in_channels=3,
    out_channels=1,
    num_channels=[16, 32, 64],
    attention_levels=[False, False, True],
    num_head_channels=[0, 0, 64],
    num_res_blocks=2,
    norm_num_groups=8,
    use_flash_attention=True,
    with_conditioning=True,
    cross_attention_dim=64,
    num_class_embeds=4,
).to(device)

discriminator = PatchDiscriminator(spatial_dims=3, num_layers_d=2, num_channels=4, in_channels=1, out_channels=1).to(device)

scheduler = DDPMScheduler(num_train_timesteps=1000, schedule='scaled_linear_beta', beta_start=5e-4, beta_end=1.95e-2)
inferer = DiffusionInferer(scheduler)

adv_loss = RelativisticPatchAdversarialLoss(discriminator)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=5e-6)

## 6) Checkpoint helpers

In [7]:
def save_checkpoint(model, save_path, epoch):
    save_path = Path(save_path)
    save_path.parent.mkdir(parents=True, exist_ok=True)
    torch.save({'model_state_dict': model.state_dict(), 'epoch': epoch}, save_path)

def load_checkpoint(model, load_path):
    checkpoint = torch.load(load_path, map_location='cpu')
    model.load_state_dict(checkpoint['model_state_dict'])
    return checkpoint.get('epoch', 0)

## 7) Training loop (DDPM noise loss + x0 L1; adversarial after warm-up)

In [8]:
n_epochs = 2 # run an example: 1 epoch of warnup and 1 epoch of adverserial learning
autoencoder_warm_up_n_epochs = 1
adv_weight = 0.1

scaler = GradScaler()
total_start = time.time()
for epoch_ in range(n_epochs):
    epoch = epoch_ + 1
    model.train()
    epoch_loss = 0
    epoch_noise_loss = 0
    epoch_x0_pred_loss = 0
    gen_epoch_loss = 0
    disc_epoch_loss = 0
    epoch_psnr = 0
    epoch_ssim = 0
    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), ncols=160)
    progress_bar.set_description(f"Epoch {epoch}")
    for step, batch in progress_bar:
        source_1 = batch['source_1']['data'].to(device)
        source_2 = batch['source_2']['data'].to(device)
        images = batch['target_modality']['data'].to(device)
        image_class = batch["label"].to(device)
        
        condition = torch.cat((source_1,source_2),1)
        optimizer.zero_grad(set_to_none=True)

        with autocast(enabled=True):
            # Generate random noise
            noise = torch.randn_like(images).to(device)
            x0_pred = torch.randn_like(images).to(device)

            # Create timesteps
            timesteps = torch.randint(
                0, inferer.scheduler.num_train_timesteps, (images.shape[0],), device=images.device
            ).long()

            # Get model prediction
            noise_pred = inferer(inputs=images, diffusion_model=model, label=image_class, noise=noise, timesteps=timesteps, condition=condition) 
            noised_image = scheduler.add_noise(original_samples = images, noise=noise, timesteps=timesteps)
            
            for n in range (len(noise_pred)):
                _, x0_pred[n] = scheduler.step(torch.unsqueeze(noise_pred[n,:,:,:,:], 0), timesteps[n], torch.unsqueeze(noised_image[n,:,:,:,:], 0))
                if image_class[n] == 3:
                    x0_pred[n] = x0_pred[n]*0.1
                    images[n] = images[n]*0.1

            noise_loss = F.mse_loss(noise_pred.float(), noise.float())
            x0_pred_loss = F.l1_loss(x0_pred,images)
            loss = noise_loss + x0_pred_loss

            if epoch > autoencoder_warm_up_n_epochs:
                logits_real = discriminator(images.contiguous().float())[-1]
                logits_fake = discriminator(x0_pred.contiguous().float())[-1]
                generator_loss = adv_loss(logits_real, logits_fake, images, x0_pred, for_discriminator=False)
                loss += adv_weight * generator_loss
            
            PSNR_mid = peak_signal_noise_ratio(x0_pred, images)
            SSIM_mid = structural_similarity_index_measure(x0_pred, images)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        if epoch > autoencoder_warm_up_n_epochs:
            # Discriminator part
            optimizer_d.zero_grad(set_to_none=True)
            logits_fake = discriminator(x0_pred.contiguous().detach())[-1]
            logits_real = discriminator(images.contiguous().detach())[-1]
            discriminator_loss = adv_loss(logits_real, logits_fake, images, x0_pred, for_discriminator=True)
            loss_d = adv_weight * discriminator_loss

            loss_d.backward()
            optimizer_d.step()

        epoch_loss += loss.item()
        epoch_noise_loss += noise_loss.item()
        epoch_x0_pred_loss += x0_pred_loss.item()
        if epoch > autoencoder_warm_up_n_epochs:
            gen_epoch_loss += generator_loss.item()
            disc_epoch_loss += discriminator_loss.item()
        epoch_psnr += PSNR_mid.item()
        epoch_ssim += SSIM_mid.item()
        progress_bar.set_postfix({"loss": epoch_loss / (step + 1),
                                  "noise_loss": epoch_noise_loss / (step + 1), 
                                  "x0_pred_loss": epoch_x0_pred_loss / (step + 1),
                                  "gen_loss": gen_epoch_loss / (step + 1),
                                  "disc_loss": disc_epoch_loss / (step + 1),
                                  "PSNR": epoch_psnr / (step + 1),
                                  "SSIM": epoch_ssim / (step + 1)})
    
    save_path = 'checkpoints/epoch'+str(epoch)+'_checkpoint.pt'
    save_checkpoint(model, save_path, epoch)
    
    save_path_discriminator = 'discriminator_checkpoints/epoch'+str(epoch)+'_checkpoint.pt'
    save_checkpoint(discriminator, save_path_discriminator, epoch)

total_time = time.time() - total_start
print(f"Training completed.")

Epoch 1: 100%|██████████████| 341/341 [15:21<00:00,  2.70s/it, loss=1.04, noise_loss=0.635, x0_pred_loss=0.409, gen_loss=0, disc_loss=0, PSNR=12.3, SSIM=0.0967]
Epoch 2: 100%|██████| 341/341 [16:07<00:00,  2.84s/it, loss=0.374, noise_loss=0.126, x0_pred_loss=0.177, gen_loss=0.713, disc_loss=0.933, PSNR=18.1, SSIM=0.245]


Training completed.


## 8) Sampling / inference

In [9]:
# Sampling / inference + saving synthesized volumes

from pathlib import Path
from torch.cuda.amp import autocast
from torchmetrics.functional.image import peak_signal_noise_ratio, structural_similarity_index_measure

epoch_to_load = 99
ckpt_path = f"checkpoints/epoch{epoch_to_load}_checkpoint.pt"
epoch_loaded = load_checkpoint(model, ckpt_path)
model.eval()

# Use a name that doesn't imply a "validation" split
sample_loader = sample_loader

tracer = ["PBR", "*PIB", "TAU"]

scheduler.set_timesteps(num_inference_steps=1000)

for step, batch in enumerate(sample_loader):
    if step == 0: # run one example
        bsz = len(batch['source_1']['path'])

        source_1 = batch['source_1']['data'].to(device)
        source_2 = batch['source_2']['data'].to(device)
        condition = torch.cat((source_1, source_2), dim=1)

        ground_truth = batch['target_modality']['data'].to(device)
        image_class = batch['label'].to(device)

        for seed in range(1): # run one example; change this to the repeating times for MC sampling
            SEED = seed
            torch.manual_seed(SEED)

            input_noise = torch.randn((bsz, 1, 160, 180, 160), device=device)

            with torch.no_grad():
                with autocast(enabled=True):
                    pred_PET = inferer.sample(
                        input_noise=input_noise,
                        diffusion_model=model,
                        label=image_class,
                        scheduler=scheduler,
                        save_intermediates=False,
                        intermediate_steps=100,
                        conditioning=condition,
                    )

            # Save one synthesized volume per subject per seed
            for i in range(bsz):
                subject_id = str(batch['target_modality']['stem'][i])

                preds = input_noise[i].detach().cpu()
                target = ground_truth[i].detach().cpu()

                PSNR = peak_signal_noise_ratio(preds, target)
                SSIM = structural_similarity_index_measure(preds, target)
                # print("PSNR is: ", str(PSNR.item()))
                # print("SSIM is: ", str(SSIM.item()))

                # Save synthesized 3D PET volume
                out_dir = Path(f"synth_{tracer[int(image_class[i].item())]}/")
                out_dir.mkdir(parents=True, exist_ok=True)
                out_pt = out_dir / f"syn_{subject_id}_seed{SEED}.pt"
                torch.save(preds, out_pt)

print(f"Saved synthesized outputs to: {out_dir}")

100%|██████████| 1000/1000 [09:55<00:00,  1.68it/s]


Saved synthesized outputs to: synth_PBR
