In [None]:
import numpy as np

import torch
import torchvision
from torchvision import transforms, datasets
from torch import nn

from tqdm import tqdm

from utilities import gpu_util as gp
from utilities import visualize as vs
from utilities import distros as ds
from utilities import base_models as bm

from hyperopt import fmin, tpe, hp, STATUS_OK, Trials

from sklearn.preprocessing import OneHotEncoder
from sklearn.linear_model import LogisticRegression
from lightgbm import LGBMClassifier
from sklearn.metrics import accuracy_score, classification_report, \
                            confusion_matrix, roc_auc_score

from matplotlib import pyplot as plt

In [None]:
batch_size = 64

transform = transforms.Compose([transforms.Resize(64),
                                transforms.ToTensor()])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=batch_size, shuffle=True)

In [None]:
tmploader = torch.utils.data.DataLoader(train_dataset, batch_size=25, shuffle=True)
data, labels = next(iter(tmploader))
vs.plot_labels(data, labels, pred = None, lbl_dict = train_dataset.class_to_idx, fig_shape = (5,5), figsize = (6,6), up_fctr = 2)

In [None]:
num_feat = 512
num_dim = 400

In [None]:
cpu = torch.device('cpu')
gpu = torch.device('cuda')

use_gpu = torch.cuda.is_available()
print("Using GPU: {0:s} - {1:0.3f} GB".format(torch.cuda.get_device_name(0), 
                                              gp.get_gpu_memory_total() / 1000))

In [None]:
encoder = bm.Encoder_Faces(num_feat, num_dim)
encoder.load_state_dict(torch.load('../storage/aae_encoder_checkpoint_0300.pth', map_location = cpu))
encoder.eval()

encoder = encoder.to(gpu)

In [None]:
trn_enc, trn_real = bm.apply_to_loader(encoder, train_loader, num_dim)
tst_enc, tst_real = bm.apply_to_loader(encoder, test_loader, num_dim)

In [None]:
clf = LGBMClassifier(objective='multiclass')
clf.fit(trn_enc, trn_real)

ohe = OneHotEncoder(categories='auto', sparse=False)
ohe.fit(trn_real.reshape(-1,1))

trn_pred = clf.predict(trn_enc)
trn_proba = clf.predict_proba(trn_enc)

tst_pred = clf.predict(tst_enc)
tst_proba = clf.predict_proba(tst_enc)

print(clf)

In [None]:
print("{} RAW - TRAINING {}".format("="*15, "="*15))
print("Accuracy: {:5.2f} %".format(accuracy_score(trn_real, trn_pred) * 100))
print("AUROC: {:5.2f} %".format(roc_auc_score(ohe.transform(trn_real.reshape(-1,1)), trn_proba) * 100))
print(classification_report(trn_real, trn_pred))

fig, ax = plt.subplots(figsize = (12, 12))
im, cbar = vs.heatmap(confusion_matrix(trn_real, trn_pred), train_dataset.classes, 
                      train_dataset.classes, ax=ax,cmap="YlGnBu")
texts = vs.annotate_heatmap(im, valfmt="{x:.0f}", fontdict = vs.font)

fig.tight_layout()
plt.show()

In [None]:
print("{} RAW - TESTING {}".format("="*15, "="*15))
print("Accuracy: {:5.2f} %".format(accuracy_score(tst_real, tst_pred) * 100))
print("AUROC: {:5.2f} %".format(roc_auc_score(ohe.transform(tst_real.reshape(-1,1)), tst_proba) * 100))
print(classification_report(tst_real, tst_pred))

fig, ax = plt.subplots(figsize = (12, 12))
im, cbar = vs.heatmap(confusion_matrix(tst_real, tst_pred), train_dataset.classes, 
                      train_dataset.classes, ax=ax,cmap="YlGnBu")
texts = vs.annotate_heatmap(im, valfmt="{x:.0f}", fontdict = vs.font)

fig.tight_layout()
plt.show()

In [None]:
tmploader = torch.utils.data.DataLoader(test_dataset, batch_size=25, shuffle=True)
data, labels = next(iter(tmploader))

mini_Dataset = torch.utils.data.TensorDataset(data, labels)
mini_loader = torch.utils.data.DataLoader(mini_Dataset, batch_size=25, shuffle=True)

mini_enc, mini_real = bm.apply_to_loader(encoder, mini_loader, num_dim)
mini_pred = clf.predict(mini_enc)
mini_proba = clf.predict_proba(mini_enc)

vs.plot_labels(data, labels, pred = mini_pred, lbl_dict = train_dataset.class_to_idx, fig_shape = (5,5), figsize = (6,6), up_fctr = 2)