# Causal Discovery and Injection for Feed Forward Neural Networks

Notebook collecting results for causal injection paper (synthetic experiments)

In [1]:
import sys
sys.path.insert(0,'..')
import re
import os
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from utils import load_pickle, plot_ly_by_compare_auto

result_folder = '../results'
output_folder = '../figures'
os.listdir(result_folder)

debug = False
save_figs = True
if save_figs:
    import kaleido

main_gray = '#262626'
sec_gray = '#595959'
main_blue = '#005383'
sec_blue = '#0085CA'
main_green = '#379f9f' 
sec_green = '#196363' 

In [2]:
def collect_results(
    version,
    known_p = 0.2,
    folder='./results/',
    nodes = '([\w]+)',
    seeds = '([\w]+)',
    dsets = '([\w.]+)',
    outfs = str(1),
    infs = str(1),
    thetas = '([-+]?([0-9]+)(\.[0-9]+)?)',
    verbose = False):

    count = 0
    res = []

    for filename in os.listdir(folder):
        match = re.search(f'Nested1FoldCASTLE.Reg.Synth.{nodes}.{dsets}.{version}.pkl$', filename)
        if match != None:
            count += 1
            
            report = load_pickle(os.path.join(folder, filename), verbose=False)
            for key, value in report.items():
                # print(key, value)
                t = [[match.group(3)],[value[i] for i in ['theta', 'n_nodes', 'N_edges', 'seed', 'data_size', 'MSE', 'MAE','right','matching']]]
                res.append([i for sublist in t for i in sublist])
                # print(res)
    print("Count=",count)

    df = pd.DataFrame(res)
    df.columns = ['V','Type', 'N_nodes', 'N_edges', 'seed', 'Size', 'MSE', 'MAE','right','matching']
    df['alpha'] = ((df['Size'].astype(int)*0.8)/df['N_nodes'].astype(int)).astype(int)
    df['branch'] = np.round(((df['N_edges'].astype(int))/df['N_nodes'].astype(int)).astype(float),0).astype(int)
    bins = pd.IntervalIndex.from_tuples([(0, 1), (1, 2), (2, 5)])
    df['branch_bin'] = pd.cut(df['branch'], bins).map(dict(zip(bins, ["1","2","5"])))
    df['rebase_den'] = (df['N_edges'].astype(int)-(df['N_edges']*known_p).astype(int))
    df['rebased'] = (df['matching'].astype(int)/(df['N_edges'].astype(int)-(df['N_edges']*known_p).astype(int))).clip(upper= 1 )
    df = df[df['alpha'] != 1000]
    
    return df

## Main Figure - Injection vs CASTLE+ with and without (20%) Noise variables 

- Size comparison

In [3]:
rerun_3l=False
if rerun_3l:
    %run -i main_synth.py --version="r_ex_1b_k20_l3_h2_n0" --branchf=1 --known_p=0.2 --noise_p=0.2 --hidden_l=3 --hidden_n_p=2 
    %run -i main_synth.py --version="r_ex_2b_k20_l3_h2_n0" --branchf=2 --known_p=0.2 --noise_p=0.2 --hidden_l=3 --hidden_n_p=2 
    %run -i main_synth.py --version="r_ex_5b_k20_l3_h2_n0" --branchf=5 --known_p=0.2 --noise_p=0.2 --hidden_l=3 --hidden_n_p=2

In [4]:
rerun_noise_3l=False
if rerun_noise_3l:
    %run -i main_synth.py --version="r_ex_1b_k20_l3_h2_n20" --branchf=1 --known_p=0.2 --noise_p=0.2 --hidden_l=3 --hidden_n_p=2 
    %run -i main_synth.py --version="r_ex_2b_k20_l3_h2_n20" --branchf=2 --known_p=0.2 --noise_p=0.2 --hidden_l=3 --hidden_n_p=2 
    %run -i main_synth.py --version="r_ex_5b_k20_l3_h2_n20" --branchf=5 --known_p=0.2 --noise_p=0.2 --hidden_l=3 --hidden_n_p=2

In [24]:
# ## Versions
version = '(r_ex_1b_k20_l3_h2_n0|r_ex_2b_k20_l3_h2_n0|r_ex_5b_k20_l3_h2_n0|r_ex_1b_k20_l3_h2_n20|r_ex_2b_k20_l3_h2_n20|r_ex_5b_k20_l3_h2_n20)'
   
df = collect_results(version=version, known_p = 0.2, folder=result_folder)

names = ['', ' w/noise']

groups = [  
    (names[1],["r_ex_1b_k20_l3_h2_n20","r_ex_2b_k20_l3_h2_n20","r_ex_5b_k20_l3_h2_n20"]),
    (names[0],["r_ex_1b_k20_l3_h2_n0","r_ex_2b_k20_l3_h2_n0","r_ex_5b_k20_l3_h2_n0"])
        ]

df['V2'] = df['V']
for v in [c[0] for c in groups]:
    group_l = [c[1] for c in groups if c[0]==v][0]
    if debug:
        print(group_l)
        print(v)
        # print([i in group_l for i in df.V2])

    df.loc[ [i in group_l for i in df.V2],'V'] = v

plot_ly_by_compare_auto(df, x='alpha', x_desc=r"$s = N/|V| \text{ (Dataset Size / Number of Nodes)}$", 
legend_cord=[0.5,1.15], names_list = names, comparison_lines_width= 260,
y1_range=[-0.9,1], y2_range= [0, 2.55], y2_ticks=[0,.3,.6,.9,1.2], xwidth=1100, 
name='compalpha_3L', version='3lvsNoise', save=save_figs, output_folder=output_folder)

Count= 90
['' ' w/noise']


In [26]:
df = df.sort_values(by=['alpha'],axis=0)
df=df[['Type', 'V', 'MSE','right', 'alpha']]

df.loc[ [i<=100 for i in df.alpha],'V2'] = 'le100'
df.loc[ [i>100 for i in df.alpha],'V2'] = 'gt100'

df.groupby(['Type','V','V2'], as_index=False).agg([ 'count','mean','std']).round(2).reset_index()
df.groupby(['Type','V'], as_index=False).agg([ 'count','mean','std']).round(2).reset_index()

Unnamed: 0_level_0,Type,V,MSE,MSE,MSE,right,right,right,alpha,alpha,alpha
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,count,mean,std,count,mean,std,count,mean,std
0,-1.0,,450,0.67,0.43,450,0.4,0.29,450,230,160.18
1,-1.0,w/noise,450,0.76,0.4,450,0.42,0.29,450,230,160.18
2,0.05,,450,0.63,0.36,450,0.47,0.27,450,230,160.18
3,0.05,w/noise,450,0.73,0.34,450,0.49,0.27,450,230,160.18


## Appendix - Percentage of Known Edges

- Edges comparison
- 10% vs 20% vs 50% 


In [6]:
rerun_known_50=False
if rerun_known_50:
    %run -i main_synth.py --version="r_ex_1b_k50_l3_h2_n0" --branchf=1 --known_p=0.5 --noise_p=0 --hidden_l=3 --hidden_n_p=2
    %run -i main_synth.py --version="r_ex_2b_k50_l3_h2_n0" --branchf=2 --known_p=0.5 --noise_p=0 --hidden_l=3 --hidden_n_p=2
    %run -i main_synth.py --version="r_ex_5b_k50_l3_h2_n0" --branchf=5 --known_p=0.5 --noise_p=0 --hidden_l=3 --hidden_n_p=2

In [7]:
rerun_known_10=False
if rerun_known_10:
    %run -i main_synth.py --version="r_ex_1b_k10_l3_h2_n0" --branchf=1 --known_p=0.1 --noise_p=0.0 --hidden_l=3 --hidden_n_p=2
    %run -i main_synth.py --version="r_ex_2b_k10_l3_h2_n0" --branchf=2 --known_p=0.1 --noise_p=0.0 --hidden_l=3 --hidden_n_p=2
    %run -i main_synth.py --version="r_ex_5b_k10_l3_h2_n0" --branchf=5 --known_p=0.1 --noise_p=0.0 --hidden_l=3 --hidden_n_p=2

In [8]:
## Versions
version = '(r_ex_1b_k10_l3_h2_n0|r_ex_2b_k10_l3_h2_n0|r_ex_5b_k10_l3_h2_n0|r_ex_1b_k20_l3_h2_n0|r_ex_2b_k20_l3_h2_n0|r_ex_5b_k20_l3_h2_n0|r_ex_1b_k50_l3_h2_n0|r_ex_2b_k50_l3_h2_n0|r_ex_5b_k50_l3_h2_n0)'

df = collect_results(version=version, known_p = 0.2, folder=result_folder)

df['rebase_den10'] = (df['N_edges'].astype(int)-(df['N_edges']*0.1).astype(int))
df['rebase_den20'] = (df['N_edges'].astype(int)-(df['N_edges']*0.2).astype(int))
df['rebase_den50'] = (df['N_edges'].astype(int)-(df['N_edges']*0.5).astype(int))

# df.loc[df['V']==0,'rebased'] = (df['matching'].astype(int)/df['rebase_den20']).clip(upper= 1)
# df.loc[df['V']==1,'rebased'] = (df['matching'].astype(int)/df['rebase_den50']).clip(upper= 1)

df['rebased10'] = (df['matching'].astype(int)/df['rebase_den10']).clip(upper= 1)
df['rebased20'] = (df['matching'].astype(int)/df['rebase_den20']).clip(upper= 1)
df['rebased50'] = (df['matching'].astype(int)/df['rebase_den50']).clip(upper= 1)

names = [' 10%',' 20%', ' 50%']

groups = [  
    (names[0],["r_ex_1b_k10_l3_h2_n0","r_ex_2b_k10_l3_h2_n0","r_ex_5b_k10_l3_h2_n0"]),
    (names[1],["r_ex_1b_k20_l3_h2_n0","r_ex_2b_k20_l3_h2_n0","r_ex_5b_k20_l3_h2_n0"]),
    (names[2],["r_ex_1b_k50_l3_h2_n0","r_ex_2b_k50_l3_h2_n0","r_ex_5b_k50_l3_h2_n0"])
        ]

df['V2'] = df['V']
for v in [c[0] for c in groups]:
    group_l = [c[1] for c in groups if c[0]==v][0]
    if debug:
        print(group_l)
        print(v)
        print([i in group_l for i in df.V2])

    df.loc[ [i in group_l for i in df.V2],'V'] = v

if debug:
    display(df)

# '#db7900'
# '#9454c4'
# '#531b7d'

plot_ly_by_compare_auto(df, x='branch_bin', x_desc=r"$e = |E|/|V| = \text{ Number of Edges/Number of Nodes}$", baseline=names[0], display_tab=False,
legend_cord=[0.5,1.15], names_list = names, comparison_lines_width= 100, injection_levels=[10, 20,50], colors_list=[main_gray, '#9454c4', main_blue, '#441469'],
y1_range=[-0.9,1], y2_range= [0, 2.2], y2_ticks=[0,.2,.4,.6,.8], xwidth=700, 
name='compedges_3L', version='10kvs20kvs50k', save=save_figs,output_folder=output_folder)

Count= 135
[' 10%' ' 20%' ' 50%']


## Appendix - Network Size Comparison

- Nodes Comparison
- (3.2*n_input) vs (2 x n_input - n_input x 2/3 - 2 x n_input)

In [9]:
rerun_1L=False
if rerun_1L:
    ## Run main script
    %run -i main_synth.py --version="r_ex3"  --branchf=1 --known_p=0.2 --noise_p=0.0
    %run -i main_synth.py --version="r_ex_2b"  --branchf=2 --known_p=0.2 --noise_p=0.0
    %run -i main_synth.py --version="r_ex_5b"  --branchf=5 --known_p=0.2 --noise_p=0.0

In [10]:
rerun_multi_3=False
if rerun_multi_3:
    ## Run main script
    %run -i main_synth.py --version="r_ex_1b_k20_l3_h2_n0" --branchf=1 --known_p=0.2 --noise_p=0.0 --hidden_l=3 --hidden_n_p=2
    %run -i main_synth.py --version="r_ex_2b_k20_l3_h2_n0" --branchf=2 --known_p=0.2 --noise_p=0.0 --hidden_l=3 --hidden_n_p=2
    %run -i main_synth.py --version="r_ex_5b_k20_l3_h2_n0" --branchf=5 --known_p=0.2 --noise_p=0.0 --hidden_l=3 --hidden_n_p=2

In [None]:
rerun_multi_2=False
if rerun_multi_2:
    ## Run main script
    %run -i main_synth.py --version="r_ex_1b_k20_l2_h1_n0" --branchf=1 --known_p=0.2 --noise_p=0.0 --hidden_l=3 --hidden_n_p=1
    %run -i main_synth.py --version="r_ex_2b_k20_l2_h1_n0" --branchf=2 --known_p=0.2 --noise_p=0.0 --hidden_l=3 --hidden_n_p=1
    %run -i main_synth.py --version="r_ex_5b_k20_l2_h1_n0" --branchf=5 --known_p=0.2 --noise_p=0.0 --hidden_l=3 --hidden_n_p=1

In [11]:
# ## Versions
version = '(r_ex3|r_ex_2b|r_ex_5b|r_ex_1b_k20_l3_h2_n0|r_ex_2b_k20_l3_h2_n0|r_ex_5b_k20_l3_h2_n0)'

df = collect_results(version=version, known_p = 0.2, folder=result_folder)

names = [' M=3',' M=1']

groups = [  
    (names[1],["r_ex_2b","r_ex3","r_ex_5b"]),
    (names[0],["r_ex_1b_k20_l3_h2_n0","r_ex_2b_k20_l3_h2_n0","r_ex_5b_k20_l3_h2_n0"])
        ]

df['V2'] = df['V']
for v in [c[0] for c in groups]:
    group_l = [c[1] for c in groups if c[0]==v][0]
    if debug:
        print(group_l)
        print(v)
        # print([i in group_l for i in df.V2])

    df.loc[ [i in group_l for i in df.V2],'V'] = v

plot_ly_by_compare_auto(df, x='N_nodes', x_desc=r"$|V| \text{ (Number of Nodes)}$", display_tab=False,
legend_cord=[0.48,1.15], names_list=names, comparison_lines_width= 160, colors_list = [sec_green,main_gray,main_blue,main_green],
y1_range=[-0.9,1], y2_range= [0,2.2], y2_ticks=[0,.2,.4,.6,0.8], xwidth=720,
name='compnodes_3L', version='3lvs1l', save=save_figs, output_folder=output_folder)


# plot_ly_by_compare(df, x='N_nodes', x_desc=r"$|V| = \text{ Number of Nodes in } \mathcal{G}$", 
# legend_cord=[0.5,1.15], names_list = ['CASTLE+ M=1','CASTLE+ M=3','Injected M=3','Injected M=1'],
# y1_range=[-0.9,1], y2_range= [0, 2.15], y2_ticks=[0,0.2,0.4,0.6,0.8,1], name='compnodes_3L', version='3lvs1l', xwidth = 700, sec_blue='#379f9f', sec_gray='#196363', 
# save=save_figs, y1_vis=True)

Count= 93
[' M=1' ' M=3']
