In [1]:
import argparse
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

import torch
import numpy as np 

from preprocess import get_mnist, get_webcam
from train import TrainerVaDE

In [2]:
class Args:
    batch_size = 128
    dataset = 'mnist'
    pretrained_path = 'weights/pretrained_parameter.pth'


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


args = Args() # Parsing all the arguments for the training
if args.dataset == 'mnist':
    dataloader_train, dataloader_test = get_mnist(args)
    n_classes = 10
else:
    dataloader_train, dataloader_test = get_webcam(args)
    n_classes = 31

In [3]:
if args.dataset == 'webcam':
    from models_office import Autoencoder, feature_extractor, VaDE
    VaDE = VaDE().to(device)
    autoencoder = Autoencoder().to(device)
    autoencoder.load_state_dict(torch.load('weights/autoencoder_parameters_webcam.pth.tar',
                                    map_location=device)['state_dict'])
    
    checkpoint = torch.load('weights/feature_extractor_params.pth.tar',
                             map_location=device)
    feature_extractor = feature_extractor().to(device)
    feature_extractor.load_state_dict(checkpoint['state_dict'])
    
elif args.dataset == 'mnist':
    from models import Autoencoder, VaDE
    VaDE = VaDE().to(device)
    autoencoder = Autoencoder().to(device)
    autoencoder.load_state_dict(torch.load('weights/autoencoder_parameters_mnist.pth.tar',
                                    map_location=device)['state_dict'])

In [4]:
import numpy as np 


if args.dataset == 'webcam':
    classes = ['back_pack',
                'bike',
                'bike_helmet',
                'bookcase',
                'bottle',
                'calculator',
                'desk_chair',
                'desk_lamp',
                'desktop_computer',
                'file_cabinet',
                'headphones',
                'keyboard',
                'laptop_computer',
                'letter_tray',
                'mobile_phone',
                'monitor',
                'mouse',
                'mug',
                'paper_notebook',
                'pen',
                'phone',
                'printer',
                'projector',
                'punchers',
                'ring_binder',
                'ruler',
                'scissors',
                'speaker',
                'stapler',
                'tape_dispenser',
                'trash_can']
else:
    classes = ['0',
               '1',
               '2',
               '3',
               '4',
               '5',
               '6',
               '7',
               '8',
               '9']


def get_latent_space(dataloader, z_dim, model, device, ftr_ext=None):
    z = torch.zeros((1, z_dim)).float().to(device)
    y = torch.zeros((1)).long().to(device)
    with torch.no_grad():
        for img, label in dataloader:
            img, label = img.to(device).float(), label.to(device).long()
            if ftr_ext is not None:
                img = ftr_ext(img); img = img.detach()

            z_l = model.encode(img)
            y = torch.cat((y, label), dim=0)
            z = torch.cat((z, z_l), dim=0)
    return z[1:], y[1:]


def plot_tsne(X_embedded, y, ticks):
    f, ax1 = plt.subplots(1, 1, sharey=True, figsize=(15,10))

    cmap = plt.get_cmap('jet', 31)


    cax = ax1.scatter(X_embedded[:, 0], X_embedded[:, 1], c=y.numpy(),
                      s=15, cmap=cmap)

    cbar = f.colorbar(cax, ticks=np.linspace(0,30,31))
    cbar.ax.set_yticklabels(ticks)

    ax1.xaxis.set_visible(False)
    ax1.yaxis.set_visible(False)

    plt.show()

In [5]:
z_dim = 10

z, y = get_latent_space(dataloader_train, z_dim, autoencoder, device)
z, y = z.cpu(), y.cpu()
#z_embedded = TSNE(n_components=2).fit_transform(z.detach().numpy())

In [6]:
#plot_tsne(z_embedded, y, classes)

In [7]:
means = []
var = []
proportion = []
for i in range(n_classes):
    ixs = np.where(y.cpu().numpy() == i)
    means.append(torch.mean(z[ixs].detach(), dim=0))
    var.append(torch.std(z[ixs].detach(), dim=0)**2)
    proportion.append(len(ixs))

In [8]:
means = torch.stack(means)
means

tensor([[-2.7225,  1.5110, -1.6059,  0.1572, -1.4409,  2.5070, -0.2138, -1.1498,
         -0.4085, -3.0455],
        [ 0.5270,  0.3365, -1.3417, -2.1833, -0.0318, -0.1525,  0.3218,  2.1061,
         -0.4019,  0.5672],
        [-1.2752,  1.0957, -1.7303, -0.7026, -1.9445,  0.0119, -0.8125,  2.1270,
          1.2704, -0.2549],
        [-0.1403,  0.9838, -0.9453, -0.3247, -0.3575,  2.7426, -1.7609,  1.1172,
          0.0941,  0.0744],
        [-0.2172, -0.0385, -0.9906, -0.2127, -0.1067,  0.8831,  2.7958,  0.5324,
          1.3305, -0.3235],
        [ 0.5567, -0.6318, -1.9012, -0.0334, -0.7016,  2.5153, -0.1153,  0.2810,
         -0.2717, -1.1620],
        [-1.7894, -0.4245, -0.3524, -2.4896, -0.7695,  1.4614,  0.7426,  0.2212,
          0.7004, -1.5768],
        [-0.3239,  2.3009, -1.5938,  0.5262, -0.7228,  1.1730,  1.8686,  1.1445,
         -1.1222,  0.7909],
        [-0.3210,  0.4511, -1.7597,  0.0295,  0.9434,  0.9218, -0.1502,  1.6866,
          0.5826, -0.9346],
        [-0.3713,  

In [9]:
var = torch.stack(var)
var

tensor([[1.4790, 1.3188, 3.9344, 2.2020, 3.5966, 1.1384, 1.1770, 1.3883, 1.2251,
         1.8249],
        [0.2681, 0.3172, 2.7721, 0.4057, 0.3317, 0.3890, 0.2415, 0.4374, 0.3806,
         0.3090],
        [2.2401, 1.3323, 2.4983, 1.8782, 2.0342, 1.0204, 1.2757, 1.8976, 1.5967,
         1.6970],
        [1.0718, 0.8199, 2.5741, 1.0362, 1.1774, 1.5073, 1.1107, 0.8607, 0.8168,
         0.8794],
        [0.7077, 0.5010, 2.9294, 1.2037, 1.0465, 0.9659, 0.7297, 0.7875, 0.9520,
         1.2278],
        [1.7126, 1.1105, 4.1666, 1.0752, 1.6680, 1.4329, 0.9751, 0.9790, 0.7626,
         0.9162],
        [1.6163, 1.4314, 2.0344, 1.4047, 1.4621, 0.9662, 0.5629, 0.9303, 1.3322,
         1.2190],
        [0.9252, 1.3466, 2.9539, 0.9515, 1.0446, 1.5429, 1.2181, 1.2913, 0.7892,
         0.8701],
        [0.6719, 0.8895, 2.7315, 0.9546, 0.7133, 0.7882, 0.6740, 0.9194, 0.7710,
         0.7292],
        [0.6304, 0.3315, 2.5566, 0.6841, 0.7056, 0.8292, 0.8711, 0.4241, 0.3767,
         1.1552]])

In [10]:
proportion = torch.Tensor(proportion)/torch.sum(torch.Tensor(proportion))
proportion

tensor([0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
        0.1000])

In [11]:
state_dict = autoencoder.state_dict()

VaDE.load_state_dict(state_dict=state_dict, strict=False)
VaDE.pi_prior.data = proportion.float().to(device)
VaDE.mu_prior.data = means.float().to(device)
VaDE.log_var_prior.data = torch.log(var).float().to(device)
torch.save(VaDE.state_dict(), 'weights/pretrained_parameters_{}.pth'.format(args.dataset)) 

In [12]:
proportion

tensor([0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
        0.1000])