In [1]:
# 3rd party
import pandas as pd
import numpy as np


In [2]:
def concat_csv( num_PCs, trials, stat_func):
    """
    :param num_PCs: int, number of representative PCs used
    :param trials: int, number of trials performed
    :param stat_func: statistic used to aggregate results from different trials
    :return: tuple, (train, val, test) statistics
    """

    # train=[]
    # val=[]
    # test=[]
    branches=['Global Flux', 'Local Flux', 'Centroid', 'Odd Even', 'Secondary', 'Stellar', 'DV']
    run='/nobackup/khauskne/kdd/explainability_runs/exp_'+str(num_PCs)+'_PCs_trial'

    for trial in range(1, trials+1):

        run_trial = f'{run}_{trial}/'

        all_train_groups=[]
        for index in range(num_PCs):
            train_group=pd.read_csv(run_trial+'train_top_'+str(index)+'.csv')
            all_train_groups.append(train_group)

        all_val_groups=[]
        for index in range(num_PCs):
            val_group=pd.read_csv(run_trial+'val_top_'+str(index)+'.csv')
            all_val_groups.append(val_group)

        all_test_groups=[]
        for index in range(num_PCs):
            test_group=pd.read_csv(run_trial+'test_top_'+str(index)+'.csv')
            all_test_groups.append(test_group)

        train_statistics=apply_statistics(branches, all_train_groups, num_PCs)
        val_statistics=apply_statistics(branches, all_val_groups, num_PCs)
        test_statistics=apply_statistics(branches, all_test_groups, num_PCs)

    return stat_func(train_statistics, branches),  stat_func(val_statistics, branches), stat_func(test_statistics, branches)

In [3]:
def apply_statistics(branches, files, num_PCs):
    """
    :param branches: list, groupings of features
    :param files:
    :param num_PCs: int, number of representative PCs used
    :return: pandas DataFrame,
    """

    new_csv=files[0].copy()
    new_csv.pop('Unnamed: 0')

    for branch in branches:

        all_scores=np.zeros(tuple((num_PCs, np.shape(files[0])[0])))

        for example in range(len(files)):
            all_scores[example]=files[example][branch]
        col_index=new_csv.columns.get_loc(branch)
        new_csv[branch]=np.mean(all_scores, axis=0)
        new_csv.insert(col_index+1, branch+' Min', np.min(all_scores, axis=0))

    return new_csv


def evaluate_mean(csv, branches):
    """ Sets of features whose mean score change is

    :param csv:
    :param branches: list, groupings of features
    :return: dict,
    """

    reduced_csv=csv
    branch_explanation=np.zeros([len(branches), len(csv)])

    for i in range(len(branches)):
        branch_explanation[i]=(reduced_csv[branches[i]]<-0.5)
    branch_explanation=np.moveaxis(branch_explanation, 0, -1)
    
    mean_dict={'target_id': csv['target_id'], 'tce_plnt_num': csv['tce_plnt_num'], 'original_label':csv['original_label'],
              'full_score': csv['full score'], 'branch_explanations': branch_explanation}

    return mean_dict

def evaluate_min(csv, branches):
    reduced_csv=csv
    branch_explanation=np.zeros([len(branches), len(csv)])
    for i in range(len(branches)):
        branch_explanation[i]=(reduced_csv[branches[i]+ ' Min']<-0.5)
    branch_explanation=np.moveaxis(branch_explanation, 0, -1)
    
    mean_dict={'target_id': csv['target_id'], 'tce_plnt_num': csv['tce_plnt_num'], 'original_label':csv['original_label'],
              'full_score': csv['full score'], 'branch_explanations': branch_explanation}
    return mean_dict

def evaluate_max(csv, branches):
    reduced_csv=csv
    branch_explanation=np.zeros([len(branches), len(csv)])
    for i in range(len(branches)):
        branch_explanation[i]=(reduced_csv[branches[i]+ ' Max']<-0.5)
    branch_explanation=np.moveaxis(branch_explanation, 0, -1)
    
    mean_dict={'target_id': csv['target_id'], 'tce_plnt_num': csv['tce_plnt_num'], 'original_label':csv['original_label'],
              'full_score': csv['full score'], 'branch_explanations': branch_explanation}
    return mean_dict

def evaluate_med(csv, branches):
    reduced_csv=csv
    branch_explanation=np.zeros([len(branches), len(csv)])
    for i in range(len(branches)):
        branch_explanation[i]=(reduced_csv[branches[i]+ ' Med']<-0.5)
    branch_explanation=np.moveaxis(branch_explanation, 0, -1)
    
    mean_dict={'target_id': csv['target_id'], 'tce_plnt_num': csv['tce_plnt_num'], 'original_label':csv['original_label'],
              'full_score': csv['full score'], 'branch_explanations': branch_explanation}
    return mean_dict



In [22]:
def find_unexplained_FPs(dictionary):
    
    ind_of_fp=np.where(dictionary['full_score']<0.5)[0]
    num_of_contributing_branches_fp=np.sum(dictionary['branch_explanations'][ind_of_fp], axis=1)
    inds=np.where(num_of_contributing_branches_fp==0)[0]
    unexplained_fps=(np.array(dictionary['target_id'][ind_of_fp])[inds], np.array(dictionary['tce_plnt_num'][ind_of_fp])[inds]) 
    
    return unexplained_fps


In [23]:
def find_explained_PCs(dictionary):
    num_of_contributing_branches_total=np.sum(dictionary['branch_explanations'], axis=1)
    ind_of_explained=np.where(num_of_contributing_branches_total>0)[0]
    num_of_explained=np.shape(ind_of_explained)[0]
    inds_of_PCs=np.where(dictionary['full_score'][ind_of_explained]>0.5)[0]
    explained_PCs=(np.array(dictionary['target_id'][inds_of_PCs]), np.array(dictionary['tce_plnt_num'][inds_of_PCs]))
    
    return explained_PCs

In [46]:
ids=[]
plnts=[]
for i in range(1,11):
    csvs=concat_csv(10, i, evaluate_min)
    ids.append(find_unexplained_FPs(csvs[2])[0])
    plnts.append(find_unexplained_FPs(csvs[2])[1])
ids=np.concatenate(ids)
plnts=np.concatenate(plnts)

In [52]:
mistakes=np.moveaxis([ids[np.argsort(ids)], plnts[np.argsort(ids)]], -1, 0)

In [58]:
unique_vals=np.unique(mistakes, axis=0, return_counts=True)

In [60]:
unique_vals[0][np.where(unique_vals[1]==9)]

array([[3246984,       1],
       [5309353,       1],
       [9777793,       1]])

In [49]:
np.unique(ids[np.argsort(ids)], return_counts=True)

(array([ 2708286,  3221310,  3246984,  3353679,  3429707,  3560301,
         4357985,  4386607,  4845555,  5265699,  5299861,  5302881,
         5309353,  5353738,  5524881,  5716330,  5881893,  5983410,
         6289897,  6365321,  6367260,  6381309,  6386784,  6387819,
         6425135,  6522242,  6780367,  6890040,  6963171,  7045685,
         7465661,  7523340,  7708418,  7839814,  7955708,  7971540,
         8381693,  8692983,  8823893,  9025557,  9216810,  9334893,
         9777793,  9824928, 10135362, 10154994, 10221153, 10485250,
        10614845, 10989859, 10990092, 11071278, 11358392, 11390941,
        11494130, 11656840, 11673686, 11811140, 12217403, 12783196]),
 array([ 1,  1,  9,  1,  2,  3,  2,  7,  5,  5,  1,  2,  9,  1,  8,  4,  4,
         3,  8,  8,  3,  5,  5,  1,  5,  5,  8,  5,  1,  1,  2,  4,  9,  5,
         2, 17,  2,  9,  8,  5,  2,  6,  9,  3,  2,  4,  1,  2,  1,  4,  3,
         1,  3,  2,  3,  1,  1,  8,  2,  4]))

In [24]:
csvs=concat_csv(15, 9, evaluate_min)

In [37]:
unexp=find_unexplained_FPs(csvs[2])
unexp

(array([ 9777256, 10661976, 10717591, 11912932,  7599004, 11494130,
         6292162,  8748659, 11071278,  6119608, 10000547,  7971540,
        11390838,  7871315,  6115025, 10472112,  8108551,  7465661,
         3560301,  8057693,  9824928,  3955026,  4274480,  8174821,
        11501492, 11673686,  9216810, 11253627,  5299861, 11656840,
        11442793,  6600515,  9138680,  5036761,  6685533, 11254601,
         8939843, 10614845,  5982073, 11358392,  8544169,  6963171,
         8023238,  8298725,  6387819,  3555678,  9142053,  3336476,
         3120397,  8044016,  8175925, 10990092,  6367260, 11811140,
         6381309,  7523340,  6369539,  4845555,  8264623, 11390941,
         8702874,  6701459,  8885132,  6367260,  1435448,  7765677,
        10154994,  9895004,  8240890,  9340460,  8160316, 10978737,
        11082830,  5309353,  8950952, 10221153,  6890040, 10959320,
         5473535,  7621172,  7708418,  7839814,  7819674,  7293769,
         6425135, 10337859,  8106610,  7971540, 

In [29]:
unexp[1][np.argsort(unexp[0])]

array([1, 1, 1, 1, 3, 2])

In [52]:
np.shape(all_test['even_se_oot_norm'])

(3027, 1)

In [46]:
type(all_test)

numpy.ndarray

In [39]:
unexp[1]

array([1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 7, 4, 1, 3, 1, 1, 4, 4, 2, 1, 5,
       3, 2, 1, 3, 2, 1, 1, 1, 1, 2, 1, 3, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 2, 1, 2, 7, 8, 6, 3, 3, 4, 1, 2, 1, 8, 1, 4, 1, 1,
       5, 1, 1, 1, 1, 1, 8, 1, 2, 1, 1, 5, 1, 1, 3, 2, 1, 3, 5, 1, 3, 6,
       7, 8, 1, 6, 4, 1, 1, 1, 6, 2, 4, 1, 1, 1, 1, 1, 1, 2, 3, 1, 1, 3,
       1, 8, 3, 1, 2, 4, 6, 2, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 5, 1, 3, 1,
       1, 3, 1, 1, 3, 1, 1, 1, 1, 2, 3, 1, 2, 1, 2, 1, 4, 1, 1, 5, 4, 1,
       1, 1, 2, 8, 1, 1, 1, 2, 1, 1, 1, 1, 4, 2, 3, 1, 1, 1, 3, 1, 1, 6,
       6, 1, 4, 1, 1, 1, 4, 1, 1, 1, 2, 1, 1, 9, 1, 1, 3, 1, 3, 2, 6, 2,
       3, 1, 1, 1, 1, 2, 1, 1, 6, 2, 1, 2, 1, 4, 1, 3, 1, 1, 1, 1, 2, 1,
       4, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 5, 1, 1, 1, 3, 3, 1, 1, 3, 5,
       1, 2, 1, 1, 9, 1, 1, 1, 1, 2, 1, 1, 6, 1, 2, 4, 6])