In [None]:
%%capture
%load_ext autoreload
%autoreload 2
#Basic Imports
import os,sys
os.chdir("..")

from tqdm import tqdm,trange
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report,confusion_matrix, ConfusionMatrixDisplay
import torch
import pandas

from datasets.ssl_dataset import SSL_Dataset
from datasets.data_utils import get_data_loader
from utils import get_model_checkpoints
from utils import net_builder

def plot_examples(images,labels,encoding, prediction=None):
    fig = plt.figure(figsize=(8, 5), dpi=150)
    for idx,img in enumerate(images[:32]):
        ax = fig.add_subplot(4, 8, idx+1, xticks=[], yticks=[])
        if np.max(img) > 1.5:
            img = img / 255
        plt.imshow(img)
        if prediction is not None:
            label = "GT: " + encoding[labels[idx]] + "\n PR: " + encoding[prediction[idx]]
        else:
            label = encoding[labels[idx]]    
        plt.title(str(label),fontsize=5)
        
def plot_cmatrix(pred,labels,encoding):
    fig = plt.figure(figsize=(8, 5), dpi=150)
    cm = confusion_matrix(labels,pred)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm,display_labels=encoding)
    disp.plot(xticks_rotation="vertical")

## Initialize parameters

In [None]:
path = "./saved_models/test/"

In [None]:
checkpoints, args = get_model_checkpoints(path)
args = args[0]

In [None]:
args["batch_size"] = 256
args["data_dir"] = "./data/"
args["use_train_model"] = False
args["load_path"] = checkpoints[0]

## Eval

In [None]:
checkpoint_path = os.path.join(args["load_path"])
checkpoint = torch.load(checkpoint_path)
load_model = (checkpoint["train_model"] if args["use_train_model"] else checkpoint["eval_model"])

_net_builder = net_builder(args["net"],None,{})

net = _net_builder(num_classes=args["num_classes"])
net.load_state_dict(load_model)
if torch.cuda.is_available():
    net.cuda()
net.eval()

_eval_dset = SSL_Dataset(name=args["dataset"], train=False, data_dir=args["data_dir"])
eval_dset = _eval_dset.get_dset()

eval_loader = get_data_loader(eval_dset, args["batch_size"], num_workers=1)

## Evaluate

In [None]:
label_encoding = _eval_dset.label_encoding
inv_transf = _eval_dset.inv_transform

In [None]:
plot_examples(eval_dset.data,eval_dset.targets,label_encoding)

In [None]:
#Assemble a batch
images, labels, preds = [],[],[]
with torch.no_grad():
    for image, target in tqdm(eval_loader):
        image = image.type(torch.FloatTensor).cuda()
        logit = net(image)
        for idx,img in enumerate(image):
            images.append(inv_transf(img.transpose(0,2).cpu().numpy()).transpose(0,2).numpy())
        preds.append(logit.cpu().max(1)[1])
        labels.append(target)
labels = torch.cat(labels).numpy()
preds = torch.cat(preds).numpy()


In [None]:
plot_examples(images[32:],labels,label_encoding,preds)

## Evaluate all training data

In [None]:
train_report = classification_report(labels, preds, target_names=label_encoding, output_dict=True)
print(classification_report(labels, preds, target_names=label_encoding))

In [None]:
plot_cmatrix(preds,labels,label_encoding)

In [None]:
#Store results in a file
import pandas as pd
test_frame = pd.DataFrame(test_report).transpose()

In [None]:
test_frame.to_csv(FLAGS.ckpt + "_test_results.csv")