In [None]:
# =========================================================
# DML (mediated) with RKHS 
# =========================================================

# ---- Limit BLAS/OpenMP threads BEFORE importing heavy libs ----
import os as os
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("VECLIB_MAXIMUM_THREADS", "1")
os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")

# ---- Standard libs ----
import sys
import time
import platform
from pathlib import Path

# ---- Third-party ----
import numpy as np
import matplotlib.pyplot as plt
from threadpoolctl import threadpool_limits

# Keep native libraries to 1 thread (NumPy/SciPy/BLAS/OpenMP)
threadpool_limits(1)

# ---- Local repo imports (adjust path if needed) ----
sys.path.append(str(Path.cwd() / "../../simulations"))
import dgps_mediated as dgps  

from nnpiv.rkhs import RKHSIV, RKHSIVL2, RKHS2IVL2 
from nnpiv.semiparametrics import DML_mediated  

# PyTorch presence for reproducibility reporting
try:
    import torch  
    TORCH_OK = True
except Exception:
    TORCH_OK = False


# -----------------------
# Reproducibility helpers
# -----------------------
def seed_everything(seed: int = 123) -> None:
    """Set seeds for reproducibility."""
    import random
    random.seed(seed)
    np.random.seed(seed)
    if TORCH_OK:
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)

seed_everything(123)


# -----------------------
# Resource print utility
# -----------------------
def print_resources():
    """Print basic compute resource info (CPU, GPU, library versions)."""
    cpu_cores = os.cpu_count()
    pyver = sys.version.split()[0]
    npver = np.__version__
    torch_info = "not installed"
    gpu_info = "CUDA: not available"
    if TORCH_OK:
        torch_info = torch.__version__
        if torch.cuda.is_available():
            try:
                name = torch.cuda.get_device_name(0)
            except Exception:
                name = "Unknown GPU"
            gpu_info = f"CUDA: available — {name}"
    print("=== Compute resources ===")
    print(f"Python: {pyver}")
    print(f"NumPy: {npver}")
    print(f"PyTorch: {torch_info}")
    print(f"CPU cores: {cpu_cores}")
    print(gpu_info)
    print(f"Platform: {platform.platform()}")
    print("=========================\n")


# -----------------------
# Result formatter
# -----------------------
def summarize_dml_result(name: str, result, elapsed: float):
    """
    Accepts result from .dml() and prints θ, SE, 95% CI when available.
    Compatible with returns like (theta, var, ci) or (theta, var, ci, cov).
    """
    # Normalize shapes
    if isinstance(result, tuple):
        if len(result) == 3:
            theta, var, ci = result
            cov = None
        elif len(result) == 4:
            theta, var, ci, cov = result
        else:
            # Unknown structure: just print and return
            print(f"[{name}] time={elapsed:.2f}s — result={result}")
            return
    else:
        print(f"[{name}] time={elapsed:.2f}s — result={result}")
        return

    theta = np.atleast_1d(theta).astype(float)
    var = np.atleast_1d(var).astype(float)
    se = np.sqrt(var)
    ci = np.array(ci, dtype=float) if ci is not None else None

    def fmt_arr(a):  # nice scalar/array printer
        return f"{float(a[0]):.4f}" if a.size == 1 else np.array2string(a, precision=4)

    print(f"[{name}] time={elapsed:.2f}s")
    print(f"  theta: {fmt_arr(theta)}")
    print(f"  SE   : {fmt_arr(se)}")
    if ci is not None:
        if ci.ndim == 1 and ci.size == 2:
            print(f"  95% CI: [{ci[0]:.4f}, {ci[1]:.4f}]")
        else:
            print(f"  95% CI: {np.array2string(ci, precision=4)}")
    if cov is not None:
        print(f"  (cov shape: {cov.shape})")
    print("")

In [2]:
# -----------------------
# Print resources 
# -----------------------
print_resources()    

=== Compute resources ===
Python: 3.10.18
NumPy: 2.2.6
PyTorch: 2.5.0
CPU cores: 112
CUDA: not available
Platform: Linux-4.18.0-553.44.1.el8_10.x86_64-x86_64-with-glibc2.28



In [3]:
# =========================================================
# Data generation 
# =========================================================
# Function dictionary (for reference):
# {'abs': 0, '2dpoly': 1, 'sigmoid': 2, 'sin': 3, 'frequent_sin': 4, 'abs_sqrt': 5,
#  'step': 6, '3dpoly': 7, 'linear': 8, 'rand_pw': 9, 'abspos': 10, 'sqrpos': 11,
#  'band': 12, 'invband': 13, 'steplinear': 14, 'pwlinear': 15, 'exponential': 16}

fn_number = 0
tau_fn = dgps.get_tau_fn(fn_number)
tauinv_fn = dgps.get_tauinv_fn(fn_number)

W, Z, X, M, D, Y, tau_fn = dgps.get_data(2000, tau_fn)


# Ground-truth value for the target estimand (for reference in logs)
TRUE_PARAM = 4.05
print(f"=== Ground truth ===\nTrue parameter for E[Y(1,M(0))] = {TRUE_PARAM:.2f}\n")

=== Ground truth ===
True parameter for E[Y(1,M(0))] = 4.05



In [5]:
# -----------------------
# 1) Sequential estimator (MR)
# -----------------------
rkhs_model = RKHSIVL2(kernel="rbf", gamma=0.013, delta_scale="auto", delta_exp=0.4)
dml_seq = DML_mediated(
    Y, D, M, W, Z, X,
    estimator="MR",
    estimand="E[Y(1,M(0))]",
    model1=[rkhs_model, rkhs_model],
    modelq1=[rkhs_model, rkhs_model],
    nn_1=[False,False],
    nn_q1=[False,False],
    fitargs1=[None, None],
    fitargsq1=[None, None],
    n_folds=5, n_rep=1
)
t0 = time.perf_counter()
res_seq = dml_seq.dml()
t1 = time.perf_counter()
summarize_dml_result("Sequential (MR) with RKHSIVL2", res_seq, t1 - t0)


# -----------------------
# 2) Simultaneous estimator (OR)
# -----------------------
dml_sim_or = DML_mediated(
    Y, D, M, W, Z, X,
    estimator="OR",
    estimand="E[Y(1,M(0))]",
    model1=RKHS2IVL2(kernel="rbf", gamma=0.01, delta_scale="auto", delta_exp=0.4),
    nn_1=False,
    modelq1=RKHS2IVL2(kernel="rbf", gamma=0.01, delta_scale="auto", delta_exp=0.4),
    nn_q1=False,
    n_folds=5, n_rep=1
)
t0 = time.perf_counter()
res_sim_or = dml_sim_or.dml()
t1 = time.perf_counter()
summarize_dml_result("Simultaneous (OR) with RKHS2IVL2", res_sim_or, t1 - t0)


# -----------------------
# 3) Simultaneous estimator (MR) — different tuning
# -----------------------
dml_sim_mr = DML_mediated(
    Y, D, M, W, Z, X,
    estimator="MR",
    estimand="E[Y(1,M(0))]",
    model1=RKHS2IVL2(kernel="rbf", gamma=0.0013, delta_scale="auto", delta_exp=10),
    nn_1=False,
    modelq1=RKHS2IVL2(kernel="rbf", gamma=0.0013, delta_scale="auto", delta_exp=10),
    nn_q1=False,
    n_folds=5, n_rep=1
)
t0 = time.perf_counter()
res_sim_mr = dml_sim_mr.dml()
t1 = time.perf_counter()
summarize_dml_result("Simultaneous (MR) with RKHS2IVL2 (alt tuning)", res_sim_mr, t1 - t0)


Rep: 1


100%|██████████| 5/5 [00:02<00:00,  2.16it/s]


[Sequential (MR) with RKHSIVL2] time=2.31s
  theta: 4.1587
  SE   : 4.8541
  95% CI: [3.9460, 4.3715]

Rep: 1


100%|██████████| 5/5 [00:06<00:00,  1.30s/it]


[Simultaneous (OR) with RKHS2IVL2] time=6.52s
  theta: 4.0575
  SE   : 2.7676
  95% CI: [3.9362, 4.1788]

Rep: 1


100%|██████████| 5/5 [00:13<00:00,  2.61s/it]

[Simultaneous (MR) with RKHS2IVL2 (alt tuning)] time=13.06s
  theta: 4.1157
  SE   : 4.5768
  95% CI: [3.9151, 4.3162]




