In [None]:
%matplotlib inline
import torch
from sebm.sgld import SGLD_sampler
from sebm.models import CEBM_GMM_2ss
dataset =  'cifar10' # 'svhn' # 'cifar10' # 'mnist' #  'flowers102' #
if dataset == 'mnist' or dataset =='fashionmnist':
    input_channels, im_height, im_width = 1, 28, 28
else:
    input_channels, im_height, im_width = 3, 32, 32
device = torch.device('cuda:0')
arch =  'simplenet2' 
ss = 2
seed =  3
# lr, reg_alpha = 1e-4, 1e-1 # 
lr ,reg_alpha = 5e-5, 5e-3 # 1e-1

optimize_priors = True
num_clusters = 20
if dataset == 'cifar10' or dataset == 'svhn':
    channels, kernels, strides, paddings =[64,128,256,512], [3,4,4,4], [1,2,2,2], [1,1,1,1]
    hidden_dim, latent_dim, activation = [1024], 128, 'Swish'
elif dataset == 'mnist' or dataset == 'fashionmnist':
    
    channels, kernels, strides, paddings =[64,64,32,32], [3,4,4,4], [1,2,2,2], [1,1,1,1]
    hidden_dim, latent_dim, activation = [128], 128, 'Swish'
else:
    raise NotImplementError

data_noise_std = 3e-2
sgld_noise_std, sgld_lr, sgld_num_steps = 7.5e-3, 2.0, 60
buffer_init, buffer_dup_allowed = True, True
data_dir = '../../../sebm_data/'
load_version = 'cebm_gmm_k=%d-d=%s-seed=%d-lr=%s-zd=%d-d_ns=%s-sgld-ns=%s-lr=%s-steps=%s-reg=%s-act=%s-arch=%s' % (num_clusters, dataset, seed, lr, latent_dim, data_noise_std, sgld_noise_std, sgld_lr, sgld_num_steps,  reg_alpha, activation, arch)
if arch == 'simplenet2':
    ebm = CEBM_GMM_2ss(K=num_clusters,
                    arch=arch,
                    optimize_priors=optimize_priors,
                    device=device,
                    im_height=im_height, 
                    im_width=im_width, 
                    input_channels=input_channels, 
                    channels=channels, 
                    kernels=kernels, 
                    strides=strides, 
                    paddings=paddings, 
                    hidden_dim=hidden_dim,
                    latent_dim=latent_dim,
                    activation=activation)
else:
    raise NotImplementError
ebm = ebm.cuda().to(device)
print('Loading trained weights..')
weights = torch.load('../weights/weights/1012/cp-%s' % load_version)
ebm.load_state_dict(weights['model_state_dict'])
# if not optimize_priors:
#     ebm.prior_mu = weights['prior_mu'].to(device)
#     ebm.prior_log_sigma = weights['prior_log_sigma'].to(device)

In [None]:
from sebm.eval import *
evaluator = Evaluator_EBM_GMM(ebm, device, dataset, data_dir, data_noise_std=1e-2)

In [None]:
fewshots(model_name='cebm_gmm', evaluator=evaluator)

In [None]:
# evaluator.oodauc(dataset_ood='texture', score='energy')

In [None]:
# similarity_z_space(evaluator, train_batch_size=2000, test_batch_size=2000, model_name='cebm_gmm')

In [None]:
# train_batch_size = 1000
# for i in range(10):
#     print('processing on dataset %d' % (i+1))
#     test_data = torch.load('/home/hao/Research/sebm_data/overviewfig/%s/100/%d.pt' % (evaluator.dataset, i+1))
#     min_distances, min_labels, nns = similarity_z_space_fewshots(evaluator, train_batch_size, test_data)
#     plot_nearest_neighbors(test_data, nns, min_labels, min_distances, fs=1, save_name='%s_cebm_gmm_%d' % (dataset, (i+1)))

In [None]:
# semi_nn_clf(model_name='cebm_gmm', device=device, evaluator=evaluator, num_epochs=1)

In [1]:
print('%.2f' % 0.5255)

0.53


In [None]:
metrics =  ['pixel','vae_z', 'igebm_z', 'cebm_z']
datasets = ['cifar10']
import torch
import numpy as np
import matplotlib.pyplot as plt
# for d in datasets:
#     for m in metrics:
plot_confusion_matrices(datasets, metrics)

In [None]:
def plot_confusion_matrices(dataset, metrics, fs=5):
    for k, d in enumerate(datasets):
        fig = plt.figure(figsize=(len(metrics)*fs, fs))
        for i, m in enumerate(metrics):
            ax = fig.add_subplot(1, len(metrics), i+1)
            plot_confusion_matrix(ax, d, m, vmax=1.0)
        plt.savefig('confusion/figures/cm_%s.png' % d)

def plot_confusion_matrix(ax, dataset, metric, vmax=1.0):
    l = torch.load('confusion/logging/confusion_matrix_labels_%s_%s.pt' % (metric, dataset)).long()
    cm = torch.zeros((10, 10))
    for i in range(len(l)):
        cm[l[i, 0], l[i, 1]] += 1
    for j in range(len(cm)):
        cm[j] /= cm[j].sum()
#     fig = plt.figure(figsize=(10, 10))
#     ax = fig.add_subplot(111)
    im = ax.imshow(cm, cmap='inferno', vmax=vmax, vmin=0.0)
    if metric == 'pixel':
        model = 'Pixel'
    elif metric == 'vae_z':
        model = 'VAE'
    elif metric == 'igebm_z':
        model = 'IGEBM'
    elif metric == 'cebm_z' or 'cebm_z2':
        model = 'CEBM'
    elif metric =='cebm_gmm_z':
        model = 'CEBMM'
    else:
        raise NotImplementError
    ax.set_title('%s' % model, fontsize=20)
#     plt.colorbar(im)
    ax.tick_params(labelsize=12)
    ticks = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
    ax.set_xticks(np.arange(len(ticks)))
    ax.set_yticks(np.arange(len(ticks)))
    ax.set_xticklabels(ticks)
    ax.set_yticklabels(ticks)
    ax.set_ylabel('True Class Labels', fontsize=16)
    ax.set_xlabel('Neighbor Class Labels', fontsize=16)
    for i in range(len(ticks)):
        for j in range(len(ticks)):
            text = ax.text(j, i, round(cm[i, j].item(), 2),
                           ha="center", va="center", color=("k" if i == j else 'w'))
#     plt.close()

In [None]:
datasets = ['svhn', 'texture', 'constant_rgb']
scores = ['energy', 'gradient']
for d in datasets:
    for s in scores:
        auroc = evaluator.oodauc(dataset_ood=d, score=s)
        print('dataset=%s, score=%s, auroc=%.3f' % (d,s,auroc))

In [None]:
# logging_interval = 200 # None otherwise
# test_batch_size = 10
# init_samples, labels = draw_one_batch(1, dataset, data_dir, train=False, normalize=True, flatten=False)

# images_ebm = evaluator.uncond_sampling_ll(sample_size=3,
#                                           batch_size=test_batch_size, 
#                                           sgld_steps=2000,
#                                           sgld_lr=2.0,
#                                           sgld_noise_std=1e-3,
#                                           init_images=init_samples,
#                                           grad_clipping=False,
#                                           logging_interval=None)
# nearestneighbours = evaluator.nn_latents(images_ebm)
# plot_evolving_samples(images_ebm, nearestneighbours, fs=8, save_name=None)

In [None]:
# evaluator.plot_all_samples(list_images=images_ebm, fs=1.0)

In [None]:
images_ebm = evaluator.uncond_sampling(batch_size=100, 
                                          sgld_steps=1000,
                                          sgld_lr=2.0,
                                          sgld_noise_std=1e-3,
                                          grad_clipping=False,
                                          init_samples=None,
                                          logging_interval=None)
evaluator.plot_final_samples(images_ebm, fs=5, save=False)

In [None]:
# evaluator.plot_all_samples(images_ebm, fs=6, save=True)

In [None]:
# train_logistic_classifier(evaluator)

In [None]:
zs, ys = evaluator.compute_tsne()
evaluator.plot_tsne(zs, ys, save=True)