In [1]:
import torch
from glob import glob
import os
from monai.networks.nets import SwinUNETR
import numpy as np
import nibabel as nib
import random

In [2]:
# EXTRACT CORRESPONDING LATENTS

def get_random_crop_coords(crop_size=[128, 128, 128]):
    # Input shape is fixed as 240x240x155
    input_shape = [240, 240, 155]

    # Calculate the possible starting indices for the crop
    max_x = input_shape[0] - crop_size[0]  # 240 - 128
    max_y = input_shape[1] - crop_size[1]  # 240 - 128
    max_z = input_shape[2] - crop_size[2]  # 155 - 128

    # Randomly select the starting indices for the crop
    start_x = random.randint(0, max_x)
    start_y = random.randint(0, max_y)
    start_z = random.randint(0, max_z)

    # Return the starting coordinates
    return start_x, start_y, start_z

class LatentSwinUNETR(SwinUNETR):
    
    def forward(self, x_in):
        if not torch.jit.is_scripting():
            self._check_input_size(x_in.shape[2:])
        hidden_states_out = self.swinViT(x_in, self.normalize)
        dec4 = self.encoder10(hidden_states_out[4])
        return dec4

In [3]:
search_path = "/mnt/disk1/hjlee/orhun/data/BRATS_Preprocessed/Images"

# Get a list of all .nii.gz files in the directory and its subdirectories
nii_files = glob(os.path.join(search_path, '**', '*.nii.gz'), recursive=True)

coords = {}
for file in nii_files:
    coords[file]= get_random_crop_coords()

In [5]:
DEVICE_ID = 3
cuda_id = "cuda:" + str(DEVICE_ID)
device = torch.device(cuda_id)
torch.cuda.set_device(DEVICE_ID)    

channels = ["FLAIR", "T1", "T1c", "T2"]

model = LatentSwinUNETR(
    img_size=(128, 128, 128),
    in_channels= 1,  #len(channels)
    out_channels=1,
    feature_size=48,
    use_checkpoint=True).to(device)

load_model_path = "/mnt/disk1/hjlee/orhun/repo/thesis/models/BRATS_FLAIR_SwinUnetr/BRATS_FLAIR_SwinUnetr_checkpoint_Epoch_349.pt"

print("LOADING MODEL: ", load_model_path)
checkpoint = torch.load(load_model_path, map_location={"cuda:0":cuda_id,"cuda:1":cuda_id})
model.load_state_dict(checkpoint['model_state_dict'])

search_path = "/mnt/disk1/hjlee/orhun/data/BRATS_Preprocessed/Images"

# Get a list of all .nii.gz files in the directory and its subdirectories
nii_files = glob(os.path.join(search_path, '**', '*.nii.gz'), recursive=True)

output_path = "/mnt/disk1/hjlee/orhun/data/BRATS_Latents/FLAIR"

for file in nii_files:
    img = nib.load(file)
    img_array=np.array(img.dataobj)
    local_coords = coords[file]
    cropped_img = img_array[local_coords[0]:local_coords[0]+ 128, 
                              local_coords[1]:local_coords[1] + 128, 
                              local_coords[2]:local_coords[2] + 128, :]
    cropped_img = np.transpose(cropped_img, (3, 0, 1, 2))
    cropped_tensor = torch.from_numpy(cropped_img).unsqueeze(0)[:, 0:1, ...].to(device)
    #cropped_tensor = torch.from_numpy(cropped_img).unsqueeze(0).to(device)
    
    out = model(cropped_tensor)

    base_filename = os.path.basename(file).replace('.nii.gz', '.pt')
    tensor_save_path = os.path.join(output_path, base_filename)
    torch.save(out, tensor_save_path)



LOADING MODEL:  /mnt/disk1/hjlee/orhun/repo/thesis/models/BRATS_FLAIR_SwinUnetr/BRATS_FLAIR_SwinUnetr_checkpoint_Epoch_349.pt


In [2]:
img = nib.load("/mnt/disk1/hjlee/orhun/data/BRATS_Preprocessed/Images/BRATS_001.nii.gz")
label = nib.load("/mnt/disk1/hjlee/orhun/data/BRATS_Preprocessed/Labels/BRATS_001.nii.gz")
affine=img.affine
img_array=np.array(img.dataobj)
label_array=np.array(label.dataobj)

In [4]:
# EXTRACT RANDOM CROPS

class LatentSwinUNETR(SwinUNETR):
    
    def forward(self, x_in):
        if not torch.jit.is_scripting():
            self._check_input_size(x_in.shape[2:])
        hidden_states_out = self.swinViT(x_in, self.normalize)
        dec4 = self.encoder10(hidden_states_out[4])
        return dec4

def cropRandom(input_tensor, crop_size):
    # Calculate the possible starting indices for the crop
    max_x = input_tensor.shape[0] - crop_size[0]
    max_y = input_tensor.shape[1] - crop_size[1]
    max_z = input_tensor.shape[2] - crop_size[2]

    # Randomly select the starting indices for the crop
    start_x = random.randint(0, max_x)
    start_y = random.randint(0, max_y)
    start_z = random.randint(0, max_z)

    # Crop the cube from the input tensor
    cropped_tensor = input_tensor[start_x:start_x + crop_size[0], 
                                start_y:start_y + crop_size[1], 
                                start_z:start_z + crop_size[2], :]

    return cropped_tensor

DEVICE_ID = 7
cuda_id = "cuda:" + str(DEVICE_ID)
device = torch.device(cuda_id)
torch.cuda.set_device(DEVICE_ID)    

cropped_input_size = [128,128,128]
channels = ["FLAIR", "T1", "T1c", "T2"]

model = LatentSwinUNETR(
    img_size=(cropped_input_size[0], cropped_input_size[1], cropped_input_size[2]),
    in_channels=1,#in_channels=len(channels),
    out_channels=1,
    feature_size=48,
    use_checkpoint=True).to(device)


load_model_path = "/mnt/disk1/hjlee/orhun/repo/MultiUnet/models/BRATSFLAIR_latent_checkpoint_Epoch_599.pt"
print("LOADING MODEL: ", load_model_path)
checkpoint = torch.load(load_model_path, map_location={"cuda:0":cuda_id,"cuda:1":cuda_id})
model.load_state_dict(checkpoint['model_state_dict'])

search_path = "/mnt/disk1/hjlee/orhun/data/BRATS_Preprocessed/Images"

# Get a list of all .nii.gz files in the directory and its subdirectories
nii_files = glob(os.path.join(search_path, '**', '*.nii.gz'), recursive=True)

output_path = "/mnt/disk1/hjlee/orhun/data/BRATS_Preprocessed/LatentsFLAIR"

for file in nii_files:
    img = nib.load(file)
    img_array=np.array(img.dataobj)
    cropped_img = cropRandom(img_array,cropped_input_size)
    cropped_img = np.transpose(cropped_img, (3, 0, 1, 2))
    cropped_tensor = torch.from_numpy(cropped_img).unsqueeze(0)[:, 0:1, ...].to(device)

    out = model(cropped_tensor)

    base_filename = os.path.basename(file).replace('.nii.gz', '.pt')
    tensor_save_path = os.path.join(output_path, base_filename)
    torch.save(out, tensor_save_path)

In [5]:
flair_latent = torch.load('/mnt/disk1/hjlee/orhun/data/BRATS_Preprocessed/LatentsFLAIR/BRATS_001.pt', weights_only=True)
all_latent = torch.load('/mnt/disk1/hjlee/orhun/data/BRATS_Preprocessed/Latents/BRATS_001.pt', weights_only=True)


In [6]:
all_latent.shape

torch.Size([1, 768, 4, 4, 4])

In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from nets.vqvae import LatentDiffusionModel

torch.cuda.set_device(7)   


flair_embeddings = torch.load('/mnt/disk1/hjlee/orhun/data/BRATS_Preprocessed/LatentsFLAIR/BRATS_001.pt').cuda()
multimodal_embeddings = torch.load('/mnt/disk1/hjlee/orhun/data/BRATS_Preprocessed/Latents/BRATS_001.pt').cuda()

# Initialize the latent diffusion model
latent_diffusion_model = LatentDiffusionModel(
    in_channels=768, embedding_dim=256, num_embeddings=512, 
    commitment_cost=0.25, num_layers=4, num_heads=8, ff_dim=1024
).cuda()

optimizer = torch.optim.Adam(latent_diffusion_model.parameters(), lr=1e-4)
criterion = nn.MSELoss()

# Training loop
for epoch in range(1000):
    # Forward pass
    recon, vq_loss = latent_diffusion_model(flair_embeddings)
    
    # Compute reconstruction loss
    recon_loss = criterion(recon, multimodal_embeddings)
    total_loss = recon_loss + vq_loss
    
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    
    if epoch % 100 == 0:
        print(f'Epoch {epoch}, Loss: {total_loss.item()}')

torch.Size([1, 768, 4, 4, 4])
torch.Size([1, 256, 1, 1, 1])
torch.Size([1, 1, 1, 1, 256])
torch.Size([1, 1, 1, 1, 256])
torch.Size([256, 1, 1])
torch.Size([256, 1, 1])


AssertionError: was expecting embedding dimension of 256, but got 1

In [None]:
# Simple LDM

from nets.vqvae import LatentUNet, NoiseScheduler, LatentDiffusionModel
from latent_dataset import LatentEmbeddingsDataset
from torch.utils.data import DataLoader
import torch
import torch.nn.functional as F
from tensorboardX import SummaryWriter
from datetime import datetime

torch.cuda.set_device(6)  
_date = datetime.now().strftime("%d-%H-%M")
#log_dir = f"./runs/diff_latent_{_date}"
log_dir = "./runs/test"
writer = SummaryWriter(log_dir=log_dir)

# Define paths to your latent embedding directories
latents_dir = '/mnt/disk1/hjlee/orhun/data/BRATS_Preprocessed/CorrespondingLatents/All'  # Replace with the correct path
latents_single_dir = '/mnt/disk1/hjlee/orhun/data/BRATS_Preprocessed/CorrespondingLatents/T1C'  # Replace with the correct path

# Hyperparameters
batch_size = 8
num_epochs = 600
learning_rate = 1e-5

# Initialize dataset and dataloader
dataset = LatentEmbeddingsDataset(latents_dir, latents_single_dir)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Initialize the Latent Diffusion Model (with latent dimension and conditioning dimension both set to 768)
latent_dim = 768
condition_dim = 768
model = LatentDiffusionModel(latent_dim=latent_dim, condition_dim=condition_dim).cuda()

# Optimizer and loss function
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# Train the model

def train_diffusion_model(model, dataloader, num_epochs):
    noise_scheduler = NoiseScheduler()

    for epoch in range(num_epochs):
        epoch_loss = 0
        model.train()
        step = 0
        for source_embeddings, target_embeddings in dataloader:
            source_embeddings = source_embeddings.cuda()  # Move to GPU
            target_embeddings = target_embeddings.cuda()
            step+=1
            # Randomly select timesteps for each batch
            t = torch.randint(0, 1000, (target_embeddings.size(0),)).cuda()

            # Add noise to the target embeddings
            noisy_embeddings = noise_scheduler.add_noise(target_embeddings, t)

            # Forward pass through the diffusion model
            optimizer.zero_grad()
            output = model(source_embeddings, noisy_embeddings, t)

            # Compute loss (MSE between the predicted and clean embeddings)
            loss = F.mse_loss(output, target_embeddings)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss / len(dataloader)}')
        writer.add_scalar("DiffTraining/EpochLoss", epoch_loss / len(dataloader), epoch)
    
    torch.save(model.state_dict(), "/mnt/disk1/hjlee/orhun/repo/MultiUnet/models/t1c_all_diff_latent_600_1e-5.pt")




In [None]:
train_diffusion_model(model, dataloader, num_epochs=num_epochs)

torch.Size([8, 768, 4, 4, 4])
torch.Size([8, 768, 4, 4, 4])


torch.Size([8, 768, 4, 4, 4])
torch.Size([8, 768, 4, 4, 4])
torch.Size([8, 768, 4, 4, 4])
torch.Size([8, 768, 4, 4, 4])
torch.Size([8, 768, 4, 4, 4])
torch.Size([8, 768, 4, 4, 4])
torch.Size([8, 768, 4, 4, 4])
torch.Size([8, 768, 4, 4, 4])
torch.Size([8, 768, 4, 4, 4])
torch.Size([8, 768, 4, 4, 4])
torch.Size([8, 768, 4, 4, 4])
torch.Size([8, 768, 4, 4, 4])
torch.Size([8, 768, 4, 4, 4])
torch.Size([8, 768, 4, 4, 4])
torch.Size([8, 768, 4, 4, 4])
torch.Size([8, 768, 4, 4, 4])
torch.Size([8, 768, 4, 4, 4])
torch.Size([8, 768, 4, 4, 4])
torch.Size([8, 768, 4, 4, 4])
torch.Size([8, 768, 4, 4, 4])
torch.Size([8, 768, 4, 4, 4])
torch.Size([8, 768, 4, 4, 4])
torch.Size([8, 768, 4, 4, 4])
torch.Size([8, 768, 4, 4, 4])
torch.Size([8, 768, 4, 4, 4])
torch.Size([8, 768, 4, 4, 4])
torch.Size([8, 768, 4, 4, 4])
torch.Size([8, 768, 4, 4, 4])
torch.Size([8, 768, 4, 4, 4])
torch.Size([8, 768, 4, 4, 4])
torch.Size([8, 768, 4, 4, 4])
torch.Size([8, 768, 4, 4, 4])
torch.Size([8, 768, 4, 4, 4])
torch.Size

KeyboardInterrupt: 

: 

In [1]:
from nets.attentionLDM import CosineScheduler, UNet3D, LatentDiffusionModel
from latent_dataset import LatentEmbeddingsDataset
from torch.utils.data import DataLoader
import torch
import torch.nn.functional as F
from tensorboardX import SummaryWriter
from datetime import datetime

cuda_id = "cuda:" + str(6)
device = torch.device(cuda_id)
_date = datetime.now().strftime("%d-%H-%M")
#log_dir = f"./runs/diff_latent_{_date}"
log_dir = "./runs/test"
writer = SummaryWriter(log_dir=log_dir)

# Define paths to your latent embedding directories
latents_dir = '/mnt/disk1/hjlee/orhun/data/BRATS_Preprocessed/CorrespondingLatents/All'  # Replace with the correct path
latents_single_dir = '/mnt/disk1/hjlee/orhun/data/BRATS_Preprocessed/CorrespondingLatents/T1C'  # Replace with the correct path

# Hyperparameters
batch_size = 8
num_epochs = 600
learning_rate = 1e-4

# Initialize dataset and dataloader
dataset = LatentEmbeddingsDataset(latents_dir, latents_single_dir)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Initialize components
cosine_scheduler = CosineScheduler(timesteps=1000, device= device )
unet = UNet3D(in_channels=768, base_channels=128).to(device)
ldm = LatentDiffusionModel(unet, cosine_scheduler).to(device)
optimizer = torch.optim.Adam(ldm.parameters(), lr=learning_rate)


# Training loop
for epoch in range(num_epochs):
    for src, tgt in dataloader:
        src, tgt = src.to(device), tgt.to(device)
        t = torch.randint(0, cosine_scheduler.timesteps, (src.size(0),), device=src.device)

        # Noise injection
        noise = torch.randn_like(src).to(device)
        alpha_bar_t = cosine_scheduler.get_alpha_bar(t).view(-1, 1, 1, 1, 1)
        noisy_src = torch.sqrt(alpha_bar_t) * src + torch.sqrt(1 - alpha_bar_t) * noise

        # Predict noise
        noise_pred = ldm(noisy_src, t)

        # MSE Loss
        loss = F.mse_loss(noise_pred, noise)

        # Optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}")

print("Training complete.")

Epoch 1/600, Loss: 1.022239327430725
Epoch 2/600, Loss: 1.0090160369873047
Epoch 3/600, Loss: 1.0070483684539795
Epoch 4/600, Loss: 0.9973657131195068
Epoch 5/600, Loss: 0.99908047914505
Epoch 6/600, Loss: 1.0029687881469727
Epoch 7/600, Loss: 0.9978360533714294
Epoch 8/600, Loss: 0.9988768696784973
Epoch 9/600, Loss: 0.9981911182403564
Epoch 10/600, Loss: 1.001476764678955
Epoch 11/600, Loss: 1.0017756223678589
Epoch 12/600, Loss: 0.9985731244087219
Epoch 13/600, Loss: 1.0010594129562378
Epoch 14/600, Loss: 0.9968657493591309
Epoch 15/600, Loss: 1.0043799877166748
Epoch 16/600, Loss: 0.9951318502426147
Epoch 17/600, Loss: 0.9969205856323242
Epoch 18/600, Loss: 1.006028175354004
Epoch 19/600, Loss: 1.007902979850769
Epoch 20/600, Loss: 0.9992837309837341
Epoch 21/600, Loss: 1.0037096738815308
Epoch 22/600, Loss: 0.9978023767471313
Epoch 23/600, Loss: 0.9993953108787537
Epoch 24/600, Loss: 0.9959292411804199
Epoch 25/600, Loss: 0.9961526989936829
Epoch 26/600, Loss: 1.0014530420303345
E