In [3]:
import numpy as np
import torch 
from torch import nn
from torch.nn import functional as F

import torch.optim as optim


# Data preprocessing utils : 
from torchvision.transforms import Compose
from torchvision import transforms
from torch.utils.data import DataLoader

# Visuals utils
import os
import matplotlib.pyplot as plt
from tqdm import tqdm

# my defined model
from utils.acdc_dataset import *
from utils.funcs import *
from utils.vqvae import *
from utils.launcher_utils import *

In [5]:
import json
import torch

def load_model_from_metadata(json_filepath):
    """
    Load model parameters from a JSON file and instantiate a new model.

    Args:
        json_filepath (str): Path to the JSON file containing the metadata.

    Returns:
        model: A new instance of the model with the saved parameters.
    """
    # Load the JSON file
    with open(json_filepath, 'r') as f:
        metadata = json.load(f)

    # Extract model parameters from the metadata
    model_params = metadata.get("model_parameters", {})
    print("Loaded model parameters:", model_params)

    # Instantiate a new model with the extracted parameters
    model = VQVAE(**model_params)

    return model

def shrink_model_from_metadata(json_filepath, new_K):
    """
    Load model parameters from a JSON file and instantiate a new model.

    Args:
        json_filepath (str): Path to the JSON file containing the metadata.

    Returns:
        model: A new instance of the model with the saved parameters.
    """
    # Load the JSON file
    with open(json_filepath, 'r') as f:
        metadata = json.load(f)

    # Extract model parameters from the metadata
    model_params = metadata.get("model_parameters", {})

    #reduce the size of K : 
    model_params['num_embeddings'] = new_K
    print("Loaded model parameters:", model_params)

    # Instantiate a new model with the extracted parameters
    model = VQVAE(**model_params)

    return model

# Example usage:
# Assuming VQVAE is your model class
# json_filepath = "./training_metadata/training_metadata_20231025_123456.json"
# model2 = load_model_from_metadata(json_filepath, VQVAE)

In [17]:
json_filepath = 'saved_models/seg/random.pth'.replace('.pth', '.json')
with open(json_filepath, 'r') as f:
    metadata = json.load(f)

model_params = metadata.get("model_parameters", {})

print(model_params)
new_K = 128
#reduce the size of K : 
model_params['num_embeddings'] = new_K


print(model_params)

{'embedding_dim': 64, 'num_embeddings': 512, 'downsampling_factor': 8, 'residual': False, 'num_quantizers': 2, 'shared_codebook': False, 'beta': 0.25, 'decay': 0.8, 'data_mod': 'SEG'}
{'embedding_dim': 64, 'num_embeddings': 128, 'downsampling_factor': 8, 'residual': False, 'num_quantizers': 2, 'shared_codebook': False, 'beta': 0.25, 'decay': 0.8, 'data_mod': 'SEG'}


In [11]:

baseline_model_path = 'saved_models/seg/random.pth'
baseline_model_metdat = (baseline_model_path).replace('.pth', '.json')
baseline_model = load_model_from_metadata(baseline_model_metdat)

Loaded model parameters: {'embedding_dim': 64, 'num_embeddings': 512, 'downsampling_factor': 8, 'residual': False, 'num_quantizers': 2, 'shared_codebook': False, 'beta': 0.25, 'decay': 0.8, 'data_mod': 'SEG'}


In [13]:
# saving the previous model encoder and decoder : 

# Load the saved model checkpoint
checkpoint = torch.load(baseline_model_path)
# Filter the encoder parameters
encoder_state_dict = {k: v for k, v in checkpoint['model_state_dict'].items() if k.startswith('encoder.')}
# Filter the decoder parameters
decoder_state_dict = {k: v for k, v in checkpoint['model_state_dict'].items() if k.startswith('decoder.')}

  checkpoint = torch.load(baseline_model_path)


In [None]:
new_params = 

In [None]:
# we are going to pass through the whole dataset, which results on 

latent_vectors = []

# Process the dataset
with torch.no_grad():  # No need to track gradients
    for batch in TrainLoader:
        # Pass the batch through the encoder
        encoded = model.encode(batch.float().to(device))[0]  # Output shape: (batch_size, 32, 32, 32)
        
        # Flatten the encoded output to (batch_size, 32*32)
        encoded_flat = encoded.view(encoded.size(0), 64, -1).permute(0, 2, 1)  # Shape: (batch_size, 1024, 64)
        
        # Now flatten across the batch and spatial dimensions to (batch_size * 1024, 64)
        encoded_flat = encoded_flat.reshape(-1, 64)
        
        # Convert the tensor to NumPy and store it
        latent_vectors.append(encoded_flat.cpu().numpy())

# Concatenate all the latent vectors into a single NumPy array
latent_vectors = np.concatenate(latent_vectors, axis=0)  # Shape: (size_of_dataset, 32*32)

# # Optionally, save the latent vectors to disk
# np.save('latent_vectors.npy', latent_vectors)

new_codebook = torch.from_numpy(centers_init)