In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys, os
sys.path.append('..')

import numpy as np
import librosa as lr
import torch
import IPython.display as ipd
import matplotlib.pyplot as plt
import pytorch_lightning as pl

from scipy.signal.windows import hann
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import LearningRateMonitor, EarlyStopping

from datasets.nsynth_datamodule import NsynthDataModule
from models.cvae_resnet import CvaeResnet
from models.cvae_inception import CvaeInception

pl.seed_everything(42)

Global seed set to 42


42

In [3]:
### CONFIGS

model_type = 'resnet'
num_workers = 8
batch_size = 16
max_epochs = 10000
patience = 200
gpus = 1

ds_configs = {
    'dataset_path': '/data/riccardo_datasets',
    'feature': 'spec',
    'feature_params': {
        'win_length': 256,
        'hop_length': 64,
        'window': hann(256)
    },
    'n_fft': 510,
    'ds_kwargs': {
        'pitches': [60, 61, 62],
        'instrument_families': [0],
        'sr': 16000,
        'duration': 1.02
    }
}

m_configs_resnet = {
    'lr': 1e-3,
    'lr_scheduler': {
        'factor': 0.5624, 
        'patience': 50,
    },
    'c_labels': ['pitch'],
    'kl_coeff': 1e-4,
    'db_coeff': 1e-3,
    'latent_size': 32,
    'channel_size': 2,
    'input_height': 256,
    'enc_type': 'resnet18',
    'first_conv': False,
    'maxpool1': False,
    'enc_out_dim': 512,
}

In [4]:
# init data loader
dm = NsynthDataModule(ds_configs, num_workers=num_workers, batch_size=batch_size)

In [5]:
# pick model
ModelClass = {
    'resnet': CvaeResnet,
    'incept': CvaeInception,
}.get(model_type)

m_configs = {
    'resnet': m_configs_resnet,
    'incept': m_configs_incept,
}.get(model_type)

In [6]:
# init model
model = ModelClass(m_configs)

In [7]:
# logger
log_name = '{}_{}'.format(ModelClass.model_name, 'test_overfit')
logger = TensorBoardLogger('logs', name=log_name)

In [None]:
# callbacks
early_stop = EarlyStopping(monitor='val_loss', patience=patience)
lr_monitor = LearningRateMonitor(logging_interval='epoch')

# train!
trainer = pl.Trainer(
    weights_summary='full',
    max_epochs=max_epochs,
    overfit_batches=1,
#    callbacks=[early_stop],
    terminate_on_nan=False,
#    gradient_clip_val=0.5,
    logger=logger,
    gpus=gpus)
trainer.fit(model=model, datamodule=dm)

In [None]:
# callbacks
lr_monitor = LearningRateMonitor(logging_interval='epoch')

# train!
trainer = pl.Trainer(
    max_epochs=20000,
    overfit_batches=1,
    callbacks=[lr_monitor],
    gpus=gpus)
trainer.fit(model=model, datamodule=dm)

In [None]:
trainer.max_epochs = 20000
trainer.fit(model=model, datamodule=dm)