In [4]:
import numpy as np
import torch 
from torch import nn
from torch.nn import functional as F

import torch.optim as optim


# Data preprocessing utils : 
from torchvision.transforms import Compose
from torchvision import transforms
from torch.utils.data import DataLoader

# Visuals utils
import os
import matplotlib.pyplot as plt
from tqdm import tqdm

# my defined model
from utils.acdc_dataset import *
from utils.funcs import *
from utils.vqvae import *
from utils.launcher_utils import *

import json


In [5]:

def load_model_from_metadata(json_filepath):
    """
    Load model parameters from a JSON file and instantiate a new model.

    Args:
        json_filepath (str): Path to the JSON file containing the metadata.

    Returns:
        model: A new instance of the model with the saved parameters.
    """
    # Load the JSON file
    with open(json_filepath, 'r') as f:
        metadata = json.load(f)

    # Extract model parameters from the metadata
    model_params = metadata.get("model_parameters", {})
    print("Loaded model parameters:", model_params)

    # Instantiate a new model with the extracted parameters
    model = VQVAE(**model_params)

    return model

def shrink_model_from_metadata(json_filepath, new_K):
    """
    Load model parameters from a JSON file and instantiate a new model.

    Args:
        json_filepath (str): Path to the JSON file containing the metadata.

    Returns:
        model: A new instance of the model with the saved parameters.
    """
    # Load the JSON file
    with open(json_filepath, 'r') as f:
        metadata = json.load(f)

    # Extract model parameters from the metadata
    model_params = metadata.get("model_parameters", {})

    #reduce the size of K : 
    model_params['num_embeddings'] = new_K
    print("Loaded model parameters:", model_params)

    # Instantiate a new model with the extracted parameters
    model = VQVAE(**model_params)

    return model

# Example usage:
# Assuming VQVAE is your model class
# json_filepath = "./training_metadata/training_metadata_20231025_123456.json"
# model2 = load_model_from_metadata(json_filepath, VQVAE)

In [6]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
json_filepath = 'saved_models/seg/random.pth'.replace('.pth', '.json')
with open(json_filepath, 'r') as f:
    metadata = json.load(f)

model_params = metadata.get("model_parameters", {})
training_params =  metadata.get("training_parameters", {})


print(model_params)
new_K = 128
#reduce the size of K : 
model_params['num_embeddings'] = new_K

print(model_params)

{'embedding_dim': 64, 'num_embeddings': 512, 'downsampling_factor': 8, 'residual': False, 'num_quantizers': 2, 'shared_codebook': False, 'beta': 0.25, 'decay': 0.8}
{'embedding_dim': 64, 'num_embeddings': 128, 'downsampling_factor': 8, 'residual': False, 'num_quantizers': 2, 'shared_codebook': False, 'beta': 0.25, 'decay': 0.8}


In [None]:

baseline_model_path = 'saved_models/seg/random.pth'
baseline_model_metdat = (baseline_model_path).replace('.pth', '.json')
baseline_model = load_model_from_metadata(baseline_model_metdat).to(device)



Loaded model parameters: {'embedding_dim': 64, 'num_embeddings': 512, 'downsampling_factor': 8, 'residual': False, 'num_quantizers': 2, 'shared_codebook': False, 'beta': 0.25, 'decay': 0.8}


In [6]:
# saving the previous model encoder and decoder : 

# Load the saved model checkpoint
checkpoint = torch.load(baseline_model_path)
# Filter the encoder parameters
encoder_state_dict = {k: v for k, v in checkpoint['model_state_dict'].items() if k.startswith('encoder.')}
# Filter the decoder parameters
decoder_state_dict = {k: v for k, v in checkpoint['model_state_dict'].items() if k.startswith('decoder.')}

  checkpoint = torch.load(baseline_model_path)


In [12]:
model = shrink_model_from_metadata(baseline_model_metdat, 128)

Loaded model parameters: {'embedding_dim': 64, 'num_embeddings': 128, 'downsampling_factor': 8, 'residual': False, 'num_quantizers': 2, 'shared_codebook': False, 'beta': 0.25, 'decay': 0.8, 'data_mod': 'SEG'}


In [49]:
model.to(device)

VQVAE(
  (encoder): Sequential(
    (0): Sequential(
      (0): Conv2d(4, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): LeakyReLU(negative_slope=0.01)
    )
    (1): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): LeakyReLU(negative_slope=0.01)
    )
    (2): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): LeakyReLU(negative_slope=0.01)
    )
    (3): Sequential(
      (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): LeakyReLU(negative_slope=0.01)
    )
    (4): ResidualLayer(
      (resblock): Sequential(
        (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): ReLU(inplace=True)
        (2): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (5): ResidualLayer(
      (resblock): Sequential(
        (0): Conv2d(256, 256, kernel_size=(3, 3), stride

In [34]:
model_params['data_mod']

'SEG'

In [43]:
data_mod = training_params['data_mod']
L = training_params['L']
batch_size = training_params['batch_size']

#################### dataset init ######################
dataset_path = "/home/ids/ihamdaoui-21/ACDC/database"

train_set_path = os.path.join(dataset_path, "training")
test_set_path  = os.path.join(dataset_path, "testing")


train_dataset = load_dataset(train_set_path, modality= data_mod)
test_dataset  = load_dataset(test_set_path, modality= data_mod)


if data_mod == 'SEG':
    input_transforms = Compose([
        transforms.Resize(size=(L,L), interpolation=transforms.InterpolationMode.NEAREST),
        One_hot_Transform(num_classes=4)
        ])
else : 
    input_transforms = Compose([
        transforms.Resize(size=(L,L), interpolation=transforms.InterpolationMode.NEAREST),
        PercentileClip(lower_percentile=1, upper_percentile=99),
        MinMaxNormalize(min_value=0.0, max_value=1.0),
        ])


TrainDataset = ACDC_Dataset(data = train_dataset, transforms= input_transforms) 
TestDataset  = ACDC_Dataset(data = test_dataset, transforms= input_transforms)

TrainLoader  = DataLoader(TrainDataset, batch_size = batch_size, shuffle = True)
TestLoader   = DataLoader(TestDataset , batch_size = batch_size, shuffle = False)




In [51]:
# we are going to pass through the whole dataset, which results on 

latent_vectors = []

# Process the dataset
with torch.no_grad():  # No need to track gradients
    for batch in TrainLoader:
        # Pass the batch through the encoder
        encoded = baseline_model.encode(batch.float().to(device))[0]  # Output shape: (batch_size, 32, 32, 32)
        
        # Flatten the encoded output to (batch_size, 32*32)
        encoded_flat = encoded.view(encoded.size(0), 64, -1).permute(0, 2, 1)  # Shape: (batch_size, 1024, 64)
        
        # Now flatten across the batch and spatial dimensions to (batch_size * 1024, 64)
        encoded_flat = encoded_flat.reshape(-1, 64)
        
        # Convert the tensor to NumPy and store it
        latent_vectors.append(encoded_flat.cpu().numpy())

# Concatenate all the latent vectors into a single NumPy array
latent_vectors = np.concatenate(latent_vectors, axis=0)  # Shape: (size_of_dataset, 32*32)

# # Optionally, save the latent vectors to disk
# np.save('latent_vectors.npy', latent_vectors)
from sklearn.cluster import kmeans_plusplus

# Calculate seeds from k-means++
centers_init, indices = kmeans_plusplus(latent_vectors, n_clusters= new_K)

new_codebook = torch.from_numpy(centers_init)
print(new_codebook)


tensor([[-7.2677e-04, -7.1296e-04, -7.2102e-04,  ..., -6.5267e-04,
         -6.6954e-04, -6.2945e-04],
        [ 4.4607e-02,  4.4302e-02,  4.3896e-02,  ...,  4.7708e-02,
          4.6804e-02,  4.1788e-02],
        [ 2.6217e-02,  2.6229e-02,  2.6229e-02,  ...,  1.9907e-02,
          1.9907e-02,  2.4059e-02],
        ...,
        [-7.4128e-06, -2.0875e-05, -2.1214e-05,  ..., -1.9673e-05,
         -2.0442e-05, -2.4311e-05],
        [ 2.5740e-02,  2.7776e-02,  2.7776e-02,  ...,  1.9694e-02,
          1.9658e-02,  1.9676e-02],
        [ 1.2430e-02,  1.2300e-02,  1.2415e-02,  ...,  6.7612e-03,
          9.4946e-03,  1.0337e-02]])


In [52]:
print(model.vq_layer.codebook)

tensor([[0.7541, 0.6493, 0.3391,  ..., 0.6177, 0.7495, 0.7156],
        [0.3443, 0.3010, 0.3454,  ..., 0.1826, 0.8266, 0.8625],
        [0.7483, 0.8908, 0.5861,  ..., 0.6253, 0.8567, 0.5477],
        ...,
        [0.6016, 0.5963, 0.1018,  ..., 0.1677, 0.6526, 0.5958],
        [0.9381, 0.6252, 0.0850,  ..., 0.2351, 0.3698, 0.1681],
        [0.5812, 0.5719, 0.8438,  ..., 0.9017, 0.9298, 0.0963]],
       device='cuda:0')


In [53]:
model.vq_layer.codebook = new_codebook
print(model.vq_layer.codebook)

tensor([[-7.2677e-04, -7.1296e-04, -7.2102e-04,  ..., -6.5267e-04,
         -6.6954e-04, -6.2945e-04],
        [ 4.4607e-02,  4.4302e-02,  4.3896e-02,  ...,  4.7708e-02,
          4.6804e-02,  4.1788e-02],
        [ 2.6217e-02,  2.6229e-02,  2.6229e-02,  ...,  1.9907e-02,
          1.9907e-02,  2.4059e-02],
        ...,
        [-7.4128e-06, -2.0875e-05, -2.1214e-05,  ..., -1.9673e-05,
         -2.0442e-05, -2.4311e-05],
        [ 2.5740e-02,  2.7776e-02,  2.7776e-02,  ...,  1.9694e-02,
          1.9658e-02,  1.9676e-02],
        [ 1.2430e-02,  1.2300e-02,  1.2415e-02,  ...,  6.7612e-03,
          9.4946e-03,  1.0337e-02]], device='cuda:0')


In [54]:
print(model.encoder[0][0].weight)

Parameter containing:
tensor([[[[ 0.0244, -0.0428,  0.0740,  0.0781],
          [ 0.1232,  0.0612,  0.0532,  0.0446],
          [ 0.0044, -0.1115,  0.0124,  0.1156],
          [-0.0906,  0.0188,  0.1185,  0.0551]],

         [[-0.0590, -0.0583,  0.0029, -0.0477],
          [ 0.0908, -0.1074,  0.0509, -0.1155],
          [-0.0379, -0.0708, -0.0016, -0.0948],
          [ 0.0435, -0.1248, -0.0533, -0.1147]],

         [[-0.0910,  0.0394, -0.1198, -0.0950],
          [-0.0304, -0.0315,  0.0967, -0.0297],
          [ 0.0692, -0.1252, -0.0929, -0.0483],
          [ 0.0631,  0.1045, -0.0244, -0.0493]],

         [[ 0.0494,  0.0723, -0.0519,  0.0014],
          [-0.0068, -0.0452,  0.0789,  0.0938],
          [-0.1046, -0.0022,  0.0069, -0.0415],
          [-0.1134, -0.0927,  0.0048,  0.0249]]],


        [[[ 0.0533,  0.0018, -0.1011, -0.0924],
          [ 0.0523, -0.0805, -0.0791,  0.0508],
          [-0.0997, -0.1199,  0.1092, -0.0663],
          [-0.0874, -0.0016, -0.0824, -0.0822]],

      

In [55]:
# Load the encoder and decoder weights into the new model
# Remove the 'encoder.' prefix from all keys in encoder_state_dict
encoder_state_dict = {k.replace('encoder.', ''): v for k, v in encoder_state_dict.items()}
model.encoder.load_state_dict(encoder_state_dict)

# Remove the 'encoder.' prefix from all keys in encoder_state_dict
decoder_state_dict = {k.replace('decoder.', ''): v for k, v in decoder_state_dict.items()}
model.decoder.load_state_dict(decoder_state_dict)

<All keys matched successfully>

In [56]:
print(model.encoder[0][0].weight)

Parameter containing:
tensor([[[[ 0.0244, -0.0428,  0.0740,  0.0781],
          [ 0.1232,  0.0612,  0.0532,  0.0446],
          [ 0.0044, -0.1115,  0.0124,  0.1156],
          [-0.0906,  0.0188,  0.1185,  0.0551]],

         [[-0.0590, -0.0583,  0.0029, -0.0477],
          [ 0.0908, -0.1074,  0.0509, -0.1155],
          [-0.0379, -0.0708, -0.0016, -0.0948],
          [ 0.0435, -0.1248, -0.0533, -0.1147]],

         [[-0.0910,  0.0394, -0.1198, -0.0950],
          [-0.0304, -0.0315,  0.0967, -0.0297],
          [ 0.0692, -0.1252, -0.0929, -0.0483],
          [ 0.0631,  0.1045, -0.0244, -0.0493]],

         [[ 0.0494,  0.0723, -0.0519,  0.0014],
          [-0.0068, -0.0452,  0.0789,  0.0938],
          [-0.1046, -0.0022,  0.0069, -0.0415],
          [-0.1134, -0.0927,  0.0048,  0.0249]]],


        [[[ 0.0533,  0.0018, -0.1011, -0.0924],
          [ 0.0523, -0.0805, -0.0791,  0.0508],
          [-0.0997, -0.1199,  0.1092, -0.0663],
          [-0.0874, -0.0016, -0.0824, -0.0822]],

      