In [1]:
from time import perf_counter
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score
from joblib import Parallel, delayed
import os
import json
import dag_utils as utils
from Baselines import Nonneg_dagma, MetMulDagma
from Baselines import colide_ev
from Baselines import DAGMA_linear
from Baselines import notears_linear
from BUILD import BUILD
from utils import *
PATH = './results/samples/'
SAVE = True 
SEED = 10
N_CPUS = os.cpu_count() 
np.random.seed(SEED)

In [1]:

def _to_jsonable(obj):
    """Recursively convert objects to JSON-serializable forms."""
    import numpy as np
    from pathlib import Path

    # basic types
    if obj is None or isinstance(obj, (bool, int, float, str)):
        # cast numpy scalars to Python
        if isinstance(obj, (np.bool_, np.integer, np.floating)):
            return obj.item()
        return obj

    # numpy arrays -> lists (careful for huge arrays: we won't dump arrays here)
    if isinstance(obj, np.ndarray):
        return obj.tolist()

    # tuples -> lists
    if isinstance(obj, tuple):
        return [_to_jsonable(x) for x in obj]

    # lists
    if isinstance(obj, list):
        return [_to_jsonable(x) for x in obj]

    # dicts
    if isinstance(obj, dict):
        return {str(k): _to_jsonable(v) for k, v in obj.items()}

    # Path
    if isinstance(obj, Path):
        return str(obj)

    # dataclasses
    try:
        from dataclasses import is_dataclass, asdict
        if is_dataclass(obj):
            return _to_jsonable(asdict(obj))
    except Exception:
        pass

    # callables (functions/classes)
    if callable(obj):
        # try to capture module + name; fall back to repr
        name = getattr(obj, "__name__", obj.__class__.__name__)
        mod  = getattr(obj, "__module__", None)
        return f"{mod+'.' if mod else ''}{name}"

    # objects with __dict__
    if hasattr(obj, "__dict__"):
        return _to_jsonable(vars(obj))

    # fallback
    return repr(obj)


In [2]:

from __future__ import annotations
import os, json, math, uuid, shutil, datetime as dt
from dataclasses import dataclass, asdict, field
from pathlib import Path
from time import perf_counter
from typing import Callable, Dict, Any, List, Optional, Tuple

import numpy as np
from sklearn.metrics import f1_score
from Baselines import GOLEM_Torch
import matplotlib.pyplot as plt

def get_lambda_value(n_nodes: int, n_samples: int, times: float = 1.0) -> float:
    """Common λ heuristic: sqrt(log p / n) scaled by `times`."""
    return math.sqrt(max(1e-12, np.log(max(2, n_nodes))) / max(2, n_samples)) * times


@dataclass
class BaselineSpec:
    model: Any
    init: Dict[str, Any] = field(default_factory=dict)
    args: Dict[str, Any] = field(default_factory=dict)
    name: str = "baseline"
    standardize: bool = False
    adapt_lambda: bool = False
    topo_transpose: bool = False
    is_topogreedy_refresh: bool = False


@dataclass
class ExperimentConfig:
    n_graphs: int
    n_nodes: int
    n_samples_list: List[int]
    edge_threshold: float
    data_params: Dict[str, Any]
    baselines: List[BaselineSpec]
    out_dir: str = "./exp_results"
    run_tag: Optional[str] = None
    save_intermediate: bool = True
    seed_offset: int = 0 


class ExperimentRunner:
    def __init__(self, cfg: ExperimentConfig):
        self.cfg = cfg
        self.run_id = cfg.run_tag or dt.datetime.now().strftime("%Y%m%d_%H%M%S") + "_" + uuid.uuid4().hex[:6]
        self.out_root = Path(cfg.out_dir) / self.run_id
        self.out_root.mkdir(parents=True, exist_ok=True)
        self._save_manifest()

    def run(self, verbose: bool = True):

        B = len(self.cfg.baselines)
        S = len(self.cfg.n_samples_list)
        N = self.cfg.n_nodes
        G = self.cfg.n_graphs

        # Metrics tensors: (G, S, B)
        shd = np.zeros((G, S, B))
        tpr = np.zeros((G, S, B))
        fdr = np.zeros((G, S, B))
        fscore = np.zeros((G, S, B))
        err = np.zeros((G, S, B))
        runtime = np.zeros((G, S, B))
        dag_count = np.zeros((G, S, B))
        theta_diff = np.zeros((G, S, B))

        W_est_all = np.zeros((G, S, B, N, N))
        Theta_est_all = np.zeros((G, S, B, N, N))

        for g in range(G):
            if verbose:
                print(f"\n=== Graph {g+1}/{G} ===")

            graph_seed = self.cfg.seed_offset + g
            data_p = dict(self.cfg.data_params)
            data_p["n_nodes"] = self.cfg.n_nodes

            W_true_cache = None
            Theta_true_cache = None

            for si, n_samples in enumerate(self.cfg.n_samples_list):
                data_p_this = dict(data_p)
                data_p_this["n_samples"] = int(n_samples)

                # simulate SEM
                W_true, _, X, Theta_true = utils.simulate_sem(**data_p_this)

                W_true_cache = W_true
                Theta_true_cache = Theta_true

                X_std = utils.standarize(X) if data_p_this.get("standarize", False) else X
                W_true_bin = utils.to_bin(W_true, self.cfg.edge_threshold)
                norm_W_true = np.linalg.norm(W_true)
                # emp_cov = (X_std.T @ X_std) / float(X_std.shape[0])
                emp_cov = np.cov(X_std, rowvar = False)
                print(f"cond: {np.linalg.cond(Theta_true)}")
                if verbose:
                    print(f"- samples={n_samples}, edges≈{np.count_nonzero(W_true_bin)}")

                for bi, base in enumerate(self.cfg.baselines):
                    X_in = X_std if base.standardize else X
                    args_call = dict(base.args)

                    # adaptive λ scheduling
                    if base.adapt_lambda:
                        if "lamb" in args_call:
                            args_call["lamb"] = get_lambda_value(self.cfg.n_nodes, n_samples, args_call["lamb"])
                        if "lambda1" in args_call:
                            args_call["lambda1"] = get_lambda_value(self.cfg.n_nodes, n_samples, args_call["lambda1"])

                    t0 = perf_counter()
                    W_est, Theta_est = self._run_one_baseline(
                        base=base,
                        X=X_in,
                        X_std=X_std,
                        emp_cov=emp_cov,
                        edge_thr=self.cfg.edge_threshold
                    )
                    t1 = perf_counter()
                    

                    if np.isnan(W_est).any():
                        W_est = np.zeros_like(W_est)
                        W_bin = np.zeros_like(W_est)
                    else:
                        W_bin = utils.to_bin(W_est, self.cfg.edge_threshold)

                    if base.topo_transpose:
                        W_est = W_est.T
                        W_bin = W_bin.T

                    # ---------- metrics ----------
                    shd[g, si, bi], tpr[g, si, bi], fdr[g, si, bi] = utils.count_accuracy(W_true_bin, W_bin)
                    fscore[g, si, bi] = f1_score(W_true_bin.flatten(), W_bin.flatten())
                    err[g, si, bi] = utils.compute_norm_sq_err(W_true, W_est, norm_W_true)
                    runtime[g, si, bi] = (t1 - t0)
                    dag_count[g, si, bi] = 1.0 if utils.is_dag(W_bin) else 0.0

                    if Theta_est is None:
                        theta_diff[g, si, bi] = 0.0
                    else:

                        Theta_norm = np.linalg.norm(Theta_true, "fro")
                        theta_diff[g, si, bi] = utils.compute_norm_sq_err(Theta_true, Theta_est, Theta_norm)

                    # store estimates
                    W_est_all[g, si, bi] = W_est
                    if Theta_est is not None:
                        Theta_est_all[g, si, bi] = Theta_est

                    if verbose:
                        print(f"  · {base.name:<18s} | SHD {shd[g,si,bi]:.1f} | F1 {fscore[g,si,bi]:.3f} | "
                              f"ΘΔ {theta_diff[g,si,bi]:.3f} | {runtime[g,si,bi]:.2f}s")

                if self.cfg.save_intermediate:
                    data = self._save_block(
                        g=g, si=si,
                        W_true=W_true, Theta_true=Theta_true,
                        W_est_all=W_est_all[g, si],
                        Theta_est_all=Theta_est_all[g, si],
                        shd=shd[g, si], tpr=tpr[g, si], fdr=fdr[g, si],
                        f1=fscore[g, si], err=err[g, si], rt=runtime[g, si],
                        dags=dag_count[g, si], theta_diff=theta_diff[g, si]
                    )

        final = dict(
            shd=shd, tpr=tpr, fdr=fdr, f1=fscore, err=err,
            runtime=runtime, dag_count=dag_count, theta_diff=theta_diff,
            W_est_all=W_est_all, Theta_est_all=Theta_est_all
        )
        np.savez_compressed(self.out_root / "final_results.npz", **final)
        if verbose:
            print(f"\nSaved results to: {self.out_root}")

        return final, data

    def _run_one_baseline(
        self,
        base: BaselineSpec,
        X: np.ndarray,
        X_std: np.ndarray,
        emp_cov: np.ndarray,
        edge_thr: float
    ) -> Tuple[np.ndarray, Optional[np.ndarray]]:

        if callable(base.model) and not hasattr(base.model, "fit"):
            # handle TopoGreedy_refresh function signature:
            if base.is_topogreedy_refresh:
                out = base.model(X, emp_cov, **base.args)
                W_est = out.get("A_est", None)
                Theta_est = out.get("prec", None)
                if W_est is None:
                    W_est = np.zeros((X.shape[1], X.shape[1]))
                return W_est, Theta_est
            else:
                W_est = base.model(X, **base.args)
                return W_est, None

        model = base.model(**base.init) if base.init else base.model()
        model.fit(X, **base.args)

        W_est = getattr(model, "W_est", None)
        if W_est is None:
            W_est = np.zeros((X.shape[1], X.shape[1]))

        Theta_est = None
        if hasattr(model, "Theta_est"):
            Theta_est = getattr(model, "Theta_est")
        elif hasattr(model, "prec"):
            Theta_est = getattr(model, "prec")

        return W_est, Theta_est

    def _save_manifest(self):
        sanitized_cfg = {
            "n_graphs": self.cfg.n_graphs,
            "n_nodes": self.cfg.n_nodes,
            "n_samples_list": list(self.cfg.n_samples_list),
            "edge_threshold": self.cfg.edge_threshold,
            "data_params": _to_jsonable(self.cfg.data_params),
            "baselines": [
                {
                    "name": b.name,
                    "model": _to_jsonable(b.model),          
                    "init": _to_jsonable(b.init),
                    "args": _to_jsonable(b.args),
                    "standardize": b.standardize,
                    "adapt_lambda": b.adapt_lambda,
                    "topo_transpose": b.topo_transpose,
                    "is_topogreedy_refresh": b.is_topogreedy_refresh,
                }
                for b in self.cfg.baselines
            ],
            "out_dir": str(self.cfg.out_dir),
            "run_tag": self.cfg.run_tag,
            "save_intermediate": self.cfg.save_intermediate,
            "seed_offset": self.cfg.seed_offset,
        }

        manifest = {
            "run_id": self.run_id,
            "created_at": dt.datetime.now().isoformat(),
            "config": sanitized_cfg,
            "python": {"numpy_version": np.__version__},
        }

        (self.out_root / "config.json").write_text(json.dumps(_to_jsonable(manifest), indent=2))

    def _save_block(
        self, g: int, si: int,
        W_true: np.ndarray, Theta_true: np.ndarray,
        W_est_all: np.ndarray,
        Theta_est_all: np.ndarray,
        shd: np.ndarray, tpr: np.ndarray, fdr: np.ndarray,
        f1: np.ndarray, err: np.ndarray, rt: np.ndarray,
        dags: np.ndarray, theta_diff: np.ndarray
    ):
        sub = self.out_root / f"graph_{g:03d}" / f"samples_{self.cfg.n_samples_list[si]}"
        sub.mkdir(parents=True, exist_ok=True)
        filename = sub / "block.npz"
        np.savez_compressed(
            filename,
            W_true=W_true, Theta_true=Theta_true,
            W_est_all=W_est_all, Theta_est_all=Theta_est_all,
            shd=shd, tpr=tpr, fdr=fdr, f1=f1, err=err, runtime=rt,
            dag_count=dags, theta_diff=theta_diff
        )
        
        return {
            'filename': str(filename),
            'W_true': W_true,
            'Theta_true': Theta_true,
            'W_est_all': W_est_all,
            'Theta_est_all': Theta_est_all,
            'shd': shd,
            'tpr': tpr,
            'fdr': fdr,
            'f1': f1,
            'err': err,
            'runtime': rt,
            'dag_count': dags,
            'theta_diff': theta_diff
        }

