In [1]:
import wandb, pdb
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
sns.set_style('dark')

# !pip install latex


In [2]:
api = wandb.Api(timeout=30)
runs = api.runs("structurelearning/structure-learning")
max_steps = 20000
num_subplots = 9
h = int(np.sqrt(num_subplots))
w = int(num_subplots / h)


In [3]:
def get_reqd_runs(exp_config):
    reqd_runs = []
    for run in runs:
        reqd_run = True
        for k,v in exp_config.items():
            if run.config[k] != v: 
                reqd_run = False
                break
        if reqd_run is False: continue
        
        reqd_runs.append(run)   # This is a required run
    assert len(reqd_runs) == 20
    print(f"Fetched {len(reqd_runs)} runs")
    return reqd_runs

def get_plotting_data(reqd_runs, reqd_keys):
    seed_data = {}
    for key in reqd_keys: seed_data[key] = []

    for run in reqd_runs:
        plotting_data = run.scan_history(reqd_keys, max_steps)

        for key in reqd_keys:
            seed_data[key].append([data[key] for data in plotting_data])
    
    for key in reqd_keys:
        seed_data[key] = [x for x in seed_data[key] if x]

    return seed_data

def plot_data(key, seed_data, ax, steps, label = None, color='blue', marker='x', linestyle='-', markersize=10):
    if label is None: label = ''
    x_axis = np.array(steps)
    y_axis_seeds = np.array(seed_data[key])
    
    yaxis = y_axis_seeds.mean(0)
    fill = y_axis_seeds.std(0)

    ax.plot(x_axis, yaxis, label=label, color=color, marker=marker, linestyle=linestyle, markersize=markersize)
    ax.fill_between(x_axis, yaxis - fill, yaxis + fill, alpha=0.3, color=color)
    ax.grid()

In [4]:
def plot_metrics_for_nodes(reqd_keys, exp_config, num_nodes_list, exp_edges_list, 
                            title, x_steps, fname, exps_per_subplot, color_list, marker='x', 
                            linestyle='-', markersize=10, figsize=(10, 10), common_x_label='',
                            common_y_label=''):
    
    idxs = []
    node_exp_edge_pair = []
    lines_labels = []

    len_nodes = len(num_nodes_list)
    len_edge_degrees = len(exp_edges_list)
    num_subplots = len_nodes * len_edge_degrees

    f, axes = plt.subplots(len_edge_degrees, len_nodes, figsize = figsize )

    for i in range(len_edge_degrees):
        ee = exp_edges_list[i]
        for j in range(len_nodes):
            idxs.append((i, j))
            node_exp_edge_pair.append((num_nodes_list[j], ee))

    for i in range(num_subplots):
        if num_subplots == 1:
            ax = axes
        elif len_edge_degrees == 1 or len_nodes == 1:
            ax = axes[i]
        else:
            ax = axes[idxs[i]]
            
        key = reqd_keys[0]

        nodes = node_exp_edge_pair[i][0]
        edge_degree = node_exp_edge_pair[i][1]
        exp_config['num_nodes'] = nodes
        exp_config['exp_edges'] = edge_degree

        for j in range(len(exps_per_subplot)):
            exp = exps_per_subplot[j]
            exp_config['exp_name'] = exp
            reqd_runs = get_reqd_runs(exp_config)
            plotting_data = get_plotting_data(reqd_runs, reqd_keys)

            if exp == 'Decoder BCD observational learn L (linear projection)':
                label = 'Observational data'
            
            elif exp == 'Decoder BCD single interventional learn L (linear projection)':
                label = 'Random single node interventions'
                
            elif exp == 'Decoder BCD multi interventional learn L (linear projection)':
                label = 'Random multi node interventions'
            
            print(label)
            print()
            
            plot_data(  key, 
                        plotting_data, 
                        ax, 
                        x_steps, 
                        label=label,
                        color=color_list[j],
                        marker=marker,
                        linestyle=linestyle,
                        markersize=markersize
                    )

        # ax.set_title(f"d={nodes}, ER-{int(edge_degree)}")
    

    # for i in range(num_subplots):
    #     ax = axes[idxs[i]]
    #     key = reqd_keys[1:][i]
    #     plot_data(key, plotting_data, ax, label = 'ER-2', color='green')
    

    lines_labels.append(ax.get_legend_handles_labels())
    lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
    label_indices = list(np.arange(0, 2*(len(exps_per_subplot)), 2))
    
    f.legend(lines[:len(exps_per_subplot)], 
            labels[:len(exps_per_subplot)], 
            loc='lower center', 
            ncol=len(exps_per_subplot), 
            bbox_to_anchor = [0.5,-0.1])

    
    f.suptitle(title)
    f.supxlabel(common_x_label)
    f.supylabel(common_y_label)
    plt.tight_layout()
    f.savefig(fname, bbox_inches='tight', dpi=300)
    print(f'Saved figure: {fname}')
    plt.close('all')


### Decoder BCD runs (learn L) for obs, single and multi interv data for d={6, 20, 50} and exp_edges={1.0, 2.0, 4.0}

In [5]:
exp_config = {
    'exp_name': 'Decoder BCD observational learn L (linear projection)',
    'num_nodes': 6,
    'proj_dims': 100,
    'exp_edges': 1.0
}

reqd_keys = [   'Evaluations/SHD', 
                # 'Evaluations/AUROC', 
                # 'Evaluations/SHD_C',
                # 'Evaluations/MCC', 
                # 'Evaluations/AUPRC_W',
                # 'Evaluations/AUPRC_G'
                # 'L_MSE',
                # 'true_obs_KL_term_Z',
                # 'Z_MSE',
            ]

steps = np.arange(0, 5200, 200)

In [6]:
exp_names = ['Decoder BCD observational learn L (linear projection)',
            'Decoder BCD single interventional learn L (linear projection)',
            'Decoder BCD multi interventional learn L (linear projection)']

color_list = [  '#fbbc05', # yellow
                '#34a853', # green
                '#ea4335' # red/cinnabar
            ]

plot_metrics_for_nodes( ['Evaluations/SHD'], 
                        exp_config, 
                        [6, 20, 50], 
                        [1.0], 
                        '', 
                        steps,
                        'exp_edge_1_lin_dbcd_learn_L_shd_vs_iterations.pdf',
                        exp_names,
                        color_list,
                        marker='X',
                        linestyle=(0, (12, 12)),
                        markersize=5,
                        figsize=(12, 4),
                        common_x_label='Iterations',
                        common_y_label=r'$\mathbb{E}-SHD$',
                )
                        


Fetched 20 runs
Observational data

Fetched 20 runs
Random single node interventions

Fetched 20 runs
Random multi node interventions

Fetched 20 runs
Observational data

Fetched 20 runs
Random single node interventions

Fetched 20 runs
Random multi node interventions

Fetched 20 runs
Observational data

Fetched 20 runs
Random single node interventions

Fetched 20 runs
Random multi node interventions

Saved figure: exp_edge_1_lin_dbcd_learn_L_shd_vs_iterations.pdf


In [7]:
plot_metrics_for_nodes( ['Evaluations/SHD'], 
                        exp_config, 
                        [6, 20, 50], 
                        [2.0], 
                        '', 
                        steps,
                        'exp_edge_2_lin_dbcd_learn_L_shd_vs_iterations.pdf',
                        exp_names,
                        color_list,
                        marker='X',
                        linestyle=(0, (12, 12)),
                        markersize=5,
                        figsize=(12, 4),
                        common_x_label='Iterations',
                        common_y_label=r'$\mathbb{E}-SHD$',
                )

Fetched 20 runs
Observational data

Fetched 20 runs
Random single node interventions

Fetched 20 runs
Random multi node interventions

Fetched 20 runs
Observational data

Fetched 20 runs
Random single node interventions

Fetched 20 runs
Random multi node interventions

Fetched 20 runs
Observational data

Fetched 20 runs
Random single node interventions

Fetched 20 runs
Random multi node interventions

Saved figure: exp_edge_2_lin_dbcd_learn_L_shd_vs_iterations.pdf


In [8]:
plot_metrics_for_nodes( ['Evaluations/SHD'], 
                        exp_config, 
                        [6, 20, 50], 
                        [4.0], 
                        '', 
                        steps,
                        'exp_edge_4_lin_dbcd_learn_L_shd_vs_iterations.pdf',
                        exp_names,
                        color_list,
                        marker='X',
                        linestyle=(0, (12, 12)),
                        markersize=5,
                        figsize=(12, 4),
                        common_x_label='Iterations',
                        common_y_label=r'$\mathbb{E}-SHD$',
                )

Fetched 20 runs
Observational data

Fetched 20 runs
Random single node interventions

Fetched 20 runs
Random multi node interventions

Fetched 20 runs
Observational data

Fetched 20 runs
Random single node interventions

Fetched 20 runs
Random multi node interventions

Fetched 20 runs
Observational data

Fetched 20 runs
Random single node interventions

Fetched 20 runs
Random multi node interventions

Saved figure: exp_edge_4_lin_dbcd_learn_L_shd_vs_iterations.pdf
