In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import fastai
fastai.__version__

In [None]:
from fastai import *
from fastai_audio import *

In [None]:
DATA = Path('data')
NSYNTH_AUDIO = DATA/'nsynth_audio' # contains train and valid folders

In [None]:
LABEL_DIR = Path('tmp_labels')
GUITAR_TRN = LABEL_DIR/'train_guitar_clean_40_88.csv'
GUITAR_VAL = LABEL_DIR/'valid_guitar_clean_40_88.csv'
GUITAR_TST = LABEL_DIR/'test_guitar_clean_40_88.csv'

In [None]:
trn_df, val_df, tst_df = [pd.read_csv(CSV) for CSV in [GUITAR_TRN, GUITAR_VAL, GUITAR_TST]]
trn_df.head(2)

In [None]:
len(trn_df), len(val_df), len(tst_df)

In [None]:
trn_list, val_list, tst_list = [AudioItemList.from_df(df, path=NSYNTH_AUDIO, 
                                                      folder=folder, suffix='.wav')
                                for df, folder in zip([trn_df, val_df, tst_df], 
                                                      ['train', 'valid', 'test'])]
len(trn_list), len(val_list), len(tst_list)

In [None]:
def get_frame(x, start=1024, frame_len=1024):
    return x[start:start+frame_len]

def batch_fft(inputs):
    xs, ys = inputs
    xs = torch.rfft(xs, 1, normalized=True).pow_(2.0).sum(-1)
    return xs, ys

In [None]:

# n_fft = 512 
# n_hop = 256
# n_mels = 64
# sample_rate = 16000
# ref = 'max'
# top_db = 50.0

bs = 32

tfm_list = [get_frame]
tfms = (tfm_list, tfm_list) # train, valid tfms 

batch_tfms = [batch_fft]

data = (ItemLists(NSYNTH_AUDIO, trn_list, val_list)
            .label_from_df('pitch')
            .add_test(tst_list)
            .transform(tfms)
            .databunch(bs=bs, tfms=batch_tfms))
xs, ys = data.one_batch()
xs.shape, ys.shape, xs.min(), xs.max()

In [None]:
class SimpleModel(nn.Module):
    def __init__(self, n_classes):
        super().__init__()
        self.layers = nn.Sequential(
            *bn_drop_lin(1024, 1024, actn=nn.ReLU(inplace=True)),
            *bn_drop_lin(1024, 512, actn=nn.ReLU(inplace=True)),
            *bn_drop_lin(512, n_classes, actn=None),
        )
    def forward(self, x):
        return self.layers(x)

In [None]:
model = SimpleModel(data.c)
learn = Learner(data, model, metrics=[accuracy])
learn.summary()

In [None]:
model_name = 'pitch_frame_v1'
learn.fit_one_cycle(4)
learn.save(model_name + '-stage-1')

In [None]:
learn.load(model_name + '-stage-1')
learn.lr_find()
learn.recorder.plot()

In [None]:
learn.load(model_name + '-stage-1')
learn.fit_one_cycle(8, max_lr=1e-3)
learn.save(model_name + '-stage-2')

In [None]:
learn.load(model_name + '-stage-2')
learn.lr_find()
learn.recorder.plot()

In [None]:
learn.load(model_name + '-stage-2')
learn.fit_one_cycle(8, max_lr=1e-3)
learn.save(model_name + '-stage-3')

In [None]:
learn.load(model_name + '-stage-3')
learn.lr_find(start_lr=1e-9, end_lr=1)
learn.recorder.plot()

In [None]:
learn.load(model_name + '-stage-3')
learn.fit_one_cycle(8, max_lr=1e-4)
learn.save(model_name + '-stage-4')

In [None]:
learn.load(model_name + '-stage-3')
learn.fit_one_cycle(8, max_lr=1e-5)
learn.save(model_name + '-stage-4')

In [None]:
accuracy(*learn.get_preds())

In [None]:
n_errors = round(float(1 - accuracy(*learn.get_preds(DatasetType.Train))) * len(trn_list))
print(n_errors, 'errors')