In [1]:
import argparse
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.mixture import GaussianMixture

import torch
import numpy as np 

from preprocess import get_mnist, get_webcam
from train import TrainerVaDE

In [2]:
class Args:
    batch_size = 128
    dataset = 'mnist'
    n_shots = 1


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    
args = Args() # Parsing all the arguments for the training
if args.dataset == 'mnist':
    dataloader_sup, dataloader_unsup, dataloader_test = get_mnist(args)
    n_classes = 10
else:
    dataloader_sup, dataloader_unsup, dataloader_test = get_webcam(args)
    n_classes = 31

In [3]:
if args.dataset == 'webcam':
    from models_office import Autoencoder, feature_extractor, VaDE
    VaDE = VaDE().to(device)
    autoencoder = Autoencoder().to(device)
    autoencoder.load_state_dict(torch.load('weights/autoencoder_parameters_webcam.pth.tar',
                                    map_location=device)['state_dict'])
    
    checkpoint = torch.load('weights/feature_extractor_params.pth.tar',
                             map_location=device)
    feature_extractor = feature_extractor().to(device)
    feature_extractor.load_state_dict(checkpoint['state_dict'])
    
elif args.dataset == 'mnist':
    from models import Autoencoder, VaDE
    VaDE = VaDE().to(device)
    autoencoder = Autoencoder().to(device)
    autoencoder.load_state_dict(torch.load('weights/autoencoder_parameters_mnist.pth.tar',
                                    map_location=device)['state_dict'])

In [4]:
def get_latent_space(dataloader, z_dim, model, device, ftr_ext=None):
    z = torch.zeros((1, z_dim)).float().to(device)
    y = torch.zeros((1)).long().to(device)
    with torch.no_grad():
        for img, label in dataloader:
            img, label = img.to(device).float(), label.to(device).long()
            if ftr_ext is not None:
                img = ftr_ext(img); img = img.detach()

            z_l = model.encode(img)
            y = torch.cat((y, label), dim=0)
            z = torch.cat((z, z_l), dim=0)
    return z[1:], y[1:]

In [5]:
z_dim = 10

z, _ = get_latent_space(dataloader_unsup, z_dim, autoencoder, device)
z = z.cpu()
gmm = GaussianMixture(n_components=n_classes, covariance_type='diag')
gmm.fit(z.cpu().detach().numpy())

GaussianMixture(covariance_type='diag', init_params='kmeans', max_iter=100,
        means_init=None, n_components=10, n_init=1, precisions_init=None,
        random_state=None, reg_covar=1e-06, tol=0.001, verbose=0,
        verbose_interval=10, warm_start=False, weights_init=None)

In [6]:
z_dim = 10
z, y = get_latent_space(dataloader_sup, z_dim, autoencoder, device)
z = z.cpu()

In [7]:
z = z[np.argsort(y.cpu())]
y = y[np.argsort(y.cpu())]

In [8]:
probas = gmm.predict_proba(z.cpu().detach().numpy())

In [9]:
mean_probas = []
if args.n_shots>1:
    for i in range(n_classes):
        ixs = np.where(y.cpu()==i)
        print(np.mean(probas[ixs], axis=0))
        mean_probas.append(np.mean(probas[ixs], axis=0))
    probas = np.array(mean_probas)
    

In [10]:
assignation = []
possibilities = np.arange(n_classes)
index = 0
toselect = 1
while len(possibilities)>0:
    sorted_ = np.argsort(probas[index])
    max_ = sorted_[-toselect]
    print('class {} has a top {} prob of {} in index {}'.format(index, toselect, probas[index][max_], max_))
    if max_ in possibilities:
        assignation.append(max_)
        possibilities = np.setdiff1d(possibilities, max_)
        index+=1
        toselect=1
    else:
        toselect+=1

print(assignation)

class 0 has a top 1 prob of 0.7557658816968944 in index 3
class 1 has a top 1 prob of 0.8884203770174495 in index 4
class 2 has a top 1 prob of 0.9480040803017324 in index 5
class 3 has a top 1 prob of 0.7680914726217992 in index 7
class 4 has a top 1 prob of 0.694569128063975 in index 5
class 4 has a top 2 prob of 0.22937955771153504 in index 2
class 5 has a top 1 prob of 0.7824726577986791 in index 1
class 6 has a top 1 prob of 0.8990052425637904 in index 8
class 7 has a top 1 prob of 0.7202281136228281 in index 6
class 8 has a top 1 prob of 0.9774442556105314 in index 0
class 9 has a top 1 prob of 0.9935317281901505 in index 2
class 9 has a top 2 prob of 0.006027929744999055 in index 6
class 9 has a top 3 prob of 0.0001330048348857446 in index 7
class 9 has a top 4 prob of 0.00012585562999016775 in index 3
class 9 has a top 5 prob of 0.0001197818686216407 in index 5
class 9 has a top 6 prob of 3.6687613794110633e-05 in index 0
class 9 has a top 7 prob of 2.3413853821174337e-05 in in

In [11]:
"""
print('Saving weights.')
state_dict = autoencoder.state_dict()

VaDE.load_state_dict(state_dict, strict=False)
VaDE.pi_prior.data = torch.from_numpy(gmm.weights_[assignation]
                                          ).float().to(device)
VaDE.mu_prior.data = torch.from_numpy(gmm.means_[assignation]
                                          ).float().to(device)
VaDE.log_var_prior.data = torch.log(torch.from_numpy(gmm.covariances_[assignation]
                                        )).float().to(device)
torch.save(VaDE.state_dict(), 'weights/pretrained_parameters_{}.pth'.format(args.dataset))
"""

"\nprint('Saving weights.')\nstate_dict = autoencoder.state_dict()\n\nVaDE.load_state_dict(state_dict, strict=False)\nVaDE.pi_prior.data = torch.from_numpy(gmm.weights_[assignation]\n                                          ).float().to(device)\nVaDE.mu_prior.data = torch.from_numpy(gmm.means_[assignation]\n                                          ).float().to(device)\nVaDE.log_var_prior.data = torch.log(torch.from_numpy(gmm.covariances_[assignation]\n                                        )).float().to(device)\ntorch.save(VaDE.state_dict(), 'weights/pretrained_parameters_{}.pth'.format(args.dataset))\n"

In [15]:
np.argmax(probas, axis=0)

array([8, 5, 9, 0, 1, 2, 7, 3, 6, 0])

In [23]:
probas[0]

array([4.51196284e-11, 5.77842629e-09, 2.23419953e-12, 7.55765882e-01,
       2.07838035e-44, 5.24618197e-04, 2.56798625e-14, 5.95797379e-04,
       4.43854928e-08, 2.43113653e-01])

In [24]:
probas[4]

array([9.67202338e-12, 3.29614253e-11, 2.29379558e-01, 7.60512946e-02,
       3.93501937e-35, 6.94569128e-01, 4.21843734e-09, 2.97154530e-09,
       1.23684190e-08, 5.43901515e-16])

In [25]:
probas[2]

array([2.75766209e-12, 6.42434183e-08, 2.23527836e-09, 5.19921139e-02,
       2.09831709e-24, 9.48004080e-01, 2.18216338e-10, 2.45080829e-07,
       3.35991252e-06, 1.34111296e-07])