# LSTM Training 

In [2]:
import torch, json
from torch.utils.data import random_split
from tokenizer import ABCTokenizer
from dataset   import LeadSheetDataset, collate_fn

# Read raw ABC file
with open("leadsheets.abc", "r") as f:
    raw_data = f.read()

with open("cache/abc_maps.json", "r", encoding="utf-8") as f:
    maps = json.load(f)

chord_map  = maps["chord_map"]
header_map = maps["header_map"]
inline_map = maps["inline_map"]

tunes = raw_data.strip().split("\n\n")
tokenizer = ABCTokenizer(chord_map=chord_map, header_map=header_map, inline_hdr_map=inline_map)
tokenizer.build_vocab(tunes)

dataset = LeadSheetDataset(tunes, tokenizer)
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_ds, val_ds, test_ds = random_split(dataset, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(42))

pad_idx = tokenizer.stoi[tokenizer.pad_token]
collate = lambda batch: collate_fn(batch, pad_idx)

In [3]:
import time 
import math
import torch
import torch.nn as nn
import torch.optim as optim


def train_epoch(model, loader, loss_fn, opt, device, clip=1.0, print_every=50):
    model.train()
    total_loss, tokens = 0.0, 0
    total_batches = len(loader)
    start_time = time.time()

    for batch_idx, (x, y) in enumerate(loader):
        x, y = x.to(device), y.to(device)
        opt.zero_grad(set_to_none=True)
        logits = model(x)
        loss = loss_fn(logits.reshape(-1, logits.size(-1)), y.reshape(-1))
        loss.backward()

        # Gradient clipping
        nn.utils.clip_grad_norm_(model.parameters(), clip)

        opt.step()

        # Count non-padding tokens for accurate perplexity
        non_pad_mask = (y != 0)
        num_tokens = non_pad_mask.sum().item()

        total_loss += loss.item() * num_tokens
        tokens += num_tokens

        # Print progress
        if (batch_idx + 1) % print_every == 0 or batch_idx == total_batches - 1:
            elapsed_time = time.time() - start_time
            current_loss = total_loss / tokens if tokens > 0 else float('inf')
            current_ppl = math.exp(current_loss)

            progress = (batch_idx + 1) / total_batches * 100
            time_per_batch = elapsed_time / (batch_idx + 1)
            eta = time_per_batch * (total_batches - batch_idx - 1)

            print(f"  Batch {batch_idx+1}/{total_batches} ({progress:.1f}%) | "
                  f"Loss: {current_loss:.4f} | PPL: {current_ppl:.2f} | "
                  f"Time: {elapsed_time:.1f}s | ETA: {eta:.1f}s")

    avg_loss = total_loss / tokens if tokens > 0 else float('inf')
    perplexity = math.exp(avg_loss)
    total_time = time.time() - start_time

    print(f"  Epoch completed in {total_time:.1f}s | Avg Loss: {avg_loss:.4f} | Perplexity: {perplexity:.2f}")

    return avg_loss, perplexity


def evaluate(model, loader, loss_fn, device):
    model.eval()
    total_loss, tokens = 0.0, 0

    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)

            # Count non-padding tokens for accurate perplexity
            non_pad_mask = (y != 0)
            num_tokens = non_pad_mask.sum().item()

            # Calculate loss
            loss = loss_fn(logits.reshape(-1, logits.size(-1)), y.reshape(-1))

            total_loss += loss.item() * num_tokens
            tokens += num_tokens

    avg_loss = total_loss / tokens if tokens > 0 else float('inf')
    perplexity = math.exp(avg_loss)

    return avg_loss, perplexity

In [4]:
from torch.utils.data import DataLoader

#create DataLoaders from  datasets
batch_size = 32 

train_loader = DataLoader(
    train_ds,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate,
)

val_loader = DataLoader(
    val_ds,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate,
)

test_loader = DataLoader(
    test_ds,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate,
)

In [5]:
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Model hyperparameters
vocab_size = tokenizer.vocab_size()
embed_dim = 512
hidden_dim = 1024
num_layers = 2
dropout = 0.30

# Initialize model
from models import LSTMModel
model = LSTMModel(
    vocab_size=vocab_size,
    embed_dim=embed_dim,
    hidden_dim=hidden_dim,
    num_layers=num_layers,
    pad_idx=pad_idx,
    dropout=dropout,
).to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)
optimizer = optim.AdamW(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.5)

# Print model summary
print(model)
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")

Using device: cuda
LSTMModel(
  (embed): Embedding(215, 512, padding_idx=0)
  (lstm): LSTM(512, 1024, num_layers=2, batch_first=True, dropout=0.3)
  (fc_out): Linear(in_features=1024, out_features=215, bias=True)
  (dropout): Dropout(p=0.3, inplace=False)
)
Total parameters: 15,026,903


In [None]:
# Create model directory
model_save_dir = "saved_models"
os.makedirs(model_save_dir, exist_ok=True)

# Training configuration
num_epochs = 15
best_val_loss = float('inf')
patience = 5  # For early stopping
no_improvement = 0
checkpoint_interval = 1  # Save checkpoints every epoch

# For plotting
train_losses = []
val_losses = []
train_perplexities = []
val_perplexities = []
learning_rates = []

print("Starting training...")
for epoch in range(1, num_epochs + 1):
    # Track current learning rate
    current_lr = optimizer.param_groups[0]['lr']
    learning_rates.append(current_lr)
    print(f"Epoch {epoch}/{num_epochs} [LR: {current_lr:.6f}]")

    # Training
    train_loss, train_perplexity = train_epoch(model, train_loader, criterion, optimizer, device)
    train_losses.append(train_loss)
    train_perplexities.append(train_perplexity)

    # Validation
    val_loss, val_perplexity = evaluate(model, val_loader, criterion, device)
    val_losses.append(val_loss)
    val_perplexities.append(val_perplexity)

    # Update learning rate
    scheduler.step(val_loss)

    # Print stats
    print(f"Epoch {epoch}/{num_epochs} | Train Loss: {train_loss:.4f} | Train PPL: {train_perplexity:.2f} | Val Loss: {val_loss:.4f} | Val PPL: {val_perplexity:.2f}")

    # Save checkpoint at regular intervals
    if epoch % checkpoint_interval == 0:
        model_path = os.path.join(model_save_dir, f"lstm_checkpoint_epoch{epoch}.pt")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'val_loss': val_loss,
            'val_perplexity': val_perplexity,
            'train_loss': train_loss,
            'train_perplexity': train_perplexity,
            # Save hyperparameters for model recreation
            'vocab_size': vocab_size,
            'embed_dim': embed_dim,
            'hidden_dim': hidden_dim,
            'num_layers': num_layers,
            'dropout': dropout,
            'bidirectional': bidirectional,
        }, model_path)
        print(f"Checkpoint saved to {model_path}")

    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improvement = 0
        model_path = os.path.join(model_save_dir, "best_lstm_model.pt")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'val_loss': val_loss,
            'val_perplexity': val_perplexity,
            # Save hyperparameters for model recreation
            'vocab_size': vocab_size,
            'embed_dim': embed_dim,
            'hidden_dim': hidden_dim,
            'num_layers': num_layers,
            'dropout': dropout,
            'bidirectional': bidirectional,
        }, model_path)
        print(f"✓ New best model saved (val_ppl={val_perplexity:.2f})")
    else:
        no_improvement += 1
        print(f"! No improvement for {no_improvement} epochs")

    # Early stopping
    if no_improvement >= patience:
        print(f"Early stopping after {epoch} epochs with no improvement")
        break

# Save final model
final_model_path = os.path.join(model_save_dir, "final_lstm_model.pt")
torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict(),
    'val_loss': val_loss,
    'val_perplexity': val_perplexity,
    # Training history
    'train_losses': train_losses,
    'val_losses': val_losses,
    'train_perplexities': train_perplexities,
    'val_perplexities': val_perplexities,
    'learning_rates': learning_rates,
    # Save hyperparameters for model recreation
    'vocab_size': vocab_size,
    'embed_dim': embed_dim,
    'hidden_dim': hidden_dim,
    'num_layers': num_layers,
    'dropout': dropout,
    'bidirectional': bidirectional,
}, final_model_path)
print(f"Final model saved to {final_model_path}")

In [None]:
import matplotlib.pyplot as plt


plt.figure(figsize=(15, 10))

# Plot loss
plt.subplot(2, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True)

# Plot perplexity
plt.subplot(2, 2, 2)
plt.plot(train_perplexities, label='Training Perplexity')
plt.plot(val_perplexities, label='Validation Perplexity')
plt.xlabel('Epoch')
plt.ylabel('Perplexity')
plt.title('Training and Validation Perplexity')
plt.legend()
plt.grid(True)

# Plot learning rate
plt.subplot(2, 2, 3)
plt.plot(learning_rates)
plt.xlabel('Epoch')
plt.ylabel('Learning Rate')
plt.title('Learning Rate Schedule')
plt.grid(True)

# Save the plot
plt.tight_layout()
plt.savefig(os.path.join(model_save_dir, 'training_metrics.png'))
plt.show()

# Parameter Hypertuning

In [None]:
%pip install optuna

In [10]:
import optuna 
from datetime import datetime

def objective(trial):
    #define hyperparameters to search
    embed_dim = trial.suggest_categorical('embed_dim', [128, 256, 384, 512])
    hidden_dim = trial.suggest_categorical('hidden_dim', [256, 512, 768])
    num_layers = trial.suggest_int('num_layers', 1, 3)
    dropout = trial.suggest_float('dropout', 0.0, 0.6)
    learning_rate = trial.suggest_float('learning_rate', 1e-4, 1e-2, log=True)
    weight_decay = trial.suggest_float('weight_decay', 1e-6, 1e-3, log=True)  # L2 regularization

    #create model with trial hyperparameters
    model = LSTMModel(
        vocab_size=tokenizer.vocab_size(),
        embed_dim=embed_dim,
        hidden_dim=hidden_dim,
        num_layers=num_layers,
        dropout=dropout,
        pad_idx=pad_idx
    ).to(device)

    #define loss function and optimizer with L2 regularization
    criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    # evaluation with fewer epochs  to speed training up
    max_epochs = 4
    patience = 2

    # smaller train/validation loaders for faster iterations
    smaller_batch_size = 32

    # small subset of data used  for faster training during optimization
    subset_size = min(10000, len(train_ds))
    subset_indices = torch.randperm(len(train_ds))[:subset_size]

    small_train_dataset = torch.utils.data.Subset(train_ds, subset_indices)
    small_train_loader = DataLoader(
        small_train_dataset,
        batch_size=smaller_batch_size,
        shuffle=True,
        collate_fn=lambda batch: collate_fn(batch, pad_idx)
    )

    small_val_loader = DataLoader(
        val_ds,
        batch_size=smaller_batch_size,
        shuffle=False,
        collate_fn=lambda batch: collate_fn(batch, pad_idx)
    )

    #training loop
    best_val_perplexity = float('inf')
    no_improvement = 0
    training_start = time.time()

    for epoch in range(max_epochs):
        # training phase
        model.train()
        train_loss = 0.0
        num_batches = 0

        for x, y in small_train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            logits = model(x)
            loss = criterion(logits.reshape(-1, logits.size(-1)), y.reshape(-1))
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            train_loss += loss.item()
            num_batches += 1

        avg_train_loss = train_loss / num_batches if num_batches > 0 else float('inf')

        #validation phase
        val_loss, val_perplexity = evaluate(model, small_val_loader, criterion, device)

        #intermediate metric for pruning
        trial.report(val_perplexity, epoch)

        # pruning (early stopping for this trial if not promising)
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()

        #track best model and early stopping
        if val_perplexity < best_val_perplexity:
            best_val_perplexity = val_perplexity
            no_improvement = 0
        else:
            no_improvement += 1
            if no_improvement >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                break

        epoch_time = time.time() - training_start
        print(f"Trial {trial.number} | Epoch {epoch+1}/{max_epochs} | "
              f"Train Loss: {avg_train_loss:.4f} | Val PPL: {val_perplexity:.2f} | "
              f"Time: {epoch_time:.1f}s")

    return best_val_perplexity

# run optimization
def run_optuna_optimization(n_trials=10):
    study_name = f"lstm_optimization_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    study = optuna.create_study(
        study_name=study_name,
        direction="minimize",
        pruner=optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=1)
    )

    print(f"Running Optuna optimization with {n_trials} trials...")
    study.optimize(objective, n_trials=n_trials)

    print("\nBest trial:")
    best_trial = study.best_trial
    print(f"  Value (Val Perplexity): {best_trial.value:.4f}")
    print("  Params:")
    for key, value in best_trial.params.items():
        print(f"    {key}: {value}")

    #save best hyperparameters
    best_params = best_trial.params
    save_path = os.path.join(model_save_dir, f"best_hyperparams_{study_name}.pt")
    torch.save(best_params, save_path)
    print(f"Best hyperparameters saved to {save_path}")

    return study, best_params

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# run optimization
study, best_params = run_optuna_optimization(n_trials=10) 

In [12]:
#train the final model with best hyperparameters
def train_with_best_params(best_params, num_epochs=15):
    print("Training final model with best hyperparameters...")

    #create model with best hyperparameters
    final_model = LSTMModel(
        vocab_size=tokenizer.vocab_size(),
        embed_dim=best_params['embed_dim'],
        hidden_dim=best_params['hidden_dim'],
        num_layers=best_params['num_layers'],
        dropout=best_params['dropout'],
        pad_idx=pad_idx
    ).to(device)

    #define loss function and optimizer
    criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)
    optimizer = optim.AdamW(
        final_model.parameters(),
        lr=best_params['learning_rate'],
        weight_decay=best_params['weight_decay']
    )

    #train with the full dataset and the best hyperparameters
    train_loader = DataLoader(
        train_ds,
        batch_size=32,  
        shuffle=True,
        collate_fn=lambda batch: collate_fn(batch, pad_idx)
    )

    val_loader = DataLoader(
        val_ds,
        batch_size=32,
        shuffle=False,
        collate_fn=lambda batch: collate_fn(batch, pad_idx)
    )

    #training loop
    best_val_loss = float('inf')
    train_losses = []
    val_losses = []

    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        train_loss, train_perplexity = train_epoch(
            final_model, train_loader, criterion, optimizer, device
        )
        val_loss, val_perplexity = evaluate(
            final_model, val_loader, criterion, device
        )

        train_losses.append(train_loss)
        val_losses.append(val_loss)

        print(f"Train Loss: {train_loss:.4f} | Train PPL: {train_perplexity:.2f}")
        print(f"Val Loss: {val_loss:.4f} | Val PPL: {val_perplexity:.2f}")

        # save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_path = os.path.join(model_save_dir, "best_optuna_lstm_model.pt")
            torch.save({
                'epoch': epoch,
                'model_state_dict': final_model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_loss,
                'val_perplexity': val_perplexity,
                'hyperparameters': best_params
            }, best_model_path)
            print(f"Best model saved to {best_model_path}")

    #plot loss curves
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss with Best Hyperparameters')
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(model_save_dir, "best_model_training_curve.png"))
    plt.show()

    return final_model, best_model_path

In [None]:
#train the final model
final_model, best_model_path = train_with_best_params(best_params)