# Lesson 5: Transformers

*Teachers:* Fares Schulz, Lina Campanella

In this course we will cover:
1. Building our first transformer
2. Visualizing the attention

## Bach Chorale Generation with Transformers

We will use the same data set as last time and prepocess it the same way.

In [None]:
import urllib.request
import tarfile
from pathlib import Path
import pandas as pd

# Download the dataset using urllib and extract with tarfile
download_link = "https://github.com/iCorv/jsb-chorales-dataset/raw/main/jsb_chorales.tar"
data_dir = Path('resources/_data/jsb_chorales')
tar_path = data_dir / 'jsb_chorales.tar'

# Create directory if it doesn't exist
data_dir.mkdir(parents=True, exist_ok=True)

# Download the file if it doesn't already exist
if not tar_path.exists():
    print(f"Downloading dataset from {download_link}")
    urllib.request.urlretrieve(download_link, tar_path)
    print(f"Downloaded to {tar_path}")
else:
    print(f"Dataset already exists at {tar_path}")

# Extract the tar file
if tar_path.exists() and not (data_dir / 'jsb_chorales').exists():
    print(f"Extracting {tar_path}")
    with tarfile.open(tar_path, 'r') as tar:
        tar.extractall(path=data_dir)
    print(f"Extracted to {data_dir}")

filepath = str(tar_path)
print(f"Dataset available at: {filepath}")

In [None]:
jsb_chorales_dir = Path(filepath).parent
train_files = sorted(jsb_chorales_dir.glob("train/chorale_*.csv"))
valid_files = sorted(jsb_chorales_dir.glob("valid/chorale_*.csv"))
test_files = sorted(jsb_chorales_dir.glob("test/chorale_*.csv"))

In [None]:
def load_chorales(filepaths):
    return [pd.read_csv(filepath).values.tolist() for filepath in filepaths]

train_chorales = load_chorales(train_files)
valid_chorales = load_chorales(valid_files)
test_chorales = load_chorales(test_files)

In [None]:
notes = set()
for chorales in (train_chorales, valid_chorales, test_chorales):
    for chorale in chorales:
        for chord in chorale:
            notes |= set(chord)

n_notes = len(notes)
min_note = min(notes - {0})
max_note = max(notes)

assert min_note == 36
assert max_note == 81

In [None]:
from resources._code.synthesizer import SimpleSynth

baroque_synth = SimpleSynth(tempo=160, amplitude=0.1, sample_rate=44100, baroque_tuning=True)

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader

def preprocess(window):
    # Shift values: keep 0 as 0 (silence), shift other notes to start from 1
    window = torch.where(window == 0, window, window - min_note + 1)
    return window.reshape(-1)  # convert to arpeggio (flatten to 1D sequence)

class BachDataset(Dataset):
    def __init__(self, chorales, window_size=64, window_shift=32):
        self.chorales = chorales
        self.window_size = window_size
        self.window_shift = window_shift
        self.windows = self._create_windows()
    
    def _create_windows(self):
        windows = []
        for chorale in self.chorales:
            chorale_tensor = torch.tensor(chorale, dtype=torch.long)
            
            # Create sliding windows
            for i in range(0, len(chorale) - self.window_size, self.window_shift):
                window = chorale_tensor[i:i + self.window_size + 1]
                if len(window) == self.window_size + 1:  # Ensure full window
                    windows.append(window)
        
        return windows
    
    def __len__(self):
        return len(self.windows)
    
    def __getitem__(self, idx):
        window = self.windows[idx]
        # Preprocess: shift note values and flatten
        preprocessed = preprocess(window)
        
        # Create input/target pairs 
        X = preprocessed[:-1]
        Y = preprocessed[1:] # predict next note in each arpeggio, at each step
        
        return X, Y

def bach_dataloader(chorales, batch_size=32, shuffle=False, window_size=32, window_shift=16):
    
    dataset = BachDataset(chorales, window_size, window_shift)
    
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
    
    return dataloader


In [None]:
# load the datasets
train_set = bach_dataloader(train_chorales, shuffle=True)
valid_set = bach_dataloader(valid_chorales)
test_set = bach_dataloader(test_chorales)

Now we will build our Transformer from scratch. We will make a decoder only transformer (like a gpt model) and train our Transformer on Bach chorales and compare it with the RNN approach. We'll use the same JSB Chorales dataset but leverage the Transformer's attention mechanism for better long-range dependencies.

In [None]:
import torch
import torch.nn as nn
import math


class MultiHeadAttention(nn.Module):
    """
    Multi-Head Attention mechanism from "Attention is All You Need" (Vaswani et al., 2017).
    
    Splits the input into multiple attention heads, computes scaled dot-product attention
    in parallel for each head, then concatenates and projects the results.
    
    The attention mechanism allows the model to focus on different positions in the sequence
    when processing each position.  Multiple heads enable attending to different representation
    subspaces simultaneously. 
    
    Args:
        d_model: Dimensionality of the model (embedding dimension)
        num_heads: Number of parallel attention heads
    """
    def __init__(self, d_model=256, num_heads=8):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        # Ensure model dimension is divisible by number of heads
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        # Linear projections for queries, keys, and values
        self.Wv = nn.Linear(d_model, d_model, bias=False) # the Value parameters
        self.Wk = nn.Linear(d_model, d_model, bias=False) # the Key parameters
        self.Wq = nn.Linear(d_model, d_model, bias=False) # the Query parameters
        self.Wo = nn.Linear(d_model, d_model, bias=False) # the output parameters


    def scaled_dot_product_attention(self, query, key, value, attention_mask=None):  
        """
        Compute scaled dot-product attention. 
        
        Attention formula: Attention(Q, K, V) = softmax(Q * K^T / sqrt(d_k)) * V
        
        The scaling by sqrt(d_k) prevents the dot products from growing too large,
        which would push the softmax into regions with very small gradients.
        
        Args:
            query: Query tensor of shape (batch_size, num_heads, seq_len, head_dim)
            key: Key tensor of shape (batch_size, num_heads, seq_len, head_dim)
            value: Value tensor of shape (batch_size, num_heads, seq_len, head_dim)
            attention_mask: Optional mask to prevent attention to certain positions
            
        Returns:
            Tuple of (attention_output, attention_weights) where:
            - attention_output: Weighted sum of values
            - attention_weights: Attention probability distribution
        """      
        d_k = query.size(-1)
        tgt_len, src_len = query.size(-2), key.size(-2)

        # logits = query * key^T / sqrt(d_k)
        logits = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) # (batch_size, num_heads, tgt_len, src_len)


        # Apply attention mask if provided 
        if attention_mask is not None:
            if attention_mask.dim() == 2:
                # Verify mask dimensions match sequence lengths
                assert attention_mask.size() == (tgt_len, src_len)
                # Add batch dimension to broadcast across all batches
                attention_mask = attention_mask.unsqueeze(0)
                # Add mask to logits (positions with -inf will have ~0 probability after softmax)
                logits = logits + attention_mask
            else:
                raise ValueError(f"Attention mask size {attention_mask.size()}")
                
        
        attention = torch.softmax(logits, dim=-1)
        output = torch.matmul(attention, value) # (batch_size, num_heads, sequence_length, head_dim)
        
        return output, attention

    
    def split_into_heads(self, x):
        """
        Split the input tensor into multiple attention heads.
        
        Reshapes from (batch_size, seq_len, d_model) to 
        (batch_size, num_heads, seq_len, head_dim) where head_dim = d_model / num_heads
        
        This allows parallel computation of attention for each head.
        
        Args:
            x: Input tensor of shape (batch_size, seq_len, d_model)
            
        Returns:
            Reshaped tensor of shape (batch_size, num_heads, seq_len, head_dim)
        """
        batch_size, seq_length, _ = x.size()
        # Reshape to separate heads: (batch_size, seq_len, num_heads, head_dim)
        x = x.view(batch_size, seq_length, self.num_heads, self.head_dim)
        return x.transpose(1, 2) # (batch_size, num_heads, seq_length, head_dim)

    def combine_heads(self, x):
        """
        Combine multiple attention heads back into a single tensor.
        
        Inverse operation of split_into_heads.  Reshapes from 
        (batch_size, num_heads, seq_len, head_dim) to (batch_size, seq_len, d_model)
        
        Args:
            x: Tensor of shape (batch_size, num_heads, seq_len, head_dim)
            
        Returns:
            Combined tensor of shape (batch_size, seq_len, d_model)
        """
        batch_size, _, seq_length, _ = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
        
    
    def forward(self, q, k, v, attention_mask=None):

        q = self.Wq(q) # Shape: (batch_size, seq_len, d_model)
        k = self.Wk(k)
        v = self.Wv(v)

        q = self.split_into_heads(q) # Shape: (batch_size, num_heads, seq_len, head_dim)
        k = self.split_into_heads(k)
        v = self.split_into_heads(v)
        
        # Compute attention for all heads in parallel
        attention_values, attention_weights  = self.scaled_dot_product_attention(query=q, key=k, value=v, attention_mask=attention_mask)

        # Combine heads back into single tensor
        grouped = self.combine_heads(attention_values)

        # Apply output projection
        output = self.Wo(grouped)
        
        # Store attention weights for visualization
        self.attention_weights = attention_weights
        
        return output

Since the Transformer has no recurrence or convolution, it has no inherent notion of token position.  Positional encodings are added to the input embeddings to inject information about the relative or absolute position of tokens in the sequence.

In [None]:
class PositionalEncoding(nn.Module):
    """
    Uses sinusoidal functions of different frequencies:
    PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
    PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
    
    These allow the model to learn to attend by relative positions, as for any fixed
    offset k, PE(pos+k) can be represented as a linear function of PE(pos).
    
    Args:
        d_model: Dimensionality of the model embeddings
        dropout: Dropout probability to apply after adding positional encodings
        max_len: Maximum sequence length to pre-compute encodings for
    """
    def __init__(self, d_model, dropout=0.1, max_len=512):
        super(PositionalEncoding, self).__init__()

        self.dropout = nn.Dropout(p=dropout)
        
        # Create a matrix of shape (max_len, d_model) to hold positional encodings
        pe = torch.zeros(max_len, d_model)

        # Create position indices: [0, 1, 2, ..., max_len-1]
        position = torch.arange(max_len, dtype=torch.float).unsqueeze(1)

        # Create division term for the sinusoidal functions creating different frequencies for different dimensions
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        # Apply sine to even indices (0, 2, 4, ...)
        pe[:, 0::2] = torch.sin(position * div_term)
        # Apply cosine to odd indices (1, 3, 5, ...)
        pe[:, 1::2] = torch.cos(position * div_term)
        # Add batch dimension: shape becomes (1, max_len, d_model)
        pe = pe.unsqueeze(0)
            
        self.register_buffer('pe', pe) # This ensures it moves to GPU with the model but is not trained

    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return x

Position-wise Feed-Forward Network applied after attention in each transformer block. Consists of two linear transformations with a ReLU activation in between:
$$
FFN(x) = ReLU(x W_1 + b_1)W_2 + b_2 
$$

This is applied independently to each position (hence "position-wise"), allowing the model to process and transform the attended information. Standard practice is to use a hidden dimension of 4 * d_model, providing additional model capacity for learning complex transformations.

In [None]:
class PositionWiseFeedForward(nn.Module):
    """
    Position-wise Feed-Forward Network from "Attention is All You Need" (Vaswani et al., 2017).
    
    Args:
        d_model: Input and output dimensionality
        hidden_dim: Hidden layer dimensionality (default: 4 * d_model)
    """
    def __init__(self, d_model, hidden_dim=None):
        super(PositionWiseFeedForward, self).__init__()
        if hidden_dim is None:
            hidden_dim = 4 * d_model 
        self.fc1 = nn.Linear(d_model, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, d_model)
        self.relu = nn.ReLU()


    def forward(self, x):        
        return self.fc2(self.relu(self.fc1(x)))

In [None]:
class DecoderBlock(nn.Module):
    """
    Single decoder block for the Transformer. 
    
    Each sub-layer (attention and feed-forward) has:
    - Pre-LayerNorm: normalization applied before the sub-layer
    - Residual connection: input is added to sub-layer output
    - Dropout: applied to sub-layer output for regularization
    
    Args:
        d_model: Model dimensionality
        dropout: Dropout probability for regularization
        num_heads: Number of attention heads
    """
    def __init__(self, d_model, dropout, num_heads):
        super(DecoderBlock, self).__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.self_attention1 = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
        self.norm2 = nn.LayerNorm(d_model)
        self.self_attention2 = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
        self.norm3 = nn.LayerNorm(d_model)
        self.ff = PositionWiseFeedForward(d_model)
        
        # Dropout layers for regularization after each sub-layer
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout) 
        self.dropout3 = nn.Dropout(dropout)
        
        
    def forward(self, tgt, tgt_mask=None):
        masked_att_output = self.self_attention1(q=self.norm1(tgt), k=self.norm1(tgt), v=self.norm1(tgt), attention_mask=tgt_mask)
        x = tgt + self.dropout1(masked_att_output)

        masked_att_output = self.self_attention2(q=self.norm2(x), k=self.norm2(x), v=self.norm2(x), attention_mask=tgt_mask)
        x = x + self.dropout2(masked_att_output)
        
        ff_output = self.ff(self.norm3(x))
        output = x + self.dropout3(ff_output)

        return output
    

    
class Decoder(nn.Module):
    """
    Decoder module consisting of multiple stacked decoder blocks.
    
    Processes input sequences through:
    1. Token embedding: maps token indices to dense vectors
    2. Positional encoding: adds position information
    3. Stack of decoder blocks: applies self-attention and feed-forward transformations
    
    Args:
        d_model: Model dimensionality
        dropout: Dropout probability
        num_decoder_blocks: Number of decoder blocks to stack
        num_heads: Number of attention heads per block
        shared_embedding: Shared embedding layer (shared with output layer for weight tying)
    """
    def __init__(self, d_model, dropout, num_decoder_blocks, num_heads, shared_embedding):
        super(Decoder, self).__init__()
        self.d_model = d_model
        self.embedding = shared_embedding
        self.positional_encoding = PositionalEncoding(d_model=d_model, dropout=dropout)
          
        self.decoder_blocks = nn.ModuleList([DecoderBlock(d_model, dropout, num_heads) for _ in range(num_decoder_blocks)])
        
        
    def forward(self, tgt, tgt_mask=None):
        x = self.embedding(tgt) 
        x = self.positional_encoding(x)

        for block in self.decoder_blocks:
            x = block(x, tgt_mask=tgt_mask)
            
        return x

In [None]:
class Transformer(nn.Module):
    """
    Decoder-only Transformer model for autoregressive sequence generation.
    
    This architecture is similar to GPT (Generative Pre-trained Transformer), consisting
    of only a decoder without an encoder. It generates sequences autoregressively by
    predicting the next token given all previous tokens.
    
    Key features:
    - Token embedding with weight sharing (embedding weights are shared with output projection)
    - Positional encoding to provide sequence position information
    - Stack of decoder blocks with multi-head self-attention
    - Output projection to vocabulary size
    
    Args:
        n_notes: Vocabulary size (number of possible note values)
        d_model: Model dimensionality (embedding size)
        dropout: Dropout probability for regularization
        n_decoder_layers: Number of stacked decoder blocks
        n_heads: Number of attention heads per block
        batch_size: Batch size for training (used for organization, not computation)
    """
    def __init__(self, **kwargs):
        super(Transformer, self).__init__()

        # Extract hyperparameters from kwargs
        self.n_notes = kwargs.get('n_notes')
        self.d_model = kwargs.get('d_model') 
        self.dropout = kwargs.get('dropout')
        self.n_decoder_layers = kwargs.get('n_decoder_layers')
        self.n_heads = kwargs.get('n_heads')
        self.batch_size = kwargs.get('batch_size')

        # Embedding layer maps token indices to dense vectors
        self.shared_embedding = nn.Embedding(self.n_notes, self.d_model)

        # Decoder processes the embedded and positionally encoded sequence
        self.decoder = Decoder(self.d_model, self.dropout, self.n_decoder_layers, self.n_heads, self.shared_embedding)

        # Output projection maps decoder output back to vocabulary size
        self.fc = nn.Linear(self.d_model, self.n_notes)

        # From the paper "Using the Output Embedding to Improve Language Models" (Press & Wolf, 2017)
        self.fc.weight = self.shared_embedding.weight 
        

    @staticmethod    
    def generate_square_subsequent_mask(size, device=None):
        """
        Generate a causal mask for autoregressive training.
        
        Creates a lower-triangular matrix where:
        - mask[i, j] = 0 if j <= i (can attend to current and past positions)
        - mask[i, j] = -inf if j > i (cannot attend to future positions)
        
        This prevents the model from "cheating" during training by looking at future tokens.
        
        Args:
            size: Sequence length
            device: Device to create mask on (CPU or GPU)
            
        Returns:
            Causal mask of shape (size, size)
        """
        mask = (1 - torch.triu(torch.ones(size, size, device=device), diagonal=1)).bool()
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask
    
        
    def forward(self, x ) -> torch.Tensor:
        # Generate causal mask to prevent attending to future positions
        tgt_mask = self.generate_square_subsequent_mask(x.size(1), device=x.device)

        decoder_output = self.decoder(tgt=x, tgt_mask=tgt_mask)   # Shape: (batch_size, seq_len, d_model)
        output = self.fc(decoder_output) # Shape: (batch_size, seq_len, n_notes)

        return output
    
    def predict(self, x) -> torch.Tensor:
        # Pass through decoder without mask (all positions can attend to all positions)
        decoder_output = self.decoder(tgt=x, tgt_mask=None)  
        output = self.fc(decoder_output)

        return output

In [None]:
def train_epoch(model, train_loader, optimizer, lr_scheduler, criterion, device):
    """
    Train the model for one epoch.
    
    Performs a complete pass through the training dataset, computing loss and gradients
    for each batch and updating model parameters.
    
    Args:
        model: Transformer model to train
        train_loader: DataLoader providing batches of training data
        optimizer: Optimizer for updating model parameters
        lr_scheduler: Optional learning rate scheduler
        criterion: Loss function (CrossEntropyLoss for classification)
        device: Device to run computations on (CPU or GPU)
        
    Returns:
        Tuple of (average_loss, accuracy) for the epoch
    """
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for batch_idx, (src, tgt) in enumerate(train_loader): 
        src, tgt = src.to(device), tgt.to(device)
       
       # Clear gradients from previous iteration
        optimizer.zero_grad()
    
        # Forward pass: compute model predictions
        outputs = model(src)
        
        # Reshape for cross entropy loss
        outputs_flat = outputs.reshape(-1, outputs.size(-1))
        tgt_flat = tgt.reshape(-1)
        
        # Compute cross-entropy loss between predictions and targets
        loss = criterion(outputs_flat, tgt_flat)
        
        # Backward pass: compute gradients
        loss.backward()

        # Update model parameters using computed gradients
        optimizer.step()

        if lr_scheduler:
            lr_scheduler.step()
        
        total_loss += loss.item()        
 
        # Calculate accuracy
        _, predicted = outputs_flat.max(dim=1)
        total += tgt_flat.size(0)
        correct += predicted.eq(tgt_flat).sum().item()
    
    avg_loss = total_loss / len(train_loader)
    accuracy = 100. * correct / total
    
    return avg_loss, accuracy

def validate(model, valid_loader, criterion, device):
    """
    Evaluate the model on validation data.
    
    Performs inference on the validation set without updating model parameters,
    used to monitor generalization during training.
    
    Args:
        model: Transformer model to evaluate
        valid_loader: DataLoader providing batches of validation data
        criterion: Loss function for computing validation loss
        device: Device to run computations on
        
    Returns:
        Tuple of (average_loss, accuracy) on validation set
    """
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch_idx, (src, tgt) in enumerate(valid_loader):  
            src, tgt = src.to(device), tgt.to(device)
        
            # Forward pass only (no gradient computation)
            outputs = model(src)
            
            # Reshape for loss computation
            outputs_flat = outputs.reshape(-1, outputs. size(-1))
            tgt_flat = tgt.reshape(-1)
            
            # Compute loss
            loss = criterion(outputs_flat, tgt_flat)
            total_loss += loss.item()

            # Calculate accuracy
            _, predicted = outputs_flat.max(1)
            total += tgt_flat.size(0)
            correct += predicted.eq(tgt_flat).sum().item()
        
        avg_loss = total_loss / len(valid_loader)
        accuracy = 100. * correct / total
    
    return avg_loss, accuracy

In [None]:
import matplotlib.pyplot as plt

# Training loop 
def train_model(model, train_loader, valid_loader, optimizer, lr_scheduler, criterion, epochs=15, device='cpu'):
    """
    Train the model for multiple epochs with validation and loss visualization.
    
    Performs the complete training loop: training for specified number of epochs,
    validating after each epoch, and plotting the learning curves.
    
    Args:
        model: Transformer model to train
        train_loader: DataLoader for training data
        valid_loader: DataLoader for validation data
        optimizer: Optimizer for parameter updates
        lr_scheduler: Optional learning rate scheduler
        criterion: Loss function
        epochs: Number of training epochs
        device: Device to train on
    """
    print("Starting training...")
    
    train_losses = []
    valid_losses = []

    for epoch in range(epochs):
        # Train for one epoch
        train_loss, train_acc = train_epoch(model, train_loader, optimizer, lr_scheduler, criterion, device)
        train_losses.append(train_loss)

        # Evaluate on validation set
        val_loss, val_acc = validate(model, valid_loader, criterion, device)
        valid_losses.append(val_loss)
        
        print(f'Epoch {epoch+1}/{epochs}:')
        print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
        print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
        print('-' * 50)

    print("Training completed!")

    plt.figure(figsize=(10, 4))
    plt.plot(range(epochs), train_losses, label='Train Loss')
    plt.plot(range(epochs), valid_losses, label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.show()


In [None]:
# Define model hyperparameters
config = {
    'n_notes': 47,
    'd_model': 128, 
    'dropout': 0,
    'n_decoder_layers': 4,
    'n_heads': 8,
    'batch_size': 32
}

num_epochs = 15

# Initialize the Transformer model
model = Transformer(**config)

device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = model.to(device) 

print(f"Model initialized with {sum(p.numel() for p in model.parameters())} parameters")
print(f"Training on device: {device}")

criterion = nn.CrossEntropyLoss() 
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.1, total_iters=num_epochs)
   
# Run training
train_model(model=model, train_loader=train_set, valid_loader=valid_set,  optimizer=optimizer, lr_scheduler=None, criterion=criterion, epochs=num_epochs, device=device)

In [None]:
def evaluate_model(model, test_loader, criterion, device):
    """
    Evaluate the trained model on the test set.
    
    Computes final metrics on held-out test data to assess model performance
    on unseen examples.
    
    Args:
        model: Trained Transformer model
        test_loader: DataLoader for test data
        criterion: Loss function
        device: Device to run evaluation on
        
    Returns:
        Tuple of (test_loss, test_accuracy)
    """
    # Set model to evaluation mode
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    # Disable gradient computation for efficiency
    with torch.no_grad():
        for batch_idx, (src, tgt) in enumerate(test_loader): 
            src, tgt = src.to(device), tgt.to(device)
            
            # Forward pass
            outputs = model(src)
            
            # Reshape for cross entropy loss
            outputs_flat = outputs.reshape(-1, outputs.size(-1))
            tgt_flat = tgt.reshape(-1)
            
            # Compute loss
            loss = criterion(outputs_flat, tgt_flat)
            total_loss += loss.item()

            # Calculate accuracy
            _, predicted = outputs_flat.max(1)
            total += tgt_flat.size(0)
            correct += predicted.eq(tgt_flat).sum().item()
        
        avg_loss = total_loss / len(test_loader)
        accuracy = 100. * correct / total
    
        print(f'Test Results:')
        print(f'Test Loss: {avg_loss:.4f}')
        print(f'Test Accuracy: {accuracy:.2f}%')
        
    return avg_loss, accuracy

    
# Evaluate the model on test set
test_loss, test_acc = evaluate_model(model, test_set, criterion, device)

In [None]:
def generate_chorale(model, seed_chords, length=32, context_window=128):
    """
    Generate a Bach chorale continuation using the trained model.
    
    Performs autoregressive generation: predicts one note at a time, appending each
    prediction to the input sequence and using it to predict the next note.
    
    Uses a context window to limit the sequence length seen by the model during generation.
    This prevents the model from seeing positions beyond its training range and maintains
    stable generation quality.
    
    Args:
        model: Trained Transformer model
        seed_chords: Initial chords to condition generation on (list of 4-note chords)
        length: Number of chords to generate (default: 32)
        context_window: Maximum sequence length to feed to model (default: 128 tokens)
                       Should match or be less than training window size
        
    Returns:
        Generated chorale as numpy array of shape (num_chords, 4)
    """
    model.eval()
    
    with torch.no_grad():  # Disable gradient computation for inference
        # Convert seed chords to tensor and preprocess
        seed_tensor = torch.tensor(seed_chords, dtype=torch.long)
        arpeggio = preprocess(seed_tensor)
        arpeggio = arpeggio.unsqueeze(0).to(device)  # Add batch dimension and move to device
        
        # Generate new notes
        for chord in range(length):
            for note in range(4):
                context = arpeggio[:, -context_window:]
            
                # Get model prediction for the current sequence
                outputs = model.predict(context)  # Shape: (1, seq_len, n_notes)
                
                
                # Get the prediction for the last timestep
                last_output = outputs[0, -1, :]  # Shape: (n_notes,)
                
                # Get the most likely next note
                next_note = torch.argmax(last_output, dim=-1, keepdim=True)  # Shape: (1,)
                
                # Append the predicted note to the sequence
                arpeggio = torch.cat([arpeggio, next_note.unsqueeze(0)], dim=1)

        # Convert back to original note range (reverse the preprocessing)
        arpeggio = torch.where(arpeggio == 0, arpeggio, arpeggio + min_note - 1)
        
        # Reshape to chord format (group every 4 notes)
        arpeggio_flat = arpeggio.squeeze(0)  # Remove batch dimension
        n_total_notes = len(arpeggio_flat)
        n_complete_chords = n_total_notes // 4
        
        # Take only complete chords and reshape
        chorale = arpeggio_flat[:n_complete_chords * 4].reshape(-1, 4)
        
        return chorale.cpu().numpy()  # Convert back to numpy for compatibility

In [None]:
# Set a seed from the test set
seed_chords = test_chorales[2][:12]
baroque_synth.play_chorale(seed_chords)   
print("Seed Chorale:")
print(seed_chords)         

In [None]:
# Create a new chorale continuation
new_chorale = generate_chorale(model, seed_chords, length=32)
baroque_synth.play_chorale(new_chorale)
print("Generated Chorale:")
print(new_chorale[4:])

## Analyzing Transformer vs RNN Performance

### Key Differences to Highlight:

1. **Attention Mechanism**: 
   - Transformer can attend to any position in the sequence simultaneously
   - RNN processes sequentially, potentially losing long-term dependencies

2. **Parallelization**: 
   - Transformer training can be parallelized (all positions at once)
   - RNN training is inherently sequential

3. **Musical Structure**:
   - Transformer might better capture harmonic relationships across time
   - Can potentially learn chord progressions and voice leading patterns

4. **Generation Quality**:
   - Compare coherence of generated chorales
   - Look at harmonic consistency and voice independence

### Exercises for Students:

1. **Experiment with attention heads**: Try different numbers of attention heads and see how it affects generation quality

2. **Temperature sampling**: Adjust the temperature parameter to control randomness vs structure

3. **Seed analysis**: Try different seed sequences and observe how the model continues them

4. **Attention visualization**: Plot attention weights to see what the model focuses on (advanced)