In [1]:
import os
import os.path as osp
import sys
sys.path.append(osp.abspath('..'))

import matplotlib.pyplot as plt
import numpy as np
import scipy
import torch
import torch.cuda as cuda
import torch.nn.functional as F
from tqdm import tqdm_notebook, tqdm
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Lambda

import config
from datasets.gtzan import GTZAN_SPEC
from dbn import DBN

%load_ext autoreload
%autoreload 2

In [2]:
# Random seeds
torch.manual_seed(1234)
cuda.manual_seed_all(1234)
np.random.seed(1234)

NUM_SEGMENTS = 10
NUM_PRINCIPAL_FEATURES = 80
SIGMA_PC = 3  # 1e-5
EPOCHS = 100
LR = 1e-3
INITIAL_MOMENTUM = 0.5
EPOCHS_FOR_INITIAL_MOMENTUM = 100
FINAL_MOMENTUM = 0.8
SPARSITY = 0.04

DEVICE = torch.device('cuda:0')

In [3]:
train_set = GTZAN_SPEC(phase='train', test_size=0.5,
                       min_segments=NUM_SEGMENTS, randomized=True)
test_set = GTZAN_SPEC(phase='test', test_size=0.5, 
                      min_segments=NUM_SEGMENTS)

print('Train:', len(train_set))
print('Test:', len(test_set))
print('Shape:', train_set[0][0].shape)

Train: 500
Test: 500
Shape: (10, 221, 301)


In [4]:
def concatenate_data(*datasets):
    X = []
    for dataset in datasets:
        for sample_idx, (x, _) in tqdm_notebook(enumerate(dataset), total=len(dataset)):
            for x_ in x:
                X.append(x_)
    X = np.concatenate(X, axis=1)
    return X

Ewhiten = None
if not osp.exists('pca_whiten_mat.npy'):
    X = concatenate_data(train_set)

    # PCA whiten
    Xcov = X.dot(X.T) / X.shape[1]
    num_features = Xcov.shape[0]
    eigen_values, eigen_vectors = np.linalg.eigh(Xcov)

    E = eigen_vectors[:, num_features:num_features-NUM_PRINCIPAL_FEATURES-1:-1]
    S = np.diag(eigen_values[num_features:num_features-NUM_PRINCIPAL_FEATURES-1:-1])

    Xpc = E.T.dot(X)
    Xrec = E.dot(Xpc)

    Ewhiten = np.diag(1. / np.sqrt((np.diag(S)+SIGMA_PC))).dot(E.T)
    Eunwhiten = E.dot(np.diag(np.sqrt(np.diag(S)+SIGMA_PC)))

    Xrec2 = Eunwhiten.dot(Ewhiten).dot(X)
    
    plt.figure(figsize=(16, 32))
    plt.subplot(1, 3, 1)
    plt.imshow(X[:, :301], cmap='hot')
    plt.subplot(1, 3, 2)
    plt.imshow(Xrec[:, :301], cmap='hot')
    plt.subplot(1, 3, 3)
    plt.imshow(Xrec2[:, :301], cmap='hot')
    plt.show()
    
    with open('pca_whiten_mat.npy', 'wb') as f:
        np.save(f, Ewhiten)
else:
    with open('pca_whiten_mat.npy', 'rb') as f:
        Ewhiten = np.load(f)
assert Ewhiten is not None

In [5]:
train_loader = DataLoader(train_set, shuffle=True, num_workers=9, pin_memory=True)
test_loader = DataLoader(test_set, num_workers=9, pin_memory=True)

In [6]:
def train_dbn_layer(dbn, layer, epochs, train_loader, test_loader, device=None,
                    initial_momentum=INITIAL_MOMENTUM, final_momentum=FINAL_MOMENTUM,
                    epochs_for_initial_momentum=EPOCHS_FOR_INITIAL_MOMENTUM):
    with tqdm_notebook(range(1, 1 + epochs), total=epochs) as progress_epoch:
        for epoch in progress_epoch:
            rbm = dbn.rbms[layer]
            if epoch < epochs_for_initial_momentum:
                rbm.momentum = initial_momentum
            else:
                rbm.momentum = final_momentum

            # Train
            loss = 0
            num_batches = 0
            with tqdm_notebook(iterable=train_loader, total=len(train_loader)) as progress_batch:
                for batch, (x, _) in enumerate(progress_batch, 1):
                    for x_ in x[0]:
                        x_ = x_.squeeze()
                        x_ = Ewhiten.dot(x_)
                        x_ = torch.from_numpy(x_)
                        x_ = x_.type(torch.FloatTensor)
                        x_ = x_.view(-1)[None, ...]
                        x_ = x_.to(device)

                        batch_error = dbn.train(x_, layer, k=1, epoch=epoch)
                        loss += batch_error.item()
                        num_batches += 1
                    progress_batch.set_postfix(
                        epoch=epoch,
                        loss='{:.3f}'.format(loss / num_batches))
            
            # Test
            test_loss = 0
            num_test_batches = 0
            with tqdm_notebook(iterable=test_loader, total=len(test_loader)) as progress_batch:
                for batch, (x, _) in enumerate(progress_batch, 1):
                    for x_ in x[0]:
                        x_ = x_.squeeze()
                        x_ = Ewhiten.dot(x_)
                        x_ = torch.from_numpy(x_)
                        x_ = x_.type(torch.FloatTensor)
                        x_ = x_.view(-1)[None, ...]
                        x_ = x_.to(device)
                        
                        batch_error = dbn.reconstruct(x_, layer+1)[1]
                        test_loss += batch_error.item()
                        num_test_batches += 1
                    progress_batch.set_postfix(
                        epoch=epoch,
                        test_loss='{:.3f}'.format(test_loss / num_test_batches))
            
            loss = loss / num_batches
            test_loss = test_loss / num_test_batches
            progress_epoch.set_postfix(
                loss='{:.3f}'.format(loss),
                test_loss='{:.3f}'.format(test_loss))
            
            if epoch % 10 == 0:
                with open('dbn_checkpoints/checkpoint_layer_{}_epoch_{}.pt'.format(layer, epoch), 'wb') as f:
                    torch.save(dbn, f)
            with open('dbn_train.log', 'a') as f:
                f.write('Layer: {} Epoch: {:4d} Loss: {:.3f} Test Loss: {:.3f}\n'.format(
                    layer, epoch, loss, test_loss))

In [7]:
dbn = DBN(24080, [512, 256], device=DEVICE)

In [8]:
train_dbn_layer(dbn, 0, EPOCHS, train_loader, test_loader, device=DEVICE)

HBox(children=(IntProgress(value=0), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))




train_cdbn_layer(cdbn, 1, EPOCHS, train_loader, test_loader, device=DEVICE)

In [16]:
x_ = test_set[0][0][0]
x_ = x_.squeeze()
x_ = Ewhiten.dot(x_)
x_ = torch.from_numpy(x_)
x_ = x_.type(torch.FloatTensor)
x_ = x_.view(-1)[None, ...]
x_ = x_.to(DEVICE)
print(x_.max())
print(x_.min())
dbn.rbms[0].v2h(x_)[1]

tensor(2.1037, device='cuda:0')
tensor(-2.4018, device='cuda:0')


tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0.,