<a href="https://colab.research.google.com/github/chetxn04/Discrete_KV_Bottleneck/blob/main/NUS_Bigger_Experiment_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Code is yet to be comeplted. I stopped this task for the time being due to the unavaialablity fo reosurces required for trianing this model on the specified parameters.
# This code serves the purpose of showing my progress uptil now
# Refer to the file titles NUS_Final.ipynb in my github for a relatively lower scale implementaiton of Discrete Key Value Bottleneck.

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import datasets, models
from torch.utils.data import DataLoader


# Verifying first if gpu is available or not (shifted it to a t4 GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Device: cuda


In [None]:
# next step invovles preparing the datasets

# CIFAR 10 will be the main dataset used for classification
# CIFAR 100 will be used for initializing the keys in the KV bottleneck

import matplotlib.pyplot as plt
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

#
def show_images(images, labels, classes, nrow=8):
  fig,axs = plt.subplots(nrow, len(images)//nrow, figsize=(12,6))
  axs = axs.flatten()
  for img, label, ax in zip(images,labels,axs):
    img = img / 2 + 0.5
    npimg = img.numpy()
    ax.imshow(np.transpose(npimg, (1,2,0)))
    ax.set_title(classes[label])
    ax.axis('off')
  plt.tight_layout()
  plt.show()


train_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

test_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

# 2.2 Loading datasets
print("Loading CIFAR-10 dataset....")
train_dataset = datasets.CIFAR10(root='./data', train=True, transform=train_transforms, download=True)
test_dataset = datasets.CIFAR10(root='./data', train=False, transform=test_transforms, download=True)

print("Loading CIFAR-100 dataset for key initialization..")
cifar100_dataset = datasets.CIFAR100(root='./data', train=True, transform=train_transforms, download=True)

print("Creating data loaders...")
train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True, num_workers=4)
test_loader = DataLoader(dataset=test_dataset, batch_size=32, shuffle=False, num_workers=4)
cifar100_loader = DataLoader(dataset=cifar100_dataset, batch_size=32, shuffle=True, num_workers=4)

print("Checking batch dimensions...")
train_batch = next(iter(train_loader))
cifar100_batch = next(iter(cifar100_loader))
print(f"CIFAR-10 training batch shape (images): {train_batch[0].shape}, labels shape: {train_batch[1].shape}")
print(f"CIFAR-100 batch shape (images): {cifar100_batch[0].shape}, labels shape: {cifar100_batch[1].shape}")

Loading CIFAR-10 dataset...
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:18<00:00, 9172418.42it/s] 


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Loading CIFAR-100 dataset for key initialization...
Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:12<00:00, 13023655.18it/s]


Extracting ./data/cifar-100-python.tar.gz to ./data
Creating data loaders...
Checking batch dimensions...




CIFAR-10 training batch shape (images): torch.Size([32, 3, 32, 32]), labels shape: torch.Size([32])
CIFAR-100 batch shape (images): torch.Size([32, 3, 32, 32]), labels shape: torch.Size([32])


In [None]:
from torchvision import models

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.resnet18 = models.resnet18(pretrained=True)  # Used ResNet-18 as it is lighter
        self.resnet18.fc = nn.Linear(self.resnet18.fc.in_features, 512)

    def forward(self, x):
        features = self.resnet18(x)  # Forward pass through ResNet-18
        return features


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader
from tqdm import tqdm
import random
import numpy as np
from sklearn.cluster import KMeans

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def generate_random_projection_matrices(input_dim, output_dim, num_codebooks):
    # Generates a list of fixed Gaussian projection matrices for each codebook
    matrices = [torch.randn(output_dim, input_dim) for _ in range(num_codebooks)]
    return matrices

class DiscreteKeyValueBottleneck(nn.Module):
    def __init__(self, num_codebooks=4, num_keys=15, key_dim=512, value_dim=512):
        super(DiscreteKeyValueBottleneck, self).__init__()
        self.num_codebooks = num_codebooks
        self.num_keys = num_keys

        # Initialize keys and values for each codebook
        reduced_dim = key_dim // num_codebooks  # Projected dimension
        self.keys = nn.Parameter(torch.randn(num_codebooks, num_keys, reduced_dim), requires_grad=False)  # Keys are frozen
        self.values = nn.Parameter(torch.randn(num_codebooks, num_keys, value_dim))

        # Initialize EMA for keys
        self.ema_keys = torch.zeros_like(self.keys)

        # Generate fixed Gaussian projection matrices once and store them for use later during inputs
        self.projection_matrices = generate_random_projection_matrices(key_dim, reduced_dim, num_codebooks)

    def forward(self, features):
        features = F.normalize(features, dim=1)

        projected_features = [features @ self.projection_matrices[i].t().to(features.device) for i in range(self.num_codebooks)]

        distances = torch.stack([torch.cdist(proj_feat.unsqueeze(1), self.keys[i].unsqueeze(0)).squeeze(1)
                                 for i, proj_feat in enumerate(projected_features)], dim=1)

        # FindING the closest key in each codebook
        _, indices = torch.min(distances, dim=2)
        selected_values = torch.stack([self.values[i][indices[:, i]] for i in range(self.num_codebooks)], dim=1)

        # Average the selected values from all codebooks
        combined_values = selected_values.mean(dim=1)

        return combined_values, indices

    def update_keys(self, new_keys, codebook_idx):

        new_keys = new_keys.to(self.ema_keys.device) # Ensure new_keys is on the same device as ema_keys (Credit to GPT for this tip)

        self.ema_keys[codebook_idx] = 0.9 * self.ema_keys[codebook_idx] + 0.1 * new_keys
        self.keys.data[codebook_idx].copy_(self.ema_keys[codebook_idx])  # Copy EMA to keys

    def print_gaussian_matrices(self):
        # Print the Gaussian projection matrices for visual verification
        for idx, matrix in enumerate(self.projection_matrices):
            print(f"Gaussian Projection Matrix for Codebook {idx}:")
            print(matrix)

def initialize_keys_with_cifar100(encoder, bottleneck, num_codebooks=4, num_keys=15, device='cpu'):
    transform = transforms.Compose([transforms.ToTensor()])
    cifar100_data = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
    cifar100_loader = DataLoader(cifar100_data, batch_size=32, shuffle=True, num_workers=2)

    features_list = []

    # Pass through the encoder to get features
    with torch.no_grad():
        for images, _ in tqdm(cifar100_loader, desc="Extracting Features"):
            images = images.to(device)
            features = encoder(images).cpu()  # Moving features to CPU
            features_list.append(features)

    all_features = torch.cat(features_list, dim=0).numpy()  # Convert to NumPy for KMeans

    # Use Gaussian projection and KMeans clustering to initialize keys for each codebook
    for i in range(num_codebooks):

        projected_features = all_features @ bottleneck.projection_matrices[i].t().numpy()

        # Perform KMeans clustering on the projected features
        kmeans = KMeans(n_clusters=num_keys, random_state=42 + i)
        kmeans.fit(projected_features)
        initial_keys = torch.tensor(kmeans.cluster_centers_, device=device)

        # Update keys using EMA
        bottleneck.update_keys(initial_keys, i)

    return bottleneck

device = 'cuda' if torch.cuda.is_available() else 'cpu'
set_seed(42)

encoder = Encoder().to(device)
bottleneck = DiscreteKeyValueBottleneck(num_codebooks=4, num_keys=15, key_dim=512, value_dim=512).to(device)
bottleneck = initialize_keys_with_cifar100(encoder, bottleneck, num_codebooks=4, num_keys=15, device=device)

bottleneck.print_gaussian_matrices()

print("Keys after initialization with CIFAR-100 using clustering:")
print(bottleneck.keys)




Files already downloaded and verified


Extracting Features: 100%|██████████| 1563/1563 [00:18<00:00, 84.36it/s]


Gaussian Projection Matrix for Codebook 0:
tensor([[-0.7105, -0.0198, -0.0777,  ...,  2.2460,  3.0764,  0.0519],
        [-0.6180, -0.2051,  1.4381,  ..., -0.4459,  0.5250, -1.2413],
        [-1.0603, -0.3556,  0.9289,  ...,  0.1741, -0.9815,  1.9251],
        ...,
        [ 0.3850, -0.8261, -1.4157,  ..., -0.2275, -0.5870, -1.1628],
        [-0.5510,  0.4117, -0.5996,  ...,  0.4430,  0.0566, -0.6701],
        [-0.6224,  1.6981, -0.4885,  ...,  0.7364, -1.0163, -1.0466]])
Gaussian Projection Matrix for Codebook 1:
tensor([[ 0.1659,  0.6571,  1.5753,  ...,  0.6519, -0.3043, -2.1921],
        [-0.7356, -1.1633, -0.1480,  ...,  0.8037, -0.3714, -0.1970],
        [-0.5265, -0.3033, -0.1243,  ...,  0.2786, -0.8021, -1.7504],
        ...,
        [-2.2782, -1.0480, -0.4143,  ...,  1.2378,  0.4139, -1.2492],
        [-2.5349,  0.6819,  0.5568,  ...,  0.7243,  0.3141,  1.2299],
        [-0.8147, -1.2754,  1.7540,  ...,  0.7978, -1.7152, -0.1408]])
Gaussian Projection Matrix for Codebook 2:
ten

In [None]:
# Visual verificaiton code (Not important)
# import pandas as pd

# def display_keys_and_values(bottleneck):
#     # Convert keys and values to NumPy arrays
#     keys_np = bottleneck.keys.detach().cpu().numpy()
#     values_np = bottleneck.values.detach().cpu().numpy()

#     # Reshape keys and values for easier display
#     num_codebooks, num_keys, key_dim = keys_np.shape
#     _, _, value_dim = values_np.shape

#     # Flatten the keys and values while keeping the codebook index
#     keys_flattened = keys_np.reshape(num_codebooks * num_keys, key_dim)
#     values_flattened = values_np.reshape(num_codebooks * num_keys, value_dim)

#     # Create DataFrames for keys and values
#     keys_df = pd.DataFrame(keys_flattened, columns=[f'Key_Dim_{i}' for i in range(key_dim)])
#     values_df = pd.DataFrame(values_flattened, columns=[f'Value_Dim_{i}' for i in range(value_dim)])
#     codebook_df = pd.DataFrame({'Codebook': [f'Codebook_{i // num_keys}' for i in range(num_codebooks * num_keys)]})

#     # Concatenate the DataFrames along the columns
#     df = pd.concat([codebook_df, keys_df, values_df], axis=1)

#     # Display the DataFrame
#     print("Keys and Corresponding Values in Tabular Format:")
#     print(df)

# # Example usage
# display_keys_and_values(bottleneck)


Keys and Corresponding Values in Tabular Format:
      Codebook  Key_Dim_0  Key_Dim_1  Key_Dim_2  Key_Dim_3  Key_Dim_4  \
0   Codebook_0   0.524784  -0.798257   0.059302   1.957101   4.585316   
1   Codebook_0  -0.900870   0.304977  -1.355916  -1.029992   2.378527   
2   Codebook_0   0.262564  -0.348699  -0.146293  -1.154292   3.075608   
3   Codebook_0  -0.440521  -0.908660  -0.933542   0.570259   1.705456   
4   Codebook_0   0.526088  -2.328063  -1.483508   1.050314   1.978939   
5   Codebook_0  -0.423469   0.048299  -0.646480   1.040738   2.162914   
6   Codebook_0   0.139213  -0.881782  -1.381922  -0.757783   0.144770   
7   Codebook_0  -2.110939  -0.358238  -0.654760   0.604697   1.642229   
8   Codebook_0  -1.522163  -0.688197  -0.870066  -0.722033   1.638394   
9   Codebook_0  -0.062502  -0.486753  -0.429995  -0.080553   1.425506   
10  Codebook_0   0.050209  -2.076406  -1.023960  -0.318207   1.209471   
11  Codebook_0  -0.884760  -0.797551  -0.839556   0.469589   1.770886   
12

In [None]:
class Decoder(nn.Module):
    def __init__(self, input_dim=512, num_classes=10):
        super(Decoder, self).__init__()
        # Fully connected layer to output class scores
        self.fc = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        # Apply the fully connected layer
        x = self.fc(x)
        return x


In [None]:
import torch
import torch.nn as nn

class CompleteModel(nn.Module):
    def __init__(self, encoder, bottleneck, decoder):
        super(CompleteModel, self).__init__()
        self.encoder = encoder
        self.bottleneck = bottleneck
        self.decoder = decoder

    def forward(self, x):
        # Passing input through the encoder to get features
        features = self.encoder(x)  # Encoder output shape: [batch_size, 512]

        # Passing the  features through the bottleneck to get combined values
        combined_values, indices = self.bottleneck(features)  # Combined values shape: [batch_size, value_dim]

        logits = self.decoder(combined_values)  # Logits are unnormalzied final scores (GPT info)

        return logits, indices



In [None]:
# Code is yet to be comeplted. I stopped this task for the time being due to the unavaialablity fo reosurces required for trianing this model on the specified parameters.
# This code serves the purpose of showing my progress uptil now
# Refer to the file titles NUS_Final.ipynb in my github for a relatively lower scale implementaiton of Discrete Key Value Bottleneck.


device = 'cuda' if torch.cuda.is_available() else 'cpu'

encoder = Encoder().to(device)  # Create the encoder on the GPU
bottleneck = DiscreteKeyValueBottleneck(num_codebooks=4, num_keys=15, key_dim=512, value_dim=512).to(device)
decoder = Decoder(input_dim=512, num_classes=10).to(device)  # Adjust input_dim as per your bottleneck output

complete_model = CompleteModel(encoder, bottleneck, decoder).to(device)

num_epochs = 2000
learning_rate = 0.005
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(complete_model.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
    complete_model.train()
    running_loss = 0.0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)


        optimizer.zero_grad()

        logits, indices = complete_model(images)
        loss = criterion(logits, labels)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    average_loss = running_loss / len(train_loader)
    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {average_loss:.4f}')

torch.save(complete_model.state_dict(), 'complete_model.pth')


Epoch [1/200], Loss: 2.3557
Epoch [2/200], Loss: 2.3006
Epoch [3/200], Loss: 2.2775


KeyboardInterrupt: 