In [None]:
%%capture
%load_ext autoreload
%autoreload 2
#Basic Imports
import os,sys
os.chdir("..")

from tqdm import tqdm,trange
import numpy as np
from sklearn.metrics import classification_report
import torch

from datasets.ssl_dataset import SSL_Dataset
from datasets.data_utils import get_data_loader
from utils import get_model_checkpoints
from utils import net_builder
from utils import plot_examples, plot_cmatrix

## Initialize parameters

In [None]:
path = "experiments\\ucm_runs\\ucm\\FixMatch_archefficientnet-b2_batch16_confidence0.95_lr0.03_uratio4_wd0.00075_wu1.0_seed0_numlabels105_optSGD"

In [None]:
checkpoints, args = get_model_checkpoints(path)
args = args[0]

In [None]:
args["batch_size"] = 256
args["data_dir"] = "./data/"
args["use_train_model"] = False
args["load_path"] = checkpoints[0]

## Eval

In [None]:
checkpoint_path = os.path.join(args["load_path"])
checkpoint = torch.load(checkpoint_path)
load_model = (checkpoint["train_model"] if args["use_train_model"] else checkpoint["eval_model"])

_eval_dset = SSL_Dataset(name=args["dataset"], train=False, data_dir=args["data_dir"], seed=args["seed"])
eval_dset = _eval_dset.get_dset()

_net_builder = net_builder(args["net"],None,{})

net = _net_builder(num_classes=_eval_dset.num_classes, in_channels=_eval_dset.num_channels)
net.load_state_dict(load_model)
if torch.cuda.is_available():
    net.cuda()
net.eval()



eval_loader = get_data_loader(eval_dset, args["batch"], num_workers=1)

## Evaluate

In [None]:
label_encoding = _eval_dset.label_encoding
inv_transf = _eval_dset.inv_transform

In [None]:
plot_examples(eval_dset.data,eval_dset.targets,label_encoding)

In [None]:
#Assemble a batch
images, labels, preds = [],[],[]
with torch.no_grad():
    for image, target in tqdm(eval_loader):
        image = image.type(torch.FloatTensor).cuda()
        logit = net(image)
        for idx,img in enumerate(image):
            images.append(inv_transf(img.transpose(0,2).cpu().numpy()).transpose(0,2).numpy())
        preds.append(logit.cpu().max(1)[1])
        labels.append(target)
labels = torch.cat(labels).numpy()
preds = torch.cat(preds).numpy()

In [None]:
plot_examples(images,labels,label_encoding,preds)

## Evaluate all training data

In [None]:
train_report = classification_report(labels, preds, target_names=label_encoding, output_dict=True)
print(classification_report(labels, preds, target_names=label_encoding))

In [None]:
plot_cmatrix(preds,labels,label_encoding, figsize=(8, 5),dpi=150, font_scale=1.1, save_fig_name=str(args["dataset"])+"_"+str(args["numlabels"])+"_cm.png")