## Re-fit ( Re-build and fine-Tune) method:

This method tries to combat the CodeBook collapse, i.e(dead codes phenomenon where only a small subset of codes are used by the encoder from the whole codebook), it was first introduced in the paper : https://arxiv.org/pdf/2112.01799

In the best scenario, We want our codeBook to be initialized in a manner where it has a prior on our dataset, so the codes and the encoders outputs are more alligned, hence the distances between them isn't that huge, The idea proposed in the article is simple : take a pre-trained Vq-Vae trained on a random intilialised codebook, then do a pass through all the dataset and encode allthe images using the encoder, we end by a huge matrix, then we apply a K-means on all the vectors to take the centroids, and train a new model whose codebbok is initialized with those centroids.

We can note the following : 
*   The nw model is still inspired by the previous CodeBook, so note that any enhancement made on the previous can strongly increase the new model, inversely, if we start from a bad model, this mehtod could not help much
*   we can highly decrease the number of codes used in this new model $K' << K$ without loosing much informations. after all, experiments show that only $20 \% $ codes are used form a randomly initialized codebook. 

In [11]:
import numpy as np
import torch 
from torch import nn
from torch.nn import functional as F

from typing import List, Callable, Union, Any, TypeVar, Tuple
Tensor = TypeVar('torch.tensor')

import torch.optim as optim


# Data preprocessing utils : 
from acdc_dataset import ACDC_Dataset, One_hot_Transform, load_dataset
from torchvision.transforms import Compose
from torchvision import transforms

from torch.utils.data import DataLoader


# Visuals utils
import os
import matplotlib.pyplot as plt
from tqdm import tqdm


# my defined model
from vqVAE import VQVAE
# from vqVAE_custom import VQVAE


import warnings
warnings.filterwarnings("ignore")

In [None]:
in_channels = 4 
L = 128 # image size L=W
BATCH_SIZE = 16

new_K = 128 # We divided the number of codes by 4
D = 64


### Calling the DataSet

In [17]:
dataset_path = "/home/ids/ihamdaoui-21/ACDC/database"

train_set_path = os.path.join(dataset_path, "training")
test_set_path  = os.path.join(dataset_path, "testing")


In [None]:
train_dataset = load_dataset(train_set_path)
test_dataset  = load_dataset(test_set_path)


input_transforms = Compose([
    transforms.Resize(size=(L,L), interpolation=transforms.InterpolationMode.NEAREST),
    One_hot_Transform(num_classes=4)
    ])


TrainDataset = ACDC_Dataset(data = train_dataset, transforms= input_transforms) 
# TestDataset  = ACDC_Dataset(data = test_dataset, transforms= input_transforms)

TrainLoader  = DataLoader(TrainDataset, batch_size = BATCH_SIZE, shuffle = True)
# TestLoader   = DataLoader(TestDataset , batch_size = BATCH_SIZE, shuffle = True)

In [19]:
# detect gpu ?

print(torch.cuda.is_available())

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

True
cuda:0


In [20]:
model_path = 'saved_models/vqvae_100_bestmodel.pth'

model = VQVAE(in_channels, 64, 512)
model.load_state_dict(torch.load(model_path)['model_state_dict'])
model = model.to(device)


In [36]:
model.vq_layer.embedding.weight

Parameter containing:
tensor([[ 4.9889e-04,  1.1660e-03, -3.9475e-04,  ...,  8.1173e-04,
          1.2089e-03, -1.5192e-03],
        [-4.6987e-03,  4.6395e-01, -8.3338e-03,  ..., -4.3272e-03,
          1.5351e+00, -7.0349e-03],
        [ 6.7910e-05, -6.4830e-04, -1.5788e-03,  ..., -2.7956e-04,
          1.1307e-03,  4.6125e-04],
        ...,
        [-1.1818e-03,  5.8600e-04, -1.7582e-03,  ...,  3.3327e-04,
          5.1822e-04, -1.0754e-03],
        [-1.9042e-03,  1.3782e-03, -1.3591e-03,  ..., -5.5735e-04,
          1.9091e-03, -8.8509e-04],
        [-4.5487e-04,  1.5970e-03, -5.1831e-04,  ..., -1.2115e-03,
         -2.3396e-04, -1.4929e-04]], device='cuda:0', requires_grad=True)

In [26]:
# we are going to pass through the whole dataset, which results on 

latent_vectors = []

# Process the dataset
with torch.no_grad():  # No need to track gradients
    for batch in TrainLoader:
        # Pass the batch through the encoder
        encoded = model.encode(batch.float().to(device))[0]  # Output shape: (batch_size, 32, 32, 32)
        
        # Flatten the encoded output to (batch_size, 32*32)
        encoded_flat = encoded.view(encoded.size(0), 64, -1).permute(0, 2, 1)  # Shape: (batch_size, 1024, 64)
        
        # Now flatten across the batch and spatial dimensions to (batch_size * 1024, 64)
        encoded_flat = encoded_flat.reshape(-1, 64)
        
        # Convert the tensor to NumPy and store it
        latent_vectors.append(encoded_flat.cpu().numpy())

# Concatenate all the latent vectors into a single NumPy array
latent_vectors = np.concatenate(latent_vectors, axis=0)  # Shape: (size_of_dataset, 32*32)

# # Optionally, save the latent vectors to disk
# np.save('latent_vectors.npy', latent_vectors)

In [None]:

import matplotlib.pyplot as plt
from sklearn.cluster import kmeans_plusplus

# Calculate seeds from k-means++
centers_init, indices = kmeans_plusplus(latent_vectors, n_clusters= new_K)


In [None]:
import umap
import numpy as np

reducer = umap.UMAP(n_components=2, random_state=42)

# Fit and transform
d2_latent_vectors = reducer.fit_transform(latent_vectors)
d2_centers_init   = reducer.transform(centers_init)

# Plot the results
plt.scatter(d2_latent_vectors[:, 0], d2_latent_vectors[:, 1], s=10, cmap='Spectral')
plt.scatter(d2_centers_init[:, 0], d2_centers_init[:, 1], c="b", s=50)
plt.title('UMAP projection of 512 vectors')
plt.xlabel('UMAP-1')
plt.ylabel('UMAP-2')
plt.show()



KeyboardInterrupt: 