In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch

from noteify.core.config import SAMPLE_RATE
from noteify.core.datasets import (MusicNetDataset, MusicNetDatasetProcessed, MusicAugmentor,
                                   MusicNetSampler, get_musicnet_dataloader)
from noteify.core.training import make_optimizer, make_scheduler, train_model
from noteify.core.models import TranscriptionNN
from noteify.core.utils import plot_audio, plot_roll_info
from noteify.utils import get_rel_pkg_path

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    print("Using the GPU!")
else:
    print("Warning: Could not find GPU! Using CPU only")

Using the GPU!


In [3]:
data_dir = get_rel_pkg_path("musicnet/")
save_dir = get_rel_pkg_path("weights/")

In [4]:
raw_dataset_train = MusicNetDataset(data_dir, download=True, train=True,
                                    numpy_cache=True, piano_only=True) #, filter_records=[1727])
raw_dataset_test = MusicNetDataset(data_dir, download=True, train=False,
                                   numpy_cache=True, piano_only=True) #, filter_records=[1759])

  2%|█▉                                                                              | 15/640 [00:00<00:04, 136.80it/s]

Loading audio


100%|███████████████████████████████████████████████████████████████████████████████| 640/640 [00:01<00:00, 338.24it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 1700.82it/s]

Loading audio





In [5]:
batch_size = 32

In [6]:
dataset_train = MusicNetDatasetProcessed(raw_dataset_train, augmentor=MusicAugmentor())
sampler_train = MusicNetSampler(dataset_train, batch_size, num_batches=None, random_start_times=True)
workers = 3 if raw_dataset_train.piano_only else 2
dataloader_train = get_musicnet_dataloader(dataset_train, sampler_train, num_workers=workers, pin_memory=True)

dataset_test = MusicNetDatasetProcessed(raw_dataset_test)
sampler_test = MusicNetSampler(dataset_test, batch_size)
dataloader_test = get_musicnet_dataloader(dataset_test, sampler_test, num_workers=None, pin_memory=True)

dataloaders = {'train': dataloader_train, 'test': dataloader_test}

In [7]:
model = TranscriptionNN()
model = model.to(device)
past_train_weights = os.path.abspath(os.path.join(save_dir, "Experiment 12-07-2020 11-51PM/Weights Best.pckl"))
model.load_state_dict(torch.load(past_train_weights))

Creating CQT kernels ...CQT kernels created, time used = 0.0947 seconds


<All keys matched successfully>

In [8]:
optimizer = make_optimizer(model, lr=0.001)

In [9]:
num_epochs = 100
scheduler = make_scheduler(optimizer, [60], gamma=0.5)

In [None]:
tracker = train_model(device=device,
                      model=model,
                      dataloaders=dataloaders,
                      optimizer=optimizer,
                      lr_scheduler=scheduler,
                      save_model=True,
                      save_dir=save_dir,
                      save_best=True,
                      save_all=False,
                      num_epochs=num_epochs)

  0%|                                                                                          | 0/307 [00:00<?, ?it/s]

Epoch 0/99
----------


Avg. Loss: 0.1225, Total Loss: 0.1237, Loss Parts: [0.0921, 0.0141, 0.0175]: 100%|███| 307/307 [06:59<00:00,  1.37s/it]


Training Loss: 0.1225
Test statistics: {'frame_avg_precision': 0.509887255344955, 'reg_onset_mae': 0.1485602855682373, 'reg_offset_mae': 0.11657305806875229}


  0%|                                                                                          | 0/307 [00:00<?, ?it/s]


Epoch 1/99
----------


Avg. Loss: 0.1217, Total Loss: 0.1464, Loss Parts: [0.1086, 0.0175, 0.0203]: 100%|███| 307/307 [06:48<00:00,  1.33s/it]


Training Loss: 0.1217
Test statistics: {'frame_avg_precision': 0.5284877545540163, 'reg_onset_mae': 0.1309359222650528, 'reg_offset_mae': 0.11286318302154541}


  0%|                                                                                          | 0/307 [00:00<?, ?it/s]


Epoch 2/99
----------


Avg. Loss: 0.1209, Total Loss: 0.0997, Loss Parts: [0.0736, 0.0118, 0.0143]: 100%|███| 307/307 [06:53<00:00,  1.35s/it]


Training Loss: 0.1209
Test statistics: {'frame_avg_precision': 0.5244165615058743, 'reg_onset_mae': 0.13597600162029266, 'reg_offset_mae': 0.11644430458545685}


  0%|                                                                                          | 0/307 [00:00<?, ?it/s]


Epoch 3/99
----------


Avg. Loss: 0.1204, Total Loss: 0.1157, Loss Parts: [0.0924, 0.0104, 0.0129]: 100%|███| 307/307 [06:39<00:00,  1.30s/it]


Training Loss: 0.1204
Test statistics: {'frame_avg_precision': 0.5270861640343865, 'reg_onset_mae': 0.14436075091362, 'reg_offset_mae': 0.11538197845220566}


  0%|                                                                                          | 0/307 [00:00<?, ?it/s]


Epoch 4/99
----------


Avg. Loss: 0.1196, Total Loss: 0.1107, Loss Parts: [0.0832, 0.0127, 0.0148]: 100%|███| 307/307 [06:37<00:00,  1.30s/it]


Training Loss: 0.1196
Test statistics: {'frame_avg_precision': 0.5434216151888948, 'reg_onset_mae': 0.14640818536281586, 'reg_offset_mae': 0.12432495504617691}


  0%|                                                                                          | 0/307 [00:00<?, ?it/s]


Epoch 5/99
----------


Avg. Loss: 0.1190, Total Loss: 0.1164, Loss Parts: [0.0858, 0.0143, 0.0163]: 100%|███| 307/307 [06:37<00:00,  1.30s/it]


Training Loss: 0.1190
Test statistics: {'frame_avg_precision': 0.5363471424798778, 'reg_onset_mae': 0.14301566779613495, 'reg_offset_mae': 0.1181500107049942}


  0%|                                                                                          | 0/307 [00:00<?, ?it/s]


Epoch 6/99
----------


Avg. Loss: 0.1190, Total Loss: 0.1059, Loss Parts: [0.0819, 0.0110, 0.0130]: 100%|███| 307/307 [07:00<00:00,  1.37s/it]


Training Loss: 0.1190
Test statistics: {'frame_avg_precision': 0.5618583737292893, 'reg_onset_mae': 0.1463242620229721, 'reg_offset_mae': 0.12019853293895721}


  0%|                                                                                          | 0/307 [00:00<?, ?it/s]


Epoch 7/99
----------


Avg. Loss: 0.1176, Total Loss: 0.1135, Loss Parts: [0.0840, 0.0138, 0.0158]:  50%|█▍ | 153/307 [04:30<02:00,  1.28it/s]

In [None]:
%qtconsole