In [None]:
import sys
sys.path.append('../')
import seaborn as sns 
from multiprocessing import Process
from omegaconf import OmegaConf
from src.utils.run_lib import *
import numpy as np 
from src.core.passive_learning import *
from src.core.auto_labeling import *
from src.datasets import dataset_factory 
from src.datasets.dataset_utils import * 
from src.utils.common_utils import * 
from src.utils.vis_utils import *
import copy 
import random 
from src.datasets.numpy_dataset import * 


root_dir = '../'
conf_dir = f'{root_dir}configs/calib-exp/'

base_conf_file = '{}/mnist_lenet_base_conf.yaml'.format(conf_dir)
conf           = OmegaConf.load(base_conf_file)

training_conf  = OmegaConf.load('{}/training_confs/fmfp_conf.yaml'.format(conf_dir))

root_pfx       = 'mnist_lenet_calib-exp-runs'
root_pfx       = f'{root_dir}/outputs/{root_pfx}/'


conf['root_pfx']    = root_pfx

conf.training_conf = training_conf


logger   = get_logger('../temp/logs/act_lbl_test.log',stdout_redirect=True,level=logging.DEBUG)


In [None]:
conf['train_pts_query_conf']['seed_train_size']= 10
conf['train_pts_query_conf']['max_num_train_pts']= 50

set_seed(conf['random_seed'])

dm = DataManager(conf,logger,lib=conf['model_conf']['lib'])

print(len(dm.ds_std_train), len(dm.ds_std_val))

pl = PassiveLearning(conf,dm,logger)

out = pl.run()

w = pl.cur_clf.get_weights()
print(torch.norm(w))
test_err = get_test_error(pl.cur_clf,dm.ds_std_test,conf['inference_conf'])
print(test_err)

In [None]:
inf_out = pl.cur_clf.predict(dm.ds_std_test,conf['inference_conf'])
bin_data = compute_calibration(dm.ds_std_test.Y, inf_out['labels'], inf_out['confidence'], num_bins=15)
ax = plt.subplot()
reliability_diagram_subplot(ax, bin_data, draw_ece=True, draw_bin_importance=False, title="Reliability Diagram",  xlabel="Confidence", ylabel="Expected Accuracy",disable_labels=False)

In [None]:
lenet_clf_model = pl.cur_clf.model
test_ds = dm.ds_std_test

inference_conf = {} 
inference_conf['device']= 'cpu'
inf_out_test_1 = pl.cur_clf.predict(test_ds, inference_conf) 

inf_out_test_1['true_labels'] = test_ds.Y
c_idx = inf_out_test_1['true_labels'] ==  inf_out_test_1['labels']
i_idx =inf_out_test_1['true_labels'] !=  inf_out_test_1['labels']


sns.kdeplot(inf_out_test_1['confidence'][c_idx])
sns.kdeplot(inf_out_test_1['confidence'][i_idx])

#plt.hist(inf_out_val_1['confidence'][c_idx])
#plt.hist(inf_out_val_1['confidence'][i_idx])



In [None]:
dm.unmark_auto_labeled()

auto_labeler = AutoLabeling(conf,dm,pl.cur_clf,logger)
out = auto_labeler.run()
out = dm.get_auto_labeling_counts()
print(out)