In [1]:
import numpy as np
import torch 
from torch import nn
from torch.nn import functional as F

import torch.optim as optim


# Data preprocessing utils : 
from torchvision.transforms import Compose
from torchvision import transforms
from torch.utils.data import DataLoader

# Visuals utils
import os
import matplotlib.pyplot as plt
from tqdm import tqdm

# my defined model
from utils.acdc_dataset import *
from utils.funcs import *
from utils.vqvae import *
from utils.launcher_utils import *

import json


In [3]:
data_mod = 'SEG'
L = 128
batch_size = 16

#################### dataset init ######################
dataset_path = "/home/ids/ihamdaoui-21/ACDC/database"

train_set_path = os.path.join(dataset_path, "training")
test_set_path  = os.path.join(dataset_path, "testing")


train_dataset = load_dataset(train_set_path, modality= data_mod)
test_dataset  = load_dataset(test_set_path, modality= data_mod)


if data_mod == 'SEG':
    input_transforms = Compose([
        transforms.Resize(size=(L,L), interpolation=transforms.InterpolationMode.NEAREST),
        One_hot_Transform(num_classes=4)
        ])
else : 
    input_transforms = Compose([
        transforms.Resize(size=(L,L), interpolation=transforms.InterpolationMode.NEAREST),
        PercentileClip(lower_percentile=1, upper_percentile=99),
        MinMaxNormalize(min_value=0.0, max_value=1.0),
        ])


TrainDataset = ACDC_Dataset(data = train_dataset, transforms= input_transforms) 
TestDataset  = ACDC_Dataset(data = test_dataset, transforms= input_transforms)

TrainLoader  = DataLoader(TrainDataset, batch_size = batch_size, shuffle = True)
TestLoader   = DataLoader(TestDataset , batch_size = batch_size, shuffle = False)


In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [5]:

baseline_model_path = 'saved_models/seg/301.pth'
baseline_model_params = load_model_metadata(baseline_model_path)
baseline_model = VQVAE(**baseline_model_params).to(device)
print(baseline_model_params)

{'embedding_dim': 64, 'num_embeddings': 512, 'downsampling_factor': 8, 'residual': False, 'num_quantizers': 2, 'shared_codebook': False, 'beta': 0.25, 'decay': 0.8, 'data_mod': 'SEG', 'loss_func': None}


In [6]:
# saving the previous model encoder and decoder : 

# Load the saved model checkpoint
checkpoint = torch.load(baseline_model_path)
# Filter the encoder parameters
encoder_state_dict = {k: v for k, v in checkpoint['model_state_dict'].items() if k.startswith('encoder.')}
# Filter the decoder parameters
decoder_state_dict = {k: v for k, v in checkpoint['model_state_dict'].items() if k.startswith('decoder.')}

  checkpoint = torch.load(baseline_model_path)


In [32]:
new_K = 256
model_params = baseline_model_params.copy()
model_params['num_embeddings'] = new_K
print(model_params)
print(baseline_model_params)

{'embedding_dim': 64, 'num_embeddings': 256, 'downsampling_factor': 8, 'residual': False, 'num_quantizers': 2, 'shared_codebook': False, 'beta': 0.25, 'decay': 0.8, 'data_mod': 'SEG', 'loss_func': None}
{'embedding_dim': 64, 'num_embeddings': 512, 'downsampling_factor': 8, 'residual': False, 'num_quantizers': 2, 'shared_codebook': False, 'beta': 0.25, 'decay': 0.8, 'data_mod': 'SEG', 'loss_func': None}


In [33]:
model = VQVAE(**model_params).to(device)

In [34]:
# we are going to pass through the whole dataset, which results on 

latent_vectors = []

# Process the dataset
with torch.no_grad():  # No need to track gradients
    for batch in TrainLoader:
        # Pass the batch through the encoder
        encoded = baseline_model.encode(batch.float().to(device))[0]  # Output shape: (batch_size, 32, 32, 32)
        
        # Flatten the encoded output to (batch_size, 32*32)
        encoded_flat = encoded.view(encoded.size(0), 64, -1).permute(0, 2, 1)  # Shape: (batch_size, 1024, 64)
        
        # Now flatten across the batch and spatial dimensions to (batch_size * 1024, 64)
        encoded_flat = encoded_flat.reshape(-1, 64)
        
        # Convert the tensor to NumPy and store it
        latent_vectors.append(encoded_flat.cpu().numpy())

# Concatenate all the latent vectors into a single NumPy array
latent_vectors = np.concatenate(latent_vectors, axis=0)  # Shape: (size_of_dataset, 32*32)

# # Optionally, save the latent vectors to disk
# np.save('latent_vectors.npy', latent_vectors)
from sklearn.cluster import kmeans_plusplus

# Calculate seeds from k-means++
centers_init, indices = kmeans_plusplus(latent_vectors, n_clusters= new_K)

new_codebook = torch.from_numpy(centers_init)
print(new_codebook)


tensor([[ 4.9095e-02,  4.4974e-02,  4.4974e-02,  ...,  4.7980e-02,
          4.7980e-02,  4.7943e-02],
        [ 1.1339e-03, -3.1888e-06, -3.1835e-06,  ..., -1.6194e-04,
         -1.7578e-04, -1.6076e-04],
        [ 1.8823e-02,  1.7054e-02,  1.7182e-02,  ...,  2.0055e-02,
          2.2020e-02,  2.0360e-02],
        ...,
        [ 4.3589e-02,  4.3720e-02,  4.4053e-02,  ...,  3.6490e-02,
          3.5010e-02,  3.6336e-02],
        [ 1.3915e-02,  1.3861e-02,  1.3861e-02,  ...,  9.0801e-03,
          9.0801e-03,  5.7537e-03],
        [ 2.2129e-02,  2.2066e-02,  2.1970e-02,  ...,  3.5239e-02,
          3.4859e-02,  3.0908e-02]])


In [35]:
# Load the encoder and decoder weights into the new model
# Remove the 'encoder.' prefix from all keys in encoder_state_dict
encoder_state_dict = {k.replace('encoder.', ''): v for k, v in encoder_state_dict.items()}
model.encoder.load_state_dict(encoder_state_dict)

# Remove the 'encoder.' prefix from all keys in encoder_state_dict
decoder_state_dict = {k.replace('decoder.', ''): v for k, v in decoder_state_dict.items()}
model.decoder.load_state_dict(decoder_state_dict)

model.vq_layer.codebook = new_codebook


In [36]:
model_name = "saved_models/Refit/445.pth"
lr = 5e-4
epochs = 40
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)


In [37]:

model.train()

train_loss_values    = []
commit_loss_values   = []
val_loss_values      = []


best_val_loss = float('inf')

for epoch in range(epochs):

    train_loss  = []
    commit_loss = []

    with tqdm(enumerate(TrainLoader), unit="batch", total=len(TrainLoader)) as tepoch:
        for batch_idx, (inputs) in tepoch:
            inputs = inputs.float().to(device)  # Move data to the appropriate device (GPU/CPU)
            
            # Zero gradients
            optimizer.zero_grad()
            
            # Forward pass // args is a list containing : [output, input, vq_loss]
            output, inputs, indices, commitement_Loss = model(inputs)
            
            # Loss and backward
            all_loss = model.loss_function(output, inputs, indices, commitement_Loss)
            loss = all_loss['loss']  # Use the loss function defined in the model
            recons_loss = all_loss['Reconstruction_Loss']
            commitement_Loss = all_loss['commitement_Loss']

            loss.backward()
            optimizer.step()
                        
            # Track running loss
            train_loss.append( recons_loss.item() )
            commit_loss.append( commitement_Loss.item() )

            # tqdm bar displays the loss
            tepoch.set_postfix(loss=loss.item())

    train_loss_values.append( np.mean(train_loss))
    commit_loss_values.append( np.mean(commit_loss))

    # Validation after each epoch
    val_loss = evaluate_model(model, TestLoader, device)
    val_loss_values.append(val_loss)

    #saving model if Loss values decreases
    if val_loss < best_val_loss :
        save_model(model_name, model, epoch, train_loss_values, val_loss_values, commit_loss_values, val_loss)
        best_val_loss = val_loss

    print('Epoch {}: '.format(epoch))


print("Training complete.")

  0%|          | 0/119 [00:00<?, ?batch/s]

100%|██████████| 119/119 [00:09<00:00, 12.85batch/s, loss=0.0287]


Epoch 0: 


100%|██████████| 119/119 [00:04<00:00, 25.07batch/s, loss=0.0282]


Epoch 1: 


100%|██████████| 119/119 [00:04<00:00, 27.85batch/s, loss=0.0212]


Epoch 2: 


100%|██████████| 119/119 [00:06<00:00, 17.73batch/s, loss=0.021] 


Epoch 3: 


100%|██████████| 119/119 [00:04<00:00, 27.54batch/s, loss=0.0249]


Epoch 4: 


100%|██████████| 119/119 [00:04<00:00, 25.74batch/s, loss=0.0261]


Epoch 5: 


100%|██████████| 119/119 [00:07<00:00, 15.12batch/s, loss=0.0284]


Epoch 6: 


100%|██████████| 119/119 [00:11<00:00, 10.41batch/s, loss=0.0201]


Epoch 7: 


100%|██████████| 119/119 [00:06<00:00, 18.89batch/s, loss=0.0199]


Epoch 8: 


100%|██████████| 119/119 [00:05<00:00, 21.89batch/s, loss=0.026] 


Epoch 9: 


100%|██████████| 119/119 [00:06<00:00, 17.80batch/s, loss=0.0284]


Epoch 10: 


100%|██████████| 119/119 [00:05<00:00, 22.76batch/s, loss=0.0241]


Epoch 11: 


100%|██████████| 119/119 [00:08<00:00, 14.37batch/s, loss=0.0179]


Epoch 12: 


100%|██████████| 119/119 [00:07<00:00, 15.04batch/s, loss=0.0258]


Epoch 13: 


100%|██████████| 119/119 [00:05<00:00, 20.76batch/s, loss=0.0222]


Epoch 14: 


100%|██████████| 119/119 [00:05<00:00, 20.73batch/s, loss=0.0253]


Epoch 15: 


100%|██████████| 119/119 [00:07<00:00, 16.41batch/s, loss=0.0263]


Epoch 16: 


100%|██████████| 119/119 [00:05<00:00, 20.18batch/s, loss=0.0278]


Epoch 17: 


100%|██████████| 119/119 [00:05<00:00, 22.87batch/s, loss=0.0224]


Epoch 18: 


100%|██████████| 119/119 [00:04<00:00, 25.80batch/s, loss=0.0241]


Epoch 19: 


100%|██████████| 119/119 [00:04<00:00, 25.74batch/s, loss=0.0218]


Epoch 20: 


100%|██████████| 119/119 [00:07<00:00, 15.54batch/s, loss=0.0268]


Epoch 21: 


100%|██████████| 119/119 [00:05<00:00, 22.38batch/s, loss=0.0224]


Epoch 22: 


100%|██████████| 119/119 [00:06<00:00, 17.51batch/s, loss=0.0251]


Epoch 23: 


100%|██████████| 119/119 [00:04<00:00, 24.33batch/s, loss=0.0251]


Epoch 24: 


100%|██████████| 119/119 [00:07<00:00, 15.57batch/s, loss=0.0211]


Epoch 25: 


100%|██████████| 119/119 [00:08<00:00, 14.59batch/s, loss=0.0236]


Epoch 26: 


100%|██████████| 119/119 [00:05<00:00, 23.67batch/s, loss=0.0218]


Epoch 27: 


100%|██████████| 119/119 [00:15<00:00,  7.46batch/s, loss=0.0239]


Epoch 28: 


100%|██████████| 119/119 [00:07<00:00, 15.39batch/s, loss=0.0212]


Epoch 29: 


100%|██████████| 119/119 [00:09<00:00, 12.12batch/s, loss=0.023] 


Epoch 30: 


100%|██████████| 119/119 [00:11<00:00, 10.64batch/s, loss=0.0288]


Epoch 31: 


100%|██████████| 119/119 [00:06<00:00, 19.44batch/s, loss=0.0239]


Epoch 32: 


100%|██████████| 119/119 [00:09<00:00, 11.94batch/s, loss=0.023] 


Epoch 33: 


100%|██████████| 119/119 [00:05<00:00, 23.24batch/s, loss=0.021] 


Epoch 34: 


100%|██████████| 119/119 [00:07<00:00, 16.59batch/s, loss=0.0252]


Epoch 35: 


100%|██████████| 119/119 [00:06<00:00, 17.81batch/s, loss=0.0231]


Epoch 36: 


100%|██████████| 119/119 [00:04<00:00, 23.84batch/s, loss=0.0176]


Epoch 37: 


100%|██████████| 119/119 [00:06<00:00, 19.38batch/s, loss=0.0201]


Epoch 38: 


100%|██████████| 119/119 [00:06<00:00, 18.39batch/s, loss=0.0206]


Epoch 39: 
Training complete.


In [38]:
print(score_model(model, TestLoader, device))

0.9633495632339927


In [39]:
hist, percentage = codebook_hist_testset(model, TestLoader, device)
hist = hist/np.sum(hist)

ONLY 85 OF CODES WERE USED FROM 256, WHICH MAKE 33.203125 % OF CODES FROM THE CODE-BOOK
