In [101]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [102]:
import math
import random

import pickle
import os
import datetime as dt

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

SEQ_LENGTH = 4


This notebook allows to define any explicit model and to produce its prediction on virtual choices sequences, for comparison with exploratory model

In [None]:
1 br : 3/1 -> - 3 eps vs 3*1 eps 
2 br : 2/2 -> -1 eps vs 2*1 eps 
3 br : 1/3 -> 1/3 eps vs 1*1

In [317]:
def explicit_prediction_1(x_seq, eps = 0.005): 
    x_seq = x_seq.squeeze(0)
    rewards = x_seq[:, -1]
    best_rewards = np.where(rewards == np.max(rewards))[0]
    best_choices = np.unique(x_seq[best_rewards, :-1].argmax(-1))
    next_choice_probs = np.zeros(4)
    next_choice_probs[best_choices] = 1/len(best_choices) - (4-len(best_choices))/len(best_choices) * eps
    next_choice_probs[next_choice_probs==0] = eps 
    return next_choice_probs

## simulating with custom inputs

In [318]:
def get_last_seq(person_df, choice_number, seq_len, debug_prints=False):
    """
    Gets a person's data and a choice number and returns a sequence of previous 4 (or seq_length) choices made (maybe not 4 (seq_length) consecutives)
    """
    if choice_number < seq_len:
        if debug_prints: print("must be {} and more to get a sequence".format(seq_len))
        return -2
    ind = np.where(person_df["orig_choice_num"] == choice_number)
    if len(ind[0]) == 0:
        if debug_prints: print("choice number {} not found".format(choice_number))
        return -9
    elif len(ind[0]) > 1:
        if debug_prints: print("FOR SOME REASON THERE'S MORE THAN ONE CHOICE NUMBER LIKE {}".format(choice_number))
        return -1
    # got here so I found the choice number , now lets check if we have 4 previous choices
    seq = []
    for i in range(ind[0][0] - seq_len + 1, ind[0][0] + 1, 1):
        seq.append(i)
    return person_df.iloc[seq]

In [319]:
def create_data_sequence(seq_df, seq_len):
    """
    gets a SINGLE sequence in df form and returns it in a train/test ready form (numpy) X and y
    """
    cur_data = seq_df.copy()
    cur_data['choice'] = cur_data.choice.apply(lambda x: x - 1)
    cur_data['prev_choice'] = cur_data.prev_choice.apply(lambda x: x - 1)
    X = cur_data.drop(
        columns=['choice', 'user', 'time', 'reward', 'payoff_structure', 'reward_1', 'reward_2', 'reward_3', 'reward_4',
                 'orig_choice_num'])
    X_prev = F.one_hot(X.prev_choice, num_classes=4, dtype='int64')
    y = cur_data.choice
    num_of_classes = len(y.unique())
    y = F.one_hot(y, num_classes=4, dtype='int64')

    new_X = []
    for prev_choice, prev_reward in zip(X_prev, X.prev_reward):
        new_i = np.append(prev_choice, prev_reward)
        new_X.append(new_i)
    new_X = np.array(new_X)

    # reshape X to be [samples, time steps, features]
    X_reshaped = np.reshape(new_X, (1, seq_len, new_X.shape[1]))
    y_reshaped = np.reshape(y, (1, seq_len, y.shape[1]))

    return X_reshaped, y_reshaped

In [320]:
def create_fake_sequence(seq, rewards, seq_len = SEQ_LENGTH):
    """
    seq - is the input : array with 4 ints in range 0-3
    rewards - are the corresponding rewards array at each time
    returns x, y in a sequence format
    """

    X_prev = F.one_hot(torch.tensor(seq), num_classes=4)

    new_X = []
    for prev_choice, prev_reward in zip(X_prev, rewards):
        new_i = np.append(prev_choice, prev_reward)
        new_X.append(new_i)
    new_X = np.array(new_X)

    # reshape X to be [samples, time steps, features]
    X_reshaped = np.reshape(new_X, (1, seq_len, new_X.shape[1]))

    return X_reshaped

In [321]:
def make_prediction(input_seq, input_rewards, print_to_screen=True, print_to_log=True, model = explicit_prediction_1):
    """
    input: a sequence and rewards , both as lists of ints
    returns: a prediction
    """
    fake_x = create_fake_sequence(seq=input_seq,rewards=input_rewards)

    probs = model(fake_x)
    print(probs)
    class_pred = np.random.choice([0,1,2,3], p=probs)
    if print_to_screen:
        print("predicted class: ",class_pred)
        print(probs)

    return (class_pred, probs)

In [323]:
def run_fake_seq():
    """
    wrapper for running a new sequence
    if input is "exit" at any phase, exiting the outer loop
    """
    g = input("Enter a sequence of choices: ")
    if g == "exit":
        return 1
    input_seq = [int(x) for x in g.split()]
    r = input("Enter the corresponding rewards: ")
    if g == "exit":
        return 1
    input_rewards = [int(x) for x in r.split()]

    make_prediction(input_seq, input_rewards)

    return 0

In [324]:
def continuous_run():
    halt = 0
    while not halt:
        halt = run_fake_seq()

In [325]:
def split_prediction_argmax(cell):
    splitted_cell = cell.split()
    return int(splitted_cell[0]), float(splitted_cell[1].replace('[','').replace(']',''))

In [326]:
def produce_comparison_df(output_with_argmax, generated_output_with_argmax):
    predict_compare = []
    argmax_compare = []
    for gen_row, orig_row in zip(generated_output_with_argmax.to_numpy(), output_with_argmax.to_numpy()):
        row_predictions_compare = []
        row_argmax_compare = []
        for gen_out, orig_out in zip(gen_row, orig_row):
            gen_pred, gen_argmax = split_prediction_argmax(gen_out)
            orig_pred, orig_argmax = split_prediction_argmax(orig_out)
            row_predictions_compare.append(int(gen_pred==orig_pred))
            row_argmax_compare.append(np.abs(gen_argmax-orig_argmax))
        predict_compare.append(row_predictions_compare)
        argmax_compare.append(row_argmax_compare)

    predict_compare_df = pd.DataFrame(np.asmatrix(predict_compare), index=generated_output_with_argmax.index, columns=generated_output_with_argmax.columns)
    argmax_compare_df = pd.DataFrame(np.asmatrix(argmax_compare), index=generated_output_with_argmax.index, columns=generated_output_with_argmax.columns)
    return predict_compare_df, argmax_compare_df

In [327]:
################## V5.1 FULL #############################
# Patterns:
constant_group = ['1 1 1 1', '2 2 2 2', '3 3 3 3', '4 4 4 4']
one_different = ['2 2 2 1', '4 2 2 2', '2 2 2 4', '1 2 2 2', '3 2 2 2', '2 3 3 3', '2 1 1 1']
repeating_two = ['1 2 1 2', '2 1 2 1', '2 3 2 3', '3 2 3 2', '4 3 4 3', '3 4 3 4']
all_different = ['1 2 3 4', '4 3 2 1', '2 3 4 1', '1 4 3 2', '3 4 1 2', '2 1 4 3', '4 1 2 3', '3 2 1 4']

all_patterns_groups = [constant_group, one_different, repeating_two, all_different]


# Rewards:
constant_group_rew = ['10 10 10 10', '15 15 15 15', '20 20 20 20', '25 25 25 25', '30 30 30 30', '35 35 35 35',
                      '40 40 40 40', '45 45 45 45', '50 50 50 50', '55 55 55 55', '60 60 60 60', '65 65 65 65',
                      '70 70 70 70', '75 75 75 75', '80 80 80 80', '85 85 85 85', '90 90 90 90']

ascending_rew = ['10 20 30 40', '15 20 25 30', '10 30 50 70', '10 40 70 90', '30 35 40 45', '45 50 55 65',
                 '50 55 65 75', '55 65 75 85', '65 75 85 95', '75 80 85 90', '10 15 20 25', '25 45 65 85',
                 '40 60 80 90', '60 70 80 90', '20 30 70 80', '40 50 60 70', '50 60 70 80']

descending_rew = ['90 80 70 60', '80 70 60 50', '70 60 50 40', '60 50 40 30', '50 40 30 20', '40 30 20 10',
                  '90 70 50 30', '70 50 30 10', '95 75 55 35', '75 55 35 15', '80 60 40 20', '85 65 45 25',
                  '90 70 40 10', '90 60 30 10', '90 50 30 10', '90 40 30 20', '90 30 20 10', '60 30 20 10']


one_different_rew_good = ['10 10 10 90', '20 20 20 90', '30 30 30 90', '40 40 40 90', '50 50 50 90', '60 60 60 90',
                          '10 10 90 10', '20 20 90 20', '30 30 90 30', '40 40 90 40', '50 50 90 50', '20 80 20 20',
                          '30 90 30 30', '60 90 60 60', '80 20 20 20', '90 30 30 30', '90 40 40 40', '90 50 50 50',
                          '20 20 20 80', '40 40 40 80', '60 60 60 80', '20 20 20 40', '40 40 40 60', '20 20 20 60',
                          '20 20 80 20', '40 40 80 40', '60 60 80 40', '20 20 40 20', '40 40 60 40', '20 20 60 20',
                          '20 80 20 20', '40 80 40 40', '60 80 60 60', '20 40 20 20', '40 60 40 40', '20 60 20 20',
                          '80 20 20 20', '80 40 40 40', '80 60 60 60', '40 20 20 20', '60 40 40 40', '60 20 20 20']


one_different_rew_bad = ['90 90 90 10', '90 90 90 20', '90 90 90 30', '90 90 90 40', '90 90 90 50', '90 90 90 60',
                         '90 90 10 90', '90 90 20 90', '90 90 30 90', '90 90 40 90', '90 90 50 90', '80 20 80 80',
                         '90 30 90 90', '90 60 90 90', '20 80 80 80', '30 90 90 90', '40 90 90 90', '50 90 90 90',
                         '80 80 80 20', '80 80 80 40', '80 80 80 60', '40 40 40 20', '60 60 60 40', '60 60 60 20',
                         '80 80 20 80', '80 80 40 80', '80 80 60 80', '40 40 20 40', '60 60 40 60', '60 60 20 60',
                         '80 20 80 80', '80 40 80 80', '80 60 80 80', '40 20 40 40', '60 40 60 60', '60 20 60 60',
                         '20 80 80 80', '40 80 80 80', '60 80 80 80', '20 40 40 40', '40 60 60 60', '20 60 60 60']


all_rewards_groups = [constant_group_rew, ascending_rew, descending_rew, one_different_rew_good, one_different_rew_bad]

In [328]:
all_rewards = []
for i in all_rewards_groups:
    for j in i:
        all_rewards.append(j)

all_patterns = []
for i in all_patterns_groups:
    for j in i:
        all_patterns.append(j)

In [329]:
new_check_df = pd.DataFrame(columns=all_rewards,index=all_patterns)

In [330]:
output_df = new_check_df.copy()
softmax_df = new_check_df.copy()
output_with_argmax = new_check_df.copy()
output_with_argmin = new_check_df.copy()

In [331]:

model = explicit_prediction_1 # put any explicit model here

for index, row in new_check_df.iterrows():
    raw_index = [int(x)-1 for x in index.split()]    # since actual model gets 0-3
    raw_cols = [r.split() for r in list(row.index)]

    for orig_col in list(row.index):
        # convert current column to array
        cur_col = np.fromstring(orig_col, dtype=int, sep=' ')
        print(cur_col)
        row_col_pred, row_col_softmax = make_prediction(raw_index, cur_col, model = model)
        output_df.loc[index, orig_col] = row_col_pred+1
        softmax_df.loc[index, orig_col] = np.array2string(np.ravel(row_col_softmax), precision=15, separator=',', suppress_small=True)
        output_with_argmax.loc[index, orig_col] =  str((row_col_pred+1)) + " " + np.array2string(np.ravel(row_col_softmax)[row_col_pred], precision=3, separator=',', suppress_small=True)
        output_with_argmin.loc[index, orig_col] = str(np.argmin(np.ravel(row_col_softmax))+1) + " " + np.array2string(np.min(np.ravel(row_col_softmax)), precision=3, separator=',', suppress_small=True)
#     raw_cur_cols = [int(x) for x in raw_cols[0]]
#     make_prediction(raw_index, raw_cur_cols)


[10 10 10 10]
[0.985 0.005 0.005 0.005]
predicted class:  0
[0.985 0.005 0.005 0.005]
[15 15 15 15]
[0.985 0.005 0.005 0.005]
predicted class:  0
[0.985 0.005 0.005 0.005]
[20 20 20 20]
[0.985 0.005 0.005 0.005]
predicted class:  1
[0.985 0.005 0.005 0.005]
[25 25 25 25]
[0.985 0.005 0.005 0.005]
predicted class:  0
[0.985 0.005 0.005 0.005]
[30 30 30 30]
[0.985 0.005 0.005 0.005]
predicted class:  0
[0.985 0.005 0.005 0.005]
[35 35 35 35]
[0.985 0.005 0.005 0.005]
predicted class:  0
[0.985 0.005 0.005 0.005]
[40 40 40 40]
[0.985 0.005 0.005 0.005]
predicted class:  0
[0.985 0.005 0.005 0.005]
[45 45 45 45]
[0.985 0.005 0.005 0.005]
predicted class:  0
[0.985 0.005 0.005 0.005]
[50 50 50 50]
[0.985 0.005 0.005 0.005]
predicted class:  0
[0.985 0.005 0.005 0.005]
[55 55 55 55]
[0.985 0.005 0.005 0.005]
predicted class:  0
[0.985 0.005 0.005 0.005]
[60 60 60 60]
[0.985 0.005 0.005 0.005]
predicted class:  0
[0.985 0.005 0.005 0.005]
[65 65 65 65]
[0.985 0.005 0.005 0.005]
predicted clas

[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[60 50 40 30]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[50 40 30 20]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[40 30 20 10]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[90 70 50 30]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[70 50 30 10]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[95 75 55 35]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[75 55 35 15]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[80 60 40 20]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[85 65 45 25]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[90 70 40 10]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[90 60 30 10]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0

[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0.005 0.985 0.005]
[60 30 20 10]
[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0.005 0.985 0.005]
[10 10 10 90]
[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0.005 0.985 0.005]
[20 20 20 90]
[0.005 0.005 0.985 0.005]
predicted class:  1
[0.005 0.005 0.985 0.005]
[30 30 30 90]
[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0.005 0.985 0.005]
[40 40 40 90]
[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0.005 0.985 0.005]
[50 50 50 90]
[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0.005 0.985 0.005]
[60 60 60 90]
[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0.005 0.985 0.005]
[10 10 90 10]
[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0.005 0.985 0.005]
[20 20 90 20]
[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0.005 0.985 0.005]
[30 30 90 30]
[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0.005 0.985 0.005]
[40 40 90 40]
[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0

predicted class:  3
[0.005 0.005 0.005 0.985]
[20 20 20 60]
[0.005 0.005 0.005 0.985]
predicted class:  3
[0.005 0.005 0.005 0.985]
[20 20 80 20]
[0.005 0.005 0.005 0.985]
predicted class:  3
[0.005 0.005 0.005 0.985]
[40 40 80 40]
[0.005 0.005 0.005 0.985]
predicted class:  3
[0.005 0.005 0.005 0.985]
[60 60 80 40]
[0.005 0.005 0.005 0.985]
predicted class:  3
[0.005 0.005 0.005 0.985]
[20 20 40 20]
[0.005 0.005 0.005 0.985]
predicted class:  3
[0.005 0.005 0.005 0.985]
[40 40 60 40]
[0.005 0.005 0.005 0.985]
predicted class:  3
[0.005 0.005 0.005 0.985]
[20 20 60 20]
[0.005 0.005 0.005 0.985]
predicted class:  3
[0.005 0.005 0.005 0.985]
[20 80 20 20]
[0.005 0.005 0.005 0.985]
predicted class:  3
[0.005 0.005 0.005 0.985]
[40 80 40 40]
[0.005 0.005 0.005 0.985]
predicted class:  3
[0.005 0.005 0.005 0.985]
[60 80 60 60]
[0.005 0.005 0.005 0.985]
predicted class:  3
[0.005 0.005 0.005 0.985]
[20 40 20 20]
[0.005 0.005 0.005 0.985]
predicted class:  3
[0.005 0.005 0.005 0.985]
[40 60 4

[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[80 40 40 40]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[80 60 60 60]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[40 20 20 20]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[60 40 40 40]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[60 20 20 20]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[90 90 90 10]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[90 90 90 20]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[90 90 90 30]
[0.005 0.985 0.005 0.005]
predicted class:  0
[0.005 0.985 0.005 0.005]
[90 90 90 40]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[90 90 90 50]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[90 90 90 60]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0

[0.005 0.495 0.005 0.495]
predicted class:  1
[0.005 0.495 0.005 0.495]
[80 80 80 40]
[0.005 0.495 0.005 0.495]
predicted class:  1
[0.005 0.495 0.005 0.495]
[80 80 80 60]
[0.005 0.495 0.005 0.495]
predicted class:  1
[0.005 0.495 0.005 0.495]
[40 40 40 20]
[0.005 0.495 0.005 0.495]
predicted class:  3
[0.005 0.495 0.005 0.495]
[60 60 60 40]
[0.005 0.495 0.005 0.495]
predicted class:  3
[0.005 0.495 0.005 0.495]
[60 60 60 20]
[0.005 0.495 0.005 0.495]
predicted class:  1
[0.005 0.495 0.005 0.495]
[80 80 20 80]
[0.005 0.495 0.005 0.495]
predicted class:  3
[0.005 0.495 0.005 0.495]
[80 80 40 80]
[0.005 0.495 0.005 0.495]
predicted class:  1
[0.005 0.495 0.005 0.495]
[80 80 60 80]
[0.005 0.495 0.005 0.495]
predicted class:  3
[0.005 0.495 0.005 0.495]
[40 40 20 40]
[0.005 0.495 0.005 0.495]
predicted class:  1
[0.005 0.495 0.005 0.495]
[60 60 40 60]
[0.005 0.495 0.005 0.495]
predicted class:  1
[0.005 0.495 0.005 0.495]
[60 60 20 60]
[0.005 0.495 0.005 0.495]
predicted class:  3
[0.005 0

[0.005 0.495 0.005 0.495]
predicted class:  1
[0.005 0.495 0.005 0.495]
[60 20 60 60]
[0.005 0.495 0.005 0.495]
predicted class:  3
[0.005 0.495 0.005 0.495]
[20 80 80 80]
[0.005 0.495 0.005 0.495]
predicted class:  3
[0.005 0.495 0.005 0.495]
[40 80 80 80]
[0.005 0.495 0.005 0.495]
predicted class:  2
[0.005 0.495 0.005 0.495]
[60 80 80 80]
[0.005 0.495 0.005 0.495]
predicted class:  1
[0.005 0.495 0.005 0.495]
[20 40 40 40]
[0.005 0.495 0.005 0.495]
predicted class:  1
[0.005 0.495 0.005 0.495]
[40 60 60 60]
[0.005 0.495 0.005 0.495]
predicted class:  3
[0.005 0.495 0.005 0.495]
[20 60 60 60]
[0.005 0.495 0.005 0.495]
predicted class:  3
[0.005 0.495 0.005 0.495]
[10 10 10 10]
[0.495 0.495 0.005 0.005]
predicted class:  1
[0.495 0.495 0.005 0.005]
[15 15 15 15]
[0.495 0.495 0.005 0.005]
predicted class:  0
[0.495 0.495 0.005 0.005]
[20 20 20 20]
[0.495 0.495 0.005 0.005]
predicted class:  1
[0.495 0.495 0.005 0.005]
[25 25 25 25]
[0.495 0.495 0.005 0.005]
predicted class:  1
[0.495 0

[45 45 45 45]
[0.005 0.495 0.495 0.005]
predicted class:  2
[0.005 0.495 0.495 0.005]
[50 50 50 50]
[0.005 0.495 0.495 0.005]
predicted class:  2
[0.005 0.495 0.495 0.005]
[55 55 55 55]
[0.005 0.495 0.495 0.005]
predicted class:  2
[0.005 0.495 0.495 0.005]
[60 60 60 60]
[0.005 0.495 0.495 0.005]
predicted class:  2
[0.005 0.495 0.495 0.005]
[65 65 65 65]
[0.005 0.495 0.495 0.005]
predicted class:  1
[0.005 0.495 0.495 0.005]
[70 70 70 70]
[0.005 0.495 0.495 0.005]
predicted class:  1
[0.005 0.495 0.495 0.005]
[75 75 75 75]
[0.005 0.495 0.495 0.005]
predicted class:  2
[0.005 0.495 0.495 0.005]
[80 80 80 80]
[0.005 0.495 0.495 0.005]
predicted class:  2
[0.005 0.495 0.495 0.005]
[85 85 85 85]
[0.005 0.495 0.495 0.005]
predicted class:  1
[0.005 0.495 0.495 0.005]
[90 90 90 90]
[0.005 0.495 0.495 0.005]
predicted class:  1
[0.005 0.495 0.495 0.005]
[10 20 30 40]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[15 20 25 30]
[0.005 0.985 0.005 0.005]
predicted clas

[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0.005 0.985 0.005]
[60 70 80 90]
[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0.005 0.985 0.005]
[20 30 70 80]
[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0.005 0.985 0.005]
[40 50 60 70]
[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0.005 0.985 0.005]
[50 60 70 80]
[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0.005 0.985 0.005]
[90 80 70 60]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[80 70 60 50]
[0.005 0.985 0.005 0.005]
predicted class:  2
[0.005 0.985 0.005 0.005]
[70 60 50 40]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[60 50 40 30]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[50 40 30 20]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[40 30 20 10]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[90 70 50 30]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0

[0.985 0.005 0.005 0.005]
predicted class:  0
[0.985 0.005 0.005 0.005]
[30 30 30 90]
[0.985 0.005 0.005 0.005]
predicted class:  0
[0.985 0.005 0.005 0.005]
[40 40 40 90]
[0.985 0.005 0.005 0.005]
predicted class:  0
[0.985 0.005 0.005 0.005]
[50 50 50 90]
[0.985 0.005 0.005 0.005]
predicted class:  0
[0.985 0.005 0.005 0.005]
[60 60 60 90]
[0.985 0.005 0.005 0.005]
predicted class:  0
[0.985 0.005 0.005 0.005]
[10 10 90 10]
[0.985 0.005 0.005 0.005]
predicted class:  0
[0.985 0.005 0.005 0.005]
[20 20 90 20]
[0.985 0.005 0.005 0.005]
predicted class:  0
[0.985 0.005 0.005 0.005]
[30 30 90 30]
[0.985 0.005 0.005 0.005]
predicted class:  0
[0.985 0.005 0.005 0.005]
[40 40 90 40]
[0.985 0.005 0.005 0.005]
predicted class:  0
[0.985 0.005 0.005 0.005]
[50 50 90 50]
[0.985 0.005 0.005 0.005]
predicted class:  0
[0.985 0.005 0.005 0.005]
[20 80 20 20]
[0.985 0.005 0.005 0.005]
predicted class:  0
[0.985 0.005 0.005 0.005]
[30 90 30 30]
[0.985 0.005 0.005 0.005]
predicted class:  0
[0.985 0

[0.985 0.005 0.005 0.005]
[90 40 40 40]
[0.985 0.005 0.005 0.005]
predicted class:  0
[0.985 0.005 0.005 0.005]
[90 50 50 50]
[0.985 0.005 0.005 0.005]
predicted class:  0
[0.985 0.005 0.005 0.005]
[20 20 20 80]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[40 40 40 80]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[60 60 60 80]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[20 20 20 40]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[40 40 40 60]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[20 20 20 60]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[20 20 80 20]
[0.985 0.005 0.005 0.005]
predicted class:  0
[0.985 0.005 0.005 0.005]
[40 40 80 40]
[0.985 0.005 0.005 0.005]
predicted class:  0
[0.985 0.005 0.005 0.005]
[60 60 80 40]
[0.985 0.005 0.005 0.005]
predicted class:  1
[0.985 0.005 0.005 0.005]
[20 20 40 20]
[0.985 0.005 0

[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[90 90 90 10]
[0.495 0.495 0.005 0.005]
predicted class:  1
[0.495 0.495 0.005 0.005]
[90 90 90 20]
[0.495 0.495 0.005 0.005]
predicted class:  0
[0.495 0.495 0.005 0.005]
[90 90 90 30]
[0.495 0.495 0.005 0.005]
predicted class:  0
[0.495 0.495 0.005 0.005]
[90 90 90 40]
[0.495 0.495 0.005 0.005]
predicted class:  1
[0.495 0.495 0.005 0.005]
[90 90 90 50]
[0.495 0.495 0.005 0.005]
predicted class:  0
[0.495 0.495 0.005 0.005]
[90 90 90 60]
[0.495 0.495 0.005 0.005]
predicted class:  0
[0.495 0.495 0.005 0.005]
[90 90 10 90]
[0.495 0.495 0.005 0.005]
predicted class:  1
[0.495 0.495 0.005 0.005]
[90 90 20 90]
[0.495 0.495 0.005 0.005]
predicted class:  1
[0.495 0.495 0.005 0.005]
[90 90 30 90]
[0.495 0.495 0.005 0.005]
predicted class:  0
[0.495 0.495 0.005 0.005]
[90 90 40 90]
[0.495 0.495 0.005 0.005]
predicted class:  0
[0.495 0.495 0.005 0.005]
[90 90 50 90]
[0.495 0.495 0.005 0.005]
predicted class:  0
[0.495 0

[0.005 0.495 0.495 0.005]
predicted class:  1
[0.005 0.495 0.495 0.005]
[80 80 40 80]
[0.005 0.495 0.495 0.005]
predicted class:  1
[0.005 0.495 0.495 0.005]
[80 80 60 80]
[0.005 0.495 0.495 0.005]
predicted class:  2
[0.005 0.495 0.495 0.005]
[40 40 20 40]
[0.005 0.495 0.495 0.005]
predicted class:  2
[0.005 0.495 0.495 0.005]
[60 60 40 60]
[0.005 0.495 0.495 0.005]
predicted class:  2
[0.005 0.495 0.495 0.005]
[60 60 20 60]
[0.005 0.495 0.495 0.005]
predicted class:  2
[0.005 0.495 0.495 0.005]
[80 20 80 80]
[0.005 0.495 0.495 0.005]
predicted class:  2
[0.005 0.495 0.495 0.005]
[80 40 80 80]
[0.005 0.495 0.495 0.005]
predicted class:  2
[0.005 0.495 0.495 0.005]
[80 60 80 80]
[0.005 0.495 0.495 0.005]
predicted class:  2
[0.005 0.495 0.495 0.005]
[40 20 40 40]
[0.005 0.495 0.495 0.005]
predicted class:  2
[0.005 0.495 0.495 0.005]
[60 40 60 60]
[0.005 0.495 0.495 0.005]
predicted class:  1
[0.005 0.495 0.495 0.005]
[60 20 60 60]
[0.005 0.495 0.495 0.005]
predicted class:  2
[0.005 0

[20 20 20 20]
[0.005 0.005 0.495 0.495]
predicted class:  2
[0.005 0.005 0.495 0.495]
[25 25 25 25]
[0.005 0.005 0.495 0.495]
predicted class:  3
[0.005 0.005 0.495 0.495]
[30 30 30 30]
[0.005 0.005 0.495 0.495]
predicted class:  2
[0.005 0.005 0.495 0.495]
[35 35 35 35]
[0.005 0.005 0.495 0.495]
predicted class:  2
[0.005 0.005 0.495 0.495]
[40 40 40 40]
[0.005 0.005 0.495 0.495]
predicted class:  3
[0.005 0.005 0.495 0.495]
[45 45 45 45]
[0.005 0.005 0.495 0.495]
predicted class:  3
[0.005 0.005 0.495 0.495]
[50 50 50 50]
[0.005 0.005 0.495 0.495]
predicted class:  2
[0.005 0.005 0.495 0.495]
[55 55 55 55]
[0.005 0.005 0.495 0.495]
predicted class:  2
[0.005 0.005 0.495 0.495]
[60 60 60 60]
[0.005 0.005 0.495 0.495]
predicted class:  3
[0.005 0.005 0.495 0.495]
[65 65 65 65]
[0.005 0.005 0.495 0.495]
predicted class:  3
[0.005 0.005 0.495 0.495]
[70 70 70 70]
[0.005 0.005 0.495 0.495]
predicted class:  3
[0.005 0.005 0.495 0.495]
[75 75 75 75]
[0.005 0.005 0.495 0.495]
predicted clas

[30 35 40 45]
[0.005 0.005 0.005 0.985]
predicted class:  3
[0.005 0.005 0.005 0.985]
[45 50 55 65]
[0.005 0.005 0.005 0.985]
predicted class:  3
[0.005 0.005 0.005 0.985]
[50 55 65 75]
[0.005 0.005 0.005 0.985]
predicted class:  3
[0.005 0.005 0.005 0.985]
[55 65 75 85]
[0.005 0.005 0.005 0.985]
predicted class:  3
[0.005 0.005 0.005 0.985]
[65 75 85 95]
[0.005 0.005 0.005 0.985]
predicted class:  3
[0.005 0.005 0.005 0.985]
[75 80 85 90]
[0.005 0.005 0.005 0.985]
predicted class:  3
[0.005 0.005 0.005 0.985]
[10 15 20 25]
[0.005 0.005 0.005 0.985]
predicted class:  3
[0.005 0.005 0.005 0.985]
[25 45 65 85]
[0.005 0.005 0.005 0.985]
predicted class:  3
[0.005 0.005 0.005 0.985]
[40 60 80 90]
[0.005 0.005 0.005 0.985]
predicted class:  3
[0.005 0.005 0.005 0.985]
[60 70 80 90]
[0.005 0.005 0.005 0.985]
predicted class:  3
[0.005 0.005 0.005 0.985]
[20 30 70 80]
[0.005 0.005 0.005 0.985]
predicted class:  3
[0.005 0.005 0.005 0.985]
[40 50 60 70]
[0.005 0.005 0.005 0.985]
predicted clas

[0.985 0.005 0.005 0.005]
predicted class:  0
[0.985 0.005 0.005 0.005]
[90 40 30 20]
[0.985 0.005 0.005 0.005]
predicted class:  0
[0.985 0.005 0.005 0.005]
[90 30 20 10]
[0.985 0.005 0.005 0.005]
predicted class:  0
[0.985 0.005 0.005 0.005]
[60 30 20 10]
[0.985 0.005 0.005 0.005]
predicted class:  0
[0.985 0.005 0.005 0.005]
[10 10 10 90]
[0.005 0.005 0.005 0.985]
predicted class:  3
[0.005 0.005 0.005 0.985]
[20 20 20 90]
[0.005 0.005 0.005 0.985]
predicted class:  3
[0.005 0.005 0.005 0.985]
[30 30 30 90]
[0.005 0.005 0.005 0.985]
predicted class:  3
[0.005 0.005 0.005 0.985]
[40 40 40 90]
[0.005 0.005 0.005 0.985]
predicted class:  3
[0.005 0.005 0.005 0.985]
[50 50 50 90]
[0.005 0.005 0.005 0.985]
predicted class:  3
[0.005 0.005 0.005 0.985]
[60 60 60 90]
[0.005 0.005 0.005 0.985]
predicted class:  3
[0.005 0.005 0.005 0.985]
[10 10 90 10]
[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0.005 0.985 0.005]
[20 20 90 20]
[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0

[60 90 60 60]
[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0.005 0.985 0.005]
[80 20 20 20]
[0.005 0.005 0.005 0.985]
predicted class:  3
[0.005 0.005 0.005 0.985]
[90 30 30 30]
[0.005 0.005 0.005 0.985]
predicted class:  3
[0.005 0.005 0.005 0.985]
[90 40 40 40]
[0.005 0.005 0.005 0.985]
predicted class:  3
[0.005 0.005 0.005 0.985]
[90 50 50 50]
[0.005 0.005 0.005 0.985]
predicted class:  3
[0.005 0.005 0.005 0.985]
[20 20 20 80]
[0.985 0.005 0.005 0.005]
predicted class:  0
[0.985 0.005 0.005 0.005]
[40 40 40 80]
[0.985 0.005 0.005 0.005]
predicted class:  0
[0.985 0.005 0.005 0.005]
[60 60 60 80]
[0.985 0.005 0.005 0.005]
predicted class:  0
[0.985 0.005 0.005 0.005]
[20 20 20 40]
[0.985 0.005 0.005 0.005]
predicted class:  0
[0.985 0.005 0.005 0.005]
[40 40 40 60]
[0.985 0.005 0.005 0.005]
predicted class:  0
[0.985 0.005 0.005 0.005]
[20 20 20 60]
[0.985 0.005 0.005 0.005]
predicted class:  0
[0.985 0.005 0.005 0.005]
[20 20 80 20]
[0.005 0.985 0.005 0.005]
predicted clas

[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0.005 0.985 0.005]
[60 80 60 60]
[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0.005 0.985 0.005]
[20 40 20 20]
[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0.005 0.985 0.005]
[40 60 40 40]
[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0.005 0.985 0.005]
[20 60 20 20]
[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0.005 0.985 0.005]
[80 20 20 20]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[80 40 40 40]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[80 60 60 60]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[40 20 20 20]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[60 40 40 40]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[60 20 20 20]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[90 90 90 10]
[0.005      0.33166667 0.33166667 0.33166667]
predicte

[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[75 80 85 90]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[10 15 20 25]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[25 45 65 85]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[40 60 80 90]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[60 70 80 90]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[20 30 70 80]
[0.005 0.985 0.005 0.005]
predicted class:  3
[0.005 0.985 0.005 0.005]
[40 50 60 70]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[50 60 70 80]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[90 80 70 60]
[0.985 0.005 0.005 0.005]
predicted class:  0
[0.985 0.005 0.005 0.005]
[80 70 60 50]
[0.985 0.005 0.005 0.005]
predicted class:  0
[0.985 0.005 0.005 0.005]
[70 60 50 40]
[0.985 0.005 0.005 0.005]
predicted class:  0
[0.985 0

[80 70 60 50]
[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0.005 0.985 0.005]
[70 60 50 40]
[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0.005 0.985 0.005]
[60 50 40 30]
[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0.005 0.985 0.005]
[50 40 30 20]
[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0.005 0.985 0.005]
[40 30 20 10]
[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0.005 0.985 0.005]
[90 70 50 30]
[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0.005 0.985 0.005]
[70 50 30 10]
[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0.005 0.985 0.005]
[95 75 55 35]
[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0.005 0.985 0.005]
[75 55 35 15]
[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0.005 0.985 0.005]
[80 60 40 20]
[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0.005 0.985 0.005]
[85 65 45 25]
[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0.005 0.985 0.005]
[90 70 40 10]
[0.005 0.005 0.985 0.005]
predicted clas

[30 30 90 30]
[0.005 0.005 0.005 0.985]
predicted class:  3
[0.005 0.005 0.005 0.985]
[40 40 90 40]
[0.005 0.005 0.005 0.985]
predicted class:  3
[0.005 0.005 0.005 0.985]
[50 50 90 50]
[0.005 0.005 0.005 0.985]
predicted class:  3
[0.005 0.005 0.005 0.985]
[20 80 20 20]
[0.985 0.005 0.005 0.005]
predicted class:  0
[0.985 0.005 0.005 0.005]
[30 90 30 30]
[0.985 0.005 0.005 0.005]
predicted class:  0
[0.985 0.005 0.005 0.005]
[60 90 60 60]
[0.985 0.005 0.005 0.005]
predicted class:  0
[0.985 0.005 0.005 0.005]
[80 20 20 20]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[90 30 30 30]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[90 40 40 40]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[90 50 50 50]
[0.005 0.985 0.005 0.005]
predicted class:  1
[0.005 0.985 0.005 0.005]
[20 20 20 80]
[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0.005 0.985 0.005]
[40 40 40 80]
[0.005 0.005 0.985 0.005]
predicted clas

predicted class:  2
[0.25 0.25 0.25 0.25]
[25 25 25 25]
[0.25 0.25 0.25 0.25]
predicted class:  3
[0.25 0.25 0.25 0.25]
[30 30 30 30]
[0.25 0.25 0.25 0.25]
predicted class:  3
[0.25 0.25 0.25 0.25]
[35 35 35 35]
[0.25 0.25 0.25 0.25]
predicted class:  2
[0.25 0.25 0.25 0.25]
[40 40 40 40]
[0.25 0.25 0.25 0.25]
predicted class:  2
[0.25 0.25 0.25 0.25]
[45 45 45 45]
[0.25 0.25 0.25 0.25]
predicted class:  0
[0.25 0.25 0.25 0.25]
[50 50 50 50]
[0.25 0.25 0.25 0.25]
predicted class:  1
[0.25 0.25 0.25 0.25]
[55 55 55 55]
[0.25 0.25 0.25 0.25]
predicted class:  1
[0.25 0.25 0.25 0.25]
[60 60 60 60]
[0.25 0.25 0.25 0.25]
predicted class:  0
[0.25 0.25 0.25 0.25]
[65 65 65 65]
[0.25 0.25 0.25 0.25]
predicted class:  3
[0.25 0.25 0.25 0.25]
[70 70 70 70]
[0.25 0.25 0.25 0.25]
predicted class:  3
[0.25 0.25 0.25 0.25]
[75 75 75 75]
[0.25 0.25 0.25 0.25]
predicted class:  0
[0.25 0.25 0.25 0.25]
[80 80 80 80]
[0.25 0.25 0.25 0.25]
predicted class:  2
[0.25 0.25 0.25 0.25]
[85 85 85 85]
[0.25 0.

[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0.005 0.985 0.005]
[80 70 60 50]
[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0.005 0.985 0.005]
[70 60 50 40]
[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0.005 0.985 0.005]
[60 50 40 30]
[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0.005 0.985 0.005]
[50 40 30 20]
[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0.005 0.985 0.005]
[40 30 20 10]
[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0.005 0.985 0.005]
[90 70 50 30]
[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0.005 0.985 0.005]
[70 50 30 10]
[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0.005 0.985 0.005]
[95 75 55 35]
[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0.005 0.985 0.005]
[75 55 35 15]
[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0.005 0.985 0.005]
[80 60 40 20]
[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0.005 0.985 0.005]
[85 65 45 25]
[0.005 0.005 0.985 0.005]
predicted class:  2
[0.005 0

[0.33166667 0.005      0.33166667 0.33166667]
predicted class:  0
[0.33166667 0.005      0.33166667 0.33166667]
[80 60 80 80]
[0.33166667 0.005      0.33166667 0.33166667]
predicted class:  3
[0.33166667 0.005      0.33166667 0.33166667]
[40 20 40 40]
[0.33166667 0.005      0.33166667 0.33166667]
predicted class:  3
[0.33166667 0.005      0.33166667 0.33166667]
[60 40 60 60]
[0.33166667 0.005      0.33166667 0.33166667]
predicted class:  2
[0.33166667 0.005      0.33166667 0.33166667]
[60 20 60 60]
[0.33166667 0.005      0.33166667 0.33166667]
predicted class:  2
[0.33166667 0.005      0.33166667 0.33166667]
[20 80 80 80]
[0.33166667 0.33166667 0.005      0.33166667]
predicted class:  0
[0.33166667 0.33166667 0.005      0.33166667]
[40 80 80 80]
[0.33166667 0.33166667 0.005      0.33166667]
predicted class:  0
[0.33166667 0.33166667 0.005      0.33166667]
[60 80 80 80]
[0.33166667 0.33166667 0.005      0.33166667]
predicted class:  3
[0.33166667 0.33166667 0.005      0.33166667]
[20 40

In [332]:
saving_dir = "explain/explicit_1_full_data_training/"
if not os.path.exists(saving_dir):
    os.makedirs(saving_dir)
    
output_df.to_csv(os.path.join(saving_dir,'output_prediction_V5.1.csv'))
softmax_df.to_csv(os.path.join(saving_dir,'output_softmax_V5.1.csv'))
output_with_argmax.to_csv(os.path.join(saving_dir,'output_prediction_with_max_softmax_V5.1.csv'))
output_with_argmin.to_csv(os.path.join(saving_dir,'argmin_prediction_with_min_softmax_V5.1.csv'))
