In [1]:
import torch
from torch import nn, Tensor
from typing import Set, Union
from tokenizer import ABCTokenizer


def sample(
    model: nn.Module,
    tokenizer: ABCTokenizer,
    *,
    max_len: int = 1024,
    temperature: float = 1.0,
    top_p: float = 0.9,
    force_tokens: Set[str] = {"|", "|:", ":|", "::", "||"},
    force_threshold: float = 0.55,
    device: Union[str, torch.device] = "cpu",
    prime_str: str = "L:1/16\nM:4/4\nK:Amin\n",
) -> str:
    """
    Generate an ABC tune with nucleus sampling while forcing arg-max
    for certain syntax-critical tokens (e.g. barlines) whenever the model
    already assigns them >= `force_threshold` probability.

    Parameters
    ----------
    force_tokens : set[str] or None
        Strings that should be taken greedily when the model is confident
        enough.
    force_threshold : float
        Probability above which we override randomness and pick the token.
    """

    model.eval()
    ids = tokenizer.encode(prime_str)
    inp = torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0)
    _, h = model(inp, None, return_hidden=True) # warm‑up hidden state
    cur_id = ids[-1]
    generated = ids.copy()

    for _ in range(max_len):
        x = torch.tensor([[cur_id]], device=device)
        logits, h = model(x, h,  return_hidden=True)  # shape [1,1,V]
        logits = logits[0, 0] / temperature           # scale “spikiness”

        # Soft‑max -> probabilities
        probs = torch.softmax(logits, dim=-1)

        # -------- 1. FORCE‑ARGMAX ---------------------------------------
        # If one of the force_tokens exceeds threshold, pick it directly.
        forced = None
        for tok in force_tokens:
            tid = tokenizer.stoi.get(tok)
            if tid is not None and probs[tid] >= force_threshold:
                forced = tid
                break
        if forced is not None:
            next_id = forced

        # -------- 2. NUCLEUS SAMPLING -----------------------------------
        else:
            sorted_probs, sorted_idx = torch.sort(probs, descending=True)
            cumsum = torch.cumsum(sorted_probs, dim=0)
            cutoff = (cumsum > top_p).nonzero(as_tuple=False)
            last = cutoff[0, 0] + 1 if cutoff.numel() else len(sorted_idx)
            pool_probs = sorted_probs[:last]
            pool_idx = sorted_idx[:last]
            pool_probs /= pool_probs.sum() # renormalise
            next_id = int(pool_idx[torch.multinomial(pool_probs, 1)])

        generated.append(next_id)
        cur_id = next_id

        # stop after blank line ends tune body
        if tokenizer.itos[cur_id] == "\n" and len(generated) > len(ids) + 800:
            break

    return tokenizer.decode(generated)

In [60]:
import torch, json
from torch.utils.data import random_split
from dataset   import LeadSheetDataset, collate_fn

# Read raw ABC file
with open("leadsheets.abc", "r") as f:
    raw_data = f.read()

with open("cache/abc_maps.json", "r", encoding="utf-8") as f:
    maps = json.load(f)

chord_map  = maps["chord_map"]
header_map = maps["header_map"]
inline_map = maps["inline_map"]

tunes = raw_data.strip().split("\n\n")
tokenizer = ABCTokenizer(chord_map=chord_map, header_map=header_map, inline_hdr_map=inline_map)
tokenizer.build_vocab(tunes)
vocab_size = tokenizer.vocab_size()

dataset = LeadSheetDataset(tunes, tokenizer)
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_ds, val_ds, test_ds = random_split(dataset, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(42))

pad_idx = tokenizer.stoi[tokenizer.pad_token]
collate = lambda batch: collate_fn(batch, pad_idx)

In [3]:
from models import LSTMModel, GRUModel

# Load the model architecture
lstm_model = LSTMModel(embed_dim=512, hidden_dim=1024, vocab_size=vocab_size, num_layers=2, dropout=0.3).to("cpu")
gru_model = GRUModel(embed_dim=512, hidden_dim=1024, num_layers=2, vocab_size=vocab_size, dropout=0.2).to("cpu")

# Load the checkpoint files
checkpoint_lstm = torch.load('saved_models/LSTM_model.pt', map_location="cpu", weights_only=True)
checkpoint_gru  = torch.load('saved_models/GRU_model.pt', map_location="cpu", weights_only=True)

# Load the state dicts
lstm_model.load_state_dict(checkpoint_lstm['model_state_dict'])
gru_model.load_state_dict(checkpoint_gru)

lstm_total_params = sum(p.numel() for p in lstm_model.parameters())
gru_total_params = sum(p.numel() for p in gru_model.parameters())
print(lstm_model)
print(f"Number of parameters: {lstm_total_params:,}")
print()
print(gru_model)
print(f"Numbers of parameters: {gru_total_params:,}")

LSTMModel(
  (embed): Embedding(215, 512, padding_idx=0)
  (lstm): LSTM(512, 1024, num_layers=2, batch_first=True, dropout=0.3)
  (fc_out): Linear(in_features=1024, out_features=215, bias=True)
  (dropout): Dropout(p=0.3, inplace=False)
)
Number of parameters: 15,026,903

GRUModel(
  (embed): Embedding(215, 512, padding_idx=0)
  (gru): GRU(512, 1024, num_layers=2, batch_first=True, dropout=0.2)
  (fc_out): Linear(in_features=1024, out_features=215, bias=True)
)
Numbers of parameters: 11,352,791


In [59]:
# Final sample GRU
print("\n=== GRU Sample ===")
print(sample(gru_model, tokenizer, device="cpu", prime_str = '''L:1/4\n''', temperature=0.9, top_p=0.95))


=== GRU Sample ===
L:1/4
K:C
[K:G] D/F/ |"G" G B c/B/ |"D" A2 D |"G" B G B |"D" c A A |"G" B G G |"Em" B2 c/B/ |"Am" A c B |"D" A2 D | 
"G" G B c |"G" d2 B |"Am" A c B |"D" A2 d |"Am" c B A |"D" d c B |"G" G2 B |"D" A G E | 
"Am" A2 E/G/ |"D" F D d |"G" B G G |"D" A D d |"G" B G G |"C" c2 B |"Am" c2 E |"D7" D2 G |"G" G3 |] 
 B/c/ |"G" d B/c/ d |"C" e c/e/ g |"C" c e g |"D" f e d |"Am" e c/B/ A |"C" e d/c/ B |"Am" A3 |"D" A3 | 
"D" d e f |"G" g d B |"Am" e d c |"Em" B/A/ G E |"Am" A G E |"D" D3 |"G" D G B |"Am" c A B | 
"Am" c B/A/ G/F/ |"Em" E3 |]"C" e f g |"G" d c/B/ A/G/ |"Am" A/B/ c/B/ c/d/ |"D7" e d c | 
"G" B A G |"Em" E3 |"Am" E A G |"D" F D d |"G" d B/c/ d/c/ |"Em" B G G |"Am" c e e |"D" d3 | 
"G" B G B |"Am" c A c |"D" d2 B/c/ |"G" d3 |"G" B G B |"G" d c B |"Am" A2 G/E/ |"D7" D E F |"G" G3 | 
"G" G3 |][M:3/4][K:D]"D" A d e | f d f |"A" e c A | e3 | A c e | A3 | 
 A A c | e c A | a c A |"G" B3 |"A" e f e | A c e |"G" g b g |"D" f a f |"A7" e2 g |"D" f3 |] 
 |:"D" a2 a | a b a |

In [5]:
# Final sample LSTM
print("\n=== LSTM Sample ===")
print(sample(lstm_model, tokenizer, device="cpu", prime_str = '''L:1/8\nM:3/4\nK:Amin\n''', temperature=1.0, top_p=0.9))


=== LSTM Sample ===
L:1/8
M:3/4
K:Amin
 E |"Am" A2 A B3 | c2 B2 A2 |"E" e4 e2 | e4 e2 | e2 f3 e | g2 f2 e2 | e6 | e4 E2 | 
"Am" A2 A3 B | c3 d e2 | d2 c3 B | A2 c3 A |"E" B3 A ^G2 |"Am" A3 c e2 | e2 a3 b | c'2 ba gf | 
"Am" e2 d2 c2 | e4 c'2 | e'3 d' c'2 | a2 ^ga ba | c'2 b3 a | g2 e3 c |"Dm" B2 A2 d2 | 
"E" e4 eB ||"Am" e3 d c2 | A4 EE | c4 Ac |"E" B4 E2 | G4 AB |"Am" c4 c2 | d2 e3 ^f | g6 | 
"E" a4 g2 | f2 e2 d2 |"Am" c6- | c2 d2 e2 |"Dm" f6 | f2 e2 d2 |"Am" e4 c2 | A4 |] E2 | 
"Am" E2 A2 ^G2 | A2 c2 c2 |"E" B3 E cB |"Am" A4 E2 | A3 B c2 |"Dm" d4 c2 |"E" B6- | B4 ee | 
"E" e2 ^d2 e2 |"Am" A4 c2 |"Dm" B3 A ^GB |"E" A2 ^G2 E2 |"Am" A6- | A2 z2 |] E2 |"Am" E4 A2 | 
 c4 e2 |"E" B4 E2 | ^G4 E2 |"Am" A6- | A2 z2 ee |"Dm" d3 c B2 | A4 B2 |"Am" c6- | c2 c3 B | 
"Dm" A2 f2 d2 | B3 A B2 |"Am" c3 d c2 |"E" B4 E2 |"Am" A6 | A4 || c2 |"Am" e3 d c2 | A4 A2 | c4 e2 | 
"E" B4 e2 |"Am" c3 A E2 | c4 B2 |"Dm" A3 B c2 | d2 ^c2 d2 |"E" e4- eB |"Am" A4 || AB | 

