In [23]:
import fedci

In [24]:
import polars as pl
import rpy2.robjects as ro
from rpy2.robjects import pandas2ri

In [25]:
import numpy as np
from tqdm.notebook import tqdm
from itertools import chain, combinations
import os

In [26]:
from scipy import stats
from pgmpy.estimators import CITests

In [27]:
class EmptyLikelihoodRatioTest(fedci.LikelihoodRatioTest):
    def __init__(self, y_label, x_label, s_labels, p_val):
        self.y_label = y_label
        self.x_label = x_label
        self.s_labels = s_labels
        self.p_val = p_val

In [28]:
TOTAL_SAMPLES = 1_000

TOTAL_FEATURES = 4
FEATURES_PER_CLIENT = 4

possible_dags = [
    "pdsep_g",
    "collider",
    "fork",
    "chain4",
    "descColl",
    "2descColl",
    "iv"
]

# TODO: possible_dags to dict or at least store num of vars for each one
chosen_dag = possible_dags[3]


server_id_pattern = 'dag_{}_{}c'

client_configurations = [1,3, 5]

max_regressors = None


alpha_comparisons = [0.01, 0.05, 0.1]
equality_tolerance = 1e-4


log_filepattern = './log-{}.csv'


In [29]:
def get_sample_data(dag_type, num_samples, num_vars):
    with (ro.default_converter + pandas2ri.converter).context():
        ro.r['source']('./app/scripts/example_data.r')
        get_example_data_f = ro.globalenv['get_example_data']

        result = get_example_data_f(dag_type, 1, num_samples, num_vars)
        
    return list(result.items())[0][1]

In [30]:
def get_servers(client_configurations, data):
    servers = {}    

    for splits in client_configurations:
        clients = {i:fedci.Client(chunk) for i,chunk in enumerate(np.array_split(data, splits))}
        servers[server_id_pattern.format(chosen_dag, splits)] = fedci.Server(clients, max_regressors=max_regressors)
    return servers

In [31]:
def get_possible_tests(available_data):

    possible_tests = []
    max_conditioning_set_size = min(len(available_data), max_regressors) if max_regressors is not None else len(available_data)

    for y_var in available_data:
        set_of_regressors = available_data - {y_var}
        for x_var in set_of_regressors:
            set_of_conditioning_variables = set_of_regressors - {x_var}
            conditioning_sets = chain.from_iterable(combinations(set_of_conditioning_variables, r) for r in range(0,max_conditioning_set_size))
            possible_tests.extend([(y_var, x_var, sorted(list(s_labels))) for s_labels in conditioning_sets])
            
    return possible_tests


In [32]:
def get_ground_truth_tests(data, possible_tests):
    ground_truth_tests = []

    for test in possible_tests:
        if len(test[2]) > 0:
            #v0 = data[test[0]].values
            #v1 = data[test[1]].values
            #s = data[list(test[2])].values
            #p0 = test[3]
            #p1 = citest(v0, v1, s, test_args={'statistic': 'ksg_cmi', 'n_jobs': 8})
            
            _, p1 = CITests.pearsonr(test[1], test[0], list(test[2]), data, boolean=False)
        else:
            v0 = data[test[0]]
            v1 = data[test[1]]
            _, p1 = stats.pearsonr(v0, v1)
            
        p1 = round(p1,4)
        
        ground_truth_tests.append(EmptyLikelihoodRatioTest(test[0], test[1], list(test[2]), p1))
    return ground_truth_tests

In [33]:
def get_server_test_results(servers):
    testing_rounds = {k:v.testing_engine.finished_rounds for k,v in servers.items()}
    likelihood_tests = {k:fedci.get_likelihood_tests(v) for k,v in testing_rounds.items()}
    return likelihood_tests

In [34]:
def prepare_server_evaluation(ground_truth_tests, server_ci_tests):
    p_value_comparison = {k:[] for k in server_ci_tests.keys()}
    missing_test = {k:0 for k in server_ci_tests.keys()}
    
    for test in ground_truth_tests:
        for k in server_ci_tests.keys():
            matching_test = [t for t in server_ci_tests[k] if t.y_label == test.y_label and t.x_label == test.x_label and sorted(t.s_labels) == sorted(test.s_labels)]
            if len(matching_test) == 0:
                print(f'No matching test in {k} for {test}')
                missing_test[k] += 1
            assert len(matching_test) == 1
            matching_test = matching_test[0]          
            p_value_comparison[k].append((matching_test.p_val, test.p_val))
        
    missing_test = {k:v/len(server_ci_tests[k]) for k,v in missing_test.items()}
    return p_value_comparison, missing_test

In [35]:
def count_correct_alpha_thresholdings(data, alpha):
    c = sum([1 for a,b in data if (a < alpha and b < alpha) or (a > alpha and b > alpha)]) / len(data)
    return c

def count_correct_pval(data, tolerance=1e-4):
    c = sum([1 for a,b in data if abs(a-b)<tolerance]) / len(data)
    return c

def evaluate_results(p_value_comparison, alphas, tolerance):
    result_alpha = {}
    result_equality = {}
    for k,v in p_value_comparison.items():
        result_alpha[k] = {}
        result_equality[k] = count_correct_pval(v, tolerance)
        for alpha in alphas:
            result_alpha[k][alpha] = count_correct_alpha_thresholdings(v,alpha)
    return result_alpha, result_equality

In [36]:
def get_records(servers, alpha_tests, equality_tests, missed_tests, total_features, features_per_client):
    results = []
    for server_id in servers.keys():
        server = servers[server_id]
        alpha_test = alpha_tests[server_id]
        
        r = {
            'chosen_dag': chosen_dag,
            'num_clients': len(server.clients),
            'num_samples': TOTAL_SAMPLES,
            'same_p_val': equality_tests[server_id],
            'missed_tests': missed_tests[server_id],
            'total_features': total_features,
            'features_per_client': features_per_client
        }
        for alpha, alpha_result in alpha_test.items():
            r[f'correctness_alpha_{alpha}'] = alpha_result
        results.append(r)

    return results

In [37]:
def csv_add_row(data, file):
    with open(file, 'a') as f:
        row = ','.join([str(d) for d in data]) + '\n'
        f.write(row)
            

def write_records(i, file, data):
    if len(data) == 0:
        return
    curr_file = file.format(i)
    if not os.path.exists(curr_file):
        csv_add_row(list(data[0].keys()), curr_file)
    for entry in data:
        csv_add_row(entry.values(), curr_file)

In [38]:
def process(i):
    data = get_sample_data(chosen_dag, TOTAL_SAMPLES, TOTAL_FEATURES)
    servers = get_servers(client_configurations, data)

    for server in servers.values(): server.run_tests()

    possible_tests = get_possible_tests(set(data.columns))
    ground_truth_tests = get_ground_truth_tests(data, possible_tests)
    server_ci_tests = get_server_test_results(servers) 

    p_val_comparisons, missed_tests = prepare_server_evaluation(ground_truth_tests, server_ci_tests)

    alpha_tests, equality_tests = evaluate_results(p_val_comparisons, alpha_comparisons, equality_tolerance)

    records = get_records(servers, alpha_tests, equality_tests, missed_tests, TOTAL_FEATURES, FEATURES_PER_CLIENT)
    
    write_records(i, log_filepattern, records)

In [39]:
for i in range(33):
    process(1)

'DataFrame.swapaxes' is deprecated and will be removed in a future version. Please use 'DataFrame.transpose' instead.
'DataFrame.swapaxes' is deprecated and will be removed in a future version. Please use 'DataFrame.transpose' instead.
'DataFrame.swapaxes' is deprecated and will be removed in a future version. Please use 'DataFrame.transpose' instead.
'DataFrame.swapaxes' is deprecated and will be removed in a future version. Please use 'DataFrame.transpose' instead.
'DataFrame.swapaxes' is deprecated and will be removed in a future version. Please use 'DataFrame.transpose' instead.
'DataFrame.swapaxes' is deprecated and will be removed in a future version. Please use 'DataFrame.transpose' instead.
'DataFrame.swapaxes' is deprecated and will be removed in a future version. Please use 'DataFrame.transpose' instead.
'DataFrame.swapaxes' is deprecated and will be removed in a future version. Please use 'DataFrame.transpose' instead.
'DataFrame.swapaxes' is deprecated and will be removed i