In [1]:
import os
import glob
import tarfile

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import optuna


device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)


Using device: cuda


  from .autonotebook import tqdm as notebook_tqdm


In [2]:

if not os.path.exists("./jsb"):
    with tarfile.open("jsb_chorales.tgz", "r:gz") as tar:
        tar.extractall("./jsb")
    print("Extracted jsb_chorales.tgz")
else:
    print("JSB folder already exists")


def load_chorales(path):
    chorales = []
    for f in glob.glob(os.path.join(path, "*.csv")):
        df = pd.read_csv(f)
        df = df.fillna(-1) 
        chorales.append(df.values.astype(int))
    return chorales


train_chorales = load_chorales("./jsb/jsb_chorales/train")
valid_chorales = load_chorales("./jsb/jsb_chorales/valid")
test_chorales  = load_chorales("./jsb/jsb_chorales/test")

len(train_chorales), len(valid_chorales), len(test_chorales)


JSB folder already exists


(229, 76, 77)

In [3]:
def build_windows(chorales, window=32):
    X, Y = [], []
    for chorale in chorales:
        T = len(chorale)
        if T <= window:
            continue
        for i in range(T - window):
            X.append(chorale[i:i+window])  
            Y.append(chorale[i+window])     
    return np.array(X), np.array(Y)


WINDOW = 32

X_train, Y_train = build_windows(train_chorales, window=WINDOW)
X_valid, Y_valid = build_windows(valid_chorales, window=WINDOW)
X_test,  Y_test  = build_windows(test_chorales,  window=WINDOW)

print("X_train:", X_train.shape, "Y_train:", Y_train.shape)
print("X_valid:", X_valid.shape, "Y_valid:", Y_valid.shape)
print("X_test :", X_test.shape,  "Y_test :", Y_test.shape)


n_notes = int(max(X_train.max(), Y_train.max()) + 1)
print("Number of distinct note indices:", n_notes)


X_train: (47900, 32, 4) Y_train: (47900, 4)
X_valid: (15976, 32, 4) Y_valid: (15976, 4)
X_test : (16436, 32, 4) Y_test : (16436, 4)
Number of distinct note indices: 82


In [4]:
class ChoraleDataset(Dataset):
    def __init__(self, X, Y):
        
        self.X = torch.tensor(X, dtype=torch.long)
        self.Y = torch.tensor(Y, dtype=torch.long)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]


train_ds = ChoraleDataset(X_train, Y_train)
valid_ds = ChoraleDataset(X_valid, Y_valid)
test_ds  = ChoraleDataset(X_test,  Y_test)

train_loader_full = DataLoader(train_ds, batch_size=64, shuffle=True)
valid_loader_full = DataLoader(valid_ds, batch_size=64, shuffle=False)
test_loader_full  = DataLoader(test_ds,  batch_size=64, shuffle=False)


In [5]:
class RNNModel(nn.Module):
    def __init__(self, n_notes, embed=128, hidden=256):
        super().__init__()
        self.embed = nn.Embedding(n_notes, embed)
        self.rnn = nn.GRU(input_size=embed*4, hidden_size=hidden, batch_first=True)
        self.fc  = nn.Linear(hidden, 4 * n_notes)

    def forward(self, x):
       
        B, T, V = x.size()  
        x = self.embed(x)  
        x = x.view(B, T, -1)  
        out, _ = self.rnn(x)  
        out = out[:, -1]      
        out = self.fc(out)   
        out = out.view(B, 4, n_notes)  
        return out


In [6]:
def train_one_epoch(model, loader, optimizer, criterion):
   
    local_device = "cuda" if torch.cuda.is_available() else "cpu"

    model.train()
    total_loss = 0.0

    for X, Y in loader:
        X, Y = X.to(local_device), Y.to(local_device)
        optimizer.zero_grad()

        logits = model(X)  
        loss = 0.0
        for v in range(4):
            loss = loss + criterion(logits[:, v, :], Y[:, v])

        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(loader)


def evaluate(model, loader, criterion):
    local_device = "cuda" if torch.cuda.is_available() else "cpu"

    model.eval()
    total_loss = 0.0

    with torch.no_grad():
        for X, Y in loader:
            X, Y = X.to(local_device), Y.to(local_device)
            logits = model(X)
            loss = 0.0
            for v in range(4):
                loss = loss + criterion(logits[:, v, :], Y[:, v])
            total_loss += loss.item()

    return total_loss / len(loader)


In [7]:
def objective(trial):
    
    local_device = "cuda" if torch.cuda.is_available() else "cpu"

    embed  = trial.suggest_categorical("embed", [64, 128, 256])
    hidden = trial.suggest_categorical("hidden", [128, 256, 512])
    lr     = trial.suggest_float("lr", 1e-4, 1e-2, log=True)
    opt_name = trial.suggest_categorical("opt", ["adam", "rmsprop", "sgd"])

    model = RNNModel(n_notes, embed=embed, hidden=hidden).to(local_device)
    criterion = nn.CrossEntropyLoss()

    if opt_name == "adam":
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    elif opt_name == "rmsprop":
        optimizer = torch.optim.RMSprop(model.parameters(), lr=lr)
    else:
        optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)

    
    train_loader = DataLoader(train_ds, batch_size=128, shuffle=True)
    valid_loader = DataLoader(valid_ds, batch_size=128, shuffle=False)

   
    train_one_epoch(model, train_loader, optimizer, criterion)
    val_loss = evaluate(model, valid_loader, criterion)

    return val_loss


study = optuna.create_study(direction="minimize")
study.optimize(objective, n_trials=20) 

print("Best trial:")
print("  Value:", study.best_value)
print("  Params:", study.best_params)


[I 2025-12-10 10:57:01,276] A new study created in memory with name: no-name-11f96889-6a19-42d7-bf22-22d268d7d7c5
[I 2025-12-10 10:57:14,790] Trial 0 finished with value: 3.2796129426956178 and parameters: {'embed': 64, 'hidden': 512, 'lr': 0.0003143435704089969, 'opt': 'rmsprop'}. Best is trial 0 with value: 3.2796129426956178.
[I 2025-12-10 10:57:19,126] Trial 1 finished with value: 3.721627481460571 and parameters: {'embed': 128, 'hidden': 128, 'lr': 0.0005123381715674258, 'opt': 'rmsprop'}. Best is trial 0 with value: 3.2796129426956178.
[I 2025-12-10 10:57:27,059] Trial 2 finished with value: 3.509983681678772 and parameters: {'embed': 256, 'hidden': 256, 'lr': 0.00039821517542115394, 'opt': 'rmsprop'}. Best is trial 0 with value: 3.2796129426956178.
[I 2025-12-10 10:57:31,763] Trial 3 finished with value: 3.997602501869202 and parameters: {'embed': 256, 'hidden': 128, 'lr': 0.009557267518731215, 'opt': 'adam'}. Best is trial 0 with value: 3.2796129426956178.
[I 2025-12-10 10:57:4

Best trial:
  Value: 3.053851508140564
  Params: {'embed': 64, 'hidden': 512, 'lr': 0.0008351870980315215, 'opt': 'rmsprop'}


In [8]:
best = study.best_params
print("Using best hyperparameters:", best)

model = RNNModel(
    n_notes,
    embed=best["embed"],
    hidden=best["hidden"]
).to(device)

criterion = nn.CrossEntropyLoss()

if best["opt"] == "adam":
    optimizer = torch.optim.Adam(model.parameters(), lr=best["lr"])
elif best["opt"] == "rmsprop":
    optimizer = torch.optim.RMSprop(model.parameters(), lr=best["lr"])
else:
    optimizer = torch.optim.SGD(model.parameters(), lr=best["lr"], momentum=0.9)

EPOCHS = 10

for epoch in range(1, EPOCHS + 1):
    train_loss = train_one_epoch(model, train_loader_full, optimizer, criterion)
    valid_loss = evaluate(model, valid_loader_full, criterion)
    print(f"Epoch {epoch:02d} | train loss: {train_loss:.4f} | valid loss: {valid_loss:.4f}")


Using best hyperparameters: {'embed': 64, 'hidden': 512, 'lr': 0.0008351870980315215, 'opt': 'rmsprop'}
Epoch 01 | train loss: 3.4429 | valid loss: 2.9497
Epoch 02 | train loss: 2.6294 | valid loss: 2.8053
Epoch 03 | train loss: 2.2869 | valid loss: 2.7034
Epoch 04 | train loss: 2.0262 | valid loss: 2.6130
Epoch 05 | train loss: 1.7978 | valid loss: 2.6153
Epoch 06 | train loss: 1.6084 | valid loss: 2.6392
Epoch 07 | train loss: 1.4513 | valid loss: 2.6911
Epoch 08 | train loss: 1.3046 | valid loss: 2.7827
Epoch 09 | train loss: 1.1652 | valid loss: 2.8701
Epoch 10 | train loss: 1.0686 | valid loss: 2.9382


In [9]:
def generate_sequence(model, seed, steps=100, window=32, temperature=1.0):
   
    model.eval()
    seq = [list(row) for row in seed] 

    with torch.no_grad():
        for _ in range(steps):
            context = torch.tensor(seq[-window:], dtype=torch.long).unsqueeze(0).to(device)

            logits = model(context)[0]  
            next_notes = []
            for v in range(4):
                logit_v = logits[v] / max(temperature, 1e-6)
                probs_v = torch.softmax(logit_v, dim=-1)
                note_v = torch.multinomial(probs_v, num_samples=1).item()
                next_notes.append(note_v)

            seq.append(next_notes)

    return np.array(seq, dtype=int)


In [10]:

seed_chorale = valid_chorales[0]
seed_length = 32
seed = seed_chorale[:seed_length]  

generated = generate_sequence(model, seed, steps=100, window=WINDOW, temperature=1.0)
print("Generated sequence shape:", generated.shape)


Generated sequence shape: (132, 4)


In [20]:
def generate_sequence(model, seed, steps=100, window=32, temperature=1.0):

    model.eval()
    seq = [list(row) for row in seed]  

    with torch.no_grad():
        for _ in range(steps):
            context = torch.tensor(seq[-window:], dtype=torch.long).unsqueeze(0).to(device)

            logits = model(context)[0]  
            next_notes = []

            for v in range(4):
                logit_v = logits[v] / max(temperature, 1e-6)
                probs_v = torch.softmax(logit_v, dim=-1)
                note_v = torch.multinomial(probs_v, 1).item()
                next_notes.append(note_v)

            seq.append(next_notes)

    return np.array(seq, dtype=int)


In [22]:
seed = valid_chorales[0][:32]  
generated = generate_sequence(model, seed, steps=100, window=32)

print("Generated shape:", generated.shape)
generated[:10] 


Generated shape: (132, 4)


array([[72, 67, 60, 48],
       [72, 67, 60, 48],
       [72, 67, 60, 48],
       [72, 67, 60, 48],
       [72, 67, 64, 48],
       [72, 67, 64, 48],
       [72, 67, 64, 50],
       [72, 67, 64, 50],
       [72, 67, 64, 52],
       [72, 67, 64, 52]])

In [24]:
def to_online_sequencer_format(seq, beat_length=1.0):

    notes = []
    T = len(seq)

    for t in range(T):
        for voice, pitch in enumerate(seq[t]):
            if pitch < 0:
                continue  

            note = {
                "pitch": int(pitch),
                "time": float(t * beat_length),
                "duration": float(beat_length),
                "instrument": voice  
            }
            notes.append(note)

    return notes

notes = to_online_sequencer_format(generated)
notes[:10]


[{'pitch': 72, 'time': 0.0, 'duration': 1.0, 'instrument': 0},
 {'pitch': 67, 'time': 0.0, 'duration': 1.0, 'instrument': 1},
 {'pitch': 60, 'time': 0.0, 'duration': 1.0, 'instrument': 2},
 {'pitch': 48, 'time': 0.0, 'duration': 1.0, 'instrument': 3},
 {'pitch': 72, 'time': 1.0, 'duration': 1.0, 'instrument': 0},
 {'pitch': 67, 'time': 1.0, 'duration': 1.0, 'instrument': 1},
 {'pitch': 60, 'time': 1.0, 'duration': 1.0, 'instrument': 2},
 {'pitch': 48, 'time': 1.0, 'duration': 1.0, 'instrument': 3},
 {'pitch': 72, 'time': 2.0, 'duration': 1.0, 'instrument': 0},
 {'pitch': 67, 'time': 2.0, 'duration': 1.0, 'instrument': 1}]

In [26]:
import json
print(json.dumps(notes))


[{"pitch": 72, "time": 0.0, "duration": 1.0, "instrument": 0}, {"pitch": 67, "time": 0.0, "duration": 1.0, "instrument": 1}, {"pitch": 60, "time": 0.0, "duration": 1.0, "instrument": 2}, {"pitch": 48, "time": 0.0, "duration": 1.0, "instrument": 3}, {"pitch": 72, "time": 1.0, "duration": 1.0, "instrument": 0}, {"pitch": 67, "time": 1.0, "duration": 1.0, "instrument": 1}, {"pitch": 60, "time": 1.0, "duration": 1.0, "instrument": 2}, {"pitch": 48, "time": 1.0, "duration": 1.0, "instrument": 3}, {"pitch": 72, "time": 2.0, "duration": 1.0, "instrument": 0}, {"pitch": 67, "time": 2.0, "duration": 1.0, "instrument": 1}, {"pitch": 60, "time": 2.0, "duration": 1.0, "instrument": 2}, {"pitch": 48, "time": 2.0, "duration": 1.0, "instrument": 3}, {"pitch": 72, "time": 3.0, "duration": 1.0, "instrument": 0}, {"pitch": 67, "time": 3.0, "duration": 1.0, "instrument": 1}, {"pitch": 60, "time": 3.0, "duration": 1.0, "instrument": 2}, {"pitch": 48, "time": 3.0, "duration": 1.0, "instrument": 3}, {"pitch