<a href="https://colab.research.google.com/github/lbinding/AT-AT/blob/main/train_VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#Install
!pip install monai==1.4.0
!pip install pandas==2.0.3
!pip install torchio==0.20.4



In [None]:
#Libs
import wandb
from pathlib import Path
import torchio as tio
from torch.utils.data import Dataset
import torch
import nibabel as nib
import os
import numpy as np
import random
from monai.bundle import ConfigParser
from monai.losses import PerceptualLoss
from monai.losses.adversarial_loss import PatchAdversarialLoss
from torch.amp import autocast
import wandb
from torch.nn import MSELoss
from monai.transforms import DivisiblePad
import os
from torchvision.utils import save_image
from torch.utils.data import DataLoader
import torch.optim as optim
import pandas as pd

In [None]:
# prompt: Setup drive path
from google.colab import drive
base_dir = '/content/drive'
drive.mount(base_dir)
base_dir = '/content/drive/MyDrive'

Mounted at /content/drive


In [None]:
source_folder_path_in_drive = os.path.join(base_dir, "training_data") # <--- CHANGE THIS PATH

# The destination path in the local Colab runtime.
# '/content/' is a common place for temporary files in Colab.
# You can give your copied folder a new name here if you want.
destination_folder_path_local = '/content/' # <--- OPTIONAL: Change the name of the copied folder

print(f"Source folder in Drive: '{source_folder_path_in_drive}'")
print(f"Destination in local Colab runtime: '{destination_folder_path_local}'")
print("-" * 50)

# --- Step 3: Check if the source folder exists ---
if not os.path.exists(source_folder_path_in_drive):
    print(f"Error: The source folder '{source_folder_path_in_drive}' does not exist in your Google Drive.")
    print("Please double-check the 'source_folder_path_in_drive' variable.")
else:
    # --- Step 4: Create the destination directory if it doesn't exist ---
    # This is good practice to ensure the target path is ready.
    os.makedirs(destination_folder_path_local, exist_ok=True)
    print(f"Ensured destination directory exists: '{destination_folder_path_local}'")

    # --- Step 5: Copy the folder using the 'cp' command ---
    # `cp -r` recursively copies directories and their contents.
    print(f"Copying folder from Drive to local Colab runtime...")
    !cp -r "$source_folder_path_in_drive" "$destination_folder_path_local"

    print("Copy process complete!")
    print("-" * 50)

    # --- Step 6: Verify the copy (optional) ---
    print(f"Listing contents of the copied folder in local Colab runspace:")
    !ls -l "$destination_folder_path_local"

    print("\nYou can now work with the copied folder at:")
    print(destination_folder_path_local)
    print("Files accessed from this location will typically be faster than directly from Google Drive.")



Source folder in Drive: '/content/drive/MyDrive/training_data'
Destination in local Colab runtime: '/content/'
--------------------------------------------------
Ensured destination directory exists: '/content/'
Copying folder from Drive to local Colab runtime...
Copy process complete!
--------------------------------------------------
Listing contents of the copied folder in local Colab runspace:
total 12
drwx------ 5 root root 4096 Jul  2 15:24 drive
drwxr-xr-x 1 root root 4096 Jun 26 13:35 sample_data
drwx------ 3 root root 4096 Jul  2 15:24 training_data

You can now work with the copied folder at:
/content/
Files accessed from this location will typically be faster than directly from Google Drive.


# SAVE MODEL


In [None]:
def save_model(vAE=None, LDM=None, discrim=None, model_dir=None, epoch=None):
    """
    Save the model state dictionary to a specified directory with epoch information.

    Args:
        vAE (torch.nn.Module, optional): The Variational Autoencoder model to save.
        LDM (torch.nn.Module, optional): The Latent Diffusion Model to save.
        model_dir (str): Directory where the models will be saved.
        epoch (int): Current epoch number for naming the file.
    """
    if model_dir is None or epoch is None:
        raise ValueError("Both 'model_dir' and 'epoch' must be provided.")

    if vAE is not None:
        model_save_path = os.path.join(model_dir, "models", f"trained_vAE_epoch_{epoch}.pt")
        torch.save(vAE.state_dict(), model_save_path)
        print(f"vAE model saved at {model_save_path}")

    if LDM is not None:
        model_save_path = os.path.join(model_dir, "models", f"trained_LDM_epoch_{epoch}.pt")
        torch.save(LDM.state_dict(), model_save_path)
        print(f"LDM model saved at {model_save_path}")

    if discrim is not None:
        discrim_save_path = os.path.join(model_dir, "models", f"trained_discriminator_epoch_{epoch}.pt")
        torch.save(discrim.state_dict(), discrim_save_path)
        print(f"Discriminator saved at {discrim_save_path}")



#DATASET


In [None]:
#%% Dataset

class FDG_Dataset(Dataset):
    """
    A PyTorch Dataset to load NIfTI files from a provided list of file paths.
    """
    def __init__(self, data, transform=None):
        """
        Args:
            file_paths (list of Path objects): List of paths to the NIfTI files.
            transform (callable, optional): Optional transform to be applied on a sample.
        """

        self.T1_paths = data['Linked_Files_Anon']
        self.diagnosis = data['DX_encoded']
        self.transform = transform
        self.rescale = tio.RescaleIntensity((0, 1))

    def __len__(self):
        return len(self.T1_paths)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        # Get path to load in
        T1_path = self.T1_paths[idx]

        diag = self.diagnosis[idx]

        # Load the T1-weighted MRI image
        # Using memmap=False can prevent potential file locking issues
        T1_img = nib.load(T1_path).get_fdata()

        # Wrap in Subject (Add dimensions for channels and 3D (for augmentations))
        subject = tio.Subject(image=tio.ScalarImage(tensor=torch.as_tensor(T1_img[None, :, :, None])))

        # Apply augmentation
        if self.transform:
            subject = self.transform(subject)
        else:
            subject = self.rescale(subject)

        # Extract transformed image tensor
        image_tensor = subject['image']['data'].squeeze(-1) # Squeeze the dummy depth dimension

        return image_tensor, diag

def create_datasets(df, augmentations, data_dir):

    df['Linked_Files_Anon'] = df['Linked_Files_Anon'].apply(lambda x: os.path.join(data_dir, x))

    train_data = df[df['Set']=='Train']
    valid_data = df[df['Set']=='Validation']
    test_data = df[df['Set']=='Test']

    train_data = train_data.reset_index(drop=True)
    valid_data = valid_data.reset_index(drop=True)
    test_data = test_data.reset_index(drop=True)

    train_dataset = FDG_Dataset(data=train_data, transform=augmentations)
    valid_dataset = FDG_Dataset(data=valid_data)
    test_dataset = FDG_Dataset(data=test_data)



    return train_dataset, valid_dataset, test_dataset

# LOAD

In [None]:

def load_KL_autoencoder(weights_path, config_file="train_autoencoder.json", device=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup paths
    config_path = os.path.join(base_dir, "configs", config_file)
    #weights_path = base_dir / "models" / weights_file

    print(config_path)

    # Read config
    config = ConfigParser()
    config.read_config(str(config_path))

    # Parse model
    model = config.get_parsed_content("gnetwork")

    # Load checkpoint
    checkpoint = torch.load(weights_path, map_location=device)

    # Key remapping
    key_mapping = {
        "encoder.blocks.10.to_q.weight": "encoder.blocks.10.attn.to_q.weight",
        "encoder.blocks.10.to_q.bias": "encoder.blocks.10.attn.to_q.bias",
        "encoder.blocks.10.to_k.weight": "encoder.blocks.10.attn.to_k.weight",
        "encoder.blocks.10.to_k.bias": "encoder.blocks.10.attn.to_k.bias",
        "encoder.blocks.10.to_v.weight": "encoder.blocks.10.attn.to_v.weight",
        "encoder.blocks.10.to_v.bias": "encoder.blocks.10.attn.to_v.bias",
        "encoder.blocks.10.proj_attn.weight": "encoder.blocks.10.attn.out_proj.weight",
        "encoder.blocks.10.proj_attn.bias": "encoder.blocks.10.attn.out_proj.bias",
        "decoder.blocks.2.to_q.weight": "decoder.blocks.2.attn.to_q.weight",
        "decoder.blocks.2.to_q.bias": "decoder.blocks.2.attn.to_q.bias",
        "decoder.blocks.2.to_k.weight": "decoder.blocks.2.attn.to_k.weight",
        "decoder.blocks.2.to_k.bias": "decoder.blocks.2.attn.to_k.bias",
        "decoder.blocks.2.to_v.weight": "decoder.blocks.2.attn.to_v.weight",
        "decoder.blocks.2.to_v.bias": "decoder.blocks.2.attn.to_v.bias",
        "decoder.blocks.2.proj_attn.weight": "decoder.blocks.2.attn.out_proj.weight",
        "decoder.blocks.2.proj_attn.bias": "decoder.blocks.2.attn.out_proj.bias",
        "decoder.blocks.6.conv.conv.weight": "decoder.blocks.6.postconv.conv.weight",
        "decoder.blocks.6.conv.conv.bias": "decoder.blocks.6.postconv.conv.bias",
        "decoder.blocks.9.conv.conv.weight": "decoder.blocks.9.postconv.conv.weight",
        "decoder.blocks.9.conv.conv.bias": "decoder.blocks.9.postconv.conv.bias",
    }

    # Remap keys
    new_state_dict = {key_mapping.get(k, k): v for k, v in checkpoint.items()}

    # Load state
    model.load_state_dict(new_state_dict, strict=False)
    model.to(device)

    return model

# TORCHIO

In [None]:
#%% Define the Augmentations class
class Augmentations:
    def __init__(self):
        self.random_anisotropy = tio.RandomAnisotropy(axes=(0, 1))
        self.random_affine = tio.RandomAffine()
        self.add_motion = tio.RandomMotion(num_transforms=1, image_interpolation='nearest')
        self.rescale = tio.RescaleIntensity((0, 1))

    def __call__(self, subject):
        aug_level = random.randint(1, 3)
        # Define individual transformations
        def blur(subject):
            downsampling_factor = random.randint(2, 3)
            original_spacing = 1 # This might need to be adjusted based on actual pixel spacing
            std = tio.Resample.get_sigma(downsampling_factor, original_spacing)
            antialiasing = tio.Blur(std) # Axes will default. Check if it's implicitly handling 2D correctly.
            return antialiasing(subject)

        def anistropy(subject):
            return self.random_anisotropy(subject)

        def affine(subject):
            return self.random_affine(subject)

        def elastix(subject):
            max_displacement_value = random.randint(1, 5) # Still in voxels
            # For 2D images, the axes should be (0, 1)
            random_elastic = tio.RandomElasticDeformation(
                max_displacement=max_displacement_value,
                num_control_points=random.randint(5, 15),
            )
            return random_elastic(subject)

        def noise(subject):
            add_noise = tio.RandomNoise(std=(np.random.rand() / 4))
            return add_noise(subject)

        def field_bias(subject):
            add_bias = tio.RandomBiasField(coefficients=(np.random.rand() / 2))
            return add_bias(subject)

        def motion(subject):
            return self.add_motion(subject)

        # List of functions
        all_functions = [blur, anistropy, noise, field_bias, motion]
        blur_functions = [noise, field_bias, anistropy]
        other_functions = [motion, blur]

        # Select transformations based on augmentation level
        if aug_level == 0:
            selected_functions = []
        elif aug_level == 1:
            selected_functions = random.sample(all_functions, 1)
        elif aug_level == 2:
            selected_blur_functions = random.sample(blur_functions, 1)
            selected_other_functions = random.sample(other_functions, 1)
            selected_functions = selected_blur_functions + selected_other_functions
        elif aug_level == 3:
            selected_blur_functions = random.sample(blur_functions, 2)
            selected_other_functions = random.sample(other_functions, 2)
            selected_functions = selected_blur_functions + selected_other_functions

        # Apply transformations
        subject = affine(subject)
        subject = elastix(subject)

        for func in selected_functions:
            subject = func(subject)

        return self.rescale(subject)


# LOSSES

In [None]:
#%% Paths and Config
weights_path = os.path.join(base_dir, "models", "model_discriminator.pt")
config = ConfigParser()
config.read_config(os.path.join(base_dir,"configs","train_autoencoder.json"))

#%% Weights for each component
adv_weight = 0.5
perceptual_weight = 1.0
kl_weight = 1e-6  # KL regularization weight

#%% Loss components
intensity_loss = torch.nn.L1Loss()
adv_loss = PatchAdversarialLoss(criterion="least_squares")

#%% KL divergence loss
def compute_kl_loss(z_mu, z_sigma):
    kl_loss = 0.5 * torch.sum(
        z_mu.pow(2) + z_sigma.pow(2) - torch.log(z_sigma.pow(2)) - 1,
        dim=list(range(1, len(z_sigma.shape)))
    )
    return torch.mean(kl_loss)

#%% Perceptual loss (ResNet50)
def load_perceptual_loss(device):
    perceptual_loss = PerceptualLoss(
        spatial_dims=2,
        network_type="resnet50",
        pretrained=True,
    )
    perceptual_loss.to(device)
    return perceptual_loss

#%% Discriminator
def load_discriminator(device):
    discriminator = config.get_parsed_content("dnetwork")
    discriminator.to(device)
    return discriminator

#%% Generator loss function
def generator_loss(gen_images, real_images, z_mu, z_sigma, disc_net, loss_perceptual, device):
    with autocast(device_type=device.type, enabled=True):

        recons_loss = intensity_loss(gen_images, real_images)
        wandb.log({"intensity loss": recons_loss}) # Consider logging epoch averages instead

        kl = compute_kl_loss(z_mu, z_sigma)
        wandb.log({"kl loss": kl})

        p_loss = loss_perceptual(gen_images, real_images)
        wandb.log({"perceptual loss": p_loss})

        # Base generator loss (reconstruction + KL + perceptual)
        loss_g = recons_loss + kl_weight * kl + perceptual_weight * p_loss
        wandb.log({"gen base loss": loss_g})

        # Adversarial component
        logits_fake = disc_net(gen_images)[-1]
        gen_adv_loss = adv_loss(logits_fake, target_is_real=True, for_discriminator=False)
        wandb.log({"adversarial loss": gen_adv_loss})

        # --- FIX 1: Replaced in-place += with standard addition to fix the error ---
        loss_g = loss_g + adv_weight * gen_adv_loss
        # -------------------------------------------------------------------------
        wandb.log({"total generator loss": loss_g})


        # --- FIX 2 (Best Practice): Detach logits_fake for the discriminator's loss calculation ---
        # This stops gradients from flowing back to the generator during the discriminator's update.
        d_loss_fake = adv_loss(logits_fake.detach(), target_is_real=False, for_discriminator=True)
        # ------------------------------------------------------------------------------------------
        logits_real = disc_net(real_images.contiguous().detach())[-1]
        d_loss_real = adv_loss(logits_real, target_is_real=True, for_discriminator=True)
        discriminator_loss = (d_loss_fake + d_loss_real) * 0.5
        loss_d = adv_weight * discriminator_loss
        wandb.log({"discriminator loss": loss_d})


    return loss_g, loss_d


# WANDB CONFIG

In [None]:

def load_wandb_config_vAE():
    """
    Load the Weights & Biases configuration

    Returns:
        dict: wandb config.
    """
    wandb.login(key="76c124f9bfc89b958db96f3de53b29ddbfa1feb5")

    wandb_config = {
        "learning_rate": 0.00005,
        "architecture": "Autoencoder_KL",
        "dataset": "FDG_2D_slices",
        "epochs": 100,
        "batch_size": 12,
        }

    wandb.init(
        # set the wandb project where this run will be logged
        project="LDM_FDG_vae",
        # track hyperparameters and run metadata
        config=wandb_config,
        mode="online"  # Ensure wandb is properly initialized
    )
    return wandb_config


# TRAIN AUTOENCODER

In [None]:
intensity_loss = torch.nn.L1Loss()
#%% Train vAE
def train_autoencoder(KL_autoencoder, discriminator, perceptual_loss, generator_loss, optimizer_g, optimizer_d, train_loader, val_loader, output_dir, device):

    wandb_config = load_wandb_config_vAE()
    KL_autoencoder.train()

    best_loss = np.inf

    for epoch in range(wandb_config["epochs"]):
        print("EPOCH:", epoch)
        wandb.log({"epoch": epoch})
        total_g_loss = 0
        total_d_loss = 0

        for data_augmented,_ in train_loader:
            data_augmented = data_augmented.to(device).float()

            optimizer_g.zero_grad()
            optimizer_d.zero_grad()
            with autocast(device_type='cuda', enabled=True):
                recon, z_mu, z_sigma = KL_autoencoder(data_augmented)
                gen_loss, disc_loss = generator_loss(recon, data_augmented, z_mu, z_sigma, discriminator, perceptual_loss, device)

            gen_loss.backward(retain_graph=True)
            optimizer_g.step()

            disc_loss.backward()
            optimizer_d.step()

            total_g_loss += gen_loss.item()
            total_d_loss += disc_loss.item()
        # Save model every epoch

        valid_loss = 0
        with torch.no_grad():
            for data,_ in val_loader:
                with autocast(device_type='cuda', enabled=True):
                    data = data.to(device).float()
                    recon, _, _ = KL_autoencoder(data)

                    recon_loss = intensity_loss(data, recon)

                    valid_loss += recon_loss.item() # Accumulates only the Python number

        epoch_valid_loss = valid_loss / len(val_loader)
        wandb.log({'valid_loss': epoch_valid_loss})

        if epoch_valid_loss < best_loss:
            best_loss = epoch_valid_loss

            save_model(vAE=KL_autoencoder, model_dir=output_dir, epoch=epoch)
            save_model(discrim=discriminator, model_dir=output_dir, epoch=epoch)


    return KL_autoencoder


# MAIN CODE

In [None]:
#%% Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#%% Setup paths
#Define directories
data_dir = os.path.join('/content/', "training_data")
slices_dir = os.path.join(data_dir, 'slices_40_new_anon')

#%% Load CSV
data_key = pd.read_csv(os.path.join(data_dir, 'data_key_new_anon.csv'))
#data_key = data_key[data_key['Linked_Files_Anon'].apply(lambda x: os.path.exists(os.path.join(slices_dir, x)))]
data_key = data_key.reset_index(drop=True)

#%% Input arguments (Sets training to true)
# Set default values:
autoencoder_weights         = os.path.join(base_dir, 'models', 'model_autoencoder.pt')

#%% Load vAE wandb config & Setup output directories
wandb_config_vAE = load_wandb_config_vAE()
#Create output directory
output_dir       = os.path.join(base_dir, "output", wandb_config_vAE['architecture'])
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# =================================================================================
# Setup Datasets and DataLoaders for vAE Training
# =================================================================================
# Augmentations are applied only to the training set for the vAE
aug_transforms = Augmentations()

# Create Dataset instances
train_dataset,  \
val_dataset,    \
test_dataset    = create_datasets(data_key, aug_transforms, slices_dir)

# Create DataLoader instances
# Note: shuffle=False for val and test loaders for consistent evaluation
train_loader     = DataLoader(train_dataset, batch_size=wandb_config_vAE['batch_size'], shuffle=True)
val_loader       = DataLoader(val_dataset, batch_size=wandb_config_vAE['batch_size'], shuffle=False)
test_loader      = DataLoader(test_dataset, batch_size=wandb_config_vAE['batch_size'], shuffle=False)

#%% Load losses
perceptual_loss = load_perceptual_loss(device=device).float()
discriminator   = load_discriminator(device=device).float()

#%% Load the vAE model and optimizer
KL_autoencoder  = load_KL_autoencoder(autoencoder_weights,
                                        config_file="train_autoencoder.json",
                                        device=device).float()

optimizer_g   = optim.Adam(params=list(KL_autoencoder.parameters()), lr=wandb_config_vAE["learning_rate"])
optimizer_d   = optim.Adam(params=list(discriminator.parameters()), lr=wandb_config_vAE["learning_rate"])


#%% Train the vAE model
print("Training vAE")
# Pass both training and validation loaders to the training function
# You might need to adapt your `train_autoencoder` function to use the validation loader
KL_autoencoder = train_autoencoder(KL_autoencoder,
                                    discriminator,
                                    perceptual_loss,
                                    generator_loss,
                                    optimizer_g,
                                    optimizer_d,
                                    train_loader, # Use the vAE specific training loader
                                    val_loader,       # Pass validation loader for evaluation
                                    output_dir=output_dir,
                                    device=device)



0,1
adversarial loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
discriminator loss,█▃▆▆▆▄▃▆▃▆▃▇▆▅▆▁▅▂▄▃▄▅▁█▃▄▅▄▄▅▂▅▃▅▆▆▆▄▅▄
epoch,▁▁▂▃▃▃▄▅▅▅▆▇▇▇█
gen base loss,██▆▄▃▄▂▅▃▄▃▂▃▂▃▂▂▃▁▂▂▂▁▃▂▂▂▂▂▂▁▁▂▂▂▃▂▁▂▂
intensity loss,▃█▄▂▃▄▅█▃▄▂▅▄▄▂▄▃▃▃▁▃▃▁▃▅▂▃▂▁▂▁▂▄▂▂▂▁▂▂▃
kl loss,▃▃▅▄▄▅▆▃▄▃▅▅▄▂▄▅▃▄▃▃▃▆▃▆▃▃▆▆▃▅▅▄▅▅█▆▇▅▁▆
perceptual loss,█▆▅▄▅▃▂▃▂▂▃▂▂▂▃▂▂▁▂▂▁▂▃▂▂▁▁▂▁▂▃▂▂▂▂▁▁▁▂▁
total generator loss,█▃▄▁▄▄▄▃▃▃▂▃▂▂▂▃▃▂▂▁▂▂▁▂▂▂▂▁▂▁▁▁▁▂▁▁▂▂▂▁
valid_loss,▄▄▅█▁▂▂▄▂▁▂▁▂▂

0,1
adversarial loss,0.00081
discriminator loss,0.24963
epoch,14.0
gen base loss,0.25779
intensity loss,0.02166
kl loss,11566.87891
perceptual loss,0.22457
total generator loss,0.2582
valid_loss,0.14244


/content/drive/MyDrive/configs/train_autoencoder.json




Training vAE


EPOCH: 0


  self.parse_free_form_transform(


vAE model saved at /content/drive/MyDrive/output/Autoencoder_KL/models/trained_vAE_epoch_0.pt
Discriminator saved at /content/drive/MyDrive/output/Autoencoder_KL/models/trained_discriminator_epoch_0.pt
EPOCH: 1
vAE model saved at /content/drive/MyDrive/output/Autoencoder_KL/models/trained_vAE_epoch_1.pt
Discriminator saved at /content/drive/MyDrive/output/Autoencoder_KL/models/trained_discriminator_epoch_1.pt
EPOCH: 2
vAE model saved at /content/drive/MyDrive/output/Autoencoder_KL/models/trained_vAE_epoch_2.pt
Discriminator saved at /content/drive/MyDrive/output/Autoencoder_KL/models/trained_discriminator_epoch_2.pt
EPOCH: 3
vAE model saved at /content/drive/MyDrive/output/Autoencoder_KL/models/trained_vAE_epoch_3.pt
Discriminator saved at /content/drive/MyDrive/output/Autoencoder_KL/models/trained_discriminator_epoch_3.pt
EPOCH: 4
EPOCH: 5
vAE model saved at /content/drive/MyDrive/output/Autoencoder_KL/models/trained_vAE_epoch_5.pt
Discriminator saved at /content/drive/MyDrive/output/