In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import fedci

No GPU automatically detected. Setting SETTINGS.GPU to 0, and SETTINGS.NJOBS to cpu_count.
INFO:rpy2.situation:cffi mode is CFFI_MODE.ANY
INFO:rpy2.situation:R home found: /opt/homebrew/Caskroom/miniforge/base/envs/promotion/lib/R
INFO:rpy2.situation:R library path: 
INFO:rpy2.situation:LD_LIBRARY_PATH: 
INFO:rpy2.rinterface_lib.embedded:Default options to initialize R: rpy2, --quiet, --no-save
INFO:rpy2.rinterface_lib.embedded:R is already initialized. No need to initialize.


In [3]:
import polars as pl
import rpy2.robjects as ro
from rpy2.robjects import pandas2ri

In [4]:
import numpy as np
from tqdm.notebook import tqdm
from itertools import chain, combinations
import os

In [5]:
from scipy import stats
from pgmpy.estimators import CITests

In [187]:
class EmptyLikelihoodRatioTest(fedci.LikelihoodRatioTest):
    def __init__(self, y_label, x_label, s_labels, p_val):
        self.y_label = y_label
        self.x_label = x_label
        self.s_labels = s_labels
        self.p_val = p_val
        
class CategoricalLikelihoodRatioTest(fedci.LikelihoodRatioTest):
    def __init__(self, y_label, t0s, t1s, num_cats):
        assert len(t0s) > 0
        assert len(t1s) > 0
        assert len(t0s[0].X_labels) + 1 == len(t1s[0].X_labels)
        # TODO: assert more data integrity
        #assert t0s[0].y_label == t1s[0].y_label
        
        self.y_label = y_label
        self.x_label = (set(t1s[0].X_labels) - set(t0s[0].X_labels)).pop()
        self.s_labels = t0s[0].X_labels
        self.p_val = self._run_likelihood_test(t0s, t1s, num_cats)
        self.p_val = round(self.p_val, 4)
        
    def _run_likelihood_test(self, t0s, t1s, num_cats):
        
        # t1 should always encompass more regressors -> less client can fulfill this
        #assert len(self.t1.providing_clients) < len(self.t0.providing_clients)
        
        providing_clients = t1s[0].providing_clients
        
        t0_llf = sum([t.get_fit_stats(providing_clients)['llf'] for t in t0s])
        t1_llf = sum([t.get_fit_stats(providing_clients)['llf'] for t in t1s])
        
        # d_y = num cats
        # DOF Z = size cond set
        # DOF X = 1
        t0_dof = (num_cats-1)*(len(self.s_labels)+1) # (d_y - 1)*(DOF(Z)+1)
        t1_dof = (num_cats-1)*(len(self.s_labels)+2) # (d_y - 1)*(DOF(Z)+DOF(X)+1)
        t = -2*(t0_llf - t1_llf)
        
        p_val = stats.chi2.sf(t, t1_dof-t0_dof)
        
        return p_val
    
class OrdinalLikelihoodRatioTest(fedci.LikelihoodRatioTest):
    def __init__(self, y_label, t0s, t1s, num_cats):
        assert len(t0s) > 0
        assert len(t1s) > 0
        assert len(t0s[0].X_labels) + 1 == len(t1s[0].X_labels)
        # TODO: assert more data integrity
        #assert t0s[0].y_label == t1s[0].y_label
        
        t0s = sorted(t0s, key=lambda x: int(x.y_label.split('__ord__')[-1]))
        t1s = sorted(t1s, key=lambda x: int(x.y_label.split('__ord__')[-1]))
        
        self.y_label = y_label
        self.x_label = (set(t1s[0].X_labels) - set(t0s[0].X_labels)).pop()
        self.s_labels = t0s[0].X_labels
        self.p_val = self._run_likelihood_test(t0s, t1s, num_cats)
        self.p_val = round(self.p_val, 4)
        
    def _run_likelihood_test(self, t0s, t1s, num_cats):
        
        # t1 should always encompass more regressors -> less client can fulfill this
        #assert len(self.t1.providing_clients) < len(self.t0.providing_clients)
        
        providing_clients = t1s[0].providing_clients
        
        t0_llf = sum([t.get_fit_stats(providing_clients)['llf'] for t in t0s])
        t1_llf = sum([t.get_fit_stats(providing_clients)['llf'] for t in t1s])
        
        # d_y = num cats
        # DOF Z = size cond set
        # DOF X = 1
        t0_dof = (num_cats-1)*(len(self.s_labels)+1) # (d_y - 1)*(DOF(Z)+1)
        t1_dof = (num_cats-1)*(len(self.s_labels)+2) # (d_y - 1)*(DOF(Z)+DOF(X)+1)
        t = -2*(t0_llf - t1_llf)
        
        p_val = stats.chi2.sf(t, t1_dof-t0_dof)
        
        return p_val

In [188]:
TOTAL_SAMPLES = 1000

#TOTAL_FEATURES = 4
#FEATURES_PER_CLIENT = 4

possible_dags = [
    "pdsep_g",
    "collider",
    "fork",
    "chain4",
    "descColl",
    "2descColl",
    "iv"
]

# TODO: possible_dags to dict or at least store num of vars for each one
chosen_dag = possible_dags[3]


server_id_pattern = 'dag_{}_{}c'

client_configurations = [1,3,5,10]

max_regressors = None


alpha_comparisons = [0.01, 0.05, 0.1]
equality_tolerance = 1e-4


log_filepattern = './log-{}.csv'


In [189]:
real_independence_tests_collider = [
    EmptyLikelihoodRatioTest('A', 'B', [], 1),
    EmptyLikelihoodRatioTest('A', 'C', [], 0),
    EmptyLikelihoodRatioTest('B', 'C', [], 0),
    EmptyLikelihoodRatioTest('A', 'B', ['C'], 0),
    EmptyLikelihoodRatioTest('A', 'C', ['B'], 0),
    EmptyLikelihoodRatioTest('B', 'C', ['A'], 0),
]

real_independence_tests_fork = [
    EmptyLikelihoodRatioTest('A', 'B', [], 0),
    EmptyLikelihoodRatioTest('A', 'C', [], 0),
    EmptyLikelihoodRatioTest('B', 'C', [], 0),
    EmptyLikelihoodRatioTest('A', 'B', ['C'], 0),
    EmptyLikelihoodRatioTest('A', 'C', ['B'], 0),
    EmptyLikelihoodRatioTest('B', 'C', ['A'], 1),
]

real_independence_tests_diamond = [
    # cond set 0
    EmptyLikelihoodRatioTest('A', 'B', [], 0),
    EmptyLikelihoodRatioTest('A', 'C', [], 0),
    EmptyLikelihoodRatioTest('A', 'D', [], 0),
    EmptyLikelihoodRatioTest('B', 'C', [], 0),
    EmptyLikelihoodRatioTest('B', 'D', [], 0),
    EmptyLikelihoodRatioTest('C', 'D', [], 0),
    # cond set 1
    # start a
    EmptyLikelihoodRatioTest('A', 'B', ['C'], 0),
    EmptyLikelihoodRatioTest('A', 'C', ['B'], 0),
    EmptyLikelihoodRatioTest('A', 'D', ['B'], 0),
    EmptyLikelihoodRatioTest('A', 'B', ['D'], 0),
    EmptyLikelihoodRatioTest('A', 'C', ['D'], 0),
    EmptyLikelihoodRatioTest('A', 'D', ['C'], 0),
    # start b
    EmptyLikelihoodRatioTest('B', 'C', ['A'], 1),
    EmptyLikelihoodRatioTest('B', 'D', ['A'], 0),
    EmptyLikelihoodRatioTest('B', 'C', ['D'], 0),
    EmptyLikelihoodRatioTest('B', 'D', ['C'], 0),
    # start c
    EmptyLikelihoodRatioTest('C', 'D', ['A'], 0),
    EmptyLikelihoodRatioTest('C', 'D', ['B'], 0),
    # cond set 2
    EmptyLikelihoodRatioTest('A', 'B', ['C', 'D'], 0),
    EmptyLikelihoodRatioTest('A', 'C', ['B', 'D'], 0),
    EmptyLikelihoodRatioTest('A', 'D', ['B', 'C'], 1),
    EmptyLikelihoodRatioTest('B', 'C', ['A', 'D'], 0),
    EmptyLikelihoodRatioTest('B', 'D', ['A', 'C'], 0),
    EmptyLikelihoodRatioTest('C', 'D', ['A', 'B'], 0),
]

real_independence_tests_chain = [
    # cond set 0
    EmptyLikelihoodRatioTest('A', 'B', [], 0),
    EmptyLikelihoodRatioTest('A', 'C', [], 0),
    EmptyLikelihoodRatioTest('A', 'D', [], 0),
    EmptyLikelihoodRatioTest('B', 'C', [], 0),
    EmptyLikelihoodRatioTest('B', 'D', [], 0),
    EmptyLikelihoodRatioTest('C', 'D', [], 0),
    # cond set 1
    # start a
    EmptyLikelihoodRatioTest('A', 'B', ['C'], 0),
    EmptyLikelihoodRatioTest('A', 'C', ['B'], 1),
    EmptyLikelihoodRatioTest('A', 'D', ['B'], 1),
    EmptyLikelihoodRatioTest('A', 'B', ['D'], 0),
    EmptyLikelihoodRatioTest('A', 'C', ['D'], 0),
    EmptyLikelihoodRatioTest('A', 'D', ['C'], 1),
    # start b
    EmptyLikelihoodRatioTest('B', 'C', ['A'], 0),
    EmptyLikelihoodRatioTest('B', 'D', ['A'], 0),
    EmptyLikelihoodRatioTest('B', 'C', ['D'], 0),
    EmptyLikelihoodRatioTest('B', 'D', ['C'], 1),
    # start c
    EmptyLikelihoodRatioTest('C', 'D', ['A'], 0),
    EmptyLikelihoodRatioTest('C', 'D', ['B'], 0),
    # cond set 2
    EmptyLikelihoodRatioTest('A', 'B', ['C', 'D'], 0),
    EmptyLikelihoodRatioTest('A', 'C', ['B', 'D'], 0),
    EmptyLikelihoodRatioTest('A', 'D', ['B', 'C'], 1),
    EmptyLikelihoodRatioTest('B', 'C', ['A', 'D'], 0),
    EmptyLikelihoodRatioTest('B', 'D', ['A', 'C'], 0),
    EmptyLikelihoodRatioTest('C', 'D', ['A', 'B'], 0),
]

In [190]:
import dgp

# fork
node1 = dgp.GenericNode('A')
node2 = dgp.GenericNode('B', parents=[node1])
node3 = dgp.GenericNode('C', parents=[node1])
nc1 = dgp.NodeCollection([node1, node2, node3])

# collider
node1 = dgp.GenericNode('A')
node2 = dgp.GenericNode('B')
node3 = dgp.GenericNode('C', parents=[node1, node2])
nc2 = dgp.NodeCollection([node1, node2, node3])

# diamond
node1 = dgp.GenericNode('A')
node2 = dgp.GenericNode('B', parents=[node1])
node3 = dgp.GenericNode('C', parents=[node1])
node4 = dgp.GenericNode('D', parents=[node2, node3])
nc3 = dgp.NodeCollection([node1, node2, node3, node4])

# chain
node1 = dgp.GenericNode('A')
node2 = dgp.GenericNode('B', parents=[node1])
node3 = dgp.GenericNode('C', parents=[node2])
node4 = dgp.GenericNode('D', parents=[node3])
nc4 = dgp.NodeCollection([node1, node2, node3, node4])


node1 = dgp.Node('A')
node2 = dgp.Node('B', parents=[node1])
node3 = dgp.Node('C', parents=[node1])
nc51 = dgp.NodeCollection([node1, node2, node3])

node1 = dgp.CategoricalNode('A')
node2 = dgp.CategoricalNode('B', parents=[node1])
node3 = dgp.CategoricalNode('C', parents=[node1])
nc52 = dgp.NodeCollection([node1, node2, node3])

node1 = dgp.OrdinalNode('A')
node2 = dgp.OrdinalNode('B', parents=[node1])
node3 = dgp.OrdinalNode('C', parents=[node1])
nc53 = dgp.NodeCollection([node1, node2, node3])


ncs = {
    1: nc1,
    2: nc2,
    3: nc3,
    4: nc4,
    51: nc51,
    52: nc52,
    53: nc53
    }

ncs_independences = {
    1: real_independence_tests_fork,
    2: real_independence_tests_collider,
    3: real_independence_tests_diamond,
    4: real_independence_tests_chain,
    51: real_independence_tests_fork,
    52: real_independence_tests_fork,
    53: real_independence_tests_fork
}

def get_sample_data(node_collection, num_samples):
    node_collection.reset()
    return node_collection.get(num_samples)

In [10]:
def get_servers(client_configurations, data):
    servers = {}    

    for splits in client_configurations:
        clients = {i:fedci.Client(pl.from_pandas(chunk)) for i,chunk in enumerate(np.array_split(data.to_pandas(), splits))}
        servers[server_id_pattern.format(chosen_dag, splits)] = fedci.Server(clients, max_regressors=max_regressors)
    return servers

In [11]:
def get_possible_tests(available_data):

    possible_tests = []
    max_conditioning_set_size = min(len(available_data), max_regressors) if max_regressors is not None else len(available_data)

    for y_var in available_data:
        set_of_regressors = available_data - {y_var}
        for x_var in set_of_regressors:
            set_of_conditioning_variables = set_of_regressors - {x_var}
            conditioning_sets = chain.from_iterable(combinations(set_of_conditioning_variables, r) for r in range(0,max_conditioning_set_size))
            possible_tests.extend([(y_var, x_var, sorted(list(s_labels))) for s_labels in conditioning_sets])
            
    return possible_tests


In [12]:
import polars.selectors as cs
import pandas as pd

In [27]:
from pycit import citest, itest

In [28]:
def test_mixed_independence(continuous, categorical):
    # ANOVA
    categories = np.unique(categorical)
    groups = [continuous[categorical == category] for category in categories]
    _, p_value = stats.f_oneway(*groups)
    #print(f"ANOVA F-statistic: {f_statistic}, p-value: {p_value}")

    # If categorical is binary, you can also use point-biserial correlation
    #if len(categories) == 2:
    #    point_biserial_corr, p_value = stats.pointbiserialr(categorical, continuous)
    #    print(f"Point-biserial correlation: {point_biserial_corr}, p-value: {p_value}")
    return p_value

In [73]:
def get_ground_truth_tests(data, possible_tests):  
    ground_truth_tests = []

    for test in possible_tests:
        if len(test[2]) > 0:
            X = data[test[0]].to_numpy()
            Y = data[test[1]].to_numpy()
            Z = data[test[2]].to_numpy()
            pvalue = citest(X, Y, Z, test_args={'statistic': 'mixed_cmi', 'n_jobs': 8})
            # if data.schema[test[0]] == pl.String and data.schema[test[1]] == pl.String:
            #     #print('A')
            #     X = data[test[0]].to_numpy()
            #     Y = data[test[1]].to_numpy()
            #     Z = data[test[2]].to_numpy()
            #     pvalue = citest(X, Y, Z, test_args={'statistic': 'mixed_cmi', 'n_jobs': 8})
            # elif data.schema[test[0]] == pl.String and data.schema[test[1]] == pl.Float64:
            #     #print('B')
            #     X = data[test[0]].to_numpy()
            #     Y = data[test[1]].to_numpy()
            #     Z = data[test[2]].to_numpy()
            #     pvalue = citest(X, Y, Z, test_args={'statistic': 'mixed_cmi', 'n_jobs': 8})
            # elif data.schema[test[0]] == pl.Float64 and data.schema[test[1]] == pl.String:
            #     #print('C')
            #     X = data[test[0]].to_numpy()
            #     Y = data[test[1]].to_numpy()
            #     Z = data[test[2]].to_numpy()
            #     pvalue = citest(X, Y, Z, test_args={'statistic': 'mixed_cmi', 'n_jobs': 8})
            # elif data.schema[test[0]] == pl.Float64 and data.schema[test[1]] == pl.Float64:
            #     #print('D')
            #     _, pvalue = CITests.pearsonr(test[1], test[0], list(test[2]), data.cast(pl.Float64).to_pandas(), boolean=False)
            # else:
            #     X = data[test[0]].to_numpy()
            #     Y = data[test[1]].to_numpy()
            #     Z = data[test[2]].to_numpy()
            #     pvalue = citest(X, Y, Z, test_args={'statistic': 'mixed_cmi', 'n_jobs': 8})
            #     #assert False, 'no fitting test'
        else:
            X = data[test[0]].to_numpy().astype(float)
            Y = data[test[1]].to_numpy().astype(float)
            pvalue = itest(X, Y, test_args={'statistic': 'mixed_mi', 'n_jobs': 8})
            
            # if data.schema[test[0]] == pl.String and data.schema[test[1]] == pl.String:
            #     crosstab = pd.crosstab(data.to_pandas()[test[0]], data.to_pandas()[test[1]])
            #     _, pvalue, _, _ = stats.chi2_contingency(crosstab)
            # elif data.schema[test[0]] == pl.String and data.schema[test[1]] == pl.Float64:
            #     #print('E')
            #     X = data[test[0]].to_numpy()
            #     Y = data[test[1]].to_numpy().astype(float)
            #     pvalue = test_mixed_independence(Y, X)
            # elif data.schema[test[0]] == pl.Float64 and data.schema[test[1]] == pl.String:
            #     #print('F')
            #     X = data[test[0]].to_numpy().astype(float)
            #     Y = data[test[1]].to_numpy()
            #     pvalue = test_mixed_independence(X, Y)
            # elif data.schema[test[0]] == pl.Float64 and data.schema[test[1]] == pl.Float64:
            #     #print('G')
            #     v0 = data[test[0]]
            #     v1 = data[test[1]]
            #     _, pvalue = stats.pearsonr(v0, v1)
            # #elif data.schema[test[0]] == pl.Int32 and data.schema[test[1]] == pl.Float64:
            # else:
            #     X = data[test[0]].to_numpy().astype(float)
            #     Y = data[test[1]].to_numpy().astype(float)
            #     pvalue = itest(X, Y, test_args={'statistic': 'mixed_mi', 'n_jobs': 8})
            #     #assert False, 'no fitting test w/o conditiong set'
        pvalue = round(pvalue,4)

        #print(test, pvalue)
                
        ground_truth_tests.append(EmptyLikelihoodRatioTest(test[0], test[1], list(test[2]), pvalue))
    return ground_truth_tests
# TODO: with and without conditioning set

In [82]:
def get_ground_truth_tests_old(data, possible_tests):    
    ground_truth_tests = []

    for test in possible_tests:
        print(test)

        if len(test[2]) > 0:
            X = data[test[0]].to_numpy()
            Y = data[test[1]].to_numpy()
            Z = data[test[2]].to_numpy()
            pvalue = citest(X, Y, Z, test_args={'statistic': 'mixed_cmi', 'n_jobs': 2})
        else:
            X = data[test[0]].to_numpy()
            Y = data[test[1]].to_numpy().astype(float)
            pvalue = test_mixed_independence(X, Y)

        pvalue = round(pvalue,4)
        
        ground_truth_tests.append(EmptyLikelihoodRatioTest(test[0], test[1], list(test[2]), pvalue))
    return ground_truth_tests

In [58]:
def get_ground_truth_tests_old(data, possible_tests):
    ground_truth_tests = []

    for test in possible_tests:
        if len(test[2]) > 0:
            #v0 = data[test[0]].values
            #v1 = data[test[1]].values
            #s = data[list(test[2])].values
            #p0 = test[3]
            #p1 = citest(v0, v1, s, test_args={'statistic': 'ksg_cmi', 'n_jobs': 8})
            
            _, p1 = CITests.pearsonr(test[1], test[0], list(test[2]), data.cast(pl.Float64).to_pandas(), boolean=False)
        else:
            
            #dummied_data = data.to_dummies(cs.string(), separator='__cat__', drop_first=True).cast(pl.Float64).to_pandas()
            #v0 = data[test[0]].cast(pl.Float64).to_pandas()
            #v1 = data[test[1]].cast(pl.Float64).to_pandas()
            
            d0 = data[test[0]]
            d1 = data[test[1]]

            
            #v0 = d0.to_dummies(cs.string(), separator='__cat__', drop_first=True).cast(pl.Float64).to_pandas()
            #v1 = d1.to_dummies(cs.string(), separator='__cat__', drop_first=True).cast(pl.Float64).to_pandas()
            
            v0 = d0.to_dummies(separator='__cat__', drop_first=True).cast(pl.Float64).to_pandas()
            v1 = d1.to_dummies(separator='__cat__', drop_first=True).cast(pl.Float64).to_pandas()
            
            
            _, p1 = stats.pearsonr(v0, v1)
            
        p1 = round(p1,4)
        
        ground_truth_tests.append(EmptyLikelihoodRatioTest(test[0], test[1], list(test[2]), p1))
    return ground_truth_tests

In [109]:
def join_categories_in_regression_sets(tests, reversed_category_expressions):
    #updated_tests = []
    for test in tests:
        test.X_labels = sorted(list(set([reversed_category_expressions[l] if l in reversed_category_expressions else l for l in test.X_labels])))
    return tests

def group_categorical_likelihood_tests(tests, category_expressions, reversed_category_expressions):
    #category_expressions = servers['dag_chain4_1c'].category_expressions
    #reversed_category_expressions = servers['dag_chain4_1c'].reversed_category_expressions
    #tests = server_ci_tests['dag_chain4_1c']

    updated_tests = []
    for test in tests:
        if test.y_label not in reversed_category_expressions:
            updated_tests.append(test)
            continue
        
        category_label = reversed_category_expressions[test.y_label]
        
        # Only run if the current test is the first category. This avoids duplicate tests
        if category_expressions[category_label][0] != test.y_label:
            continue
        
        categorical_test_group = []
        for test_lookup in tests:
            if test_lookup.y_label in category_expressions[category_label] and test_lookup.x_label == test.x_label and sorted(test_lookup.s_labels) == sorted(test.s_labels):
                categorical_test_group.append(test_lookup)
                
        lrt = CategoricalLikelihoodRatioTest(category_label, [t.t0 for t in categorical_test_group], [t.t1 for t in categorical_test_group], len(category_expressions[category_label]))
        updated_tests.append(lrt)
        
    return updated_tests


def group_ordinal_likelihood_tests(tests, ordinal_expressions, reversed_ordinal_expressions):
    #category_expressions = servers['dag_chain4_1c'].category_expressions
    #reversed_category_expressions = servers['dag_chain4_1c'].reversed_category_expressions
    #tests = server_ci_tests['dag_chain4_1c']

    updated_tests = []
    for test in tests:
        if test.y_label not in reversed_ordinal_expressions:
            updated_tests.append(test)
            continue
        
        category_label = reversed_ordinal_expressions[test.y_label]
        #print(category_label)
        
        # Only run if the current test is the first category. This avoids duplicate tests
        if ordinal_expressions[category_label][0] != test.y_label:
            continue
        
        categorical_test_group = []
        for test_lookup in tests:
            if test_lookup.y_label in ordinal_expressions[category_label] and test_lookup.x_label == test.x_label and sorted(test_lookup.s_labels) == sorted(test.s_labels):
                categorical_test_group.append(test_lookup)
                
        lrt = OrdinalLikelihoodRatioTest(category_label, [t.t0 for t in categorical_test_group], [t.t1 for t in categorical_test_group], len(ordinal_expressions[category_label]))
        updated_tests.append(lrt)
        
    return updated_tests


def get_server_test_results(servers):
    testing_rounds = {k:v.testing_engine.finished_rounds for k,v in servers.items()}
    testing_rounds = {k:join_categories_in_regression_sets(v, servers[k].reversed_category_expressions) for k,v in testing_rounds.items()}
    likelihood_tests = {k:fedci.get_likelihood_tests(v) for k,v in testing_rounds.items()}
    # fix up categorical tests
    likelihood_tests = {k:group_categorical_likelihood_tests(v, servers[k].category_expressions, servers[k].reversed_category_expressions) for k,v in likelihood_tests.items()}
    
    likelihood_tests = {k:group_ordinal_likelihood_tests(v, servers[k].ordinal_expressions, servers[k].reversed_ordinal_expressions) for k,v in likelihood_tests.items()}
    
    return likelihood_tests

In [110]:
def prepare_server_evaluation(ground_truth_tests, server_ci_tests):
    p_value_comparison = {k:[] for k in server_ci_tests.keys()}
    missing_test = {k:0 for k in server_ci_tests.keys()}
    
    for test in ground_truth_tests:
        for k in server_ci_tests.keys():
            matching_test = [t for t in server_ci_tests[k] if t.y_label == test.y_label and t.x_label == test.x_label and sorted(t.s_labels) == sorted(test.s_labels)]
            if len(matching_test) == 0:
                print(f'No matching test in {k} for {test}')
                missing_test[k] += 1
                continue
            assert len(matching_test) == 1
            matching_test = matching_test[0]          
            p_value_comparison[k].append((matching_test.p_val, test.p_val))
        
    missing_test = {k:v/len(server_ci_tests[k]) if len(server_ci_tests[k]) > 0 else 0 for k,v in missing_test.items()}
    return p_value_comparison, missing_test

In [111]:
def count_correct_alpha_thresholdings(data, alpha):
    c = sum([1 for a,b in data if (a < alpha and b < alpha) or (a > alpha and b > alpha)]) / len(data)
    return c

def count_correct_pval(data, tolerance=1e-4):
    c = sum([1 for a,b in data if abs(a-b)<tolerance]) / len(data)
    return c

def evaluate_results(p_value_comparison, alphas, tolerance):
    result_alpha = {}
    result_equality = {}
    for k,v in p_value_comparison.items():
        result_alpha[k] = {}
        result_equality[k] = count_correct_pval(v, tolerance)
        for alpha in alphas:
            result_alpha[k][alpha] = count_correct_alpha_thresholdings(v,alpha)
            
    return result_alpha, result_equality

In [112]:
def get_records(servers, alpha_tests, equality_tests, missed_tests, total_features, features_per_client, comparison_category):
    results = []
    for server_id in servers.keys():
        server = servers[server_id]
        alpha_test = alpha_tests[server_id]
        
        r = {
            'chosen_dag': chosen_dag,
            'num_clients': len(server.clients),
            'num_samples': TOTAL_SAMPLES,
            'comparison_category': comparison_category,
            'same_p_val': equality_tests[server_id],
            'missed_tests': missed_tests[server_id],
            'total_features': total_features,
            'features_per_client': features_per_client
        }
        for alpha, alpha_result in alpha_test.items():
            r[f'correctness_alpha_{alpha}'] = alpha_result
        results.append(r)

    return results

In [113]:
def csv_add_row(data, file):
    with open(file, 'a') as f:
        row = ','.join([str(d) for d in data]) + '\n'
        f.write(row)
            

def write_records(i, file, data):
    if len(data) == 0:
        return
    curr_file = file.format(i)
    if not os.path.exists(curr_file):
        csv_add_row(list(data[0].keys()), curr_file)
    for entry in data:
        csv_add_row(entry.values(), curr_file)

In [114]:
import polars.selectors as cs

In [161]:
def process(i):
    #print('Step 1/6 --> Setup')
    #data = pl.read_parquet(f'./fedci/testdata-{i}.parquet')
    #TOTAL_SAMPLES = len(data)
    data = get_sample_data(ncs[i], TOTAL_SAMPLES)
    servers = get_servers(client_configurations, data)

    #print('Step 2/6 --> Run Tests')
    for server in servers.values(): server.run_tests()

    #print('Step 3/6 --> Collect Results')
    possible_tests = get_possible_tests(set(data.columns))
    server_ci_tests = get_server_test_results(servers) 
    
    comparison_tests_collection = []
    ground_truth_tests = get_ground_truth_tests(data, possible_tests)
    comparison_tests_collection.append(('ground_truth', ground_truth_tests))
    
    if i in ncs_independences:
        real_independences = ncs_independences[i]
        comparison_tests_collection.append(('real', real_independences))
    #ground_truth_tests = real_indep3 # todo: maybe add addition call of prepare_server_evaluation with prefix for different types of ground truth tests
    
    for comparison_name, comparison_tests in comparison_tests_collection:

        #print('Step 4/6 --> Prepare Evaluation')
        p_val_comparisons, missed_tests = prepare_server_evaluation(comparison_tests, server_ci_tests)

        #print('Step 5/6 --> Run Evaluation')
        alpha_tests, equality_tests = evaluate_results(p_val_comparisons, alpha_comparisons, equality_tolerance)

        #print('Step 6/6 --> Log Results')
        records = get_records(servers, alpha_tests, equality_tests, missed_tests, len(ncs[i].nodes), len(ncs[i].nodes), comparison_name)
        
        write_records(i, log_filepattern, records)

In [162]:
# TODO: remove non-zero correctness.

In [191]:
for i in range(20):
    process(4)

'DataFrame.swapaxes' is deprecated and will be removed in a future version. Please use 'DataFrame.transpose' instead.


In [148]:
i = 5

In [149]:
data = get_sample_data(ncs[i], TOTAL_SAMPLES)
servers = get_servers(client_configurations, data)

#print('Step 2/6 --> Run Tests')
for server in servers.values(): server.run_tests()

#print('Step 3/6 --> Collect Results')
possible_tests = get_possible_tests(set(data.columns))
server_ci_tests = get_server_test_results(servers) 

comparison_tests_collection = []
ground_truth_tests = get_ground_truth_tests(data, possible_tests)
comparison_tests_collection.append(('ground_truth', ground_truth_tests))

if i in ncs_independences:
    real_independences = ncs_independences[i]
    comparison_tests_collection.append(('real', real_independences))

'DataFrame.swapaxes' is deprecated and will be removed in a future version. Please use 'DataFrame.transpose' instead.


In [150]:
server_ci_tests

{'dag_chain4_1c': [LikelihoodRatioTest - y: B, x: C, S: [], p: 0.0,
  LikelihoodRatioTest - y: B, x: D, S: [], p: 0.0011,
  LikelihoodRatioTest - y: B, x: A, S: [], p: 0.0,
  LikelihoodRatioTest - y: D, x: B, S: [], p: 0.0011,
  LikelihoodRatioTest - y: D, x: C, S: [], p: 0.0408,
  LikelihoodRatioTest - y: D, x: A, S: [], p: 0.3085,
  LikelihoodRatioTest - y: C, x: B, S: [], p: 0.0,
  LikelihoodRatioTest - y: C, x: D, S: [], p: 0.0408,
  LikelihoodRatioTest - y: C, x: A, S: [], p: 0.0,
  LikelihoodRatioTest - y: A, x: C, S: [], p: 0.0,
  LikelihoodRatioTest - y: A, x: B, S: [], p: 0.0,
  LikelihoodRatioTest - y: A, x: D, S: [], p: 0.3085,
  LikelihoodRatioTest - y: B, x: D, S: ['C'], p: 0.0,
  LikelihoodRatioTest - y: B, x: C, S: ['D'], p: 0.0,
  LikelihoodRatioTest - y: B, x: C, S: ['A'], p: 0.0281,
  LikelihoodRatioTest - y: B, x: A, S: ['C'], p: 0.0004,
  LikelihoodRatioTest - y: B, x: D, S: ['A'], p: 0.0,
  LikelihoodRatioTest - y: B, x: A, S: ['D'], p: 0.0,
  LikelihoodRatioTest -