In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf
import ast
import random
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm

Symbolic, Unconditioned Generation

In [None]:
df_v2 = pd.read_csv('midi_df_v2.csv')

def parse_tokens(s):
    return ast.literal_eval(s)

df_v2['token_list'] = df_v2['token_sequence'].apply(parse_tokens)

all_tokens = set()
for seq in df_v2['token_list']:
    all_tokens.update(seq)
all_tokens = sorted(all_tokens)
orig2new = {orig: idx for idx, orig in enumerate(all_tokens)}
new2orig = {idx: orig for orig, idx in orig2new.items()}
vocab_size = len(orig2new)

print(f"Number of sequences: {len(df_v2)}")
print(f"Example sequence length: {len(df_v2.loc[0, 'token_list'])}")
print(f"Vocabulary size: {vocab_size} (tokens {all_tokens[0]} ... {all_tokens[-1]})")

class TransformerNextToken(nn.Module):
    def __init__(self, vocab_size: int, d_model: int = 128, nhead: int = 4, num_layers: int = 2, dim_ff: int = 256, dropout: float = 0.1, max_len: int = 64):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_len, d_model)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=dim_ff,
            dropout=dropout, activation="gelu", batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.fc_out = nn.Linear(d_model, vocab_size)
        self.dropout = nn.Dropout(dropout)
        self.d_model = d_model
        self.max_len = max_len

    def forward(self, x):
        batch_size, seq_len = x.size()
        tok_emb = self.token_emb(x) * (self.d_model ** 0.5)
        positions = torch.arange(0, seq_len, device=x.device).unsqueeze(0)
        pos_emb = self.pos_emb(positions)
        h = self.dropout(tok_emb + pos_emb)
        h = self.transformer(h)
        logits = self.fc_out(h)
        return logits

T = 64
d_model = 128
nhead = 4
num_layers = 2
dim_ff = 256
dropout = 0.1

model = TransformerNextToken(
    vocab_size=vocab_size,
    d_model=d_model,
    nhead=nhead,
    num_layers=num_layers,
    dim_ff=dim_ff,
    dropout=dropout,
    max_len=T
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nModel instantiated on {device}. Total trainable parameters: {total_params}\n")
print(model)

Number of sequences: 1276
Example sequence length: 7263
Vocabulary size: 1166 (tokens 4 ... 1275)

Model instantiated on cpu. Total trainable parameters: 572814

TransformerNextToken(
  (token_emb): Embedding(1166, 128)
  (pos_emb): Embedding(64, 128)
  (transformer): TransformerEncoder(
    (layers): ModuleList(
      (0-1): 2 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=256, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=256, out_features=128, bias=True)
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (fc_out): Linear(in_features=128, out

In [None]:
class RandomWindowDataset(Dataset):
    def __init__(self, sequences, max_len=64):
        self.seqs = [seq for seq in sequences if len(seq) > max_len]
        self.total_windows = sum(len(seq) - max_len for seq in self.seqs)

    def __len__(self):
        return self.total_windows

    def __getitem__(self, idx):
        seq = random.choice(self.seqs)
        L = len(seq)
        start = random.randint(0, L - 64 - 1)
        window = seq[start : start + 64 + 1]
        x = torch.LongTensor(window[:-1])
        y = torch.LongTensor(window[1:])
        return x, y

In [None]:
df_v2['indexed_seq'] = df_v2['token_list'].apply(lambda seq: [orig2new[t] for t in seq])

train_seqs = df_v2[df_v2['split']=='train']['indexed_seq'].tolist()
val_seqs = df_v2[df_v2['split']=='validation']['indexed_seq'].tolist()

print(f"Train sequences: {len(train_seqs)}, Val sequences: {len(val_seqs)}")

T = 64
train_dataset = RandomWindowDataset(train_seqs, max_len=T)
val_dataset = RandomWindowDataset(val_seqs, max_len=T)

batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False)

print("Num train batches:", len(train_loader), "Num val batches:", len(val_loader))


Train sequences: 962,  Val sequences: 137
Num train batches: 271595 Num val batches: 31048


In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=2)

In [None]:
max_train_batches = 2000
max_val_batches = 500

num_epochs = 15
patience_val = 2
best_val_loss = float("inf")
epochs_no_improve = 0

for epoch in range(1, num_epochs + 1):
    model.train()
    total_train_loss = 0.0

    for i, (batch_x, batch_y) in enumerate(train_loader):
        if i >= max_train_batches:
            break

        batch_x = batch_x.to(device)
        batch_y = batch_y.to(device)

        optimizer.zero_grad()
        logits = model(batch_x)
        loss = criterion(logits.view(-1, vocab_size), batch_y.view(-1))
        loss.backward()
        optimizer.step()

        total_train_loss += loss.item()

        if i % 500 == 0:
            avg_so_far = total_train_loss / (i + 1)
            print(f"Epoch {epoch} [Train] Batch {i}/{max_train_batches}   Avg Loss: {avg_so_far:.4f}")

    avg_train_loss = total_train_loss / min(len(train_loader), max_train_batches)

    model.eval()
    total_val_loss = 0.0

    with torch.no_grad():
        for i, (batch_x, batch_y) in enumerate(val_loader):
            if i >= max_val_batches:
                break

            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)
            logits = model(batch_x)
            loss = criterion(logits.view(-1, vocab_size), batch_y.view(-1))
            total_val_loss += loss.item()

            if i % 200 == 0:
                avg_val_so_far = total_val_loss / (i + 1)
                print(f"Epoch {epoch} [Val] Batch {i}/{max_val_batches} Avg Loss: {avg_val_so_far:.4f}")

    avg_val_loss = total_val_loss / min(len(val_loader), max_val_batches)
    scheduler.step(avg_val_loss)

    print(f"Epoch {epoch} Train Loss: {avg_train_loss:.4f} Val Loss: {avg_val_loss:.4f}")

    if avg_val_loss < best_val_loss - 1e-4:
        best_val_loss = avg_val_loss
        epochs_no_improve = 0
        torch.save(model.state_dict(), "best_transformer_model.pt")
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience_val:
            print(f"No improvement for {patience_val} epochs → Early stopping.")
            break

model.load_state_dict(torch.load("best_transformer_model.pt"))



Epoch 1 [Train] Batch 0/2000   Avg Loss: 4.6676
Epoch 1 [Train] Batch 500/2000   Avg Loss: 4.6825
Epoch 1 [Train] Batch 1000/2000   Avg Loss: 4.6457
Epoch 1 [Train] Batch 1500/2000   Avg Loss: 4.6191
Epoch 1 [Val]   Batch 0/500   Avg Loss: 4.5949
Epoch 1 [Val]   Batch 200/500   Avg Loss: 4.4643
Epoch 1 [Val]   Batch 400/500   Avg Loss: 4.4706
Epoch 1  Train Loss: 4.5939   Val Loss: 4.4716
Epoch 2 [Train] Batch 0/2000   Avg Loss: 4.4011
Epoch 2 [Train] Batch 500/2000   Avg Loss: 4.4871
Epoch 2 [Train] Batch 1000/2000   Avg Loss: 4.4745
Epoch 2 [Train] Batch 1500/2000   Avg Loss: 4.4583
Epoch 2 [Val]   Batch 0/500   Avg Loss: 4.2646
Epoch 2 [Val]   Batch 200/500   Avg Loss: 4.3595
Epoch 2 [Val]   Batch 400/500   Avg Loss: 4.3588
Epoch 2  Train Loss: 4.4452   Val Loss: 4.3598
Epoch 3 [Train] Batch 0/2000   Avg Loss: 4.4036
Epoch 3 [Train] Batch 500/2000   Avg Loss: 4.3921
Epoch 3 [Train] Batch 1000/2000   Avg Loss: 4.3805
Epoch 3 [Train] Batch 1500/2000   Avg Loss: 4.3700
Epoch 3 [Val]   

<All keys matched successfully>

In [None]:
def generate_sequence(model, start_seq, gen_len=200, temperature=1.0):
    model.eval()
    generated = start_seq.copy()

    for _ in range(gen_len):
        context = generated[-T:]
        x = torch.LongTensor(context).unsqueeze(0).to(device)
        with torch.no_grad():
            logits = model(x)
        last_logits = logits[0, -1, :]
        probs = torch.softmax(last_logits / temperature, dim=0).cpu().numpy()
        next_idx = np.random.choice(vocab_size, p=probs)
        generated.append(int(next_idx))
    return generated

random_idx = random.randrange(len(val_seqs))
seed = val_seqs[random_idx]
if len(seed) < T:
    seed = [seed[0]] * (T - len(seed)) + seed
else:
    seed = seed[:T]

gen_indices = generate_sequence(model, start_seq=seed, gen_len=200, temperature=1.0)
print("Generated total tokens:", len(gen_indices))

generated_original = [ new2orig[idx] for idx in gen_indices ]
np.save("generated_tokens.npy", np.array(generated_original, dtype=np.int16))
print("Saved generated_tokens.npy (shape=", len(generated_original), ")")

Generated total tokens: 264
Saved generated_tokens.npy (shape= 264 )
