In [3]:
%load_ext autoreload
%autoreload 2

from os import listdir
from os.path import isfile, join

import numpy as np
import matplotlib.pyplot as plt
import madmom

import sys
sys.path.append('../src')
from preprocessing import spectro_mini_db, spectro_mini_db_patches
from models import OLSPatchRegressor

na = np.newaxis

plt.rc('text', usetex=True)
plt.rc('font', family='serif')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Create a mini database with patches from the spectrogram and Train a simple OLS regressor on it

TODO: 

- add STFT options to the spectrogram (window size etc)
- add possibility to use different options at the same time (add depth dimension, is there a problem with the resulting shape?)

In [None]:
music_dir  = '../data/music_speech/music_wav/'
speech_dir = '../data/music_speech/speech_wav/'

max_samples = -1

X, Y = spectro_mini_db(music_dir, speech_dir, max_samples=max_samples)
X_patched, Y_patched = spectro_mini_db_patches(music_dir, speech_dir, 10, hpool=16, wpool=15,  shuffle=False, max_samples=max_samples)

print('Train Set Shape')
print(X.shape, Y.shape)
print('Patched Train Set Shape')
print(X_patched.shape, Y_patched.shape)

In [None]:
# Train linear patch regressor (att: no bias)
regressor = OLSPatchRegressor()
regressor.fit(X_patched, Y_patched)

print('Train Accuracy (Patched): {}'.format(np.mean(np.sign(regressor.predict(X_patched, patch_mode=True)) == Y_patched)))

ypred = regressor.predict(X)
print('Train Accuracy (Conv)   : {}'.format(np.mean(np.sign(np.mean(ypred, axis=1)) == Y)))

# ----------------------------------

In [None]:
def show_in_grid(input_3d, instant_output=True, figsize=(20, 20), save_path = None):
    
    N, H, W = input_3d.shape

    N_h = int(np.floor(N**.5))
    N_w = N // N_h

    hpad, wpad = 1, 1
    pad_val = np.min(input_3d)

    # add padding and grid presentation
    padded_input = np.pad(input_3d[:N_h * N_w], [[0,0], [hpad,hpad], [wpad,wpad]], mode='constant', constant_values=pad_val)
    H_padded = H + 2*wpad
    W_padded = W + 2*hpad
    spectro_grid = padded_input.reshape(N_h, N_w, H_padded, W_padded).transpose(0, 2, 1, 3).reshape(N_h* H_padded, N_w * W_padded)
    
    # present the grid
    fig = plt.figure(figsize=figsize)
    plt.imshow(spectro_grid, origin='lower')
    plt.axis('off')
    
    if save_path is not None:
        plt.savefig(save_path, dpi=300)
    
    if instant_output:
        plt.show()