In [1]:
import random

import numpy as np
from torch.utils.data import DataLoader
import torch
from src.model.model import MusicTransformer
from pathlib import Path

from src.utils.data.guitar_dataset import GuitarDataset
from src.utils.data.random_guitar_seq_dataset import RandomGuitarSeqDataset
from src.utils.hyperparameters import BATCH_SIZE, BLOCK_SIZE, EMBEDDING_DIM, N_LAYER, N_HEAD, DROPOUT, VOCAB_SIZE, \
    LEARNING_RATE

In [2]:
data_root = Path("../data/")
midi_folder = data_root / "midi/"
augmented_folder = data_root / "augmented/"
tokenized_folder = data_root / "tokenized/"
splits_folder = data_root / "splits/"
train_tok_folder = tokenized_folder / "train-aug/"
val_tok_folder = tokenized_folder / "val/"
train_midi_folder = data_root / "train-midi/"
val_midi_folder = data_root / "val-midi/"

In [3]:
# hyper-parameters
batch_size = BATCH_SIZE

block_size = BLOCK_SIZE
n_embd = EMBEDDING_DIM
vocab_size = VOCAB_SIZE
n_layer = N_LAYER
n_head = N_HEAD
dropout = DROPOUT

learning_rate = LEARNING_RATE
training_split = 0.8

In [4]:
# setting the device
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [5]:
# AMP setup, works only on CUDA
use_amp = (device == "cuda")
if use_amp and hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported():
    amp_dtype = torch.bfloat16
else:
    amp_dtype = torch.float16

scaler = torch.amp.GradScaler('cuda', enabled=use_amp)

In [6]:
# prepare datasets
train_files = sorted(train_tok_folder.glob("*.json"))
val_files = sorted(val_tok_folder.glob("*.json"))
random.shuffle(train_files)

train_ds = GuitarDataset(block_size=block_size, stride=block_size // 2, file_list=train_files)
val_ds = GuitarDataset(block_size=block_size, stride=block_size // 2, file_list=val_files)
# train_ds = RandomGuitarSeqDataset(block_size=block_size, epoch_len=2000, file_list=all_files[:split])
# val_ds = RandomGuitarSeqDataset(block_size=block_size,  epoch_len=400, file_list=all_files[split:])

In [7]:
train_dl = DataLoader(train_ds, batch_size=32, shuffle=True, drop_last=True)
val_dl = DataLoader(val_ds, batch_size=32, shuffle=True, drop_last=True)

In [8]:
print("Training tokens count", train_dl.dataset.tokens)
print("Validation tokens count", val_dl.dataset.tokens)

Training tokens count 574897
Validation tokens count 27587


In [9]:
model = MusicTransformer(
    vocab_size=vocab_size,
    n_embd=n_embd,
    n_head=n_head,
    n_layer=n_layer,
    block_size=block_size,
    dropout=dropout
).to(device)

In [10]:
# using AdamW optimisation
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.95), weight_decay=0.1)

In [11]:
print("Parameter count: ", sum([p.numel() for p in model.parameters()]))
print("Training on ", device)
print("Using amp?", use_amp)

Parameter count:  4728248
Training on  mps
Using amp? False


In [12]:
# from src.model.model import Head
# import numpy as np
# x_test, y_test = train_ds.__getitem__(56)
# x_test, y_test = x_test.view(1,-1).to(device), y_test.view(1,-1).to(device)
#
print(-np.log(1/8000))
#
# logits, loss = model(x_test, y_test)
# loss

8.987196820661973


In [13]:
epochs = 20
V = vocab_size
lnV = np.log(V)

for epoch in range(epochs):
    # ---- train -----
    model.train()
    for x, y in train_dl:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad(set_to_none=True)

        # using torch.autocast here with device_type to avoid backend-specific contexts
        with torch.amp.autocast('cuda', enabled=torch.cuda.is_available()):
            logits, loss = model(x, y)

        if use_amp:
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

    # ----- validate -----
    model.eval()
    val_loss, total_tokens = 0.0, 0
    with torch.no_grad():
        for x, y in val_dl:
            x, y = x.to(device), y.to(device)
            _, l = model(x, y) # loss here is already mean loss per token for this batch
            # as we have already reshaped logits in forward when computing loss against targets
            num_tokens = y.numel()
            val_loss += l.item() * num_tokens
            total_tokens += num_tokens

    avg_loss = val_loss / total_tokens
    ppl = np.exp(avg_loss)
    bpc = avg_loss / np.log(2)
    improv_ratio = V / ppl
    delta_nats = lnV - avg_loss

    print(
        f"epoch {epoch:03d} "
        f"train {loss.item():.4f} "
        f"val_loss {avg_loss:.4f}  ppl {ppl:.0f}  "
        f"bpc {bpc:.3f}  Δnats {delta_nats:.3f}  x-better {improv_ratio:.2f}x  (lnV {lnV:.3f})"
    )

epoch 000 train 5.9336 val_loss 5.9081  ppl 368  bpc 8.524  Δnats 2.098  x-better 8.15x  (lnV 8.006)
epoch 001 train 4.2858 val_loss 4.7746  ppl 118  bpc 6.888  Δnats 3.232  x-better 25.33x  (lnV 8.006)
epoch 002 train 3.5772 val_loss 4.3915  ppl 81  bpc 6.336  Δnats 3.615  x-better 37.15x  (lnV 8.006)
epoch 003 train 3.2622 val_loss 4.2553  ppl 70  bpc 6.139  Δnats 3.751  x-better 42.57x  (lnV 8.006)
epoch 004 train 2.9729 val_loss 4.2022  ppl 67  bpc 6.062  Δnats 3.804  x-better 44.89x  (lnV 8.006)
epoch 005 train 2.6844 val_loss 4.2039  ppl 67  bpc 6.065  Δnats 3.802  x-better 44.81x  (lnV 8.006)
epoch 006 train 2.3972 val_loss 4.1577  ppl 64  bpc 5.998  Δnats 3.849  x-better 46.93x  (lnV 8.006)
epoch 007 train 2.2595 val_loss 4.2136  ppl 68  bpc 6.079  Δnats 3.793  x-better 44.38x  (lnV 8.006)
epoch 008 train 2.0058 val_loss 4.2436  ppl 70  bpc 6.122  Δnats 3.763  x-better 43.07x  (lnV 8.006)
epoch 009 train 1.9471 val_loss 4.2818  ppl 72  bpc 6.177  Δnats 3.725  x-better 41.45x  (

In [63]:
x, y = val_dl.dataset.__getitem__(15)
temp = x[:500]
out = model.generate(temp.view(1, -1).to(device), max_new_tokens=2500).cpu()
out

tensor([[529,  26, 518,  ..., 490, 995, 490]])

In [64]:
from miditok import REMI, TokenizerConfig

tokenizer = REMI(params=Path("../data/tokenized/config/tokenizer.json"))
print("Is trained", tokenizer.is_trained)
out_midi = tokenizer.decode(out[0])
temp_midi = tokenizer.decode(temp)

Is trained True


  super().__init__(tokenizer_config, params)


In [65]:
out_midi.dump_midi(data_root / "output" / "test_gen.mid")
temp_midi.dump_midi(data_root / "output" / "input_input.mid")