In [1]:
import numpy as np
import torch
%matplotlib inline
from matplotlib import pyplot as plt
import tqdm
import torch.nn.functional as F
import torch.nn as nn
import pickle
import fb_utils as fb
from torch.utils.data import DataLoader
import random
import pandas
import os

from model import TDFilterbank, ClassifierConv
from dataset import TinySol
from losses import KappaLoss

In [2]:
# set seed
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
random.seed(0)
np.random.seed(0)

In [7]:
config = {
        "N": 2**12,
        "J": 96,
        "T": 1024,
        "sr": 16000,
        "fmin": 64,
        "fmax": 8000,
        "stride": 64,
        "batch_size": 16,
        "epochs": 10,
    }

# info_csv_path="/Users/felixperfler/Documents/ISF/Random Filterbanks/TinySOL_metadata.csv"
# data_dir="/Users/felixperfler/Documents/ISF/Random Filterbanks/TinySOL2020"
tiny_sol_csv="/Users/Dane/GitHub/Random-Filterbanks/TinySOL_metadata.csv"
data_dir="/Users/Dane/GitHub/Random-Filterbanks/TinySOL2020"
# tiny_sol_csv = '/Users/felixperfler/Documents/ISF/Random Filterbanks/TinySOL_metadata.csv'


TinySol_df = pandas.read_csv(tiny_sol_csv)

instrument_classes = TinySol_df['Instrument (abbr.)'].unique()
# map instrument classes to integers
instrument_to_int = {instrument: i for i, instrument in enumerate(instrument_classes)}
int_to_instrument = {i: instrument for i, instrument in enumerate(instrument_classes)}


random_filterbank = fb.random_filterbank(config["N"], config["J"], config["T"], tight=False, support_only=False)
random_filterbank_tight = fb.fir_tight(random_filterbank.numpy(), config["T"], eps=1.01)
target = 'VQT'
# get current working directory of file
cwd = os.path.abspath('')

with open(cwd+'/targets/'+target+'.pkl', 'rb') as fp:
    target_filterbank = pickle.load(fp)["freqz"]
    target_filterbank = torch.from_numpy(target_filterbank.T)


In [8]:
classifier = ClassifierConv()
loss = nn.CrossEntropyLoss()

train_dataset = TinySol(tiny_sol_csv, data_dir, config, target_filterbank, 'train')
train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True, num_workers=1)

val_dataset = TinySol(tiny_sol_csv, data_dir, config, target_filterbank, 'val')
val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=True, num_workers=1)

optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3)

In [7]:
# get one item from train_loader
batch = next(iter(train_loader))

In [8]:
x = batch['x_out']
instrument_pred = classifier(x.unsqueeze(1))
instrument_target = batch['instrument_vec']
torch.argmax(instrument_target, dim = 1) == torch.argmax(instrument_pred, dim = 1)


tensor([False, False,  True, False, False, False, False, False])

In [9]:
fit = []
fit_val = []
trues = []
trues_val = []

for epoch in range(config["epochs"]):
    running_loss = 0.0
    running_trues = 0.0
    for batch in train_loader:
        x = batch['x_out']
        instrument_target = batch['instrument_vec']
        
        optimizer.zero_grad()
        instrument_pred = classifier(x.unsqueeze(1))
        l = loss(instrument_pred, instrument_target)
        l.backward()
        optimizer.step()
        running_loss += l.item()

        instrument_guess = torch.argmax(instrument_pred, dim=1)
        instrument_true = torch.argmax(instrument_target, dim=1)
        running_trues += torch.sum(instrument_guess==instrument_true)/config["batch_size"]
    
    fit.append(running_loss/len(train_loader))
    trues.append(running_trues/len(train_loader))

    running_loss = 0.0
    running_trues = 0.0
    for batch in val_loader:
        x = batch['x_out']
        instrument_target = batch['instrument_vec']

        instrument_pred = classifier(x.unsqueeze(1))
        l = loss(instrument_pred, instrument_target)
        running_loss += l.item()

        instrument_guess = torch.argmax(instrument_pred, dim=1)
        instrument_true = torch.argmax(instrument_target, dim=1)
        running_trues += torch.sum(instrument_guess==instrument_true)/config["batch_size"]

    fit_val.append(running_loss/len(val_loader))
    trues_val.append(running_trues/len(val_loader))

    print(f"Epoch {epoch+1}/{config['epochs']}:")
    print(f"\tTest loss: {fit[-1]:.2f}")
    print(f"\tTest classification accuracy: {trues[-1]:.2f}")
    print(f"\tValidation loss: {fit_val[-1]:.2f}")
    print(f"\tValidation classification accuracy: {trues_val[-1]:.2f}")

plt.plot(fit)
plt.plot(fit_val)
        

37
Epoch 1/10:
	Test loss: 2.45
	Test classification accuracy: 0.30
	Validation loss: 2.54
	Validation classification accuracy: 0.22


KeyboardInterrupt: 

In [10]:
len(val_loader)

37