In [1]:
import random

import numpy as np
from torch.utils.data import DataLoader
import torch
from src.model.model import MusicTransformer
from pathlib import Path

from src.utils.data.guitar_dataset import GuitarDataset
from src.utils.data.random_guitar_seq_dataset import RandomGuitarSeqDataset
from src.utils.hyperparameters import BATCH_SIZE, BLOCK_SIZE, EMBEDDING_DIM, N_LAYER, N_HEAD, DROPOUT, VOCAB_SIZE, \
    LEARNING_RATE

In [2]:
tokenized_folder = Path("../data/tokenized/")

In [3]:
# hyper-parameters
batch_size = BATCH_SIZE

block_size = BLOCK_SIZE
n_embd = EMBEDDING_DIM
vocab_size = VOCAB_SIZE
n_layer = N_LAYER
n_head = N_HEAD
dropout = DROPOUT

learning_rate = LEARNING_RATE
training_split = 0.8

In [4]:
# setting the device
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [5]:
# AMP setup, works only on CUDA
use_amp = (device == "cuda")
if use_amp and hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported():
    amp_dtype = torch.bfloat16
else:
    amp_dtype = torch.float16

scaler = torch.amp.GradScaler('cuda', enabled=use_amp)

In [6]:
# prepare datasets
all_files = sorted(tokenized_folder.glob("*.json"))
random.shuffle(all_files)

split = int(training_split * len(all_files))
# train_ds = GuitarDataset(block_size=block_size, stride=block_size // 2, file_list=all_files[:split])
# val_ds = GuitarDataset(block_size=block_size, stride=block_size // 2, file_list=all_files[split:])
train_ds = RandomGuitarSeqDataset(block_size=block_size, epoch_len=2000, file_list=all_files[:split])
val_ds = RandomGuitarSeqDataset(block_size=block_size,  epoch_len=400, file_list=all_files[split:])

In [7]:
train_dl = DataLoader(train_ds, batch_size=32, shuffle=True, drop_last=True)
val_dl = DataLoader(val_ds, batch_size=32, shuffle=True, drop_last=True)

In [8]:
# print("Training tokens count", train_dl.dataset.tokens)
# print("Validation tokens count", val_dl.dataset.tokens)

In [9]:
model = MusicTransformer(
    vocab_size=vocab_size,
    n_embd=n_embd,
    n_head=n_head,
    n_layer=n_layer,
    block_size=block_size,
    dropout=dropout
).to(device)

In [10]:
# using AdamW optimisation
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.95), weight_decay=0.1)

In [11]:
print("Parameter count: ", sum([p.numel() for p in model.parameters()]))
print("Training on ", device)
print("Using amp?", use_amp)

Parameter count:  7293248
Training on  mps
Using amp? False


In [12]:
# from src.model.model import Head
# import numpy as np
# x_test, y_test = train_ds.__getitem__(56)
# x_test, y_test = x_test.view(1,-1).to(device), y_test.view(1,-1).to(device)
#
print(-np.log(1/8000))
#
# logits, loss = model(x_test, y_test)
# loss

8.987196820661973


In [13]:
epochs = 40
V = vocab_size
lnV = np.log(V)

for epoch in range(epochs):
    # ---- train -----
    model.train()
    for x, y in train_dl:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad(set_to_none=True)

        # using torch.autocast here with device_type to avoid backend-specific contexts
        with torch.amp.autocast('cuda', enabled=torch.cuda.is_available()):
            logits, loss = model(x, y)

        if use_amp:
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

    # ----- validate -----
    model.eval()
    val_loss, total_tokens = 0.0, 0
    with torch.no_grad():
        for x, y in val_dl:
            x, y = x.to(device), y.to(device)
            _, l = model(x, y) # loss here is already mean loss per token for this batch
            # as we have already reshaped logits in forward when computing loss against targets
            num_tokens = y.numel()
            val_loss += l.item() * num_tokens
            total_tokens += num_tokens

    avg_loss = val_loss / total_tokens
    ppl = np.exp(avg_loss)
    bpc = avg_loss / np.log(2)
    improv_ratio = V / ppl
    delta_nats = lnV - avg_loss

    print(
        f"epoch {epoch:03d} "
        f"train {loss.item():.4f} "
        f"val_loss {avg_loss:.4f}  ppl {ppl:.0f}  "
        f"bpc {bpc:.3f}  Δnats {delta_nats:.3f}  x-better {improv_ratio:.2f}x  (lnV {lnV:.3f})"
    )

epoch 000 train 7.9561 val_loss 9.1026  ppl 8978  bpc 13.132  Δnats -0.115  x-better 0.89x  (lnV 8.987)
epoch 001 train 7.6018 val_loss 9.1376  ppl 9298  bpc 13.183  Δnats -0.150  x-better 0.86x  (lnV 8.987)
epoch 002 train 6.8865 val_loss 9.1855  ppl 9754  bpc 13.252  Δnats -0.198  x-better 0.82x  (lnV 8.987)
epoch 003 train 6.1408 val_loss 9.1731  ppl 9634  bpc 13.234  Δnats -0.186  x-better 0.83x  (lnV 8.987)
epoch 004 train 5.4834 val_loss 9.2285  ppl 10184  bpc 13.314  Δnats -0.241  x-better 0.79x  (lnV 8.987)
epoch 005 train 5.0255 val_loss 9.3243  ppl 11207  bpc 13.452  Δnats -0.337  x-better 0.71x  (lnV 8.987)
epoch 006 train 4.5168 val_loss 9.4745  ppl 13024  bpc 13.669  Δnats -0.487  x-better 0.61x  (lnV 8.987)
epoch 007 train 3.7018 val_loss 9.3794  ppl 11842  bpc 13.532  Δnats -0.392  x-better 0.68x  (lnV 8.987)
epoch 008 train 3.5143 val_loss 9.6543  ppl 15588  bpc 13.928  Δnats -0.667  x-better 0.51x  (lnV 8.987)
epoch 009 train 3.1369 val_loss 9.6365  ppl 15314  bpc 13.9

In [14]:
x, y = val_dl.dataset.__getitem__(5)
temp = x[:50]
out = model.generate(temp.view(1, -1).to(device), max_new_tokens=100).cpu()
out

tensor([[ 136,  547,   54, 1106, 4609,  462,   54,  837, 5595,  456, 2380, 6773,
         1698,  518,  515, 6724,   68, 2852,   54, 1308, 4827,   65,  778, 4827,
         5117,  453, 2898, 7404,   66,  907,  456,  601,  130, 4454,  818, 7388,
           66,  100, 6805, 3624, 2821,  500, 5304, 3180, 1352, 7759,   47, 1352,
           66, 7894,  658,  456, 5075, 1640,  463, 7056, 3498, 6155, 1644,  515,
         5048,  626,  459,   64, 1916, 7857, 2390,  649,  641,   59, 4541,   67,
          598,  460, 3726,  178, 5831,  176, 1072,  988,  189,  465,   62,  951,
         4315, 6281,   43, 4099,  950,  513,   50,  504,  493,   58, 2752, 7246,
          520, 6768,   41, 1645,  503,   61, 2040,   60, 1271,   62, 5794, 1086,
           41, 2046, 2057,   55,  918,  501, 4058,  509,   62, 4195, 2203,   46,
          936,   57, 7006,  691,   65,  478,   50, 1039,  501,   60, 1311,   55,
         4708,  493,   60, 1109, 1689,  520,   60, 1013, 5626, 2854, 4160, 6634,
         3526,  493, 4064,  