In [18]:
import numpy as np
import torch 
from torch import nn
from torch.nn import functional as F

from typing import List, Callable, Union, Any, TypeVar, Tuple
Tensor = TypeVar('torch.tensor')

import torch.optim as optim


# Data preprocessing utils : 
import torchio as tio
from torchvision import transforms
import torchvision.transforms.functional as F
import torch.nn.functional as f
import torchvision.ops as ops
from torchvision.transforms import Compose
from torch.utils.data import Dataset, DataLoader


# Visuals utils
import os
import matplotlib.pyplot as plt
from tqdm import tqdm


# my defined model
from vqVAE import VQVAE


## Preparing Dataset 

In [19]:
dataset_path = "/Users/kajou/OneDrive/Desktop/VQ-VAE/ACDC/database/training"

In [20]:
# The following paramerters are set according to the VQ-VAE paper.

L = 64 # image size L=W
BATCH_SIZE = 64

def get_max_BB(D3_slice):
    masks = (D3_slice > 0).float()

    # boxes = ops.masks_to_boxes(binary_masks)
    boxes = []
    for mask in masks:
        # Get coordinates of non-zero pixels
        y, x = torch.where(mask)
        
        if len(y) == 0:  # Empty mask
            continue
            
        # Compute box coordinates
        box = [
            x.min(), y.min(),
            x.max(), y.max()
        ]
        
        boxes.append(box)
        
    boxes = torch.tensor(boxes)
    largest_box = torch.tensor([
        boxes[:, 0].min(),  # smallest x1
        boxes[:, 1].min(),  # smallest y1
        boxes[:, 2].max(),  # largest x2
        boxes[:, 3].max()   # largest y2
    ])
    return largest_box

def to_square_BB(BB):
    min_col, min_row,max_col, max_row = BB
    L = max_col - min_col
    W = max_row - min_row
    d = np.abs(L-W)
    eps_left = d//2 
    eps_right = d-eps_left
    if L>W : 
        min_row -= eps_left
        max_row += eps_right
    else : 
        min_col -= eps_left
        max_col += eps_right
    return [min_col, min_row, max_col, max_row]


def add_background_BB(suqare_BB,n_p = 10):
    min_col, min_row, max_col, max_row = suqare_BB

    min_col -= n_p
    min_row -= n_p
    max_row += n_p
    max_col += n_p
    BB_bg = [min_col, min_row, max_col, max_row]
    return BB_bg

def crop_img(D3_slice):
    # BB = ops.masks_to_boxes(img.unsqueeze(0))[0].int().numpy()
    BB = get_max_BB(D3_slice)
    BB = to_square_BB(BB)
    BB = add_background_BB(BB)
    min_col, min_row,max_col, max_row = BB

    img = D3_slice
    croped = F.crop(img, min_row, min_col, max_row-min_row, max_col-min_col)
    return croped


In [21]:
dataset_path = "/Users/kajou/OneDrive/Desktop/VQ-VAE/ACDC/database"

patients = os.listdir(os.path.join(dataset_path, "training"))[1:]
print(patients)
patient_PATH  = os.path.join(dataset_path, "training", "patient002")
files = os.listdir(patient_PATH)
image_files = [f for f in files if f.endswith('.nii.gz') and not f.endswith('_gt.nii.gz') and not f.endswith('4d.nii.gz')]
frames = [f.split('_')[1].split('.')[0] for f in image_files ]

print(max(frames))

['patient001', 'patient002', 'patient003', 'patient004', 'patient005', 'patient006', 'patient007', 'patient008', 'patient009', 'patient010', 'patient011', 'patient012', 'patient013', 'patient014', 'patient015', 'patient016', 'patient017', 'patient018', 'patient019', 'patient020', 'patient021', 'patient022', 'patient023', 'patient024', 'patient025', 'patient026', 'patient027', 'patient028', 'patient029', 'patient030', 'patient031', 'patient032', 'patient033', 'patient034', 'patient035', 'patient036', 'patient037', 'patient038', 'patient039', 'patient040', 'patient041', 'patient042', 'patient043', 'patient044', 'patient045', 'patient046', 'patient047', 'patient048', 'patient049', 'patient050', 'patient051', 'patient052', 'patient053', 'patient054', 'patient055', 'patient056', 'patient057', 'patient058', 'patient059', 'patient060', 'patient061', 'patient062', 'patient063', 'patient064', 'patient065', 'patient066', 'patient067', 'patient068', 'patient069', 'patient070', 'patient071', 'pati

In [22]:
dataset_path = "/Users/kajou/OneDrive/Desktop/VQ-VAE/ACDC/database"

def load_patient_gt(patient_PATH) :
    files = os.listdir(patient_PATH)
    image_files = [f for f in files if f.endswith('.nii.gz') and not f.endswith('_gt.nii.gz') and not f.endswith('4d.nii.gz')]
    gt_files = [f for f in files if f.endswith('_gt.nii.gz')]


    # Identify ED and ES frames
    frames = [f.split('_')[1].split('.')[0] for f in image_files ]
    ed_frame = min(frames)
    es_frame = max(frames)

    # Load ground truth segmentations
    ed_gt = tio.LabelMap(os.path.join(patient_PATH, f"patient{patient_PATH[-3:]}_{ed_frame}_gt.nii.gz"))
    es_gt = tio.LabelMap(os.path.join(patient_PATH, f"patient{patient_PATH[-3:]}_{es_frame}_gt.nii.gz"))
    ed_gt_slices = crop_img(ed_gt.data.squeeze(0).permute(2,0,1)).flatten(start_dim = 0, end_dim=0) 
    es_gt_slices = crop_img(es_gt.data.squeeze(0).permute(2,0,1)).flatten(start_dim = 0, end_dim=0)


    return [tensor for tensor in ed_gt_slices] + [tensor for tensor in es_gt_slices]

def load_dataset(Path):
    dataset = []
    patients = os.listdir(Path)[1:]
    for patient in patients : 
        patient_gt = load_patient_gt(os.path.join(Path, patient)) 
        dataset  += patient_gt
    return dataset


train_dataset = load_dataset(os.path.join(dataset_path, "training"))
test_dataset = load_dataset(os.path.join(dataset_path, "testing"))

In [23]:
class One_hot_Transform:
    def __init__(self, num_classes):
        self.num_classes = num_classes
    
    def __call__(self, x):
        x = x.squeeze(0).long()
        one_hot_encoded = f.one_hot(x, num_classes=4)
        return one_hot_encoded.permute(2, 0, 1)
        
input_transforms = Compose([
    transforms.Resize(size=(L,L), interpolation=transforms.InterpolationMode.NEAREST),
    One_hot_Transform(num_classes=4)
    ])

#define the dataset

class ACDC_slices(Dataset):
    def __init__(self, data, transforms =None):
        self.data = data
        self.transforms = transforms

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        item = item.unsqueeze(0)  # Shape: (1, H, W)

        if self.transforms:
            item = self.transforms(item)

        return item
    
TrainDataset = ACDC_slices(data = train_dataset, transforms= input_transforms) 
TestDataset  = ACDC_slices(data = test_dataset, transforms= input_transforms)

TrainLoader  = DataLoader(TrainDataset, batch_size = BATCH_SIZE, shuffle = True)
TestLoader   = DataLoader(TestDataset , batch_size = BATCH_SIZE, shuffle = True)

batch = next(iter(TrainLoader))

## Prepairing the model

In [24]:
K =  512 # num_embeddings
D =  64 # embedding_dim
in_channels = 4 
img_size = 64

In [25]:
ACDC_VQVAE = VQVAE(in_channels, D, K)

input = torch.rand(16, 4, 64, 64)

In [26]:
y = ACDC_VQVAE(input)
z_e = ACDC_VQVAE.encode(input)[0]
z_q, _ = ACDC_VQVAE.vq_layer(z_e)
codeBook = ACDC_VQVAE.vq_layer.embedding

## Training the Model

In [27]:
# detect gpu ?

print(torch.cuda.is_available())

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

False
cpu


In [48]:
### Learning parameters


model = ACDC_VQVAE.to(device)
batch_size = 64
lr = 1e-4
epochs = 2
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)

In [44]:
print(batch.dtype)

torch.int64


In [49]:
###########################        Training ....      #################################

def evaluate_model(model, val_loader):
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch in val_loader:
            inputs = batch.float().to(device)
           
            output, input, vq_loss = model(inputs)
            
            # Loss and backward
            loss = model.loss_function(output, input, vq_loss)['loss']
            
            val_loss += loss.item()
    
    avg_val_loss = val_loss / len(val_loader.dataset)
    return avg_val_loss


def save_model(model, epoch):
    checkpoint_path = os.path.join( os.getcwd() , 'vqvae_100_bestmodel.pth' )
    torch.save({'epoch' : epoch,
                'model_state_dict': model.state_dict()}, checkpoint_path)



model.train()
train_loss_values = []
val_loss_values = []
best_val_loss = float('inf')

for epoch in range(epochs):

    train_loss = 0.0
    
    with tqdm(TrainLoader, unit="batch") as tepoch:
        for batch_idx, inputs in enumerate(TrainLoader):
            inputs = inputs.float().to(device)  # Move data to the appropriate device (GPU/CPU)
            
            # Zero gradients
            optimizer.zero_grad()
            
            # Forward pass // args is a list containing : [output, input, vq_loss]
            output, input, vq_loss = model(inputs)
            
            # Loss and backward
            loss = model.loss_function(output, input, vq_loss)['loss']  # Use the loss function defined in the model
            loss.backward()
            optimizer.step()
            
            # Track running loss
            train_loss += loss.item()

            # tqdm bar displays the loss
            tepoch.set_postfix(loss=loss.item())

    epoch_loss = train_loss / len(TrainLoader.dataset)
    train_loss_values.append(epoch_loss)

    # Validation after each epoch
    val_loss = evaluate_model(model, TestLoader)
    val_loss_values.append(val_loss)

    #saving model if Loss values decreases
    if val_loss < best_val_loss :
        save_model(model, epoch)

    print('Epoch {}: Train Loss: {:.4f}'.format(epoch, train_loss/len(TrainLoader)))

print("Training complete.")

  0%|          | 0/30 [01:15<?, ?batch/s, loss=0.431]


Epoch 0: Train Loss: 3.3519


  0%|          | 0/30 [01:16<?, ?batch/s, loss=0.427]


Epoch 1: Train Loss: 0.4310
Training complete.
