In [1]:
!pip install muspy torch transformers datasets


Defaulting to user installation because normal site-packages is not writeable
Collecting muspy
  Downloading muspy-0.5.0-py3-none-any.whl.metadata (5.5 kB)
Collecting bidict>=0.21 (from muspy)
  Downloading bidict-0.23.1-py3-none-any.whl.metadata (8.7 kB)
Collecting pretty-midi>=0.2 (from muspy)
  Downloading pretty_midi-0.2.10.tar.gz (5.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.6/5.6 MB[0m [31m102.8 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting pypianoroll>=1.0 (from muspy)
  Downloading pypianoroll-1.0.4-py3-none-any.whl.metadata (3.8 kB)
Downloading muspy-0.5.0-py3-none-any.whl (119 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m119.1/119.1 kB[0m [31m22.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading bidict-0.23.1-py3-none-any.whl (32 kB)
Downloading pypianoroll-1.0.4-py3-none-any.whl (26 kB)
Building wheels for collected packages: pretty-midi
  Building whe

In [2]:
import muspy
import torch
from torch.utils.data import Dataset, DataLoader
from collections import defaultdict
import random

# Download and load MetaMIDI
dataset = muspy.META_MIDI(root="data/", download=True)

# Filter only those with genre metadata and not too long
filtered = [song for song in dataset if song.metadata and song.metadata.genre]
filtered = [s for s in filtered if s.get_note_sequence().get_end_time() < 60]  # less than 1 min

# Convert genre names to tokens
genre_set = sorted(set(s.metadata.genre for s in filtered))
genre2idx = {g: i for i, g in enumerate(genre_set)}

def encode_piece(piece):
    genre = piece.metadata.genre
    events = muspy.to_event_representation(piece, encode_velocity=True, encode_program=False)
    return {
        "genre": genre2idx[genre],
        "tokens": events
    }

encoded = [encode_piece(p) for p in filtered if len(muspy.to_event_representation(p)) > 10]


AttributeError: module 'muspy' has no attribute 'META_MIDI'

In [None]:
class MusicDataset(Dataset):
    def __init__(self, data, max_len=512):
        self.data = data
        self.max_len = max_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        tokens = item["tokens"][: self.max_len - 1]
        x = [item["genre"]] + tokens[:-1]  # input with genre token
        y = tokens  # target
        return torch.tensor(x), torch.tensor(y)

train_data = MusicDataset(encoded)
train_loader = DataLoader(train_data, batch_size=16, shuffle=True, drop_last=True)


In [None]:
import torch.nn as nn

class MusicTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=256, nhead=4, num_layers=4):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.transformer = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(d_model, nhead), num_layers
        )
        self.pos_enc = nn.Parameter(torch.randn(512, d_model))
        self.out = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        x = self.embed(x) + self.pos_enc[:x.size(1)]
        x = x.transpose(0, 1)  # time-first
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(x.size(0)).to(x.device)
        x = self.transformer(x, x, tgt_mask=tgt_mask)
        x = x.transpose(0, 1)
        return self.out(x)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vocab_size = max(max(e["tokens"]) for e in encoded) + 1
model = MusicTransformer(vocab_size + len(genre2idx)).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

for epoch in range(10):
    model.train()
    total_loss = 0
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        pred = model(x)
        loss = criterion(pred.view(-1, pred.size(-1)), y.view(-1))
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss:.2f}")


In [None]:
def generate(genre_idx, model, max_len=300):
    model.eval()
    tokens = [genre_idx]
    with torch.no_grad():
        for _ in range(max_len):
            inp = torch.tensor(tokens).unsqueeze(0).to(device)
            out = model(inp)
            next_token = out[0, -1].argmax().item()
            tokens.append(next_token)
    return tokens[1:]  # strip genre token

# Example: Generate Jazz
jazz_idx = genre2idx["Jazz"]
tokens = generate(jazz_idx, model)

# Convert back to MIDI
event_seq = muspy.EventSequence(events=tokens)
midi = muspy.from_event_representation(event_seq)
midi.write("generated_jazz.mid")
