# Data Pooling Experiments

Where we analyze and present the results described in Section 3.2

1. [Preeliminaries](#Preeliminaries)
2. [Natural Partitions](#Natural-Partitions)
    - [Table 3.2](#Natural-Partitions)
    
    
3. [Synthetic Partitions](#Synthetic-Partitions)
    - [Table 3.3 Summary](#Synthetic-Partitions)
    - [Table 3.4 By Params](#Group-by-partition-parameters)

## Preeliminaries

In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import inspect,sys
import os

from sklearn.ensemble import RandomForestClassifier,RandomForestRegressor
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import cross_val_score,KFold
from sklearn.model_selection import train_test_split

import datasets,utils

In [2]:
def local_index(df):
    return min([(i,len(i.split(","))) for i in df.index],key=lambda x:x[1])[0]

def global_index(df):
    return max([(i,len(i.split(','))) for i in df.index],key=lambda x:x[1])[0]

In [3]:
def get_n(partition,subset):
    #islands=[i.split("'")[1] for i in subset.split(',')]
    return sum(len(partition[island][1]) for island in get_islands_from_str(subset))  

In [4]:
def get_islands_from_str(s):
    tmp=[i.strip() for i in s.replace('(','').replace(')','').replace('\'','').split(',') if i.strip()]
    if '\'' not in s:
        tmp=[int(i) for i in tmp]
    return tmp

In [5]:
def get_0s(partition,subset,exclude=''):
    islands=[i for i in get_islands_from_str(subset) if i!=exclude]
    zero=sorted(partition[islands[0]][1].unique())[0]
    n=get_n(partition,subset)
    return sum((partition[island][1]==zero).sum() for island in islands)/n

In [6]:
RESULTS_DIR='results/data_experiments'
PARTITIONS_DIR='partitions/data_experiments'

In [7]:
metrics={
        'regression':['r2','neg_mean_squared_error'],
        'classification':['balanced_accuracy','accuracy'] 
}

In [35]:
def construct_table(results_info):
    # We then construct the table
    rows=[]
    
    negative_scores=0
    
    perf_by_n={}

    for partition_name,(results_file,partition_file) in results_info.items():

        partition_info=utils.loadPickeObj(partition_file)
        results_by_island=utils.loadPickeObj(results_file)
        partition,dataset=partition_info['partition'],partition_info['dataset']()
        
        perf_by_n[partition_name]={island:[] for island in results_by_island}

        for island,results_df in results_by_island.items():

            tmp=results_df['test_'+metrics[dataset.task][0]].round(2)
            negative_scores+=(tmp<0).sum()

            # Scores
            local_score=tmp[local_index(tmp)]
            global_score=tmp[global_index(tmp)]
            max_score=tmp.max()
            min_score=tmp.min()
            runner_score=tmp.sort_values(ascending=False)[1]
            best_subset=tmp.index[tmp.argmax()]
            worst_subset=tmp.index[tmp.argmin()]

            # N
            total_n=get_n(partition,global_index(tmp))
            prop_n_local=get_n(partition,local_index(tmp))/total_n
            prop_n_best=get_n(partition,best_subset)/total_n
            prop_n_worst=get_n(partition,worst_subset)/total_n

            # 0s
            prop_0s_local=np.nan
            prop_0s_global=np.nan
            prop_0s_best=np.nan
            prop_0s_worst=np.nan
            
            if dataset.task=='classification':
                prop_0s_local=get_0s(partition,local_index(tmp))
                prop_0s_global=get_0s(partition,global_index(tmp))
                prop_0s_best=get_0s(partition,best_subset)#exclude=local_index(tmp))
                prop_0s_worst=get_0s(partition,worst_subset) #,exclude=local_index(tmp))

            best_subset_text='Global'
            if best_subset==local_index(tmp):
                best_subset_text='Local'
            elif best_subset!=global_index(tmp):
                best_subset_text=set(get_islands_from_str(best_subset))

            worst_subset_text='Global'
            if worst_subset==local_index(tmp):
                worst_subset_text='Local'
            elif worst_subset!=global_index(tmp):
                worst_subset_text=set(get_islands_from_str(worst_subset))

            rows.append([
                partition_name,
                dataset.task,
                island,best_subset_text,worst_subset_text,
                max_score-local_score,
                max_score-global_score,
                100*prop_n_local,
                100*prop_n_best,
                100*prop_n_worst,
                100*len(get_islands_from_str(best_subset))/len(get_islands_from_str(global_index(tmp))),
                prop_0s_local,
                prop_0s_global,
                prop_0s_best,
                prop_0s_worst,
            ])
            
            local_n=get_n(partition,local_index(tmp))
            if max_score!=min_score:
                for island_set in tmp.index:
                    prop_n=get_n(partition,island_set)/total_n #(get_n(partition,island_set)-local_n)/(total_n-local_n)
                    prop_perf=tmp[island_set]/max_score #(tmp[island_set]-min_score)/(max_score-min_score)
                    perf_by_n[partition_name][island].append((prop_n,prop_perf))
            
    if negative_scores>0:
        print(f'Found {negative_scores} negative scores!')

    return (pd.DataFrame(
        rows,
        columns=[
            'Partition','Task','Island','Best Model','Worst Model','Local Delta',
            'Global Delta','Size of Local (%)',"Best's dataset size (%)",
            'Size of Worst (%)',"Best's set size (%)",'Prop 0s local','Prop 0s global','Prop 0s best','Prop 0s worst',
        ]
    ).sort_values(['Partition','Island']),
    perf_by_n)

In [36]:
# Load the results and the corresponding partition file

# ID string -> (result file, partition file)
results_info={}

for fname in os.listdir(RESULTS_DIR):
    if '.pkl' not in fname: continue
    id_str=fname.replace('.pkl','')

    # Check that partition file exists
    if fname not in os.listdir(PARTITIONS_DIR):
        print(f'Partition file not found for {id_str} in {PARTITIONS_DIR}. Skipping.')
        continue
    
    # Save 
    results_file=os.path.join(RESULTS_DIR,fname)
    partition_file=os.path.join(PARTITIONS_DIR,fname)
    results_info[id_str]=(results_file,partition_file)


In [40]:
# Directory we will save graphs and tables to
os.makedirs('tmp/data_experiments',exist_ok=True)

## Natural Partitions

In [41]:
# Filter out non natural partitions
natural_results_info={k:v for k,v in results_info.items() if 'natural' in k}

# Construct the table (as a DataFrame)
natural_table,perf_by_n_natural=construct_table(natural_results_info)

# Remove quotes from Best Model col
natural_table['Best Model']=natural_table['Best Model'].astype('str').str.replace("'","")

# Save it as a csv - Table 3.2.
natural_table.to_csv('tmp/data_experiments/natural_table.csv',index=False)

# Display it
#natural_table

In [42]:
# Summary statistics for every numeric column
natural_table.describe()

Unnamed: 0,Local Delta,Global Delta,Size of Local (%),Best's dataset size (%),Size of Worst (%),Best's set size (%),Prop 0s local,Prop 0s global,Prop 0s best,Prop 0s worst
count,33.0,33.0,33.0,33.0,33.0,33.0,15.0,15.0,15.0,15.0
mean,0.092727,0.020606,15.151515,48.826846,29.812599,44.023569,0.786124,0.782073,0.781852,0.807119
std,0.081481,0.018865,17.45668,22.399718,25.782817,18.229943,0.062371,0.011169,0.022901,0.044437
min,0.0,0.0,1.897698,6.335292,4.437594,16.666667,0.697124,0.760492,0.754165,0.741385
25%,0.02,0.01,6.748351,32.31352,15.346977,33.333333,0.736733,0.787468,0.771309,0.779736
50%,0.08,0.01,9.217667,52.102989,17.808091,44.444444,0.780289,0.787468,0.779654,0.811749
75%,0.13,0.03,14.366007,63.420514,32.360234,58.333333,0.821536,0.787468,0.780736,0.824808
max,0.26,0.07,93.664708,98.102302,100.0,75.0,0.921129,0.787468,0.854029,0.921129


## Synthetic Partitions

In [48]:
# Filter out natural partitions
synth_results_info={k:v for k,v in results_info.items() if 'natural' not in k}

# Construct the table (as a DataFrame)
synth_table,perf_by_n_synth=construct_table(synth_results_info)

# Save it as a csv
synth_table.to_csv('tmp/data_experiments_synth_table.csv',index=False)

# Construct extra columns
synth_table[['Data','P. Method','Param','Clients','Run ID']]=synth_table['Partition'].str.split(pat='_',expand=True)
synth_table['Param']=synth_table['Param'].str.split(pat='=',expand=True)[1]

# Summary statistics for every numeric column - Table 3.3
synth_table.describe().to_csv('tmp/data_experiments/synth_summary.csv')
synth_table.describe()

Unnamed: 0,Island,Local Delta,Global Delta,Size of Local (%),Best's dataset size (%),Size of Worst (%),Best's set size (%),Prop 0s local,Prop 0s global,Prop 0s best,Prop 0s worst
count,1750.0,1750.0,1750.0,1750.0,1750.0,1750.0,1750.0,1125.0,1125.0,1125.0,1125.0
mean,2.0,0.058417,0.013531,20.0,71.154791,25.558977,60.617143,0.565811,0.565792,0.564796,0.566099
std,1.414618,0.051615,0.037775,16.058105,23.788675,20.676113,24.367745,0.251265,0.183382,0.205023,0.225389
min,0.0,0.0,0.0,2.064009,2.256742,2.111405,20.0,0.000236,0.144515,0.00221,0.000754
25%,1.0,0.02,0.0,12.0684,59.933912,12.599938,40.0,0.356495,0.427836,0.426906,0.414251
50%,2.0,0.04,0.0,19.986804,79.553622,20.0,60.0,0.598654,0.536224,0.525213,0.577806
75%,3.0,0.09,0.01,20.0,89.490179,33.229056,80.0,0.780036,0.759206,0.777334,0.77709
max,4.0,0.34,0.28,85.500031,100.0,100.0,100.0,0.999362,0.852423,0.999362,0.997055


## Group by partition parameters

In [49]:
by_param=synth_table.groupby(['P. Method','Param'])[
    ['Local Delta','Global Delta','Best\'s set size (%)','Best\'s dataset size (%)']
]
# Used to make Table 3.4
by_param.mean().to_csv('tmp/data_experiments/synth_by_param_mean.csv')
by_param.std().to_csv('tmp/data_experiments/synth_by_param_std.csv')
by_param.mean()

Unnamed: 0_level_0,Unnamed: 1_level_0,Local Delta,Global Delta,Best's set size (%),Best's dataset size (%)
P. Method,Param,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
dirY,0.5,0.07976,0.06544,52.32,52.32
dirY,0.75,0.0908,0.05448,59.84,59.84
dirY,1.0,0.09776,0.04824,54.88,54.88
dirY,10.0,0.10888,0.00472,79.2,79.2
powN,0.1,0.0656,0.00316,46.48,84.210697
powN,0.25,0.04836,0.00144,57.44,80.066621
powN,0.5,0.03844,0.001,64.32,75.347766
powN,0.75,0.0346,0.0016,67.52,69.877587
powN,1.0,0.03332,0.00108,65.44,65.460863
