In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn import TransformerEncoder, TransformerEncoderLayer, TransformerDecoder, TransformerDecoderLayer
from vector_quantize_pytorch import FSQ, VectorQuantize
import math
import einops
from fancy_einsum import einsum
import numpy as np

### TRANSFORMER LENS ###
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

## Cache residual streams

In [None]:
# Load the model
PTH_LOCATION = "data/transformer_lens.pth"
model_dict = torch.load(PTH_LOCATION)
model = HookedTransformer(model_dict["config"])
model.load_state_dict(torch.load(PTH_LOCATION)["model"])

In [None]:
# Load our tensors
train_data = torch.load("data/train_data.pt")
eval_data = torch.load("data/eval_data.pt")
print(train_data.shape, eval_data.shape)

In [None]:
_, train_cache = model.run_with_cache(train_data)
_, eval_cache = model.run_with_cache(eval_data)

In [None]:
train_residual_stream, train_labels = train_cache.decompose_resid(return_labels=True)
eval_residual_stream, eval_labels = eval_cache.decompose_resid(return_labels=True)

In [None]:
train_labels

In [None]:
train_residual_stream.shape # (n_layers, n_examples, seq_len, d_model)

In [None]:
# We want each of the indices in the seq_len to be a separate example
train_residual_stream = einops.rearrange(train_residual_stream, "layers examples seq_len d_model -> layers (examples seq_len) d_model")
eval_residual_stream = einops.rearrange(eval_residual_stream, "layers examples seq_len d_model -> layers (examples seq_len) d_model")
print(train_residual_stream.shape)

In [None]:
# Save residual streams
torch.save(train_residual_stream, "data/train_residual_stream.pt")
torch.save(eval_residual_stream, "data/eval_residual_stream.pt")

## VQ-VAE

In [6]:
class TransformerAutoencoder(nn.Module):
    def __init__(self, input_dim, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, 
                 codebook_size=128, codebook_dim= 16, threshold_ema_dead_code=2, dropout=0.1):
        super(TransformerAutoencoder, self).__init__()
        self.input_dim = input_dim
        self.d_model = d_model

        # VQ Quantizer
        dim = 256
        self.codebook_size = codebook_size
        self.codebook_dim = codebook_dim
        self.threshold_ema_dead_code = threshold_ema_dead_code
        self.quantizer = VectorQuantize(
            dim=dim,  # Assuming the dimensionality to match d_model for simplicity
            codebook_size=codebook_size,  # Example codebook size
            codebook_dim=codebook_dim,  # This is an illustrative example, adjust based on your model's needs
            decay=0.8,
            commitment_weight=1.0,
            use_cosine_sim=True,  # Example, adjust as needed
            threshold_ema_dead_code = threshold_ema_dead_code
        )
        self.bottleneck_dim = dim

        # Positional Encoding
        self.positional_encoder = PositionalEncoding(d_model, dropout)

        # Encoder Layer
        encoder_layers = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True)
        self.transformer_encoder = TransformerEncoder(encoder_layers, num_encoder_layers)

        # Linear projection down - replaces torch.zeros with nn.Linear
        self.encoder_output_projection = nn.Linear(d_model, self.bottleneck_dim)

        # Linear projection up - replaces torch.zeros with nn.Linear
        self.decoder_input_projection = nn.Linear(self.bottleneck_dim, d_model)

        # Decoder Layer
        decoder_layers = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True)
        self.transformer_decoder = TransformerDecoder(decoder_layers, num_decoder_layers)

        self.encoder_input_projection = nn.Linear(input_dim, d_model)
        self.decoder_output_projection = nn.Linear(d_model, input_dim)

    def forward(self, src):
        # Encode
        src = self.encoder_input_projection(src)
        src = self.positional_encoder(src)
        memory = self.transformer_encoder(src)
        
        # Apply the encoder output projection down
        memory = F.relu(self.encoder_output_projection(memory))

        # Vector quantize the memory
        quantized_memory, _, commit_loss = self.quantize(memory)
        #print(f"Quantised Memory shape = {quantized_memory.shape}")

        # Decode
        quantized_memory = F.relu(self.decoder_input_projection(quantized_memory))
        output = self.transformer_decoder(quantized_memory, quantized_memory)
        output = self.decoder_output_projection(output)
        return output, commit_loss

    def quantize(self, bottleneck):
        quantized, indices, commit_loss = self.quantizer(bottleneck)
        return quantized, indices, commit_loss

    def quantized_indices(self, src):
        # Encode
        src = self.encoder_input_projection(src)
        src = self.positional_encoder(src)
        memory = self.transformer_encoder(src)
        
        # Apply the encoder output projection down
        memory = F.relu(self.encoder_output_projection(memory))

        # Vector quantize the memory
        quantised, indices, _ = self.quantize(memory)

        return quantised, indices
    
    def indices_to_rep(self, indices):
        with torch.no_grad():
            low_dim_vectors = self.quantizer.get_codes_from_indices([indices])
            quantized_memory = self.quantizer.project_out(low_dim_vectors)
            # Decode
            quantized_memory = F.relu(self.decoder_input_projection(quantized_memory))
            output = self.transformer_decoder(quantized_memory, quantized_memory)
            output = self.decoder_output_projection(output)
            return output

        
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

In [13]:
train_streams = torch.load("data/train_residual_stream.pt")
eval_streams = torch.load("data/eval_residual_stream.pt")

In [16]:
# Example Configuration
input_dim = train_streams.shape[-1]  # Size of the input
sequence_length = train_streams.shape[0]  # Length of the sequence
d_model = 64  # The number of expected features in the encoder/decoder inputs
nhead = 8  # The number of heads in the multiheadattention models
num_encoder_layers = 1  # The number of sub-encoder-layers in the encoder
num_decoder_layers = 1  # The number of sub-decoder-layers in the decoder
dim_feedforward = 1024  # The dimension of the feedforward network model
dropout = 0.1  # The dropout value
codebook_size=32
codebook_dim=8

In [17]:
model = TransformerAutoencoder(input_dim, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, 
                               codebook_size=codebook_size, dropout=dropout)

In [18]:
x = train_streams[:, 0, :].unsqueeze(1)
x.shape

torch.Size([4, 1, 32])

In [22]:
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adjust the learning rate as needed
loss_function = nn.MSELoss()
epochs = 10

In [23]:
def train(model, epochs, optimizer, train_streams, eval_streams, verbose=True, print_epoch=None):

    train_losses, eval_losses, unique_indices_utilised = [], [], []

    for epoch in range(epochs):
        model.train()  # Set the model to training mode
        epoch_loss = 0
        
        optimizer.zero_grad()  # Zero the gradients
        outputs, commit_loss = model(train_streams)  # Forward pass: compute the model output
        loss = loss_function(outputs, train_streams)  # Compute the loss
        loss += commit_loss[0]
        
        loss.backward()  # Backward pass: compute gradient of the loss with respect to model parameters
        optimizer.step()  # Perform a single optimization step (parameter update)
        
        epoch_loss += loss.item()  # Accumulate the loss

        divider = epochs // 10 if print_epoch is None else print_epoch

        if epoch % divider == 0:
            # Calculate eval loss
            model.eval()  # Set the model to evaluation mode
            eval_loss = 0
            with torch.no_grad():  # No need to track the gradients
                eval_outputs, eval_commit_loss = model(eval_streams)  # Forward pass: compute the model output
                eval_loss = loss_function(eval_outputs, eval_streams)  # Compute the loss
                eval_loss += eval_commit_loss[0]
            # Calculate unique codes on train
            _, train_indices = model.quantized_indices(train_streams.cpu())
            unique_indices = torch.unique(train_indices)
            unique_indices = len(unique_indices)
            unique_indices_utilised.append(unique_indices)

            if verbose:
                print(f"Epoch [{epoch}/{epochs}], Train Loss: {epoch_loss:.3f}, Eval Loss: {eval_loss:.3f}, Unique Indices: {unique_indices}")
            train_losses.append(epoch_loss)
            eval_losses.append(eval_loss)

    return model, train_losses, eval_losses, unique_indices_utilised

In [24]:
model, train_losses, eval_losses, unique_indices_utilised = train(model, epochs, optimizer, train_streams, eval_streams)

RuntimeError: [enforce fail at alloc_cpu.cpp:125] err == 0. DefaultCPUAllocator: can't allocate memory: you tried to allocate 432041713463808 bytes. Error code 12 (Cannot allocate memory)