In [None]:
import wandb
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd
sns.set_style('dark')
from collections import defaultdict
import helper

In [None]:
api = wandb.Api(timeout=40)
runs = api.runs("structurelearning/BIOLS")
max_steps = 20000
reqd_keys = ['Evaluations/SHD', 'Evaluations/AUROC', 'L_MSE']

In [None]:
def get_plot_dataframe(data_folders, runs, reqd_keys):
    plot_data_dict = defaultdict(lambda: [])

    for data_folder in data_folders:
        exp_config = {'biols_data_folder': data_folder}
        exp_run = helper.get_reqd_runs(exp_config, runs, num_seeds=5)
        plotting_data = helper.get_plotting_data(exp_run, reqd_keys)

        splits = data_folder.split('-')
        exp_edges = int(splits[0][-1])
        proj = splits[2][:-4]
        if proj == '3_layer_mlp':   proj = 'nonlinear'
        d = int(splits[3][1:])
        D = int(splits[4][1:])
        num_intervs = int(splits[6][7:])
        interv_sets = int(splits[7][4:])

        for key in reqd_keys:
            num_seeds = len(plotting_data[key][:, -1])
            lhs_key = key
            if 'SHD' in key:        lhs_key = 'SHD'
            elif 'AUROC' in key:    lhs_key = 'AUROC'
            plot_data_dict[lhs_key] += (plotting_data[key][:, -1]).tolist()
        
        rstring = r"$ER-{}, d={}, D={}\ $".format(exp_edges, d, D)
        plot_data_dict['Graph density'] += [rstring] * num_seeds
        plot_data_dict['Interventional Sets'] += [interv_sets] * num_seeds
        plot_data_dict['Model'] += ['BIOLS'] * num_seeds
        plot_data_dict['biols_data_folder'] += [exp_config['biols_data_folder']] * num_seeds

    plot_df = pd.DataFrame(plot_data_dict)
    name = f'er{exp_edges}_d{d}_D{D}_proj{proj}'
    return plot_df, name

In [None]:
def fetch_and_plot_num_intervs_ablation(num_nodes, num_intervs, basepath, runs, reqd_keys, box_widths=0.8, fontsize=18):
    zfilled_nodes = str(num_nodes).zfill(3)
    data_folders = []
    for num_interv in num_intervs:
        datafolder = f'er1-ws_datagen_fix_noise_interv_noise-linearproj-d{zfilled_nodes}-D0100-multi-n_pairs{num_interv}-sets{int(num_interv/100)}-gaussianinterv'
        data_folders.append(datafolder)

    plot_df, name = get_plot_dataframe(data_folders, runs, reqd_keys)
    helper.plot_num_interventions_ablation(plot_df, basepath, name, reqd_keys, fontsize, num_intervs, box_widths=box_widths, title=None)

In [None]:
fetch_and_plot_num_intervs_ablation(
    num_nodes=20, 
    num_intervs=[2000, 4000, 6000, 8000, 10000], 
    basepath='/home/mila/j/jithendaraa.subramanian/scratch/biols_datasets/num_intervention_ablations', 
    runs=runs, 
    reqd_keys=reqd_keys, 
    box_widths=0.6
)