## MIDI generation with transformers

In [1]:
!pip install torch numpy matplotlib pretty_midi miditoolkit tqdm



In [2]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from miditoolkit import MidiFile
import matplotlib.pyplot as plt
from tqdm import tqdm
from collections import defaultdict
import pretty_midi
import math

In [3]:
DATA_PATH = "clean_midi"
OUTPUT_PATH = "generated_midi"
os.makedirs(OUTPUT_PATH, exist_ok=True)

Model params

In [4]:
# Embeddings size
d_model = 256

# Attention heads number
n_heads = 8

# Transformer layers number
num_layers = 6

# Feed-forward layer size
d_ff = 4096

# Dropout
dropout = 0.1 


Learning params

In [5]:
batch_size = 16
seq_len = 256 # Length of a sequence
epochs = 20
learning_rate = 3e-5

## Tokenize the MIDI

In [6]:
class MIDITokenizer:
    def __init__(self):
        self.event2idx = {}
        self.idx2event = {}
        self._build_vocab()


    def _build_vocab(self):
        # Events: notes, pauses, control events
        note_events = [f'NOTE_{i}' for i in range(128)]
        velocity_events = [f'VELOCITY_{i}' for i in range(128)]
        duration_events = [f'DURATION_{i}' for i in range(1, 7681)] # Ticks is in interval of 1--128
        control_events = [f'TIME_SHIFT', 'END_OF_TRACK']

        # Collecting all events
        all_events = note_events + velocity_events + duration_events + control_events

        # Create dicts
        self.event2idx = {event: idx for idx, event in enumerate(all_events)}
        self.idx2event = {idx: event for idx, event in enumerate(all_events)}

        self.vocab_size = len(all_events)


    def encode(self, midi_file):
        # Upload a .mid file
        midi = MidiFile(midi_file)
        notes = midi.instruments[0].notes

        # Sort notes by start time
        notes.sort(key=lambda x: (x.start, x.pitch))

        events = []
        current_time = 0

        for note in notes:
            # Time between previous and current events
            time_shift = note.start - current_time
            if time_shift > 0:
                events.append('TIME_SHIFT')
        
            # Add note, velocity & duration
            events.append(f'NOTE_{note.pitch}')
            events.append(f'VELOCITY_{note.velocity}')
            events.append(f'DURATION_{note.duration}')

            current_time = note.start

        events.append('END_OF_TRACK')

        result = []
        for event in events:
            result.append(self.event2idx[event])
        return result


    def decode(self, indices):
        events = [self.idx2event[event] for event in indices]

        # Create new .mid file
        midi = pretty_midi.PrettyMIDI()
        instrument = pretty_midi.Instrument(program=0)

        current_time = 0
        current_note = None
        current_velocity = None

        for event in events:
            if event.startswith('NOTE_'):
                current_note = int(event.split('_')[1])
            elif event.startswith('VELOCITY_'):
                current_velocity = int(event.split('_')[1])
            elif event.startswith('DURATION_'):
                if current_note is not None and current_velocity is not None:
                    duration = int(event.split('_')[1])
                    note = pretty_midi.Note(
                        velocity=current_velocity,
                        pitch=current_note,
                        start=current_time,
                        end=current_time + duration/100
                    )
                    instrument.notes.append(note)
                    current_time += duration/100
                    current_note = None
                    current_velocity = None
            elif event == 'TIME_SHIFT':
                current_time += 0.1 # Fixed time shift

        midi.instruments.append(instrument)
        return midi

tokenizer = MIDITokenizer()
print(f"Vocab size: {tokenizer.vocab_size}")

Vocab size: 7938


In [7]:
torch.save(tokenizer, 'tokenizer.pt')

## Create dataset & dataloader

In [8]:
class MIDIDataset(Dataset):
    def __init__(self, dataset_path, tokenizer, seq_len=512):
        self.tokenizer = tokenizer
        self.seq_len = seq_len
        self.file_paths = []

        # Collect all .mid (.midi) files, skip everything else
        for root, _, files in os.walk(DATA_PATH):
            for file in files:
                if file.endswith('.mid') or file.endswith('.midi'):
                    self.file_paths.append(os.path.join(root, file))

        # Tokenize all collected files
        self.sequences = []
        for path in tqdm(self.file_paths[:150]): # Последние 150
            try:
                tokens = self.tokenizer.encode(path)
                # Split into fixed length sequences
                for i in range(0, len(tokens) - seq_len, seq_len//2):
                    self.sequences.append(tokens[i:i+seq_len])
            except Exception as e:
                # print(e)
                continue


    def __len__(self):
        return len(self.sequences)


    def __getitem__(self, idx):
        sequence = self.sequences[idx]
        # Input sequence and shifted by 1 target sequence
        x = torch.tensor(sequence[:-1], dtype=torch.long)
        y = torch.tensor(sequence[1:], dtype=torch.long)

        return x, y


dataset = MIDIDataset(DATA_PATH, tokenizer, seq_len)
print(len(dataset))
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

0it [00:00, ?it/s]

0





ValueError: num_samples should be a positive integer value, but got num_samples=0

## Positional encoder and Transformer

In [11]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

class MusicTransformer(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, num_layers, d_ff, dropout):
        super().__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        
        encoder_layers = nn.TransformerEncoderLayer(d_model, n_heads, d_ff, dropout)
        self.transformer = nn.TransformerEncoder(encoder_layers, num_layers)
        
        self.fc_out = nn.Linear(d_model, vocab_size)

        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        src = self.embedding(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        
        output = self.transformer(src, src_mask, src_key_padding_mask)
        output = self.fc_out(output)
        return output

model = MusicTransformer(
    vocab_size=tokenizer.vocab_size,
    d_model=d_model,
    n_heads=n_heads,
    num_layers=num_layers,
    d_ff=d_ff,
    dropout=dropout
).to('cuda' if torch.cuda.is_available() else 'cpu')

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

Model parameters: 11,962,626




## Learning

### Generate MIDI file

In [12]:
def generate_sample(model, tokenizer, temperature=1.0, max_len=512):
    model.eval()
    with torch.no_grad():
        # Random note as starting token
        current_token = torch.tensor([[np.random.randint(0, 128)]], device=model.device)
        generated = [current_token.item()]
        
        for _ in range(max_len):
            # Убираем лишнюю размерность (unsqueeze(0) уже есть в создании тензора)
            output = model(current_token)  # теперь форма [1, 1]
            
            # Получаем последний выход (форма [1, vocab_size])
            output = output[:, -1, :] / temperature
            probs = torch.softmax(output, dim=-1)
            
            next_token = torch.multinomial(probs, num_samples=1)
            generated.append(next_token.item())
            
            if next_token.item() == tokenizer.event2idx['END_OF_TRACK']:
                break
                
            current_token = next_token
        
        # Decode & save as file
        midi = tokenizer.decode(generated)
        output_path = os.path.join(OUTPUT_PATH, f"generated_{temperature}.mid")
        midi.write(output_path)
        print(f"Generated MIDI saved to {output_path}")

generate_sample(model, tokenizer, temperature=0.7)

Generated MIDI saved to generated_midi\generated_0.7.mid


In [49]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

def train_epoch(model, dataloader, optimizer, criterion):
    model.train()
    total_loss = 0
    
    for batch_idx, (src, tgt) in enumerate(tqdm(dataloader)):
        src, tgt = src.to(model.device), tgt.to(model.device)
        
        optimizer.zero_grad()
        output = model(src.transpose(0, 1))  # (seq_len, batch, vocab_size)
        
        loss = criterion(output.view(-1, tokenizer.vocab_size), tgt.transpose(0, 1).reshape(-1))
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(dataloader)

# Try to use cuda
model.device = 'cuda' if torch.cuda.is_available() else 'cpu'

train_losses = []
for epoch in range(epochs):
    epoch_loss = train_epoch(model, dataloader, optimizer, criterion)
    scheduler.step()
    train_losses.append(epoch_loss)
    print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss:.4f}")
    
    # Saving model epochs
    torch.save(model.state_dict(), f"transformer_epoch_{epoch+1}.pt")
    generate_sample(model, tokenizer, temperature=0.8)

# Loss plot
plt.plot(train_losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.show()

100%|██████████| 265/265 [01:36<00:00,  2.75it/s]


Epoch 1/20, Loss: 2.5032
Generated MIDI saved to generated_midi\generated_0.8.mid


100%|██████████| 265/265 [01:26<00:00,  3.06it/s]


Epoch 2/20, Loss: 2.1806
Generated MIDI saved to generated_midi\generated_0.8.mid


100%|██████████| 265/265 [02:03<00:00,  2.15it/s]


Epoch 3/20, Loss: 1.9900
Generated MIDI saved to generated_midi\generated_0.8.mid


100%|██████████| 265/265 [01:17<00:00,  3.44it/s]


Epoch 4/20, Loss: 1.8635
Generated MIDI saved to generated_midi\generated_0.8.mid


100%|██████████| 265/265 [01:24<00:00,  3.12it/s]


Epoch 5/20, Loss: 1.7697
Generated MIDI saved to generated_midi\generated_0.8.mid


100%|██████████| 265/265 [01:48<00:00,  2.44it/s]


Epoch 6/20, Loss: 1.7099
Generated MIDI saved to generated_midi\generated_0.8.mid


100%|██████████| 265/265 [02:15<00:00,  1.96it/s]


Epoch 7/20, Loss: 1.6780
Generated MIDI saved to generated_midi\generated_0.8.mid


  9%|▉         | 24/265 [00:13<02:12,  1.82it/s]


KeyboardInterrupt: 

### Params

-- 1 --

d_model = 512

n_heads = 8

num_layers = 6 

d_ff = 2048 

dropout = 0.1


batch_size = 16

seq_len = 512

epochs = 20

learning_rate = 0.0001




## Example of loading and using trained model

In [17]:
def load_model(model_path):
    model = MusicTransformer(
        vocab_size=tokenizer.vocab_size,
        d_model=d_model,
        n_heads=n_heads,
        num_layers=num_layers,
        d_ff=d_ff,
        dropout=dropout
    ).to('cuda' if torch.cuda.is_available() else 'cpu')
    
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
    model.eval()
    return model

# Example
trained_model = load_model("transformer_epoch_1.pt")
generate_sample(trained_model, tokenizer, temperature=1)

Generated MIDI saved to generated_midi\generated_1.mid
