In [13]:
%pip install miditoolkit miditok ipywidgets transformers torch

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [14]:
from pathlib import Path
import miditok

# Choose a vocabulary / representation
tokenizer = miditok.REMI()  # REMI is a good general format

midi_dir = Path("data/train_midis")
token_seqs = []

for midi_path in midi_dir.glob("*.mid"):
    # Pass the path directly to the tokenizer (miditok uses symusic internally)
    tokens = tokenizer.encode(midi_path)
    # tokens is a TokSequence or list of TokSequences for multi-track
    if isinstance(tokens, list):
        for track_tokens in tokens:
            token_seqs.append(track_tokens.ids)
    else:
        token_seqs.append(tokens.ids)

In [15]:
from transformers import GPT2Config, GPT2LMHeadModel

vocab_size = tokenizer.vocab_size  # from miditok

config = GPT2Config(
    vocab_size=vocab_size,
    n_positions=2048,
    n_ctx=2048,
    n_layer=8,
    n_head=8,
    n_embd=512
)

model = GPT2LMHeadModel(config)
model.cuda()


GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(284, 512)
    (wpe): Embedding(2048, 512)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-7): 8 x GPT2Block(
        (ln_1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=1536, nx=512)
          (c_proj): Conv1D(nf=512, nx=512)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=2048, nx=512)
          (c_proj): Conv1D(nf=512, nx=2048)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=512, out_features=284, bias=False)
)

In [17]:
import torch
from torch.utils.data import Dataset, DataLoader

class MidiTokenDataset(Dataset):
    def __init__(self, sequences, seq_len=1024):
        self.sequences = sequences
        self.seq_len = seq_len
        self.data = []

        for seq in sequences:
            if len(seq) < 2:
                continue
            # break long seq into chunks
            for i in range(0, len(seq) - 1, seq_len):
                chunk = seq[i:i+seq_len+1]
                if len(chunk) > 1:
                    self.data.append(chunk)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        seq = self.data[idx]
        # pad if needed
        if len(seq) < self.seq_len + 1:
            pad_len = self.seq_len + 1 - len(seq)
            seq = seq + [0] * pad_len  # assume 0 is PAD if unused
        input_ids = torch.tensor(seq[:-1], dtype=torch.long)
        labels = torch.tensor(seq[1:], dtype=torch.long)
        return {"input_ids": input_ids, "labels": labels}

In [18]:
# Create dataset and dataloader
dataset = MidiTokenDataset(token_seqs, seq_len=512)
loader = DataLoader(dataset, batch_size=4, shuffle=True)

# Set training parameters
num_epochs = 10

print(f"Dataset size: {len(dataset)} sequences")
print(f"Vocabulary size: {vocab_size}")

Dataset size: 203 sequences
Vocabulary size: 284


## Using Hugging Face GPT2LMHeadModel-style

In [19]:
import torch
from torch.utils.data import Dataset, DataLoader
import random

class MidiTokenDataset(Dataset):
    def __init__(self, sequences, seq_len=1024):
        self.sequences = sequences
        self.seq_len = seq_len
        self.data = []

        for seq in sequences:
            if len(seq) < 2:
                continue
            # break long seq into chunks
            for i in range(0, len(seq) - 1, seq_len):
                chunk = seq[i:i+seq_len+1]
                if len(chunk) > 1:
                    self.data.append(chunk)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        seq = self.data[idx]
        # pad if needed
        if len(seq) < self.seq_len + 1:
            pad_len = self.seq_len + 1 - len(seq)
            seq = seq + [0] * pad_len  # assume 0 is PAD if unused
        input_ids = torch.tensor(seq[:-1], dtype=torch.long)
        labels = torch.tensor(seq[1:], dtype=torch.long)
        return {"input_ids": input_ids, "labels": labels}


## Training loop

In [20]:
from torch.optim import AdamW

model.train()
optimizer = AdamW(model.parameters(), lr=5e-5)

for epoch in range(num_epochs):
    for batch in loader:
        input_ids = batch["input_ids"].cuda()
        labels = batch["labels"].cuda()

        outputs = model(input_ids=input_ids, labels=labels)
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch} | loss: {loss.item():.4f}")

`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.


Epoch 0 | loss: 1.7604
Epoch 1 | loss: 1.3440
Epoch 1 | loss: 1.3440
Epoch 2 | loss: 0.9500
Epoch 2 | loss: 0.9500
Epoch 3 | loss: 0.6943
Epoch 3 | loss: 0.6943
Epoch 4 | loss: 1.5805
Epoch 4 | loss: 1.5805
Epoch 5 | loss: 1.8692
Epoch 5 | loss: 1.8692
Epoch 6 | loss: 0.9520
Epoch 6 | loss: 0.9520
Epoch 7 | loss: 1.1887
Epoch 7 | loss: 1.1887
Epoch 8 | loss: 0.9724
Epoch 8 | loss: 0.9724
Epoch 9 | loss: 1.2262
Epoch 9 | loss: 1.2262


## Generating new MIDI sequences

In [38]:
import torch

def generate_tokens(model, tokenizer, max_length=1024, temperature=1.0, top_k=0, prompt=None):
    model.eval()
    if prompt is None:
        # Use some default BOS token or a small generic prompt
        prompt = [tokenizer["BOS_None"]] if "BOS_None" in tokenizer.vocab else [0]

    input_ids = torch.tensor(prompt, dtype=torch.long).unsqueeze(0).cuda()

    with torch.no_grad():
        for _ in range(max_length - len(prompt)):
            outputs = model(input_ids=input_ids)
            logits = outputs.logits[:, -1, :] / temperature

            if top_k > 0:
                values, indices = torch.topk(logits, top_k)
                probs = torch.softmax(values, dim=-1)
                next_token_idx = torch.multinomial(probs, 1)
                next_token = indices.gather(1, next_token_idx)
            else:
                probs = torch.softmax(logits, dim=-1)
                next_token = torch.multinomial(probs, 1)

            input_ids = torch.cat([input_ids, next_token], dim=1)

    return input_ids[0].tolist()

In [41]:
# Generate new MIDI with a proper prompt from training data
from miditok import TokSequence

# Use the beginning of a real sequence as a prompt
prompt_seq = token_seqs[0][:20]  # First 20 tokens from first training sequence
print(f"Using prompt: {[tokenizer[tid] for tid in prompt_seq]}")

# Increase max_length for longer music (1536 tokens should give ~10-15 seconds)
generated_tokens = generate_tokens(
    model, 
    tokenizer, 
    max_length=1536,  # Increased from 512 for longer output
    temperature=0.9, 
    top_k=50,
    prompt=prompt_seq
)

# Convert token IDs to token strings
token_strings = [tokenizer[token_id] for token_id in generated_tokens]

# Create TokSequence with both IDs and token strings
tok_seq = TokSequence(ids=generated_tokens, tokens=token_strings)

# Convert tokens back to MIDI (wrap in a list for multi-track support)
generated_midi = tokenizer.decode([tok_seq])
generated_midi.dump_midi("output_generated.mid")
print(f"Generated MIDI saved to output_generated.mid")
print(f"Duration: {generated_midi.end() / generated_midi.ticks_per_quarter / 2:.2f} seconds (approximate)")

Using prompt: ['Bar_None', 'Bar_None', 'Position_0', 'Pitch_67', 'Velocity_63', 'Duration_1.4.8', 'Position_12', 'Pitch_66', 'Velocity_63', 'Duration_0.2.8', 'Position_14', 'Pitch_67', 'Velocity_63', 'Duration_0.2.8', 'Position_16', 'Pitch_66', 'Velocity_63', 'Duration_1.0.8', 'Position_24', 'Pitch_62']
Generated MIDI saved to output_generated.mid
Duration: 3.50 seconds (approximate)


In [44]:
# Check the generated MIDI file
print(f"Number of tracks: {len(generated_midi.tracks)}")
print(f"Time division (ticks per quarter note): {generated_midi.ticks_per_quarter}")
print(f"Duration: {generated_midi.end():.2f} ticks")
print(f"Duration in seconds (at 120 BPM): {generated_midi.end() / generated_midi.ticks_per_quarter / 2:.2f}s")

for i, track in enumerate(generated_midi.tracks):
    print(f"\nTrack {i}:")
    print(f"  Notes: {len(track.notes)}")
    print(f"  Name: {track.name}")
    if len(track.notes) > 0:
        print(f"  First 3 notes: {track.notes[:3]}")
        print(f"  Last 3 notes: {track.notes[-3:]}")
    else:
        print("  WARNING: This track has no notes!")
        
# Check the generated tokens
print(f"\nGenerated {len(generated_tokens)} tokens")
print(f"First 20 tokens: {generated_tokens[:20]}")
print(f"First 20 token strings: {token_strings[:20]}")

# Check for early stopping tokens
print(f"\nLast 20 tokens: {generated_tokens[-20:]}")
print(f"Last 20 token strings: {token_strings[-20:]}")

# Count token types
import collections
token_type_counts = collections.Counter([t.split('_')[0] for t in token_strings])
print(f"\nToken type distribution:")
for token_type, count in token_type_counts.most_common(10):
    print(f"  {token_type}: {count}")

Number of tracks: 1
Time division (ticks per quarter note): 8
Duration: 56.00 ticks
Duration in seconds (at 120 BPM): 3.50s

Track 0:
  Notes: 4
  Name: Acoustic Grand Piano
  First 3 notes: [Note(32, 12, 67, 63, 'Tick'), Note(44, 2, 66, 63, 'Tick'), Note(46, 2, 67, 63, 'Tick')]
  Last 3 notes: [Note(44, 2, 66, 63, 'Tick'), Note(46, 2, 67, 63, 'Tick'), Note(48, 8, 66, 63, 'Tick')]

Generated 1536 tokens
First 20 tokens: [4, 4, 190, 51, 109, 137, 202, 50, 109, 127, 204, 51, 109, 127, 206, 50, 109, 133, 214, 46]
First 20 token strings: ['Bar_None', 'Bar_None', 'Position_0', 'Pitch_67', 'Velocity_63', 'Duration_1.4.8', 'Position_12', 'Pitch_66', 'Velocity_63', 'Duration_0.2.8', 'Position_14', 'Pitch_67', 'Velocity_63', 'Duration_0.2.8', 'Position_16', 'Pitch_66', 'Velocity_63', 'Duration_1.0.8', 'Position_24', 'Pitch_62']

Last 20 tokens: [127, 41, 127, 41, 127, 53, 127, 4, 41, 127, 34, 127, 63, 127, 41, 127, 41, 127, 50, 127]
Last 20 token strings: ['Duration_0.2.8', 'Pitch_57', 'Duratio

In [43]:
# Check what tokens are in the training data
import collections
all_tokens = []
for seq in token_seqs[:5]:  # Check first 5 sequences
    all_tokens.extend(seq[:100])  # First 100 tokens of each

# Get token frequencies
token_counts = collections.Counter(all_tokens)
print("Most common tokens in training data:")
for token_id, count in token_counts.most_common(20):
    print(f"  {token_id}: {tokenizer[token_id]} (count: {count})")
    
# Check if we have program tokens
print("\nChecking for Program tokens in vocabulary:")
for token_name in list(tokenizer.vocab.keys())[:30]:
    if 'Program' in token_name or 'Instrument' in token_name:
        print(f"  Found: {token_name}")

Most common tokens in training data:
  109: Velocity_63 (count: 90)
  4: Bar_None (count: 54)
  127: Duration_0.2.8 (count: 39)
  190: Position_0 (count: 33)
  233: PitchDrum_38 (count: 24)
  126: Duration_0.1.8 (count: 24)
  140: Duration_1.7.8 (count: 17)
  206: Position_16 (count: 14)
  101: Velocity_31 (count: 12)
  214: Position_24 (count: 11)
  51: Pitch_67 (count: 10)
  156: Duration_3.7.8 (count: 10)
  32: Pitch_48 (count: 9)
  113: Velocity_79 (count: 9)
  129: Duration_0.4.8 (count: 8)
  38: Pitch_54 (count: 8)
  202: Position_12 (count: 7)
  133: Duration_1.0.8 (count: 7)
  46: Pitch_62 (count: 7)
  198: Position_8 (count: 7)

Checking for Program tokens in vocabulary:


## Summary and Next Steps

**Current Issues:**
- Model generated 1536 tokens but only 4 valid notes (~3.5 seconds)
- Last tokens show the model is stuck in a loop: `Pitch → Duration → Pitch → Duration`
- Missing `Position` and `Velocity` tokens needed to complete note sequences
- In REMI, a valid note needs: Position → Pitch → Velocity → Duration

**Why this happened:**
- Only trained for 10 epochs on 8 MIDI files (very small dataset)
- Model hasn't learned proper token sequence structure
- Training loss was still decreasing (1.76 → 1.23), indicating it needs more training

**How to get longer, better music:**

1. **Train longer**: 50-100 epochs instead of 10
2. **More training data**: Add more MIDI files to `data/train_midis/`
3. **Use a pre-trained model**: Start with a model that already understands music structure
4. **Post-process tokens**: Filter out invalid token sequences before decoding

For now, your setup works and can generate MIDI! The model just needs more training to create longer, more coherent sequences.