# Nottingham Dataset

In [1]:
from scipy.io import loadmat
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn

class PianoRollDataset(Dataset):
    def __init__(self, sequences, sequence_length=100):
        self.inputs, self.targets = [], []
        for seq in sequences:
            if seq.shape[0] > sequence_length:
                for i in range(seq.shape[0] - sequence_length):
                    self.inputs.append(seq[i:i+sequence_length])
                    self.targets.append(seq[i+1:i+sequence_length+1])
        self.inputs = torch.stack(self.inputs, dim=0)
        self.targets = torch.stack(self.targets, dim=0)

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return self.inputs[idx], self.targets[idx]

def get_nottingham_dataloaders(data_path, batch_size=64):
    def data_generator(path):
        data = loadmat(path)
        X_train = data['traindata'][0]
        X_valid = data['validdata'][0]
        X_test = data['testdata'][0]
        
        for data in [X_train, X_valid, X_test]:
            for i in range(len(data)):
                data[i] = torch.Tensor(data[i].astype(np.float64))
        
        return X_train, X_valid, X_test
    X_train_nottingham, X_val_nottingham, X_test_nottingham = data_generator(data_path)
    # Dataset and DataLoader
    train_dataset = PianoRollDataset(X_train_nottingham)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    
    val_dataset = PianoRollDataset(X_val_nottingham)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
    
    test_dataset = PianoRollDataset(X_test_nottingham)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

    return train_dataloader, val_dataloader, test_dataloader

nottingham_train_dataloader, nottingham_val_dataloader, nottingham_test_dataloader = get_nottingham_dataloaders(
    "/kaggle/input/nottingham-music/Nottingham.mat",
    batch_size=64)

# Muse Dataset

In [2]:
!pip install music21



In [3]:
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from music21 import corpus, note, chord
import random

# Convert a music21 score to a binary piano roll matrix
def score_to_pianoroll(score, time_step=0.25, pitch_range=(21, 109), max_length=500):
    lower, upper = pitch_range
    pr = np.zeros((max_length, upper - lower), dtype=np.float32)
    for n in score.flat.notes:
        if isinstance(n, note.Note):
            pitches = [n.pitch.midi]
        elif isinstance(n, chord.Chord):
            pitches = [p.midi for p in n.pitches]
        else:
            continue
        onset = int(n.offset / time_step)
        duration = int(n.quarterLength / time_step)
        if onset < max_length:
            for p in pitches:
                if lower <= p < upper:
                    pr[onset : min(onset + duration, max_length), p - lower] = 1.0
    return pr

class MuseDataset(Dataset):
    def __init__(self, sequences, sequence_length=100):
        self.inputs, self.targets = [], []
        for seq in sequences:
            T, _ = seq.shape
            if T > sequence_length:
                for i in range(T - sequence_length):
                    self.inputs.append(seq[i : i + sequence_length])
                    self.targets.append(seq[i + 1 : i + sequence_length + 1])
        if len(self.inputs) == 0:
            raise ValueError(f"No sequences ≥ {sequence_length+1} timesteps found.")
        self.inputs = torch.tensor(self.inputs)
        self.targets = torch.tensor(self.targets)

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return self.inputs[idx], self.targets[idx]

def load_musedata_pianorolls(sequence_length=100, seed=42):
    random.seed(seed)
    composers = ['beethoven', 'corelli', 'haydn', 'handel', 'scarlatti']
    all_pieces = []
    for comp in composers:
        all_pieces += corpus.getComposer(comp)
    random.shuffle(all_pieces)

    rolls = []
    for p in all_pieces:
        try:
            score = corpus.parse(p)
            pr = score_to_pianoroll(score)
            if pr.shape[0] > sequence_length:
                rolls.append(pr)
        except Exception as e:
            print(f"Skipping {p} (parse error): {e}")

    if len(rolls) == 0:
        raise RuntimeError("No valid piano rolls extracted. Try lowering sequence_length.")

    return rolls

def get_musedata_dataloaders(batch_size=64, sequence_length=100):
    rolls = load_musedata_pianorolls(sequence_length=sequence_length)
    n = len(rolls)
    train, val, test = rolls[: int(0.8*n)], rolls[int(0.8*n): int(0.9*n)], rolls[int(0.9*n):]

    train_ds = MuseDataset(train, sequence_length)
    val_ds   = MuseDataset(val,   sequence_length)
    test_ds  = MuseDataset(test,  sequence_length)

    return (
        DataLoader(train_ds, batch_size=batch_size, shuffle=True),
        DataLoader(val_ds,   batch_size=batch_size, shuffle=False),
        DataLoader(test_ds,  batch_size=batch_size, shuffle=False),
    )

muse_train_dataloader, muse_val_dataloader, muse_test_dataloader = get_musedata_dataloaders(
    batch_size=32,
    sequence_length=100
)
print("Train batches:", len(muse_train_dataloader))
x, y = next(iter(muse_train_dataloader))
print("Example batch shapes:", x.shape, y.shape)

  return self.iter().getElementsByClass(classFilterList)
  self.inputs = torch.tensor(self.inputs)


Train batches: 363
Example batch shapes: torch.Size([32, 100, 88]) torch.Size([32, 100, 88])


# JSB Chorales Dataset

In [4]:
import pickle
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import matplotlib.pyplot as plt

# ─── 1. Load JSB Chorales Dataset ──────────────────────────────────────────
with open("/kaggle/input/nottingham-music/jsb-chorales-16th.pkl", "rb") as f:
    jsb_data = pickle.load(f, encoding="latin1")
# jsb_data is a dict with keys: 'train', 'valid', 'test'
# Each is a list of sequences, each sequence is a list of time steps,
# each time step is a list of active MIDI pitches (integers)

# ─── 2. Preprocess: Convert to piano roll (binary vector per time step) ────
all_pitches = set()
for split in ['train', 'valid', 'test']:
    for seq in jsb_data[split]:
        for chord in seq:
            all_pitches.update(chord)
all_pitches = sorted(all_pitches)
pitch2idx = {p: i for i, p in enumerate(all_pitches)}
num_pitches = len(all_pitches)

def seq_to_pianoroll(seq):
    # seq: list of time steps, each is a list of pitches
    roll = np.zeros((len(seq), num_pitches), dtype=np.float32)
    for t, chord in enumerate(seq):
        for p in chord:
            roll[t, pitch2idx[p]] = 1.0
    return roll

def make_dataset(split, seq_len=32):
    X, Y = [], []
    for seq in jsb_data[split]:
        roll = seq_to_pianoroll(seq)
        if len(roll) > seq_len:
            for i in range(len(roll) - seq_len):
                X.append(roll[i:i+seq_len])
                Y.append(roll[i+1:i+seq_len+1])
    return np.stack(X), np.stack(Y)

class ChoraleDataset(Dataset):
    def __init__(self, split, seq_len=32):
        self.X, self.Y = make_dataset(split, seq_len)
    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        return torch.tensor(self.X[idx]), torch.tensor(self.Y[idx])

# ─── 3. DataLoaders ────────────────────────────────────────────────────────
seq_len = 32
batch_size = 64
train_ds = ChoraleDataset('train', seq_len)
val_ds   = ChoraleDataset('valid', seq_len)
test_ds  = ChoraleDataset('test', seq_len)
jsb_train_dataloader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
jsb_val_dataloader   = DataLoader(val_ds, batch_size=batch_size)
jsb_test_dataloader  = DataLoader(test_ds, batch_size=batch_size)

# Shallow RNN

In [5]:
class ShallowRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, rnn_type='RNN'):
        super().__init__()
        rnn_cls = {'RNN': nn.RNN, 'LSTM': nn.LSTM, 'GRU': nn.GRU}[rnn_type]
        self.rnn_type = rnn_type
        self.rnn = rnn_cls(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        output, _ = self.rnn(x)
        return self.fc(output)

# DT(S)-RNN

In [6]:
class DeepTransitionRNNCell(nn.Module):
    def __init__(self, input_size, hidden_size, transition_size, depth, nonlinearity):
        super().__init__()
        self.input_layer = nn.Linear(input_size, hidden_size)
        self.hidden_layers = nn.ModuleList()
        self.hidden_layers.append(nn.Linear(hidden_size, transition_size))
        for i in range(depth - 2):
            self.hidden_layers.append(nn.Linear(transition_size, transition_size))
        self.hidden_layers.append(nn.Linear(transition_size, hidden_size))

        if nonlinearity == 'sigmoid':
            self.activation = nn.Sigmoid()
        else:
            self.activation = nn.ReLU()

    def forward(self, x, h_prev):
        h = self.activation(self.input_layer(x) + h_prev)  # shortcut
        for layer in self.hidden_layers:
            h = self.activation(layer(h))
        return h

class DTRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, transition_size, depth, nonlinearity='sigmoid'):
        super().__init__()
        self.cell = DeepTransitionRNNCell(input_size, hidden_size, transition_size, depth, nonlinearity)
        self.output = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        batch_size, seq_len, _ = x.size()
        h = torch.zeros(batch_size, self.cell.input_layer.out_features, device=x.device)
        outputs = []
        for t in range(seq_len):
            h = self.cell(x[:, t, :], h)
            outputs.append(self.output(h))
        return torch.stack(outputs, dim=1)

# DOT(S)-RNN

In [7]:
class DOTSRNN(nn.Module):
    def __init__(
        self, 
        input_size, 
        hidden_size, 
        output_size, 
        transition_size, 
        depth, 
        intermediate_output_size, 
        output_depth=2, 
        nonlinearity='sigmoid', 
        intermediate_output_nonlinearity='sigmoid'):
        super().__init__()
        self.cell = DeepTransitionRNNCell(input_size, hidden_size, transition_size, depth, nonlinearity)
        self.output_layers = nn.ModuleList()
        self.output_layers.append(nn.Linear(hidden_size, intermediate_output_size))
        for i in range(output_depth - 2):
            self.output_layers.append(nn.Linear(intermediate_output_size, intermediate_output_size))
        self.output_layers.append(nn.Linear(intermediate_output_size, hidden_size))
        self.output_layers.append(nn.Linear(hidden_size, output_size))
        
        if intermediate_output_nonlinearity == 'sigmoid':
            self.activation = nn.Sigmoid()
        else:
            self.activation = nn.ReLU()

    def forward(self, x):
        batch_size, seq_len, _ = x.size()
        h = torch.zeros(batch_size, self.cell.input_layer.out_features, device=x.device)
        outputs = []
        for t in range(seq_len):
            h = self.cell(x[:, t], h)
            out = h
            for layer in self.output_layers:
                out = self.activation(layer(out))
            outputs.append(out)
        return torch.stack(outputs, dim=1)

# sRNN

In [8]:
class StackedRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=2, rnn_type='RNN'):
        super().__init__()
        rnn_cls = {'RNN': nn.RNN, 'LSTM': nn.LSTM, 'GRU': nn.GRU}[rnn_type]
        self.rnn_type = rnn_type
        self.rnn = rnn_cls(input_size, hidden_size, num_layers=num_layers, batch_first=True)
        self.output = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        output, _ = self.rnn(x)
        return self.output(output)

In [15]:
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch import optim

def lr_schedule(step, initial_lr, beta):
    return initial_lr * (0.1 ** (step / beta))

def train_and_eval(
    model, train_loader, val_loader, vocab_size=None, initial_lr=0.1, beta = 2330, epochs=10, device=None, 
    criterion=None, optimizer_class=None, model_name="model", dataset="dataset"):
    
    device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    if criterion is None:
        raise ValueError("You must provide a loss function as 'criterion'.")
    if optimizer_class is None:
        optimizer_class = optim.Adam
    optimizer = optimizer_class(model.parameters(), lr=initial_lr)

    isPolyphonicDataset = dataset in ["Nottingham", "MuseDataset", "JSBDataset"]
    
    train_losses, val_losses = [], []
    train_accs, val_accs = [], []
    for ep in range(1, epochs+1):
        # === Training ===
        model.train()
        total_train, correct_train, total_train_tokens = 0, 0, 0
        train_loop = tqdm(train_loader, desc=f"Epoch {ep}/{epochs} [Train]", leave=False)
        for x, y in train_loop:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            loss = criterion(logits, y)
            preds = (torch.sigmoid(logits) > 0.5).float()
            correct_train += (preds == y).float().sum().item()
            total_train_tokens += y.numel()
            optimizer.zero_grad()
            loss.backward()

            if isPolyphonicDataset:
                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()

            if isPolyphonicDataset:
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_schedule(ep, initial_lr, beta)
            
            total_train += loss.item()
            train_loop.set_postfix(loss=loss.item())
        train_losses.append(total_train / len(train_loader))
        train_accs.append(correct_train / total_train_tokens)

        # === Validation ===
        model.eval()
        total_val, correct_val, total_val_tokens = 0, 0, 0
        with torch.no_grad():
            val_loop = tqdm(val_loader, desc=f"Epoch {ep}/{epochs} [Val]", leave=False)
            for x, y in val_loop:
                x, y = x.to(device), y.to(device)
                logits = model(x)
                loss = criterion(logits, y)
                preds = (torch.sigmoid(logits) > 0.5).float()
                correct_val += (preds == y).float().sum().item()
                total_val_tokens += y.numel()
                total_val += loss.item()
                val_loop.set_postfix(val_loss=loss.item())
        val_losses.append(total_val / len(val_loader))
        val_accs.append(correct_val / total_val_tokens)

        print(f"Epoch {ep}/{epochs}  "
              f"Train: {train_losses[-1]:.4f} (Acc {train_accs[-1]*100:.2f}%)  "
              f"Val: {val_losses[-1]:.4f} (Acc {val_accs[-1]*100:.2f}%)")

    # === Plotting ===
    plt.figure(figsize=(12,5))
    plt.subplot(1,2,1)
    plt.plot(train_losses, label='Train')
    plt.plot(val_losses,   label='Val')
    plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.legend(); plt.title("Loss")
    plt.subplot(1,2,2)
    plt.plot(train_accs, label='Train Acc')
    plt.plot(val_accs,   label='Val Acc')
    plt.xlabel("Epoch"); plt.ylabel("Accuracy"); plt.legend(); plt.title("Accuracy")
    plt.savefig(model_name + "_" + dataset + "_loss_acc.pdf")
    plt.show()
    return model, train_losses[-1], train_accs[-1], val_losses[-1], val_accs[-1]

In [16]:
from itertools import product

def get_config_combinations(model_type, config_dict):
    keys = list(config_dict.keys())
    values = list(config_dict.values())
    for combo in product(*values):
        yield dict(zip(keys, combo))

hidden_sizes = [10, 50, 100, 150, 200, 400, 600]
transition_sizes = [200]
depths = [2]
intermediate_output_sizes = [200]
output_depths = [2]
num_layers = [2]
rnn_types = ['RNN', 'LSTM', 'GRU']

model_configs = {
    "RNN": {
        "hidden_size": hidden_sizes,
        "rnn_type": rnn_types
    },
    "DT(S)-RNN": {
        "hidden_size": hidden_sizes,
        "transition_size": transition_sizes,
        "depth": depths
    },
    "DOT(S)-RNN": {
        "hidden_size": hidden_sizes,
        "transition_size": transition_sizes,
        "intermediate_output_size": intermediate_output_sizes,
        "depth": depths,
        "output_depth": output_depths
    }
    ,
    "sRNN": {
        "hidden_size": hidden_sizes,
        "num_layers": num_layers,
        "rnn_type": rnn_types
    }
}

beta_values = {
    'Nottingham': 2330,  
    'MuseDataset': 1475,   
    'JSBDataset': 100   
}

import matplotlib.pyplot as plt

def grid_search(train_dataloader, val_dataloader, input_size, output_size, dataset, num_epochs=10):
    for model_type, config_options in model_configs.items():
        train_losses = []
        val_losses = []
        train_accuracies = []
        val_accuracies = []
        labels = []

        for config in get_config_combinations(model_type, config_options):
            print(f"\nTraining {model_type} with config: {config}")
         
            if model_type == "RNN":
                model = ShallowRNN(input_size=input_size, hidden_size=config["hidden_size"], rnn_type=config["rnn_type"], output_size=output_size)
    
            elif model_type == "DT(S)-RNN":
                model = DTRNN(input_size=input_size,
                              hidden_size=config["hidden_size"],
                              transition_size=config["transition_size"],
                              depth=config["depth"],
                              output_size=output_size)
    
            elif model_type == "DOT(S)-RNN":
                model = DOTSRNN(input_size=input_size,
                                hidden_size=config["hidden_size"],
                                transition_size=config["transition_size"],
                                intermediate_output_size=config["intermediate_output_size"],
                                depth=config["depth"],
                                output_depth=config["output_depth"],
                                output_size=output_size)
    
            elif model_type == "sRNN":
                model = StackedRNN(input_size=input_size,
                                   hidden_size=config["hidden_size"],
                                   num_layers=config["num_layers"],
                                   rnn_type=config["rnn_type"],
                                   output_size=output_size)
    
            criterion = nn.BCEWithLogitsLoss()
            optimizer_class = optim.SGD
                        
            model, last_train_loss, last_train_acc, last_val_loss, last_val_acc = train_and_eval(
                model=model,
                train_loader=train_dataloader,
                val_loader=val_dataloader,
                initial_lr=0.1,
                beta=beta_values[dataset],
                epochs=num_epochs,
                model_name=f"{model_type} {config}",
                dataset=dataset,
                criterion=criterion,
                optimizer_class=optimizer_class
            )

            train_losses.append(last_train_loss)
            val_losses.append(last_val_loss)
            train_accuracies.append(last_train_acc)
            val_accuracies.append(last_val_acc)


        x = hidden_sizes

        plt.figure(figsize=(14, 6))
        plt.subplot(1, 2, 1)
        plt.plot(x, train_losses, marker='o', label='Train Loss')
        plt.plot(x, val_losses, marker='o', label='Val Loss')

        plt.xlabel("Hidden Size")
        plt.ylabel("Loss")
        
        plt.title(f"{model_type} Losses on {dataset}")
        plt.legend()

        plt.subplot(1, 2, 2)
        plt.plot(x, train_accuracies, marker='o', label='Train Accuracy')
        plt.plot(x, val_accuracies, marker='o', label='Val Accuracy')

        plt.xlabel("Hidden Size")
        plt.ylabel("Accuracy")
        
        plt.title(f"{model_type} Accuracies on {dataset}")
        plt.legend()
        
        plt.tight_layout()
        plt.savefig(model_type + "_" + dataset + "_hidden_sizes_" + "comparison.pdf")
        plt.show()


In [None]:
grid_search(nottingham_train_dataloader, nottingham_val_dataloader, input_size = 88, output_size = 88, dataset = "Nottingham", num_epochs=10)
grid_search(muse_train_dataloader, muse_val_dataloader, input_size = 88, output_size = 88, dataset = "MuseDataset", num_epochs=10)
grid_search(jsb_train_dataloader, jsb_val_dataloader, input_size = num_pitches, output_size = num_pitches, dataset = "JSBDataset", num_epochs=10)