### Imports

In [1]:
import os
import json
import torch
import pickle
import string
import itertools
import numpy as np
import pandas as pd
from tqdm import trange
from pathlib import Path

import matplotlib.pyplot as plt
import seaborn as sns

import sys
sys.path.append("..")

from tempogen.temporal_scm import TempSCM
from tempogen.functional_utils import (_torch_tanh, _torch_arctan, _torch_sin, _torch_cos,
                                       _torch_sigmoid, _torch_pow, _torch_identity, _torch_sqrt)

from tempogen.temporal_random_generation import get_p_edge, get_funcs, get_z_distribution 
from simulation.simulation_utils import simulate

from cdt.metrics import SHD
from simulation.simulation_metrics import mmd_torch

from simulation.simulation_tools import get_optimal_sim_XY, run_detection_metrics_XY, prepare_det_data, run_detection_metrics
from utils import custom_binary_metrics, _from_full_to_cp

from CausalTime.tools import generate_CT
from synthcity.plugins import Plugins
from synthcity.plugins.core.dataloader import TimeSeriesDataLoader

from sdv.metadata import Metadata
from sdv.sequential import PARSynthesizer


rng = np.random.default_rng()

def prind(di): print(json.dumps(di, sort_keys=False, indent=4))

COL_NAMES = list(string.ascii_uppercase) + ["".join(a) for a in list(itertools.permutations(list(string.ascii_uppercase), r=2))]

par_dir = Path(os.getcwd()).parents[1].as_posix()
FN = "cp_0"
SF = "L_1L+-"

TimesFM v1.2.0. See https://github.com/google-research/timesfm/blob/master/README.md for updated APIs.
Loaded Jax TimesFM.
Loaded PyTorch TimesFM.


Detecting 1 CUDA device(s).





### Data Generation

In [None]:
# paths
par_dir = Path(os.getcwd()).parents[1].as_posix() 
FN = "cp_0"
SF = "L_1L+-"
N_SCMS = 10

# # the space for the # vars during random generation
n_vars_space = {
    "a": np.arange(start=10, stop=20, step=1).tolist(), 
    "p": np.full(fill_value=1/10, shape=[10]).tolist()
    # "p": [0.05, 0.05, 0.125, 0.125, 0.15, 0.15, 0.125, 0.125, 0.05, 0.05]
}
# - dynamic version of get_n_vars, able to favor specific size of graphs 
def get_n_vars(a, p):
    return rng.choice(a=a, p=p)

# # the space for the # lags during random generation
n_lags_space = {
    "a": [1, 2, 3], 
    "p": [1.0, 0.0, 0.0]
}
# - dynamic version of get_n_vars, able to favor specific size of graphs 
def get_n_lags(a, p):
    return rng.choice(a=a, p=p)

# the space for the edge probability during random generation 
# - dynamic version, depending on the # vars & # lags, to keep large graphs sparser; based on preconfigured options; still testing this
def p_edge_space(n_vars, n_lags):
    total_edges = (n_vars**2)*n_lags
    if total_edges < 100:
        return {
            "c": None,
            "values": [3/total_edges, 5/total_edges, 7/total_edges],
            "weights": [0.6, 0.3, 0.1] 
        }
    elif total_edges < 200:
        return {
            "c": None,
            "values": [5/total_edges, 7/total_edges, 9/total_edges],
            "weights": [0.6, 0.3, 0.1] 
        }
    else:
        return {
            "c": None,
            "values": [9/total_edges, 12/total_edges, 15/total_edges],
            "weights": [0.6, 0.3, 0.1] 
        }

# the space for functional dependencies during random generation
funcs_space = {
    "c": None, 
    "functions": [_torch_identity, _torch_sqrt, _torch_sigmoid, _torch_sin, _torch_cos, _torch_arctan],
    # "weights": [0.2, 0.2, 0.1, 0.1, 0.1, 0.3],
    "weights": [0.6, 0.4, 0.0, 0.0, 0.0, 0.0],  
}

# the space for noise distribution during random generation
z_distribution_space = {
     "c": None,
     "functions": [
        torch.distributions.normal.Normal(loc=0, scale=0.005), 
        torch.distributions.uniform.Uniform(low=-0.15, high=0.15)],
     "weights": [0.01, 0.99] 
}


# main loop for random SCM generation
for ctr in range(N_SCMS):

    # DFs & SCMs
    n_vars = get_n_vars(**n_vars_space)
    n_lags = get_n_lags(**n_lags_space)
    p_edge = get_p_edge(**p_edge_space(n_vars=n_vars, n_lags=n_lags))
    scm = TempSCM(
        method="C",
        n_vars=n_vars,
        n_lags=n_lags,
        p_edge=p_edge,
        funcs=get_funcs(**funcs_space),
        z_distributions=get_z_distribution(**z_distribution_space)   
    )
    df = scm.generate_time_series(n_samples=500) 

    # resample DFs & SCMs in case of infinity values
    while np.isinf(df.values.sum()):
        n_vars = get_n_vars(**n_vars_space)
        n_lags = get_n_lags(**n_lags_space)
        p_edge = get_p_edge(**p_edge_space(n_vars=n_vars, n_lags=n_lags))
        scm = TempSCM(
            method="C",
            n_vars=n_vars,
            n_lags=n_lags,
            p_edge=p_edge,
            funcs=get_funcs(**funcs_space),
            z_distributions=get_z_distribution(**z_distribution_space)   
        )
        df = scm.generate_time_series(n_samples=500)

    # # store
    # Path(f"{par_dir}/data/cp_style/{FN}_{SF}/data/").mkdir(parents=True, exist_ok=True)
    # Path(f"{par_dir}/data/cp_style/{FN}_{SF}/structure/").mkdir(parents=True, exist_ok=True)
    # df.to_csv(f"{par_dir}/data/cp_style/{FN}_{SF}/data/cp_collection_data_{ctr}.csv", index=False)
    # torch.save(scm.causal_structure.causal_structure_cp, f"{par_dir}/data/cp_style/{FN}_{SF}/structure/cp_collection_struct_{ctr}.pt")

    # plot
    scm.causal_structure.plot_structure()

### Runs

In [None]:
# data structure as such for convenient comparison with CausalTime
DATA_DICT = {
    filename.split(".csv")[0]: {
        'data_path': f"{par_dir}/data/cp_style/{FN}_{SF}/data/",
        'data_type': 'fmri',
        'task': filename, 
        'straight_path': f"{par_dir}/data/cp_style/{FN}_{SF}/data/" + f"{filename}",
        'struct_path' : f"{par_dir}/data/cp_style/{FN}_{SF}/structure/" + f"{filename.replace('data', 'struct').replace('csv', 'pt')}"
    } for filename in os.listdir(f"{par_dir}/data/cp_style/{FN}_{SF}/data/")
}

# CausalTime Parameters
PARAMS = {
    "batch_size" : 32, 
    "hidden_size" : 128, 
    "num_layers" : 2, 
    "dropout" : 0.1, 
    "seq_length" : 20, 
    "test_size" : 0.2, 
    "learning_rate" : 0.0001, 
    "n_epochs" : 1, 
    "flow_length" : 4, 
    "gen_n" : 20, 
    "n" : 2000,
    "arch_type" : "MLP", 
    "save_path" : "outputs/", 
    "log_dir" : "log/", 
}

# placeholders
shd_dict = {}

auc_dict_tcs = {}
data_dict_tcs = {}
auc_dict_ct = {}
data_dict_ct = {}
auc_dict_sdv = {}
data_dict_sdv = {}
auc_dict_tvae = {}
data_dict_tvae = {}

mmd_dict_tcs = {}
mmd_dict_ct = {}
mmd_dict_sdv = {}
mmd_dict_tvae = {}

# run
for k, v in list(DATA_DICT.items())[:]:

    try:
    
        # info
        filename = v['task']
        print(f" \n------------- {filename} ---------------\n ")

        # data
        true_data = pd.read_csv(v["straight_path"])
        true_data = true_data.rename(columns=dict(zip(true_data.columns, COL_NAMES[:true_data.shape[1]])))
        true_graph = torch.load(v["struct_path"])
        
        # adjust timesteps for computation time (1000 max)
        print(f"true data length: {true_data.shape[0]}")

        if true_data.shape[0]>2000:
            anchor = np.random.uniform(low=0, high=true_data.shape[0]-2000)
            true_data = true_data.loc[anchor : anchor + 2000, :]
            print(f"true data length (adjusted): {true_data.shape[0]}")

        # # adjust zeros for the numerical computations of some methods
        # for i in range(true_data.shape[0]):
        #     for j in range(true_data.shape[1]):
        #         if true_data.iloc[i, j] == 0:
        #             true_data.iloc[i, j] += np.random.uniform(low=0.0001, high=0.001)
        

        print("""\n ____________________________________ Simulate w/ CausalTime ____________________________________ \n""")

        true_pd, pro_true_pd, skimmed_pd, pro_gen_pd = generate_CT(
                batch_size=PARAMS["batch_size"], 
                hidden_size=PARAMS["hidden_size"], 
                num_layers=PARAMS["num_layers"], 
                dropout=PARAMS["dropout"], 
                seq_length=PARAMS["seq_length"], 
                test_size=PARAMS["test_size"], 
                learning_rate=PARAMS["learning_rate"], 
                n_epochs=PARAMS["n_epochs"], 
                flow_length=PARAMS["flow_length"], 
                gen_n=PARAMS["gen_n"], 
                n=PARAMS["n"],
                arch_type=PARAMS["arch_type"], 
                save_path=PARAMS["save_path"], 
                log_dir=PARAMS["log_dir"], 
                data_path=v["data_path"],
                data_type= v["data_type"], 
                task= v["task"],
            )
        ct_data = pro_gen_pd.copy()

        # Fix potential length mismatches
        if true_data.shape[0] > ct_data.shape[0]:
            true_data = true_data[:ct_data.shape[0]]
        elif true_data.shape[0] < ct_data.shape[0]:
            ct_data = ct_data[:true_data.shape[0]]

        # Evaluate
        print(f"LOG : true shape - {true_data.shape} VS ct shape - {ct_data.shape}")
        # train_X, train_Y, test_X, test_Y = prepare_det_data(real=true_data, synthetic=ct_data)
        # ct_auc = run_detection_metrics_XY(train_X=train_X, train_Y=train_Y, test_X=test_X, test_Y=test_Y)['auc']
        ct_auc = run_detection_metrics(real=true_data, synthetic=ct_data, bias_correction=False, verbose=False)['auc']

        mmd = mmd_torch(synthetic=ct_data, real=true_data)

        # Store
        mmd_dict_ct[filename] = mmd
        auc_dict_ct[filename] = ct_auc
        data_dict_ct[filename] = ct_data.copy()


        print("""\n ____________ Simulate w/ SDV ____________ \n""")

        true_data_sdv = true_data.copy()

        # Creating same conditions as CausalTime
        els = true_data_sdv.shape[0] % (true_data_sdv.shape[0]//20)
        if els!=0:
            true_data_sdv = true_data_sdv.iloc[:-els, :]

        # Sequence key
        true_data_sdv.loc[:, 'id'] = [i for i in range(true_data_sdv.shape[0]//20) for _ in range(20)]

        # Metadata
        metadata = Metadata.detect_from_dataframe(data=true_data_sdv)
        metadata.tables["table"].columns["id"]["sdtype"] = "id"
        metadata.set_sequence_key(column_name='id')

        # Synthesizer
        synthesizer = PARSynthesizer(metadata)
        synthesizer.fit(data=true_data_sdv)
        synthetic_data = synthesizer.sample(num_sequences=true_data_sdv.shape[0]//20 + 1)

        # Fix potential length mismatches
        sdv_data = synthetic_data.loc[:len(true_data), :].drop(columns=["id"])
        if true_data.shape[0] > sdv_data.shape[0]:
            true_data = true_data[:sdv_data.shape[0]]
        elif true_data.shape[0] < sdv_data.shape[0]:
            sdv_data = sdv_data[:true_data.shape[0]]
        
        mmd = mmd_torch(synthetic=sdv_data, real=true_data)

        # Evaluate
        print(f"LOG : true shape - {true_data.shape} VS sdv shape - {sdv_data.shape}")
        # train_X, train_Y, test_X, test_Y = prepare_det_data(real=true_data, synthetic=sdv_data)
        # sdv_auc = run_detection_metrics_XY(train_X=train_X, train_Y=train_Y, test_X=test_X, test_Y=test_Y)['auc']
        sdv_auc = run_detection_metrics(real=true_data, synthetic=sdv_data, bias_correction=False, verbose=False)['auc']

        # Store
        mmd_dict_sdv[filename] = mmd
        auc_dict_sdv[filename] = sdv_auc
        data_dict_sdv[filename] = sdv_data.copy()


        print("""\n _____________ Simulate w/ TCS _____________ \n""")

        results_tcs = get_optimal_sim_XY(true_data=true_data)
        tcs_data = results_tcs["optimal_data"]
        tcs_auc = results_tcs["auc"]
        pred_graph = results_tcs["optimal_scm"].causal_structure.causal_structure_cp

        # Fix potential length mismatches
        if true_data.shape[0] > tcs_data.shape[0]:
            true_data = true_data[:tcs_data.shape[0]]
        elif true_data.shape[0] < tcs_data.shape[0]:
            tcs_data = tcs_data[:true_data.shape[0]]

        # Evaluate
        print(f"LOG : true shape - {true_data.shape} VS tcs shape - {tcs_data.shape}")
        # train_X, train_Y, test_X, test_Y = prepare_det_data(real=true_data, synthetic=tcs_data)
        # tcs_auc = run_detection_metrics_XY(train_X=train_X, train_Y=train_Y, test_X=test_X, test_Y=test_Y)
        tcs_auc = run_detection_metrics(real=true_data, synthetic=tcs_data, bias_correction=False, verbose=False)['auc']

        mmd = mmd_torch(synthetic=tcs_data, real=true_data)

        # Store
        mmd_dict_tcs[filename] = mmd
        auc_dict_tcs[filename] = tcs_auc
        data_dict_tcs[filename] = tcs_data.copy()
        shd_dict[filename] = SHD(true_graph.numpy(), pred_graph.numpy())
        
        
        print("""\n _____________ Simulate w/ TimeVAE _____________ \n""")
        
        # Prepare TimeVAE Data
        dat = true_data.copy()

        n_samples = dat.shape[0]
        if 'target' in dat.columns:
            X = dat.drop(columns=['target']) 
            y = dat['target'] 
        else:
            X = dat
            y = None

        temporal_data = [X]
        observation_times = [X.index.to_numpy()]

        # Initialize the TimeSeriesDataLoader
        X_loader = TimeSeriesDataLoader(
            temporal_data=temporal_data, 
            observation_times=observation_times, 
            outcome=y,
            static_data=None,
            train_size=1.0, 
            test_size=0.0
        )

        # Define plugin kwargs for TimeVAE
        plugin_kwargs = dict(
            n_iter=30,
            batch_size=64,
            lr=0.001,
            encoder_n_layers_hidden=2,
            decoder_n_layers_hidden=2,
            encoder_dropout=0.05,
            decoder_dropout=0.05
        )

        # Initialize the generative model for TimeVAE
        test_plugin = Plugins().get("tvae", **plugin_kwargs)
        # test_plugin = Plugins().get("timegan", ?)

        # Fit the model
        if y is not None:
            test_plugin.fit(X_loader, cond=y)
        else:
            test_plugin.fit(X_loader)

        # Generate synthetic data
        generated_data = test_plugin.generate(count=n_samples) 

        # Extract the generated time-series data
        generated_data = generated_data.data["seq_data"]

        # Drop unnecessary columns like "seq_id", "seq_time_id"
        generated_data = generated_data.drop(columns=["seq_id", "seq_time_id"])

        # Fix potential length mismatches
        if true_data.shape[0] > generated_data.shape[0]:
            true_data = true_data[:generated_data.shape[0]]
        elif true_data.shape[0] < generated_data.shape[0]:
            generated_data = generated_data[:true_data.shape[0]]

        # Evaluate TimeVAE generated data
        print(f"LOG : true shape - {true_data.shape} VS generated shape - {generated_data.shape}")
        # train_X, train_Y, test_X, test_Y = prepare_det_data(real=true_data, synthetic=generated_data)
        # tvae_auc = run_detection_metrics_XY(train_X=train_X, train_Y=train_Y, test_X=test_X, test_Y=test_Y)['auc']
        tvae_auc = run_detection_metrics(real=true_data, synthetic=generated_data, bias_correction=False, verbose=False)['auc']

        mmd = mmd_torch(synthetic=generated_data, real=dat)

        # Store results for TimeVAE
        mmd_dict_tvae[filename] = mmd
        auc_dict_tvae[filename] = tvae_auc
        data_dict_tvae[filename] = generated_data.copy()
    
    except:
        print(" -------------------------- OUPS -------------------------- ")
        continue

### Store and visualize results

In [None]:
# Reformat
save_path = Path(os.getcwd()).parents[1] / "data" / "results" / "vs"

auc_dict_tcs
auc_dict = {}
mmd_dict = {}

for k in list(auc_dict_tcs.keys())[:]:
    auc_dict[k] = {
        "tcs" : auc_dict_tcs[k] if k in auc_dict_tcs.keys() else 0, 
        "ct" : auc_dict_ct[k] if k in auc_dict_ct.keys() else 0,
        "cpar" : auc_dict_sdv[k] if k in auc_dict_sdv.keys() else 0,
        "tvae" : auc_dict_tvae[k] if k in auc_dict_tvae.keys() else 0,
    }

for k in list(mmd_dict_tcs.keys())[:]:
    mmd_dict[k] = {
        "tcs" : mmd_dict_tcs[k] if k in mmd_dict_tcs.keys() else 0, 
        "ct" : mmd_dict_ct[k] if k in mmd_dict_ct.keys() else 0,
        "cpar" : mmd_dict_sdv[k] if k in mmd_dict_sdv.keys() else 0,
        "tvae" : mmd_dict_tvae[k] if k in mmd_dict_tvae.keys() else 0,
    }

# Store as JSON 
json.dump(auc_dict, open(save_path / f"{FN}_{SF}_auc.json", "w"))
json.dump(mmd_dict, open(save_path / f"{FN}_{SF}_mmd.json", "w"))
json.dump(shd_dict, open(save_path / f"{FN}_{SF}_shd.json", "w"))

In [None]:
# Load JSON
save_path = Path(os.getcwd()).parents[1] / "data" / "results" / "vs"

auc_dict = json.load(open(save_path / f"{FN}_{SF}_auc.json", "r"))
mmd_dict = json.load(open(save_path / f"{FN}_{SF}_mmd.json", "r"))
shd_dict = json.load(open(save_path / f"{FN}_{SF}_shd.json", "r"))
auc_df = pd.DataFrame(auc_dict).T
mmd_df = pd.DataFrame(mmd_dict).T
auc_df = auc_df.loc[~(auc_df==0).any(axis=1)]
mmd_df = mmd_df.loc[~(mmd_df==0).any(axis=1)]

# Plot
f, axs = plt.subplots(ncols=2, figsize=(12, 6))
sns.boxplot(data=auc_df, palette="pastel", ax=axs[0])
axs[0].set_title(r"AUC$_D$ comparison")
sns.boxplot(data=mmd_df, palette="pastel", ax=axs[1])
axs[1].set_title("MMD comparison")
# import matplotlib
# patches = [matplotlib.patches.Patch(color=sns.color_palette("pastel")[i], label=t) 
#            for i,t in enumerate(t.get_text() for t in axs[1].get_xticklabels())]
# plt.legend(handles=patches, loc="upper right")
plt.show()

prind(shd_dict)
auc_df["tcs"]
auc_df