In [1]:
import torch
import numpy as np
import gc
import torch.nn as nn
# from datasets import load_dataset
from transformers import AutoTokenizer
from torch.optim.lr_scheduler import LinearLR

from tqdm import tqdm


#vqvae libs
from vqvae import VQVAE
from utils import *

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.nn import functional as F

import numpy as np

# Visuals utils
import os
import matplotlib.pyplot as plt
from tqdm import tqdm


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Build first the dataset :

The dataset to be train on will be the sequences of indices outputed by the quantization layer of the VQ-VAE of each image.

In [3]:
dataset = []
,k,
# 1. Load and Preprocess the Dataset
transform = transforms.Compose([
    transforms.ToTensor(),               # Convert images to PyTorch tensors
    # transforms.Normalize((0.5,), (0.5,)) # Normalize the images to [-1, 1]
])

# Download and load the Fashion-MNIST training and test dataset
train_dataset = datasets.FashionMNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.FashionMNIST(root='./data', train=False, transform=transform, download=True)

TrainLoader = DataLoader(train_dataset, batch_size=32, shuffle=True)
TestLoader  = DataLoader(test_dataset, batch_size=32, shuffle=True)

# Class labels for reference
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']



In [255]:
K = 64
D =  32 # dimension of each embedding vector
in_channels = 1 # gray scale image  = 1 color channel
downsampling_factor = 4 # two stages of dwonsampling the image (28x28) --> (7x7)

model_path = 'saved_models/model_Refit.pth'

model_vq = VQVAE(in_channels, D, K)
model_vq.load_state_dict(torch.load(model_path)['model_state_dict'])
model_vq = model_vq.to(device)



  model_vq.load_state_dict(torch.load(model_path)['model_state_dict'])


In [7]:
# we are going to pass through the whole dataset, which results on 

dataset = []

# Process the dataset
with torch.no_grad():  # No need to track gradients
    for (batch,_) in TrainLoader:
        # Pass the batch through the encoder
        indices = model_vq.get_indices(batch.to(device))  # Output shape: [B, HW = 7x7 = 49]
        dataset.append(indices.cpu().numpy())

# Concatenate all the latent vectors into a single NumPy array
dataset = np.concatenate(dataset, axis=0)  # Shape: (size_of_dataset, 32*32)

# # Optionally, save the latent vectors to disk
# np.save('dataset.npy', dataset)

In [79]:
dataset[100]

array([44,  0, 37, 54, 49, 60, 57, 21, 17, 32, 45, 53, 17, 44, 12, 53, 14,
        7,  3, 38, 21, 47,  3,  7,  7, 53, 14, 57, 12, 27,  7, 41, 14, 37,
       63, 57,  7,  7, 14, 53, 53, 33, 44, 43, 37,  3, 24,  0, 57])

In [8]:
# Optionally, save the latent vectors to disk
np.save('sequences_dataset.npy', dataset)

In [84]:
import torch
from torch.utils.data import Dataset

class VQVAECodebookDataset(Dataset):
    def __init__(self, codebook_sequences, start_token_idx=K+1, end_token_idx=K+2):
        """
        :param codebook_sequences: A list of sequences where each sequence is a list of integers 
                                   (indices from the VQ-VAE codebook, range 0 - K-1).
        :param start_token_idx: Integer representing the [START] token (default: K).
        :param end_token_idx: Integer representing the [END] token (default: K+1).
        """
        self.codebook_sequences = codebook_sequences
        self.start_token_idx = start_token_idx
        self.end_token_idx = end_token_idx

    def __len__(self):
        return len(self.codebook_sequences)

    def __getitem__(self, idx):
        sequence = self.codebook_sequences[idx]
        
        # Add [START] and [END] tokens to the sequence
        input_sequence = [self.start_token_idx] + sequence.tolist()  # [START] token at the beginning
        target_sequence = sequence.tolist() + [self.end_token_idx]   # [END] token at the end
        
        # Return as PyTorch tensors (1D, dtype long)
        return torch.tensor(input_sequence, dtype=torch.long), torch.tensor(target_sequence, dtype=torch.long)


In [85]:
Dataset = VQVAECodebookDataset(dataset)
TrainLoader = DataLoader(Dataset, batch_size=32, shuffle=True)

----------
### Defining the transformer

In [144]:
import torch.nn.functional as F

class SmallTransformerModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, n_heads, n_layers, sequence_length):
        super(SmallTransformerModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        # self.pos_embedding = nn.Parameter(torch.zeros(1, sequence_length, embedding_dim))  # Positional Encoding
        self.register_buffer("pos_embedding", self.create_linear_positional_encoding(sequence_length, embedding_dim))

        encoder_layer = nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=n_heads, dim_feedforward=hidden_dim)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)

        self.fc_out = nn.Linear(embedding_dim, vocab_size)
        
    def create_linear_positional_encoding(self, sequence_length, embedding_dim):
        """
        Create a simple linear positional encoding where each position is scaled linearly.
        """
        # Generate positions [0, 1, ..., sequence_length - 1]
        positions = torch.arange(0, sequence_length).unsqueeze(1).float()  # Shape: [sequence_length, 1]
        # Scale positions linearly to the embedding dimension
        # Normalize by dividing by sequence_length to keep values small
        encoding = positions / sequence_length  # Shape: [sequence_length, 1]
        # Expand encoding to match the embedding dimension
        encoding = encoding * torch.linspace(0, 1, embedding_dim).unsqueeze(0)  # Shape: [sequence_length, embedding_dim]
        return encoding.unsqueeze(0)

    def forward(self, x):
        # Embed tokens + positions
        x = self.embedding(x) + self.pos_embedding[:, :x.size(1), :]

        # Next time decoment this, to leverage the autoregressiveness
        # # Create a causal mask to prevent attention to future tokens
        # causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(x.device)
        # # Transformer forward pass with the mask
        # x = self.transformer(x, src_mask=causal_mask)
        
        # Transformer forward pass
        x = self.transformer(x)

        # Output logits for each token position
        out = self.fc_out(x)
        return out



In [166]:
vocab_size = K+3 # num_embeddings of codebook plus the two special_tokens (either : [START] or [END] )
embedding_dim = 32
hidden_dim = embedding_dim*2
n_heads = 4
n_layers = 2
sequence_length = 50 # since my image are downsampled 2 times, 28x28 --> 7x7 = 49, plus the [END] or [START]

lr = 1e-4
num_epochs = 10
# Model, loss, optimizer
model = SmallTransformerModel(vocab_size, embedding_dim, hidden_dim, n_heads, n_layers, sequence_length)
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)





In [208]:

lr = 3e-5

# Training loop
for epoch in tqdm(range(num_epochs), desc="Epochs"):
    model.train()
    total_loss = 0

    with tqdm(enumerate(TrainLoader), unit="batch", total=len(TrainLoader)) as tepoch:
        for bacth in tepoch:
    # Wrap the TrainLoader with tqdm for progress tracking within each epoch
    # for batch in tqdm(TrainLoader, desc=f"Training Epoch {epoch+1}", leave=False):
            x, y = batch
            x, y = x.to(device), y.to(device)


            optimizer.zero_grad()
            
            
            # Forward pass
            logits = model(x)
            
            # Reshape logits and targets for loss calculation
            logits = logits.view(-1, vocab_size)  # Flatten for all time steps
            y = y.view(-1)  # Flatten targets to match logits
            
            loss = criterion(logits, y)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()

    avg_loss = total_loss / len(TrainLoader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

    

100%|██████████| 1875/1875 [00:11<00:00, 163.64batch/s]
Epochs:  10%|█         | 1/10 [00:11<01:43, 11.46s/it]

Epoch 1/10, Loss: 0.6688


100%|██████████| 1875/1875 [00:11<00:00, 164.08batch/s]
Epochs:  20%|██        | 2/10 [00:22<01:31, 11.45s/it]

Epoch 2/10, Loss: 0.6682


100%|██████████| 1875/1875 [00:11<00:00, 163.13batch/s]
Epochs:  30%|███       | 3/10 [00:34<01:20, 11.47s/it]

Epoch 3/10, Loss: 0.6671


100%|██████████| 1875/1875 [00:11<00:00, 162.85batch/s]
Epochs:  40%|████      | 4/10 [00:45<01:08, 11.49s/it]

Epoch 4/10, Loss: 0.6661


100%|██████████| 1875/1875 [00:11<00:00, 168.79batch/s]
Epochs:  50%|█████     | 5/10 [00:57<00:56, 11.35s/it]

Epoch 5/10, Loss: 0.6652


100%|██████████| 1875/1875 [00:10<00:00, 175.13batch/s]
Epochs:  60%|██████    | 6/10 [01:07<00:44, 11.13s/it]

Epoch 6/10, Loss: 0.6646


100%|██████████| 1875/1875 [00:10<00:00, 176.30batch/s]
Epochs:  70%|███████   | 7/10 [01:18<00:32, 10.97s/it]

Epoch 7/10, Loss: 0.6641


100%|██████████| 1875/1875 [00:11<00:00, 168.08batch/s]
Epochs:  80%|████████  | 8/10 [01:29<00:22, 11.03s/it]

Epoch 8/10, Loss: 0.6631


100%|██████████| 1875/1875 [00:11<00:00, 165.07batch/s]
Epochs:  90%|█████████ | 9/10 [01:40<00:11, 11.14s/it]

Epoch 9/10, Loss: 0.6624


100%|██████████| 1875/1875 [00:12<00:00, 154.90batch/s]
Epochs: 100%|██████████| 10/10 [01:53<00:00, 11.30s/it]

Epoch 10/10, Loss: 0.6623





In [149]:

# #from : https://medium.com/@ikim1994914/understanding-the-modern-llm-part-3-using-pytorch-built-in-function-to-build-an-autoregressive-3feeb14496e9
# ########################## define transformer function ##########################
# #################################################################################
# class fullTransformer(nn.Module):
#     def __init__(self, device, input_size, max_length_src, max_length_tgt, d_model, nhead = 8,
#                  num_encoder_layers = 4, num_decoder_layers = 4,
#                  dim_feedforward = 1024, dropout = 0.1, pad_idx = 50257):
#         super(fullTransformer, self).__init__()
#         self.device = device
#         self.input_size = input_size # this is the # of the vocabularies in the source (how mnay tokens the tokenizer has)
#         self.output_size = input_size # input tokenizer and output tokenizer are the same, so input_size = output_size
#         self.d_model = d_model # this is the hidden, or the embedding dimension
#         self.nhead = nhead # number of multihead-attention 
#         self.enc_layer = num_encoder_layers
#         self.dec_layer = num_decoder_layers
#         self.dim_forward = dim_feedforward
#         self.dropout = dropout
#         self.max_length_src = max_length_src
#         self.max_length_tgt = max_length_tgt
#         self.pad_idx = pad_idx

        
#         # define the transformer module
#         self.transformer = nn.Transformer(d_model = self.d_model, nhead = self.nhead, num_encoder_layers = self.enc_layer,
#                                           num_decoder_layers = self.dec_layer, dim_feedforward=self.dim_forward,
#                                           dropout = self.dropout, batch_first = True, bias = True, device = self.device)
        
#         # define the embedding for the ids and the position
#         self.src_embedding = nn.Embedding(num_embeddings = self.input_size, embedding_dim = self.d_model)
#         self.tgt_embedding = nn.Embedding(num_embeddings = self.output_size, embedding_dim = self.d_model)
#         self.src_posembedding = nn.Embedding(num_embeddings = self.max_length_src, embedding_dim = self.d_model)
#         self.tgt_posembedding = nn.Embedding(num_embeddings = self.max_length_tgt, embedding_dim = self.d_model)
        
#         # expand the hidden to the output size (same as the input vocabulary)
#         self.deco_final_layer  = nn.Linear(self.d_model , self.input_size)
        
#     def forward(self, src, tgt, src_key_mask, tgt_key_mask):
#         # embed the inputs
#         src_embed = self.src_embedding(src) # src[N x T] -> [N x T x H]
#         src_pos_embed = self.src_posembedding(torch.arange(self.max_length_src).to(self.device))
#         src_total_embed = src_embed + src_pos_embed # add position embed
        
#         tgt_embed = self.tgt_embedding(tgt) # tgt [N x T] -> [N x T x H]
#         tgt_pos_embed = self.tgt_embedding(torch.arange(self.max_length_tgt).to(self.device))
#         tgt_total_embed = tgt_embed + tgt_pos_embed
        
#         # feed the embedding into the transformer
#         # this is mostly used for autoregression, but never the less, we will set it. 
#         tgt_mask = nn.Transformer.generate_square_subsequent_mask(sz = self.max_length_tgt)
#         tgt_seq_mask = (tgt_mask == float('-inf')).to(self.device)
#         #src_seq_mask = (tgt_mask == float('-inf')).to(self.device)
        
#         # the mask provided by the hugging face is [1,1,1,0,0,0,0] -> [False, False, True]
#         # TRUE values cannot participate in attention -  this seems to be correct
#         src_key_mask = (src_key_mask.bool() != True).to(self.device)
#         tgt_key_mask = (tgt_key_mask.bool() != True).to(self.device)
        
#         # run the transformer: this will have the output of [N, T, H]
#         transformer_out = self.transformer(src = src_total_embed, tgt = tgt_total_embed,
#                                            tgt_mask = tgt_seq_mask,
#                                            src_key_padding_mask = src_key_mask , tgt_key_padding_mask = tgt_key_mask)
        
#         # run the transformer output through the final layer [N, T, H] -> [N, T, vocab]
#         final_output = self.deco_final_layer(transformer_out)
        
#         return final_output

In [209]:
# Save the state dictionary
torch.save(model.state_dict(), "saved_models/prior_model.pth")


# # # Load model
# model = SmallTransformerModel(vocab_size, embedding_dim, hidden_dim, n_heads, n_layers, sequence_length)
# model.load_state_dict(torch.load("model.pth"))
# model.to(device)


In [222]:
# def generate_sequence(model, start_token_idx = K+1, end_token_idx = K+2, max_len=50):
#     model.eval()
#     device = next(model.parameters()).device
#     generated_sequence = [start_token_idx]  # Start with the [START] token
    
#     for _ in range(max_len - 1):  # Generate up to max_len tokens
#         input_seq = torch.tensor([generated_sequence], dtype=torch.long).to(device)
#         logits = model(input_seq)
        
#         # Get the most likely next token (argmax)
#         next_token = torch.argmax(logits[:, -1, :], dim=-1).item()
        
#         generated_sequence.append(next_token)
        
#         # Optionally, stop generation if [END] token is predicted
#         if next_token == end_token_idx:
#             break
    
#     return generated_sequence

In [262]:
def generate_sequence(
    model, 
    start_token_idx=K+1, 
    end_token_idx=K+2, 
    max_len=50, 
    temperature=0.8, 
    top_k=None
):
    model.eval()
    device = next(model.parameters()).device
    generated_sequence = [start_token_idx]  # Start with the [START] token

    for _ in range(max_len - 1):  # Generate up to max_len tokens
        input_seq = torch.tensor([generated_sequence], dtype=torch.long).to(device)
        logits = model(input_seq)
        
        # Extract logits for the last token in the sequence
        next_token_logits = logits[:, -1, :]

        # Apply temperature scaling
        next_token_logits = next_token_logits / temperature

        if top_k is not None:
            # Apply top-k sampling
            next_token = top_k_sampling(next_token_logits, k=top_k)
        else:
            # Default to greedy decoding (argmax)
            next_token = torch.argmax(next_token_logits, dim=-1).item()
        
        generated_sequence.append(next_token)
        
        # Stop generation if [END] token is predicted
        # if next_token == end_token_idx:
        #     break

    # Return the sequence excluding [START] and [END] tokens
    return generated_sequence[1:] if generated_sequence[-1] == end_token_idx else generated_sequence[1:]


In [263]:
generated_seq = generate_sequence(model)
print(generated_seq)

[33, 0, 23, 1, 18, 2, 66, 36, 28, 3, 21, 0, 20, 12, 19, 20, 12, 19, 20, 12, 19, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54]


# Conclusion:

>> the model fails to converge :P