### Imports

In [None]:
import os

import numpy as np
import pandas as pd
from pathlib import Path

import torch
import pickle
import string
import itertools

import matplotlib.pyplot as plt

import sys
sys.path.append("..")

from cdt.metrics import SHD

from simulation.simulation_tools import get_random_sim_XY
from utils import custom_binary_metrics, _from_full_to_cp

rng = np.random.default_rng()

COL_NAMES = list(string.ascii_uppercase) + ["".join(a) for a in list(itertools.permutations(list(string.ascii_uppercase), r=2))]

### Experiment 
Observe the behavior of TCS when choosing a random configuration - baseline method

In [None]:
# data_path
ie_custom_path = list(Path(".").resolve().parents)[1] / "data" / "cp_style" / "increasing_edges_cp_1"


# placeholders
errors = []
scm_list = {}
shd_list = {}
scores_list = {}
det_auc_list = {}
struct_auc_list = {}

errors_w = []
scm_list_w = {}
shd_list_w = {}
scores_list_w = {}
det_auc_list_w = {}
struct_auc_list_w = {}


# run loop
for fn in os.listdir(ie_custom_path / "data")[:]:
        print(f"\n\n------------------------------------- {fn} -------------------------------------")

        # read the time-series
        X_data = pd.read_csv(ie_custom_path / "data" / fn)
        X_data.rename(columns=dict(zip(X_data.columns, COL_NAMES[:X_data.shape[1]])), inplace=True)

        # read the ground truth
        gn = fn.split("_ts")[0] + "_struct.pt"
        Y_data = torch.load(ie_custom_path / "structure" / gn)
        # print(f"- edges : {Y_data.sum().int()}")

        # optimal simulation
        res = get_random_sim_XY(
                true_data = X_data, 
                CONFIGS = None, 
                done_eval = False,
                optimal_det_config = None,
                optimal_det_func = None, 
                verbose = True
        )

        # compare
        if isinstance(res["optimal_scm"], pd.DataFrame):
                pred_cp = _from_full_to_cp(res["optimal_scm"])
        else:
                pred_cp = res["optimal_scm"].causal_structure.causal_structure_cp
        true_cp = Y_data
        if  true_cp.shape[2]>pred_cp.shape[2]:
                pred_cp = torch.nn.functional.pad(input=pred_cp, pad=(0, true_cp.shape[2] - pred_cp.shape[2], 0, 0, 0, 0), value=0)
        if  pred_cp.shape[2]>true_cp.shape[2]:
                true_cp = torch.nn.functional.pad(input=true_cp, pad=(0, pred_cp.shape[2] - true_cp.shape[2], 0, 0, 0, 0), value=0)
        tpr, fpr, tnr, fnr, auc = custom_binary_metrics(binary=pred_cp, A=true_cp, verbose=True)
        shd_d = SHD(target=true_cp.numpy(), pred=pred_cp.numpy(), double_for_anticausal=True)

        # store
        scm_list[fn] =  pred_cp
        scores_list[fn] =  res["scores"]
        det_auc_list[fn] =  res["auc"]
        struct_auc_list[fn] =  {"tpr": tpr, "fpr": fpr, "tnr": tnr, "fnr": fnr, "auc": auc, "shd": shd_d,
                                "pred#": pred_cp.sum().numpy(), "true#": true_cp.sum().numpy()}
        shd_list[fn] = shd_d


### Store Results

In [None]:
# merge
res_both = {
    "res" : struct_auc_list,
}

# # store results
# pickle.dump(res_both, open(list(Path(".").resolve().parents)[1] / "data" / "results" / "random_graph" / "res_cp_vs_3.p", "wb"))

### Submission Plots

In [None]:
# double figure for the submission

# load data
# res_dense = pickle.load(open(list(Path(".").resolve().parents)[1] / "data" / "results" / "dense_graph" / "res_cp_vs_1.p", "rb"))
res_dense = pickle.load(open(list(Path(".").resolve().parents)[1] / "data" / "results" / "dense_graph" / "res_cp_vs_4n.p", "rb"))
res_oracle = pickle.load(open(list(Path(".").resolve().parents)[1] / "data" / "results" / "oracle_graph" / "res_cp_vs_1.p", "rb"))
res_random = pickle.load(open(list(Path(".").resolve().parents)[1] / "data" / "results" / "random_graph" / "res_cp_vs_3.p", "rb"))

# sort results according to ground truth edge density - dense
sorted_true_dense, sorted_shd_dense = list(zip(*sorted(list(zip(
    [v["true#"] for v in list(res_dense['res'].values())], 
    [v["shd"] for v in list(res_dense['res'].values())])), key=lambda x: x[0])[:])) 

sorted_true_w_dense, sorted_shd_w_dense = list(zip(*sorted(list(zip(
    [v["true#"] for v in list(res_dense['res_w'].values())], 
    [v["shd"] for v in list(res_dense['res_w'].values())])), key=lambda x: x[0])[:])) 

# sort results according to ground truth edge density - oracle
sorted_true_oracle, sorted_shd_oracle = list(zip(*sorted(list(zip(
    [v["true#"] for v in list(res_oracle['res'].values())], 
    [v["shd"] for v in list(res_oracle['res'].values())])), key=lambda x: x[0])[:])) 

sorted_true_w_oracle, sorted_shd_w_oracle = list(zip(*sorted(list(zip(
    [v["true#"] for v in list(res_oracle['res_w'].values())], 
    [v["shd"] for v in list(res_oracle['res_w'].values())])), key=lambda x: x[0])[:]))

# sort results according to ground truth edge density - random
sorted_true_random, sorted_shd_random = list(zip(*sorted(list(zip(
    [v["true#"] for v in list(res_random['res'].values())], 
    [v["shd"] for v in list(res_random['res'].values())])), key=lambda x: x[0])[:])) 

# plot
f, axs = plt.subplots(ncols=2, figsize=(15, 4.5), sharey=True)

axs[1].scatter(x=sorted_true_oracle, y=sorted_shd_oracle, label="w/o penalty", color="darkolivegreen", s=100, marker="s")
axs[1].plot(sorted_true_oracle, sorted_shd_oracle, "--", color="darkolivegreen", alpha=0.3)
axs[1].scatter(x=sorted_true_w_oracle, y=sorted_shd_w_oracle, label="w/ penalty", color='indianred', s=100, marker="o")
axs[1].plot(sorted_true_w_oracle, sorted_shd_w_oracle, "--", color='indianred', alpha=0.3)
axs[1].scatter(x=sorted_true_random, y=sorted_shd_random, label="random", color='black', s=100, marker="p")
axs[1].plot(sorted_true_random, sorted_shd_random, "--", color='black', alpha=0.3)
axs[1].set_xticks(list(set([int(x) for x in sorted_true_oracle])))
axs[1].set_xticklabels(list(set([int(x) for x in sorted_true_oracle])))
axs[1].set_xlabel("#edges", fontsize=14)
axs[1].set_ylabel("SHD", fontsize=14)
axs[1].set_title("Oracle Graph", fontdict={"size": 14, "color": "darkred", 'family': 'serif'})

f.legend(loc=(0.37, 0.925), ncol=3, fontsize=12)

axs[0].scatter(x=sorted_true_dense, y=sorted_shd_dense, label="w/o penalty", color="darkolivegreen", s=100, marker="s")
axs[0].plot(sorted_true_dense, sorted_shd_dense, "--", color="darkolivegreen", alpha=0.3)
axs[0].scatter(x=sorted_true_w_dense, y=sorted_shd_w_dense, label="w/ penalty", color='indianred', s=100)
axs[0].plot(sorted_true_w_dense, sorted_shd_w_dense, "--", color='indianred', alpha=0.3)
axs[0].scatter(x=sorted_true_random, y=sorted_shd_random, label="random", color='black', s=100, marker="p")
axs[0].plot(sorted_true_random, sorted_shd_random, "--", color='black', alpha=0.3)
axs[0].set_xticks(list(set([int(x) for x in sorted_true_dense])))
axs[0].set_xticklabels(list(set([int(x) for x in sorted_true_dense])))
axs[0].set_xlabel("#edges", fontsize=14)
axs[0].set_ylabel("SHD", fontsize=14)
axs[0].set_title("Dense Graph", fontdict={"size": 14, "color": "darkred"})

plt.tight_layout()
plt.show()

In [None]:
# double figure for the submission

# load data
# res_dense = pickle.load(open(list(Path(".").resolve().parents)[1] / "data" / "results" / "dense_graph" / "res_cp_vs_1.p", "rb"))
res_dense = pickle.load(open(list(Path(".").resolve().parents)[1] / "data" / "results" / "dense_graph" / "res_cp_vs_4n.p", "rb"))
res_oracle = pickle.load(open(list(Path(".").resolve().parents)[1] / "data" / "results" / "oracle_graph" / "res_cp_vs_1.p", "rb"))
res_random = pickle.load(open(list(Path(".").resolve().parents)[1] / "data" / "results" / "random_graph" / "res_cp_vs_3.p", "rb"))

# sort results according to ground truth edge density - dense
sorted_true_dense, sorted_shd_dense = list(zip(*sorted(list(zip(
    [v["true#"] for v in list(res_dense['res'].values())], 
    [v["shd"] for v in list(res_dense['res'].values())])), key=lambda x: x[0])[:])) 
sorted_true_dense = sorted_true_dense[-10:]
sorted_shd_dense = sorted_shd_dense[-10:]

sorted_true_w_dense, sorted_shd_w_dense = list(zip(*sorted(list(zip(
    [v["true#"] for v in list(res_dense['res_w'].values())], 
    [v["shd"] for v in list(res_dense['res_w'].values())])), key=lambda x: x[0])[:])) 
sorted_true_w_dense = sorted_true_w_dense[-10:]
sorted_shd_w_dense = sorted_shd_w_dense[-10:]

# sort results according to ground truth edge density - oracle
sorted_true_oracle, sorted_shd_oracle = list(zip(*sorted(list(zip(
    [v["true#"] for v in list(res_oracle['res'].values())], 
    [v["shd"] for v in list(res_oracle['res'].values())])), key=lambda x: x[0])[:])) 
sorted_true_oracle = sorted_true_oracle[-10:]
sorted_shd_oracle = sorted_shd_oracle[-10:]

sorted_true_w_oracle, sorted_shd_w_oracle = list(zip(*sorted(list(zip(
    [v["true#"] for v in list(res_oracle['res_w'].values())], 
    [v["shd"] for v in list(res_oracle['res_w'].values())])), key=lambda x: x[0])[:]))
sorted_true_w_oracle = sorted_true_w_oracle[-10:]
sorted_shd_w_oracle = sorted_shd_w_oracle[-10:]

# sort results according to ground truth edge density - random
sorted_true_random, sorted_shd_random = list(zip(*sorted(list(zip(
    [v["true#"] for v in list(res_random['res'].values())], 
    [v["shd"] for v in list(res_random['res'].values())])), key=lambda x: x[0])[:]))
sorted_true_random = sorted_true_random[-10:]
sorted_shd_random = sorted_shd_random[-10:]

# plot
f, axs = plt.subplots(nrows=2, figsize=(8, 7), sharey=True)

axs[1].scatter(x=sorted_true_oracle, y=sorted_shd_oracle, label="w/o penalty", color="darkolivegreen", s=100, marker="s")
axs[1].plot(sorted_true_oracle, sorted_shd_oracle, "--", color="darkolivegreen", alpha=0.3)
axs[1].scatter(x=sorted_true_w_oracle, y=sorted_shd_w_oracle, label="w/ penalty", color='indianred', s=100, marker="o")
axs[1].plot(sorted_true_w_oracle, sorted_shd_w_oracle, "--", color='indianred', alpha=0.3)
axs[1].scatter(x=sorted_true_random, y=sorted_shd_random, label="random", color='black', s=100, marker="p")
axs[1].plot(sorted_true_random, sorted_shd_random, "--", color='black', alpha=0.3)
axs[1].set_xticks(list(set([int(x) for x in sorted_true_oracle])))
axs[1].set_xticklabels(list(set([int(x) for x in sorted_true_oracle])))
axs[1].set_xlabel("#edges", fontsize=14)
axs[1].set_ylabel("SHD", fontsize=14)
axs[1].set_title("Oracle Graph", fontdict={"size": 14, "color": "darkred", 'family': 'serif'})

f.legend(loc=(0.25, 0.962), ncol=3, fontsize=12)

axs[0].scatter(x=sorted_true_dense, y=sorted_shd_dense, label="w/o penalty", color="darkolivegreen", s=100, marker="s")
axs[0].plot(sorted_true_dense, sorted_shd_dense, "--", color="darkolivegreen", alpha=0.3)
axs[0].scatter(x=sorted_true_w_dense, y=sorted_shd_w_dense, label="w/ penalty", color='indianred', s=100)
axs[0].plot(sorted_true_w_dense, sorted_shd_w_dense, "--", color='indianred', alpha=0.3)
axs[0].scatter(x=sorted_true_random, y=sorted_shd_random, label="random", color='black', s=100, marker="p")
axs[0].plot(sorted_true_random, sorted_shd_random, "--", color='black', alpha=0.3)
axs[0].set_xticks(list(set([int(x) for x in sorted_true_dense])))
axs[0].set_xticklabels(list(set([int(x) for x in sorted_true_dense])))
axs[0].set_xlabel("#edges", fontsize=14)
axs[0].set_ylabel("SHD", fontsize=14)
axs[0].set_title("Dense Graph", fontdict={"size": 14, "color": "darkred"})

plt.tight_layout()
plt.show()