In [1]:
%matplotlib inline
import torch
from sebm.models import Encoder, Decoder
from torchvision import datasets, transforms
from sebm.data import load_data
dataset =  'mnist' # 'svhn' # 'cifar10' # 'mnist' #  'flowers102' #
if dataset == 'mnist' or dataset == 'fashionmnist':
    input_channels, im_height, im_width = 1, 28, 28
else:
    input_channels, im_height, im_width = 3, 32, 32
    
device = torch.device('cuda:1')
arch =  'simplenet2' # 'mlp'
lr = 1e-3
seed = 1
latent_dim = 128
activation = 'ReLU'
reparameterized = True
heldout_class = -1
load_version = 'vae-out=%s-d=%s-seed=%s-lr=%s-zd=%s-act=%s-arch=%s' % (heldout_class, dataset, seed, lr, latent_dim, activation, arch)
data_dir = '/home/hao/Research/sebm_data/'
if arch == 'simplenet2':
    if dataset == 'cifar10' or dataset == 'svhn':
        enc = Encoder(arch=arch,
                      reparameterized=reparameterized,
                      im_height=im_height, 
                      im_width=im_width, 
                      input_channels=input_channels, 
                      channels=[64,128,256,512], 
                      kernels=[3,4,4,4], 
                      strides=[1,2,2,2], 
                      paddings=[1,1,1,1], 
                      hidden_dim=[128],
                      latent_dim=latent_dim,
                      activation=activation)

        dec = Decoder(arch=arch,
                      device=device,
                      im_height=im_height, 
                      im_width=im_width, 
                      input_channels=input_channels, 
                      channels=[64,128,256,512], 
                      kernels=[3,4,4,4], 
                      strides=[1,2,2,2], 
                      paddings=[1,1,1,1], 
                      mlp_input_dim=latent_dim, ## TODO: hand-coded for now
                      hidden_dim=[128],
                      mlp_output_dim=8192,
                      activation=activation)
    elif dataset == 'mnist' or dataset == 'fashionmnist':
        enc = Encoder(arch=arch,
                      reparameterized=reparameterized,
                      im_height=im_height, 
                      im_width=im_width, 
                      input_channels=input_channels, 
                      channels=[64,64,32,32], 
                      kernels=[3,4,4,4], 
                      strides=[1,2,2,2], 
                      paddings=[1,1,1,1], 
                      hidden_dim=[128],
                      latent_dim=latent_dim,
                      activation=activation)
        dec = Decoder(arch=arch,
                      device=device,
                      im_height=im_height, 
                      im_width=im_width, 
                      input_channels=input_channels, 
                      channels=[64,64,32,32], 
                      kernels=[3,4,4,4], 
                      strides=[1,2,2,2], 
                      paddings=[0,0,1,1], 
                      mlp_input_dim=latent_dim, ## TODO: hand-coded for now
                      hidden_dim=[128],
                      mlp_output_dim=288,
                      activation=activation)
    else:
        raise NotImplementError
        
else:
    raise NotImplementError
    
enc = enc.cuda().to(device)  
dec = dec.cuda().to(device)
print('Loading trained models...')
enc.load_state_dict(torch.load('../weights/cp-%s' % load_version)['enc_state_dict'])
dec.load_state_dict(torch.load('../weights/cp-%s' % load_version)['dec_state_dict'])
# for p in enc.parameters():
#     p.requires_grad = False
# for p in dec.parameters():
#     p.requires_grad = False

Loading trained models...


<All keys matched successfully>

In [2]:
from sebm.eval import *
evaluator = Evaluator_VAE(enc, dec, arch, device, dataset, data_dir)

In [None]:
from tqdm import tqdm
NUM_SHOTs = [1, 10, 100]
NUM_NEIGHBORs = [1, 5, 10]
algo_name, seed, gamma, max_iter = 'lp', 1, 10, 30

for num_shots in NUM_SHOTs:
    for n_neighbors in tqdm(NUM_NEIGHBORs):
        accu = label_propagation(algo_name, evaluator, num_shots, seed, gamma, n_neighbors, max_iter)
        fout = open('label_propagation_accuracy.txt', 'a+')
        print('model=vae, accuracy=%s, algo=%s, seed=%d, gamma=%s, max_iter=%d, num_shots=%d, n_neighbors=%d' % (accu, algo_name, seed, gamma, max_iter, num_shots, n_neighbors), file=fout)
        fout.close()

  0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
# similarity_ebm_z_space(evaluator, 
#                        train_batch_size=5000, 
#                        test_batch_size=5000,
#                        model_name='vae')

In [None]:
# ys_test, pred_ys_test = evaluator.similarity_ebm_density_space(train_batch_size=200, test_batch_size=200)

In [None]:
# paired = torch.cat((ys_test.unsqueeze(-1), pred_ys_test), -1)
# torch.save(paired, 'confusion_matrix_labels_vae_z_%s.pt' % dataset)

In [None]:
# metrics =  ['vae_10z', 'vae_50z', 'vae_128z']#, 'cebm_z', 'igebm_z']
# datasets = ['mnist', 'cifar10']
# import torch
# import numpy as np
# import matplotlib.pyplot as plt
# for d in datasets:
#     for m in metrics:
#         plot_confusion_matrix(d, m)

In [None]:
def plot_confusion_matrix(dataset, metric):
    l = torch.load('confusion/logging/confusion_matrix_labels_%s_%s.pt' % (metric, dataset)).long()
    cm = torch.zeros((10, 10))
    for i in range(len(l)):
        cm[l[i, 0], l[i, 1]] += 1
    for j in range(len(cm)):
        cm[j] /= cm[j].sum()
    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111)
    im = ax.imshow(cm, cmap='inferno', vmax=1.0, vmin=0.0)
    ax.set_title('cm_%s_%s \n average diagonals=%.2f' % (dataset, metric, torch.diag(cm).mean().item()), fontsize=14)
    plt.colorbar(im)
    ticks = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
    ax.set_xticks(np.arange(len(ticks)))
    ax.set_yticks(np.arange(len(ticks)))
    ax.set_xticklabels(ticks)
    ax.set_yticklabels(ticks)
    ax.set_ylabel('True Class Labels', fontsize=14)
    ax.set_xlabel('Predicted Class Labels', fontsize=14)
    for i in range(len(ticks)):
        for j in range(len(ticks)):
            text = ax.text(j, i, round(cm[i, j].item(), 3),
                           ha="center", va="center", color=("k" if i == j else 'w'))
    plt.savefig('confusion/figures/cm_%s_%s.png' % (dataset, metric))

In [None]:
# num_shots = 100
# num_runs = 10
# Accu = []
# for i in range(num_runs):
#     print('dataset=%s, run=%d / %d' % (dataset, i+1, num_runs))
#     data = torch.load('/home/hao/Research/sebm_data/fewshots/%s/%d/%d.pt' % (dataset, num_shots*10, i+1))
#     accu = train_logistic_classifier(evaluator, train_data=data)
#     Accu.append(np.array([accu]))
# Accu = np.concatenate(Accu)
# print('mean=%.4f, std=%.4f' % (Accu.mean(), Accu.std()))

In [None]:
# evaluator.oodauc(dataset_ood='svhn', score='marginal', sample_size=100, batch_size=20)

In [None]:
evaluator.plot_oods(dataset='fashionmnist', train=False, score='marginal', sample_size=100, batch_size=20, density=True, save=True)

In [None]:
# import numpy as np
# test_batch_size = 10
# num_random_walks = 20
# test_batch, rand_inds = draw_one_batch(dataset, 
#                                          data_dir, 
#                                          train=False, 
#                                          normalize=False, 
#                                          flatten=False,
#                                          rand_inds=np.array([413, 659, 869, 653, 538,  97, 447, 694, 319, 860]))

# images_samples = evaluator.random_walks(num_random_walks, 
#                                         test_batch, 
#                                         sample=True)
# nearestneighbours = evaluator.nn_latents(images_samples)
# plot_evolving_samples(images_samples, nearestneighbours, fs=10)

In [None]:
images, recons = evaluator.test_one_batch(batch_size=10)
evaluator.plot_samples(images, recons)

In [None]:
# train_logistic_classifier(evaluator)

In [None]:
import matplotlib
import matplotlib.pyplot as plt
import numpy as np


labels = ['MNIST', 'FMNIST', 'SVHN','CIFAR10']
cebm = [0.9755, 0.8306, 0.5758, 0.4699]
vae = [0.9409, 0.7963, 0.3434, 0.4066]
mlp = [0.9724, 0.8599, 0.2559, 0.1002]
cnn = [0.9890, 0.9071, 0.9228, 0.6246]
x = np.arange(len(labels))  # the label locations
width = 0.2  # the width of the bars

fig = plt.figure(figsize=(16, 8)) 
ax = fig.add_subplot(111)
colors = ['#0077BB', '#EE7733', '#009988', '#AA3377', '#BBBBBB', '#EE3377', '#DDCC77']
rects1 = ax.bar(x - 2 * width, cebm, width, color=colors[0], label='CEBM')
rects2 = ax.bar(x - width, vae, width, color=colors[1], label='VAE')
rects3 = ax.bar(x, mlp, width, color=colors[2], label='MLP_CLF')
rects4 = ax.bar(x + width, cnn, width, color=colors[3], label='CNN_CLF')
# Add some text for labels, title and custom x-axis tick labels, etc.
ax.set_ylabel('Accuracy', fontsize=18)
ax.set_xlabel('Datasets', fontsize=18)
ax.set_title('Logistic Regressor', fontsize=18)
ax.set_xticks(x)
ax.set_xticklabels(labels, fontsize=18)
ax.legend(fontsize=14)
def autolabel(rects):
    """Attach a text label above each bar in *rects*, displaying its height."""
    for rect in rects:
        height = rect.get_height()
        ax.annotate('{}'.format(height),
                    xy=(rect.get_x() + rect.get_width() / 2, height),
                    xytext=(0, 3),  # 3 points vertical offset
                    textcoords="offset points",
                    ha='center', va='bottom', fontsize=14)
autolabel(rects1)
autolabel(rects2)
autolabel(rects3)
autolabel(rects4)
fig.tight_layout()
fig.savefig('Logistic_Regressor.png')