In [10]:
%pip install miditoolkit miditok ipywidgets transformers torch

Collecting torch
Collecting torch
  Downloading torch-2.9.1-cp313-cp313-manylinux_2_28_x86_64.whl.metadata (30 kB)
  Downloading torch-2.9.1-cp313-cp313-manylinux_2_28_x86_64.whl.metadata (30 kB)
Collecting setuptools (from torch)
Collecting setuptools (from torch)
  Downloading setuptools-80.9.0-py3-none-any.whl.metadata (6.6 kB)
  Downloading setuptools-80.9.0-py3-none-any.whl.metadata (6.6 kB)
Collecting sympy>=1.13.3 (from torch)
  Downloading sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting sympy>=1.13.3 (from torch)
  Downloading sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting networkx>=2.5.1 (from torch)
Collecting networkx>=2.5.1 (from torch)
  Downloading networkx-3.6-py3-none-any.whl.metadata (6.8 kB)
  Downloading networkx-3.6-py3-none-any.whl.metadata (6.8 kB)
Collecting jinja2 (from torch)
Collecting jinja2 (from torch)
  Downloading jinja2-3.1.6-py3-none-any.whl.metadata (2.9 kB)
  Downloading jinja2-3.1.6-py3-none-any.whl.metadata (2.9 kB)
Collecting n

In [None]:
from pathlib import Path
import miditok

# Choose a vocabulary / representation
tokenizer = miditok.REMI()  # REMI is a good general format

midi_dir = Path("data/train_midis")
token_seqs = []

for midi_path in midi_dir.glob("*.mid"):
    # Pass the path directly to the tokenizer (miditok uses symusic internally)
    tokens = tokenizer.encode(midi_path)
    # tokens is a TokSequence or list of TokSequences for multi-track
    if isinstance(tokens, list):
        for track_tokens in tokens:
            token_seqs.append(track_tokens.ids)
    else:
        token_seqs.append(tokens.ids)

In [11]:
from transformers import GPT2Config, GPT2LMHeadModel

vocab_size = tokenizer.vocab_size  # from miditok

config = GPT2Config(
    vocab_size=vocab_size,
    n_positions=2048,
    n_ctx=2048,
    n_layer=8,
    n_head=8,
    n_embd=512
)

model = GPT2LMHeadModel(config)
model.cuda()


ImportError: 
GPT2LMHeadModel requires the PyTorch library but it was not found in your environment. Check out the instructions on the
installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
Please note that you may need to restart your runtime after installation.


## Using Hugging Face GPT2LMHeadModel-style

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import random

class MidiTokenDataset(Dataset):
    def __init__(self, sequences, seq_len=1024):
        self.sequences = sequences
        self.seq_len = seq_len
        self.data = []

        for seq in sequences:
            if len(seq) < 2:
                continue
            # break long seq into chunks
            for i in range(0, len(seq) - 1, seq_len):
                chunk = seq[i:i+seq_len+1]
                if len(chunk) > 1:
                    self.data.append(chunk)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        seq = self.data[idx]
        # pad if needed
        if len(seq) < self.seq_len + 1:
            pad_len = self.seq_len + 1 - len(seq)
            seq = seq + [0] * pad_len  # assume 0 is PAD if unused
        input_ids = torch.tensor(seq[:-1], dtype=torch.long)
        labels = torch.tensor(seq[1:], dtype=torch.long)
        return {"input_ids": input_ids, "labels": labels}


## Training loop

In [None]:
from transformers import AdamW

model.train()
optimizer = AdamW(model.parameters(), lr=5e-5)

for epoch in range(num_epochs):
    for batch in loader:
        input_ids = batch["input_ids"].cuda()
        labels = batch["labels"].cuda()

        outputs = model(input_ids=input_ids, labels=labels)
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch} | loss: {loss.item():.4f}")


## Generating new MIDI sequences

In [None]:
import torch

def generate_tokens(model, tokenizer, max_length=1024, temperature=1.0, top_k=0, prompt=None):
    model.eval()
    if prompt is None:
        # Use some default BOS token or a small generic prompt
        prompt = [tokenizer["BOS"]] if "BOS" in tokenizer.vocab else [0]

    input_ids = torch.tensor(prompt, dtype=torch.long).unsqueeze(0).cuda()

    with torch.no_grad():
        for _ in range(max_length - len(prompt)):
            outputs = model(input_ids=input_ids)
            logits = outputs.logits[:, -1, :] / temperature

            if top_k > 0:
                values, indices = torch.topk(logits, top_k)
                probs = torch.softmax(values, dim=-1)
                next_token = indices[0, torch.multinomial(probs, 1)]
            else:
                probs = torch.softmax(logits, dim=-1)
                next_token = torch.multinomial(probs, 1)[0]

            input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)

    return input_ids[0].tolist()
