In [None]:
import sys
sys.path.append("..")
import numpy as np
import pandas as pd
import torch
import torchmetrics
from pathlib import Path
from src.modules.lcm_module import LCMModule
from src.utils.misc_utils import count_params, run_evaluation_experiments, run_cdml_evaluation_experiments
#from src.benchmarks.CausalPretraining.model.model_wrapper import Architecture_PL

import warnings
warnings.filterwarnings("ignore")

roc = torchmetrics.classification.BinaryROC()
auroc = torchmetrics.classification.BinaryAUROC()

pd.set_option("display.max_rows", None)
pd.set_option("display.max_columns", None)

# for reproducibility
np.random.seed(42)
torch.manual_seed(42)

### Models Loading

In [None]:
""" Paths. Adjust them as needed. """
par_dir = Path.cwd().parent
models_path = Path("/media/nikolas/KINGSTON/LCM/logs")
cp_path = Path(par_dir / "src/benchmarks/CausalPretraining/pretrained_weights")
out_path = Path(par_dir / "outputs")

try:
    out_path.mkdir(parents=True, exist_ok=False)
    print(f"Created: {out_path}")
except FileExistsError:
    print(f"{out_path} already exists.")

models = {
    #"CP_trf": Architecture_PL.load_from_checkpoint(Path(cp_path / "transformer.ckpt")),
    "LCM_2.5M": LCMModule.load_from_checkpoint(Path(models_path / "LCM_2.5M.ckpt")),
    "LCM_9.4M": LCMModule.load_from_checkpoint(Path(models_path / "LCM_9.4M.ckpt")),
    "LCM_12.2M":LCMModule.load_from_checkpoint(Path(models_path / "LCM_12.2M.ckpt")),
    "LCM_24M": LCMModule.load_from_checkpoint(Path(models_path / "LCM_24M.ckpt")),
    "PCMCI": None,
    "DYNOTEARS": None,
    "VARLINGAM": None,
}

for model_name, model in models.items():
    if model is not None:
        print(f"\n___{model_name}___")
        print(count_params(model, pretty=False))

### Test Synthetic (S_Joint)

In [None]:
run_evaluation_experiments(models=models, cpd_path=Path("/media/nikolas/KINGSTON/Datasets/synthetic_data/S_/S_joint/test"),
                           out_dir=out_path)

### Test Synthetic (Synth_230K)

In [None]:
run_evaluation_experiments(models=models, cpd_path=Path("/media/nikolas/KINGSTON/Datasets/synth_230k_pt/test"),
                           out_dir=out_path)

### Test Mixture Synthetic Simulated (Synth_230K_Sim_45K)

In [None]:
run_evaluation_experiments(models=models, cpd_path=Path("/media/nikolas/KINGSTON/Datasets/synth_230k_sim_45k_pt"),
                           out_dir=out_path, sharded_data=True)

### Test Simulated (Sim_45k)

In [None]:
run_evaluation_experiments(models=models, cpd_path=Path("/media/nikolas/KINGSTON/Datasets/simulated_45k_pt"),
                           out_dir=out_path, sharded_data=True)

### fMRI-5

In [None]:
fmri_path = Path(par_dir / "data/fMRI_5")

run_evaluation_experiments(models=models, cpd_path=fmri_path,
                           out_dir=out_path, fmri_data=True)

### fMRI_10

In [None]:
fmri_path = Path(par_dir / "data/fMRI")

run_evaluation_experiments(models=models, cpd_path=fmri_path,
                           out_dir=out_path, fmri_data=True)

### Kuramoto_5V_1L

In [None]:
run_evaluation_experiments(models=models, cpd_path=Path("/media/nikolas/KINGSTON/Datasets/kuramoto_5V_1L/test"),
                           out_dir=out_path, sharded_data=False, kuramoto_data=True)

### Kuramoto_10V_1L

In [None]:
run_evaluation_experiments(models=models, cpd_path=Path("/media/nikolas/KINGSTON/Datasets/kuramoto_10V_1L/test"),
                           out_dir=out_path, sharded_data=False, kuramoto_data=True)

### Air_quality_mini (AirQualityMS)

In [None]:
run_evaluation_experiments(
    models=models,
    cpd_path=Path("/media/nikolas/KINGSTON/Datasets/air_quality_mini"),
    out_dir=out_path,
    sharded_data=False
)

### CDML

In [None]:
cdml_path = Path("/media/nikolas/KINGSTON/Datasets/cdml")
out_dir = Path("/home/nikolas/LCM/outputs")
run_cdml_evaluation_experiments(models=models,
                               cdml_path=Path("/media/nikolas/KINGSTON/Datasets/cdml"),
                               out_dir=Path("/home/nikolas/LCM/outputs"),
                               MAX_VAR=12,
                               MAX_LAG=3
)