In [1]:
%reload_ext autoreload
%autoreload 2
from trainer import Trainer, TrainerConfig
from mingpt_utils import set_seed
from model import GPT, GPTConfig
import torch
from utils import *

from torch.utils.tensorboard import SummaryWriter
from mingpt_utils import sample
torch.cuda.empty_cache()

In [2]:
print("Available devices: ", torch.cuda.device_count())
print("torch version:", torch.__version__)
print("cudnn version:", torch.backends.cudnn.version())
print("cuda version:", torch.version.cuda)

Available devices:  2
torch version: 2.2.1+cu121
cudnn version: 8902
cuda version: 12.1


In [3]:

max_length = 2048
id = 0
tokens = np.load('../data/formatted/tokens.npy', allow_pickle=True)
train = np.load('../data/shuffled/dataset_train.npy', allow_pickle=True)
test = np.load('../data/shuffled/dataset_test.npy', allow_pickle=True)
midi_train = np.load('../data/shuffled/midi_train.npy', allow_pickle=True)
midi_test = np.load('../data/shuffled/midi_test.npy', allow_pickle=True)

In [4]:
#Convert midi into dtype int
midi_train = midi_train.astype(int)
midi_test = midi_test.astype(int)

In [5]:
print(train.shape, test.shape, midi_train.shape, midi_test.shape)

dataset = TokenDatasetMidi(train, midi_train,  max_length, tokens)
validation = TokenDatasetMidi(test, midi_test, max_length, tokens)

(43272, 2048) (4800, 2048) (43272, 2048, 8) (4800, 2048, 8)
data has 43272 pieces, 195 unique tokens.
data has 4800 pieces, 195 unique tokens.


In [6]:
import wandb
#wandb.login()
wandb.init(
    # set the wandb project where this run will be logged
    project="music_gpt_new_voicing",
    
    # track hyperparameters and run metadata
    config={
    "learning_rate": 3e-5,
    "architecture": "Transformer - minGPT",
    "dataset": "chords from iRealPro",
    "epochs": 250,
    }
)

03/12/2024 12:58:29 - ERROR - wandb.jupyter -   Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdazzid[0m ([33mmusic_gpt[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [7]:
epochs = 250
embedding = 192
heads = 4
layers = 4
batch_size = 32
learning_rate = 3e-5
num_workers = 4
midi_vocab = 128

mconf = GPTConfig(len(tokens), dataset.block_size, midi_vocab, n_layer=layers, n_head=heads, n_embd=embedding)
session_model = GPT(mconf)
MODEL_NAME = "../models/model_"+ "epochs->" + str(epochs) + "_heads->" + str(heads) + "_embd->" + str(embedding) + "_batch->" + str(batch_size) + "_new_midi_embeddings"
print(MODEL_NAME)

session_model = load_model(MODEL_NAME, session_model)

if (session_model == None):
    #mconf = GPTConfig(len(tokens), dataset.block_size, n_layer=layers, n_head=heads, n_embd=embbedings)
    session_model = GPT(mconf)
    tconf = TrainerConfig(max_epochs=epochs, 
                          batch_size=batch_size, 
                          learning_rate=learning_rate, 
                          num_workers=num_workers
                          )
    writer = SummaryWriter(log_dir='../runs/'+'logs') 
    trainer = Trainer(session_model, dataset, validation, tconf, writer)
    trainer.train()
    save_model(MODEL_NAME, session_model)
    # [optional] finish the wandb run, necessary in notebooks
    wandb.finish()

03/12/2024 12:58:32 - INFO - model -   number of parameters: 1.861376e+06
03/12/2024 12:58:32 - INFO - model -   number of parameters: 1.861376e+06


../models/model_epochs->250_heads->4_embd->192_batch->32_new_midi_embeddings


epoch 1 iter 1352: train loss 1.69695. lr 3.000000e-05: 100%|██████████| 1353/1353 [05:18<00:00,  4.24it/s]
03/12/2024 13:03:52 - INFO - trainer -   epoch train loss: 2.799219


train loss: 2.799219166095928


03/12/2024 13:04:06 - INFO - trainer -   test loss: 1.646985
epoch 2 iter 1352: train loss 1.45854. lr 3.000000e-05: 100%|██████████| 1353/1353 [05:18<00:00,  4.25it/s]
03/12/2024 13:09:25 - INFO - trainer -   epoch train loss: 1.584151


train loss: 1.5841505897547348


03/12/2024 13:09:39 - INFO - trainer -   test loss: 1.481287
epoch 3 iter 1352: train loss 1.39143. lr 3.000000e-05: 100%|██████████| 1353/1353 [05:18<00:00,  4.25it/s]
03/12/2024 13:14:57 - INFO - trainer -   epoch train loss: 1.444626


train loss: 1.444626359435601


03/12/2024 13:15:11 - INFO - trainer -   test loss: 1.427665
epoch 4 iter 1352: train loss 1.27422. lr 3.000000e-05: 100%|██████████| 1353/1353 [05:18<00:00,  4.24it/s]
03/12/2024 13:20:30 - INFO - trainer -   epoch train loss: 1.338538


train loss: 1.3385379384908513


03/12/2024 13:20:44 - INFO - trainer -   test loss: 1.271491
epoch 5 iter 1352: train loss 1.19974. lr 3.000000e-05: 100%|██████████| 1353/1353 [05:18<00:00,  4.25it/s]
03/12/2024 13:26:03 - INFO - trainer -   epoch train loss: 1.256074


train loss: 1.2560737678410825


03/12/2024 13:26:17 - INFO - trainer -   test loss: 1.227735
epoch 6 iter 1352: train loss 1.17343. lr 3.000000e-05: 100%|██████████| 1353/1353 [05:17<00:00,  4.26it/s]
03/12/2024 13:31:35 - INFO - trainer -   epoch train loss: 1.220359


train loss: 1.2203588897179254


03/12/2024 13:31:49 - INFO - trainer -   test loss: 1.205947
epoch 7 iter 1352: train loss 1.13213. lr 3.000000e-05: 100%|██████████| 1353/1353 [05:17<00:00,  4.26it/s]
03/12/2024 13:37:07 - INFO - trainer -   epoch train loss: 1.195299


train loss: 1.195299013314737


03/12/2024 13:37:21 - INFO - trainer -   test loss: 1.184742
epoch 8 iter 1352: train loss 1.12346. lr 3.000000e-05: 100%|██████████| 1353/1353 [05:17<00:00,  4.26it/s]
03/12/2024 13:42:39 - INFO - trainer -   epoch train loss: 1.173387


train loss: 1.1733866373169273


03/12/2024 13:42:53 - INFO - trainer -   test loss: 1.161737
epoch 9 iter 1352: train loss 1.10216. lr 3.000000e-05: 100%|██████████| 1353/1353 [05:18<00:00,  4.25it/s]
03/12/2024 13:48:12 - INFO - trainer -   epoch train loss: 1.153956


train loss: 1.1539562235738468


03/12/2024 13:48:26 - INFO - trainer -   test loss: 1.145953
epoch 10 iter 1352: train loss 1.08275. lr 3.000000e-05: 100%|██████████| 1353/1353 [05:18<00:00,  4.25it/s]
03/12/2024 13:53:44 - INFO - trainer -   epoch train loss: 1.137876


train loss: 1.1378758466428769


03/12/2024 13:53:58 - INFO - trainer -   test loss: 1.122994
epoch 11 iter 1352: train loss 1.06621. lr 3.000000e-05: 100%|██████████| 1353/1353 [05:17<00:00,  4.25it/s]
03/12/2024 13:59:17 - INFO - trainer -   epoch train loss: 1.122520


train loss: 1.1225195069358864


03/12/2024 13:59:30 - INFO - trainer -   test loss: 1.100284
epoch 12 iter 1352: train loss 1.05141. lr 3.000000e-05: 100%|██████████| 1353/1353 [05:18<00:00,  4.25it/s]
03/12/2024 14:04:49 - INFO - trainer -   epoch train loss: 1.105544


train loss: 1.1055436424652207


03/12/2024 14:05:03 - INFO - trainer -   test loss: 1.077704
epoch 13 iter 1352: train loss 1.02099. lr 3.000000e-05: 100%|██████████| 1353/1353 [05:17<00:00,  4.26it/s]
03/12/2024 14:10:21 - INFO - trainer -   epoch train loss: 1.087544


train loss: 1.0875439282179056


03/12/2024 14:10:35 - INFO - trainer -   test loss: 1.059534
epoch 14 iter 1352: train loss 1.00255. lr 3.000000e-05: 100%|██████████| 1353/1353 [05:17<00:00,  4.26it/s]
03/12/2024 14:15:53 - INFO - trainer -   epoch train loss: 1.072636


train loss: 1.0726359817863655


03/12/2024 14:16:07 - INFO - trainer -   test loss: 1.049797
epoch 15 iter 1352: train loss 0.98366. lr 3.000000e-05: 100%|██████████| 1353/1353 [05:18<00:00,  4.25it/s]
03/12/2024 14:21:25 - INFO - trainer -   epoch train loss: 1.060054


train loss: 1.060053651498497


03/12/2024 14:21:39 - INFO - trainer -   test loss: 1.036180
epoch 16 iter 898: train loss 1.08910. lr 3.000000e-05:  66%|██████▋   | 899/1353 [03:30<01:46,  4.25it/s]