# Setup

In [None]:
import random
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

from tqdm.notebook import tqdm
from matplotlib import pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

#deallocate all cuda memory
torch.cuda.empty_cache()

#print cuda memory
print(torch.cuda.memory_allocated())

# Standard Causal MLP Mixer

In [13]:
#define a MLP Mixer based causal-language-model using weight masking

class CausalLinear(nn.Module):
    """
    A linear layer with a triangular (causal) mask applied to the weight matrix.
    This ensures each position i cannot use info from positions > i.
    """
    def __init__(self, in_dim: int, out_dim: int):
        
        super().__init__()

        if in_dim != out_dim:
            raise NotImplementedError("Only square matrices are supported.")

        # Standard weight + bias
        self.weight = nn.Parameter(torch.randn(in_dim, out_dim))
        self.bias = nn.Parameter(torch.zeros(out_dim))

        # triangular mask
        mask = torch.triu(torch.ones(in_dim, out_dim))
        self.register_buffer('mask', mask)

    def forward(self, x):
        """
        x shape: (batch, embed_dim, seq_len)
        """
        B, E, S = x.shape
        W = self.weight * self.mask    # elementwise multiply
        x_reshaped = x.reshape(B * E, S)  # (B*E, S)
        out = x_reshaped @ W           # (B*E, S)
        out = out + self.bias          # broadcast bias
        out = out.view(B, E, S)        # reshape back

        return out

class MixerBlock(nn.Module):
    
    def __init__(
        self,
        hidden_dim:int,
        seq_len:int,
        expansion_factor:int=2,
        dropout:float=0.1):

        super(MixerBlock, self).__init__()

        self.hidden_dim = hidden_dim
        self.seq_len = seq_len
        self.expansion_factor = expansion_factor

        #channel-norm
        self.channel_norm = nn.RMSNorm(hidden_dim)

        #channel-mixing layer
        self.channel_mixing_layer = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * expansion_factor),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim * expansion_factor, hidden_dim)
        )

        #token-norm
        self.token_norm = nn.RMSNorm(hidden_dim)

        #token-mixing layer
        self.token_mixing_layer = nn.Sequential(
            CausalLinear(seq_len, seq_len),
            nn.SiLU(),
            nn.Dropout(dropout),
            CausalLinear(seq_len, seq_len)
        )

    def forward(self, x):

        res = x
        x = self.channel_norm(x)
        x = self.channel_mixing_layer(x)
        x = x + res

        res = x
        x = self.token_norm(x)
        x = x.transpose(1, 2)
        x = self.token_mixing_layer(x)
        x = x.transpose(1, 2)
        x = x + res

        return x

class MLPMixer(nn.Module):
    
    def __init__(
        self,
        vocab_size:int,
        hidden_dim:int,
        seq_len:int,
        num_blocks:int):

        super(MLPMixer, self).__init__()

        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        self.seq_len = seq_len
        self.num_blocks = num_blocks

        # Input Embedding
        self.input_layer = nn.Embedding(vocab_size, hidden_dim)

        # Mixer Blocks
        self.mixer_blocks = nn.ModuleList(
            [MixerBlock(hidden_dim, seq_len) for _ in range(num_blocks)]
        )

        # Output Layer
        self.output_layer = nn.Linear(hidden_dim, vocab_size, bias=False)

        # Tie input and output layer weights
        self.output_layer.weight = self.input_layer.weight

        # Initialize weights
        self._init_weights()

        # Define loss function
        self.loss_fn = nn.CrossEntropyLoss()

    def _init_weights(self):

        for m in self.modules():
            if isinstance(m, nn.Linear) or isinstance(m, CausalLinear):
                # Kaiming He initialization for Swish activation
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def count_params(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def forward(self, x, labels=None):

        x = self.input_layer(x)
        for block in self.mixer_blocks:
            x = block(x)
        logits = self.output_layer(x)

        if not labels is None:

            logits = logits.view(-1, self.vocab_size)
            labels = labels.view(-1)
            loss = self.loss_fn(logits, labels)
            return loss, logits

        else:
            return logits

In [14]:
# Define a function to save the model checkpoint
def save_checkpoint(model, params, optimizer, losses, filename="checkpoint.pth"):
    
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'losses': losses,
        'params':{}
    }

    keys = ['vocab_size', 'hidden_dim', 'seq_len', 'num_blocks']
    assert all(k in params for k in keys)
    for k in keys:
        checkpoint['params'][k] = params[k]

    torch.save(checkpoint, filename)
    print(f"Checkpoint saved with loss {losses[-1]:.4f}")

In [15]:
def load_checkpoint(filename="checkpoint.pth"):

    checkpoint = torch.load(filename, weights_only=True)
    
    params = checkpoint['params']
    print(params)
    model = MLPMixer(**params)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    
    optimizer = optim.AdamW(model.parameters())
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    losses = checkpoint['losses']

    print(f"Checkpoint loaded: loss {losses[-1]:.4f}")

    return model, optimizer, losses

## Prepare Dataset

In [None]:
#load text
with open('shakespeare/data.txt', 'r', encoding='utf-8') as f:
    text = f.read()
print(len(text))

# #load text
# with open('mixer/TinyStories-train.txt', 'r', encoding='utf-8') as f:
#     text = f.read()
# print(len(text))

#get unique characters
all_chars = sorted(list(set(text)))
vocab_size = len(all_chars)
print(f"Unique characters: {vocab_size}")

In [None]:
char2id = {ch: i for i, ch in enumerate(tqdm(all_chars))}
id2char = {i: ch for i, ch in enumerate(tqdm(all_chars))}
tokens = [char2id[ch] for ch in tqdm(text)]

## Train

In [18]:
# Hyperparameters
hidden_dim = 128  # Size of hidden layers
seq_len = 128
num_blocks = 6
batch_size = 32
num_epochs = 2000

In [None]:
# Initialize the model
model = MLPMixer(vocab_size, hidden_dim, seq_len, num_blocks).to(device)
print(f"Model has {model.count_params():,} parameters")

In [20]:
lr = 1e-4
optimizer = optim.AdamW(model.parameters(), lr=lr)

In [None]:
pad_size = seq_len - (len(tokens) % seq_len) + 1 #we add one to account for x-y offset
pad = [char2id[' ']] * pad_size
train_data = torch.tensor(tokens + pad, dtype=torch.long).to(device)
print(len(train_data))

In [22]:
losses = []
accuracies = []

In [None]:
model.train()

#get index every seq_len, starting at idx 0
idx = list(range(0, len(train_data)-1, seq_len))
num_batches = len(idx) // batch_size
if len(idx) % batch_size != 0:
    num_batches += 1

SHUFFLE_BATCHES = True
random.seed(42)

for epoch in tqdm(range(num_epochs)):
    print(f"EPOCH {epoch+1}/{num_epochs}")

    if SHUFFLE_BATCHES:
        random.shuffle(idx)

    batch_nums = list(range(num_batches))
    for i in batch_nums:
        
        batch_start = i * batch_size
        batch_end = min((i+1) * batch_size, len(idx))
        batch_idx = idx[batch_start:batch_end]

        # Get the batch data
        batch = torch.stack([train_data[i:i+seq_len+1] for i in batch_idx], dim=0)
        x = batch[:, :-1].contiguous()
        y = batch[:, 1:].contiguous()

        # Forward pass
        optimizer.zero_grad()
        loss, output = model(x, y)
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

    avg_loss = sum(losses[-num_batches:]) / num_batches
    print(f"Average loss: {avg_loss:.4f}")

In [None]:
save_checkpoint(model, {
    'vocab_size': vocab_size,
    'hidden_dim': hidden_dim,
    'seq_len': seq_len,
    'num_blocks': num_blocks
}, optimizer, losses, filename="mixer/checkpoint.pth")

In [None]:
model, optimizer, losses = load_checkpoint("mixer/checkpoint.pth")

In [None]:
# Plot the loss curve
downsample = 1
window = 10

temp = []
for epoch in losses[num_batches*2::downsample]:
    avg = np.mean(epoch)
    temp.append(avg)
temp = np.convolve(temp, np.ones(window)/window, mode='valid')

plt.figure(figsize=(20, 6))
plt.plot(temp, label="Training Loss")
plt.xlabel("Steps")
plt.ylabel("Loss")
plt.title("Training Loss")
plt.legend()
plt.show()

# Vector-Transition MLP Mixer

In [None]:
import torch
import torch.nn as nn



In [None]:
#define a MLP Mixer based causal-language-model using weight masking

class StaticIndexMatrix(nn.Module):
    def __init__(self, M, N, tall=False):
        super().__init__()
        if N % M != 0:
            raise ValueError("N must be a multiple of M.")
        self.M = M
        self.N = N
        self.tall = tall
        self.r = N // M
        
        if not tall:
            # Precompute for wide variant: output shape (M, N)
            i = torch.arange(M).unsqueeze(1).expand(M, N)
            j = torch.arange(N).unsqueeze(0).expand(M, N)
            block = j // self.r
            offset = block - i
            valid = (offset >= 0) & (offset < M)
            
            # Register these as buffers so they're moved along with the module.
            self.register_buffer('offset_wide', offset)
            self.register_buffer('valid_wide', valid)
        else:
            # Precompute for tall variant: output shape (N, M)
            i = torch.arange(N).unsqueeze(1).expand(N, M)
            j = torch.arange(M).unsqueeze(0).expand(N, M)
            block = i // self.r
            offset = j - block
            valid = (offset >= 0) & (offset < M)
            
            self.register_buffer('offset_tall', offset) 
            self.register_buffer('valid_tall', valid)
    
    def forward(self, v):
        # v should be of shape (M,)
        if not self.tall:
            result = torch.where(self.valid_wide, v[self.offset_wide], torch.zeros_like(self.offset_wide, dtype=v.dtype))
        else:
            result = torch.where(self.valid_tall, v[self.offset_tall], torch.zeros_like(self.offset_tall, dtype=v.dtype))
        return result

class CausalLinearVector(nn.Module):
    """
    A linear layer with a triangular (causal) mask applied to the weight matrix.
    This ensures each position i cannot use info from positions > i.
    """
    def __init__(self, in_dim: int, out_dim: int):
        
        super().__init__()

        if in_dim != out_dim:
            raise NotImplementedError("Only square matrices are currently supported.") #TODO

        # Standard weight + bias
        self.weight = nn.Parameter(torch.randn(in_dim))
        self.bias = nn.Parameter(torch.zeros(in_dim))
        self.transform = StaticIndexMatrix( max(in_dim, out_dim), min(in_dim, out_dim))

    def forward(self, x):
        """
        x shape: (batch, embed_dim, seq_len)
        """
        B, E, S = x.shape
        W = self.weight * self.mask    # elementwise multiply
        x_reshaped = x.reshape(B * E, S)  # (B*E, S)
        out = x_reshaped @ W           # (B*E, S)
        out = out + self.bias          # broadcast bias
        out = out.view(B, E, S)        # reshape back

        return out

class MixerBlock(nn.Module):
    
    def __init__(
        self,
        hidden_dim:int,
        seq_len:int,
        expansion_factor:int=2,
        dropout:float=0.1):

        super(MixerBlock, self).__init__()

        self.hidden_dim = hidden_dim
        self.seq_len = seq_len
        self.expansion_factor = expansion_factor

        #channel-norm
        self.channel_norm = nn.RMSNorm(hidden_dim)

        #channel-mixing layer
        self.channel_mixing_layer = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * expansion_factor),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim * expansion_factor, hidden_dim)
        )

        #token-norm
        self.token_norm = nn.RMSNorm(hidden_dim)

        #token-mixing layer
        self.token_mixing_layer = nn.Sequential(
            CausalLinear(seq_len, seq_len),
            nn.SiLU(),
            nn.Dropout(dropout),
            CausalLinear(seq_len, seq_len)
        )

    def forward(self, x):

        res = x
        x = self.channel_norm(x)
        x = self.channel_mixing_layer(x)
        x = x + res

        res = x
        x = self.token_norm(x)
        x = x.transpose(1, 2)
        x = self.token_mixing_layer(x)
        x = x.transpose(1, 2)
        x = x + res

        return x

class VectorMLPMixer(nn.Module):
    
    def __init__(
        self,
        vocab_size:int,
        hidden_dim:int,
        seq_len:int,
        num_blocks:int):

        super(MLPMixer, self).__init__()

        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        self.seq_len = seq_len
        self.num_blocks = num_blocks

        # Input Embedding
        self.input_layer = nn.Embedding(vocab_size, hidden_dim)

        # Mixer Blocks
        self.mixer_blocks = nn.ModuleList(
            [MixerBlock(hidden_dim, seq_len) for _ in range(num_blocks)]
        )

        # Output Layer
        self.output_layer = nn.Linear(hidden_dim, vocab_size, bias=False)

        # Tie input and output layer weights
        self.output_layer.weight = self.input_layer.weight

        # Initialize weights
        self._init_weights()

        # Define loss function
        self.loss_fn = nn.CrossEntropyLoss()

    def _init_weights(self):

        for m in self.modules():
            if isinstance(m, nn.Linear) or isinstance(m, CausalLinear):
                # Kaiming He initialization for Swish activation
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def count_params(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def forward(self, x, labels=None):

        x = self.input_layer(x)
        for block in self.mixer_blocks:
            x = block(x)
        logits = self.output_layer(x)

        if not labels is None:

            logits = logits.view(-1, self.vocab_size)
            labels = labels.view(-1)
            loss = self.loss_fn(logits, labels)
            return loss, logits

        else:
            return logits