In [1]:
import os
import os.path as osp
import sys
from itertools import combinations
sys.path.append(osp.abspath('..'))

import numpy as np
import torch
from tqdm import tqdm_notebook
from scipy.stats import kurtosis, skew
from sklearn.model_selection import cross_validate
from sklearn.svm import SVC
from sklearn.preprocessing import scale, normalize
from sklearn.metrics import accuracy_score

import config
from datasets.gtzan import GTZAN_SPEC
from cdbn import CDBN

%load_ext autoreload
%autoreload 2

In [2]:
# Random seeds
np.random.seed(1234)
MIN_SEGMENTS = 1

CDBN_CHECKPOINT = 'cdbn_checkpoints/checkpoint_layer_0_epoch_90.pt'

In [3]:
dataset = GTZAN_SPEC(phase='all', min_segments=MIN_SEGMENTS, randomized=True)
print('Train:', len(dataset))
print('Shape:', dataset[0][0].shape)

Train: 1000
Shape: (1, 221, 101)


In [4]:
Ewhiten = None
with open('pca_whiten_mat.npy', 'rb') as f:
    Ewhiten = np.load(f)
assert Ewhiten is not None

for x_ in dataset[0][0]:
    print(Ewhiten.dot(x_).shape)

(80, 101)


In [9]:
# aggregate over all frames, calculate simple summary statistics
# such as average, max or standard deviation for each channel
cdbn = None
with open(CDBN_CHECKPOINT, 'rb') as f:
    cdbn = torch.load(f)

def populate_features(dataset):
    X = []
    Y = []
    for sample_idx, (x, y) in tqdm_notebook(enumerate(dataset), total=len(dataset)):
        for segment_idx, x_ in enumerate(x):
            x_ = Ewhiten.dot(x_)
            x_ = torch.from_numpy(x_)
            x_ = x_.type(torch.FloatTensor)
            x_ = x_[None, None, :, 1:]
            x_ = x_.to(0)
            x_ = cdbn.crbms[0].v2h(x_)[1].squeeze_().cpu().numpy()
            x_ = np.hstack((
                np.mean(x_, axis=1),
                np.max(x_, axis=1),
                np.std(x_, axis=1),
                kurtosis(x_, axis=1),
                skew(x_, axis=1),
            ))
            X.append(x_)
            Y.append(y)
    X = np.array(X)
    Y = np.array(Y)
    return X, Y
            

X, Y = populate_features(dataset)
print(X.shape)
print(Y.shape)

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))


False
(1000, 1500)
(1000,)


In [14]:
scores = cross_validate(SVC(C=2), normalize(X), Y, cv=10, n_jobs=-1)
print('SCORES', scores['test_score'])
print('AVG:', scores['test_score'].mean())
print('STD:', scores['test_score'].std())



SCORES [0.44 0.49 0.5  0.45 0.48 0.41 0.49 0.56 0.44 0.42]
AVG: 0.46799999999999997
STD: 0.04261455150532505
