## Transformer

In this section, we explore the use of a Transformer-based model for the task of jazz piano music generation. We train the model to predict the next musical event — a combination of a pitch and its duration — based on a fixed-length context of previous events. Once trained, the model can be used to generate new musical sequences by iteratively predicting events starting from an initial seed.



In [141]:
import ast
import random

import numpy as np
from numpy import e
import pandas as pd
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
from sklearn.model_selection import KFold
import pretty_midi
from music21 import converter, note, chord
from tqdm import tqdm

root = 'data_processed/'

### Import Dataset

In [3]:
def safe_parse_all_columns_df(df):
    """
    Parse all columns in a DataFrame to numeric, coercing errors.
    """
    df['notes'] = df['notes'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)
    df['chords'] = df['chords'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)
    df['velocities'] = df['velocities'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)
    df['durations'] = df['durations'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)
    df['offsets'] = df['offsets'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)
    df['ordered_events'] = df['ordered_events'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)
    return df

def load_dataframe_from_two_csvs(file1, file2):
    """
    Load and concatenate two CSV files into a single pandas DataFrame.
    """
    df1 = pd.read_csv(file1)
    df2 = pd.read_csv(file2)
    full_df = pd.concat([df1, df2], ignore_index=True)
    full_df = safe_parse_all_columns_df(full_df)

    return full_df

def save_dataframe_to_two_csvs(df, file1, file2):
    """
    Split a DataFrame in half and save it into two CSV files.
    """
    halfway = len(df) // 2
    df.iloc[:halfway].to_csv(file1, index=False)
    df.iloc[halfway:].to_csv(file2, index=False)

def load_dataframe_from_one_csv(file):
    """
    Load a DataFrame from a single CSV file.
    """
    df = pd.read_csv(file)
    
    return df

def save_dataframe_to_one_csv(df, file):
    """
    Save a DataFrame to a single CSV file.
    """
    df.to_csv(file, index=True)

def load_reconstructed_events(file):
    """
    Loads the reconstructed events CSV and safely parses the 'sequence' column,
    converting notes to integers and chords to lists of integers.
    """
    df = pd.read_csv(file)

    def safe_parse(seq_str):
        try:
            parsed = ast.literal_eval(seq_str)
            if not isinstance(parsed, list):
                raise ValueError("Parsed sequence is not a list")

            normalized = []
            for el in parsed:
                if isinstance(el, list):
                    normalized.append([int(x) for x in el])
                else:
                    normalized.append(int(el))
            return normalized

        except Exception as e:
            print(f"Error parsing sequence: {seq_str}")
            raise e

    df['sequence'] = df['sequence'].apply(safe_parse)
    return df

In [4]:
file1 = root + 'data_part1.csv'
file2 = root + 'data_part2.csv'

df = load_dataframe_from_two_csvs(file1, file2)

### Useful functions

In [10]:
def parse_chord_to_list(chord):
    """
    Convert a chord string to a list of integers.
    """
    if isinstance(chord, str):
        print([int(x) for x in chord.split(',') if x.isdigit()])
        return [int(x) for x in chord.split(',') if x.isdigit()]
    return []

In [None]:
def reconstruct_ordered_events(df):
    """
    Reconstruct the ordered list of events (notes and chords) for each song.
    """
    sequences  = []

    for i in range(len(df)):
        idx_note = 0
        idx_chord = 0
        reconstructed = []

        for element in df['ordered_events'][i]:
            if element == 'n':
                reconstructed.append(df['notes'][i][idx_note])
                idx_note += 1
            elif element == 'c':
                parsed_chord = parse_chord_to_list(df['chords'][i][idx_chord])
                reconstructed.append(df['chords'][i][idx_chord])
                idx_chord += 1
            else:
                raise ValueError(f"Unknown event type: {e}")
        
        sequences.append(reconstructed)

    reconstructed_dataset = pd.DataFrame({'sequence': sequences})
    reconstructed_dataset.index.name = 'index'

    return reconstructed_dataset

In [12]:
save_dataframe_to_one_csv(reconstruct_ordered_events(df), root + 'reconstructed_ordered_events.csv')

In [13]:
reconstructed_dataset = load_reconstructed_events(root + 'reconstructed_ordered_events.csv')

In [14]:
ordered_events_with_durations = reconstructed_dataset.copy()
ordered_events_with_durations['durations'] = df['durations']

In [15]:
ordered_events_with_durations

Unnamed: 0,index,sequence,durations
0,0,"[88, 38, 38, 45, 45, 45, [44, 45], 45, 45, 38,...","[0.25, 2.5, 0.25, 0.3333, 0.3333, 0.25, 0.3333..."
1,1,"[[56, 65, 68, 80, 60], [54, 61, 65, 75], 77, 7...","[1.25, 1.6667, 0.25, 1.0, 1.0, 0.5, 0.6667, 0...."
2,2,"[46, 70, [53, 60, 62], [65, 46, 53, 60], 70, [...","[2.6667, 0.25, 1.25, 0.75, 0.75, 1.6667, 1.75,..."
3,3,"[[52, 28, 40], 64, [65, 68], [53, 55, 56], [58...","[3.75, 2.5, 1.0, 2.0, 1.6667, 1.3333, 0.75, 0...."
4,4,"[59, 59, [60, 62, 38, 54], 62, [43, 62, 50, 52...","[0.75, 0.6667, 0.6667, 0.5, 1.0, 0.25, 2.0, 0...."
...,...,...,...
2770,2770,"[61, 49, [63, 60], 65, 70, [58, 61, 66, 49], [...","[1.3333, 1.6667, 1.0, 0.5, 0.5, 0.5, 0.6667, 0..."
2771,2771,"[[45, 55, 67, 69, 72, 75], [46, 74, 56, 62, 68...","[1.0, 1.0, 0.25, 1.25, 0.3333, 0.25, 0.25, 0.7..."
2772,2772,"[[45, 57], 71, 52, 61, 64, 45, 69, [45, 69], 4...","[1.6667, 0.6667, 3.3333, 2.75, 2.25, 0.75, 0.5..."
2773,2773,"[48, 41, 60, [67, 68, 72, 79], 66, [65, 68, 72...","[0.3333, 0.3333, 0.25, 0.25, 0.3333, 0.3333, 0..."


In [16]:
mismatches = ordered_events_with_durations.apply(
    lambda row: len(row['sequence']) != len(row['durations']),
    axis=1
)

invalid_rows = ordered_events_with_durations[mismatches]


if not invalid_rows.empty:
    print("Rows with mismatched 'sequence' and 'durations':")
    print(invalid_rows)
else:
    print("All rows have matching 'sequence' and 'durations' lengths.")

All rows have matching 'sequence' and 'durations' lengths.


In [None]:
save_dataframe_to_one_csv((ordered_events_with_durations), root + 'ordered_events_with_durations.csv')

## Tokenization Strategy

Before training, each music track is represented in a structured format as a row in a DataFrame. Each row contains two key sequences:
- `sequence`: a list of events, where each event is either a single note (as an integer) or a chord (as a list of integers).
- `durations`: a list of float values representing how long each event is held.

To convert these symbolic sequences into a format suitable for Transformer training, we use a custom `TokenVocabulary` class. This class encodes each pair of `(event, duration)` into two tokens:
- A **note/chord token**: formatted as `NOTE_<pitch>` for single notes (e.g., `NOTE_60`) or `CHORD_<pitch1>_<pitch2>_...` for chords (e.g., `CHORD_60_64_67`).
- A **duration token**: formatted as `DUR_<duration>` (e.g., `DUR_0.500`).

The `encode_sequence_and_durations` method processes each song by alternating between these two types of tokens, producing sequences like:
[NOTE_60, DUR_0.500, NOTE_62, DUR_0.250, CHORD_60_64_67, DUR_1.000, ...]

In [18]:
class TokenVocabulary:
    def __init__(self, df):
        """
        Build a vocabulary of all unique tokens (note, chord, duration).
        """
        token_set = set()

        for _, row in df.iterrows():
            for event, dur in zip(row['sequence'], row['durations']):
                if isinstance(event, list):
                    note_token = 'CHORD_' + '_'.join(map(str, sorted(event)))
                else:
                    note_token = f'NOTE_{event}'
                dur_token = f'DUR_{dur:.3f}'
                token_set.update([note_token, dur_token])

        self.tokens = sorted(token_set)
        self.token_to_idx = {tok: idx for idx, tok in enumerate(self.tokens)}
        self.idx_to_token = {idx: tok for tok, idx in self.token_to_idx.items()}
        self.vocab_size = len(self.tokens)

    def encode(self, token):
        return self.token_to_idx[token]

    def decode(self, index):
        return self.idx_to_token[index]

    def encode_sequence_and_durations(self, sequence, durations):
        """
        Given a sequence of events and durations, returns a list of token IDs
        in alternating [note_token, dur_token, note_token, dur_token, ...] format.
        """
        tokens = []
        for event, dur in zip(sequence, durations):
            if isinstance(event, list):
                note_token = 'CHORD_' + '_'.join(map(str, sorted(event)))
            else:
                note_token = f'NOTE_{event}'
            dur_token = f'DUR_{dur:.3f}'
            tokens.extend([note_token, dur_token])
        return [self.encode(token) for token in tokens]
    
    def decode_sequence_and_durations(self, token_ids):
        """
        Converts a flat list of token IDs (note/dur alternating) into
        a list of (event, duration) pairs.

        Returns:
            List of tuples: (event, duration), where event is int or list[int], duration is float
        """
        decoded = [self.decode(tid) for tid in token_ids]
        assert len(decoded) % 2 == 0, "Token sequence should be even-length (note-dur pairs)."

        event_sequence = []
        for i in range(0, len(decoded), 2):
            note_token = decoded[i]
            dur_token = decoded[i + 1]

            # Parse note/chord
            if note_token.startswith('NOTE_'):
                event = int(note_token.replace('NOTE_', ''))
            elif note_token.startswith('CHORD_'):
                event = list(map(int, note_token.replace('CHORD_', '').split('_')))
            else:
                raise ValueError(f"Invalid note token: {note_token}")

            # Parse duration
            if dur_token.startswith('DUR_'):
                duration = float(dur_token.replace('DUR_', ''))
            else:
                raise ValueError(f"Invalid duration token: {dur_token}")

            event_sequence.append((event, duration))

        return event_sequence


    

In [142]:
vocab = TokenVocabulary(ordered_events_with_durations)
print("Total tokens:", vocab.vocab_size)


Total tokens: 154107


In [21]:
class TokenizedMusicDataset(Dataset):
    def __init__(self, df, vocab, context_length=20):
        self.samples = []
        self.vocab = vocab
        self.context_length = context_length

        for _, row in df.iterrows():
            sequence = self.vocab.encode_sequence_and_durations(row['sequence'], row['durations'])
            n_tokens = len(sequence)

            # Since each event is 2 tokens, step by 2
            if n_tokens <= 2 * context_length + 2:
                continue

            for i in range(0, n_tokens - 2 * context_length - 2 + 1, 2):
                context = sequence[i : i + 2 * context_length]  # 2 tokens per event
                target = sequence[i + 2 * context_length : i + 2 * context_length + 2]  # next event (2 tokens)
                self.samples.append((context, target))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        context, target = self.samples[idx]
        context_tensor = torch.tensor(context, dtype=torch.long)
        target_tensor = torch.tensor(target, dtype=torch.long)
        return context_tensor, target_tensor



In [22]:
dataset = TokenizedMusicDataset(ordered_events_with_durations, vocab, context_length=20)



In [37]:
x, y = dataset[0]
print("Input token IDs:", x)
print("Target token ID:", y)
print("Decoded input tokens:", [vocab.decode(i.item()) for i in x])
print("Decoded target token:", [vocab.decode(i.item()) for i in y])


Input token IDs: tensor([154095, 153461, 154045, 153608, 154045, 153461, 154052, 153462, 154052,
        153462, 154052, 153461,  51829, 153462, 154052, 153461, 154052, 153462,
        154045, 153467, 112860, 153476,  51829, 153462, 154052, 153461, 154068,
        153470, 154053, 153462, 154053, 153462,  23927, 153479, 154064, 153461,
        154045, 153475, 154053, 153462])
Target token ID: tensor([154045, 153470])
Decoded input tokens: ['NOTE_88', 'DUR_0.250', 'NOTE_38', 'DUR_2.500', 'NOTE_38', 'DUR_0.250', 'NOTE_45', 'DUR_0.333', 'NOTE_45', 'DUR_0.333', 'NOTE_45', 'DUR_0.250', 'CHORD_44_45', 'DUR_0.333', 'NOTE_45', 'DUR_0.250', 'NOTE_45', 'DUR_0.333', 'NOTE_38', 'DUR_0.750', 'CHORD_55_60', 'DUR_1.500', 'CHORD_44_45', 'DUR_0.333', 'NOTE_45', 'DUR_0.250', 'NOTE_61', 'DUR_1.000', 'NOTE_46', 'DUR_0.333', 'NOTE_46', 'DUR_0.333', 'CHORD_38_65', 'DUR_1.750', 'NOTE_57', 'DUR_0.250', 'NOTE_38', 'DUR_1.417', 'NOTE_46', 'DUR_0.333']
Decoded target token: ['NOTE_38', 'DUR_1.000']


### Efficient Training via Sampling

Due to the large size of the dataset — and the high computational cost of training a Transformer model on all available sequences — we adopted a more efficient sampling-based strategy for training.

Instead of feeding entire tokenized tracks into the model, we use the `SampledMusicDataset` class to extract a fixed number of random training samples from each song. For every track in the dataset:
- The entire event-duration sequence is first tokenized as described earlier.
- Then, instead of using every possible subsequence, we randomly select a small number of **valid start points** within the sequence.
- From each chosen start point, we extract:
  - A **context** of `context_length` events (i.e., `2 * context_length` tokens).
  - The **target** event — the next pair of tokens immediately following the context.

By training on these randomized subsequences, the model learns to generalize from diverse musical contexts without the overhead of processing full-length tracks.


In [None]:
class SampledMusicDataset(Dataset):
    def __init__(self, df, vocab, context_length=20, samples_per_song=5):
        self.samples = []
        self.vocab = vocab
        self.context_length = context_length

        for _, row in df.iterrows():
            sequence = self.vocab.encode_sequence_and_durations(row['sequence'], row['durations'])
            n_tokens = len(sequence)

            max_start = n_tokens - 2 * context_length - 2
            if max_start < 0:
                continue  # skip very short songs

            # Collect all valid even start indices
            valid_starts = [i for i in range(0, max_start + 1, 2)]
            if not valid_starts:
                continue

            # Sample from valid starts
            for _ in range(samples_per_song):
                start = random.choice(valid_starts)
                context = sequence[start : start + 2 * context_length]
                target = sequence[start + 2 * context_length : start + 2 * context_length + 2]
                self.samples.append((context, target))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        context, target = self.samples[idx]
        context_tensor = torch.tensor(context, dtype=torch.long)
        target_tensor = torch.tensor(target, dtype=torch.long)
        return context_tensor, target_tensor


In [None]:
sampled_data = SampledMusicDataset(ordered_events_with_durations, vocab, context_length=20, samples_per_song=5)


In [45]:
x, y = sampled_data[0]
print("Input token IDs:", x)
print("Target token ID:", y)
print("Decoded input tokens:", [vocab.decode(i.item()) for i in x])
print("Decoded target token:", [vocab.decode(i.item()) for i in y])


Input token IDs: tensor([ 95523, 153467,  17453, 153461, 154059, 153740,  77853, 153461, 154081,
        153740, 154055, 153470, 154055, 153611, 150266, 153478, 154098, 153476,
        154088, 153734, 154065, 153462, 154071, 153610, 154089, 153479, 154084,
        153466, 154067, 153602, 129955, 153464, 154065, 153467, 154065, 153470,
        144481, 153466, 146040, 153605])
Target token ID: tensor([152082, 153474])
Decoded input tokens: ['CHORD_52_58_62', 'DUR_0.750', 'CHORD_36_70', 'DUR_0.250', 'NOTE_52', 'DUR_3.500', 'CHORD_48_58_65_70', 'DUR_0.250', 'NOTE_74', 'DUR_3.500', 'NOTE_48', 'DUR_1.000', 'NOTE_48', 'DUR_2.750', 'CHORD_79_86', 'DUR_1.667', 'NOTE_91', 'DUR_1.500', 'NOTE_81', 'DUR_3.000', 'NOTE_58', 'DUR_0.333', 'NOTE_64', 'DUR_2.667', 'NOTE_82', 'DUR_1.750', 'NOTE_77', 'DUR_0.667', 'NOTE_60', 'DUR_2.000', 'CHORD_58_70', 'DUR_0.500', 'NOTE_58', 'DUR_0.750', 'NOTE_58', 'DUR_1.000', 'CHORD_67_74', 'DUR_0.667', 'CHORD_69_93_98', 'DUR_2.250']
Decoded target token: ['CHORD_89_94',

#### Model

In [None]:
class MusicTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=64, nhead=2, num_layers=2, dropout=0.1, max_seq_len=512):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.pos_embedding = nn.Embedding(max_seq_len, d_model)

        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dropout=dropout, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.output_layer = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        """
        x: [batch_size, seq_len] of token indices
        """
        B, T = x.size()
        token_emb = self.token_embedding(x)                         # [B, T, d_model]
        pos_ids = torch.arange(T, device=x.device).unsqueeze(0)     # [1, T]
        pos_emb = self.pos_embedding(pos_ids)                       # [1, T, d_model]

        x = token_emb + pos_emb                                     # [B, T, d_model]
        x = self.transformer(x)                                     # [B, T, d_model]
        logits = self.output_layer(x)                               # [B, T, vocab_size]

        return logits


In [None]:
vocab_size = vocab.vocab_size
model = MusicTransformer(vocab_size=vocab_size)


In [None]:
device = torch.device( "cpu")
model = model.to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)


In [None]:
train_loader = DataLoader(sampled_data, batch_size=16, shuffle=True)


In [None]:
epochs = 3

for epoch in range(epochs):
    model.train()
    total_loss = 0
    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch + 1}")

    for batch_idx, (x, y) in progress_bar:
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad()
        logits = model(x)
        note_logits = logits[:, -2, :]
        dur_logits = logits[:, -1, :]

        loss_note = loss_fn(note_logits, y[:, 0])
        loss_dur = loss_fn(dur_logits, y[:, 1])
        loss = loss_note + loss_dur

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch + 1} complete | Avg Loss: {avg_loss:.4f}")



Epoch 1: 100%|██████████| 868/868 [19:47<00:00,  1.37s/it, loss=11.2]


Epoch 1 complete | Avg Loss: 9.4007


Epoch 2: 100%|██████████| 868/868 [18:02<00:00,  1.25s/it, loss=7.18]


Epoch 2 complete | Avg Loss: 8.9449


Epoch 3: 100%|██████████| 868/868 [18:29<00:00,  1.28s/it, loss=8.18]

Epoch 3 complete | Avg Loss: 8.6550





### Generating music

In [130]:
def generate_music(model, vocab, seed_sequence, seed_durations, context_length, generate_events, temperature, device):
    """
    Generate music by alternating between note/chord and duration tokens.
    """
    model.eval()
    tokens = vocab.encode_sequence_and_durations(seed_sequence, seed_durations)
    


    generated_tokens = tokens.copy()

    # Decide what type of token comes next
    next_is_note = (len(generated_tokens) % 2 == 0)  # even index → note, odd → dur

    for _ in range(generate_events * 2):
        context = generated_tokens[-context_length:] if len(generated_tokens) > context_length else generated_tokens
        input_tensor = torch.tensor(context, dtype=torch.long, device=device).unsqueeze(0)

        with torch.no_grad():
            logits = model(input_tensor)

        next_token_logits = logits[0, -1, :] / temperature
        probs = F.softmax(next_token_logits, dim=-1)


        # Mask out invalid tokens (either note or duration)
        filtered_probs = probs.clone()
        for i, token in enumerate(vocab.tokens):
            if next_is_note and token.startswith("DUR_"):
                filtered_probs[i] = 0
            elif not next_is_note and (token.startswith("NOTE_") or token.startswith("CHORD_")):
                filtered_probs[i] = 0

        if filtered_probs.sum() == 0:
            raise ValueError("All probabilities filtered out — check vocab consistency or model behavior.")

        filtered_probs = filtered_probs / filtered_probs.sum()
        next_token = torch.multinomial(filtered_probs, num_samples=1).item()
        generated_tokens.append(next_token)
        next_is_note = not next_is_note  # alternate

    return vocab.decode_sequence_and_durations(generated_tokens)



#### Trial on a single sequence

In [94]:
# Select the first track from the dataset
track_index = 0
seed_seq = ordered_events_with_durations.iloc[track_index]['sequence']
seed_durs = ordered_events_with_durations.iloc[track_index]['durations']

# Generation settings
context_length = 20
generate_events = 50
temperature = 0.8

# Generate music
generated_sequence = generate_music(
    model=model,
    vocab=vocab,
    seed_sequence=seed_seq[:context_length],
    seed_durations=seed_durs[:context_length],
    context_length=context_length,
    generate_events=generate_events,
    temperature=temperature,
    device=device
)

# Print results
print("Generated sequence:")
for i, (event, dur) in enumerate(generated_sequence):
    print(f"{i+1}: {event} ({dur})")


Generated sequence:
1: 88 (0.25)
2: 38 (2.5)
3: 38 (0.25)
4: 45 (0.333)
5: 45 (0.333)
6: 45 (0.25)
7: [44, 45] (0.333)
8: 45 (0.25)
9: 45 (0.333)
10: 38 (0.75)
11: [55, 60] (1.5)
12: [44, 45] (0.333)
13: 45 (0.25)
14: 61 (1.0)
15: 46 (0.333)
16: 46 (0.333)
17: [38, 65] (1.75)
18: 57 (0.25)
19: 38 (1.417)
20: 46 (0.333)
21: 84 (0.333)
22: 48 (0.25)
23: 48 (0.5)
24: 58 (0.333)
25: 56 (0.25)
26: 52 (0.5)
27: [46, 56, 62, 70] (0.333)
28: 72 (0.167)
29: 50 (0.5)
30: 86 (0.25)
31: 37 (1.75)
32: [36, 48] (1.25)
33: 77 (0.25)
34: 58 (0.25)
35: 66 (0.333)
36: 93 (1.25)
37: 77 (0.25)
38: 85 (0.25)
39: [56, 60, 65] (2.333)
40: 88 (0.833)
41: 95 (0.5)
42: 69 (1.75)
43: 83 (1.25)
44: 97 (0.25)
45: 79 (2.0)
46: 81 (3.667)
47: 80 (0.75)
48: 35 (0.667)
49: [54, 64] (2.0)
50: 76 (1.5)
51: 82 (0.667)
52: 31 (0.25)
53: 52 (1.333)
54: 46 (0.25)
55: 51 (0.5)
56: [54, 63] (0.25)
57: 49 (0.5)
58: 52 (0.667)
59: 86 (0.5)
60: 50 (0.25)
61: 86 (0.25)
62: 101 (0.5)
63: 83 (0.25)
64: 74 (0.25)
65: 81 (0.75)
66: 5

#### Multiple generations

In [None]:
def sequence_to_midi(event_sequence, output_path="generated.mid", program=0):
    """
    Converts a sequence of (note or chord, duration) into a MIDI file.

    Args:
        event_sequence: List of (note:int or chord:list[int], duration:float)
        output_path: Path to save the output MIDI file
        program: MIDI instrument program number (0 = Acoustic Grand Piano)

    Returns:
        None (writes MIDI file to disk)
    """
    midi = pretty_midi.PrettyMIDI()
    instrument = pretty_midi.Instrument(program=program)

    current_time = 0.0
    for event, duration in event_sequence:
        if isinstance(event, int):
            note = pretty_midi.Note(
                velocity=100,
                pitch=event,
                start=current_time,
                end=current_time + duration
            )
            instrument.notes.append(note)
        elif isinstance(event, list):
            for pitch in event:
                note = pretty_midi.Note(
                    velocity=100,
                    pitch=pitch,
                    start=current_time,
                    end=current_time + duration
                )
                instrument.notes.append(note)
        else:
            raise ValueError(f"Invalid event type: {event}")
        current_time += duration

    midi.instruments.append(instrument)
    midi.write(output_path)
    print(f"MIDI file written to: {output_path}")


In [None]:

output_dir = "generated_music_transformer/dataset"
os.makedirs(output_dir, exist_ok=True)

context_length = 20     
generate_events = 50    
temperature = 0.9       

for i in range(10):
    row = ordered_events_with_durations.iloc[i]
    seed_seq = row['sequence'][:context_length]
    seed_durs = row['durations'][:context_length]

    # Ensure matching length
    if len(seed_seq) != len(seed_durs) or len(seed_seq) < context_length:
        print(f"Skipping track {i} due to length mismatch or insufficient seed.")
        continue

    # Generate music
    generated_sequence = generate_music(
        model=model,
        vocab=vocab,
        seed_sequence=seed_seq,
        seed_durations=seed_durs,
        context_length=context_length,
        generate_events=generate_events,
        temperature=temperature,
        device=device
    )

    # Save as MIDI
    output_path = os.path.join(output_dir, f"sample_{i}.mid")
    sequence_to_midi(generated_sequence, output_path=output_path)

    print(f"[{i+1}/100] Saved: {output_path}")


MIDI file written to: generated_music_transformer/dataset\sample_0.mid
[1/10] Saved: generated_music_transformer/dataset\sample_0.mid
MIDI file written to: generated_music_transformer/dataset\sample_1.mid
[2/10] Saved: generated_music_transformer/dataset\sample_1.mid
MIDI file written to: generated_music_transformer/dataset\sample_2.mid
[3/10] Saved: generated_music_transformer/dataset\sample_2.mid
MIDI file written to: generated_music_transformer/dataset\sample_3.mid
[4/10] Saved: generated_music_transformer/dataset\sample_3.mid
MIDI file written to: generated_music_transformer/dataset\sample_4.mid
[5/10] Saved: generated_music_transformer/dataset\sample_4.mid
MIDI file written to: generated_music_transformer/dataset\sample_5.mid
[6/10] Saved: generated_music_transformer/dataset\sample_5.mid
MIDI file written to: generated_music_transformer/dataset\sample_6.mid
[7/10] Saved: generated_music_transformer/dataset\sample_6.mid
MIDI file written to: generated_music_transformer/dataset\samp

#### Generation from popular songs

In [None]:
def midi_to_event_duration_sequence(midi_path):
    """
    Convert a MIDI file to a (sequence, durations) tuple.
    sequence: list of ints or list of lists (for chords)
    durations: list of floats
    """
    try:
        score = converter.parse(midi_path)
        flat_notes = score.flat.notes

        sequence = []
        durations = []

        for element in flat_notes:
            if isinstance(element, note.Note):
                sequence.append(element.pitch.midi)
                durations.append(float(element.quarterLength))
            elif isinstance(element, chord.Chord):
                sequence.append(sorted(p.midi for p in element.pitches))
                durations.append(float(element.quarterLength))

        return sequence, durations
    except Exception as e:
        return f"Error processing {midi_path}: {str(e)}", [], []


In [105]:
midi_folder = "popular_songs"
converted_data = {}

for filename in os.listdir(midi_folder):
    if filename.endswith(".mid") or filename.endswith(".midi"):
        midi_path = os.path.join(midi_folder, filename)
        result = midi_to_event_duration_sequence(midi_path)
        if isinstance(result, tuple):
            sequence, durations = result
            converted_data[filename] = {
                "sequence": sequence,
                "durations": durations
            }
        else:
            print(result[0])  # Print error message if any

converted_data.keys()  # Show which files were processed successfully

  return self.iter().getElementsByClass(classFilterList)


dict_keys(['AnotherBrickInTheWall.mid', 'BackInBlack.mid', 'BillieJean.mid', 'DancingQueen.mid', 'FinalCountdown.mid', 'Hallelujah.mid', 'HappyBirthday.mid', 'HipsDontLie.mid', 'HotelCalifornia.mid', 'PokerFace.mid', 'Titanic.mid', 'Umbrella.mid'])

In [113]:
starting_points = {'AnotherBrickInTheWall.mid': 61, 'BackInBlack.mid': 60, 'BillieJean.mid': 50, 'DancingQueen.mid': 80, 'FinalCountdown.mid': 10, 'Hallelujah.mid':80, 'HappyBirthday.mid': 40, 'HipsDontLie.mid': 60, 'HotelCalifornia.mid': 50, 'PokerFace.mid': 280, 'Titanic.mid': 60, 'Umbrella.mid': 60}

The vocabulary was built only from the training dataset and may not include all possible chords from the popular songs (converted_data).


In [None]:
def check_popular_songs_in_vocab(vocab, converted_data, starting_points, context_length):
    """
    Check each popular song to see if the slice from start_idx with length context_length
    contains any tokens not in the vocab.
    Prints songs with unknown tokens.
    Returns a list of filenames that are safe to use (all tokens known in that slice).
    """
    safe_songs = []

    for filename, start_idx in starting_points.items():
        if filename not in converted_data:
            print(f"Skipping {filename} (not in converted_data)")
            continue

        song_data = converted_data[filename]
        sequence = song_data['sequence']
        durations = song_data['durations']

        
        if len(sequence) < start_idx + context_length or len(durations) < start_idx + context_length:
            print(f"Skipping {filename} (not enough events after start index)")
            continue

        seq_slice = sequence[start_idx : start_idx + context_length]
        dur_slice = durations[start_idx : start_idx + context_length]

        unknown_tokens_found = False

        for event, dur in zip(seq_slice, dur_slice):
            if isinstance(event, list):
                note_token = 'CHORD_' + '_'.join(map(str, sorted(event)))
            else:
                note_token = f'NOTE_{event}'
            dur_token = f'DUR_{dur:.3f}'

            if note_token not in vocab.token_to_idx or dur_token not in vocab.token_to_idx:
                unknown_tokens_found = True
                break

        if unknown_tokens_found:
            print(f"Unknown tokens in {filename} — skipping this song.")
        else:
            safe_songs.append(filename)

    return safe_songs



In [None]:
safe_popular_songs = check_popular_songs_in_vocab(vocab, converted_data, starting_points, 20)

for filename in safe_popular_songs:
    start_idx = starting_points[filename]
 


Unknown tokens in DancingQueen.mid — skipping this song.
Unknown tokens in HipsDontLie.mid — skipping this song.
Unknown tokens in PokerFace.mid — skipping this song.


In [None]:
output_dir = "generated_music_transformer/popular"
os.makedirs(output_dir, exist_ok=True)

context_length = 20
generate_events = 50
temperature = 0.9


for filename in safe_popular_songs:
    start_idx = starting_points[filename]
    song_data = converted_data[filename]
    full_seq = song_data["sequence"]
    full_durs = song_data["durations"]

    seed_seq = full_seq[start_idx : start_idx + context_length]
    seed_durs = full_durs[start_idx : start_idx + context_length]

    
    if len(seed_seq) != context_length or len(seed_durs) != context_length:
        print(f"Skipping {filename} due to mismatched seed lengths")
        continue

    generated_sequence = generate_music(
        model=model,
        vocab=vocab,
        seed_sequence=seed_seq,
        seed_durations=seed_durs,
        context_length=context_length,
        generate_events=generate_events,
        temperature=temperature,
        device=device
    )

    output_path = os.path.join(output_dir, f"{filename.replace('.mid', '')}_gen.mid")
    sequence_to_midi(generated_sequence, output_path=output_path)

    print(f"Saved: {output_path}")


MIDI file written to: generated_music_transformer/popular\AnotherBrickInTheWall_gen.mid
Saved: generated_music_transformer/popular\AnotherBrickInTheWall_gen.mid
MIDI file written to: generated_music_transformer/popular\BackInBlack_gen.mid
Saved: generated_music_transformer/popular\BackInBlack_gen.mid
MIDI file written to: generated_music_transformer/popular\BillieJean_gen.mid
Saved: generated_music_transformer/popular\BillieJean_gen.mid
MIDI file written to: generated_music_transformer/popular\FinalCountdown_gen.mid
Saved: generated_music_transformer/popular\FinalCountdown_gen.mid
MIDI file written to: generated_music_transformer/popular\Hallelujah_gen.mid
Saved: generated_music_transformer/popular\Hallelujah_gen.mid
MIDI file written to: generated_music_transformer/popular\HappyBirthday_gen.mid
Saved: generated_music_transformer/popular\HappyBirthday_gen.mid
MIDI file written to: generated_music_transformer/popular\HotelCalifornia_gen.mid
Saved: generated_music_transformer/popular\Hot