In [None]:
%%capture
%load_ext autoreload
%autoreload 2
#Basic Imports
import os,sys
os.chdir("..")

from tqdm import tqdm,trange
import numpy as np
from sklearn.metrics import classification_report
import torch

from datasets.ssl_dataset import SSL_Dataset
from datasets.data_utils import get_data_loader
from utils import get_model_checkpoints
from utils import net_builder
from utils import plot_examples, plot_cmatrix

import pandas as pd
from termcolor import colored
from copy import deepcopy

## Initialize parameters

In [None]:
eurosat_rgb_path="/scratch/fixmatch_results/new_runs/nr_of_labels/eurosat_rgb/FixMatch_archefficientnet-b2_batch32_confidence0.95_lr0.03_uratio7_wd0.00075_wu1.0_seed0_numlabels50_optSGD"
ucm_path = "/scratch/fixmatch_results/runs_new_paper_version/nr_of_labels/ucm/FixMatch_archefficientnet-b2_batch16_confidence0.95_lr0.03_uratio4_wd0.00075_wu1.0_seed0_numlabels105_optSGD"

In [None]:
use_all_seeds=True
path = eurosat_rgb_path

In [None]:
checkpoints, args = get_model_checkpoints(path)
args = args[0]

In [None]:
args["batch_size"] = 256
args["data_dir"] = "./data/"
args["use_train_model"] = False
args["load_path"] = checkpoints[0]

## Eval

In [None]:
checkpoint_path = os.path.join(args["load_path"])
checkpoint = torch.load(checkpoint_path)
load_model = (checkpoint["train_model"] if args["use_train_model"] else checkpoint["eval_model"])

_eval_dset = SSL_Dataset(name=args["dataset"], train=False, data_dir=args["data_dir"], seed=args["seed"])
eval_dset = _eval_dset.get_dset()

_net_builder = net_builder(args["net"],None,{})

net = _net_builder(num_classes=_eval_dset.num_classes, in_channels=_eval_dset.num_channels)
net.load_state_dict(load_model)
if torch.cuda.is_available():
    net.cuda()
net.eval()



eval_loader = get_data_loader(eval_dset, args["batch"], num_workers=1)

## Evaluate

In [None]:
label_encoding = _eval_dset.label_encoding
inv_transf = _eval_dset.inv_transform

In [None]:
plot_examples(eval_dset.data,eval_dset.targets,label_encoding)

In [None]:
#Assemble a batch
images, labels, preds = [],[],[]
with torch.no_grad():
    for image, target in tqdm(eval_loader):
        image = image.type(torch.FloatTensor).cuda()
        logit = net(image)
        for idx,img in enumerate(image):
            images.append(inv_transf(img.transpose(0,2).cpu().numpy()).transpose(0,2).numpy())
        preds.append(logit.cpu().max(1)[1])
        labels.append(target)
labels = torch.cat(labels).numpy()
preds = torch.cat(preds).numpy()

In [None]:
plot_examples(images,labels,label_encoding, (10,6), 160, 6,preds, args["dataset"]+".png")

## Evaluate all test data

In [None]:
if use_all_seeds:
    labels, preds = [],[]
    
    for seed in [0,1,2]:
        
        print("Processing seed:", colored(seed,"red"))
        checkpoint_path=checkpoints[0].replace("seed"+str(checkpoints[0][checkpoints[0].find("seed")+4]), "seed"+str(seed))
        print(checkpoint_path)
        
        checkpoint = torch.load(checkpoint_path,map_location='cuda:0')
        load_model = (checkpoint["train_model"] if args["use_train_model"] else checkpoint["eval_model"])
        _net_builder = net_builder(args["net"],False,{})
        _eval_dset = SSL_Dataset(name=args["dataset"], train=False, data_dir=args["data_dir"], seed=seed)
        eval_dset = _eval_dset.get_dset()
        net = _net_builder(num_classes=_eval_dset.num_classes, in_channels=_eval_dset.num_channels)
        net.load_state_dict(load_model)
        if torch.cuda.is_available():
            net.cuda()
        net.eval()
        eval_loader = get_data_loader(eval_dset, args["batch"], num_workers=1)
        
        #Assemble a batch
        labels_seed, preds_seed = [],[]
        with torch.no_grad():
            for image, target in tqdm(eval_loader):
                image = image.type(torch.FloatTensor).cuda()
                logit = net(image)
                
                preds_seed.append(logit.cpu().max(1)[1])
                labels_seed.append(target)
                
        preds.append(torch.cat(preds_seed).numpy())
        labels.append(torch.cat(labels_seed).numpy())
    
    

In [None]:
if use_all_seeds:
    test_report_list=[]
    for labels_seed, preds_seed in zip(labels, preds):
        test_seed=classification_report(labels_seed, preds_seed, target_names=label_encoding, output_dict=True)
        test_seed_keys=list(test_seed.keys())[:-3]
        test_seed_values=list(test_seed.values())[:-3]
        test_report_list.append(dict(zip(test_seed_keys, test_seed_values)))
        test_report_keys=list(test_report_list[0].keys())
    test_report=deepcopy(test_report_list[0])
    
    for key in list(test_report_keys):
        test_report[key]['precision']=0.0
        test_report[key]['recall']=0.0
        test_report[key]['f1-score']=0.0
        test_report[key]['support']=0.0
    
    for key in list(test_report_keys):
        for n in range(len(test_report_list)):
            test_report[key]['precision']+=test_report_list[n][key]['precision']/len(test_report_list)
            test_report[key]['recall']+=test_report_list[n][key]['recall']/len(test_report_list)
            test_report[key]['f1-score']+=test_report_list[n][key]['f1-score']/len(test_report_list)
            test_report[key]['support']+=test_report_list[n][key]['support']/len(test_report_list)

        
else:
    test_report = classification_report(labels, preds, target_names=label_encoding, output_dict=True)[:-3]
    

df = pd.DataFrame(test_report)
print(df)
df.to_csv("./"+str(args["dataset"])+"_"+str(args["numlabels"]) + "_test_results.csv")


In [None]:
plot_cmatrix(preds,labels,label_encoding, figsize=(10, 8),dpi=150, class_names_font_scale=1.6, matrix_font_size=12, save_fig_name=str(args["dataset"])+"_"+str(args["numlabels"])+"_cm.png")