In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch

from noteify.core.config import SAMPLE_RATE
from noteify.core.datasets import (MaestroDataset, MaestroDatasetProcessed, MusicAugmentor,
                                   MusicSegmentSampler, get_music_dataloader)
from noteify.core.training import (make_optimizer, make_scheduler, train_model,
                                   save_training_session, load_training_session,
                                   set_optimizer_lr)
from noteify.core.models import TranscriptionNN
from noteify.core.utils import plot_audio, plot_roll_info
from noteify.utils import get_rel_pkg_path

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    print("Using the GPU!")
else:
    print("Warning: Could not find GPU! Using CPU only")

Using the GPU!


In [3]:
data_dir = get_rel_pkg_path("maestro/maestro-v3.0.0")
weights_dir = get_rel_pkg_path("weights/")
session_dir = get_rel_pkg_path("sessions/")

In [4]:
raw_dataset_train = MaestroDataset(data_dir, train=True)
raw_dataset_test = MaestroDataset(data_dir, train=False)

 10%|███████▉                                                                        | 96/962 [00:00<00:00, 959.20it/s]

Loading audio


100%|███████████████████████████████████████████████████████████████████████████████| 962/962 [00:00<00:00, 973.59it/s]
100%|██████████████████████████████████████████████████████████████████████████████| 177/177 [00:00<00:00, 1053.74it/s]

Loading audio





In [5]:
batch_size = 32

In [6]:
dataset_train = MaestroDatasetProcessed(raw_dataset_train, augmentor=MusicAugmentor())
sampler_train = MusicSegmentSampler(dataset_train, batch_size, num_batches=None, random_start_times=True)
dataloader_train = get_music_dataloader(dataset_train, sampler_train, num_workers=8, pin_memory=True)

dataset_test = MaestroDatasetProcessed(raw_dataset_test)
sampler_test = MusicSegmentSampler(dataset_test, batch_size)
dataloader_test = get_music_dataloader(dataset_test, sampler_test, num_workers=None, pin_memory=True)

dataloaders = {'train': dataloader_train, 'test': dataloader_test}

In [7]:
model = TranscriptionNN()
model = model.to(device)
past_train_weights = os.path.abspath(os.path.join(weights_dir, "Experiment 12-08-2020 03-48PM/Weights Best.pckl"))
model.load_state_dict(torch.load(past_train_weights))

Creating CQT kernels ...CQT kernels created, time used = 0.1221 seconds


<All keys matched successfully>

In [8]:
optimizer = make_optimizer(model, lr=0.0005)

In [9]:
# _ = load_training_session(model, optimizer,
#                           os.path.join(session_dir, "Session 12-08-2020 03-18PM"))

In [10]:
num_epochs = 100
scheduler = make_scheduler(optimizer, [60], gamma=0.5)

In [None]:
tracker = train_model(device=device,
                      model=model,
                      dataloaders=dataloaders,
                      optimizer=optimizer,
                      lr_scheduler=scheduler,
                      save_log=True,
                      save_model=True,
                      save_dir=weights_dir,
                      save_best=True,
                      save_all=False,
                      num_epochs=num_epochs)

  0%|                                                                                         | 0/2964 [00:00<?, ?it/s]

Epoch 0/99
----------


Avg. Loss: 0.0849, Total Loss: 0.1052, Loss Parts: [0.0852, 0.0074, 0.0126]: 100%|█| 2964/2964 [51:05<00:00,  1.03s/it]
  0%|                                                                                          | 0/374 [00:00<?, ?it/s]

Training Loss: 0.0849


  2%|█▉                                                                                | 9/374 [00:47<31:49,  5.23s/it]


Test statistics: {'frame_avg_precision': 0.8155788988450773, 'reg_onset_mae': 0.12424508482217789, 'reg_offset_mae': 0.1311982274055481}


  0%|                                                                                         | 0/2964 [00:00<?, ?it/s]


Epoch 1/99
----------


Avg. Loss: 0.0825, Total Loss: 0.0753, Loss Parts: [0.0568, 0.0075, 0.0110]: 100%|█| 2964/2964 [52:26<00:00,  1.06s/it]
  0%|                                                                                          | 0/374 [00:00<?, ?it/s]

Training Loss: 0.0825


  2%|█▉                                                                                | 9/374 [00:42<29:03,  4.78s/it]


Test statistics: {'frame_avg_precision': 0.8152843224994907, 'reg_onset_mae': 0.11782515048980713, 'reg_offset_mae': 0.12765178084373474}


  0%|                                                                                         | 0/2964 [00:00<?, ?it/s]


Epoch 2/99
----------


Avg. Loss: 0.0809, Total Loss: 0.0749, Loss Parts: [0.0567, 0.0074, 0.0109]: 100%|█| 2964/2964 [52:30<00:00,  1.06s/it]
  0%|                                                                                          | 0/374 [00:00<?, ?it/s]

Training Loss: 0.0809


  2%|█▉                                                                                | 9/374 [00:43<29:38,  4.87s/it]


Test statistics: {'frame_avg_precision': 0.8115624711536504, 'reg_onset_mae': 0.12190597504377365, 'reg_offset_mae': 0.1302483230829239}


  0%|                                                                                         | 0/2964 [00:00<?, ?it/s]


Epoch 3/99
----------


Avg. Loss: 0.0800, Total Loss: 0.0727, Loss Parts: [0.0545, 0.0073, 0.0109]: 100%|█| 2964/2964 [51:49<00:00,  1.05s/it]
  0%|                                                                                          | 0/374 [00:00<?, ?it/s]

Training Loss: 0.0800


  2%|█▉                                                                                | 9/374 [00:44<29:49,  4.90s/it]


Test statistics: {'frame_avg_precision': 0.8121575378137924, 'reg_onset_mae': 0.12281186133623123, 'reg_offset_mae': 0.1335730105638504}


  0%|                                                                                         | 0/2964 [00:00<?, ?it/s]


Epoch 4/99
----------


Avg. Loss: 0.0790, Total Loss: 0.0733, Loss Parts: [0.0554, 0.0072, 0.0107]: 100%|█| 2964/2964 [52:34<00:00,  1.06s/it]
  0%|                                                                                          | 0/374 [00:00<?, ?it/s]

Training Loss: 0.0790


  2%|█▉                                                                                | 9/374 [00:44<29:55,  4.92s/it]


Test statistics: {'frame_avg_precision': 0.8227924597646905, 'reg_onset_mae': 0.12421718239784241, 'reg_offset_mae': 0.1366361528635025}


  0%|                                                                                         | 0/2964 [00:00<?, ?it/s]


Epoch 5/99
----------


Avg. Loss: 0.0776, Total Loss: 0.0809, Loss Parts: [0.0632, 0.0072, 0.0104]: 100%|█| 2964/2964 [53:19<00:00,  1.08s/it]
  0%|                                                                                          | 0/374 [00:00<?, ?it/s]

Training Loss: 0.0776


  2%|█▉                                                                                | 9/374 [00:46<31:16,  5.14s/it]


Test statistics: {'frame_avg_precision': 0.7847419268341627, 'reg_onset_mae': 0.12322495877742767, 'reg_offset_mae': 0.13506537675857544}


  0%|                                                                                         | 0/2964 [00:00<?, ?it/s]


Epoch 6/99
----------


Avg. Loss: 0.0781, Total Loss: 0.0718, Loss Parts: [0.0542, 0.0072, 0.0104]: 100%|█| 2964/2964 [52:24<00:00,  1.06s/it]
  0%|                                                                                          | 0/374 [00:00<?, ?it/s]

Training Loss: 0.0781


  2%|█▉                                                                                | 9/374 [00:47<31:46,  5.22s/it]


Test statistics: {'frame_avg_precision': 0.8122980983918312, 'reg_onset_mae': 0.12245731800794601, 'reg_offset_mae': 0.13799062371253967}


  0%|                                                                                         | 0/2964 [00:00<?, ?it/s]


Epoch 7/99
----------


Avg. Loss: 0.0764, Total Loss: 0.0708, Loss Parts: [0.0534, 0.0072, 0.0102]: 100%|█| 2964/2964 [51:50<00:00,  1.05s/it]
  0%|                                                                                          | 0/374 [00:00<?, ?it/s]

Training Loss: 0.0764


  2%|█▉                                                                                | 9/374 [00:46<31:10,  5.12s/it]


Test statistics: {'frame_avg_precision': 0.8241874234519544, 'reg_onset_mae': 0.12373329699039459, 'reg_offset_mae': 0.14077693223953247}


  0%|                                                                                         | 0/2964 [00:00<?, ?it/s]


Epoch 8/99
----------


Avg. Loss: 0.0759, Total Loss: 0.0706, Loss Parts: [0.0532, 0.0071, 0.0103]: 100%|█| 2964/2964 [53:17<00:00,  1.08s/it]
  0%|                                                                                          | 0/374 [00:00<?, ?it/s]

Training Loss: 0.0759


  2%|█▉                                                                                | 9/374 [00:44<29:48,  4.90s/it]


Test statistics: {'frame_avg_precision': 0.8219343761389408, 'reg_onset_mae': 0.1209413930773735, 'reg_offset_mae': 0.14012114703655243}


  0%|                                                                                         | 0/2964 [00:00<?, ?it/s]


Epoch 9/99
----------


Avg. Loss: 0.0765, Total Loss: 0.0922, Loss Parts: [0.0699, 0.0094, 0.0129]:  40%|▍| 1181/2964 [22:34<39:47,  1.34s/it]

In [None]:
%qtconsole

In [None]:
save_training_session(model, optimizer, session_dir)