In [30]:
import miditok
import miditoolkit
import os
import json

melody_folder = 'MIDI/melody'
token_folder = 'nottingham_tokens'
os.makedirs(token_folder, exist_ok=True)

tokenizer = miditok.REMI()

def get_token_ids(tokens):
    # Case 1: tokens is a TokSequence
    if hasattr(tokens, "ids"):
        return tokens.ids
    # Case 2: tokens is a list of TokSequence
    if isinstance(tokens, list) and len(tokens) > 0 and hasattr(tokens[0], "ids"):
        return [t.ids for t in tokens]
    # Case 3: tokens is a plain list (already token IDs)
    return tokens

jig_tempos = []

for fname in os.listdir(melody_folder):
    if fname.endswith('.mid') and fname.lower().startswith("jigs"):
        midi = miditoolkit.MidiFile(os.path.join(melody_folder, fname))
        tokens = tokenizer(midi)
        token_ids = get_token_ids(tokens)
        # ---- Store tempo for average calculation ----
        if midi.tempo_changes:
            jig_tempos.append(midi.tempo_changes[0].tempo)
        else:
            jig_tempos.append(120)  # Fallback/default if not set
        # ---- Save tokenized output as before ----
        with open(os.path.join(token_folder, fname.replace('.mid', '.json')), "w") as fp:
            json.dump({'ids': token_ids}, fp)

tokenizer.save_params(token_folder)

# ---- Calculate average BPM for jigs ----
avg_bpm = sum(jig_tempos) / len(jig_tempos) if jig_tempos else 120
avg_us_per_quarter = 60_000_000 / avg_bpm
print(f"Jig average BPM: {avg_bpm:.2f}, us/quarter: {avg_us_per_quarter:.2f}")



from torch.utils.data import Dataset, DataLoader
import torch

class MIDITokenDataset(Dataset):
    def __init__(self, token_folder):
        self.files = [
            os.path.join(token_folder, f)
            for f in os.listdir(token_folder)
            if f.endswith(".json") and f != "tokenizer.json" and f.lower().startswith("jigs")
        ]
        param_path = os.path.join(token_folder, "tokenizer.json")
        self.tokenizer = miditok.REMI(params=param_path)
        self.sequences = []
        for file in self.files:
            with open(file, "r") as fp:
                tok_seq = json.load(fp)
            if 'ids' in tok_seq:
                self.sequences.append(tok_seq['ids'])
            else:
                print(f"Skipping {file} (no 'ids' key)")
    def __len__(self):
        return len(self.sequences)
    def __getitem__(self, idx):
        seq = self.sequences[idx]
        x = torch.tensor(seq[:-1], dtype=torch.long)
        y = torch.tensor(seq[1:], dtype=torch.long)
        return x, y



dataset = MIDITokenDataset(token_folder)
loader = DataLoader(dataset, batch_size=8, shuffle=True, drop_last=True)


import torch.nn as nn

def get_positional_encoding(seq_len, d_model, device):
    position = torch.arange(seq_len, dtype=torch.float, device=device).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2, device=device).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
    pe = torch.zeros(seq_len, d_model, device=device)
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    return pe

class MusicTransformer(nn.Module):
    def __init__(
        self,
        vocab_size,
        d_model=256,
        nhead=8,
        num_layers=4,
        dim_feedforward=1024,
        dropout=0.1,
    ):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model, nhead, dim_feedforward, dropout
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        self.fc_out = nn.Linear(d_model, vocab_size)
    
    def forward(self, x):
        # x: (batch, seq)
        batch_size, seq_len = x.shape
        pos_enc = get_positional_encoding(seq_len, self.embed.embedding_dim, x.device).unsqueeze(0)
        x = self.embed(x) + pos_enc
        x = x.transpose(0, 1)
        x = self.transformer(x)
        x = x.transpose(0, 1)
        return self.fc_out(x)



vocab_size = len(dataset.tokenizer.vocab)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = MusicTransformer(vocab_size).to(device)


import torch.optim as optim
from tqdm import tqdm

optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

epochs = 10
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for x, y in tqdm(loader):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out.view(-1, out.size(-1)), y.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss / len(loader):.4f}")


def generate_until_duration(model, tokenizer, device, target_seconds=120, temperature=1.0, max_tokens=2000):
    model.eval()
    start_token = list(tokenizer.vocab.values())[0]
    seq = [start_token]
    x = torch.tensor([seq], dtype=torch.long).to(device)
    elapsed_seconds = 0

    TPQ = 480
    tempo_us_per_quarter = 500_000  # 120 BPM by default

    while elapsed_seconds < target_seconds and len(seq) < max_tokens:
        with torch.no_grad():
            logits = model(x)
            logits = logits[:, -1, :] / temperature
            probs = torch.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).item()
        seq.append(next_token)
        x = torch.tensor([seq], dtype=torch.long).to(device)

        if len(seq) % 32 == 0 or elapsed_seconds == 0:
            midi_obj = tokenizer.tokens_to_midi([seq])
            ticks = midi_obj.end()
            elapsed_seconds = ticks / TPQ * (tempo_us_per_quarter / 1_000_000)

    return seq

def generate_until_duration_smart(
    model, tokenizer, device, target_seconds=25, temperature=1.0, max_tokens=3000,
    tpq=480, tempo_us_per_quarter=500_000, max_tokens_per_beat=4
):
    model.eval()
    start_token = list(tokenizer.vocab.values())[0]
    seq = [start_token]
    x = torch.tensor([seq], dtype=torch.long).to(device)
    elapsed_seconds = 0
    tokens_in_current_beat = 0

    # Get time-shift token IDs for the tokenizer
    time_shift_token_ids = [
        tid for t, tid in tokenizer.vocab.items() if t.startswith("TimeShift")
    ]
    # (For REMI, "TimeShift_XX" are the time advance tokens.)

    while elapsed_seconds < target_seconds and len(seq) < max_tokens:
        with torch.no_grad():
            logits = model(x)
            logits = logits[:, -1, :] / temperature
            probs = torch.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).item()

        # Check if the next token is a time-shift
        if next_token in time_shift_token_ids:
            tokens_in_current_beat = 0
        else:
            tokens_in_current_beat += 1

        # If tokens in this beat exceed limit, force a time-shift
        if tokens_in_current_beat > max_tokens_per_beat:
            time_shift_token_ids = [
                tid for t, tid in tokenizer.vocab.items() if "shift" in t.lower()
            ]
            tokens_in_current_beat = 0

        seq.append(next_token)
        x = torch.tensor([seq], dtype=torch.long).to(device)

        # Periodically check duration
        if len(seq) % 32 == 0 or elapsed_seconds == 0:
            midi_obj = tokenizer.tokens_to_midi([seq])
            ticks = midi_obj.end()
            elapsed_seconds = ticks / tpq * (tempo_us_per_quarter / 1_000_000)

    return seq


# tokens = generate_until_duration(
#     model,
#     dataset.tokenizer,
#     device=device,
#     target_seconds=27,
#     temperature=1.0,
#     max_tokens=3000  # Safety upper limit for shorter jigs
# )

tokens = generate_until_duration_smart(
    model,
    dataset.tokenizer,
    device=device,
    target_seconds=25,  # e.g., 25 seconds
    temperature=1.0,
    max_tokens=3000,
    tpq=480,
    tempo_us_per_quarter=avg_us_per_quarter,
    max_tokens_per_beat=4
)

midi = dataset.tokenizer.tokens_to_midi([tokens])
midi.dump_midi("generated_jig.mid")



  tokens = tokenizer(midi)
  tokenizer.save_params(token_folder)


Jig average BPM: 120.00, us/quarter: 500000.00


100%|██████████| 42/42 [00:00<00:00, 44.06it/s]


Epoch 1, Loss: nan


100%|██████████| 42/42 [00:00<00:00, 43.41it/s]


Epoch 2, Loss: nan


100%|██████████| 42/42 [00:01<00:00, 41.72it/s]


Epoch 3, Loss: nan


100%|██████████| 42/42 [00:00<00:00, 42.39it/s]


Epoch 4, Loss: nan


100%|██████████| 42/42 [00:01<00:00, 40.51it/s]


Epoch 5, Loss: nan


100%|██████████| 42/42 [00:00<00:00, 43.74it/s]


Epoch 6, Loss: nan


100%|██████████| 42/42 [00:00<00:00, 45.83it/s]


Epoch 7, Loss: nan


100%|██████████| 42/42 [00:00<00:00, 46.52it/s]


Epoch 8, Loss: nan


100%|██████████| 42/42 [00:00<00:00, 44.93it/s]


Epoch 9, Loss: nan


100%|██████████| 42/42 [00:00<00:00, 44.17it/s]
  midi_obj = tokenizer.tokens_to_midi([seq])


Epoch 10, Loss: nan


  midi = dataset.tokenizer.tokens_to_midi([tokens])


In [27]:
# To verify duration
from symusic.core import ScoreTick

midi_loaded = ScoreTick.from_file("generated_jig.mid")
duration_seconds = midi_loaded.end() / midi_loaded.ticks_per_quarter * 60 / 120
print(f"Duration (seconds): {duration_seconds:.2f}")

Duration (seconds): 26.31


In [47]:
import os
import json
import torch
from torch.utils.data import Dataset, DataLoader
import miditok
import miditoolkit
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

# === Tokenize Jigs MIDI ===
melody_folder = 'MIDI/melody'  # Update as needed
token_folder = 'nottingham_tokens'
os.makedirs(token_folder, exist_ok=True)
tokenizer = miditok.REMI()

def get_token_ids(tokens):
    if hasattr(tokens, "ids"):
        return tokens.ids
    if isinstance(tokens, list) and len(tokens) > 0 and hasattr(tokens[0], "ids"):
        return [t.ids for t in tokens]
    return tokens

jig_tempos = []
for fname in os.listdir(melody_folder):
    if fname.endswith('.mid') and fname.lower().startswith("jigs"):
        midi = miditoolkit.MidiFile(os.path.join(melody_folder, fname))
        tokens = tokenizer(midi)
        # tokens might be a TokSequence or a list of TokSequence!
        if isinstance(tokens, list):  # If it's a list, flatten all ids
            token_ids = []
            for t in tokens:
                token_ids.extend(t.ids)
        else:  # Single TokSequence
            token_ids = tokens.ids
        # ---- Store tempo for average calculation ----
        if midi.tempo_changes:
            jig_tempos.append(midi.tempo_changes[0].tempo)
        else:
            jig_tempos.append(120)
        # ---- Save tokenized output as before ----
        with open(os.path.join(token_folder, fname.replace('.mid', '.json')), "w") as fp:
            json.dump({'ids': token_ids}, fp)
tokenizer.save_params(token_folder)

avg_bpm = sum(jig_tempos) / len(jig_tempos) if jig_tempos else 120
avg_us_per_quarter = 60_000_000 / avg_bpm
print(f"Jig average BPM: {avg_bpm:.2f}, us/quarter: {avg_us_per_quarter:.2f}")

# === PyTorch Dataset with Padding ===
class MIDITokenDataset(Dataset):
    def __init__(self, token_folder, min_length=32):
        self.files = [
            os.path.join(token_folder, f)
            for f in os.listdir(token_folder)
            if f.endswith(".json") and f != "tokenizer.json" and f.lower().startswith("jigs")
        ]
        param_path = os.path.join(token_folder, "tokenizer.json")
        self.tokenizer = miditok.REMI(params=param_path)
        self.sequences = []
        skipped = []
        for file in self.files:
            with open(file, "r") as fp:
                tok_seq = json.load(fp)
            if 'ids' in tok_seq and len(tok_seq['ids']) > 1:
                self.sequences.append(tok_seq['ids'])
            else:
                skipped.append((file, len(tok_seq['ids']) if 'ids' in tok_seq else 'NO IDS'))
        print(f"Loaded {len(self.sequences)} sequences.")
        print(f"Skipped {len(skipped)} files: {skipped[:10]} ...")  # show first 10 skipped
        if len(self.sequences) == 0:
            raise ValueError("No sequences found. Try lowering min_length or check tokenized data.")
    def __len__(self):
        return len(self.sequences)
    def __getitem__(self, idx):
        seq = self.sequences[idx]
        x = torch.tensor(seq[:-1], dtype=torch.long)
        y = torch.tensor(seq[1:], dtype=torch.long)
        return x, y

def collate_pad(batch):
    xs, ys = zip(*batch)
    lengths = [len(x) for x in xs]
    xs_padded = pad_sequence(xs, batch_first=True, padding_value=0)
    ys_padded = pad_sequence(ys, batch_first=True, padding_value=0)
    return xs_padded, ys_padded, torch.tensor(lengths, dtype=torch.long)

dataset = MIDITokenDataset(token_folder, min_length=1)
loader = DataLoader(dataset, batch_size=8, shuffle=True, drop_last=True, collate_fn=collate_pad)

# === LSTM Model ===
class MusicLSTM(nn.Module):
    def __init__(self, vocab_size, embed_dim=256, hidden_dim=512, num_layers=2):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
    def forward(self, x, lengths, hidden=None):
        x = self.embed(x)
        packed = pack_padded_sequence(x, lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_out, hidden = self.lstm(packed, hidden)
        out, _ = pad_packed_sequence(packed_out, batch_first=True)
        out = self.fc(out)
        return out, hidden

vocab_size = len(dataset.tokenizer.vocab)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = MusicLSTM(vocab_size).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss(ignore_index=0)

# === Training Loop ===
epochs = 10
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for x, y, lengths in tqdm(loader, desc=f"Epoch {epoch+1}"):
        x, y, lengths = x.to(device), y.to(device), lengths.to(device)
        optimizer.zero_grad()
        out, _ = model(x, lengths)
        out = out.view(-1, vocab_size)
        y = y.view(-1)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss / len(loader):.4f}")

# === Generation ===
def generate_until_duration_lstm(
    model, tokenizer, device, target_seconds=27, temperature=1.0, max_tokens=3000,
    tpq=480, tempo_us_per_quarter=500_000
):
    model.eval()
    start_token = list(tokenizer.vocab.values())[0]
    seq = [start_token]
    input_seq = torch.tensor([seq], dtype=torch.long).to(device)
    hidden = None
    elapsed_seconds = 0
    with torch.no_grad():
        while elapsed_seconds < target_seconds and len(seq) < max_tokens:
            out, hidden = model(input_seq, torch.tensor([len(seq)]).to(device), hidden)
            logits = out[:, -1, :] / temperature
            probs = torch.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).item()
            seq.append(next_token)
            input_seq = torch.tensor([seq], dtype=torch.long).to(device)
            if len(seq) % 32 == 0 or elapsed_seconds == 0:
                midi_obj = tokenizer.tokens_to_midi([seq])
                ticks = midi_obj.end()
                elapsed_seconds = ticks / tpq * (tempo_us_per_quarter / 1_000_000)
    return seq

tokens = generate_until_duration_lstm(
    model,
    dataset.tokenizer,
    device,
    target_seconds=27,
    temperature=1.0,
    max_tokens=3000,
    tpq=480,
    tempo_us_per_quarter=avg_us_per_quarter
)

midi = dataset.tokenizer.tokens_to_midi([tokens])
midi.dump_midi("generated_jig_lstm.mid")
print("Saved generated_jig_lstm.mid")


  tokens = tokenizer(midi)
  tokenizer.save_params(token_folder)


Jig average BPM: 120.00, us/quarter: 500000.00
Loaded 340 sequences.
Skipped 0 files: [] ...


Epoch 1: 100%|██████████| 42/42 [24:22<00:00, 34.83s/it] 


Epoch 1, Loss: 2.4462


Epoch 2: 100%|██████████| 42/42 [20:25<00:00, 29.17s/it] 


Epoch 2, Loss: 1.4072


Epoch 3: 100%|██████████| 42/42 [18:10<00:00, 25.97s/it] 


Epoch 3, Loss: 1.3234


Epoch 4: 100%|██████████| 42/42 [18:08<00:00, 25.91s/it]  


Epoch 4, Loss: 1.1337


Epoch 5: 100%|██████████| 42/42 [19:18<00:00, 27.58s/it] 


Epoch 5, Loss: 0.9371


Epoch 6: 100%|██████████| 42/42 [18:16<00:00, 26.10s/it] 


Epoch 6, Loss: 0.8636


Epoch 7: 100%|██████████| 42/42 [18:42<00:00, 26.73s/it]  


Epoch 7, Loss: 0.8448


Epoch 8: 100%|██████████| 42/42 [21:07<00:00, 30.19s/it] 


Epoch 8, Loss: 0.8297


Epoch 9: 100%|██████████| 42/42 [23:48<00:00, 34.00s/it] 


Epoch 9, Loss: 0.8187


Epoch 10: 100%|██████████| 42/42 [23:53<00:00, 34.13s/it] 
  midi_obj = tokenizer.tokens_to_midi([seq])


Epoch 10, Loss: 0.8106
Saved generated_jig_lstm.mid


  midi = dataset.tokenizer.tokens_to_midi([tokens])


In [48]:
# To verify duration
from symusic.core import ScoreTick

midi_loaded = ScoreTick.from_file("generated_jig_lstm.mid")
duration_seconds = midi_loaded.end() / midi_loaded.ticks_per_quarter * 60 / 120
print(f"Duration (seconds): {duration_seconds:.2f}")

Duration (seconds): 240.50
