## Description

This notebook implements a custom Extended Long Short-Term Memory (xLSTM) model to predict the next tokens given an input sequence as described in the paper [xLSTM: Extended Long Short-Term Memory](https://arxiv.org/abs/2405.04517).

We will work with the “Tiny Shakespeare” dataset, a character-level corpus of Shakespeare’s plays and sonnets, commonly used for next-character prediction. The dataset is available at [Github](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt).

### Imports

In [45]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

device = 'cuda' if torch.cuda.is_available() else 'cpu'

### Google Drive Setup for Checkpointing in Google Colab

In [None]:
# Mount Google Drive for checkpoint saving
from google.colab import drive
import os
from datetime import datetime

# Mount Google Drive
drive.mount('/content/drive')

# Create checkpoint directory
checkpoint_dir = '/content/drive/MyDrive/xlstm_checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)

print(f"Google Drive mounted successfully!")
print(f"Checkpoint directory: {checkpoint_dir}")

### **Preparing the Tokenizer and Dataloader** (1 point)

In [46]:
# Character-level tokenizer
class CharTokenizer:
    def __init__(self, text):
        self.chars = sorted(set(text))
        self.char_to_idx = {ch: i for i, ch in enumerate(self.chars)}
        self.idx_to_char = {i: ch for i, ch in enumerate(self.chars)}
        self.vocab_size = len(self.chars)

    def encode(self, text):
        return [self.char_to_idx[ch] for ch in text]

    def decode(self, indices):
        return ''.join([self.idx_to_char[i] for i in indices])

# Dataset class
class ShakespeareDataset(Dataset):
    def __init__(self, text, tokenizer, seq_length):
        self.tokenizer = tokenizer
        self.seq_length = seq_length
        self.data = tokenizer.encode(text)

    def __len__(self):
        return len(self.data) - self.seq_length

    def __getitem__(self, idx):
        x = torch.tensor(self.data[idx:idx+self.seq_length], dtype=torch.long)
        y = torch.tensor(self.data[idx+1:idx+self.seq_length+1], dtype=torch.long)
        return x, y

# Load and prepare data
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

tokenizer = CharTokenizer(text)
seq_length = 64
dataset = ShakespeareDataset(text, tokenizer, seq_length)

# Split into train and validation
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# Create data loaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

print(f"Vocabulary size: {tokenizer.vocab_size}")
print(f"Dataset size: {len(dataset)}")
print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

Vocabulary size: 65
Dataset size: 1115330
Training samples: 1003797
Validation samples: 111533


### **Preparing the Model** (2.5 point)

#### components

In [None]:
class BlockDiagonalProj(nn.Module):
    def __init__(self, input_dim, num_heads):
        super(BlockDiagonalProj, self).__init__()
        self.input_dim = input_dim
        self.num_heads = num_heads
        self.out_head_size = input_dim // num_heads
        self.weight = nn.Parameter(torch.empty(num_heads, self.out_head_size, input_dim // num_heads))
        nn.init.xavier_uniform_(self.weight)

    def forward(self, x):
        shape = x.shape
        x = x.view(*shape[:-1], self.num_heads, -1)
        x = torch.einsum("...hd,hod->...ho", x, self.weight)
        x = x.reshape(*shape[:-1], -1)
        return x

class CausalConv1d(nn.Module):
    def __init__(self, feature_dim, kernel_size, bias=True):
        super(CausalConv1d, self).__init__()
        self.pad = (kernel_size -1)
        self.conv = nn.Conv1d(in_channels=feature_dim, out_channels=feature_dim, kernel_size=kernel_size, padding=self.pad, groups=feature_dim, bias=bias)
    def forward(self, x):
        y = x.transpose(2, 1)
        y = self.conv(y)
        return y[:, :, : -self.pad].transpose(2, 1)


#### mLSTM block

In [None]:
### COMPLETE THIS CLASS ####
class mLSTMCell(nn.Module):
    def __init__(self, input_dim, num_heads):
        super(mLSTMCell, self).__init__()
        self.input_dim = input_dim
        self.num_heads = num_heads
        self.head_dim = input_dim // num_heads
        
        self.i_proj = BlockDiagonalProj(input_dim, num_heads)
        self.f_proj = BlockDiagonalProj(input_dim, num_heads)

    def forward(self, q, k, v, x_conv):  # Need conv output for gates
        B, S, _ = q.shape

        # Reshape to heads
        q = q.view(B, S, self.num_heads, self.head_dim)
        k = k.view(B, S, self.num_heads, self.head_dim)
        v = v.view(B, S, self.num_heads, self.head_dim)
        
        # Get gates from conv output
        i = self.i_proj(x_conv).view(B, S, self.num_heads, self.head_dim)
        f = self.f_proj(x_conv).view(B, S, self.num_heads, self.head_dim)
        
        # Apply exponential activation to gates (while preventing overflow)
        i = torch.exp(torch.clamp(i, max=10))
        f = torch.exp(torch.clamp(f, max=10))

        # Initialize states
        C = torch.zeros(B, self.num_heads, self.head_dim, self.head_dim, device=q.device)
        n = torch.zeros(B, self.num_heads, self.head_dim, device=q.device)
        
        outputs = []

        for t in range(S):
            qt = q[:, t]  # [B, H, D]
            kt = k[:, t]  # [B, H, D]
            vt = v[:, t]  # [B, H, D]
            it = i[:, t]  # [B, H, D]
            ft = f[:, t]  # [B, H, D]
            
            # Update memory matrix: C = f ⊙ C + i ⊙ (v ⊗ k)
            # Expand gates for matrix operations
            ft_expanded = ft.unsqueeze(-1)  # [B, H, D, 1]
            it_expanded = it.unsqueeze(-1)  # [B, H, D, 1]
            
            vt_expanded = vt.unsqueeze(-1)  # [B, H, D, 1]
            kt_expanded = kt.unsqueeze(-2)  # [B, H, 1, D]
            
            C = ft_expanded * C + it_expanded * (vt_expanded @ kt_expanded)
            
            # Update normalizer: n = f ⊙ n + i ⊙ k
            n = ft * n + it * kt
            
            # Query processing with stabilization
            h_tilde = torch.einsum('bhij,bhj->bhi', C, qt)  # [B, H, D]
            
            # Compute normalizer term
            n_term = torch.abs(torch.einsum('bhd,bhd->bh', n, qt))  # [B, H]
            n_term = torch.maximum(n_term, torch.ones_like(n_term))  # max(|q^T n|, 1)
            
            # Normalize output
            h_t = h_tilde / n_term.unsqueeze(-1)  # [B, H, D]
            
            outputs.append(h_t)

        # Stack outputs
        output = torch.stack(outputs, dim=1)  # [B, S, H, D]
        output = output.permute(0, 2, 1, 3)  # [B, H, S, D] for consistency
        
        return output
#############################

class mLSTMLayer(nn.Module):
    def __init__(self, embedding_dim, proj_blocksize, bias=False):
        super(mLSTMLayer, self).__init__()
        self.outer_embedding_dim = embedding_dim
        self.inner_embedding_dim = 2 * embedding_dim
        self.proj_blocksize = proj_blocksize
        self.bias = bias

        self.proj_up = nn.Linear(in_features=self.outer_embedding_dim,
                                 out_features= 2 * self.inner_embedding_dim,
                                 bias=bias)
        self.num_proj_heads = self.inner_embedding_dim // proj_blocksize
        self.q_proj = BlockDiagonalProj(input_dim=self.inner_embedding_dim, num_heads=self.num_proj_heads)
        self.k_proj = BlockDiagonalProj(input_dim=self.inner_embedding_dim, num_heads=self.num_proj_heads)
        self.v_proj = BlockDiagonalProj(input_dim=self.inner_embedding_dim, num_heads=self.num_proj_heads)

        self.conv1d = CausalConv1d(feature_dim=self.inner_embedding_dim, kernel_size=4)
        self.conv_swish = nn.SiLU()

        ############################     EDIT      ##################################
        self.mlstm_cell = mLSTMCell(self.inner_embedding_dim, self.num_proj_heads)
        ##############################################################

        self.ogate_swish = nn.SiLU()
        self.learnable_skip_con = nn.Parameter(torch.ones(self.inner_embedding_dim, requires_grad=True))
        self.proj_down = nn.Linear(in_features=self.inner_embedding_dim,
                                 out_features=self.outer_embedding_dim,
                                 bias=bias)



    def forward(self, x):
        B, S, _ = x.shape
        x_ = F.layer_norm(x, normalized_shape=(self.outer_embedding_dim,))
        x_inner = self.proj_up(x_)
        x_mlstm, z = torch.split(x_inner, split_size_or_sections=self.inner_embedding_dim, dim=-1)
        x_mlstm_conv = self.conv1d(x_mlstm)
        x_mlstm_conv_act = self.conv_swish(x_mlstm_conv)

        q = self.q_proj(x_mlstm_conv_act)
        k = self.k_proj(x_mlstm_conv_act)
        v = self.v_proj(x_mlstm)

        ### EDIT ####
        y_ = self.mlstm_cell(q, k, v, x_mlstm_conv_act)

        B_, NH_, S_, DH_ = y_.shape
        gn_in_1 = y_.transpose(1, 2)  # [B, S, NH, DH]
        gn_in_2 = gn_in_1.reshape(B_ * S_, NH_ * DH_)  # [B*S, NH*DH]
        gn_out = F.group_norm(gn_in_2, num_groups=NH_)
        out = gn_out.view(B, S, -1)
        #############

        mlstm_cell_skip = out + (self.learnable_skip_con * x_mlstm_conv_act)

        h_state = mlstm_cell_skip * self.ogate_swish(z)

        y = self.proj_down(h_state) + x

        return y


#### sLSTM block

In [None]:
### COMPLETE THIS CLASS ####
class sLSTMCell(nn.Module):
    def __init__(self, input_dim, num_heads):
        super(sLSTMCell, self).__init__()
        self.input_dim = input_dim
        self.num_heads = num_heads
        self.head_dim = input_dim // num_heads

    def forward(self, i, f, z, o):
        B, S, D = i.shape

        # Apply activations
        i = torch.exp(torch.clamp(i, max=10))  # Prevent overflow
        f = torch.exp(torch.clamp(f, max=10))  # Prevent overflow
        z = torch.tanh(z)
        o = torch.sigmoid(o)

        # Reshape to heads
        i = i.view(B, S, self.num_heads, self.head_dim)
        f = f.view(B, S, self.num_heads, self.head_dim)
        z = z.view(B, S, self.num_heads, self.head_dim)
        o = o.view(B, S, self.num_heads, self.head_dim)

        # Initialize states
        c = torch.zeros(B, self.num_heads, self.head_dim, device=i.device)
        n = torch.zeros(B, self.num_heads, self.head_dim, device=i.device)
        m = torch.zeros(B, self.num_heads, self.head_dim, device=i.device)

        outputs = []

        # Process sequence
        for t in range(S):
            # Update states
            c = f[:, t] * c + i[:, t] * z[:, t]
            n = f[:, t] * n + i[:, t]
            m = torch.maximum(f[:, t] * m + i[:, t], torch.abs(c))

            # Compute output
            h = o[:, t] * (c / (m + 1e-6))
            outputs.append(h)

        # Stack outputs and reshape properly
        output = torch.stack(outputs, dim=1)  # [B, S, H, D]
        output = output.view(B, self.num_heads, S, self.head_dim)  # [B, H, S, D]

        return output
#############################

class sLSTMLayer(nn.Module):
    def __init__(self, embedding_dim, proj_blocksize, conv_block=True, bias=False):
        super(sLSTMLayer, self).__init__()
        self.inner_embedding_dim = embedding_dim
        self.proj_blocksize = proj_blocksize
        self.conv_block = conv_block
        self.num_heads = 4

        if conv_block:
            self.conv1d = CausalConv1d(feature_dim=self.inner_embedding_dim, kernel_size=4)
            self.conv_swish = nn.SiLU()

        self.i_proj = BlockDiagonalProj(input_dim=self.inner_embedding_dim, num_heads=self.num_heads)
        self.f_proj = BlockDiagonalProj(input_dim=self.inner_embedding_dim, num_heads=self.num_heads)
        self.z_proj = BlockDiagonalProj(input_dim=self.inner_embedding_dim, num_heads=self.num_heads)
        self.o_proj = BlockDiagonalProj(input_dim=self.inner_embedding_dim, num_heads=self.num_heads)

        self.slstm_cell = sLSTMCell(self.inner_embedding_dim, self.num_heads)

        self.up_proj1 = nn.Linear(in_features=self.inner_embedding_dim, out_features= int((4/3)*self.inner_embedding_dim), bias=bias)
        self.up_proj2 = nn.Linear(in_features=self.inner_embedding_dim, out_features= int((4/3)*self.inner_embedding_dim), bias=bias)
        self.up_proj2_gelu = nn.GELU()

        self.down_proj = nn.Linear(in_features=int((4/3)*self.inner_embedding_dim), out_features=self.inner_embedding_dim, bias=bias)

    def forward(self, x):
        B, S, _ = x.shape

        x_ = F.layer_norm(x, normalized_shape=(self.inner_embedding_dim,))

        if self.conv_block:
            x_conv = self.conv1d(x_)
            x_conv_act = self.conv_swish(x_conv)
        else:
            x_conv_act = x_
        i = self.i_proj(x_conv_act)
        f = self.f_proj(x_conv_act)
        z = self.z_proj(x_)
        o = self.o_proj(x_)

        y_ = self.slstm_cell(i, f, z, o)

        B_, NH_, S_, DH_ = y_.shape
        gn_in_1 = y_.transpose(1, 2)
        gn_in_2 = gn_in_1.reshape(B_ * S_, NH_ * DH_)
        gn_out = F.group_norm(gn_in_2, num_groups=NH_)
        out = gn_out.view(B, S, -1)

        skip_con = x + out
        skip_con_layer_norm = F.layer_norm(skip_con, normalized_shape=(self.inner_embedding_dim,))

        up_proj1 = self.up_proj1(skip_con_layer_norm)
        up_proj2 = self.up_proj2(skip_con_layer_norm)
        up_proj2_act = self.up_proj2_gelu(up_proj2)
        down_proj = self.down_proj(up_proj2_act * up_proj1)
        y = down_proj + skip_con
        return y

xLSTM

In [50]:
class xLSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim=128, num_layers=4, proj_blocksize=32):
        super(xLSTM, self).__init__()
        self.embedding_dim = embedding_dim
        self.vocab_size = vocab_size

        # Embedding layer
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        nn.init.normal_(self.embedding.weight, mean=0.0, std=0.02)

        # xLSTM layers with 1:1 ratio (TODO: 7:1 is best according to paper)
        self.layers = nn.ModuleList([
            nn.ModuleDict({
                "sLSTM": sLSTMLayer(embedding_dim, proj_blocksize),
                "mLSTM": mLSTMLayer(embedding_dim, proj_blocksize)
            }) for _ in range(num_layers)
        ])

        # Output layer
        self.output_layer = nn.Linear(embedding_dim, vocab_size)
        nn.init.normal_(self.output_layer.weight, mean=0.0, std=0.02)

    def forward(self, x):
        # Embedding
        x = self.embedding(x)  # [B, S, D]

        # Pass through xLSTM layers
        for layer_dict in self.layers:
            # Apply sLSTM then mLSTM
            x = layer_dict["sLSTM"](x)
            x = layer_dict["mLSTM"](x)

        # Output projection
        output = self.output_layer(x)  # [B, S, vocab_size]

        return output

### Checkpoint Management Functions

In [None]:
import glob

def save_checkpoint(model, optimizer, epoch, train_losses, val_losses, train_perplexities, val_perplexities, checkpoint_dir):
    """Save model checkpoint to Google Drive"""
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    checkpoint_path = os.path.join(checkpoint_dir, f'xlstm_checkpoint_epoch_{epoch+1}_{timestamp}.pt')
    
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_losses': train_losses,
        'val_losses': val_losses,
        'train_perplexities': train_perplexities,
        'val_perplexities': val_perplexities,
        'vocab_size': tokenizer.vocab_size,
        'model_config': {
            'embedding_dim': model.embedding_dim,
            'vocab_size': model.vocab_size
        }
    }
    
    torch.save(checkpoint, checkpoint_path)
    print(f"Checkpoint saved: {checkpoint_path}")
    return checkpoint_path

def load_checkpoint(checkpoint_path, model, optimizer):
    """Load model checkpoint from Google Drive"""
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    return (checkpoint['epoch'], 
            checkpoint['train_losses'], 
            checkpoint['val_losses'],
            checkpoint['train_perplexities'], 
            checkpoint['val_perplexities'])

def cleanup_old_checkpoints(checkpoint_dir, keep_last_n=3):
    """Keep only the N most recent checkpoints"""
    checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'xlstm_checkpoint_*.pt'))
    if len(checkpoint_files) > keep_last_n:
        # Sort by modification time (newest first)
        checkpoint_files.sort(key=os.path.getmtime, reverse=True)
        # Remove older checkpoints
        for old_checkpoint in checkpoint_files[keep_last_n:]:
            os.remove(old_checkpoint)
            print(f"Removed old checkpoint: {old_checkpoint}")

def find_latest_checkpoint(checkpoint_dir):
    """Find the most recent checkpoint"""
    checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'xlstm_checkpoint_*.pt'))
    if checkpoint_files:
        return max(checkpoint_files, key=os.path.getmtime)
    return None

# Configuration
SAVE_EVERY_N_BATCHES = 5000  # Save checkpoint every N batches
RESUME_FROM_CHECKPOINT = False

print("Checkpoint functions loaded successfully!")

### **Train the Model** (1 point)

In [None]:
import matplotlib.pyplot as plt
import math

# Initialize model
model = xLSTM(vocab_size=tokenizer.vocab_size, embedding_dim=128, num_layers=2, proj_blocksize=32)
model = model.to(device)

# Training parameters
learning_rate = 1e-3
num_epochs = 3
warmup_epochs = 1
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

# Create learning rate scheduler with warmup + cosine decay
def create_warmup_cosine_scheduler(optimizer, warmup_epochs, total_epochs):
    """Create scheduler with linear warmup followed by cosine decay"""
    from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR
    
    # Linear warmup from 0.1 to 1.0 of base LR
    warmup_scheduler = LinearLR(
        optimizer, 
        start_factor=0.1,
        end_factor=1.0,
        total_iters=warmup_epochs
    )
    
    # Cosine annealing from peak to 10% of base LR
    cosine_scheduler = CosineAnnealingLR(
        optimizer,
        T_max=total_epochs - warmup_epochs,
        eta_min=learning_rate * 0.1
    )
    
    # Combine schedulers
    scheduler = SequentialLR(
        optimizer,
        schedulers=[warmup_scheduler, cosine_scheduler],
        milestones=[warmup_epochs]
    )
    
    return scheduler

scheduler = create_warmup_cosine_scheduler(optimizer, warmup_epochs, num_epochs)

# Training history
train_losses = []
val_losses = []
train_perplexities = []
val_perplexities = []
learning_rates = []  # Track learning rate changes

# Resume from checkpoint if requested
start_epoch = 0
if RESUME_FROM_CHECKPOINT:
    latest_checkpoint = find_latest_checkpoint(checkpoint_dir)
    if latest_checkpoint:
        print(f"Resuming from checkpoint: {latest_checkpoint}")
        start_epoch, train_losses, val_losses, train_perplexities, val_perplexities = load_checkpoint(
            latest_checkpoint, model, optimizer)
        start_epoch += 1  # Start from next epoch
        print(f"Resuming training from epoch {start_epoch + 1}")
    else:
        print("No checkpoint found, starting fresh training")

def calculate_perplexity(loss):
    return math.exp(loss)

def evaluate_model(model, dataloader):
    model.eval()
    total_loss = 0
    num_batches = 0

    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            outputs = model(x)

            # Reshape for loss calculation
            outputs = outputs.view(-1, tokenizer.vocab_size)
            y = y.view(-1)

            loss = criterion(outputs, y)
            total_loss += loss.item()
            num_batches += 1

    return total_loss / num_batches

# Training loop
print("Starting training...")
for epoch in range(start_epoch, start_epoch + num_epochs):
    model.train()
    total_train_loss = 0
    num_train_batches = 0

    for batch_idx, (x, y) in enumerate(train_loader):
        x, y = x.to(device), y.to(device)

        # Forward pass
        outputs = model(x)

        # Reshape for loss calculation
        outputs = outputs.view(-1, tokenizer.vocab_size)
        y = y.view(-1)

        loss = criterion(outputs, y)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        total_train_loss += loss.item()
        num_train_batches += 1

        if (batch_idx + 1) % 100 == 0:
            current_lr = scheduler.get_last_lr()[0]
            print(f'Epoch {epoch+1}/{start_epoch + num_epochs}, Batch {batch_idx+1}/{len(train_loader)}, Loss: {loss.item():.4f}, LR: {current_lr:.6f}')
        
        # Save checkpoint every N batches
        if (batch_idx + 1) % SAVE_EVERY_N_BATCHES == 0:
            checkpoint_path = save_checkpoint(
                model, optimizer, epoch, train_losses, val_losses, 
                train_perplexities, val_perplexities, checkpoint_dir
            )
            # Clean up old checkpoints
            cleanup_old_checkpoints(checkpoint_dir, keep_last_n=3)

    # Step the scheduler after each epoch
    scheduler.step()
    current_lr = scheduler.get_last_lr()[0]
    learning_rates.append(current_lr)

    print(f'Finished epoch {epoch+1}')
    # Calculate average training loss
    avg_train_loss = total_train_loss / num_train_batches
    train_losses.append(avg_train_loss)
    train_perplexities.append(calculate_perplexity(avg_train_loss))

    # Evaluate on validation set
    avg_val_loss = evaluate_model(model, val_loader)
    val_losses.append(avg_val_loss)
    val_perplexities.append(calculate_perplexity(avg_val_loss))

    print(f'Epoch {epoch+1}/{start_epoch + num_epochs}:')
    print(f'  Train Loss: {avg_train_loss:.4f}, Train Perplexity: {train_perplexities[-1]:.4f}')
    print(f'  Val Loss: {avg_val_loss:.4f}, Val Perplexity: {val_perplexities[-1]:.4f}')
    print(f'  Learning Rate: {current_lr:.6f}')
    print('-' * 60)

print("Training completed!")

### **Showcasing plots and few input & output examples** (0.5 point)

In [None]:
# Plot training curves
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20, 5))

# Plot loss curves
epochs = range(1, len(train_losses) + 1)
ax1.plot(epochs, train_losses, 'b-', label='Training Loss')
ax1.plot(epochs, val_losses, 'r-', label='Validation Loss')
ax1.set_title('Training and Validation Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.legend()
ax1.grid(True)

# Plot perplexity curves
ax2.plot(epochs, train_perplexities, 'b-', label='Training Perplexity')
ax2.plot(epochs, val_perplexities, 'r-', label='Validation Perplexity')
ax2.set_title('Training and Validation Perplexity')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Perplexity')
ax2.legend()
ax2.grid(True)

# Plot learning rate schedule
if learning_rates:
    ax3.plot(epochs, learning_rates, 'g-', label='Learning Rate')
    ax3.set_title('Learning Rate Schedule')
    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('Learning Rate')
    ax3.legend()
    ax3.grid(True)
    ax3.set_yscale('log')  # Log scale for better visualization

plt.tight_layout()
plt.show()

# Text generation function
def generate_text(model, tokenizer, prompt, max_length=200, temperature=1.0):
    model.eval()

    # Encode the prompt
    tokens = tokenizer.encode(prompt)
    input_ids = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(device)

    generated = []

    with torch.no_grad():
        for _ in range(max_length):
            # Get model predictions
            outputs = model(input_ids)

            # Get the last token's logits
            next_token_logits = outputs[0, -1, :] / temperature

            # Sample from the distribution
            probs = F.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)

            # Add the new token to the sequence
            generated.append(next_token.item())
            input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)

            # Keep only the last seq_length tokens to avoid memory issues
            if input_ids.size(1) > seq_length:
                input_ids = input_ids[:, -seq_length:]

    # Decode the generated text
    generated_text = tokenizer.decode(generated)
    return prompt + generated_text

# Generate some example texts
print("=== Text Generation Examples ===\n")

prompts = [
    "ROMEO:",
    "To be or not to be,",
    "HAMLET:",
    "Fair is foul and"
]

for i, prompt in enumerate(prompts, 1):
    print(f"Example {i}:")
    print(f"Prompt: '{prompt}'")
    generated = generate_text(model, tokenizer, prompt, max_length=150, temperature=0.8)
    print(f"Generated: {generated}")
    print("-" * 80)
    print()

print("=== Training Summary ===")
print(f"Final Training Loss: {train_losses[-1]:.4f}")
print(f"Final Validation Loss: {val_losses[-1]:.4f}")
print(f"Final Training Perplexity: {train_perplexities[-1]:.4f}")
print(f"Final Validation Perplexity: {val_perplexities[-1]:.4f}")
if learning_rates:
    print(f"Final Learning Rate: {learning_rates[-1]:.6f}")
print(f"Model Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Vocabulary Size: {tokenizer.vocab_size}")
print(f"Sequence Length: {seq_length}")
print(f"Number of Training Samples: {len(train_dataset)}")
print(f"Number of Validation Samples: {len(val_dataset)}")