# Causal Injection into a Neural Network

In [1]:
import re
import os
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from utils import load_pickle, plot_ly_by_compare

save_figs = True
if save_figs:
    import kaleido

In [2]:
def collect_results(
    version,
    known_p = 0.2,
    folder='./results/',
    nodes = '([\w]+)',
    seeds = '([\w]+)',
    dsets = '([\w.]+)',
    outfs = str(1),
    infs = str(1),
    thetas = '([-+]?([0-9]+)(\.[0-9]+)?)',
    verbose = False):

    count = 0
    res = []

    for filename in os.listdir(folder):
        match = re.search(f'Nested1FoldCASTLE.Reg.Synth.{nodes}.{dsets}.{version}.pkl$', filename)
        if match != None:
            count += 1
            
            report = load_pickle(os.path.join(folder, filename), verbose=False)
            for key, value in report.items():
                # print(key, value)
                t = [[match.group(3)],[value[i] for i in ['theta', 'n_nodes', 'N_edges', 'seed', 'data_size', 'MSE', 'MAE','right','matching']]]
                res.append([i for sublist in t for i in sublist])
                # print(res)
    print("Count=",count)

    df = pd.DataFrame(res)
    df.columns = ['V','Type', 'N_nodes', 'N_edges', 'seed', 'Size', 'MSE', 'MAE','right','matching']
    df['alpha'] = ((df['Size'].astype(int)*0.8)/df['N_nodes'].astype(int)).astype(int)
    df['branch'] = np.round(((df['N_edges'].astype(int))/df['N_nodes'].astype(int)).astype(float),0).astype(int)
    bins = pd.IntervalIndex.from_tuples([(0, 1), (1, 2), (2, 5)])
    df['branch_bin'] = pd.cut(df['branch'], bins)
    df['rebase_den'] = (df['N_edges'].astype(int)-(df['N_edges']*known_p).astype(int))
    df['rebased'] = (df['matching'].astype(int)/(df['N_edges'].astype(int)-(df['N_edges']*known_p).astype(int))).clip(upper= 1 )
    df = df[df['alpha'] != 1000]
    
    return df

## Main Figure - 3L Injection vs CASTLE+ with and without (20%) Noise variables - Size comparison

In [3]:
rerun_3l=False
if rerun_3l:
    %run -i main_biggerDAG_extended.py --version="r_ex_1b_k20_l3_h2_n0" --branchf=1 --known_p=0.2 --noise_p=0.2 --hidden_l=3 --hidden_n_p=2 # PID 12393
    %run -i main_biggerDAG_extended.py --version="r_ex_2b_k20_l3_h2_n0" --branchf=2 --known_p=0.2 --noise_p=0.2 --hidden_l=3 --hidden_n_p=2 # PID 12915
    %run -i main_biggerDAG_extended.py --version="r_ex_5b_k20_l3_h2_n0" --branchf=5 --known_p=0.2 --noise_p=0.2 --hidden_l=3 --hidden_n_p=2

In [4]:
rerun_noise_3l=False
if rerun_noise_3l:
    %run -i main_biggerDAG_extended.py --version="r_ex_1b_k20_l3_h2_n20" --branchf=1 --known_p=0.2 --noise_p=0.2 --hidden_l=3 --hidden_n_p=2 # PID 12393
    %run -i main_biggerDAG_extended.py --version="r_ex_2b_k20_l3_h2_n20" --branchf=2 --known_p=0.2 --noise_p=0.2 --hidden_l=3 --hidden_n_p=2 # PID 12915
    %run -i main_biggerDAG_extended.py --version="r_ex_5b_k20_l3_h2_n20" --branchf=5 --known_p=0.2 --noise_p=0.2 --hidden_l=3 --hidden_n_p=2

In [7]:
# ## Versions
version = '(r_ex_1b_k20_l3_h2_n0|r_ex_2b_k20_l3_h2_n0|r_ex_5b_k20_l3_h2_n0|r_ex_1b_k20_l3_h2_n20|r_ex_2b_k20_l3_h2_n20|r_ex_5b_k20_l3_h2_n20)'

df = collect_results(version=version, known_p = 0.2)

### Check How many runs you have of a given version
# v="r_ex_1b_k20_l3_h2_n20"
# print(v)
# print((df[df['V']==v].groupby(by=['seed']).size()))
# print(df[df['V']==v].groupby(by=['N_nodes']).size())
# print(df[df['V']==v].groupby(by=['alpha']).size())


# v="r_ex_2b_k20_l3_h2_n20"
# print(v)
# print((df[df['V']==v].groupby(by=['seed']).size()))
# print(df[df['V']==v].groupby(by=['N_nodes']).size())
# print(df[df['V']==v].groupby(by=['alpha']).size())

v="r_ex_5b_k20_l3_h2_n20"
print(v)
print((df[df['V']==v].groupby(by=['seed']).size()))
print(df[df['V']==v].groupby(by=['N_nodes']).size())
print(df[df['V']==v].groupby(by=['alpha']).size())
# display(df)

### Group into two noise vs no noise
df['V2'] = df['V']
df['V'] = [int("n20" in i) for i in df['V2'] ]

plot_ly_by_compare(df, x='alpha', x_desc=r"$s = N/|V| \text{ (Dataset Size / Number of Nodes)}$", 
legend_cord=[0.5,1.15], names_list = ['CASTLE+ w/noise','CASTLE+','Injected','Injected w/noise'],
y1_range=[-0.9,1], y2_range= [0, 3], y2_ticks=[0,.3,.6,.9,1.2], name='compalpha_v3', version='3lvsNoise', save=save_figs, xwidth=1100)


# plot_ly_by_compare(df, x='branch_bin', x_desc=r"$b = |E|/|V| \text{ (Number of Edges / Number of Nodes)}$", 
# legend_cord=[0.5,1.15], names_list = ['CASTLE+ w/noise','CASTLE+','Injected','Injected w/noise'],
# y1_range=[-0.9,1], y2_range= [0, 3], y2_ticks=[0,.3,.6,.9,1.2], name='compedge_v3', version='3lvsNoise', save=False, xwidth=700)


Count= 90
r_ex_5b_k20_l3_h2_n20
seed
0       30
10      30
100     30
1000    30
2000    30
3000    30
4000    30
5000    30
6000    25
dtype: int64
N_nodes
10    90
20    90
50    85
dtype: int64
alpha
50     54
100    54
200    53
300    52
500    52
dtype: int64
[0 1]


Unnamed: 0,Type,V,alpha,N_nodescount,N_nodesmean,N_nodesstd,N_edgescount,N_edgesmean,N_edgesstd,seedcount,...,branchcount,branchmean,branchstd,rebase_dencount,rebase_denmean,rebase_denstd,rebasedcount,rebasedmean,rebasedstd,text
0,-1.0,0,50,90,26.67,17.09,90,62.59,65.36,90,...,90,2.33,1.34,90,50.48,52.16,90,0.19,0.24,0.93 (0.65)
1,-1.0,0,100,90,26.67,17.09,90,62.59,65.36,90,...,90,2.33,1.34,90,50.48,52.16,90,0.39,0.32,0.65 (0.36)
2,-1.0,0,200,90,26.67,17.09,90,62.59,65.36,90,...,90,2.33,1.34,90,50.48,52.16,90,0.58,0.31,0.59 (0.31)
3,-1.0,0,300,90,26.67,17.09,90,62.59,65.36,90,...,90,2.33,1.34,90,50.48,52.16,90,0.61,0.32,0.59 (0.31)
4,-1.0,0,500,90,26.67,17.09,90,62.59,65.36,90,...,90,2.33,1.34,90,50.48,52.16,90,0.64,0.28,0.58 (0.31)
5,-1.0,1,50,87,26.67,17.1,87,60.9,64.07,87,...,87,2.28,1.32,87,49.11,51.18,87,0.18,0.24,1.04 (0.6)
6,-1.0,1,100,87,26.67,17.1,87,60.9,64.07,87,...,87,2.28,1.32,87,49.11,51.18,87,0.41,0.31,0.76 (0.32)
7,-1.0,1,200,87,26.67,17.1,87,60.9,64.07,87,...,87,2.28,1.32,87,49.11,51.18,87,0.61,0.32,0.69 (0.28)
8,-1.0,1,300,86,26.4,17.01,86,58.93,61.75,86,...,86,2.24,1.29,86,47.55,49.33,86,0.68,0.3,0.69 (0.27)
9,-1.0,1,500,86,26.4,17.01,86,58.93,61.75,86,...,86,2.24,1.29,86,47.55,49.33,86,0.68,0.27,0.68 (0.27)


## 3L vs 1L - Nodes comparison

In [5]:
rerun_experiment_20=False
if rerun_experiment_20:
    ## Run main script
    %run -i main_biggerDAG_extended.py --version="r_ex3"  --branchf=1 --known_p=0.2 --noise_p=0.0
    %run -i main_biggerDAG_extended.py --version="r_ex_2b"  --branchf=2 --known_p=0.2 --noise_p=0.0
    %run -i main_biggerDAG_extended.py --version="r_ex_5b"  --branchf=5 --known_p=0.2 --noise_p=0.0

In [6]:
rerun_multi_2=False
if rerun_multi_2:
    ## Run main script
    %run -i main_biggerDAG_extended.py --version="r_ex_1b_k20_l3_h2_n0" --branchf=1 --known_p=0.2 --noise_p=0.0 --hidden_l=3 --hidden_n_p=2
    %run -i main_biggerDAG_extended.py --version="r_ex_2b_k20_l3_h2_n0" --branchf=2 --known_p=0.2 --noise_p=0.0 --hidden_l=3 --hidden_n_p=2
    %run -i main_biggerDAG_extended.py --version="r_ex_5b_k20_l3_h2_n0" --branchf=5 --known_p=0.2 --noise_p=0.0 --hidden_l=3 --hidden_n_p=2

In [21]:
# ## Versions
version = '(r_ex3|r_ex_2b|r_ex_5b|r_ex_1b_k20_l3_h2_n0|r_ex_2b_k20_l3_h2_n0|r_ex_5b_k20_l3_h2_n0)'

df = collect_results(version=version, known_p = 0.2)

df['V2'] = df['V']
df['V'] = [int("k20_l3_h2_n0" not in i) for i in df['V2'] ]

# display(df)

plot_ly_by_compare(df, x='N_nodes', x_desc=r"$|V| = \text{ Number of Nodes in } \mathcal{G}$", 
legend_cord=[0.5,1.15], names_list = ['CASTLE+ M=1','CASTLE+ M=3','Injected M=3','Injected M=1'],
y1_range=[-0.9,1], y2_range= [0, 2.15], y2_ticks=[0,0.2,0.4,0.6,0.8,1], name='compnodes_v2', version='1lvs3l', xwidth = 700, sec_blue='#379f9f', sec_gray='#196363', 
save=save_figs, y1_vis=True)

plot_ly_by_compare(df, x='branch_bin', x_desc=r"$b= N/|E| = \text{ Dataset Size/Number of Edges}$", 
legend_cord=[0.5,1.15], names_list = ['CASTLE+ M=1','CASTLE+ M=3','Injected M=3','Injected M=1'],
y1_range=[-0.9,1], y2_range= [0, 2.15], y2_ticks=[0,0.2,0.4,0.6,0.8,1], name='compedges_v2', version='1lvs3l', xwidth = 700, sec_blue='#379f9f', sec_gray='#196363', 
save=save_figs, y1_vis=True)

# plot_ly_by_compare(df, x='alpha', x_desc=r"$s = N/|V| \text{ (Dataset Size / Number of Nodes)}$", 
# legend_cord=[0.5,1.15], names_list = ['CASTLE+ M=1','CASTLE+ M=3','Injected M=3','Injected M=1'],
# y1_range=[-0.9,1], y2_range= [0, 2.2], name='compalpha_v2', version='1lvs3l', xwidth = 1100, sec_blue='#379f9f', sec_gray='#196363', 
# save=save_figs, y1_vis=True)

Count= 93
[0 1]


Unnamed: 0,Type,V,N_nodes,N_edgescount,N_edgesmean,N_edgesstd,seedcount,seedmean,seedstd,Sizecount,...,branchcount,branchmean,branchstd,rebase_dencount,rebase_denmean,rebase_denstd,rebasedcount,rebasedmean,rebasedstd,text
0,-1.0,0,10,150,18.5,8.24,150,2811,2477.95,150,...,150,2.0,0.82,150,15.5,6.61,150,0.62,0.26,0.56 (0.35)
1,-1.0,0,20,150,43.37,22.17,150,2811,2477.95,150,...,150,2.33,1.25,150,35.03,17.97,150,0.44,0.35,0.89 (0.44)
2,-1.0,0,50,150,125.9,76.4,150,2811,2477.95,150,...,150,2.67,1.71,150,100.9,61.05,150,0.38,0.35,0.55 (0.43)
3,-1.0,1,10,150,19.47,9.01,150,2811,2477.95,150,...,150,2.0,0.82,150,16.0,7.33,150,0.47,0.3,0.46 (0.36)
4,-1.0,1,20,150,45.6,25.61,150,2811,2477.95,150,...,150,2.33,1.25,150,36.9,20.57,150,0.35,0.28,0.55 (0.37)
5,-1.0,1,50,150,124.77,75.42,150,2811,2477.95,150,...,150,2.57,1.59,150,100.2,60.27,150,0.46,0.22,0.57 (0.34)
6,0.05,0,10,150,18.5,8.24,150,2811,2477.95,150,...,150,2.0,0.82,150,15.5,6.61,150,0.64,0.25,0.55 (0.35)
7,0.05,0,20,150,43.37,22.17,150,2811,2477.95,150,...,150,2.33,1.25,150,35.03,17.97,150,0.51,0.34,0.83 (0.36)
8,0.05,0,50,150,125.9,76.4,150,2811,2477.95,150,...,150,2.67,1.71,150,100.9,61.05,150,0.52,0.32,0.5 (0.28)
9,0.05,1,10,150,19.47,9.01,150,2811,2477.95,150,...,150,2.0,0.82,150,16.0,7.33,150,0.56,0.28,0.46 (0.36)


[0 1]


Unnamed: 0,Type,V,branch_bin,N_nodescount,N_nodesmean,N_nodesstd,N_edgescount,N_edgesmean,N_edgesstd,seedcount,...,branchcount,branchmean,branchstd,rebase_dencount,rebase_denmean,rebase_denstd,rebasedcount,rebasedmean,rebasedstd,text
0,-1.0,0,"(0, 1]",150,26.67,17.05,150,26.33,17.38,150,...,150,1.0,0.0,150,21.33,13.64,150,0.74,0.33,0.93 (0.37)
1,-1.0,0,"(1, 2]",150,26.67,17.05,150,50.77,34.27,150,...,150,2.0,0.0,150,41.1,27.45,150,0.39,0.25,0.59 (0.37)
2,-1.0,0,"(2, 5]",150,26.67,17.05,150,110.67,86.56,150,...,150,4.0,0.82,150,89.0,68.91,150,0.31,0.26,0.48 (0.42)
3,-1.0,1,"(0, 1]",150,26.67,17.05,150,26.2,16.89,150,...,150,1.0,0.0,150,21.27,13.55,150,0.52,0.31,0.6 (0.39)
4,-1.0,1,"(1, 2]",150,26.67,17.05,150,50.77,34.27,150,...,150,2.0,0.0,150,41.1,27.45,150,0.42,0.25,0.56 (0.33)
5,-1.0,1,"(2, 5]",150,26.67,17.05,150,112.87,83.68,150,...,150,3.9,0.75,150,90.73,66.82,150,0.34,0.22,0.43 (0.33)
6,0.05,0,"(0, 1]",150,26.67,17.05,150,26.33,17.38,150,...,150,1.0,0.0,150,21.33,13.64,150,0.83,0.24,0.87 (0.22)
7,0.05,0,"(1, 2]",150,26.67,17.05,150,50.77,34.27,150,...,150,2.0,0.0,150,41.1,27.45,150,0.49,0.22,0.56 (0.34)
8,0.05,0,"(2, 5]",150,26.67,17.05,150,110.67,86.56,150,...,150,4.0,0.82,150,89.0,68.91,150,0.35,0.25,0.45 (0.36)
9,0.05,1,"(0, 1]",150,26.67,17.05,150,26.2,16.89,150,...,150,1.0,0.0,150,21.27,13.55,150,0.61,0.28,0.59 (0.38)


## 3L Partial Injection - 20% vs 50% known edges - Edges comparison

In [8]:
rerun_known_3l=False
if rerun_known_3l:
    %run -i main_biggerDAG_extended.py --version="r_ex_1b_k50_l3_h2_n0" --branchf=1 --known_p=0.5 --noise_p=0 --hidden_l=3 --hidden_n_p=2
    %run -i main_biggerDAG_extended.py --version="r_ex_2b_k50_l3_h2_n0" --branchf=2 --known_p=0.5 --noise_p=0 --hidden_l=3 --hidden_n_p=2
    %run -i main_biggerDAG_extended.py --version="r_ex_5b_k50_l3_h2_n0" --branchf=5 --known_p=0.5 --noise_p=0 --hidden_l=3 --hidden_n_p=2

In [9]:
def plot_ly_by_compare3(df, 
        x, x_desc, 
        y1='right', y1_desc="Reconstruction Accuracy", 
        margin_list = [10, 10, 0, 10, 0],
        y1_range=[-0.8, 1], y1_ticks=[0,0.2,0.4,0.6,0.8,1], y1_vis=True,
        y2='MSEmean', y2_desc="Mean Squared Error",
        y2_range= [0, 1.8], y2_ticks=[0,0.2,0.5,0.8], y2_vis=True,
        showleg = True,
        legend_cord = [0.8,1.05],
        save=False, name='',version='', xwidth = 1100,
        main_gray = '#262626',
        sec_gray = '#595959',
        main_blue = '#005383',
        sec_blue = '#0085CA',
        names_list = ['CASTLE+ w/noise','CASTLE+','Injected','Injected w/noise']
       ):
    colors_list = [sec_gray,main_gray,main_blue,sec_blue]

    df = df.sort_values(by=[x],axis=0)
    # df=df[['Type', 'N_nodes','alpha', 'MSE', 'MAE','right']]
    mses = df.groupby(['Type','V',x], as_index=False).agg([ 'count','mean','std']).round(2).reset_index()
    mses.columns = list(map(''.join, mses.columns.values))
    mses['text'] = mses[['MSEmean','MSEstd']].apply(lambda x : '{} ({})'.format(x[0],x[1]), axis=1)
    mses[x] = mses[x].astype(str)

    print(np.unique(df['V']))
    display(mses)

    fig = make_subplots(specs=[[{"secondary_y": True}]])

    ########## BarCharts
    # msesC = df[(df['Type']==-1)].groupby(['Type',x], as_index=False).agg([ 'count','mean','std']).round(2).reset_index()
    # msesC.columns = list(map(''.join, msesC.columns.values))
    # msesC['text'] = msesC[['MSEmean','MSEstd']].apply(lambda x : '{} ({})'.format(x[0],x[1]), axis=1)
    # msesC[x] = msesC[x].astype(str)

    msesC = mses[(mses['Type']==-1) & (mses['V']==np.unique(mses['V'])[0])]

    fig.add_trace(
        go.Bar(x=msesC[x], y=msesC[y2], name="CASTLEMSE",
        marker_color=colors_list[0],#'#FF4136', 
        marker_line_color=colors_list[0],#rgb(8,48,107)',
                    marker_line_width=2, opacity=0.6,
        # error_y=dict(
        #         type='data', # value of error bar given in data coordinates
        #         array=msesC['MSEstd'],
        #         visible=True,
        #         color='Black',
        #         thickness=1.5,
        #         width=3),
        showlegend=False,
            text=msesC['text'],
            textposition='auto',
                    ),
        secondary_y=True,
    )

    msesI = mses[(mses['Type']==0.05) & (mses['V']==np.unique(mses['V'])[0])]
    fig.add_trace(
        go.Bar(x=msesI[x], y=msesI[y2], name="InjectedMSE",
        marker_color=colors_list[2],#'#0085CA',#'#3D9970', 
        marker_line_color=colors_list[2],#'#0085CA',#'White',
                    marker_line_width=2, opacity=0.6,
        # error_y=dict(
        #         type='data', # value of error bar given in data coordinates
        #         array=msesI['MSEstd'],
        #         visible=True,
        #         color='Black',
        #         thickness=1.5,
        #         width=3),
        showlegend=False,
            text=msesI['text'],
            textposition='inside',
                    ),
        secondary_y=True,
    )
    msesI = mses[(mses['Type']==0.05) & (mses['V']==np.unique(mses['V'])[1])]
    fig.add_trace(
        go.Bar(x=msesI[x], y=msesI[y2], name="InjectedMSE",
        marker_color=colors_list[3],#'#0085CA',#'#3D9970', 
        marker_line_color=colors_list[3],#'#0085CA',#'White',
                    marker_line_width=2, opacity=0.6,
        # error_y=dict(
        #         type='data', # value of error bar given in data coordinates
        #         array=msesI['MSEstd'],
        #         visible=True,
        #         color='Black',
        #         thickness=1.5,
        #         width=3),
        showlegend=False,
            text=msesI['text'],
            textposition='inside',
                    ),
        secondary_y=True,
    )


    ########## Boxplots
    fig.add_trace(go.Box(
        y=df[(df['Type']==-1) & (df['V']==np.unique(df['V'])[0])][y1],
        x=df[(df['Type']==-1) & (df['V']==np.unique(df['V'])[0])][x].astype(str),
        boxmean='sd', # represent mean
        boxpoints=False,#'outliers',
        name=names_list[0],
        marker_color= colors_list[0]
    ),
    secondary_y=False,)


    if y1=='right':

        # msesC = mses[(mses['Type']==-1)]
        d = msesC['rebased20mean']#[i*1.2 for i in msesC['rightmean']]

        fig.add_trace(
            go.Scatter(x=msesC[x], y=d, name="CASTLE+20%" , mode='markers', marker_symbol='line-ew',
                    marker=dict(
                    color='Black',
                    size=50,
                    line=dict(
                        color='Black',
                        width=2)
                        ),
                    showlegend=False),
            secondary_y=False, 
        )

        d = msesC['rebased50mean']#[i*1.2 for i in msesC['rightmean']]

        fig.add_trace(
            go.Scatter(x=msesC[x], y=d, name="CASTLE+50%" , mode='markers', marker_symbol='line-ew',
                    marker=dict(
                    color='Black',
                    size=50,
                    line=dict(
                        color='Black',
                        width=2)
                        ),
                    showlegend=False),
            secondary_y=False, 
        )

    fig.add_trace(go.Box(
        y=df[(df['Type']==0.05) & (df['V']==np.unique(df['V'])[0])][y1],
        x=df[(df['Type']==0.05) & (df['V']==np.unique(df['V'])[0])][x].astype(str),
        boxmean='sd', # represent mean
        boxpoints=False,#'outliers',
        name=names_list[2],
        marker_color=colors_list[2]#'#0085CA'#'#3D9970'
    ),
        secondary_y=False,)
    fig.add_trace(go.Box(
        y=df[(df['Type']==0.05) & (df['V']==np.unique(df['V'])[1])][y1],
        x=df[(df['Type']==0.05) & (df['V']==np.unique(df['V'])[1])][x].astype(str),
        boxmean='sd', # represent mean
        boxpoints=False,#'outliers',
        name=names_list[3],
        marker_color= colors_list[3]
    ),
        secondary_y=False,)


    fig.update_layout(
        # yaxis_title='normalized moisture',
        boxmode='group' # group together boxes of the different traces for each value of x
        ,bargap=0.2
        ,bargroupgap=0.1
    )

    if xwidth == 600:
        x1shift=-330
        x2shift=270
    elif xwidth == 1100:
        x1shift=-580
        x2shift=480
    else:
        x1shift = -xwidth*0.53       
        x2shift = xwidth*0.42            

    if not y1_vis:
        y1_desc = ''
    else:
        fig.add_annotation(
                # Don't specify y position, because yanchor="middle" should do it
                yshift=150,
                xshift=x1shift,
                align="left",
                valign="top",
                text=y1_desc,
                showarrow=False,
                xref="paper",
                yref="paper",
                xanchor="left",
                yanchor="top",
                # Parameter textangle allow you to rotate annotation how you want
                textangle=-90
            )
    if not y2_vis:
        y2_desc = ''
    else:
        fig.add_annotation(
            # Don't specify y position, because yanchor="middle" should do it
            yshift=10,
            xshift=x2shift,
            align="left",
            valign="top",
            text=y2_desc,
            showarrow=False,
            xref="paper",
            yref="paper",
            xanchor="left",
            yanchor="top",
            # Parameter textangle allow you to rotate annotation how you want
            textangle=-90
        )
    # Set x-axis title
    fig.update_xaxes(showgrid=True,
    title={'text':x_desc#,'font':{'size':18}
    })
    # Set y-axes titles
    fig.update_yaxes(showgrid=True,nticks=10,zeroline=True, title={'text':""#,'font':{'size':18}
    }, 
    range=y1_range,
    tickvals=y1_ticks,
    tickformat=".0%",
    secondary_y=False,
    showticklabels=y1_vis
    )
    fig.update_yaxes(showgrid=True,nticks=10,zeroline=True, title={'text':""#,'font':{'size':18}
    },
    range=y2_range,
    tickvals=y2_ticks,
    secondary_y=True,
    showticklabels=y2_vis)

    # Add figure title
    fig.update_layout(
        showlegend=showleg,
        title='',
        legend={
            'y':legend_cord[1],
            'x':legend_cord[0],
            # 'y':-0.08,
            # 'x':0.92,
            'orientation':"h",
            'xanchor': 'center',
            'yanchor': 'top'},
        template='plotly_white',
        autosize=True,
        width=xwidth, height=350, 
        margin=dict(
            l=margin_list[0],
            r=margin_list[1],
            b=margin_list[2],
            t=margin_list[3],
            pad=margin_list[4],
        ),
        font=dict(
            family='Serif',#"Courier New, monospace",
            size=18,
            color="Black"
        )    
    )

    if save:
        output_folder = "figures"
        if not os.path.exists(output_folder):
            os.mkdir(output_folder)

        out_path = os.path.join(output_folder,f"plot_{name}_{version}.png")

        import kaleido
        fig.write_image(out_path)

    fig.show()

In [10]:
# ## Versions
version = '(r_ex_1b_k20_l3_h2_n0|r_ex_2b_k20_l3_h2_n0|r_ex_5b_k20_l3_h2_n0|r_ex_1b_k50_l3_h2_n0|r_ex_2b_k50_l3_h2_n0|r_ex_5b_k50_l3_h2_n0)'

df = collect_results(version=version, known_p = 0.2)

### Check How many runs you have of a given version
# v="r_ex_1b_k50_l3_h2_n0"
# print(v)
# print((df[df['V']==v].groupby(by=['seed']).size()))
# print(df[df['V']==v].groupby(by=['N_nodes']).size())
# print(df[df['V']==v].groupby(by=['alpha']).size())


v="r_ex_2b_k50_l3_h2_n0"
print(v)
print((df[df['V']==v].groupby(by=['seed']).size()))
print(df[df['V']==v].groupby(by=['N_nodes']).size())
print(df[df['V']==v].groupby(by=['alpha']).size())

v="r_ex_5b_k50_l3_h2_n0"
print(v)
print((df[df['V']==v].groupby(by=['seed']).size()))
print(df[df['V']==v].groupby(by=['N_nodes']).size())
print(df[df['V']==v].groupby(by=['alpha']).size())

### Group into two noise vs no noise
df['V2'] = df['V']
df['V'] = [int("k20" in i) for i in df['V2'] ]

save_figs = True
from utils import plot_ly_by_compare
# plot_ly_by_compare(df, x='alpha', x_desc=r"$s = N/|V| \text{ (Dataset Size / Number of Nodes)}$", 
# legend_cord=[0.5,1.15], names_list = ['CASTLE+ w/noise','CASTLE+','Injected','Injected w/noise'],
# y1_range=[-0.9,1], y2_range= [0, 3], y2_ticks=[0,.3,.6,.9,1.2], name='compalpha_v3', version='3lvsNoise', save=save_figs, xwidth=1100)


# plot_ly_by_compare(df, x='branch_bin', x_desc=r"$b = E/|V| \text{ (Number of Edges / Number of Nodes)}$", 
# legend_cord=[0.5,1.15], names_list = ['CASTLE+ w/noise','CASTLE+','Injected','Injected w/noise'],
# y1_range=[-0.9,1], y2_range= [0, 3], y2_ticks=[0,.3,.6,.9,1.2], name='compedge_v3', version='3lvsNoise', save=False, xwidth=700)


Count= 75
r_ex_2b_k50_l3_h2_n0
seed
0       30
10      30
100     30
1000    20
dtype: int64
N_nodes
10    40
20    40
50    30
dtype: int64
alpha
50     22
100    22
200    22
300    22
500    22
dtype: int64
r_ex_5b_k50_l3_h2_n0
Series([], dtype: int64)
Series([], dtype: int64)
Series([], dtype: int64)


In [11]:
## Versions
version = '(r_ex_1b_k20_l3_h2_n0|r_ex_2b_k20_l3_h2_n0|r_ex_5b_k20_l3_h2_n0|r_ex_1b_k50_l3_h2_n0|r_ex_2b_k50_l3_h2_n0|r_ex_5b_k50_l3_h2_n0)'

df = collect_results(version=version, known_p = 0.2)

df['V2'] = df['V']
df['V'] = [int("50" in i) for i in df['V2'] ]

df['rebase_den20'] = (df['N_edges'].astype(int)-(df['N_edges']*0.2).astype(int))
df['rebase_den50'] = (df['N_edges'].astype(int)-(df['N_edges']*0.5).astype(int))

# df.loc[df['V']==0,'rebased'] = (df['matching'].astype(int)/df['rebase_den20']).clip(upper= 1)
# df.loc[df['V']==1,'rebased'] = (df['matching'].astype(int)/df['rebase_den50']).clip(upper= 1)

df['rebased20'] = (df['matching'].astype(int)/df['rebase_den20']).clip(upper= 1)
df['rebased50'] = (df['matching'].astype(int)/df['rebase_den50']).clip(upper= 1)

# display(df)

plot_ly_by_compare3(df, x='alpha', x_desc=r"$s = N/|V| \text{ (Dataset Size / Number of Nodes)}$", 
legend_cord=[0.5,1.15], names_list = ['CASTLE+','CASTLE+','Injected 20%','Injected 50%'],
y1_range=[-0.9,1], y2_range= [0, 3], y2_ticks=[0,.3,.6,.9,1.2], name='compalpha_3L', version='20kvs50k', xwidth = 700,  sec_blue='#531b7d',#'#ab56b3',
save=save_figs, y1_vis=True)

plot_ly_by_compare3(df, x='N_nodes', x_desc=r"$|V| = \text{ Number of Nodes in } \mathcal{G}$", 
legend_cord=[0.5,1.15], names_list = ['CASTLE+','CASTLE+','Injected 20%','Injected 50%'],
y1_range=[-0.9,1], y2_range= [0, 3], y2_ticks=[0,.3,.6,.9,1.2], name='compnodes_3L', version='20kvs50k', xwidth = 700,  sec_blue='#531b7d',#'#ab56b3',
save=save_figs, y1_vis=True)

plot_ly_by_compare3(df, x='branch_bin', x_desc=r"$b= |E|/|V| = \text{ Number of Edges/Number of Nodes}$", 
legend_cord=[0.5,1.15], names_list = ['CASTLE+','CASTLE+','Injected 20%','Injected 50%'],
y1_range=[-0.9,1], y2_range= [0, 3], y2_ticks=[0,.3,.6,.9,1.2],  name='compedges_3L', version='20kvs50k', xwidth = 700, sec_blue='#531b7d', sec_gray='#262626',
save=save_figs, y1_vis=True)

Count= 75
[0 1]


Unnamed: 0,Type,V,alpha,N_nodescount,N_nodesmean,N_nodesstd,N_edgescount,N_edgesmean,N_edgesstd,seedcount,...,rebase_den50count,rebase_den50mean,rebase_den50std,rebased20count,rebased20mean,rebased20std,rebased50count,rebased50mean,rebased50std,text
0,-1.0,0,50,90,26.67,17.09,90,62.59,65.36,90,...,90,31.54,32.6,90,0.19,0.24,90,0.29,0.35,0.93 (0.65)
1,-1.0,0,100,90,26.67,17.09,90,62.59,65.36,90,...,90,31.54,32.6,90,0.39,0.32,90,0.54,0.39,0.65 (0.36)
2,-1.0,0,200,90,26.67,17.09,90,62.59,65.36,90,...,90,31.54,32.6,90,0.58,0.31,90,0.75,0.32,0.59 (0.31)
3,-1.0,0,300,90,26.67,17.09,90,62.59,65.36,90,...,90,31.54,32.6,90,0.61,0.32,90,0.76,0.31,0.59 (0.31)
4,-1.0,0,500,90,26.67,17.09,90,62.59,65.36,90,...,90,31.54,32.6,90,0.64,0.28,90,0.81,0.26,0.58 (0.31)
5,-1.0,1,50,41,26.1,17.01,41,31.68,24.78,41,...,41,16.05,12.39,41,0.24,0.28,41,0.36,0.4,1.26 (0.59)
6,-1.0,1,100,41,26.1,17.01,41,31.68,24.78,41,...,41,16.05,12.39,41,0.57,0.31,41,0.75,0.35,0.91 (0.19)
7,-1.0,1,200,41,26.1,17.01,41,31.68,24.78,41,...,41,16.05,12.39,41,0.82,0.25,41,0.92,0.21,0.82 (0.13)
8,-1.0,1,300,41,26.1,17.01,41,31.68,24.78,41,...,41,16.05,12.39,41,0.87,0.17,41,0.97,0.07,0.81 (0.13)
9,-1.0,1,500,41,26.1,17.01,41,31.68,24.78,41,...,41,16.05,12.39,41,0.87,0.16,41,0.98,0.06,0.8 (0.13)


[0 1]


Unnamed: 0,Type,V,N_nodes,N_edgescount,N_edgesmean,N_edgesstd,seedcount,seedmean,seedstd,Sizecount,...,rebase_den50count,rebase_den50mean,rebase_den50std,rebased20count,rebased20mean,rebased20std,rebased50count,rebased50mean,rebased50std,text
0,-1.0,0,10,150,18.5,8.24,150,2811.0,2477.95,150,...,150,9.7,4.13,150,0.62,0.26,150,0.82,0.28,0.56 (0.35)
1,-1.0,0,20,150,43.37,22.17,150,2811.0,2477.95,150,...,150,21.9,11.29,150,0.44,0.35,150,0.57,0.37,0.89 (0.44)
2,-1.0,0,50,150,125.9,76.4,150,2811.0,2477.95,150,...,150,63.03,38.17,150,0.38,0.35,150,0.5,0.41,0.55 (0.43)
3,-1.0,1,10,70,11.0,3.19,70,2087.14,2408.23,70,...,70,5.86,1.37,70,0.73,0.27,70,0.89,0.24,0.9 (0.15)
4,-1.0,1,20,70,24.86,7.74,70,2087.14,2408.23,70,...,70,12.57,4.1,70,0.66,0.37,70,0.75,0.36,1.1 (0.3)
5,-1.0,1,50,65,61.31,20.81,65,2170.77,2480.57,65,...,65,30.77,10.62,65,0.62,0.37,65,0.75,0.4,0.75 (0.44)
6,0.05,0,10,150,18.5,8.24,150,2811.0,2477.95,150,...,150,9.7,4.13,150,0.64,0.25,150,0.84,0.27,0.55 (0.35)
7,0.05,0,20,150,43.37,22.17,150,2811.0,2477.95,150,...,150,21.9,11.29,150,0.51,0.34,150,0.65,0.33,0.83 (0.36)
8,0.05,0,50,150,125.9,76.4,150,2811.0,2477.95,150,...,150,63.03,38.17,150,0.52,0.32,150,0.68,0.35,0.5 (0.28)
9,0.05,1,10,70,11.0,3.19,70,2087.14,2408.23,70,...,70,5.86,1.37,70,0.72,0.25,70,0.88,0.23,0.89 (0.16)


[0 1]


Unnamed: 0,Type,V,branch_bin,N_nodescount,N_nodesmean,N_nodesstd,N_edgescount,N_edgesmean,N_edgesstd,seedcount,...,rebase_den50count,rebase_den50mean,rebase_den50std,rebased20count,rebased20mean,rebased20std,rebased50count,rebased50mean,rebased50std,text
0,-1.0,0,"(0, 1]",150,26.67,17.05,150,26.33,17.38,150,...,150,13.33,8.53,150,0.74,0.33,150,0.84,0.34,0.93 (0.37)
1,-1.0,0,"(1, 2]",150,26.67,17.05,150,50.77,34.27,150,...,150,25.63,17.12,150,0.39,0.25,150,0.6,0.36,0.59 (0.37)
2,-1.0,0,"(2, 5]",150,26.67,17.05,150,110.67,86.56,150,...,150,55.67,43.05,150,0.31,0.26,150,0.46,0.35,0.48 (0.42)
3,-1.0,1,"(0, 1]",150,26.67,17.05,150,26.33,17.38,150,...,150,13.33,8.53,150,0.74,0.33,150,0.84,0.34,0.93 (0.37)
4,-1.0,1,"(1, 2]",55,24.55,16.31,55,46.27,33.82,55,...,55,23.45,17.07,55,0.49,0.29,55,0.69,0.35,0.91 (0.23)
5,-1.0,1,"(2, 5]",0,,,0,,,0,...,0,,,0,,,0,,,nan (nan)
6,0.05,0,"(0, 1]",150,26.67,17.05,150,26.33,17.38,150,...,150,13.33,8.53,150,0.83,0.24,150,0.93,0.2,0.87 (0.22)
7,0.05,0,"(1, 2]",150,26.67,17.05,150,50.77,34.27,150,...,150,25.63,17.12,150,0.49,0.22,150,0.73,0.3,0.56 (0.34)
8,0.05,0,"(2, 5]",150,26.67,17.05,150,110.67,86.56,150,...,150,55.67,43.05,150,0.35,0.25,150,0.52,0.33,0.45 (0.36)
9,0.05,1,"(0, 1]",150,26.67,17.05,150,26.33,17.38,150,...,150,13.33,8.53,150,0.82,0.26,150,0.91,0.23,0.87 (0.21)


## 1L Partial Injection - 20% vs 50% known edges

In [12]:
rerun_experiment_20=False
if rerun_experiment_20:
    ## Run main script
    %run -i main_biggerDAG_extended.py --version="r_ex3"  --branchf=1 --known_p=0.2 --noise_p=0.0
    %run -i main_biggerDAG_extended.py --version="r_ex_2b"  --branchf=2 --known_p=0.2 --noise_p=0.0
    %run -i main_biggerDAG_extended.py --version="r_ex_5b"  --branchf=5 --known_p=0.2 --noise_p=0.0

In [13]:
rerun_experiment_50=False
if rerun_experiment_50:
    ## Run main script
    %run -i main_biggerDAG_extended.py --version="r_ex_1b_k50"  --branchf=1 --known_p=0.5 --noise_p=0.0
    %run -i main_biggerDAG_extended.py --version="r_ex_1b_50"  --branchf=2 --known_p=0.5 --noise_p=0.0
    %run -i main_biggerDAG_extended.py --version="r_ex_5b_50"  --branchf=5 --known_p=0.5 --noise_p=0.0

In [14]:
## Versions
version = '(r_ex3|r_ex_2b|r_ex_5b|r_ex_1b_k50|r_ex_1b_50|r_ex_5b_50)'

df = collect_results(version=version, known_p = 0.2)

df['V2'] = df['V']
df['V'] = [int("50" in i) for i in df['V2'] ]

df['rebase_den20'] = (df['N_edges'].astype(int)-(df['N_edges']*0.2).astype(int))
df['rebase_den50'] = (df['N_edges'].astype(int)-(df['N_edges']*0.5).astype(int))

# df.loc[df['V']==0,'rebased'] = (df['matching'].astype(int)/df['rebase_den20']).clip(upper= 1)
# df.loc[df['V']==1,'rebased'] = (df['matching'].astype(int)/df['rebase_den50']).clip(upper= 1)

df['rebased20'] = (df['matching'].astype(int)/df['rebase_den20']).clip(upper= 1)
df['rebased50'] = (df['matching'].astype(int)/df['rebase_den50']).clip(upper= 1)

# display(df)

from utils import plot_ly_by_compare

# plot_ly_by_compare3(df, x='N_nodes', x_desc=r"$|V| = \text{ Number of Nodes in } \mathcal{G}$", 
# legend_cord=[0.5,1.15], names_list = ['CASTLE+','CASTLE+','Injected 20%','Injected 50%'],
# y1_range=[-0.9,1], y2_range= [0, 2], name='compnodes_v2', version='20kvs50k', xwidth = 700,  sec_blue='#ab56b3',
# save=save_figs, y1_vis=True)

plot_ly_by_compare3(df, x='branch_bin', x_desc=r"$b= |E|/|V| = \text{ Number of Edges/Number of Nodes}$", 
legend_cord=[0.5,1.15], names_list = ['CASTLE+','CASTLE+','Injected 20%','Injected 50%'],
y1_range=[-0.9,1], y2_range= [0, 2], name='compedges_v2', version='20kvs50k', xwidth = 700, sec_blue='#531b7d', sec_gray='#262626',
save=save_figs, y1_vis=True)

Count= 93
[0 1]


Unnamed: 0,Type,V,branch_bin,N_nodescount,N_nodesmean,N_nodesstd,N_edgescount,N_edgesmean,N_edgesstd,seedcount,...,rebase_den50count,rebase_den50mean,rebase_den50std,rebased20count,rebased20mean,rebased20std,rebased50count,rebased50mean,rebased50std,text
0,-1.0,0,"(0, 1]",150,26.67,17.05,150,26.2,16.89,150,...,150,13.27,8.44,150,0.52,0.31,150,0.7,0.36,0.6 (0.39)
1,-1.0,0,"(1, 2]",150,26.67,17.05,150,50.77,34.27,150,...,150,25.63,17.12,150,0.42,0.25,150,0.64,0.34,0.56 (0.33)
2,-1.0,0,"(2, 5]",150,26.67,17.05,150,112.87,83.68,150,...,150,56.7,41.82,150,0.34,0.22,150,0.52,0.3,0.43 (0.33)
3,-1.0,1,"(0, 1]",150,26.67,17.05,150,26.33,17.38,150,...,150,13.33,8.53,150,0.73,0.31,150,0.84,0.29,0.87 (0.21)
4,-1.0,1,"(1, 2]",150,26.67,17.05,150,50.77,34.27,150,...,150,25.63,17.12,150,0.42,0.26,150,0.62,0.36,0.56 (0.34)
5,-1.0,1,"(2, 5]",150,26.67,17.05,150,112.87,83.68,150,...,150,56.7,41.82,150,0.34,0.23,150,0.51,0.31,0.43 (0.33)
6,0.05,0,"(0, 1]",150,26.67,17.05,150,26.2,16.89,150,...,150,13.27,8.44,150,0.61,0.28,150,0.8,0.28,0.59 (0.38)
7,0.05,0,"(1, 2]",150,26.67,17.05,150,50.77,34.27,150,...,150,25.63,17.12,150,0.51,0.23,150,0.75,0.3,0.55 (0.33)
8,0.05,0,"(2, 5]",150,26.67,17.05,150,112.87,83.68,150,...,150,56.7,41.82,150,0.37,0.21,150,0.57,0.27,0.42 (0.32)
9,0.05,1,"(0, 1]",150,26.67,17.05,150,26.33,17.38,150,...,150,13.33,8.53,150,0.84,0.23,150,0.94,0.19,0.86 (0.19)


## 1L noise Experiment - add 20% noise variables

In [15]:
rerun_noise=False
if rerun_noise:
    %run -i main_biggerDAG_extended.py --version="r_ex_1b_k20_l1_h3_n20" --branchf=1 --known_p=0.2 --noise_p=0.2 --hidden_l=1 --hidden_n_p=3.2
    %run -i main_biggerDAG_extended.py --version="r_ex_2b_k20_l1_h3_n20" --branchf=2 --known_p=0.2 --noise_p=0.2 --hidden_l=1 --hidden_n_p=3.2
    %run -i main_biggerDAG_extended.py --version="r_ex_5b_k20_l1_h3_n20" --branchf=5 --known_p=0.2 --noise_p=0.2 --hidden_l=1 --hidden_n_p=3.2

In [16]:
# ## Versions
version = '(r_ex_2b|r_ex3|r_ex_5b|r_ex_1b_k20_l1_h3_n20|r_ex_2b_k20_l1_h3_n20|r_ex_5b_k20_l1_h3_n20)'

df = collect_results(version=version, known_p = 0.2)

df['V2'] = df['V']
df['V'] = [int("n20" in i) for i in df['V2'] ]

# display(df)

from utils import plot_ly_by_compare
plot_ly_by_compare(df, x='alpha', x_desc=r"$s = N/|V| \text{ (Dataset Size / Number of Nodes)}$", 
legend_cord=[0.5,1.15], names_list = ['CASTLE+ w/noise','CASTLE+','Injected','Injected w/noise'],
y1_range=[-0.9,1], y2_range= [0, 1.7], y2_ticks=[0,.2,.4,.6], name='compalpha_v2', version='1lvsNoise', save=save_figs, xwidth=1100)


Count= 93
[0 1]


Unnamed: 0,Type,V,alpha,N_nodescount,N_nodesmean,N_nodesstd,N_edgescount,N_edgesmean,N_edgesstd,seedcount,...,branchcount,branchmean,branchstd,rebase_dencount,rebase_denmean,rebase_denstd,rebasedcount,rebasedmean,rebasedstd,text
0,-1.0,0,50,90,26.67,17.09,90,63.28,64.64,90,...,90,2.3,1.28,90,51.03,51.69,90,0.17,0.22,0.62 (0.43)
1,-1.0,0,100,90,26.67,17.09,90,63.28,64.64,90,...,90,2.3,1.28,90,51.03,51.69,90,0.34,0.23,0.53 (0.36)
2,-1.0,0,200,90,26.67,17.09,90,63.28,64.64,90,...,90,2.3,1.28,90,51.03,51.69,90,0.54,0.22,0.5 (0.33)
3,-1.0,0,300,90,26.67,17.09,90,63.28,64.64,90,...,90,2.3,1.28,90,51.03,51.69,90,0.54,0.23,0.5 (0.33)
4,-1.0,0,500,90,26.67,17.09,90,63.28,64.64,90,...,90,2.3,1.28,90,51.03,51.69,90,0.55,0.23,0.49 (0.33)
5,-1.0,1,50,90,26.67,17.09,90,62.92,64.99,90,...,90,2.32,1.32,90,50.76,51.92,90,0.12,0.19,0.65 (0.46)
6,-1.0,1,100,90,26.67,17.09,90,62.92,64.99,90,...,90,2.32,1.32,90,50.76,51.92,90,0.23,0.25,0.55 (0.39)
7,-1.0,1,200,90,26.67,17.09,90,62.92,64.99,90,...,90,2.32,1.32,90,50.76,51.92,90,0.53,0.24,0.5 (0.33)
8,-1.0,1,300,90,26.67,17.09,90,62.92,64.99,90,...,90,2.32,1.32,90,50.76,51.92,90,0.55,0.24,0.5 (0.33)
9,-1.0,1,500,90,26.67,17.09,90,62.92,64.99,90,...,90,2.32,1.32,90,50.76,51.92,90,0.54,0.24,0.5 (0.33)


## 3way comparison - Noise and Multi layer with 1L - 20% known

In [17]:
def plot_ly_by_compare6(df, 
        x, x_desc, 
        y1='right', y1_desc="Reconstruction Accuracy", 
        margin_list = [10, 10, 0, 10, 0],
        y1_range=[-0.8, 1], y1_ticks=[0,0.2,0.4,0.6,0.8,1], y1_vis=True,
        y2='MSEmean', y2_desc="Mean Squared Error",
        y2_range= [0, 1.8], y2_ticks=[0,0.2,0.4,0.6,0.8], y2_vis=True,
        showleg = True,
        legend_cord = [0.8,1.05],
        save=False, name='',version='', xwidth = 1100,
        main_gray = '#262626',
        sec_gray = '#595959',
        main_blue = '#005383',
        sec_blue = '#0085CA',
        main_green = '#379f9f', 
        sec_green = '#196363', 
        names_list = ['CASTLE+ M=3','CASTLE+ M=1 w/noise','CASTLE+  M=1 ','Injected  M=1','Injected  M=1 w/noise','Injected  M=3']
        ,display_tab = False
       ):
    colors_list = [sec_green,sec_gray,main_gray,main_blue,sec_blue,main_green]

    df = df.sort_values(by=[x],axis=0)
    # df=df[['Type', 'N_nodes','alpha', 'MSE', 'MAE','right']]
    mses = df.groupby(['Type','V',x], as_index=False).agg([ 'count','mean','std']).round(2).reset_index()
    mses.columns = list(map(''.join, mses.columns.values))
    mses['text'] = mses[['MSEmean','MSEstd']].apply(lambda x : '{} ({})'.format(x[0],x[1]), axis=1)
    mses[x] = mses[x].astype(str)

    print(np.unique(df['V']))
    if display_tab:
        display(mses)

    fig = make_subplots(specs=[[{"secondary_y": True}]])

    ########## BarCharts
    msesC = mses[(mses['Type']==-1) & (mses['V']==np.unique(mses['V'])[2])]
    fig.add_trace(
        go.Bar(x=msesC[x], y=msesC[y2], name="CASTLEMSE",
        marker_color=colors_list[0],#'#FF4136', 
        marker_line_color=colors_list[0],#rgb(8,48,107)',
                    marker_line_width=2, opacity=0.6,
        # error_y=dict(
        #         type='data', # value of error bar given in data coordinates
        #         array=msesC['MSEstd'],
        #         visible=True,
        #         color='Black',
        #         thickness=1.5,
        #         width=3),
        showlegend=False,
            text=msesC['text'],
            textposition='auto',
                    ),
        secondary_y=True,
    )
    msesC = mses[(mses['Type']==-1) & (mses['V']==np.unique(mses['V'])[1])]
    fig.add_trace(
        go.Bar(x=msesC[x], y=msesC[y2], name="CASTLEMSE",
        marker_color=colors_list[1],#'#FF4136', 
        marker_line_color=colors_list[1],#rgb(8,48,107)',
                    marker_line_width=2, opacity=0.6,
        # error_y=dict(
        #         type='data', # value of error bar given in data coordinates
        #         array=msesC['MSEstd'],
        #         visible=True,
        #         color='Black',
        #         thickness=1.5,
        #         width=3),
        showlegend=False,
            text=msesC['text'],
            textposition='auto',
                    ),
        secondary_y=True,
    )
    msesC = mses[(mses['Type']==-1) & (mses['V']==np.unique(mses['V'])[0])]
    fig.add_trace(
        go.Bar(x=msesC[x], y=msesC[y2], name="CASTLEMSE",
        marker_color=colors_list[2],#'#FF4136', 
        marker_line_color=colors_list[2],#rgb(8,48,107)',
                    marker_line_width=2, opacity=0.6,
        # error_y=dict(
        #         type='data', # value of error bar given in data coordinates
        #         array=msesC['MSEstd'],
        #         visible=True,
        #         color='Black',
        #         thickness=1.5,
        #         width=3),
        showlegend=False,
            text=msesC['text'],
            textposition='auto',
                    ),
        secondary_y=True,
    )


    msesI = mses[(mses['Type']==0.05) & (mses['V']==np.unique(mses['V'])[0])]
    fig.add_trace(
        go.Bar(x=msesI[x], y=msesI[y2], name="InjectedMSE",
        marker_color=colors_list[3],#'#0085CA',#'#3D9970', 
        marker_line_color=colors_list[3],#'#0085CA',#'White',
                    marker_line_width=2, opacity=0.6,
        # error_y=dict(
        #         type='data', # value of error bar given in data coordinates
        #         array=msesI['MSEstd'],
        #         visible=True,
        #         color='Black',
        #         thickness=1.5,
        #         width=3),
        showlegend=False,
            text=msesI['text'],
            textposition='inside',
                    ),
        secondary_y=True,
    )
    msesI = mses[(mses['Type']==0.05) & (mses['V']==np.unique(mses['V'])[1])]
    fig.add_trace(
        go.Bar(x=msesI[x], y=msesI[y2], name="InjectedMSE",
        marker_color=colors_list[4],#'#0085CA',#'#3D9970', 
        marker_line_color=colors_list[4],#'#0085CA',#'White',
                    marker_line_width=2, opacity=0.6,
        # error_y=dict(
        #         type='data', # value of error bar given in data coordinates
        #         array=msesI['MSEstd'],
        #         visible=True,
        #         color='Black',
        #         thickness=1.5,
        #         width=3),
        showlegend=False,
            text=msesI['text'],
            textposition='inside',
                    ),
        secondary_y=True,
    )
    msesI = mses[(mses['Type']==0.05) & (mses['V']==np.unique(mses['V'])[2])]
    fig.add_trace(
        go.Bar(x=msesI[x], y=msesI[y2], name="InjectedMSE",
        marker_color=colors_list[5],#'#0085CA',#'#3D9970', 
        marker_line_color=colors_list[5],#'#0085CA',#'White',
                    marker_line_width=2, opacity=0.6,
        # error_y=dict(
        #         type='data', # value of error bar given in data coordinates
        #         array=msesI['MSEstd'],
        #         visible=True,
        #         color='Black',
        #         thickness=1.5,
        #         width=3),
        showlegend=False,
            text=msesI['text'],
            textposition='inside',
                    ),
        secondary_y=True,
    )

    ########## Boxplots
    fig.add_trace(go.Box(
        y=df[(df['Type']==-1) & (df['V']==np.unique(df['V'])[2])][y1],
        x=df[(df['Type']==-1) & (df['V']==np.unique(df['V'])[2])][x].astype(str),
        boxmean='sd', # represent mean
        boxpoints=False,#'outliers',
        name=names_list[0],
        marker_color= colors_list[0]
    ), secondary_y=False,)

    fig.add_trace(go.Box(
        y=df[(df['Type']==-1) & (df['V']==np.unique(df['V'])[1])][y1],
        x=df[(df['Type']==-1) & (df['V']==np.unique(df['V'])[1])][x].astype(str),
        boxmean='sd', # represent mean
        boxpoints=False,#'outliers',
        name=names_list[1],
        marker_color= colors_list[1]
    ), secondary_y=False,)

    fig.add_trace(go.Box(
        y=df[(df['Type']==-1) & (df['V']==np.unique(df['V'])[0])][y1],
        x=df[(df['Type']==-1) & (df['V']==np.unique(df['V'])[0])][x].astype(str),
        boxmean='sd', # represent mean
        boxpoints=False,#'outliers',
        name=names_list[2],
        marker_color=colors_list[2]#'#FF4136'
    ), secondary_y=False,)


    if y1=='right':

        for group, color in zip(range(1,4),range(2,-1,-1)):
            idx = group-1 
            msesC = mses[(mses['Type']==-1) & (mses['V']==np.unique(mses['V'])[idx])]
            d = msesC['rebasedmean']

            fig.add_trace(
                go.Scatter(x=msesC[x], y=d, name="CASTLE+20%" , mode='markers', marker_symbol='line-ew',
                        marker=dict(
                        color=colors_list[color],
                        size=group*30,
                        line=dict(
                            color=colors_list[color],
                            width=2)
                            ),
                        showlegend=False),
                secondary_y=False, 
            )


        # msesC = mses[(mses['Type']==-1) & (mses['V']==np.unique(mses['V'])[1])]
        # d = msesC['rebasedmean']#[i*1.2 for i in msesC['rightmean']]

        # fig.add_trace(
        #     go.Scatter(x=msesC[x], y=d, name="CASTLE+20%" , mode='markers', marker_symbol='line-ew',
        #             marker=dict(
        #             color='Black',
        #             size=100,
        #             line=dict(
        #                 color='Black',
        #                 width=2)
        #                 ),
        #             showlegend=False),
        #     secondary_y=False, 
        # )

    fig.add_trace(go.Box(
        y=df[(df['Type']==0.05) & (df['V']==np.unique(df['V'])[0])][y1],
        x=df[(df['Type']==0.05) & (df['V']==np.unique(df['V'])[0])][x].astype(str),
        boxmean='sd', # represent mean
        boxpoints=False,#'outliers',
        name=names_list[3],
        marker_color=colors_list[3]#'#0085CA'#'#3D9970'
    ), secondary_y=False,)

    fig.add_trace(go.Box(
        y=df[(df['Type']==0.05) & (df['V']==np.unique(df['V'])[1])][y1],
        x=df[(df['Type']==0.05) & (df['V']==np.unique(df['V'])[1])][x].astype(str),
        boxmean='sd', # represent mean
        boxpoints=False,#'outliers',
        name=names_list[4],
        marker_color= colors_list[4]
    ), secondary_y=False,)

    fig.add_trace(go.Box(
        y=df[(df['Type']==0.05) & (df['V']==np.unique(df['V'])[2])][y1],
        x=df[(df['Type']==0.05) & (df['V']==np.unique(df['V'])[2])][x].astype(str),
        boxmean='sd', # represent mean
        boxpoints=False,#'outliers',
        name=names_list[5],
        marker_color= colors_list[5]
    ), secondary_y=False,)

    fig.update_layout(
        # yaxis_title='normalized moisture',
        boxmode='group' # group together boxes of the different traces for each value of x
        ,bargap=0.1
        ,bargroupgap=0.1
    )

    if xwidth == 600:
        x1shift=-330
        x2shift=270
    elif xwidth == 1100:
        x1shift=-580
        x2shift=480
    elif xwidth == 1200:
        x1shift=-xwidth*0.525 
        x2shift=xwidth*0.436
    else:
        x1shift = -xwidth*0.53       
        x2shift = xwidth*0.42            

    if not y1_vis:
        y1_desc = ''
    else:
        fig.add_annotation(
                # Don't specify y position, because yanchor="middle" should do it
                yshift=150,
                xshift=x1shift,
                align="left",
                valign="top",
                text=y1_desc,
                showarrow=False,
                xref="paper",
                yref="paper",
                xanchor="left",
                yanchor="top",
                # Parameter textangle allow you to rotate annotation how you want
                textangle=-90
            )
    if not y2_vis:
        y2_desc = ''
    else:
        fig.add_annotation(
            # Don't specify y position, because yanchor="middle" should do it
            yshift=10,
            xshift=x2shift,
            align="left",
            valign="top",
            text=y2_desc,
            showarrow=False,
            xref="paper",
            yref="paper",
            xanchor="left",
            yanchor="top",
            # Parameter textangle allow you to rotate annotation how you want
            textangle=-90
        )
    # Set x-axis title
    fig.update_xaxes(showgrid=True,
    title={'text':x_desc#,'font':{'size':18}
    })
    # Set y-axes titles
    fig.update_yaxes(showgrid=True,nticks=10,zeroline=True, title={'text':""#,'font':{'size':18}
    }, 
    range=y1_range,
    tickvals=y1_ticks,
    tickformat=".0%",
    secondary_y=False,
    showticklabels=y1_vis
    )
    fig.update_yaxes(showgrid=True,nticks=10,zeroline=True, title={'text':""#,'font':{'size':18}
    },
    range=y2_range,
    tickvals=y2_ticks,
    secondary_y=True,
    showticklabels=y2_vis)

    # Add figure title
    fig.update_layout(
        showlegend=showleg,
        title='',
        legend={
            'y':legend_cord[1],
            'x':legend_cord[0],
            # 'y':-0.08,
            # 'x':0.92,
            'orientation':"h",
            'xanchor': 'center',
            'yanchor': 'top'},
        template='plotly_white',
        autosize=True,
        width=xwidth, height=350, 
        margin=dict(
            l=margin_list[0],
            r=margin_list[1],
            b=margin_list[2],
            t=margin_list[3],
            pad=margin_list[4],
        ),
        font=dict(
            family='Serif',#"Courier New, monospace",
            size=18,
            color="Black"
        )    
    )

    if save:
        output_folder = "figures"
        if not os.path.exists(output_folder):
            os.mkdir(output_folder)

        out_path = os.path.join(output_folder,f"plot_{name}_{version}.png")

        import kaleido
        fig.write_image(out_path)

    fig.show()


In [18]:
debug = True

## Versions
version = '(r_ex_2b|r_ex3|r_ex_5b|r_ex_1b_k20_l1_h3_n20|r_ex_2b_k20_l1_h3_n20|r_ex_5b_k20_l1_h3_n20|r_ex_1b_k20_l3_h2_n0|r_ex_2b_k20_l3_h2_n0|r_ex_5b_k20_l3_h2_n0)'

groups = [  
    ('20K',["r_ex_2b","r_ex3","r_ex_5b"]),
    ('20N',["r_ex_1b_k20_l1_h3_n20","r_ex_2b_k20_l1_h3_n20","r_ex_5b_k20_l1_h3_n20"]),
    ('3L',["r_ex_1b_k20_l3_h2_n0","r_ex_2b_k20_l3_h2_n0","r_ex_5b_k20_l3_h2_n0"])
        ]

df = collect_results(version=version, known_p = 0.2)

df['V2'] = df['V']
# df['V'] = [int("n20" in i) for i in df['V2'] ]

# df.loc['V'] = groups[0]

for v in [c[0] for c in groups]:
    group_l = [c[1] for c in groups if c[0]==v][0]
    if debug:
        print(group_l)
        print(v)
        # print([i in group_l for i in df.V2])

    df.loc[ [i in group_l for i in df.V2],'V'] = v


    # not_causes_idxs = [i for i in range(len(df.columns)) if list(df.columns)[i] in not_causes]
    # if debug:
    #     print("0ing:", not_causes_idxs)
    # partial_mat[not_causes_idxs,i] = 0  

display(df)

plot_ly_by_compare6(df, x='alpha', x_desc=r"$s = N/|V| \text{ (Dataset Size / Number of Nodes)}$", 
legend_cord=[0.48,1.15], display_tab=True,
y1_range=[-0.9,1], y2_range= [0,2.2], y2_ticks=[0,.2,.4,.6,0.8], name='compalpha_v3', version='1lvsNoisevs3l', save=save_figs, xwidth=1200)


Count= 138
['r_ex_2b', 'r_ex3', 'r_ex_5b']
20K
['r_ex_1b_k20_l1_h3_n20', 'r_ex_2b_k20_l1_h3_n20', 'r_ex_5b_k20_l1_h3_n20']
20N
['r_ex_1b_k20_l3_h2_n0', 'r_ex_2b_k20_l3_h2_n0', 'r_ex_5b_k20_l3_h2_n0']
3L


Unnamed: 0,V,Type,N_nodes,N_edges,seed,Size,MSE,MAE,right,matching,alpha,branch,branch_bin,rebase_den,rebased,V2
8,20N,-1.00,20,37,0,1250,1.163986,0.863243,0.000000,0,50,2,"(1, 2]",30,0.000000,r_ex_2b_k20_l1_h3_n20
9,20N,0.05,20,37,0,1250,1.155092,0.852532,0.027027,1,50,2,"(1, 2]",30,0.033333,r_ex_2b_k20_l1_h3_n20
10,20N,-1.00,20,36,10,1250,0.293958,0.425193,0.000000,0,50,2,"(1, 2]",29,0.000000,r_ex_2b_k20_l1_h3_n20
11,20N,0.05,20,36,10,1250,0.240622,0.385128,0.027778,1,50,2,"(1, 2]",29,0.034483,r_ex_2b_k20_l1_h3_n20
12,20N,-1.00,20,38,100,1250,1.211399,0.874832,0.000000,0,50,2,"(1, 2]",31,0.000000,r_ex_2b_k20_l1_h3_n20
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2717,20K,0.05,50,48,5000,31250,0.970014,0.785866,0.750000,36,500,1,"(0, 1]",39,0.923077,r_ex3
2718,20K,-1.00,50,50,6000,31250,0.457001,0.520015,0.500000,25,500,1,"(0, 1]",40,0.625000,r_ex3
2719,20K,0.05,50,50,6000,31250,0.457897,0.521632,0.500000,25,500,1,"(0, 1]",40,0.625000,r_ex3
2720,20K,-1.00,50,50,7000,31250,0.347962,0.444627,0.440000,22,500,1,"(0, 1]",40,0.550000,r_ex3


['20K' '20N' '3L']


Unnamed: 0,Type,V,alpha,N_nodescount,N_nodesmean,N_nodesstd,N_edgescount,N_edgesmean,N_edgesstd,seedcount,...,branchcount,branchmean,branchstd,rebase_dencount,rebase_denmean,rebase_denstd,rebasedcount,rebasedmean,rebasedstd,text
0,-1.0,20K,50,90,26.67,17.09,90,63.28,64.64,90,...,90,2.3,1.28,90,51.03,51.69,90,0.17,0.22,0.62 (0.43)
1,-1.0,20K,100,90,26.67,17.09,90,63.28,64.64,90,...,90,2.3,1.28,90,51.03,51.69,90,0.34,0.23,0.53 (0.36)
2,-1.0,20K,200,90,26.67,17.09,90,63.28,64.64,90,...,90,2.3,1.28,90,51.03,51.69,90,0.54,0.22,0.5 (0.33)
3,-1.0,20K,300,90,26.67,17.09,90,63.28,64.64,90,...,90,2.3,1.28,90,51.03,51.69,90,0.54,0.23,0.5 (0.33)
4,-1.0,20K,500,90,26.67,17.09,90,63.28,64.64,90,...,90,2.3,1.28,90,51.03,51.69,90,0.55,0.23,0.49 (0.33)
5,-1.0,20N,50,90,26.67,17.09,90,62.92,64.99,90,...,90,2.32,1.32,90,50.76,51.92,90,0.12,0.19,0.65 (0.46)
6,-1.0,20N,100,90,26.67,17.09,90,62.92,64.99,90,...,90,2.32,1.32,90,50.76,51.92,90,0.23,0.25,0.55 (0.39)
7,-1.0,20N,200,90,26.67,17.09,90,62.92,64.99,90,...,90,2.32,1.32,90,50.76,51.92,90,0.53,0.24,0.5 (0.33)
8,-1.0,20N,300,90,26.67,17.09,90,62.92,64.99,90,...,90,2.32,1.32,90,50.76,51.92,90,0.55,0.24,0.5 (0.33)
9,-1.0,20N,500,90,26.67,17.09,90,62.92,64.99,90,...,90,2.32,1.32,90,50.76,51.92,90,0.54,0.24,0.5 (0.33)


## 3L and 50% known and 20% noise 

- does knowing more protect you better against noise?

In [19]:
rerun_known_3l=False
if rerun_known_3l:
    %run -i main_biggerDAG_extended.py --version="r_ex_1b_k50_l3_h2_n20" --branchf=1 --known_p=0.5 --noise_p=0.2 --hidden_l=3 --hidden_n_p=2
    %run -i main_biggerDAG_extended.py --version="r_ex_2b_k50_l3_h2_n20" --branchf=2 --known_p=0.5 --noise_p=0.2 --hidden_l=3 --hidden_n_p=2
    %run -i main_biggerDAG_extended.py --version="r_ex_5b_k50_l3_h2_n20" --branchf=5 --known_p=0.5 --noise_p=0.2 --hidden_l=3 --hidden_n_p=2