In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import miditoolkit
from miditoolkit import MidiFile, Instrument, Note
import mido
from tqdm import tqdm

In [2]:
# Adds sinusoidal positional encoding to token embeddings, following Vaswani et al.
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :]

In [45]:
# Transformer-based model for sequence prediction using pitch only
class MusicTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=128, nhead=8, num_encoder_layers=2, num_decoder_layers=2, dim_feedforward=512, max_len=1024, dropout=0.1):
        super(MusicTransformer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, max_len)
        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout
        )
        self.output = nn.Linear(d_model, vocab_size)
        self.d_model = d_model

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        src = self.embedding(src) * np.sqrt(self.d_model)
        tgt = self.embedding(tgt) * np.sqrt(self.d_model)
        src = self.pos_encoder(src)
        tgt = self.pos_encoder(tgt)
        src = src.permute(1, 0, 2)
        tgt = tgt.permute(1, 0, 2)
        output = self.transformer(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
        output = output.permute(1, 0, 2)
        return self.output(output)


In [4]:
# Returns the sparse categorical cross-entropy loss
# Equivalent to nn.CrossEntropyLoss in PyTorch

def get_loss():
    return nn.CrossEntropyLoss()

def get_optimizer(model):
    return torch.optim.Adam(model.parameters(), lr=1e-4)

# Generate a mask to prevent attention to future tokens
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

In [46]:
# Dataset returns pitch-only sequences
class MidiPitchDataset(Dataset):
    def __init__(self, midi_files, seq_len=128):
        self.data = []
        self.seq_len = seq_len

        for path in midi_files:
            try:
                midi_obj = miditoolkit.MidiFile(path, clip=True)
                notes = []
                for inst in midi_obj.instruments:
                    if inst.is_drum and inst.notes:
                        notes.extend(inst.notes)
                if len(notes) < seq_len + 1:
                    continue
                notes = sorted(notes, key=lambda x: x.start)
                pitch_seq = [note.pitch for note in notes]
                for i in range(0, len(pitch_seq) - seq_len):
                    src_seq = pitch_seq[i:i+seq_len]
                    tgt_seq = pitch_seq[i+1:i+seq_len+1]
                    self.data.append((src_seq, tgt_seq))
            except Exception as e:
                print(f"Skipping {path.name} due to error: {e}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        src, tgt = self.data[idx]
        src_tensor = torch.tensor(src, dtype=torch.long)
        tgt_tensor = torch.tensor(tgt, dtype=torch.long)
        return src_tensor, tgt_tensor

In [47]:
def train_model(model, dataloader, optimizer, loss_fn, device, epochs=10):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch + 1}")
        for src, tgt in progress_bar:
            src, tgt = src.to(device), tgt.to(device)
            optimizer.zero_grad()
            tgt_input = tgt[:, :-1]
            tgt_expected = tgt[:, 1:]
            tgt_mask = generate_square_subsequent_mask(tgt_input.size(1)).to(device)
            output = model(src, tgt_input, tgt_mask=tgt_mask)
            loss = loss_fn(output.view(-1, output.size(-1)), tgt_expected.reshape(-1))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            progress_bar.set_postfix(loss=total_loss / (progress_bar.n + 1))
        print(f"Epoch {epoch + 1} completed. Avg Loss: {total_loss / len(dataloader):.4f}")

In [60]:
def generate_melody(model, seed_sequence, length=50, device='cpu'):
    model.eval()
    generated = seed_sequence[:]
    valid_drums = [36, 38, 40, 41, 42, 43, 45, 46, 47, 48, 49, 50, 51, 53, 54, 55, 57, 59]
    for _ in range(length):
        src = torch.tensor([generated], dtype=torch.long).to(device)
        tgt = torch.tensor([generated], dtype=torch.long).to(device)
        if src.shape[1] > 128:
            src = src[:, -128:]
            tgt = tgt[:, -128:]
        tgt_mask = generate_square_subsequent_mask(tgt.size(1)).to(device)
        with torch.no_grad():
            output = model(src, tgt, tgt_mask=tgt_mask)
            probs = torch.softmax(output[0, -1], dim=-1)
            raw_pitch = torch.multinomial(probs, 1).item()
            next_pitch = min(valid_drums, key=lambda x: abs(x - raw_pitch))
        generated.append(next_pitch)
    return generated

In [49]:
def save_melody_as_midi(pitch_sequence, file_path="generated_melody.mid"):
    from miditoolkit import MidiFile, Instrument, Note
    midi = MidiFile()
    instrument = Instrument(program=0, is_drum=True)
    time = 0
    for pitch in pitch_sequence:
        note = Note(velocity=100, pitch=int(pitch), start=int(time), end=int(time + 120))
        instrument.notes.append(note)
        time += 120
    midi.instruments.append(instrument)
    midi.dump(file_path)
    print(f"Melody saved to {file_path}")


In [51]:
# ========================
# Data loading and training
# ========================
midi_path = Path("groove")
midi_files = list(midi_path.rglob("*.mid"))
print(f"Found {len(midi_files)} MIDI files.")
random.shuffle(midi_files)
train_files = midi_files[:500]  # Limit to 500 for faster training

# Prepare dataset and dataloader
dataset = MidiPitchDataset(train_files)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Initialize model and training setup
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model = MusicTransformer(vocab_size=128).to(device)
optimizer = get_optimizer(model)
loss_fn = get_loss()

Found 1150 MIDI files.


In [52]:
# Train the model
train_model(model, dataloader, optimizer, loss_fn, device)

Epoch 1: 100%|███████████████████| 4863/4863 [05:12<00:00, 15.58it/s, loss=1.54]


Epoch 1 completed. Avg Loss: 1.5347


Epoch 2: 100%|███████████████████| 4863/4863 [05:05<00:00, 15.90it/s, loss=1.39]


Epoch 2 completed. Avg Loss: 1.3891


Epoch 3: 100%|███████████████████| 4863/4863 [05:06<00:00, 15.88it/s, loss=1.34]


Epoch 3 completed. Avg Loss: 1.3388


Epoch 4: 100%|████████████████████| 4863/4863 [05:08<00:00, 15.79it/s, loss=1.3]


Epoch 4 completed. Avg Loss: 1.2983


Epoch 5: 100%|███████████████████| 4863/4863 [04:55<00:00, 16.43it/s, loss=1.27]


Epoch 5 completed. Avg Loss: 1.2649


Epoch 6: 100%|███████████████████| 4863/4863 [04:52<00:00, 16.63it/s, loss=1.24]


Epoch 6 completed. Avg Loss: 1.2361


Epoch 7: 100%|███████████████████| 4863/4863 [04:53<00:00, 16.58it/s, loss=1.21]


Epoch 7 completed. Avg Loss: 1.2091


Epoch 8: 100%|███████████████████| 4863/4863 [04:53<00:00, 16.59it/s, loss=1.18]


Epoch 8 completed. Avg Loss: 1.1830


Epoch 9: 100%|███████████████████| 4863/4863 [04:54<00:00, 16.51it/s, loss=1.16]


Epoch 9 completed. Avg Loss: 1.1584


Epoch 10: 100%|██████████████████| 4863/4863 [04:55<00:00, 16.46it/s, loss=1.14]

Epoch 10 completed. Avg Loss: 1.1355





In [53]:
# Save model to disk
model_path = "music_transformer.pth"
torch.save(model.state_dict(), model_path)
print(f"Model saved to {model_path}")

Model saved to music_transformer.pth


In [71]:
# Get a sample directly from the dataset
seed_src, _ = dataset[6050]  # Take the first source-target pair
seed = seed_src[:32].tolist()  # Convert to list of ints
generated = generate_melody(model, seed, length=64, device=device)
save_melody_as_midi(generated, file_path="generated_output.mid")

Melody saved to generated_output.mid
