In [1]:
import argparse
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

import torch
import numpy as np 

from preprocess import get_mnist, get_webcam
from train import TrainerVaDE

In [2]:
class Args:
    batch_size = 128
    lr = 1e-5
    dataset = 'webcam'
    pretrained_path = 'weights/pretrained_parameter.pth'
    patience = 50
    pretrain = True
    epochs = 200


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    
args = Args() # Parsing all the arguments for the training
if args.dataset == 'mnist':
    dataloader = get_mnist(batch_size=args.batch_size)
    n_classes = 10
else:
    dataloader = get_webcam(batch_size=args.batch_size)
    n_classes = 31

In [3]:
vade = TrainerVaDE(args, device, dataloader, n_classes)

In [4]:
if args.dataset == 'webcam':
    classes = ['back_pack',
                'bike',
                'bike_helmet',
                'bookcase',
                'bottle',
                'calculator',
                'desk_chair',
                'desk_lamp',
                'desktop_computer',
                'file_cabinet',
                'headphones',
                'keyboard',
                'laptop_computer',
                'letter_tray',
                'mobile_phone',
                'monitor',
                'mouse',
                'mug',
                'paper_notebook',
                'pen',
                'phone',
                'printer',
                'projector',
                'punchers',
                'ring_binder',
                'ruler',
                'scissors',
                'speaker',
                'stapler',
                'tape_dispenser',
                'trash_can']
else:
    classes = ['0',
               '1',
               '2',
               '3',
               '4',
               '5',
               '6',
               '7',
               '8',
               '9']


def get_latent_space(dataloader, z_dim, model, device, ftr_ext=None):
    z = torch.zeros((1, z_dim)).float().to(device)
    y = torch.zeros((1)).long().to(device)
    with torch.no_grad():
        for img, label in dataloader:
            img, label = img.to(device).float(), label.to(device).long()
            if ftr_ext is not None:
                img = ftr_ext(img); img = img.detach()

            mu, log_var = model.encode(img)
            z_l = model.reparameterize(mu, log_var)
            y = torch.cat((y, label), dim=0)
            z = torch.cat((z, z_l), dim=0)
    return z[1:], y[1:]


def plot_tsne(X_embedded, y, ticks):
    f, ax1 = plt.subplots(1, 1, sharey=True, figsize=(15,5))

    cmap = plt.get_cmap('jet', 31)


    cax = ax1.scatter(X_embedded[:, 0], X_embedded[:, 1], c=y.numpy(),
                      s=15, cmap=cmap)

    cbar = f.colorbar(cax, ticks=np.linspace(0,30,31))
    cbar.ax.set_yticklabels(ticks)

    ax1.xaxis.set_visible(False)
    ax1.yaxis.set_visible(False)

    plt.show()

In [5]:
vade.train()

Training VaDE...




Training VaDE... Epoch: 0, Loss: 13285.23353794643
Testing VaDE... Epoch: 0, Loss: 13223.004185267857, Acc: 39.119496855345915
Training VaDE... Epoch: 1, Loss: 13232.789481026786
Testing VaDE... Epoch: 1, Loss: 13080.526785714286, Acc: 39.49685534591195
Training VaDE... Epoch: 2, Loss: 13095.9482421875
Testing VaDE... Epoch: 2, Loss: 13018.60267857143, Acc: 39.49685534591195
Training VaDE... Epoch: 3, Loss: 13031.666573660714
Testing VaDE... Epoch: 3, Loss: 12960.36216517857, Acc: 39.119496855345915
Training VaDE... Epoch: 4, Loss: 12859.229910714286
Testing VaDE... Epoch: 4, Loss: 12798.50892857143, Acc: 38.86792452830189
Training VaDE... Epoch: 5, Loss: 12747.818777901786
Testing VaDE... Epoch: 5, Loss: 12743.317103794643, Acc: 39.24528301886793
Training VaDE... Epoch: 6, Loss: 12712.742466517857
Testing VaDE... Epoch: 6, Loss: 12685.273716517857, Acc: 38.86792452830189
Training VaDE... Epoch: 7, Loss: 12639.769949776786
Testing VaDE... Epoch: 7, Loss: 12514.07017299107, Acc: 39.1194

Training VaDE... Epoch: 65, Loss: 8926.201171875
Testing VaDE... Epoch: 65, Loss: 8892.917410714286, Acc: 37.735849056603776
Training VaDE... Epoch: 66, Loss: 8874.125279017857
Testing VaDE... Epoch: 66, Loss: 8894.3525390625, Acc: 37.23270440251572
Training VaDE... Epoch: 67, Loss: 8809.151646205357
Testing VaDE... Epoch: 67, Loss: 8856.907924107143, Acc: 37.61006289308176
Training VaDE... Epoch: 68, Loss: 8830.72349330357
Testing VaDE... Epoch: 68, Loss: 8859.548130580357, Acc: 37.9874213836478
Training VaDE... Epoch: 69, Loss: 8833.32798549107
Testing VaDE... Epoch: 69, Loss: 8800.64662388393, Acc: 37.10691823899371
Training VaDE... Epoch: 70, Loss: 8772.373325892857
Testing VaDE... Epoch: 70, Loss: 8722.0322265625, Acc: 36.477987421383645
Training VaDE... Epoch: 71, Loss: 8739.1005859375
Testing VaDE... Epoch: 71, Loss: 8735.564174107143, Acc: 37.484276729559745
Training VaDE... Epoch: 72, Loss: 8738.771205357143
Testing VaDE... Epoch: 72, Loss: 8702.98842075893, Acc: 37.4842767295

Training VaDE... Epoch: 130, Loss: 8057.105538504465
Testing VaDE... Epoch: 130, Loss: 8093.705636160715, Acc: 34.71698113207547
Training VaDE... Epoch: 131, Loss: 8053.224051339285
Testing VaDE... Epoch: 131, Loss: 8039.231515066965, Acc: 34.21383647798742
Training VaDE... Epoch: 132, Loss: 8049.927594866072
Testing VaDE... Epoch: 132, Loss: 8051.276785714285, Acc: 35.22012578616352
Training VaDE... Epoch: 133, Loss: 8076.154994419643
Testing VaDE... Epoch: 133, Loss: 8040.512276785715, Acc: 34.71698113207547
Training VaDE... Epoch: 134, Loss: 7985.143484933035
Testing VaDE... Epoch: 134, Loss: 8052.878138950893, Acc: 33.9622641509434
Training VaDE... Epoch: 135, Loss: 8030.164132254465
Testing VaDE... Epoch: 135, Loss: 8038.315848214285, Acc: 35.22012578616352
Training VaDE... Epoch: 136, Loss: 8066.943289620535
Testing VaDE... Epoch: 136, Loss: 8038.985212053572, Acc: 34.59119496855346
Training VaDE... Epoch: 137, Loss: 8070.670828683035
Testing VaDE... Epoch: 137, Loss: 8049.733468

Training VaDE... Epoch: 194, Loss: 7856.3837890625
Testing VaDE... Epoch: 194, Loss: 7864.360491071428, Acc: 32.83018867924528
Training VaDE... Epoch: 195, Loss: 7865.822684151785
Testing VaDE... Epoch: 195, Loss: 7867.193708147322, Acc: 33.9622641509434
Training VaDE... Epoch: 196, Loss: 7869.461774553572
Testing VaDE... Epoch: 196, Loss: 7822.260602678572, Acc: 33.9622641509434
Training VaDE... Epoch: 197, Loss: 7883.801339285715
Testing VaDE... Epoch: 197, Loss: 7833.176339285715, Acc: 32.83018867924528
Training VaDE... Epoch: 198, Loss: 7851.275809151785
Testing VaDE... Epoch: 198, Loss: 7879.449079241072, Acc: 34.59119496855346
Training VaDE... Epoch: 199, Loss: 7891.806780133928
Testing VaDE... Epoch: 199, Loss: 7808.944475446428, Acc: 33.83647798742138


In [None]:
z_dim = 10
ftr_ext = vade.feature_extractor
model = vade.VaDE
z, y = get_latent_space(dataloader, z_dim, model, device, ftr_ext)
z, y = z.cpu(), y.cpu()
z_embedded = TSNE(n_components=2).fit_transform(z.detach().numpy())
plot_tsne(z_embedded, y, classes)