## Data Loading

In [1]:
import os

def build_vocab_from_files(root_dir):
    all_tokens = set()
    for root, _, files in os.walk(root_dir):
        for fname in files:
            if fname.endswith(".txt"):
                with open(os.path.join(root, fname), 'r') as f:
                    tokens = list(map(int, f.read().strip().split()))
                    all_tokens.update(tokens)
    sorted_vocab = sorted(all_tokens)
    itos = {i: tok for i, tok in enumerate(sorted_vocab)}
    stoi = {tok: i for i, tok in itos.items()}
    return stoi, itos


In [2]:
import torch
from torch.utils.data import Dataset
import glob

class OutpaintingDataset(Dataset):
    def __init__(self, token_dir, stoi, context_len=64, target_len=64):
        self.file_paths = glob.glob(os.path.join(token_dir, '**/*.txt'), recursive=True)
        self.stoi = stoi
        self.context_len = context_len
        self.target_len = target_len
        self.samples = []

        for path in self.file_paths:
            with open(path) as f:
                raw_tokens = list(map(int, f.read().strip().split()))
                token_indices = [stoi[t] for t in raw_tokens if t in stoi]
                if len(token_indices) >= context_len + target_len:
                    for i in range(0, len(token_indices) - context_len - target_len + 1, context_len):
                        ctx = token_indices[i:i+context_len]
                        tgt = token_indices[i+context_len:i+context_len+target_len]
                        self.samples.append((ctx, tgt))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        context, target = self.samples[idx]
        return torch.tensor(context), torch.tensor(target)


In [3]:
def collate_fn(batch, pad_idx):
    contexts, targets = zip(*batch)
    context_pad = torch.nn.utils.rnn.pad_sequence(contexts, batch_first=True, padding_value=pad_idx)
    target_pad = torch.nn.utils.rnn.pad_sequence(targets, batch_first=True, padding_value=pad_idx)
    return context_pad, target_pad


In [4]:
from torch.utils.data import DataLoader
import torch
from tqdm import tqdm
from torch.utils.data import random_split

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

data_path = "../data/tokenized"
stoi, itos = build_vocab_from_files(data_path)
pad_idx = stoi.get(0, 0)

context_len, target_len = 64, 64 # not being passed in as they are same as default vals
train_dataset = OutpaintingDataset(data_path, stoi)
train_size = int(0.9 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

Using device: cuda


## Model

In [5]:
import torch.nn as nn

class MusicOutpaintingTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = nn.Parameter(torch.randn(1, 1024, d_model))
        layer = nn.TransformerDecoderLayer(d_model, nhead)
        self.decoder = nn.TransformerDecoder(layer, num_layers)
        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, src, tgt):
        src = self.embedding(src) + self.pos_encoding[:, :src.size(1), :]
        tgt = self.embedding(tgt) + self.pos_encoding[:, :tgt.size(1), :]
        src, tgt = src.transpose(0, 1), tgt.transpose(0, 1)
        out = self.decoder(tgt, memory=src)
        return self.fc_out(out.transpose(0, 1))


## Hyperparameters

In [6]:
# Hyperparameters:
batch_size = 8
epochs = 10

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda b: collate_fn(b, pad_idx))
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=lambda b: collate_fn(b, pad_idx))

model = MusicOutpaintingTransformer(vocab_size=len(stoi)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)

In [7]:
@torch.no_grad()
def evaluate(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0
    for x, y in val_loader:
        x, y = x.to(device), y.to(device)
        y_input = y[:, :-1]
        y_target = y[:, 1:]

        out = model(x, y_input)
        loss = criterion(out.reshape(-1, out.shape[-1]), y_target.reshape(-1))
        total_loss += loss.item()
    return total_loss / len(val_loader)


In [None]:
for epoch in range(epochs):
    model.train()
    total_loss = 0
    train_progress = tqdm(train_loader, desc=f"Epoch {epoch+1} [Training]")

    for x, y in train_progress:
        x, y = x.to(device), y.to(device)
        y_input = y[:, :-1]
        y_target = y[:, 1:]

        out = model(x, y_input)
        loss = criterion(out.reshape(-1, out.shape[-1]), y_target.reshape(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        train_progress.set_postfix(loss=loss.item())

    avg_train_loss = total_loss / len(train_loader)

    val_loss = evaluate(model, val_loader, criterion, device)

    print(f"\nEpoch {epoch+1} Summary:")
    print(f"  Train Loss: {avg_train_loss:.4f}")
    print(f"  Val Loss  : {val_loss:.4f}\n")



Epoch 1 [Training]:  38%|███▊      | 7400/19541 [04:24<07:39, 26.40it/s, loss=0.0602]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

Epoch 1 [Training]: 100%|██████████| 19541/19541 [11:53<00:00, 27.37it/s, loss=0.0632]



Epoch 1 Summary:
  Train Loss: 0.1029
  Val Loss  : 0.0504



Epoch 2 [Training]:  85%|████████▌ | 16624/19541 [10:25<01:52, 26.00it/s, loss=0.0511]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

Epoch 3 [Training]:  89%|████████▊ | 17341/19541 [10:48<01:22, 26.51it/s, loss=0.0404]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)




Epoch 4 Summary:
  Train Loss: 0.0489
  Val Loss  : 0.0489



Epoch 5 [Training]: 100%|██████████| 19541/19541 [12:27<00:00, 26.13it/s, loss=0.0429]



Epoch 5 Summary:
  Train Loss: 0.0486
  Val Loss  : 0.0487



Epoch 6 [Training]:  39%|███▉      | 7664/19541 [04:48<07:26, 26.61it/s, loss=0.0569]

## Outpainting

In [45]:
# Define special token IDs
PAD_TOKEN_ID = stoi.get(0, 0)
BOS_TOKEN_ID = stoi.get(1, 1)
EOS_TOKEN_ID = stoi.get(2, 2)

In [46]:
from music21 import stream, note, chord

def tokens_to_midi(tokens, output_path):
    s = stream.Stream()
    i = 0
    while i < len(tokens):
        tok = tokens[i]

        # Single Note
        if 1000 <= tok < 2000:
            pitch_midi = tok - 1000
            if i + 1 < len(tokens) and 2000 <= tokens[i + 1] < 3000:
                dur = (tokens[i + 1] - 2000) / 4.0
                s.append(note.Note(pitch_midi, quarterLength=dur))
                i += 2
            else:
                s.append(note.Note(pitch_midi))
                i += 1

        # Chord
        elif 1000 <= tok < 2000:
            pitches = []
            while i < len(tokens) and 1000 <= tokens[i] < 2000:
                pitches.append(tokens[i] - 1000)
                i += 1
            if i < len(tokens) and 3000 <= tokens[i] < 4000:
                dur = (tokens[i] - 3000) / 4.0
                s.append(chord.Chord(pitches, quarterLength=dur))
                i += 1
            else:
                s.append(chord.Chord(pitches))
        
        # Skip other tokens like PAD, BOS, EOS
        else:
            i += 1

    s.write("midi", fp=output_path)
    print(f"Saved MIDI to: {output_path}")


In [53]:
def generate_continuation(model, context, max_length=64, temperature = 0.75, device=device):
    model.eval()
    generated = []
    input_tgt = torch.tensor([[BOS_TOKEN_ID]], device=device)
    with torch.no_grad():
        for _ in range(max_length):
            logits = model(context, input_tgt)
            probs = torch.softmax(logits[:, -1, :] / temperature, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            input_tgt = torch.cat([input_tgt, next_token], dim=1)
            generated.append(next_token.item())
    return generated


In [54]:
# Pick a random sample from val set
sample_context, _ = val_dataset[0]

# Generate continuation from a context
sample_context_tensor = torch.tensor(sample_context, dtype=torch.long, device=device).unsqueeze(0)
generated_indices = generate_continuation(model, sample_context_tensor, max_length=64, device=device)

# Convert model output indices back to your original token IDs
generated_tokens = [itos[i] for i in generated_indices]

# Convert to MIDI and save
tokens_to_midi(generated_tokens, "generated_sample.mid")

  sample_context_tensor = torch.tensor(sample_context, dtype=torch.long, device=device).unsqueeze(0)


Saved MIDI to: generated_sample.mid
