# Online Local Adaptive Model - Notebook 3_2

* Prior Probability Shift is one of the common problems encountered in Machine Learning algortihms.   
* There are some approaches for dealing with this problem in a 'static' scenario. But there are situations in which we need a model which deals with secvential data as input (e.g. a server which gets input from different users, with different data distributions).   
* In this project, we try to build a model which self adapts its predictions based on the local label distribution. 

### About notebook 3_2

In this notebook we address the problem of Prior Probability Shift. In Experiment 2, a modified LeNet5 model is trained in order to make it able to adapt to local distribution.

## Notebook setup and data preparation
Launch notebook in Windows from terminal:  C:\Users\diaco\Anaconda3\envs\work\python.exe -m notebook


### Notebook setup

In [None]:
from IPython.core.display import display, HTML
from IPython.display import Image
# display(HTML("<style>.container { width:100% !important; }</style>"))
% matplotlib inline
# %matplotlib qt
% load_ext autoreload
% autoreload 2

# import seaborn as sns
# sns.set_context('notebook')
# sns.set_style('white')  # workaround for displaying plot axis on white background


### Imports

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
import pickle

# os.chdir(r'C:\Users\diaco\Desktop\ML\Licenta\finalproject\mnist')
from dataset import MNISTDataset
from utils import Utils
from lenet5 import Lenet5
from lenet5_with_distr import Lenet5WithDistr
import PIL.Image

# numpy print options
np.set_printoptions(linewidth=150)
np.set_printoptions(edgeitems=10)
np.set_printoptions(precision=3)
pd.set_option('display.precision', 3)


### Set seed

In [None]:
# create a random generator using a constant seed in order to reproduce results
seed = 112358
nprg = np.random.RandomState(seed)

### Import MNIST dataset

In [None]:
MNIST_TRAIN_IMG_PATH = 'MNIST_dataset/train-images.idx3-ubyte'
MNIST_TRAIN_LABELS_PATH = 'MNIST_dataset/train-labels.idx1-ubyte'
MNIST_TEST_IMG_PATH = 'MNIST_dataset/t10k-images.idx3-ubyte'
MNIST_TEST_LABELS_PATH = 'MNIST_dataset/t10k-labels.idx1-ubyte'

mnist_ds = MNISTDataset(MNIST_TRAIN_IMG_PATH, MNIST_TRAIN_LABELS_PATH, MNIST_TEST_IMG_PATH, MNIST_TEST_LABELS_PATH)


---
### - build a list with distributions used in previously trained models
---

In [None]:
# WORK_DIR = "./results/PriorProbabilityShift_experiment_5_10000samples/"
WORK_DIR = "./results/PriorProbabilityShift_experiment_5_5000samples/"
ckpt_file_list = Utils.get_all_files_from_dir_ending_with(WORK_DIR, "ckpt.meta", without_file_extension=True)
perf_dict = {'idx_model': [], 'idx_distr': [], 'test_loss': [], 'test_acc': [], 'total_predict': [], 'total_actual': [],
             'train_distr': [], 'test_distr': [], 'ckpt_file': []}

# build a list with all ditributions considered in training phase
distrs_used_for_training = []
for idx_model, ckpt_file in enumerate(ckpt_file_list):
    print('Restoring model {} from {}'.format(idx_model, ckpt_file))
    temp_model = Lenet5(mnist_dataset=mnist_ds, display_summary=False)
    temp_model.restore_session(ckpt_dir=WORK_DIR, ckpt_filename=ckpt_file)
    current_model_train_distr = temp_model.session.run(temp_model.train_distr)
    distrs_used_for_training.append(current_model_train_distr)
    print('The restored model {} was trained using distr: {}\n'.format(idx_model, current_model_train_distr))
    plt.bar(range(0, 10), current_model_train_distr)
    plt.xticks(range(0, 10))
    plt.title('current_model_train_distr')
    plt.show()


## Experiment 2. Try to adapt the standard version of LeNet5 in order to make the model able to adapt to local label distribution

---
### Section 1
Train 2 models:
- first one without using the distribution as input
- second one using the distribution as input but without explicitly build batches with respect to a distribution (i.e. just append the distribution of current batch before the specified layers)

Check if appending the distribution has an influence on the accuracy.
***

In [None]:
SUBSET_SIZE_LIST = [150, 250, 500, 1000, 5000, 10000, 45000]

DISTR_POS_LIST = [
    [False, False, False, False, False],
    [False, False, False, False, True],
    [False, False, True, False, False],
    [False, False, True, True, True],
    [False, True, False, False, False],
    [True, False, False, False, False],
    [True, True, False, False, False],
    [True, True, True, True, True]
]
for subset_size in SUBSET_SIZE_LIST:
    if subset_size < 500:
        batch_size = 50
    else:
        batch_size = 100
        
    for distr_pos in DISTR_POS_LIST:
        mnist_ds = MNISTDataset(MNIST_TRAIN_IMG_PATH, MNIST_TRAIN_LABELS_PATH, MNIST_TEST_IMG_PATH, MNIST_TEST_LABELS_PATH)
        mnist_ds.impose_distr_on_train_dataset(subset_size=subset_size, weights=[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
        lenet5_model_with_distr = Lenet5WithDistr(mnist_ds, "randDistr_distrPos_{}_{}examples_{}batchSize".format(distr_pos, subset_size, batch_size),
                                                  epochs=40, batch_size=batch_size, variable_mean=0, variable_stddev=0.1,
                                                  learning_rate=0.001,
                                                  drop_out_keep_prob=0.75,
                                                  distr_pos=distr_pos)
        lenet5_model_with_distr.train(distrs_list=None)
        plt.show()


***
### Section 2
Train a model by imposing only those eight distributions used in the first experiment. Use the same amount of data.  
Check its accuracy comparing to the models from the first emperiment (adjusting with real priors) and to the models trained with distributions appended as they are.
***

In [None]:
SUBSET_SIZE_LIST = [150, 250, 500, 1000, 5000, 10000, 45000]

DISTR_POS_LIST = [
    [False, False, False, False, False],
    [False, False, False, False, True],
    [False, False, True, False, False],
    [False, False, True, True, True],
    [False, True, False, False, False],
    [True, False, False, False, False],
    [True, True, False, False, False],
    [True, True, True, True, True]
]

for subset_size in SUBSET_SIZE_LIST: 
    if subset_size < 500:
        batch_size = 50
    else:
        batch_size = 100
        
    for distr_pos in DISTR_POS_LIST:
        mnist_ds = MNISTDataset(MNIST_TRAIN_IMG_PATH, MNIST_TRAIN_LABELS_PATH, MNIST_TEST_IMG_PATH, MNIST_TEST_LABELS_PATH)
        mnist_ds.impose_distr_on_train_dataset(subset_size=subset_size, weights=[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
        lenet5_model_with_distr = Lenet5WithDistr(mnist_ds, "8distrs_distrPos_{}_{}examples_{}batchSize".format(distr_pos, subset_size, batch_size),
                                                  epochs=40, batch_size=batch_size, variable_mean=0, variable_stddev=0.1,
                                                  learning_rate=0.001,
                                                  drop_out_keep_prob=0.75,
                                                  distr_pos=distr_pos)
        lenet5_model_with_distr.train(distrs_list=distrs_used_for_training)
        plt.show()

---
### - restore and test the adapted models on subsets with different distributions; save the results to file
---

In [None]:
# WORK_DIRS = [
#     "./results/Lenet5WithDistr_8distr_10000samples/",
#     "./results/Lenet5WithDistr_8distr_10000samples_2/",
#     "./results/Lenet5WithDistr_8distr_allData/",
#     "./results/Lenet5WithDistr_randomDistr_allData/",
#     "./results/Lenet5WithDistr_randomDistr_10000samples/",
#     "./results/Lenet5WithDistr_randomDistr_10000samples_2/"]

# WORK_DIRS = ["./results/Lenet5WithDistr_8distr_10000samples_2/"]

# WORK_DIRS = [
#     "./results/Lenet5WithDistr_8distr_[]samples/Lenet5_8distr_150samples/",
#     "./results/Lenet5WithDistr_8distr_[]samples/Lenet5_8distr_250samples/",
#     "./results/Lenet5WithDistr_8distr_[]samples/Lenet5_8distr_500samples/",
#     "./results/Lenet5WithDistr_8distr_[]samples/Lenet5_8distr_1000samples/",
#     "./results/Lenet5WithDistr_8distr_[]samples/Lenet5_8distr_5000samples/",
#     "./results/Lenet5WithDistr_8distr_[]samples/Lenet5_8distr_10000samples/",
#     "./results/Lenet5WithDistr_8distr_[]samples/Lenet5_8distr_45000samples/"
# ]

# WORK_DIRS = [
#     "./results/Lenet5WithDistr_randomDistr_[]samples/Lenet5_randomDistr_150samples/",
#     "./results/Lenet5WithDistr_randomDistr_[]samples/Lenet5_randomDistr_250samples/",
#     "./results/Lenet5WithDistr_randomDistr_[]samples/Lenet5_randomDistr_500samples/",
#     "./results/Lenet5WithDistr_randomDistr_[]samples/Lenet5_randomDistr_1000samples/",
#     "./results/Lenet5WithDistr_randomDistr_[]samples/Lenet5_randomDistr_5000samples/",
#     "./results/Lenet5WithDistr_randomDistr_[]samples/Lenet5_randomDistr_10000samples/",
#     "./results/Lenet5WithDistr_randomDistr_[]samples/Lenet5_randomDistr_45000samples/"
# ]

WORK_DIRS = [
    "./results/Lenet5WithDistr_8distr_[]samples_2/150samples/",
    "./results/Lenet5WithDistr_8distr_[]samples_2/250samples/",
    "./results/Lenet5WithDistr_8distr_[]samples_2/500samples/",
    "./results/Lenet5WithDistr_8distr_[]samples_2/1000samples/",
    "./results/Lenet5WithDistr_8distr_[]samples_2/5000samples/",
    "./results/Lenet5WithDistr_8distr_[]samples_2/10000samples/",
    "./results/Lenet5WithDistr_8distr_[]samples_2/45000samples/"
]

NO_TESTS_PER_MODEL = 50
TEST_SUBSET_SIZE = 3000
SAVE_DICT_TO_FILE = True
USE_SHORTCUT_METHOD_FOR_TESTING = True
USE_NORMAL_METHOD_FOR_TESTING = False

def save_test_results_to_dict(test_dict, test_method, test_size, id_test, id_distr, test_distr, test_acc, total_actual, total_predict):
    test_dict['test_method'].append(test_method)
    test_dict['testset_size'].append(test_size)
    test_dict['id_test'].append(id_test)
    test_dict['id_distr'].append(id_distr)
    test_dict['test_distr'].append(test_distr)
    test_dict['test_acc'].append(test_acc)
    correct_predict = total_predict[total_actual == total_predict]
    wrong_predict = total_predict[total_actual != total_predict]
    wrong_actual = total_actual[total_actual != total_predict]
    test_dict['correct_predicted_counts'].append(np.bincount(correct_predict, minlength=10))
    test_dict['wrong_predicted_counts'].append(np.bincount(wrong_predict, minlength=10))
    test_dict['wrong_actual_counts'].append(np.bincount(wrong_actual, minlength=10))
        
for WORK_DIR in WORK_DIRS:

    ckpt_file_list = Utils.get_all_files_from_dir_ending_with(WORK_DIR, "ckpt.meta", without_file_extension=True)

    # dictionary used for saving all the results
    full_results_dict = {}

    for idx_model, ckpt_file in enumerate(ckpt_file_list):
        print('{}: Restoring model {} from {}{}'.format(Utils.now_as_str(), idx_model, WORK_DIR, ckpt_file))

        restored_distr_pos = Utils.restore_variable_from_checkpoint(ckpt_dir=WORK_DIR, ckpt_file=ckpt_file, var_name='distr_pos')
        test_model = Lenet5WithDistr(mnist_dataset=mnist_ds, verbose=False, distr_pos=restored_distr_pos)
        test_model.restore_session(ckpt_dir=WORK_DIR, ckpt_filename=ckpt_file)
        current_model_train_distr = test_model.session.run(test_model.train_distr)
        full_results_dict[ckpt_file] = {'idx_model':idx_model, 'distr_pos':restored_distr_pos, 'train_distr': current_model_train_distr, 
                                   'test_results': {'test_method':[], 'testset_size':[], 'id_test':[], 'id_distr':[], 'test_distr':[], 'test_acc':[],
                                                    'correct_predicted_counts':[],    'wrong_predicted_counts':[], 'wrong_actual_counts':[]}}

        print('The restored model {} was trained using distr: {}, but with batch distr. attached (distr_pos = {})\n'.format(idx_model, current_model_train_distr, restored_distr_pos))

        # test on the entire MNIST dataset
        mnist_ds = MNISTDataset(MNIST_TRAIN_IMG_PATH, MNIST_TRAIN_LABELS_PATH, MNIST_TEST_IMG_PATH, MNIST_TEST_LABELS_PATH)
        print('Testing model on the entire dataset, with its distr. attached ({})'.format(mnist_ds.test.label_distr))
        test_loss, test_acc, total_predict, total_actual, wrong_predict_images, output_probs = test_model.test_data(mnist_ds.test, use_only_one_batch=True)
        print('test_acc = {:.3f}% ({}/{})\n'.format(test_acc * 100, mnist_ds.test.num_examples - len(wrong_predict_images), mnist_ds.test.num_examples))

        # save the test results
        save_test_results_to_dict(test_dict=full_results_dict[ckpt_file]['test_results'], test_method='entire', test_size=mnist_ds.test.num_examples, id_test=-1, id_distr=-1, 
                                  test_distr=mnist_ds.test.label_distr, test_acc=test_acc, total_actual=total_actual, total_predict=total_predict)
        
        print('{}: start testing model on distributions\n'.format(Utils.now_as_str()))
        for idx_distr, distr in enumerate(distrs_used_for_training):
            print('Testing model on the entire dataset, with {} attached'.format(distr))
            mnist_ds = MNISTDataset(MNIST_TRAIN_IMG_PATH, MNIST_TRAIN_LABELS_PATH, MNIST_TEST_IMG_PATH, MNIST_TEST_LABELS_PATH)
            test_loss, test_acc, total_predict, total_actual, wrong_predict_images, output_probs = test_model.test_data( mnist_ds.test, use_only_one_batch=True, distr_to_attach=distr)
            print('test_acc = {:.3f}% ({}/{})'.format(test_acc * 100, mnist_ds.test.num_examples - len(wrong_predict_images), mnist_ds.test.num_examples))

            # save the test results
            save_test_results_to_dict(test_dict=full_results_dict[ckpt_file]['test_results'], test_method='entire_and_distr_attached', test_size=mnist_ds.test.num_examples, id_test=-1, id_distr=idx_distr, 
                                      test_distr=distr, test_acc=test_acc, total_actual=total_actual, total_predict=total_predict)

            print('Testing model {} ({} times) on distribution {}: {}'.format(idx_model, NO_TESTS_PER_MODEL, idx_distr, distr))
            
            # compute theoretical mean accuracy
            wrong_actual = total_actual[total_actual != total_predict]
            theoretical_wrong_count = np.sum(np.bincount(wrong_actual, minlength=10) * (10 * distr)) * (TEST_SUBSET_SIZE / mnist_ds.test.num_examples)
            print('theoretical \u03BC(acc) = {:.3f}% ({}/{})'.format((1 - theoretical_wrong_count / TEST_SUBSET_SIZE) * 100, TEST_SUBSET_SIZE - theoretical_wrong_count, TEST_SUBSET_SIZE))
            
            if USE_SHORTCUT_METHOD_FOR_TESTING:
                list_acc_shortcut_method = []
                for id_test in range(NO_TESTS_PER_MODEL):
                    indices_wrt_distr = Utils.get_indices_wrt_distr(labels=total_actual, weights=distr, max_no_examples=TEST_SUBSET_SIZE)
                    num_examples_selection = len(indices_wrt_distr)
                    total_predict_selection = total_predict[indices_wrt_distr]
                    total_actual_selection = total_actual[indices_wrt_distr]
                    no_correct_predicted = np.sum(total_predict_selection == total_actual_selection)
                    test_acc = no_correct_predicted / num_examples_selection
#                     print('id_test = {} : test_acc = {:.3f}% ({}/{})'.format(id_test, test_acc * 100, no_correct_predicted, num_examples_selection))

                    list_acc_shortcut_method.append(test_acc)

                    # save the test results
                    save_test_results_to_dict(test_dict=full_results_dict[ckpt_file]['test_results'], test_method='shortcut', test_size=num_examples_selection, id_test=id_test, id_distr=idx_distr, 
                                              test_distr=distr, test_acc=test_acc, total_actual=total_actual_selection, total_predict=total_predict_selection)
            
                print('\u03BC(acc_s ) = {:.3f}%, \u03C3(acc_s ) = {:.3f}%'.format(np.mean(list_acc_shortcut_method) * 100, np.std(list_acc_shortcut_method, ddof=1) * 100))
            
            if USE_NORMAL_METHOD_FOR_TESTING:
                list_acc_normal_method = []
                for id_test in range(NO_TESTS_PER_MODEL):
                    mnist_ds1 = MNISTDataset(MNIST_TRAIN_IMG_PATH, MNIST_TRAIN_LABELS_PATH, MNIST_TEST_IMG_PATH, MNIST_TEST_LABELS_PATH)
                    for s in range(id_test):
                        mnist_ds1.test.shuffle() # when MNISTDataset is instanced, the rg is reseted, so explicitly shuffling is needed for getting different results
                    mnist_ds1.impose_distribution(weights=distr, max_test_size=TEST_SUBSET_SIZE)
                    test_loss1, test_acc1, total_predict1, total_actual1, wrong_predict_images1, output_probs1 = test_model.test_data(mnist_ds1.test, use_only_one_batch=True)
#                     print('id_test = {} : test_acc1 = {:.3f}% ({}/{})'.format(id_test, test_acc1 * 100, mnist_ds1.test.num_examples - len(wrong_predict_images1), mnist_ds1.test.num_examples))

                    list_acc_normal_method.append(test_acc1)

                    # save the test results
                    save_test_results_to_dict(test_dict=full_results_dict[ckpt_file]['test_results'], test_method='normal', test_size=mnist_ds1.test.num_examples, id_test=id_test, id_distr=idx_distr, 
                                              test_distr=distr, test_acc=test_acc1, total_actual=total_actual1, total_predict=total_predict1)

                print('\u03BC(acc_n) = {:.3f}%, \u03C3(acc_n) = {:.3f}%'.format(np.mean(list_acc_normal_method) * 100, np.std(list_acc_normal_method, ddof=1) * 100))
                
            print('\n')

        # analyze error distribution
        # restore_and_test_a_model_on_a_mnist_subset(mnist_ds, ckpt_dir=WORK_DIR, ckpt_filemame=ckpt_file, plot_filename = 'test_model')

    # save the above results dictionary to file
    if SAVE_DICT_TO_FILE:
        filename = 'testing_results_{}testsPerModel.dict.pickle'.format(NO_TESTS_PER_MODEL)
        full_filepath = os.path.join(WORK_DIR, filename)
        filehandler = open(full_filepath, 'wb') 
        pickle.dump(full_results_dict, filehandler)
        print('Results dictionary was succesfully saved to: {}'.format(full_filepath))
        filehandler.close()
        
    print('\n\n\n')



---
### - check if testing on the entire MNIST testset with different distributions attached affects the accuracy
- analyze the wrong predictions distribution
---

In [None]:
# WORK_DIR = "./results/Lenet5WithDistr_8distr_[]samples/Lenet5_8distr_1000samples/"
# filename = 'testing_results_50testsPerModel.dict.pickle'
# ckpt_file = 'Lenet5_8distrs_distrPos_[False, False, True, True, True]_1000examples_2018_05_28---11_02.model.ckpt'

WORK_DIR = "./results/Lenet5WithDistr_8distr_[]samples_2/Lenet5_8distr_1000samples/"
filename = 'testing_results_50testsPerModel.dict.pickle'
ckpt_file = 'Lenet5_8distrs_distrPos_[False, False, True, True, True]_1000examples_100batchSize_2018_06_01---11_05.model.ckpt'

filehandler = open(os.path.join(WORK_DIR, filename), 'rb')
restored_full_results_dict = pickle.load(filehandler)
filehandler.close()
print('Results dictionary was succesfully restored from: {}{}'.format(WORK_DIR, filename))

print('idx_model = {}, ckpt_file = {}\n'.format(restored_full_results_dict[ckpt_file]['idx_model'], ckpt_file))
perf_df = pd.DataFrame(restored_full_results_dict[ckpt_file]['test_results'], columns=list(restored_full_results_dict[ckpt_file]['test_results'].keys()))
temp_df = perf_df[(perf_df['test_method'] == 'entire_and_distr_attached')]
# assert(False)

k = 0
fig = plt.figure(figsize=(40, 7))
fig.suptitle(t='ckpt_file = {}'.format(ckpt_file), fontsize=20, fontweight='bold')
no_subplots_per_row = temp_df.shape[0]

global_max_wrong_predicted_count = 0
for index, df_row in temp_df.iterrows():
    if np.max(df_row['wrong_predicted_counts']) > global_max_wrong_predicted_count:
        global_max_wrong_predicted_count = np.max(df_row['wrong_predicted_counts'])

for index, df_row in temp_df.iterrows():
#     print(df_row['test_distr'])
    k += 1
    
    plt.subplot(2, no_subplots_per_row, k)
    plt.bar(range(10), df_row['test_distr'] / sum(df_row['test_distr']))
    plt.xticks(range(0, 10))
    plt.title('acc = {:.1f}%'.format(df_row['test_acc'] * 100), fontsize=16)
    if k==1:
        plt.ylabel('attached distr.', fontsize=20)
    
    plt.subplot(2, no_subplots_per_row, k + no_subplots_per_row)
    plt.ylim([0, global_max_wrong_predicted_count])
    plt.bar(range(10), df_row['wrong_predicted_counts'])
    plt.xticks(range(0, 10))
    plt.grid()
    if k==1:
        plt.ylabel('wrong pred. counts', fontsize=20)
        
plt.show()
    

---
- print results in a compact form
- compare the different distr_pos's considered
- check again if attaching the distribution has an impact on performance
- compare the results between 8distr and randomDistr
---

In [None]:
print(scores)
print(np.sort(-scores))
print(np.argsort(-scores))
print(np.searchsorted(np.sort(-scores), -scores[3]))

In [None]:
#### WORK_DIRS = [
#     "./results/Lenet5WithDistr_8distr_allData",
#     "./results/Lenet5WithDistr_randomDistr_allData",
#     "./results/Lenet5WithDistr_randomDistr_10000samples/",
#     "./results/Lenet5WithDistr_randomDistr_10000samples_2/",
#     "./results/Lenet5WithDistr_8distr_10000samples/",
#     "./results/Lenet5WithDistr_8distr_10000samples_2/"]


# WORK_DIRS = [
#     "./results/Lenet5WithDistr_8distr_[]samples/Lenet5_8distr_150samples/",
#     "./results/Lenet5WithDistr_8distr_[]samples/Lenet5_8distr_250samples/",
#     "./results/Lenet5WithDistr_8distr_[]samples/Lenet5_8distr_500samples/",
#     "./results/Lenet5WithDistr_8distr_[]samples/Lenet5_8distr_1000samples/",
#     "./results/Lenet5WithDistr_8distr_[]samples/Lenet5_8distr_5000samples/",
#     "./results/Lenet5WithDistr_8distr_[]samples/Lenet5_8distr_10000samples/",
#     "./results/Lenet5WithDistr_8distr_[]samples/Lenet5_8distr_45000samples/",
#     "./results/Lenet5WithDistr_randomDistr_[]samples/Lenet5_randomDistr_150samples/",
#     "./results/Lenet5WithDistr_randomDistr_[]samples/Lenet5_randomDistr_250samples/",
#     "./results/Lenet5WithDistr_randomDistr_[]samples/Lenet5_randomDistr_500samples/",
#     "./results/Lenet5WithDistr_randomDistr_[]samples/Lenet5_randomDistr_1000samples/",
#     "./results/Lenet5WithDistr_randomDistr_[]samples/Lenet5_randomDistr_5000samples/",
#     "./results/Lenet5WithDistr_randomDistr_[]samples/Lenet5_randomDistr_10000samples/",
#     "./results/Lenet5WithDistr_randomDistr_[]samples/Lenet5_randomDistr_45000samples/"
# ]

WORK_DIRS = [
    "./results/Lenet5WithDistr_8distr_[]samples_2/150samples/",
    "./results/Lenet5WithDistr_8distr_[]samples_2/250samples/",
    "./results/Lenet5WithDistr_8distr_[]samples_2/500samples/",
    "./results/Lenet5WithDistr_8distr_[]samples_2/1000samples/",
    "./results/Lenet5WithDistr_8distr_[]samples_2/5000samples/",
    "./results/Lenet5WithDistr_8distr_[]samples_2/10000samples/",
    "./results/Lenet5WithDistr_8distr_[]samples_2/45000samples/"
]

filename = 'testing_results_50testsPerModel.dict.pickle'


def std_1ddof(x):
    return np.std(x, ddof=1)

for WORK_DIR in WORK_DIRS:
    filehandler = open(os.path.join(WORK_DIR, filename), 'rb')
    restored_full_results_dict = pickle.load(filehandler)
    filehandler.close()
    print('Results dictionary was succesfully restored from: {}{}'.format(WORK_DIR, filename))

    # compute a score for each distr_pos in order to decide where to attach the distribution
    acc_matrix_to_compare_distr_pos = []
    list_distr_pos = []

    # iterate trough models
    for ckpt_file in restored_full_results_dict.keys():
        perf_df = pd.DataFrame(restored_full_results_dict[ckpt_file]['test_results'], columns=list(restored_full_results_dict[ckpt_file]['test_results'].keys()))
        temp_df = perf_df[(perf_df['test_method'] != 'entire_and_distr_attached')].groupby(['test_method', 'id_distr']).agg([np.mean, std_1ddof])
        temp_acc_list = []
        for id_distr in perf_df[perf_df['test_method'] != 'entire_and_distr_attached']['id_distr'].unique():    
            temp_df = perf_df[(perf_df['test_method'] != 'entire_and_distr_attached') & (perf_df['id_distr'] == id_distr)]
            temp_acc_list.append(np.average(temp_df['test_acc']))
        acc_matrix_to_compare_distr_pos.append(temp_acc_list)
        list_distr_pos.append(restored_full_results_dict[ckpt_file]['distr_pos'])
        
    acc_matrix_to_compare_distr_pos = np.array(acc_matrix_to_compare_distr_pos)
    min_per_column = np.min(acc_matrix_to_compare_distr_pos, axis = 0)
    acc_diff_matrix = acc_matrix_to_compare_distr_pos - min_per_column
    acc_diff_matrix = acc_diff_matrix / np.max(acc_diff_matrix, axis = 0)
    scores = np.sum(acc_diff_matrix, axis=1) / acc_diff_matrix.shape[1]
    
    # iterate trough models
    for idx, ckpt_file in enumerate(restored_full_results_dict.keys()):
    #     print('idx_model = {}, ckpt_file = {}'.format(restored_full_results_dict[ckpt_file]['idx_model'], ckpt_file))
    #     print('distr_pos = {}'.format(restored_full_results_dict[ckpt_file]['distr_pos']))
    #     print('train_distr = {}'.format(restored_full_results_dict[ckpt_file]['train_distr']))
    #     print()

        # build a pandas dataframe from results dictionary
        perf_df = pd.DataFrame(restored_full_results_dict[ckpt_file]['test_results'], columns=list(restored_full_results_dict[ckpt_file]['test_results'].keys()))
    #     display(perf_df)
    #     display(perf_df.describe())
    #     display(perf_df.head())

    #     display(perf_df[(perf_df['test_method'] == 'normal') & (perf_df['id_distr'] == 0)])
    #     display(perf_df[(perf_df['test_method'] == 'normal') & (perf_df['id_distr'] == 0)].describe())
        temp_df = perf_df[(perf_df['test_method'] != 'entire_and_distr_attached')].groupby(['test_method', 'id_distr']).agg([np.mean, std_1ddof])
    #     display(temp_df['test_acc'])

        # plot results
        print('{} --- '.format(restored_full_results_dict[ckpt_file]['distr_pos']), end="", flush=True)
        for id_distr in perf_df[perf_df['test_method'] != 'entire_and_distr_attached']['id_distr'].unique():    
            temp_df = perf_df[(perf_df['test_method'] != 'entire_and_distr_attached') & (perf_df['id_distr'] == id_distr)]
    #         display(temp_df.head())
            print('{:.1f}% \u00B1 {:.1f}%, '.format(np.average(temp_df['test_acc']) * 100, std_1ddof(temp_df['test_acc']) * 100), end="", flush=True)
        print(' --- score = {:.3f} --- {}'.format(scores[idx], np.searchsorted(np.sort(-scores), -scores[idx])+1))

        
    print('\n\n')

    