### Imports

In [None]:
import os
import json
import torch
import pickle
import string
import itertools
import functools
import numpy as np
import pandas as pd
from tqdm import trange
from pathlib import Path
from copy import deepcopy

from matplotlib.gridspec import GridSpec
import matplotlib.pyplot as plt
import seaborn as sns

import sys
sys.path.append("..")

from tempogen.temporal_scm import TempSCM
from tempogen.functional_utils import (_torch_tanh, _torch_arctan, _torch_sin, _torch_cos,
                                       _torch_sigmoid, _torch_pow, _torch_identity, _torch_sqrt)

from tempogen.temporal_random_generation import get_p_edge, get_funcs, get_z_distribution 
from simulation.simulation_utils import simulate

from cdt.metrics import SHD
from simulation.simulation_metrics import mmd_torch

from simulation.simulation_tools import get_optimal_sim_XY, run_detection_metrics_XY, prepare_det_data, run_detection_metrics
from utils import custom_binary_metrics, _from_full_to_cp

from CausalTime.tools import generate_CT
from synthcity.plugins import Plugins
from synthcity.plugins.core.dataloader import TimeSeriesDataLoader

from sdv.metadata import Metadata
from sdv.sequential import PARSynthesizer


rng = np.random.default_rng()

def prind(di): print(json.dumps(di, sort_keys=False, indent=4))

COL_NAMES = list(string.ascii_uppercase) + ["".join(a) for a in list(itertools.permutations(list(string.ascii_uppercase), r=2))]

par_dir = Path(os.getcwd()).parents[1].as_posix()

### Data

In [None]:
# plotting function
def plot_intervened_ts(data, i_step=600):
    """
    """
    fig = plt.figure(layout="constrained", figsize=(9, 3))
    # gs = GridSpec(nrows=2, ncols=2, figure=fig, width_ratios=(2, 1))

    gs = GridSpec(nrows=2, ncols=2, figure=fig, width_ratios=(1, i_step/1000))
    ax1 = fig.add_subplot(gs[1, :])
    ax2 = fig.add_subplot(gs[0, 0])
    ax3 = fig.add_subplot(gs[0, 1])

    sns.lineplot(data=data, ax=ax1, legend=False, palette=sns.color_palette("pastel6", n_colors=data.shape[1]))
    ax1.axvline(x=i_step, color="black", linestyle="--", alpha=0.6)
    ax1.set_yticks(ticks=[], labels=[])
    ax1.set_xlabel("Intervened Time-series", fontdict={"size": 12, "weight": "roman"})

    sns.lineplot(data=data.loc[:int(len(data)*(i_step/1000))], ax=ax2, legend=False, palette=sns.color_palette("pastel6", n_colors=data.shape[1]))
    ax2.set_yticks(ticks=[], labels=[])
    ax2.set_xlabel("Data before intervention", fontdict={"size": 12, "weight": "roman"})

    sns.lineplot(data=data.loc[int(len(data)*(i_step/1000)):], ax=ax3, legend=False, palette=sns.color_palette("pastel6", n_colors=data.shape[1]))
    ax3.set_yticks(ticks=[], labels=[])
    ax3.set_xlabel("Data after intervention", fontdict={"size": 12, "weight": "roman"})

    plt.show()

# paths
par_dir = Path(os.getcwd()).parents[1].as_posix()

# (hyper) parameters
FN = "cp_0"
SF = "L_1L+-"
NS = 500
N_SCMS = 20
FNS = 1
IPR = range(int(0.5*NS), int(0.75*NS), 1)
LOC = 5.0
SCALE = 1.0

# # the space for the # vars during random generation
n_vars_space = {
    "a": np.arange(start=3, stop=13, step=1).tolist(), 
    "p": np.full(fill_value=0.1, shape=[10]).tolist()
    # "p": [0.05, 0.05, 0.125, 0.125, 0.15, 0.15, 0.125, 0.125, 0.05, 0.05]
}
# - dynamic version of get_n_vars, able to favor specific size of graphs 
def get_n_vars(a, p):
    return rng.choice(a=a, p=p)

# # the space for the # lags during random generation
n_lags_space = {
    "a": [1, 2, 3], 
    "p": [1.0, 0.0, 0.0]
}
# - dynamic version of get_n_vars, able to favor specific size of graphs 
def get_n_lags(a, p):
    return rng.choice(a=a, p=p)

# the space for the edge probability during random generation 
# - dynamic version, depending on the # vars & # lags, to keep large graphs sparser; based on preconfigured options; still testing this
def p_edge_space(n_vars, n_lags):
    total_edges = (n_vars**2)*n_lags
    if total_edges < 100:
        return {
            "c": None,
            "values": [3/total_edges, 5/total_edges, 7/total_edges],
            "weights": [0.6, 0.3, 0.1] 
        }
    elif total_edges < 200:
        return {
            "c": None,
            "values": [5/total_edges, 7/total_edges, 9/total_edges],
            "weights": [0.6, 0.3, 0.1] 
        }
    else:
        return {
            "c": None,
            "values": [9/total_edges, 12/total_edges, 15/total_edges],
            "weights": [0.6, 0.3, 0.1] 
        }

# the space for functional dependencies during random generation
funcs_space = {
    "c": None, 
    "functions": [_torch_identity, _torch_sqrt, _torch_sigmoid, _torch_sin, _torch_cos, _torch_arctan],
    # "weights": [0.2, 0.2, 0.1, 0.1, 0.1, 0.3],
    "weights": [0.4, 0.4, 0.2, 0.0, 0.0, 0.0],  
}

# the space for noise distribution during random generation
z_distribution_space = {
    "c": None,
    "functions": [
    torch.distributions.normal.Normal(loc=0, scale=0.005), 
    torch.distributions.uniform.Uniform(low=-0.15, high=0.15)],
    "weights": [0.01, 0.99] 
}

def _torch_noisy_f(x, loc=LOC, scale=SCALE, func=_torch_identity):
    """ """
    noise_dist = torch.distributions.normal.Normal(loc=loc, scale=scale)
    return func(x) + noise_dist.sample()

i_step = np.random.choice(IPR)
auc_ori_list = []
auc_pre_list = []
auc_post_list = []

# main loop for random SCM generation
for ctr in range(N_SCMS):

    # DFs & SCMs
    n_vars = get_n_vars(**n_vars_space)
    n_lags = get_n_lags(**n_lags_space)
    p_edge = get_p_edge(**p_edge_space(n_vars=n_vars, n_lags=n_lags))
    scm = TempSCM(
        method="C",
        n_vars=n_vars,
        n_lags=n_lags,
        p_edge=p_edge,
        funcs=get_funcs(**funcs_space),
        z_distributions=get_z_distribution(**z_distribution_space)   
    )

    # Generate obervational
    ts_pre = scm.generate_time_series(n_samples=i_step).reset_index(drop=True)
    print(f"len(ts_pre): {len(ts_pre)}")

    # Resample DFs & SCMs in case of infinity values
    while np.isinf(ts_pre.values.sum()):
        n_vars = get_n_vars(**n_vars_space)
        n_lags = get_n_lags(**n_lags_space)
        p_edge = get_p_edge(**p_edge_space(n_vars=n_vars, n_lags=n_lags))
        scm = TempSCM(
            method="C",
            n_vars=n_vars,
            n_lags=n_lags,
            p_edge=p_edge,
            funcs=get_funcs(**funcs_space),
            z_distributions=get_z_distribution(**z_distribution_space)   
        )
        ts_pre = scm.generate_time_series(n_samples=i_step)

    # Simulate with TCS
    true_data = ts_pre.copy()
    results_tcs = get_optimal_sim_XY(true_data=true_data, sparsity_penalty=True)

    # Make a copy of the true and predicted SCMs
    true_scm = deepcopy(scm)
    true_scm_i = deepcopy(scm)
    tcs_scm = deepcopy(results_tcs["optimal_scm"])
    tcs_scm_i = deepcopy(results_tcs["optimal_scm"])

    true_scm._reset_time_series()
    true_scm_i._reset_time_series()
    tcs_scm._reset_time_series()
    tcs_scm_i._reset_time_series()

    # Define intervened nodes & perform the soft intervention
    n2is = np.random.choice(a=range(len(scm.temp_nodes)), size=FNS, replace=False)
    for n2i in n2is:
        true_scm_i.temp_nodes[n2i].func = functools.partial(_torch_noisy_f, func=true_scm_i.temp_nodes[n2i].func)
        tcs_scm_i.temp_nodes[n2i].func = functools.partial(_torch_noisy_f, func=tcs_scm_i.temp_nodes[n2i].func)

    # Generate interventional data from the true SCM & the predicted SCM
    ts_pre_true = true_scm.generate_time_series(n_samples=NS)
    ts_pre_tcs = tcs_scm.generate_time_series(n_samples=NS)
    ts_post_i_true = true_scm_i.generate_time_series(n_samples=NS)
    ts_post_i_tcs = tcs_scm_i.generate_time_series(n_samples=NS)

    # Evaluate
    # auc_pre = run_detection_metrics(real=ts_pre_true, synthetic=ts_pre_tcs)['auc']
    # auc_post = run_detection_metrics(real=ts_post_i_true, synthetic=ts_post_i_tcs)['auc']
    auc_pre, _, _ = results_tcs["optimal_detector"](real=ts_pre_true, synthetic=ts_pre_tcs, **results_tcs["optimal_detector_config"])
    auc_post, _, _ = results_tcs["optimal_detector"](real=ts_post_i_true, synthetic=ts_post_i_tcs, **results_tcs["optimal_detector_config"])
    print(f"AUC ori : {results_tcs['auc']}")
    print(f"AUC pre : {auc_pre}")
    print(f"AUC post : {auc_post}")
    auc_ori_list.append(round(results_tcs["auc"], 2))
    auc_pre_list.append(round(auc_pre, 2))
    auc_post_list.append(round(auc_post, 2))


In [None]:
# Store results
pre_post_auc = {'pre': auc_ori_list, 'post': auc_post_list}

interventions_dir = Path(os.getcwd()).parents[1] / 'data' / 'results' / 'interventions'
os.makedirs(interventions_dir, exist_ok=True)
with open(interventions_dir / f"pre_post_auc_{LOC}L_{SCALE}S.json", "w") as f:
    json.dump(pre_post_auc, f)

In [None]:
# Check results
with open(interventions_dir / f"pre_post_auc_{LOC}L_{SCALE}S.json", "r") as f:
    pre_post_auc = json.load(f)

print(np.mean(pre_post_auc["pre"]).round(2))
print(np.mean(auc_pre_list).round(2))
print(np.mean(pre_post_auc["post"]).round(2))

### Debugging

In [None]:
# # - proof that on averace the return of TCS discrimination score is on average consisten with the run_detection_metric output
# dl = []
# for _ in trange(100):
#     tcs_scm._reset_time_series()
#     true_scm._reset_time_series()
#     tcs_ts = tcs_scm.generate_time_series(n_samples=NS)
#     true_ts = true_scm.generate_time_series(n_samples=NS)
#     dl.append(run_detection_metrics(real=true_ts, synthetic=tcs_ts, verbose=False)["auc"])

In [None]:
# print(f"TCS auc : {round(results_tcs['auc'], 2)}")
# print(f"RDM auc : {np.mean(dl).round(2)}")