In [1]:
import argparse
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

import torch
import numpy as np 

from preprocess import get_mnist, get_webcam
from train import TrainerVaDE

In [2]:
class Args:
    batch_size = 128
    lr = 1e-5
    dataset = 'webcam'
    pretrain = True
    epochs = 200
    n_shots = 10
    sup_mul = 0.9
    cl_mul = 100


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    
args = Args() # Parsing all the arguments for the training
if args.dataset == 'mnist':
    dataloader_sup, dataloader_unsup, dataloader_test = get_mnist(args)
    n_classes = 10
else:
    dataloader_sup, dataloader_unsup, dataloader_test = get_webcam(args)
    n_classes = 31

In [3]:
vade = TrainerVaDE(args, device, dataloader_sup, dataloader_unsup, dataloader_test, n_classes)

In [4]:
if args.dataset == 'webcam':
    classes = ['back_pack',
                'bike',
                'bike_helmet',
                'bookcase',
                'bottle',
                'calculator',
                'desk_chair',
                'desk_lamp',
                'desktop_computer',
                'file_cabinet',
                'headphones',
                'keyboard',
                'laptop_computer',
                'letter_tray',
                'mobile_phone',
                'monitor',
                'mouse',
                'mug',
                'paper_notebook',
                'pen',
                'phone',
                'printer',
                'projector',
                'punchers',
                'ring_binder',
                'ruler',
                'scissors',
                'speaker',
                'stapler',
                'tape_dispenser',
                'trash_can']
else:
    classes = ['0',
               '1',
               '2',
               '3',
               '4',
               '5',
               '6',
               '7',
               '8',
               '9']


def get_latent_space(dataloader, z_dim, model, device, ftr_ext=None):
    z = torch.zeros((1, z_dim)).float().to(device)
    y = torch.zeros((1)).long().to(device)
    with torch.no_grad():
        for img, label in dataloader:
            img, label = img.to(device).float(), label.to(device).long()
            if ftr_ext is not None:
                img = ftr_ext(img); img = img.detach()

            mu, log_var = model.encode(img)
            z_l = model.reparameterize(mu, log_var)
            y = torch.cat((y, label), dim=0)
            z = torch.cat((z, z_l), dim=0)
    return z[1:], y[1:]


def plot_tsne(X_embedded, y, ticks, dataset):
    f, ax1 = plt.subplots(1, 1, sharey=True, figsize=(15,10))

    cmap = plt.get_cmap('jet', 31)


    cax = ax1.scatter(X_embedded[:, 0], X_embedded[:, 1], c=y.numpy(),
                      s=15, cmap=cmap)

    cbar = f.colorbar(cax, ticks=np.linspace(0,30,31))
    cbar.ax.set_yticklabels(ticks)

    ax1.xaxis.set_visible(False)
    ax1.yaxis.set_visible(False)
    plt.savefig('weights/vade_tsne_{}_ss.png'.format(dataset))
    plt.show()

In [None]:
vade.train()

Training VaDE...




Testing VaDE... Epoch: -1, Loss: 21.04513692855835, Acc: 73.26345915841584
Training VaDE... Epoch: 0, Loss: 4065.2872314453125, Acc: 70.0
Testing VaDE... Epoch: 0, Loss: 21.212724685668945, Acc: 74.20134591584159
Training VaDE... Epoch: 1, Loss: 4009.570556640625, Acc: 70.64516129032258
Testing VaDE... Epoch: 1, Loss: 21.345197200775146, Acc: 73.70629641089108
Training VaDE... Epoch: 2, Loss: 3963.1649169921875, Acc: 69.6774193548387
Testing VaDE... Epoch: 2, Loss: 21.65846872329712, Acc: 74.09692141089108
Training VaDE... Epoch: 3, Loss: 3910.8416748046875, Acc: 73.87096774193549
Testing VaDE... Epoch: 3, Loss: 21.803311824798584, Acc: 73.02947091584159
Training VaDE... Epoch: 4, Loss: 3863.6785888671875, Acc: 74.51612903225806
Testing VaDE... Epoch: 4, Loss: 22.179669857025146, Acc: 72.92504641089108
Training VaDE... Epoch: 5, Loss: 3814.556640625, Acc: 70.96774193548387
Testing VaDE... Epoch: 5, Loss: 22.33392906188965, Acc: 73.26345915841584
Training VaDE... Epoch: 6, Loss: 3757.98

Training VaDE... Epoch: 54, Loss: 2705.3489990234375, Acc: 55.80645161290323
Testing VaDE... Epoch: 54, Loss: 35.151594161987305, Acc: 57.80669863861387
Training VaDE... Epoch: 55, Loss: 2699.88623046875, Acc: 56.451612903225815
Testing VaDE... Epoch: 55, Loss: 35.23458003997803, Acc: 55.64472462871287
Training VaDE... Epoch: 56, Loss: 2689.6129150390625, Acc: 57.41935483870968
Testing VaDE... Epoch: 56, Loss: 35.367919921875, Acc: 54.91568688118812
Training VaDE... Epoch: 57, Loss: 2694.8812255859375, Acc: 54.516129032258064
Testing VaDE... Epoch: 57, Loss: 35.786593437194824, Acc: 54.838335396039604
Training VaDE... Epoch: 58, Loss: 2671.472900390625, Acc: 51.935483870967744
Testing VaDE... Epoch: 58, Loss: 36.06625747680664, Acc: 55.56737314356435
Training VaDE... Epoch: 59, Loss: 2682.850830078125, Acc: 55.80645161290323
Testing VaDE... Epoch: 59, Loss: 36.06271553039551, Acc: 55.553836633663366
Training VaDE... Epoch: 60, Loss: 2681.6112060546875, Acc: 52.903225806451616
Testing V

Testing VaDE... Epoch: 108, Loss: 42.56789779663086, Acc: 43.32650061881188
Training VaDE... Epoch: 109, Loss: 2553.79736328125, Acc: 47.41935483870968
Testing VaDE... Epoch: 109, Loss: 42.694539070129395, Acc: 41.80267636138614
Training VaDE... Epoch: 110, Loss: 2568.9383544921875, Acc: 37.41935483870968
Testing VaDE... Epoch: 110, Loss: 42.65861892700195, Acc: 41.62090037128713
Training VaDE... Epoch: 111, Loss: 2562.2939453125, Acc: 42.25806451612903
Testing VaDE... Epoch: 111, Loss: 42.680277824401855, Acc: 43.61270111386139
Training VaDE... Epoch: 112, Loss: 2565.1160888671875, Acc: 39.03225806451613
Testing VaDE... Epoch: 112, Loss: 42.80706214904785, Acc: 41.95931311881188
Training VaDE... Epoch: 113, Loss: 2567.2843017578125, Acc: 43.225806451612904
Testing VaDE... Epoch: 113, Loss: 42.84471797943115, Acc: 43.37871287128713
Training VaDE... Epoch: 114, Loss: 2561.0894775390625, Acc: 41.29032258064516
Testing VaDE... Epoch: 114, Loss: 43.110740661621094, Acc: 40.57858910891089
T

Training VaDE... Epoch: 162, Loss: 2521.3897705078125, Acc: 34.516129032258064
Testing VaDE... Epoch: 162, Loss: 47.03949546813965, Acc: 34.35372834158416
Training VaDE... Epoch: 163, Loss: 2528.0032958984375, Acc: 35.16129032258065
Testing VaDE... Epoch: 163, Loss: 46.8110237121582, Acc: 33.220529084158414
Training VaDE... Epoch: 164, Loss: 2525.6617431640625, Acc: 36.12903225806451
Testing VaDE... Epoch: 164, Loss: 46.7675666809082, Acc: 33.32495358910891
Training VaDE... Epoch: 165, Loss: 2522.4698486328125, Acc: 34.193548387096776
Testing VaDE... Epoch: 165, Loss: 46.938920974731445, Acc: 32.686803836633665
Training VaDE... Epoch: 166, Loss: 2515.173095703125, Acc: 35.483870967741936
Testing VaDE... Epoch: 166, Loss: 46.9022331237793, Acc: 33.57247834158416
Training VaDE... Epoch: 167, Loss: 2519.9840087890625, Acc: 32.903225806451616
Testing VaDE... Epoch: 167, Loss: 47.05932140350342, Acc: 34.45815284653465
Training VaDE... Epoch: 168, Loss: 2523.195068359375, Acc: 35.80645161290

In [None]:
z_dim = 10
model = vade.VaDE
z, y = get_latent_space(dataloader_test, z_dim, model, device, vade.feature_extractor)
z, y = z.cpu(), y.cpu()
z_embedded = TSNE(n_components=2).fit_transform(z.detach().numpy())


In [None]:
plot_tsne(z_embedded, y, classes, args.dataset)

In [None]:
acc = np.array(vade.acc)
acc_t = np.array(vade.acc_t)
rec = np.array(vade.rec)
rec_t = np.array(vade.rec_t)
dkl = np.array(vade.dkl)
dkl_t = np.array(vade.dkl_t)

def plot_loss(values, values_t, metric, dataset):
    plt.plot(np.arange(len(values)), values, c='k', label='train')
    plt.plot(np.arange(len(values_t)), values_t, c='b', label='test')
    plt.title('VaDE {}'.format(metric))
    plt.ylabel(metric)
    plt.xlabel('Epoch')
    plt.legend(loc='best')
    plt.grid(True)
    plt.savefig('weights/vade_{}_{}_ss'.format(metric, dataset))

In [None]:
plot_loss(acc, acc_t, 'Accuracy', args.dataset)

In [None]:
plot_loss(rec, rec_t, 'Reconstruction', args.dataset)

In [None]:
plot_loss(dkl[2:], dkl_t[2:], 'DKL', args.dataset)