### Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torchvision.transforms as transforms
import torch.utils.data as data
import matplotlib.pyplot as plt
import IPython.display
from IPython.display import Audio
import torch.optim as optim
from types import SimpleNamespace
import scipy.signal as sc
import time

from trainDataset import TrainDataset
from testDataset import TestDataset
#from trainDatasetNew import TrainDatasetNew
#from testDatasetNew import TestDatasetNew
from validation_split import get_dataloaders
from math_utils import logMagStft, ffts
from SpectrogramCNN import SpectrogramCNN
from train_utils import train, test

### Parameters

In [None]:
validation_split = .2
do_plots = False
args = SimpleNamespace(batch_size=64, test_batch_size=64, epochs=3,
                       lr=0.01, momentum=0.5, seed=1, log_interval=200, 
                      net = SpectrogramCNN)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

path_test =  './../data/test/kaggle-test/'
if not torch.cuda.is_available(): # adapt those paths on other machine
    print('no cuda')
    path_train = './../data/train-small/'
else:
    print('with cuda')
    path_train = './../data/kaggle-train/'
    
sample_rate = 16000

### Original Dataset

In [None]:
# todo add in the classes the features and the fft data

toFloat = transforms.Lambda(lambda x: x / np.iinfo(np.int16).max)

trainDataset = TrainDataset(path_train, transform=toFloat)
print(len(trainDataset))

testDataset = TestDataset(path_test, transform=toFloat)
print(len(testDataset))

In [None]:
input_size = len(trainDataset[0][0])
print('input size: ',input_size)

### Look at Original Data

In [None]:
if do_plots:
    # how many instruments are there?
    dummy_count = np.zeros(20)

    for sample in trainDataset:
        dummy_count[sample[1]] += 1

    labels_count = []
    for elem in dummy_count:
        if elem != 0:
            labels_count.append(elem)

    print(labels_count)

In [None]:
if do_plots:
    nmbr_classes = len(labels_count)
    print('nmbr_classes: ', nmbr_classes)

In [None]:
if do_plots:
    plt.plot(labels_count, '*')

In [None]:
if do_plots:
    # plot one of each

    done = np.zeros(nmbr_classes)
    examples = []

    for sample in trainDataset:
        if done[sample[1]] == 0:
            examples.append(sample)
            done[sample[1]] = 1

In [None]:
if do_plots:
    plt.subplot(431)
    plt.plot(examples[0][0])

    plt.subplot(432)
    plt.plot(examples[1][0])

    plt.subplot(433)
    plt.plot(examples[2][0])

    plt.subplot(434)
    plt.plot(examples[3][0])

    plt.subplot(435)
    plt.plot(examples[4][0])

    plt.subplot(436)
    plt.plot(examples[5][0])

    plt.subplot(437)
    plt.plot(examples[6][0])

    plt.subplot(438)
    plt.plot(examples[7][0])

    plt.subplot(439)
    plt.plot(examples[8][0])

    plt.subplot(4,3,10)
    plt.plot(examples[9][0])

    plt.show()

In [None]:
if do_plots:

    # plot one of each in FFT

    plt.subplot(431)
    plt.plot(ffts(examples[0][0]))

    plt.subplot(432)
    plt.plot(ffts(examples[1][0]))

    plt.subplot(433)
    plt.plot(ffts(examples[2][0]))

    plt.subplot(434)
    plt.plot(ffts(examples[3][0]))

    plt.subplot(435)
    plt.plot(ffts(examples[4][0]))

    plt.subplot(436)
    plt.plot(ffts(examples[5][0]))

    plt.subplot(437)
    plt.plot(ffts(examples[6][0]))

    plt.subplot(438)
    plt.plot(ffts(examples[7][0]))

    plt.subplot(439)
    plt.plot(ffts(examples[8][0]))

    plt.subplot(4,3,10)
    plt.plot(ffts(examples[9][0]))

    plt.show()

In [None]:
if do_plots:

    for sample in examples:
        display(Audio(sample[0], rate=sample_rate))

### Dataloaders

In [None]:
# validation split is done here

train_loader, validation_loader = get_dataloaders(trainDataset, 
                                                  batch_size = args.batch_size, 
                                                  validation_split = validation_split, 
                                                  shuffle_dataset = True, 
                                                  random_seed = None)

for samples, instrument_family_target in train_loader:
        print(samples.shape, instrument_family_target.shape,
              instrument_family_target.data)
        print(torch.min(samples), torch.max(samples))
        print(trainDataset.transformInstrumentsFamilyToString(instrument_family_target.data))
        break
        
for samples, instrument_family_target in validation_loader:
        print(samples.shape, instrument_family_target.shape,
              instrument_family_target.data)
        print(torch.min(samples), torch.max(samples))
        print(trainDataset.transformInstrumentsFamilyToString(instrument_family_target.data))
        break

In [None]:
test_loader = data.DataLoader(testDataset, batch_size=args.batch_size, shuffle=False) #!!! shuffle should be false
for samples in test_loader:
        print(samples.shape)
        print(torch.min(samples), torch.max(samples))
        break

### Main

In [None]:
# Main
model = args.net(device).to(device)

optimizer = optim.SGD(model.parameters(), lr=args.lr, 
                      momentum=args.momentum)

for epoch in range(1, args.epochs + 1):
    train(args, model, device, train_loader, optimizer, epoch, start_time = time.time())
    test(args, model, device, test_loader, epoch, trainDataset, testDataset)