## Loading and Preprocessing the ABC Music Dataset

In [1]:
import os

data_dir = "data/"
abc_files = [f for f in os.listdir(data_dir) if f.endswith(".abc")]

# Each .abc file contains multiple songs. We will split each file into a list of individual tunes.
all_tunes = []
for filename in abc_files:
    with open(os.path.join(data_dir, filename), 'r') as file:
        content = file.read()
        # In a file, each song is separated by a blank line
        tunes = content.split('\n\n')
        for i, tune in enumerate(tunes):
            all_tunes.append(tune)

print(f"Total tunes extracted: {len(all_tunes)}")
print("Example tune:")
print(all_tunes[0])

Total tunes extracted: 1049
Example tune:

X: 1
T:A and A's Waltz
% Nottingham Music Database
S:Mick Peat
M:3/4
L:1/4
K:G
e|:"G"d2B|"D"A3/2B/2c|"G"B2G|"D"A2e|"G"d2B|"D"A3/2B/2c|
M:2/4
"F"B=F|
M:3/4
"G"G2e:||:
"C"g2e|"Bb"=f2d|"F"c2A|=F2e|"C"g2e|"Bb"=f2d|
M:2/4
"F"cA|
M:3/4
 [1 "G"G2e:| [2"G"G2z||


Some tunes may be incomplete or junk. A simple heuristic:
- Keep only tunes with a key signature K:
- Maybe also require a time signature M:

In [2]:
clean_tunes = [t for t in all_tunes if 'K:' in t and 'M:' in t]
print(f"Tunes after cleaning: {len(clean_tunes)}")

Tunes after cleaning: 1034


In [3]:
def extract_melody_with_context(tune):
    """
    Extracts the essential musical information and melody from an ABC tune.
    Automatically fills in a default L: value if missing, based on M: (meter).

    Parameters:
        tune (str): A raw ABC tune as a multiline string.

    Returns:
        str: A string containing only the essential musical context and the melody.

    Notes:
        In ABC notation, tunes begin with a header section that may contain many fields.
        Only some of these fields are musically relevant to interpreting the notes.

        We keep:
            - M: Meter (e.g., "M:4/4")
            - L: Default note length (e.g., "L:1/8")
            - K: Key signature (e.g., "K:Cmaj")

        These fields are needed to interpret rhythm and pitch correctly.

        We discard:
            - X: Tune number (just an ID)
            - T: Title
            - N:, C:, Z:, Q:, etc. — any fields that are comments, composer names,
              tempo hints, etc., which are often inconsistent or irrelevant for modeling.

        Once we reach the 'K:' line (the key signature), we start including melody lines,
        which contain the actual note sequences.

        Rules for default L:
        - If no L: line is provided, compute decimal value of the meter (M:)
        - If meter >= 0.75 → default to L:1/8
        - If meter <  0.75 → default to L:1/16
    """
    lines = tune.strip().splitlines()

    meter_line = None
    length_line = None
    key_line = None
    melody_lines = []

    header_done = False

    for line in lines:
        line = line.strip()
        if not line:
            continue

        if not header_done:
            if line.startswith("M:") and meter_line is None:
                meter_line = line
            elif line.startswith("L:") and length_line is None:
                length_line = line
            elif line.startswith("K:") and key_line is None:
                key_line = line
                header_done = True  # everything after this is melody
        else:
            melody_lines.append(line)

    # Synthesize L: if missing
    if length_line is None:
        default_length = "1/8"  # fallback
        if meter_line:
            try:
                meter_value = meter_line[2:].strip()
                num, denom = map(int, meter_value.split('/'))
                meter_decimal = num / denom
                default_length = "1/8" if meter_decimal >= 0.75 else "1/16"
            except Exception:
                pass
        length_line = f"L:{default_length}"

    header = []
    if meter_line:
        header.append(meter_line)
    if length_line:
        header.append(length_line)
    if key_line:
        header.append(key_line)

    return ' '.join(header + melody_lines)


melodies = [extract_melody_with_context(tune) for tune in clean_tunes]
print("Example melody:")
print(melodies[10])

Example melody:
M:6/8 L:1/8 K:D e|:"D"dcd "A"FAB|"D"dcd "G"BAG|"D"FAd "G"GBd|"Em"Ged "A"cBc| "D"dcd "A"FAB|"D"dcd "G"B2d|"Em"Bed "A"cag|"D"fdc d2e:| |:"D"fdA "G"g3|"C"e=cG "F"=f3|"C"e=cG Gce|"C"=ceg "G"Bdg| "D"fdA "G"g3|"A"ecA "D"f3|"Bm"def "Em"ged|"A"cBc "D"d3:|


In [4]:
import re

def normalize_slash_to_half(abc_str):
    """
    Replaces note durations written with a dangling '/' (e.g., 'b/') 
    with '/2' to standardize fractional durations.
    """
    return re.sub(r'([=^_]*[A-Ga-grz][\',]?)\/(?!\d)', r'\1/2', abc_str)

def normalize_chord_formatting(abc_str):
    """
    Cleans up chord notation:
    - Removes parentheses around chords (e.g., "(A7)" → "A7")
    - Fixes space typos (e.g., "D m" → "Dm")
    - Strips leading/trailing whitespace inside quotes (e.g., '" Em"' → '"Em"')
    """

    def clean_chord(match):
        chord = match.group(0)
        chord = chord.strip('"() ')       # remove outer symbols and spaces
        chord = chord.replace(" ", "")    # squash space typos like "D m"
        return f'"{chord}"'

    return re.sub(r'"[^"]+"', clean_chord, abc_str)

def preprocess_abc(abc_str):
    abc_str = normalize_slash_to_half(abc_str)
    abc_str = normalize_chord_formatting(abc_str)
    return abc_str

normalized_melodies = [preprocess_abc(melody) for melody in melodies]

## Tokenization

In [5]:
def tokenize_abc(abc_str):
    """
    Tokenizes an ABC notation string into musically meaningful symbols.

    The tokenizer extracts:
        - Bar lines and repeat symbols: '|', '||', ':|', '|:', '::'
        - Repeat brackets: '[1', '[2'
        - Chords: enclosed in double quotes, e.g., "G", "Dmin"
        - Notes: including accidentals (=, ^, _), octave markers (',), rests (z), and durations (e.g., A3/2)
        - Headers: 
            - M: (time signature), e.g., M:4/4
            - L: (default note length), e.g., L:1/8
            - K: (key signature), e.g., K:D#

    Parameters:
        abc_str (str): A string in ABC notation.

    Returns:
        list[str]: A list of tokens representing the music in ABC format.
    """

    tokens = re.findall(r'''
          \[1|\[2                          # repeat brackets [1, [2
        | :\|\|                            # :|| (repeat + end)
        | :\|                              # :|  (end of repeat)
        | \|:                              # |:  (start of repeat)
        | ::                               # ::  (double repeat)
        | \|\|                             # ||  (end of section)
        | \|                               # |   (single bar)
        | \|\||\|                          # barlines
        | "[^"]+"                          # chords, e.g., "G", "Dmin"
        | [=^_]*[A-Ga-grz][\',]?\d*\/?\d*  # notes (d2, A3/2, z, etc)
        | M:\d+\/\d+                       # time signature
        | L:\d+\/\d+                       # default note length
        | K:[A-G][#b]?m?                   # key signature
        ''',
        abc_str,
        re.VERBOSE
    )
    return tokens

tokenized_melodies = [tokenize_abc(melody) for melody in normalized_melodies]
tokens = tokenize_abc(melodies[0])
print(tokens)

['M:3/4', 'L:1/4', 'K:G', 'e', '|:', '"G"', 'd2', 'B', '|', '"D"', 'A3/2', 'B/2', 'c', '|', '"G"', 'B2', 'G', '|', '"D"', 'A2', 'e', '|', '"G"', 'd2', 'B', '|', '"D"', 'A3/2', 'B/2', 'c', '|', 'M:2/4', '"F"', 'B', '=F', '|', 'M:3/4', '"G"', 'G2', 'e', ':||', '"C"', 'g2', 'e', '|', '"Bb"', '=f2', 'd', '|', '"F"', 'c2', 'A', '|', '=F2', 'e', '|', '"C"', 'g2', 'e', '|', '"Bb"', '=f2', 'd', '|', 'M:2/4', '"F"', 'c', 'A', '|', 'M:3/4', '[1', '"G"', 'G2', 'e', ':|', '[2', '"G"', 'G2', 'z', '||']


In [6]:
def run_tokenizer_tests():
    test_cases = [
        {
            "input": 'M:4/4 L:1/8 K:C CDEF GABc|',
            "expected": ['M:4/4', 'L:1/8', 'K:C', 'C', 'D', 'E', 'F', 'G', 'A', 'B', 'c', '|']
        },
        {
            "input": 'K:G "D7" D2 G2 |: B4 :|',
            "expected": ['K:G', '"D7"', 'D2', 'G2', '|:', 'B4', ':|']
        },
        {
            "input": 'L:1/16 M:6/8 K:D z3 [1 A2 B2 :||',
            "expected": ['L:1/16', 'M:6/8', 'K:D', 'z3', '[1', 'A2', 'B2', ':||']
        },
        {
            "input": 'K:D | A3/2 B/2 c\'/ | "G" G,2 |',
            "expected": ['K:D', '|', 'A3/2', 'B/2', "c'/", '|', '"G"', 'G,2', '|']
        }
    ]

    for i, case in enumerate(test_cases, 1):
        result = tokenize_abc(case['input'])
        if result != case['expected']:
            print(f"[FAIL] Test {i}")
            print("Input:   ", case['input'])
            print("Expected:", case['expected'])
            print("Got:     ", result)
        else:
            print(f"[PASS] Test {i}")

run_tokenizer_tests()

[PASS] Test 1
[PASS] Test 2
[PASS] Test 3
[PASS] Test 4


## Building a vocabulary

In [7]:
from collections import Counter

def build_vocab(tokenized_melodies):
    """
    Builds a vocabulary (token-to-index mapping) from a list of tokenized ABC melodies.

    Parameters:
        tokenized_melodies (List[List[str]]): A list where each melody is a list of ABC tokens.

    Returns:
        vocab (Dict[str, int]): Mapping from tokens to unique integer indices.
        indexed_melodies (List[List[int]]): Each melody as a list of integer token indices.
        token_freq (Counter): Frequency count of all tokens.
    """
    # Flatten all tokenized melodies into a single list of tokens
    all_tokens = [token for melody in tokenized_melodies for token in melody]

    # Count how often each token appears
    token_freq = Counter(all_tokens)

    # Create vocabulary: map each unique token to a unique index
    # Start from 1 so that 0 can be reserved for padding
    vocab = {token: idx for idx, (token, _) in enumerate(token_freq.items(), start=1)}
    vocab["<PAD>"] = 0  # Add padding token (index 0)

    # Convert each tokenized melody to a list of token indices
    indexed_melodies = [[vocab[token] for token in melody] for melody in tokenized_melodies]

    return vocab, indexed_melodies, token_freq


# Build vocab from your tokenized melodies
vocab, indexed_melodies, token_freq = build_vocab(tokenized_melodies)
inv_vocab = {v: k for k, v in vocab.items()}
print(f"Unique tokens: {len(vocab)}")

Unique tokens: 440


In [38]:
from collections import defaultdict

# Group tokens into categories
grouped_tokens = defaultdict(list)

for token in vocab:
    if token == "<PAD>":
        grouped_tokens["Special"].append(token)
    elif re.match(r'^K:[A-G][#b]?m?$', token):
        grouped_tokens["Key"].append(token)
    elif re.match(r'^M:\d+/\d+$', token):
        grouped_tokens["TimeSig"].append(token)
    elif re.match(r'^L:\d+/\d+$', token):
        grouped_tokens["NoteLength"].append(token)
    elif re.match(r'^"[^"]+"$', token):
        grouped_tokens["Chords"].append(token)
    elif re.match(r'^\[1|\[2$', token):
        grouped_tokens["RepeatBrackets"].append(token)
    elif re.match(r'^::|:\|\||\|:|:\||\|\||\|$', token):
        grouped_tokens["Barlines"].append(token)
    elif re.match(r'^[=^_]*[A-Ga-grz][\',]?\d*/?\d*$', token):
        grouped_tokens["Notes"].append(token)
    else:
        grouped_tokens["Other/Weird"].append(token)

# Print grouped tokens
for group, tokens in grouped_tokens.items():
    print(f"\n--- {group} ({len(tokens)} tokens) ---")
    for token in sorted(tokens):
        print(token)


--- TimeSig (9 tokens) ---
M:2/2
M:2/4
M:3/2
M:3/4
M:4/4
M:5/4
M:6/4
M:6/8
M:9/8

--- NoteLength (2 tokens) ---
L:1/4
L:1/8

--- Key (14 tokens) ---
K:A
K:Am
K:B
K:Bb
K:Bm
K:C
K:Cm
K:D
K:Dm
K:E
K:Em
K:F
K:G
K:Gm

--- Notes (298 tokens) ---
=A
=A/2
=B
=B,
=B/2
=B2
=B3
=C
=C/2
=D
=D/2
=E
=E/2
=E/4
=F
=F/2
=F2
=F3
=G
=G/2
=G/4
=G2
=G3/2
=a/2
=a2
=b/2
=c
=c'2
=c/2
=c/4
=c2
=c3
=c3/2
=c3/4
=c4
=d
=d/2
=e
=e/2
=f
=f/2
=f2
=f3
=f3/2
=f4
=g
=g/2
=g2
A
A,
A,/2
A,2
A,3
A,4
A/2
A/4
A2
A3
A3/2
A3/4
A4
A6
B
B,
B,/2
B,/4
B,2
B,3
B,3/2
B,6
B/2
B/4
B2
B3
B3/2
B3/4
B4
B5
B6
C
C/2
C/4
C2
C3
C3/2
C4
C6
D
D/2
D/4
D2
D3
D3/2
D3/4
D4
D6
E
E/2
E/4
E2
E3
E3/2
E3/4
E4
E6
F
F/2
F/4
F2
F3
F3/2
F3/4
F4
F6
G
G,
G,/2
G,2
G,3
G/2
G/4
G2
G3
G3/2
G3/4
G4
G6
^A
^A,
^A/2
^A/4
^A2
^A3
^A3/2
^B
^B,/2
^B/2
^B2
^C
^C/2
^C2
^C3
^D
^D/2
^D2
^D3
^D3/2
^E
^F
^F/2
^F/4
^F2
^F3
^G
^G/2
^G/4
^G2
^G3
^G3/2
^G3/4
^a
^a/2
^c
^c/2
^c/4
^c2
^c3
^c3/2
^c3/4
^d
^d/2
^d/4
^d2
^d3
^d3/2
^e
^e/2
^e2
^e3
^f
^f/2
^f2
^f3
^f3/4
^g
^g/2
^g/4
^

## Creating a PyTorch Dataset

In [8]:
import torch
from torch.utils.data import Dataset

class NottinghamDataset(Dataset):
    def __init__(self, indexed_sequences, vocab, inv_vocab, window_size=32):
        """
        Args:
            indexed_sequences (List[List[int]]): List of tokenized melody sequences (as integers).
            vocab (Dict[str, int]): Token-to-index mapping.
            inv_vocab (Dict[int, str]): Index-to-token mapping.
            window_size (int): Number of tokens to include in each input window.
        """
        self.window_size = window_size
        self.vocab = vocab
        self.inv_vocab = inv_vocab
        self.inputs = []
        self.labels = []

        # Go through each melody and turn it into sliding input windows
        for seq in indexed_sequences:
            context = self.extract_context(seq)  # this grabs the first 3 tokens (assumed M:/L:/K:)
            body = seq[3:]  # remove M:/L:/K: from melody body
            
            windows, labels = self.generate_windows(body, context)
            self.inputs.extend(windows)
            self.labels.extend(labels)

    def extract_context(self, seq):
        return seq[:3]  # assuming M:/L:/K: are always at the beginning

    def generate_windows(self, seq, context):
        windows = []
        labels = []
        for i in range(len(seq) - self.window_size):
            input_window = seq[i:i+self.window_size]
            target_token = seq[i+self.window_size]
            full_input = context + input_window
            windows.append(torch.tensor(full_input, dtype=torch.long))
            labels.append(torch.tensor(target_token, dtype=torch.long))
        return windows, labels

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return self.inputs[idx], self.labels[idx]

WINDOW_SIZE = 16
dataset = NottinghamDataset(indexed_melodies, vocab, inv_vocab, WINDOW_SIZE)
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_ds, val_ds, test_ds = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])

In [9]:
print(len(dataset))

141585


In [41]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_ds, batch_size=1024, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=1024)
test_loader = DataLoader(test_ds, batch_size=1024)

## Training the model

We’ll use a simple feedforward neural network:
- Inputs are token IDs (integers), so we first use an embedding layer to convert each token into a vector.
- Then we flatten the embedded input (which is a 2D sequence) into a 1D vector.
- Then feed it through one fully connected layers.
- Final output is a vector of length equal to vocab_size, representing logits for each possible next token.

In [42]:
import torch.nn as nn

class MelodyModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, dropout):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.fc1 = nn.Linear((WINDOW_SIZE + 3) * embed_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=dropout)
        self.fc2 = nn.Linear(hidden_dim, vocab_size)  # output logits for all tokens

    def forward(self, x):
        # x: (batch_size, seq_len)
        x = self.embedding(x)  # -> (batch_size, seq_len, embed_dim)
        x = x.view(x.size(0), -1)  # flatten: (batch_size, seq_len * embed_dim)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)  # logits (no softmax here since crossentropy expects logits)
        return x

In [43]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = MelodyModel(vocab_size=len(vocab), embed_dim=128, hidden_dim=512, dropout=0.5)
model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)
loss_fn = nn.CrossEntropyLoss()

In [44]:
def train_one_epoch(model, dataloader, loss_fn, optimizer):
    model.train()
    total_loss = 0

    for inputs, targets in dataloader:
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        logits = model(inputs)
        loss = loss_fn(logits, targets)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)

In [45]:
def evaluate(model, dataloader, loss_fn):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)
            logits = model(inputs)
            loss = loss_fn(logits, targets)
            total_loss += loss.item()

    return total_loss / len(dataloader)

In [46]:
epochs = 100
for epoch in range(epochs):
    train_loss = train_one_epoch(model, train_loader, loss_fn, optimizer)
    val_loss = evaluate(model, val_loader, loss_fn)
    print(f"Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")

Epoch 1: Train Loss = 3.1084, Val Loss = 2.5784
Epoch 2: Train Loss = 2.5194, Val Loss = 2.3888
Epoch 3: Train Loss = 2.3320, Val Loss = 2.2881
Epoch 4: Train Loss = 2.2109, Val Loss = 2.2246
Epoch 5: Train Loss = 2.1210, Val Loss = 2.1710
Epoch 6: Train Loss = 2.0478, Val Loss = 2.1369
Epoch 7: Train Loss = 1.9913, Val Loss = 2.1034
Epoch 8: Train Loss = 1.9373, Val Loss = 2.0779
Epoch 9: Train Loss = 1.9032, Val Loss = 2.0613
Epoch 10: Train Loss = 1.8659, Val Loss = 2.0486
Epoch 11: Train Loss = 1.8344, Val Loss = 2.0355
Epoch 12: Train Loss = 1.8151, Val Loss = 2.0198
Epoch 13: Train Loss = 1.7920, Val Loss = 2.0115
Epoch 14: Train Loss = 1.7763, Val Loss = 1.9970
Epoch 15: Train Loss = 1.7609, Val Loss = 1.9908
Epoch 16: Train Loss = 1.7499, Val Loss = 1.9892
Epoch 17: Train Loss = 1.7341, Val Loss = 1.9806
Epoch 18: Train Loss = 1.7276, Val Loss = 1.9737
Epoch 19: Train Loss = 1.7179, Val Loss = 1.9700
Epoch 20: Train Loss = 1.7125, Val Loss = 1.9657
Epoch 21: Train Loss = 1.7056

In [47]:
test_loss = evaluate(model, test_loader, loss_fn)
print(f"Test Loss: {test_loss:.4f}")

Test Loss: 1.8893


In [51]:
def generate_music_sample(model, loader, vocab, inv_vocab, num_tokens=100):
    model.eval()
    with torch.no_grad():
        for batch_x, _ in loader:
            batch_x = batch_x.to(device)
            break

        input_seq = batch_x[0].unsqueeze(0)  # shape: [1, seq_len]
        generated = input_seq.clone()

        context_tokens = input_seq[:, :3]  # M:, L:, K:
        rolling_window = input_seq[:, 3:]  # actual melody

        for _ in range(num_tokens):
            rolling_window = generated[:, -WINDOW_SIZE:]  # keep last notes only
            input_window = torch.cat((context_tokens, rolling_window), dim=1)
            
            output = model(input_window)
            next_token = torch.argmax(output, dim=-1)
            next_token = next_token.unsqueeze(1)
            generated = torch.cat((generated, next_token), dim=1)

        generated_tokens = [inv_vocab[idx.item()] for idx in generated[0]]
        abc_sequence = ''.join(generated_tokens)

    return abc_sequence

In [55]:
abc = generate_music_sample(model, test_loader, vocab, inv_vocab)
print(abc)

M:4/4L:1/4K:DD/2|F/2AF/2A3/2A/2|"G"B/2A/2B/2c/2"D"d/2B/2A/2G/2|"Em"FE"A7"EF/2G/2|"D"A/2B/2A/2F/2Ad|"D"Add/2e/2f/2g/2|"D"afdf|"A7"edcB|"D"AFAd|"A7"cAAA|"D"FA"A7"AF|"D"FAAA|"D"dAFA|"D"dfdf|"Em"efed|"A7"cBAG|"D"FA"A7"AF|"D"FAdA
