In [None]:
import os
import numpy as np
import pretty_midi
import glob
from collections import defaultdict

midi_path = "./Jazz-Midi"
output_path = './tokenized_jazz'

os.makedirs(output_path, exist_ok=True)

max_tick = 1024
max_velocity = 127

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def extract_notes(midi_file):
    """Extract notes from a MIDI file and convert to tokens."""
    try:
        midi_data = pretty_midi.PrettyMIDI(midi_file)

        notes = []
        for instrument in midi_data.instruments:
            if not instrument.is_drum:
                for note in instrument.notes:
                    start_tick = int(note.start * midi_data.resolution * 4)
                    end_tick = int(note.end * midi_data.resolution * 4)
                    notes.append({
                        'pitch': note.pitch,
                        'start': start_tick,
                        'end': end_tick,
                        'velocity': note.velocity,
                        'instrument': instrument.program
                    })

        notes = sorted(notes, key=lambda x: (x['start'], x['pitch']))
        return notes
    except Exception as e:
        print(f"Error processing {midi_file}: {e}")
        return None

def tokenize_midi_file(midi_file, file_id):
    """Convert a MIDI file to a sequence of tokens."""
    notes = extract_notes(midi_file)
    if not notes:
        return None

    tokens = []
    time_shift = 0
    prev_start = 0

    for note in notes:
        if note['start'] > prev_start:
            time_delta = min(note['start'] - prev_start, max_tick)
            time_token = {'type': 'time_shift', 'value': time_delta}
            tokens.append(time_token)
            prev_start = note['start']

        note_on = {
            'type': 'note_on',
            'pitch': note['pitch'],
            'velocity': min(note['velocity'], max_velocity),
            'instrument': note['instrument']
        }
        tokens.append(note_on)

        note_off = {
            'type': 'note_off',
            'pitch': note['pitch'],
            'start': note['start'],
            'end': note['end']
        }
        tokens.append(note_off)


    return tokens, file_id

midi_files = glob.glob(os.path.join(midi_path, '**', '*.mid'), recursive=True)
midi_files += glob.glob(os.path.join(midi_path, '**', '*.midi'), recursive=True)

all_sequences = {}

for i, midi_file in enumerate(midi_files):
    print(f"Processing {i+1}/{len(midi_files)}: {midi_file}")
    result = tokenize_midi_file(midi_file, i)
    if result:
        tokens, file_id = result
        all_sequences[file_id] = tokens

np.save(os.path.join(output_path, 'tokenized_jazz.npy'), all_sequences)

Processing 1/934: ./Jazz-Midi\2ndMovementOfSinisterFootwear.mid
Processing 2/934: ./Jazz-Midi\55Dive.mid
Processing 3/934: ./Jazz-Midi\5To10.mid
Processing 4/934: ./Jazz-Midi\634-5789.mid
Processing 5/934: ./Jazz-Midi\914.mid
Processing 6/934: ./Jazz-Midi\ABC.mid
Processing 7/934: ./Jazz-Midi\ACertainSmile.mid
Processing 8/934: ./Jazz-Midi\ACrush.mid
Processing 9/934: ./Jazz-Midi\Adayinalifeofafool.mid
Processing 10/934: ./Jazz-Midi\Adventage.mid
Processing 11/934: ./Jazz-Midi\AffairInSanMiguel.mid
Error processing ./Jazz-Midi\AffairInSanMiguel.mid: data byte must be in range 0..127
Processing 12/934: ./Jazz-Midi\AFifthofBeethoven.mid
Processing 13/934: ./Jazz-Midi\Afoggydayilondontown.mid
Processing 14/934: ./Jazz-Midi\AfterTheLoveHasGone.mid
Processing 15/934: ./Jazz-Midi\AfterTheRainHasFallen.mid
Processing 16/934: ./Jazz-Midi\AgeOfAquarius.mid
Processing 17/934: ./Jazz-Midi\Aintchagottired.mid
Processing 18/934: ./Jazz-Midi\AintNothingLikeRealThingBaby.mid
Processing 19/934: ./Jazz



Processing 31/934: ./Jazz-Midi\AllIEverNeedIsYou.mid
Processing 32/934: ./Jazz-Midi\AllMyLife.mid
Processing 33/934: ./Jazz-Midi\AllMyTomorrows.mid
Processing 34/934: ./Jazz-Midi\AllOfMe.mid
Processing 35/934: ./Jazz-Midi\AllofYou.mid
Processing 36/934: ./Jazz-Midi\AllOrNothingAtAll.mid
Processing 37/934: ./Jazz-Midi\AllOverYou.mid
Processing 38/934: ./Jazz-Midi\AllTheThingsYouAre.mid
Processing 39/934: ./Jazz-Midi\AllTheWay.mid
Processing 40/934: ./Jazz-Midi\AllThisTime.mid
Processing 41/934: ./Jazz-Midi\AllThumb.mid
Processing 42/934: ./Jazz-Midi\Alone.mid
Processing 43/934: ./Jazz-Midi\AloneTogether.mid
Processing 44/934: ./Jazz-Midi\Alphonso.mid
Processing 45/934: ./Jazz-Midi\Amazinggrace.mid
Processing 46/934: ./Jazz-Midi\Andy.mid
Processing 47/934: ./Jazz-Midi\Angela.mid
Processing 48/934: ./Jazz-Midi\AngelEyes.mid
Processing 49/934: ./Jazz-Midi\AnimateInanimate.mid
Processing 50/934: ./Jazz-Midi\AnotherDay.mid
Processing 51/934: ./Jazz-Midi\AnotherDimension.mid
Processing 52/934

In [None]:
import torch
import torch.nn as nn
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=2048):
        super().__init__()
        pos = torch.arange(0, max_len).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class JazzTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6, dim_feedforward=2048, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model)
        decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers)
        self.output = nn.Linear(d_model, vocab_size)

    def generate_square_subsequent_mask(self, sz):
        return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)

    def forward(self, tgt):
        x = self.embedding(tgt)
        x = self.pos_encoder(x)
        tgt_mask = self.generate_square_subsequent_mask(x.size(1)).to(x.device)
        out = self.decoder(tgt=x, memory=None, tgt_mask=tgt_mask)
        return self.output(out)


In [None]:
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torch.optim as optim

class MIDIDataset(torch.utils.data.Dataset):
    def __init__(self, file_path, seq_len=128):
        self.data = np.load(file_path)
        self.seq_len = seq_len

    def __len__(self):
        return len(self.data) - self.seq_len

    def __getitem__(self, idx):
        x = self.data[idx:idx + self.seq_len]
        y = self.data[idx + 1:idx + self.seq_len + 1]
        return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)

dataset = MIDIDataset('./tokenized_data.npy', seq_len=128)

def train_model(model, dataset, vocab_size, epochs=10, batch_size=32, lr=1e-4):
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            output = model(x)
            loss = criterion(output.view(-1, vocab_size), y.view(-1))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}: Loss = {total_loss / len(dataloader):.4f}")


In [None]:
def inference(model, start_seq, length=100, temperature=1.0):
    model.eval()
    generated = start_seq.clone().detach().to(device)

    for _ in range(length):
        with torch.no_grad():
            output = model(generated[:, -model.pos_encoder.pe.size(1):])
            next_token_logits = output[:, -1, :] / temperature
            probs = F.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            generated = torch.cat((generated, next_token), dim=1)

    return generated.squeeze().tolist()


In [None]:
#training

model = JazzTransformer(vocab_size=10000)
model.to(device)
dataset = MIDIDataset('./tokenized_data.npy', seq_len=128)
train_model(model, dataset, vocab_size=10000, epochs=10, batch_size=32, lr=1e-4)



