# Project: Online Local Adaptive Model

* Prior Probability Shift is one of the common problems encountered in Machine Learning algortihms.   
* There are some approaches for dealing with this problem in a 'static' scenario. But there are situations in which we need a model which deals with secvential data as input (e.g. a server which gets input from different users, with different data distributions).   
* In this project, we try to build a model which self adapts its predictions based on the local label distribution. 

### About notebook 2

In this notebook we address the problem of Prior Probability Shift as following:
- we ilustrate an example of how a different test distribution has an impact on the model's performance.
We train multiple models, on a range of subsets of MNIST, with different distributions. Then each model is tested on a range of test subsets with respect to the distributions considered in the training phase.
- then, we test that imposing the real priors helps.

---
## 1. Notebook setup and data preparation
---

### Notebook setup

In [None]:
from IPython.core.display import display, HTML
from IPython.display import Image
display(HTML("<style>.container { width:100% !important; }</style>"))
%matplotlib inline
# %matplotlib qt
%load_ext autoreload
%autoreload 2

### Imports

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf
import time
from collections import deque
import os
import pickle
from training_plotter import TrainingPlotter
from dataset import MNISTDataset
import utils
from lenet5 import Lenet5
from lenet5_with_distr import Lenet5WithDistr
import PIL.Image

# numpy print options
np.set_printoptions(linewidth=150)
np.set_printoptions(edgeitems=10)
np.set_printoptions(precision=3)
pd.set_option('display.precision', 3)

### Set seed

In [None]:
# create a random generator using a constant seed in order to reproduce results
seed = 112358
nprg = np.random.RandomState(seed)

### Import MNIST dataset

In [None]:
MNIST_TRAIN_IMAGES_FILEPATH = 'MNIST_dataset/train-images.idx3-ubyte'
MNIST_TRAIN_LABELS_FILEPATH = 'MNIST_dataset/train-labels.idx1-ubyte'
MNIST_TEST_IMAGES_FILEPATH = 'MNIST_dataset/t10k-images.idx3-ubyte'
MNIST_TEST_LABELS_FILEPATH = 'MNIST_dataset/t10k-labels.idx1-ubyte'

mnist_ds = MNISTDataset(MNIST_TRAIN_IMAGES_FILEPATH, MNIST_TRAIN_LABELS_FILEPATH, MNIST_TEST_IMAGES_FILEPATH, MNIST_TEST_LABELS_FILEPATH)

---
## 2. Train models on subsets with different distributions
---

### Build some label distributions

In [None]:
distrs_used_for_training = []

# uniform distribution
distr = np.array([1,1,1,1,1,1,1,1,1,1])
distrs_used_for_training.append(distr/np.sum(distr))
# normal distribution centered about label 4-5
r = 2
distr = [r**1,r**2,r**3,r**4,r**5,r**5,r**4,r**3,r**2,r**1]
distrs_used_for_training.append(distr/np.sum(distr))

# skewed normal distribution centered about 2
distr = [r**3,r**4,r**5,r**4.5,r**4,r**3.5,r**3,r**2.5,r**2,r**1.5]
distrs_used_for_training.append(distr/np.sum(distr))

# skwed normal distribution centered about 7
distr = [r**1.5,r**2,r**2.5,r**3,r**3.5,r**4,r**4.5,r**5,r**4,r**3]
distrs_used_for_training.append(distr/np.sum(distr))

# bimodal normal distribution
distr = [r**1,r**2,r**3,r**2,r**1,r**1,r**2,r**3,r**2,r**1]
distrs_used_for_training.append(distr/np.sum(distr))

# bimodal skewed normal distribution
distr = [r**3.5,r**4,r**3,r**2,r**1,r**1,r**2,r**3,r**4,r**3.5]
distrs_used_for_training.append(distr/np.sum(distr))


# exponential distribution
r=1.4
distr = [r**1,r**2,r**3,r**4,r**5,r**6,r**7,r**8,r**9,r**10]
distrs_used_for_training.append(distr/np.sum(distr))

# exponential distribution
r=1.4
distr = [r**10,r**9,r**8,r**7,r**6,r**5,r**4,r**3,r**2,r**1]
distrs_used_for_training.append(distr/np.sum(distr))

print('#distributions used for training = {}'.format(len(distrs_used_for_training)))
for idx, distr in enumerate(distrs_used_for_training):
    print('idx = {}: distr = {}'.format(idx,distr))
    plt.bar(range(10), distr)
    plt.show()

### Train LeNet5 models by imposing the considered distributions on the original MNIST dataset, with fixed subset sizes

In [None]:
SUBSET_SIZE_LIST = [150, 250, 500, 1000, 5000, 10000]
global_max_weight = np.max(distrs_used_for_training)

for subset_size in SUBSET_SIZE_LIST:
    for k, distr in enumerate(distrs_used_for_training):
        mnist_ds = MNISTDataset(MNIST_TRAIN_IMAGES_FILEPATH, MNIST_TRAIN_LABELS_FILEPATH, MNIST_TEST_IMAGES_FILEPATH, MNIST_TEST_LABELS_FILEPATH)
        print('\n\nk = {}: Imposed distribution: {}'.format(k, np.round(np.array(distr), decimals=3)))
        mnist_ds.impose_distribution(np.array(distr), global_max_weight, max_training_size=subset_size)
        lenet5_model = Lenet5(mnist_ds, "with_imposed_distr_{}_{}samples".format(k, subset_size), epochs=40, batch_size=100, variable_mean=0, variable_stddev=0.1, learning_rate=0.001, drop_out_keep_prob=0.75)
        lenet5_model.train()
        plt.show()

---
## 3. Test the above models and adjust outputs using real priors
Test each model:
- on a subset which respects its training distribution
- on subsets which respects the other models' distributions
- on the entire MNIST test set

In order to summarize the results, we will build a matrix, similar to a confusion matrix, by following the next steps:
- take each trained model and test it again on his initial test set
- then test each model on the data that respects, by turn, the other distributions considered
- save these results into a dictionary and then to disk for reuse them later
- finally, build a matrix where an element $m_{ij}$ represents the test accuracy of model $i$ evaluated on the subset with distribution $j$  

Similar, some matrices containing label distribution are built (for correct predictions, wrong predictions, wrong actual predictions)

Formula for corrected a posteriori probabilities:

\begin{equation}
 \large
 \hat{p}(\omega_{i}|\mathrm{x})=\frac{\frac{\hat{p}(\omega_{i})}{\hat{p}_{t}(\omega_{i})}\hat{p}_{t}(\omega_{i}|\mathrm{x})}{\sum_{j=1}^{n}\frac{\hat{p}(\omega_{j})}{\hat{p}_{t}(\omega_{j})}\hat{p}_{t}(\omega_{j}|\mathrm{x})}
\end{equation}

- $\hat{p}(\omega_{i}|\mathrm{x})$ = corrected a posteriori probabilities
- $\hat{p}_{t}(\omega_{i}|\mathrm{x})$ = a posteriori probabilities (as provided by the trained model) 
- $\hat{p}(\omega_{i})$ = the new priors
- $\hat{p}_{t}(\omega_{i})$ = old priors

The denominator ensures that the corrected a posteriori probabilities sum to one.

---




In [None]:
def get_all_files_from_dir_ending_with(directory, ending, without_file_extension = False):
    file_list = []
    files = os.listdir(directory)
    files.sort(key=lambda fn: os.path.getmtime(os.path.join(directory, fn))) # sort by date
    for file in files:
        if file.endswith(ending):
            if without_file_extension:
                file_list.append(os.path.splitext(file)[0])
            else:
                file_list.append(file)
    return file_list

In [None]:
def restore_and_test_a_model_on_a_mnist_subset(mnist_subset, ckpt_dir, ckpt_filemame, plot_filename):
    print('Restoring model from {}{}'.format(ckpt_dir, ckpt_filemame))
    restored_distr_pos = utils.restore_variable_from_checkpoint(ckpt_dir=ckpt_dir, ckpt_file=ckpt_filemame, var_name = 'distr_pos')
    if restored_distr_pos is None:
        restored_distr_pos = [False, False, False, False, False]
    restored_model = Lenet5WithDistr(mnist_dataset=mnist_ds, verbose=False, distr_pos=restored_distr_pos)
    restored_model.restore_session(ckpt_dir=ckpt_dir, ckpt_filename=ckpt_filemame)
    train_distr = restored_model.session.run(restored_model.train_distr)
    test_loss, test_acc, total_predict, total_actual, wrong_predict_images, total_softmax_output_probs = restored_model.test_data(mnist_subset.test)

    print('test_loss = {:.4f}, test_acc = {:.4f} ({}/{})'.format(test_loss, test_acc, mnist_subset.test.num_examples - len(wrong_predict_images), mnist_subset.test.num_examples))
    
    # sort wrong_predict_images by target label
    correct_predict = total_predict[total_actual == total_predict]
    wrong_predict = total_predict[total_actual != total_predict]
    wrong_predict_softmax_output_probs = total_softmax_output_probs[total_actual != total_predict]
    wrong_actual = total_actual[total_actual != total_predict]
    wrong_predict_images = np.array(wrong_predict_images)
    wrong_predict_images_sorted = wrong_predict_images[wrong_actual.argsort(), ]
    wrong_predict_images_sorted = [image for image in wrong_predict_images_sorted]

    count_figures = 6
    fig = plt.figure(figsize=(30, 3))
    fig.suptitle(y = 1.1, t = 'test_acc = {:.4f} ({}/{})'.format(test_acc, mnist_subset.test.num_examples - len(wrong_predict_images), mnist_subset.test.num_examples), fontsize=18, fontweight='bold')

    k = 1
    plt.subplot(1,count_figures, k)
    plt.bar(range(10), train_distr)
    plt.xticks(range(0, 10))
    plt.title('train label distr')
    
    k+=1
    plt.subplot(1,count_figures, k)
    plt.bar(range(10), mnist_subset.test.label_distr)
    plt.xticks(range(0, 10))
    plt.title('test label distr')

    k+=1
    plt.subplot(1,count_figures, k)
    plt.hist(correct_predict, bins=np.arange(11), rwidth=0.8, normed=False)
    plt.xticks(range(0, 10))
    plt.title('correct predicted label distr')
    
    k+=1
    plt.subplot(1,count_figures, k)
    plt.hist(wrong_predict, bins=np.arange(11), rwidth=0.8, normed=False)
    plt.xticks(range(0, 10))
    plt.title('wrong predicted label distr')
    
    k+=1
    plt.subplot(1,count_figures, k)
    plt.hist(wrong_actual, bins=np.arange(11), rwidth=0.8, normed=False)
    plt.xticks(range(0, 10))
    plt.title('wrong actual label distr')
    
    k+=1
    plt.subplot(1,count_figures, k)
    plt.bar(range(0, 10), np.average(wrong_predict_softmax_output_probs, axis=0))
    plt.xticks(range(0, 10))
    plt.title('average of wrong actual softmax output probabilities')

    plt.savefig(os.path.join(ckpt_dir, plot_filename))
    plt.show()


---
### - test and save the results to file
---

In [None]:
# WORK_DIRS = [
#     './results/PriorProbabilityShift_experiment_6_[]samples/150samples/',
#     './results/PriorProbabilityShift_experiment_6_[]samples/250samples/',
#     './results/PriorProbabilityShift_experiment_6_[]samples/500samples/',
#     './results/PriorProbabilityShift_experiment_6_[]samples/1000samples/',
#     './results/PriorProbabilityShift_experiment_6_[]samples/5000samples/',
#     './results/PriorProbabilityShift_experiment_6_[]samples/10000samples/',
# ] 

WORK_DIRS = [
    './results/PriorProbabilityShift_experiment_6_[]samples_2/150samples/',
    './results/PriorProbabilityShift_experiment_6_[]samples_2/250samples/',
    './results/PriorProbabilityShift_experiment_6_[]samples_2/500samples/',
    './results/PriorProbabilityShift_experiment_6_[]samples_2/1000samples/',
    './results/PriorProbabilityShift_experiment_6_[]samples_2/5000samples/',
    './results/PriorProbabilityShift_experiment_6_[]samples_2/10000samples/',
]  

NO_TESTS_PER_MODEL = 50
TEST_SUBSET_SIZE = 3000
SAVE_DICT_TO_FILE = True
USE_SHORTCUT_METHOD_FOR_TESTING = True
USE_NORMAL_METHOD_FOR_TESTING = False
ADJUSTING_USING_REAL_PRIORS = True

def save_test_results_to_dict(test_dict, test_method, test_size, id_test, id_distr, test_distr, test_acc, total_actual, total_predict):
    test_dict['test_method'].append(test_method)
    test_dict['testset_size'].append(test_size)
    test_dict['id_test'].append(id_test)
    test_dict['id_distr'].append(id_distr)
    test_dict['test_distr'].append(test_distr)
    test_dict['test_acc'].append(test_acc)
    correct_predict = total_predict[total_actual == total_predict]
    wrong_predict = total_predict[total_actual != total_predict]
    wrong_actual = total_actual[total_actual != total_predict]
    test_dict['correct_predicted_counts'].append(np.bincount(correct_predict, minlength=10))
    test_dict['wrong_predicted_counts'].append(np.bincount(wrong_predict, minlength=10))
    test_dict['wrong_actual_counts'].append(np.bincount(wrong_actual, minlength=10))
        

for WORK_DIR in WORK_DIRS:
    ckpt_file_list = utils.get_all_files_from_dir_ending_with(WORK_DIR, "ckpt.meta", without_file_extension=True)

    # dictionary used for saving all the results
    full_results_dict = {}

    for idx_model, ckpt_file in enumerate(ckpt_file_list):
        print('{}: Restoring model {} from {}{}'.format(utils.now_as_str(), idx_model, WORK_DIR, ckpt_file))

        test_model = Lenet5(mnist_dataset=mnist_ds, display_summary=False)
        test_model.restore_session(ckpt_dir=WORK_DIR, ckpt_filename=ckpt_file)
        current_model_train_distr = test_model.session.run(test_model.train_distr)
        full_results_dict[ckpt_file] = {'idx_model':idx_model, 'train_distr': current_model_train_distr, 
                                       'test_results': {'test_method':[], 'testset_size':[], 'id_test':[], 'id_distr':[], 'test_distr':[], 'test_acc':[], 'correct_predicted_counts':[], 'wrong_predicted_counts':[], 'wrong_actual_counts':[]}}

        print('The restored model {} was trained using distr: {}\n'.format(idx_model, current_model_train_distr))

        # test on the entire MNIST dataset
        mnist_ds = MNISTDataset(MNIST_TRAIN_IMAGES_FILEPATH, MNIST_TRAIN_LABELS_FILEPATH, MNIST_TEST_IMAGES_FILEPATH, MNIST_TEST_LABELS_FILEPATH)
        print('Testing model on the entire dataset (distr. = {})'.format(mnist_ds.test.label_distr))
        test_loss, test_acc, total_predict, total_actual, wrong_predict_images, output_probs = test_model.test_data(mnist_ds.test, use_only_one_batch=True)
        print('test_acc = {:.3f}% ({}/{})\n'.format(test_acc * 100, mnist_ds.test.num_examples - len(wrong_predict_images), mnist_ds.test.num_examples))

        # save the test results
        save_test_results_to_dict(test_dict=full_results_dict[ckpt_file]['test_results'], test_method='entire', test_size=mnist_ds.test.num_examples, id_test=-1, id_distr=-1, 
                                  test_distr=mnist_ds.test.label_distr, test_acc=test_acc, total_actual=total_actual, total_predict=total_predict)

        print('{}: start testing model on distributions\n'.format(utils.now_as_str()))
        for idx_distr, distr in enumerate(distrs_used_for_training):
            print('Testing model {} ({} times) on distribution {}: {}'.format(idx_model, NO_TESTS_PER_MODEL, idx_distr, distr))

            # compute theoretical mean accuracy
            wrong_actual = total_actual[total_actual != total_predict]
            theoretical_wrong_count = np.sum(np.bincount(wrong_actual, minlength=10) * (10 * distr)) * (TEST_SUBSET_SIZE / mnist_ds.test.num_examples)
            print('theoretical \u03BC(acc) = {:.3f}% ({}/{})'.format((1 - theoretical_wrong_count / TEST_SUBSET_SIZE) * 100, TEST_SUBSET_SIZE - theoretical_wrong_count, TEST_SUBSET_SIZE))

            if USE_SHORTCUT_METHOD_FOR_TESTING:
                list_acc_shortcut_method = []
                list_acc_shortcut_method_adj = []
                for id_test in range(NO_TESTS_PER_MODEL):
                    indices_wrt_distr = utils.get_indices_wrt_distr(labels=total_actual, weights=distr, max_no_examples=TEST_SUBSET_SIZE)
                    num_examples_selection = len(indices_wrt_distr)
                    total_predict_selection = total_predict[indices_wrt_distr]
                    total_actual_selection = total_actual[indices_wrt_distr]
                    output_probs_selection = output_probs[indices_wrt_distr]
                    no_correct_predicted = np.sum(total_predict_selection == total_actual_selection)
                    test_acc = no_correct_predicted / num_examples_selection
    #                 print('id_test = {} : test_acc = {:.3f}% ({}/{})4'.format(id_test, test_acc * 100, no_correct_predicted, num_examples_selection))

                    list_acc_shortcut_method.append(test_acc)

                    # save the test results
                    save_test_results_to_dict(test_dict=full_results_dict[ckpt_file]['test_results'], test_method='shortcut', test_size=num_examples_selection, id_test=id_test, id_distr=idx_distr, 
                                              test_distr=distr, test_acc=test_acc, total_actual=total_actual_selection, total_predict=total_predict_selection)

                    if ADJUSTING_USING_REAL_PRIORS:
                        # adjust the outputs
                        real_priors = np.bincount(total_actual_selection, minlength=10) / num_examples_selection
                        old_priors = current_model_train_distr
    #                     print('real priors = {}'.format(real_priors))
    #                     print('old prios   = {}'.format(old_priors))
                        adj_output_probs_selection = (real_priors / old_priors) * output_probs_selection / (np.sum((real_priors / old_priors) * output_probs_selection, axis = 1)[:, None])
                        adj_total_predict_selection = np.argmax(adj_output_probs_selection, axis=1)
                        count_correct_predicted = np.sum(total_predict_selection == total_actual_selection)
    #                     print('Old accuracy: {:.3f} ({}/{})'.format(count_correct_predicted / num_examples_selection, count_correct_predicted, num_examples_selection))
                        count_correct_predicted = np.sum(adj_total_predict_selection == total_actual_selection)
    #                     print('New accuracy: {:.3f} ({}/{})'.format(count_correct_predicted / num_examples_selection, count_correct_predicted, num_examples_selection))

                        # save the results into dictionary
                        adj_test_acc = np.sum(adj_total_predict_selection == total_actual_selection) / len(total_predict_selection)
                        list_acc_shortcut_method_adj.append(adj_test_acc)
                        save_test_results_to_dict(test_dict=full_results_dict[ckpt_file]['test_results'], test_method='shortcut_adj', test_size=num_examples_selection, id_test=id_test, id_distr=idx_distr, 
                                              test_distr=distr, test_acc=adj_test_acc, total_actual=total_actual_selection, total_predict=adj_total_predict_selection)


                print('\u03BC(acc_s ) = {:.3f}%, \u03C3(acc_s ) = {:.3f}%'.format(np.mean(list_acc_shortcut_method) * 100, np.std(list_acc_shortcut_method, ddof=1) * 100))
                print('\u03BC(adj_acc_s) = {:.3f}%, \u03C3(adj_acc_s) = {:.3f}%'.format(np.mean(list_acc_shortcut_method_adj) * 100, np.std(list_acc_shortcut_method_adj, ddof=1) * 100))


            if USE_NORMAL_METHOD_FOR_TESTING:
                list_acc_normal_method = []
                list_acc_normal_method_adj = []
                for id_test in range(NO_TESTS_PER_MODEL):
                    mnist_ds1 = MNISTDataset(MNIST_TRAIN_IMAGES_FILEPATH, MNIST_TRAIN_LABELS_FILEPATH, MNIST_TEST_IMAGES_FILEPATH, MNIST_TEST_LABELS_FILEPATH)
                    for s in range(id_test):
                        mnist_ds1.test.shuffle() # when MNISTDataset is instanced, the rg is reseted, so explicitly shuffling is needed for getting different results
                    mnist_ds1.impose_distribution(weights=distr, max_test_size=TEST_SUBSET_SIZE)
                    test_loss1, test_acc1, total_predict1, total_actual1, wrong_predict_images1, output_probs1 = test_model.test_data(mnist_ds1.test, use_only_one_batch=True)
    #                 print('id_test = {} : test_acc1 = {:.3f}% ({}/{})'.format(id_test, test_acc1 * 100, mnist_ds1.test.num_examples - len(wrong_predict_images1), mnist_ds1.test.num_examples))

                    list_acc_normal_method.append(test_acc1)

                    # save the test results
                    save_test_results_to_dict(test_dict=full_results_dict[ckpt_file]['test_results'], test_method='normal', test_size=mnist_ds1.test.num_examples, id_test=id_test, id_distr=idx_distr, 
                                              test_distr=distr, test_acc=test_acc1, total_actual=total_actual1, total_predict=total_predict1)

                    if ADJUSTING_USING_REAL_PRIORS:
                        # adjust the outputs
                        real_priors = mnist_ds1.test.label_distr
                        old_priors = current_model_train_distr
                        adj_output_probs1 = (real_priors / old_priors) * output_probs1 / (np.sum((real_priors / old_priors) * output_probs1, axis = 1)[:, None])
                        adj_total_predict1 = np.argmax(adj_output_probs1, axis=1)
                        count_correct_predicted = np.sum(total_predict1 == total_actual1)
    #                     print('Old accuracy: {:.3f} ({}/{})'.format(count_correct_predicted / mnist_ds1.test.num_examples, count_correct_predicted, mnist_ds1.test.num_examples))
                        count_correct_predicted = np.sum(adj_total_predict1 == total_actual1)
    #                     print('New accuracy: {:.3f} ({}/{})'.format(count_correct_predicted / mnist_ds1.test.num_examples, count_correct_predicted, mnist_ds1.test.num_examples))

                        # save the results into dictionary
                        adj_test_acc = np.sum(adj_total_predict1 == total_actual1) / mnist_ds1.test.num_examples
                        list_acc_normal_method_adj.append(adj_test_acc)
                        save_test_results_to_dict(test_dict=full_results_dict[ckpt_file]['test_results'], test_method='normal_adj', test_size=mnist_ds1.test.num_examples, id_test=id_test, id_distr=idx_distr, 
                                              test_distr=distr, test_acc=adj_test_acc, total_actual=total_actual1, total_predict=adj_total_predict1)

                print('\u03BC(acc_n) = {:.3f}%, \u03C3(acc_n) = {:.3f}%'.format(np.mean(list_acc_normal_method) * 100, np.std(list_acc_normal_method, ddof=1) * 100))
                print('\u03BC(adj_acc_n) = {:.3f}%, \u03C3(adj_acc_n) = {:.3f}%'.format(np.mean(list_acc_normal_method_adj) * 100, np.std(list_acc_normal_method_adj, ddof=1) * 100))


            print('\n')

        # analyze error distribution
        # restore_and_test_a_model_on_a_mnist_subset(mnist_ds, ckpt_dir=WORK_DIR, ckpt_filemame=ckpt_file, plot_filename = 'test_model')
        print('\n\n\n')


    # save the above results dictionary to file
    if SAVE_DICT_TO_FILE:
        filename = 'testing_results_{}testsPerModel.dict.pickle'.format(NO_TESTS_PER_MODEL)
        full_filepath = os.path.join(WORK_DIR, filename)
        filehandler = open(full_filepath, 'wb') 
        pickle.dump(full_results_dict, filehandler)
        print('Results dictionary was succesfully saved to: {}'.format(full_filepath))
        filehandler.close()


---
### - restore the results saved earlier and plot them
---

In [None]:
WORK_DIRS = [
    './results/PriorProbabilityShift_experiment_6_[]samples_2/150samples/',
    './results/PriorProbabilityShift_experiment_6_[]samples_2/250samples/',
    './results/PriorProbabilityShift_experiment_6_[]samples_2/500samples/',
    './results/PriorProbabilityShift_experiment_6_[]samples_2/1000samples/',
    './results/PriorProbabilityShift_experiment_6_[]samples_2/5000samples/',
    './results/PriorProbabilityShift_experiment_6_[]samples_2/10000samples/',
]  
filename = 'testing_results_50testsPerModel.dict.pickle'

for WORK_DIR in WORK_DIRS:

    filehandler = open(os.path.join(WORK_DIR,filename), 'rb') 
    restored_perf_dict = pickle.load(filehandler)
    filehandler.close()
    print('Results dictionary was succesfully restored from: {}{}'.format(WORK_DIR, filename))

    # get all distributions used for training
    distrs_used_for_training = []
    for ckpt_file in restored_perf_dict.keys():
        distrs_used_for_training.append(restored_perf_dict[ckpt_file]['train_distr'])
    L = len(distrs_used_for_training)

    # build the average accuracy matrices and distribution matrices, assuming that each model was tested on all distributions from distrs_used_for_training
    avg_acc_matrix = np.empty((L, L))
    avg_acc_matrix_adj = np.empty((L, L))
    std_matrix = np.empty((L, L))
    std_matrix_adj = np.empty((L, L))

    avg_correct_predictions_distr_matrix = np.empty((L, L, 10))
    avg_correct_predictions_distr_matrix_adj = np.empty((L, L, 10))
    avg_wrong_predictions_distr_matrix = np.empty((L, L, 10))
    std_wrong_predictions_distr_matrix = np.empty((L, L, 10))
    avg_wrong_predictions_distr_matrix_adj = np.empty((L, L, 10))
    avg_wrong_actual_predictions_distr_matrix = np.empty((L, L, 10))
    avg_wrong_actual_predictions_distr_matrix_adj = np.empty((L, L, 10))

    def std_1ddof(x):
        return np.std(x, ddof=1)

    for id_model, ckpt_file in enumerate(restored_perf_dict.keys()):

    #     build a pandas dataframe from test results dictionary
        perf_df = pd.DataFrame(restored_perf_dict[ckpt_file]['test_results'], columns=list(restored_perf_dict[ckpt_file]['test_results'].keys()))
        temp_df = perf_df[perf_df['test_method'] == 'shortcut'].groupby(['id_distr']).agg([np.mean, std_1ddof])
        avg_acc_matrix[id_model, :] = temp_df['test_acc']['mean']
        std_matrix[id_model, :] = temp_df['test_acc']['std_1ddof']

        temp_df = perf_df[perf_df['test_method'] == 'shortcut_adj'].groupby(['id_distr']).agg([np.mean, std_1ddof])
        avg_acc_matrix_adj[id_model, :] = temp_df['test_acc']['mean']
        std_matrix_adj[id_model, :] = temp_df['test_acc']['std_1ddof']

        for t in range(L):
            avg_correct_predictions_distr_matrix[id_model, t, :] = np.average(perf_df[(perf_df['test_method'] == 'shortcut') & (perf_df['id_distr'] == t)]['correct_predicted_counts'], axis=0)
            avg_wrong_predictions_distr_matrix[id_model, t, :] = np.average(perf_df[(perf_df['test_method'] == 'shortcut') & (perf_df['id_distr'] == t)]['wrong_predicted_counts'], axis=0)
    #         std_wrong_predictions_distr_matrix[id_model, t, :] = np.std(np.array(perf_df[(perf_df['test_method'] == 'shortcut') & (perf_df['id_distr'] == t)]['wrong_predicted_counts']), axis=0, ddof=1)
            avg_wrong_actual_predictions_distr_matrix[id_model, t, :] = np.average(perf_df[(perf_df['test_method'] == 'shortcut') & (perf_df['id_distr'] == t)]['wrong_actual_counts'], axis=0)
            avg_correct_predictions_distr_matrix_adj[id_model, t, :] = np.average(perf_df[(perf_df['test_method'] == 'shortcut_adj') & (perf_df['id_distr'] == t)]['correct_predicted_counts'], axis=0)
            avg_wrong_predictions_distr_matrix_adj[id_model, t, :] = np.average(perf_df[(perf_df['test_method'] == 'shortcut_adj') & (perf_df['id_distr'] == t)]['wrong_predicted_counts'], axis=0)
            avg_wrong_actual_predictions_distr_matrix_adj[id_model, t, :] = np.average(perf_df[(perf_df['test_method'] == 'shortcut_adj') & (perf_df['id_distr'] == t)]['wrong_actual_counts'], axis=0)

    diff_acc_matrix = avg_acc_matrix_adj - avg_acc_matrix

    # plot the average accuracy comparison matrix
    acc_matrix_plt = utils.plot_acc_matrix(train_distributions=distrs_used_for_training, acc_matrix=avg_acc_matrix, use_percent_for_accuracies=True,
                                           std_matrix=std_matrix, title='accuracy matrix (avg and std)')
    acc_matrix_plt.savefig(os.path.join(WORK_DIR,filename + '.acc_matrix.png'), bbox_inches='tight', pad_inches=1)

    # plot distributions corresponding to correct predictions
    acc_matrix_plt = utils.plot_acc_matrix(train_distributions=distrs_used_for_training, acc_matrix=avg_acc_matrix, distr_matrix=avg_correct_predictions_distr_matrix,
                                           use_percent_for_accuracies=True, title='correct predictions distr. (avg)')
    acc_matrix_plt.savefig(os.path.join(WORK_DIR,filename + '.correct_predictions_distr_matrix.png'), bbox_inches='tight', pad_inches=1)

    # plot distributions corresponding to wrong predictions
    acc_matrix_plt = utils.plot_acc_matrix(train_distributions=distrs_used_for_training, acc_matrix=avg_acc_matrix, distr_matrix=avg_wrong_predictions_distr_matrix,
                                           use_percent_for_accuracies=True, title='wrong predictions (predicted) distr. (avg)')
    acc_matrix_plt.savefig(os.path.join(WORK_DIR,filename + '.wrong_predictions_distr_matrix.png'), bbox_inches='tight', pad_inches=1)

    # plot distributions corresponding to wrong actual predictions
    acc_matrix_plt = utils.plot_acc_matrix(train_distributions=distrs_used_for_training, acc_matrix=avg_acc_matrix, distr_matrix=avg_wrong_actual_predictions_distr_matrix,
                                           use_percent_for_accuracies=True, title='wrong predictions (actual) distr. (avg)')
    acc_matrix_plt.savefig(os.path.join(WORK_DIR,filename + '.wrong_actual_predictions_distr_matrix.png'), bbox_inches='tight', pad_inches=1)

    # plot the average accuracy comparison matrix (adjusted version using real priors)
    acc_matrix_plt = utils.plot_acc_matrix(train_distributions=distrs_used_for_training, acc_matrix=avg_acc_matrix_adj, use_percent_for_accuracies=True,
                                           std_matrix=std_matrix_adj, title='accuracy matrix after adj. using real priors (avg and std)')
    acc_matrix_plt.savefig(os.path.join(WORK_DIR,filename + '.adj_acc_matrix.png'), bbox_inches='tight', pad_inches=1)

    # build and plot distributions corresponding to correct predictions (adjusted version using real priors)
    acc_matrix_plt = utils.plot_acc_matrix(train_distributions=distrs_used_for_training, acc_matrix=avg_acc_matrix_adj, distr_matrix=avg_correct_predictions_distr_matrix_adj,
                                           use_percent_for_accuracies=True, title='correct predictions distr. after adj. using real priors (avg)')
    acc_matrix_plt.savefig(os.path.join(WORK_DIR,filename + '.adj_correct_predictions_distr_matrix.png'), bbox_inches='tight', pad_inches=1)

    # plot distributions corresponding to wrong predictions (adjusted version using real priors)
    acc_matrix_plt = utils.plot_acc_matrix(train_distributions=distrs_used_for_training, acc_matrix=avg_acc_matrix_adj, distr_matrix=avg_wrong_predictions_distr_matrix_adj,
                                           use_percent_for_accuracies=True, title='wrong predictions (predicted) distr. after adj. using real priors (avg)')
    acc_matrix_plt.savefig(os.path.join(WORK_DIR,filename + '.adj_wrong_predictions_distr_matrix.png'), bbox_inches='tight', pad_inches=1)

    # plot distributions corresponding to wrong actual predictions (adjusted version using real priors)
    acc_matrix_plt = utils.plot_acc_matrix(train_distributions=distrs_used_for_training, acc_matrix=avg_acc_matrix_adj, distr_matrix=avg_wrong_actual_predictions_distr_matrix_adj,
                                           use_percent_for_accuracies=True, title='wrong predictions (actual) distr. after adj. using real priors (avg)')
    acc_matrix_plt.savefig(os.path.join(WORK_DIR,filename + '.adj_wrong_actual_predictions_distr_matrix.png'), bbox_inches='tight', pad_inches=1)

    # plot the difference of accuracy matrix (standard version vs adjusted version using real priors)
    acc_matrix_plt = utils.plot_acc_matrix(train_distributions=distrs_used_for_training, acc_matrix=diff_acc_matrix, use_percent_for_accuracies=True,
                                          title='diff. accuracy matrix (before vs. after adj. using real priors) (avg)')
    acc_matrix_plt.savefig(os.path.join(WORK_DIR,'{}_{:.1f}%avgDiff.diff_acc_matrix.png'.format(filename, np.average(diff_acc_matrix) * 100)), bbox_inches='tight', pad_inches=1)
    
    plt.show()


---
### - test the models and the adjusting method on the entire MNIST testset
---

In [None]:
# WORK_DIR = "./results/PriorProbabilityShift_experiment_5_10000samples/"
# WORK_DIR = "./results/PriorProbabilityShift_experiment_5_5000samples/"
WORK_DIR = "./results/PriorProbabilityShift_experiment_6_[]samples_2/1000samples/"

ckpt_file_list = utils.get_all_files_from_dir_ending_with(WORK_DIR, "ckpt.meta", without_file_extension=True)

for idx_model, ckpt_file in enumerate(ckpt_file_list):
    print('Restoring model {}{} {}'.format(WORK_DIR, idx_model, ckpt_file))
    test_model = Lenet5(mnist_dataset=mnist_ds, display_summary=False)
    test_model.restore_session(ckpt_dir=WORK_DIR, ckpt_filename=ckpt_file)
    test_model_train_distr = test_model.session.run(test_model.train_distr)
    print('The restored model was trained using distr: {}\n'.format(test_model_train_distr))
    
    # test it on the entire MNIST test set
    mnist_ds = MNISTDataset(MNIST_TRAIN_IMAGES_FILEPATH, MNIST_TRAIN_LABELS_FILEPATH, MNIST_TEST_IMAGES_FILEPATH, MNIST_TEST_LABELS_FILEPATH)
    test_loss, test_acc, total_predict, total_actual, wrong_predict_images, output_probs = test_model.test_data(mnist_ds.test, use_only_one_batch=True)
  
    # analyze error distribution
    restore_and_test_a_model_on_a_mnist_subset(mnist_ds, ckpt_dir=WORK_DIR, ckpt_filemame=ckpt_file, plot_filename = 'test_model')
    
    # adjusing the ouputs using real priots
    real_priors = mnist_ds.test.label_distr
    old_priors = test_model_train_distr
    print('real priors = {}'.format(real_priors))
    print('old prios   = {}'.format(old_priors))

    corrected_output_probs = (real_priors / old_priors) * output_probs / (np.sum((real_priors / old_priors) * output_probs, axis = 1)[:, None])
    corrected_total_predict = np.argmax(corrected_output_probs, axis=1)
    
    count_correct_predicted = np.sum(total_predict == total_actual)
    print('Old accuracy: {:.3f} ({}/{})'.format(count_correct_predicted / mnist_ds.test.num_examples, count_correct_predicted, mnist_ds.test.num_examples))
    count_correct_predicted = np.sum(corrected_total_predict == total_actual)
    print('New accuracy: {:.3f} ({}/{}), after adjusted using real priors'.format(count_correct_predicted / mnist_ds.test.num_examples, count_correct_predicted, mnist_ds.test.num_examples))
    
    print('\n\n\n')

---
### - analyze some examples for which the adjusting using the real priors changed the decision
---

In [None]:
MAX_IMAGES_TO_DISPLAY_PER_MODEL = 3
# WORK_DIR = "./results/PriorProbabilityShift_experiment_5_10000samples/"
# WORK_DIR = "./results/PriorProbabilityShift_experiment_5_5000samples/"
WORK_DIR = "./results/PriorProbabilityShift_experiment_6_[]samples_2/1000samples/"

ckpt_file_list = utils.get_all_files_from_dir_ending_with(WORK_DIR, "ckpt.meta", without_file_extension=True)

for idx_model, ckpt_file in enumerate(ckpt_file_list):
    print('Restoring model {}{} {}'.format(WORK_DIR, idx_model, ckpt_file))
    test_model = Lenet5(mnist_dataset=mnist_ds, display_summary=False)
    test_model.restore_session(ckpt_dir=WORK_DIR, ckpt_filename=ckpt_file)
    test_model_train_distr = test_model.session.run(test_model.train_distr)
#     print('The restored model was trained using distr: {}\n'.format(test_model_train_distr))
    
    for idx_distr, distr in enumerate(distrs_used_for_training):
        if (test_model_train_distr == distr).all():
            continue
#         print('Train distr: {}'.format(test_model_train_distr))
#         print('Test distr: {}'.format(distr))
        
        mnist_ds = MNISTDataset(MNIST_TRAIN_IMAGES_FILEPATH, MNIST_TRAIN_LABELS_FILEPATH, MNIST_TEST_IMAGES_FILEPATH, MNIST_TEST_LABELS_FILEPATH)
        mnist_ds.impose_distribution(weights=distr, global_max_weight=np.max(distrs_used_for_training))
        test_loss, test_acc, total_predict, total_actual, wrong_predict_images, output_probs = test_model.test_data(mnist_ds.test, use_only_one_batch=True)
        correct_predict = total_predict[total_actual == total_predict]
        wrong_predict = total_predict[total_actual != total_predict]
        wrong_actual = total_actual[total_actual != total_predict]

        # adjust the outputs
        real_priors = mnist_ds.test.label_distr
        old_priors = test_model_train_distr
#         print('real priors = {}'.format(real_priors))
#         print('old prios   = {}'.format(old_priors))
        adj_output_probs = (real_priors / old_priors) * output_probs / (np.sum((real_priors / old_priors) * output_probs, axis = 1)[:, None])
        adj_total_predict = np.argmax(adj_output_probs, axis=1).astype(np.int32)
        count_correct_predicted = np.sum(total_predict == total_actual)
        acc = count_correct_predicted / mnist_ds.test.num_examples
        print('Old accuracy: {:.1f}% ({}/{})'.format(acc * 100, count_correct_predicted, mnist_ds.test.num_examples))
        count_correct_predicted = np.sum(adj_total_predict == total_actual)
        
        adj_test_acc = np.sum(adj_total_predict == total_actual) / len(adj_total_predict)
        print('New accuracy: {:.1f}%({}/{}), diff = {:.2f}%'.format(adj_test_acc * 100, count_correct_predicted, mnist_ds.test.num_examples,(adj_test_acc - acc) * 100))
        
        adj_correct_predict = adj_total_predict[total_actual == adj_total_predict]
        adj_wrong_predict = adj_total_predict[total_actual != adj_total_predict]
        adj_wrong_actual = total_actual[total_actual != adj_total_predict]
        
        indices_wrong_predicted_corrected_by_adj = np.where(np.logical_and((total_actual != total_predict), (total_actual == adj_total_predict)) == True)[0]
        indices_correct_predicted_disturbed_by_adj = np.where(np.logical_and((total_actual == total_predict), (total_actual != adj_total_predict)) == True)[0]
        print('#wrong_predicted_corrected_by_adj = {}'.format(indices_wrong_predicted_corrected_by_adj.shape[0]))
        print('#correct_predicted_disturbed_by_adj = {}'.format(indices_correct_predicted_disturbed_by_adj.shape[0]))
        if indices_wrong_predicted_corrected_by_adj.shape[0] > MAX_IMAGES_TO_DISPLAY_PER_MODEL:
            indices_wrong_predicted_corrected_by_adj = indices_wrong_predicted_corrected_by_adj[0:MAX_IMAGES_TO_DISPLAY_PER_MODEL]
        if indices_correct_predicted_disturbed_by_adj.shape[0] > MAX_IMAGES_TO_DISPLAY_PER_MODEL:
            indices_correct_predicted_disturbed_by_adj = indices_correct_predicted_disturbed_by_adj[0:MAX_IMAGES_TO_DISPLAY_PER_MODEL]
        for idx in np.concatenate((indices_wrong_predicted_corrected_by_adj, indices_correct_predicted_disturbed_by_adj)):
#             print('output probs: {}'.format(output_probs[idx]))
#             print('adj. probs: {}'.format(adj_output_probs[idx]))

            count_figures = 5
            fig = plt.figure(figsize=(20, 2))

            k = 1
            plt.subplot(1, count_figures, k)
            plt.bar(range(10), old_priors)
            plt.xticks(range(0, 10))
            plt.title('old priors (= train label distr)')

            k+=1
            plt.subplot(1, count_figures, k)
            plt.bar(range(10), real_priors)
            plt.xticks(range(0, 10))
            plt.title('real priors (= test label distr)')

            k += 1
            plt.subplot(1, count_figures, k)
            plt.imshow(mnist_ds.test.images[idx].reshape((28,28)), cmap='gray')
            plt.title('initial predicted {} and is {}; (after adj.: {})'.format(total_predict[idx], total_actual[idx], adj_total_predict[idx]))

            k += 1
            plt.subplot(1, count_figures, k)
            plt.bar(range(10), output_probs[idx])
            plt.xticks(range(0, 10))
            plt.title('output probs')
            axes = plt.gca()
            axes.set_ylim([0,1])

            k+=1
            plt.subplot(1, count_figures, k)
            plt.bar(range(10), adj_output_probs[idx])
            plt.xticks(range(0, 10))
            plt.title('adj. probs')
            axes = plt.gca()
            axes.set_ylim([0,1])

            plt.show()
        print('\n\n\n')

# TO DO. Adaptation using estimated priors