In [None]:
# -*- coding: utf-8 -*-
"""
IFA + CF causal pipeline — advanced, CF-based alignment & metrics, bootstrap, Graphviz plotting.
- ورودی: B_manual (m x m), M_manual (m x k_true)
- خروجی: results JSON, PNG تصاویر گراف واقعی و تخمینی، معیارها چاپ شده
"""

!pip install arabic_reshaper python-bidi
import os
import json
import subprocess
import tempfile
from dataclasses import dataclass, field
from scipy.optimize import linear_sum_assignment
from scipy.fft import fft
from sklearn.linear_model import Ridge
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import networkx as nx
import arabic_reshaper
from bidi.algorithm import get_display

#!pip install arabic_reshaper python-bidi
# ---------------------------
# Config dataclasses
# ---------------------------
@dataclass
class SCMConfig:
    B: np.ndarray
    M: np.ndarray
    sigma2: float = 0.05
    seed: int = 42

@dataclass
class IFAConfig:
    k_est: int = 8
    max_iter: int = 1000
    tol: float = 1e-6
    reg_latent: float = 1e-3

@dataclass
class BootstrapConfig:
    repeats: int = 40
    tau_A: float = 0.12
    tau_B: float = 0.06
    keep_prob: float = 0.6

@dataclass
class MOGConfig:
          n_components:int= 3,
          means_strategy: str= 'spread',  # 'spread', 'clustered', 'asymmetric'
          variance_strategy : str= 'moderate',  # 'low', 'moderate', 'high'
          weights_strategy: str= 'balanced'  # 'balanced', 'imbalanced'


@dataclass
class RunConfig:
    n: int = 4000
    scm: SCMConfig = None
    mog:MOGConfig=field(default_factory=MOGConfig)
    ifa: IFAConfig = field(default_factory=IFAConfig)
    boot: BootstrapConfig = field(default_factory=BootstrapConfig)
    senario: str = "default"
    outdir: str = "./ifa_cf_result"
    use_graphviz: bool = True

#-------------------------
# Funtions for test
#-------------------------
def run_comprehensive_data_tests():
    """
 اجرای تست‌های جامع با شرایط مختلف داده‌سازی
    """
    test_scenarios = [

       {
            'name': 'Complex_Structure',
            'mog_config': {'n_components': 3, 'means_strategy': 'spread',
                          'variance_strategy': 'moderate', 'weights_strategy': 'balanced'},
            'scale_strategy': 'moderate',
            'structure_type': 'hierarchical'
        },
        {
            'name': 'Baseline',
            'mog_config': {'n_components': 2, 'means_strategy': 'spread',
                          'variance_strategy': 'moderate', 'weights_strategy': 'balanced'},
            'scale_strategy': 'moderate',
            'structure_type': 'chain'
        },
        {
            'name': 'Strong_Signal',
            'mog_config': {'n_components': 3, 'means_strategy': 'spread',
                          'variance_strategy': 'low', 'weights_strategy': 'balanced'},
            'scale_strategy': 'strong',
            'structure_type': 'chain'
        },
        {
            'name': 'Challenging_MoG',
            'mog_config': {'n_components': 4, 'means_strategy': 'clustered',
                          'variance_strategy': 'high', 'weights_strategy': 'imbalanced'},
            'scale_strategy': 'moderate',
            'structure_type': 'chain'
        }

    ]

    results = {}

    for scenario in test_scenarios:
        print(f"Running scenario: {scenario['name']}")

        # تولید ساختار
        B, M = generate_different_structures(5, 4, scenario['structure_type'])
        B, M = optimize_matrix_scaling(B, M, scenario['scale_strategy'])

        mog_snr=scenario['mog_config']

        # اجرای pipeline
        cfg = RunConfig(
            n=10000,
            scm=SCMConfig(B=B, M=M, sigma2=4.3, seed=42),
            mog = MOGConfig(n_components=mog_snr['n_components'],
                            means_strategy=mog_snr['means_strategy'],
                            variance_strategy=mog_snr['variance_strategy'],
                            weights_strategy=mog_snr['weights_strategy']),
            ifa = IFAConfig(k_est=3, max_iter=300, tol=1e-6, reg_latent=1e-4),
            boot = BootstrapConfig(repeats=100, tau_A=0.2, tau_B=0.1, keep_prob=0.5),
            outdir=f"./ifa_cf_result/{scenario['name']}",
            senario=scenario['name'],
            use_graphviz=True
        )



        try:
            result = run_pipeline_advanced(B, M, cfg)

            results[scenario['name']] = result
        except Exception as e:
            print(f"Scenario {scenario['name']} failed: {e}")
            results[scenario['name']] = {'error': str(e)}

    return results
#---------------------------
# Generate structure
#--------------------------
def generate_different_structures(m, k, structure_type='chain'):
    """
    تولید ساختارهای مختلف گرافی برای تست robustness
    """
    if structure_type == 'chain':
        # ساختار زنجیره‌ای
        B = np.zeros((m, m))
        for i in range(m-1):
            B[i, i+1] = 0.8
        M = np.eye(m, k)

    elif structure_type == 'hierarchical' and m>=4:
        # ساختار سلسله مراتبی
        B = np.zeros((m, m))
        # لایه اول به دوم
        for i in range(2):
            for j in range(2, 4):
                B[i,j ] = 0.7
        # لایه دوم به سوم
        for i in range(2, 4):
            for j in range(4, m):
                B[i, j] = 0.6

        M = np.zeros((m, k))
        M[:2, :2] = np.eye(2)
        M[2:4, 2:4] = np.eye(2)

    elif structure_type == 'random':
        # ساختار تصادفی با چگالی کنترل شده
        B = np.zeros((m, m))
        density = 0.3  # چگالی یال‌ها
        indices = np.random.choice(m*m, int(m*m*density), replace=False)
        for idx in indices:
            i, j = idx // m, idx % m
            if i != j:  # جلوگیری از حلقه خودی
                B[j, i] = np.random.uniform(0.5, 0.9)

        M = np.random.randn(m, k) * 0.5
        M[np.abs(M) < 0.3] = 0  # ایجاد پراکندگی

    return B, M

#---------------------------
# Align Column
#---------------------------
def align_columns(A_ref, A_b):
    """
    A_ref : مرجع (m x k)
    A_b   : ماتریس تخمین زده شده از bootstrap (m x k)
    خروجی: A_b بعد از permutation + sign alignment
    """
    k = A_ref.shape[1]

    # --- مرحله 1: محاسبه similarity (اینجا از |corr| استفاده می‌کنیم) ---
    sim = np.abs(np.corrcoef(A_ref.T, A_b.T)[:k, k:])

    # --- مرحله 2: Hungarian برای پیدا کردن بهترین perm ---
    row_ind, col_ind = linear_sum_assignment(-sim)

    # --- مرحله 3: اعمال perm روی ستون‌های A_b ---
    A_perm = A_b[:, col_ind]

    # --- مرحله 4: sign alignment ---
    signs = np.sign(np.sum(A_ref * A_perm, axis=0))
    A_aligned = A_perm * signs

    return A_aligned



#----------------------------
# sign alignment
#-----------------------------
# --- sign alignment to a stable reference across bootstraps ---
def _col_ref_signs(A):
    # return sign = +1 or -1 for each column based on largest-abs entry
    k = A.shape[1]
    signs = np.ones(k)
    for j in range(k):
        idx = np.argmax(np.abs(A[:, j]))
        s = np.sign(A[idx, j])
        if s == 0: s = 1.0
        signs[j] = s
    return signs

# final sign correction relative to A_true (if available)
def final_sign_correction(A_est, M_est, A_ref):
    A_corr = A_est.copy()
    M_corr = M_est.copy()
    k = A_est.shape[1]
    for j in range(k):
        # if A_ref exists and matches dimension
        if A_ref is not None and A_ref.shape[1] > j:
            dot = np.dot(A_corr[:, j], A_ref[:, j])
            if dot < 0:
                A_corr[:, j] *= -1.0
                M_corr[:, j] *= -1.0
        else:
            # fallback: ensure max-abs entry positive
            idx = np.argmax(np.abs(A_corr[:, j]))
            if A_corr[idx, j] < 0:
                A_corr[:, j] *= -1.0
                M_corr[:, j] *= -1.0
    return A_corr, M_corr


def normalize_cols(A):
    """
    Normalize columns of matrix A to unit L2 norm.
    """
    A_norm = A.copy()
    col_norms = np.linalg.norm(A_norm, axis=0)
    # avoid division by zero
    col_norms[col_norms == 0] = 1.0
    A_norm /= col_norms
    return A_norm
#----------------------------
#  Optimize B,M
#----------------------------
def optimize_matrix_scaling(B, M, scale_strategy='moderate'):
    """
    بهینه‌سازی مقیاس ماتریس‌های B و M برای بهبود شناسایی
    """
    B_opt = B.copy()
    M_opt = M.copy()

    if scale_strategy == 'strong':
        # تقویت سیگنال برای شناسایی بهتر
        B_opt = B_opt * 1.5
        M_opt = M_opt * 1.2
    elif scale_strategy == 'moderate':
        # مقیاس متوسط (پیش‌فرض)
        pass
    elif scale_strategy == 'weak':
        # تضعیف سیگنال برای تست robustness
        B_opt = B_opt * 0.7
        M_opt = M_opt * 0.8

    # اطمینان از اینکه ماتریس B پایدار است
    eigenvalues = np.linalg.eigvals(B_opt)
    if np.any(np.abs(eigenvalues) >= 1):
        # مقیاس‌دهی برای اطمینان از پایداری
        scale_factor = 0.9 / np.max(np.abs(eigenvalues))
        B_opt = B_opt * scale_factor

    return B_opt, M_opt
# ---------------------------
# 1) Generate data (SCM-MOG)
# ---------------------------

def generate_data(B, M, n=4000, sigma2=0.05, seed=0, mog_config=None):

    rng = np.random.default_rng(seed)
    m, k = M.shape

    # تنظیم پارامترهای MoG بر اساس استراتژی
    if mog_config.means_strategy == 'spread':
        mog_means = np.linspace(-2.5, 2.5, mog_config.n_components)
    elif mog_config.means_strategy == 'clustered':
        mog_means = np.linspace(-1.0, 1.0, mog_config.n_components)
    else:  # asymmetric
        mog_means = np.linspace(-1.5, 2.0, mog_config.n_components)

    if mog_config.variance_strategy == 'low':
        mog_vars = 0.1 * np.ones(mog_config.n_components)
    elif mog_config.variance_strategy == 'moderate':
        mog_vars = 0.3 * np.ones(mog_config.n_components)
    else:  # high
        mog_vars = 0.5 * np.ones(mog_config.n_components)

    if mog_config.weights_strategy == 'balanced':
        mog_weights = np.ones(mog_config.n_components) / mog_config.n_components
    else:  # imbalanced
        mog_weights = rng.dirichlet(np.ones(mog_config.n_components) * 0.5)
        mog_weights = mog_weights / mog_weights.sum()

    # تولید داده
    H = np.linalg.inv(np.eye(m) - B)
    C = mog_config.n_components
    Q = np.zeros((k, n))

    for i in range(k):
        comps = rng.choice(C, size=n, p=mog_weights)
        for c in range(C):
            mask = (comps == c)
            if mask.sum() > 0:
                Q[i, mask] = rng.normal(loc=mog_means[c],
                                       scale=np.sqrt(mog_vars[c]),
                                       size=mask.sum())

    E = rng.normal(0.0, np.sqrt(sigma2), size=(m, n))
    X = H @ (M @ Q + E)

    return X, Q, H


#---------------------------
# Learn
#---------------------------

# VB-IFA (Attias-like) — full, practical implementation for model selection by ELBO
import numpy as np
from scipy.linalg import solve, pinv
from scipy.stats import multivariate_normal
from scipy.optimize import linear_sum_assignment

def _safe_log(x):
    return np.log(np.maximum(x, 1e-300))

def _normalize_columns(A):
    norms = np.linalg.norm(A, axis=0) + 1e-12
    return A / norms, norms

def _init_mixture_params(k, R=3, rng=None):
    # initialize mixture params per latent: weights, means, variances
    if rng is None: rng = np.random.default_rng(0)
    pis = np.ones((k, R)) / R
    mus = rng.normal(0, 1.0, size=(k, R))
    vars_ = np.ones((k, R)) * 1.0
    return pis, mus, vars_

def _compute_posterior_S(X, A, sigma2, pis, mus, vars_):
    # compute q(S) approx as Gaussian N(S_mean, S_vardiag) assuming factorized across latents
    # This is an approximation: we use diagonal posterior covariances for speed (common in VB-IFA practical)
    # Returns S_mean (k x n), S_var (k x n), and per-latent posterior precision diag entries (k,)
    m, n = X.shape
    k = A.shape[1]
    # compute Sigma_s_approx diagonal: diag = 1 / (alpha_j + (A_j^T A_j) / sigma2)
    # but alpha_j will be computed from mixture responsibilities below; here we use expected precision per latent:
    # we'll compute responsibilities given a prior guess of S_mean (iterate)
    # initialize: simple least squares to get S_mean
    S_mean = np.linalg.pinv(A) @ X  # k x n
    # We'll compute S_var diag from A and sigma2 using posterior formula for factor analysis with diag approx
    A2sum = np.sum(A**2, axis=0)  # shape (k,)
    S_var_diag = (sigma2) / (A2sum[:, None] + 1e-12)  # k x 1 -> k x n after tile
    S_var = np.tile(np.maximum(S_var_diag, 1e-12), (1, n))
    return S_mean, S_var

def _update_mixture_responsibilities(S_mean, S_var, pis, mus, vars_):
    # S_mean: k x n, S_var: k x n
    k, n = S_mean.shape
    R = pis.shape[1]
    resp = np.zeros((k, R, n))
    # For each latent j, for each component r, responsibility ~ pi_jr * N(s | mu_jr, var_jr + S_var_jn)
    for j in range(k):
        for r in range(R):
            var_effect = vars_[j, r] + S_var[j, :]
            # gaussian log-prob:
            logp = -0.5 * (_safe_log(2 * np.pi * var_effect) + ((S_mean[j, :] - mus[j, r])**2) / (var_effect))
            resp[j, r, :] = _safe_log(pis[j, r]) + logp
        # normalize in log-space over r
        # subtract max for stability
        mx = np.max(resp[j, :, :], axis=0)
        resp[j, :, :] = np.exp(resp[j, :, :] - mx[None, :])
        sumr = np.sum(resp[j, :, :], axis=0) + 1e-300
        resp[j, :, :] = resp[j, :, :] / sumr[None, :]
    return resp  # shape (k, R, n)

def _update_mixture_params_from_resp(S_mean, S_var, resp):
    # resp: k x R x n
    k, R, n = resp.shape
    pis = np.zeros((k, R))
    mus = np.zeros((k, R))
    vars_ = np.zeros((k, R))
    for j in range(k):
        Njr = np.sum(resp[j, :, :], axis=1) + 1e-12  # R
        pis[j, :] = Njr / np.sum(Njr)
        # means: weighted average of posterior means
        for r in range(R):
            numer = np.sum(resp[j, r, :] * S_mean[j, :])
            mu_r = numer / Njr[r]
            mus[j, r] = mu_r
            # var: weighted of (S_var + (S_mean - mu)^2)
            sq = np.sum(resp[j, r, :] * (S_var[j, :] + (S_mean[j, :] - mu_r)**2))
            vars_[j, r] = sq / Njr[r]
            vars_[j, r] = max(vars_[j, r], 1e-8)
    return pis, mus, vars_

def _compute_elbo(X, A, sigma2, S_mean, S_var, resp, pis, mus, vars_):
    # Approximate ELBO = E_q[log p(X|S,A)] + E_q[log p(S|z)] + E_q[log p(z)] - E_q[log q(S)] - E_q[log q(z)]
    # We'll compute terms in a tractable approximate form.
    m, n = X.shape
    k = A.shape[1]
    R = pis.shape[1]

    # 1) E_q[log p(X|S,A)] : Gaussian likelihood with noise sigma2
    # = -0.5 * mn * log(2πσ2) - (1/(2σ2)) * E_q[||X - A S||^2]
    recon_mean = A @ S_mean  # m x n
    # E||X - A S||^2 = ||X - A S_mean||^2 + sum_j sum_n A[:,j]^2 * S_var[j,n]
    term1 = np.sum((X - recon_mean)**2)
    term1 += np.sum((A**2) @ S_var)  # A**2: m x k ; S_var: k x n -> (m x n) summed
    E_log_like = -0.5 * m * n * np.log(2 * np.pi * sigma2) - 0.5 * term1 / sigma2

    # 2) E_q[log p(S|z)] where p(s_j | z_j=r) = N(mu_jr, var_jr)
    # E over q(s,z): sum_jr sum_n r_jrn * E_{q(s|.)}[log N(s_jn | mu_jr, var_jr)]
    E_log_pS = 0.0
    for j in range(k):
        for r in range(R):
            var_r = vars_[j, r]
            mu_r = mus[j, r]
            # expectation of quadratic term:
            # E[(s - mu)^2] = S_var[j,:] + (S_mean[j,:] - mu_r)^2
            quad = S_var[j, :] + (S_mean[j, :] - mu_r)**2
            logN = -0.5 * (_safe_log(2 * np.pi * var_r) + quad / var_r)
            E_log_pS += np.sum(resp[j, r, :] * logN)

    # 3) E_q[log p(z)] = sum_jr sum_n r_jrn * log(pi_jr)
    E_log_pz = np.sum(resp * _safe_log(pis)[:, :, None])

    # 4) - E_q[log q(S)] : q(S) is Gaussian with diagonal cov per latent per sample
    # Entropy of Gaussian with variance v: 0.5 * log(2πe v)
    ent_S = 0.5 * np.sum(_safe_log(2 * np.pi * np.e * S_var))

    # 5) - E_q[log q(z)] : categorical per latent per sample
    # = - sum_jn sum_r r_jrn * log r_jrn
    ent_z = - np.sum(resp * _safe_log(resp))

    elbo = E_log_like + E_log_pS + E_log_pz + ent_S + ent_z
    return float(elbo), dict(E_log_like=E_log_like, E_log_pS=E_log_pS, E_log_pz=E_log_pz, ent_S=ent_S, ent_z=ent_z)


#--------------------------------------------------------------------------------
def learn_ifa_vb_full_with_ard(X, k_candidates, R=3, max_iter=500, tol=1e-6,
                              sigma2_init=None, rng_seed=0, verbose=False,
                              init_from_pca=True,
                              # ARD hyperparams:
                              a0=1e-6, b0=1e-6,    # weak gamma prior on alpha
                              ard_prune_thresh=1e6, # if alpha_j > thresh -> inactive
                              norm_prune_thresh=1e-6 # or if ||A_j|| < thresh -> inactive
                             ):
    rng = np.random.default_rng(rng_seed)
    m, n = X.shape
    if isinstance(k_candidates, int):
        k_candidates = [k_candidates]
    k_candidates = sorted(list(k_candidates))
    if sigma2_init is None:
        sigma2_init = max(1e-8, np.var(X) * 0.01)

    all_results = {}
    best_elbo = -np.inf
    best_result = None

    for k in k_candidates:
        if verbose:
            print(f"\n--- VB-IFA+ARD: trying k={k} ---")
        # init A
        if init_from_pca:
            U, Sdiag, Vt = np.linalg.svd(X / np.sqrt(n), full_matrices=False)
            A = np.zeros((m, k))
            take = min(U.shape[1], k)
            A[:, :take] = U[:, :take] * (Sdiag[:take])
            if k > take:
                A[:, take:] = rng.normal(0, 0.01, size=(m, k - take))
        else:
            A = rng.normal(0, 0.01, size=(m, k))
        A, _ = _normalize_columns(A)

        # init ARD alphas (precision for each column)
        alpha = np.ones(k) * 1e-2

        # mixture params
        pis, mus, vars_ = _init_mixture_params(k, R=R, rng=rng)
        sigma2 = float(sigma2_init)

        # initial S posterior
        S_mean, S_var = _compute_posterior_S(X, A, sigma2, pis, mus, vars_)
        resp = _update_mixture_responsibilities(S_mean, S_var, pis, mus, vars_)
        pis, mus, vars_ = _update_mixture_params_from_resp(S_mean, S_var, resp)

        elbo_trace = []
        for it in range(max_iter):
            # E-step (like before)
            k_, n_ = S_mean.shape
            precision_prior = np.zeros((k_, n_))
            for j in range(k_):
                precision_prior[j, :] = np.sum(resp[j, :, :] / (vars_[j, :, None] + 1e-12), axis=0)

            A_col_sqsum = np.sum(A**2, axis=0)
            denom = precision_prior + (A_col_sqsum[:, None] / (sigma2 + 1e-12))
            S_var = 1.0 / (denom + 1e-12)

            AtX = A.T @ X
            prior_num = np.zeros_like(S_mean)
            for j in range(k_):
                num = np.zeros(n_)
                for r in range(resp.shape[1]):
                    num += resp[j, r, :] * (mus[j, r] / (vars_[j, r] + 1e-12))
                prior_num[j, :] = num

            S_mean = S_var * (AtX / (sigma2 + 1e-12) + prior_num)

            # M-step with ARD prior on A: prior p(A_j) = N(0, alpha_j^{-1} I_m)
            ESS = S_mean @ S_mean.T + np.diag(np.sum(S_var, axis=1))
            XS_T = X @ S_mean.T

            # regularize ESS with sigma2 * diag(alpha)
            ESS_reg = ESS + sigma2 * np.diag(alpha + 1e-12)
            A_new = XS_T @ np.linalg.inv(ESS_reg + 1e-8 * np.eye(k_))
            A_new, _ = _normalize_columns(A_new)  # renormalize to avoid scale ambiguity

            # update sigma2 (as before)
            recon_mean = A_new @ S_mean
            term_recon = np.sum((X - recon_mean)**2) + np.sum((A_new**2) @ S_var)
            sigma2_new = max(1e-12, term_recon / (m * n))

            # update mixture resp/params
            resp = _update_mixture_responsibilities(S_mean, S_var, pis, mus, vars_)
            pis, mus, vars_ = _update_mixture_params_from_resp(S_mean, S_var, resp)

            # update ARD alpha with Gamma(a0,b0) hyperprior:
            # posterior estimate (MAP / moment) for alpha_j:
            # alpha_j = (m/2 + a0) / (0.5 * sum_i A_ij^2 + b0)
            sqnorms = 0.5 * np.sum(A_new**2, axis=0)  # 0.5 * ||A_j||^2
            alpha = (m / 2.0 + a0) / (sqnorms + b0 + 1e-12)

            # assign
            A = A_new
            sigma2 = sigma2_new

            # compute ELBO (we can extend compute to include log p(A|alpha) - log q(alpha) if desired)
            elbo, _ = _compute_elbo(X, A, sigma2, S_mean, S_var, resp, pis, mus, vars_)
            # add ARD prior contribution approx: -0.5 * sum_j alpha_j ||A_j||^2 + (a0-1)log alpha - b0 alpha  (we'll add only the quadratic term for stability)
            ard_quadratic = -0.5 * np.sum(alpha * np.sum(A**2, axis=0))
            elbo += ard_quadratic

            elbo_trace.append(elbo)
            if verbose and (it % 20 == 0 or it == max_iter-1):
                active = np.sum((alpha < ard_prune_thresh) & (np.linalg.norm(A, axis=0) > norm_prune_thresh))
                print(f"k={k} it={it} ELBO={elbo:.4f} sigma2={sigma2:.4e} active_factors={active}")

            if it > 2:
                rel = abs(elbo_trace[-1] - elbo_trace[-2]) / (abs(elbo_trace[-2]) + 1e-12)
                if rel < tol:
                    if verbose:
                        print(f"k={k} converged it={it} rel_change={rel:.3e}")
                    break

        # count effective factors
        active_mask = (alpha < ard_prune_thresh) & (np.linalg.norm(A, axis=0) > norm_prune_thresh)
        effective_k = int(np.sum(active_mask))
        result = {
            "A": A,
            "S_mean": S_mean,
            "S_var": S_var,
            "pis": pis,
            "mus": mus,
            "vars": vars_,
            "sigma2": sigma2,
            "alpha": alpha,
            "active_mask": active_mask,
            "effective_k": effective_k,
            "elbo_trace": np.array(elbo_trace),
            "elbo_final": float(elbo_trace[-1]) if len(elbo_trace) else -np.inf,
            "k": k
        }
        all_results[k] = result
        if result["elbo_final"] > best_elbo:
            best_elbo = result["elbo_final"]
            best_result = result
        if verbose:
            print(f"Finished k={k} ELBO_final={result['elbo_final']:.4f} effective_k={effective_k}")

    return best_result, all_results


# ---------------------------
# estimate B via row-wise Ridge
# ---------------------------
def estimate_B_ridge(X, alpha=5e-2):
    m, n = X.shape
    B_hat = np.zeros((m, m))
    for i in range(m):
        idx = np.arange(m) != i
        Xp = X[idx, :]
        y = X[i, :]
        G = Xp @ Xp.T + alpha * np.eye(m-1)
        beta = np.linalg.solve(G, Xp @ y)
        B_hat[i, idx] = beta
    return B_hat

# ---------------------------
#  CF tools: empirical CF and distance
# ---------------------------
def empirical_cf_1d(z, tgrid):
    # z: samples (n,)
    # returns φ(t) for all tgrid shape (len(tgrid),)
    return np.mean(np.exp(1j * np.outer(tgrid, z)), axis=1)

def empirical_cf_multivariate(samples, tgrid):
     return np.mean(np.exp(1j * (tgrid @ samples)), axis=1)

def cf_distance(X_true, X_est, m_grid=200, seed=0):
    # X_true, X_est: (d x n)
    rng = np.random.default_rng(seed)
    d, n = X_true.shape
    # sample t vectors from isotropic normal
    tgrid = rng.normal(0, 1, size=(m_grid, d))
    phi_true = empirical_cf_multivariate(X_true, tgrid)
    phi_est = empirical_cf_multivariate(X_est, tgrid)
    dist = np.mean(np.abs(phi_true - phi_est)**2).real
    return float(dist)

# ---------------------------
# 5) CF-based alignment (empirical CF per latent) + Hungarian
#    Align estimated S_est rows to Q_true rows
# ---------------------------
def empirical_cf_vector(z, tgrid):
    # z shape (n,)
    return np.exp(1j * np.outer(tgrid, z)).mean(axis=1)

def cf_alignment(Q_true, S_est, tgrid=None):
    # Q_true: k_true x n ; S_est: k_est x n
    if tgrid is None:
        tgrid = np.linspace(-3, 3, 121)
    k_true, n = Q_true.shape
    k_est, n2 = S_est.shape
    assert n == n2
    cf_true = [empirical_cf_vector(Q_true[i], tgrid) for i in range(k_true)]   # k_true x T
    cf_est = [empirical_cf_vector(S_est[j], tgrid) for j in range(k_est)]      # k_est x T
    C = np.zeros((k_true, k_est))
    for i in range(k_true):
        for j in range(k_est):
            diff = cf_true[i] - cf_est[j]
            C[i, j] = np.sum(np.abs(diff)**2)
    row, col = linear_sum_assignment(C)
    perm = col  # perm[i] = index in est matched to true i
    # sign correction by corr
    signs = np.ones(k_true)
    for i in range(k_true):
        j = perm[i]
        # real correlation (use real parts)
        corr = np.corrcoef(Q_true[i].real, S_est[j].real)[0,1]
        signs[i] = np.sign(corr) if not np.isnan(corr) else 1.0
    return perm, signs

# ---------------------------
#  Align A_est columns to A_true (via CF alignment on S_est)
# ---------------------------
def align_A_via_CF(A_est, X, Q_true, tgrid=None):
    # S_est = pinv(A_est) @ X
    S_est = np.linalg.pinv(A_est) @ X   # k_est x n
    perm, signs = cf_alignment(Q_true, S_est, tgrid=tgrid)
    m, k_est = A_est.shape
    k_true = Q_true.shape[0]
    A_aligned = np.zeros((m, k_true))
    for i in range(k_true):
        j = perm[i]
        A_aligned[:, i] = signs[i] * A_est[:, j]
        S_est[j, :] *= signs[i]
    return A_aligned, perm, signs, S_est

# ---------------------------
#  Bootstrap stability (fit A and B on bootstrap samples)
# ---------------------------
import numpy as np
from scipy.optimize import linear_sum_assignment
from sklearn.metrics.pairwise import cosine_similarity

def bootstrap_stability(X, k_est, repeats=40, tau_A=0.12, tau_B=0.06, seed=0, sigma2=0, alpha_B=1e-2):
    rng = np.random.default_rng(seed)
    m, n = X.shape
    k_target = k_est if isinstance(k_est, int) else (int(k_est[0]) if hasattr(k_est, '__iter__') else int(k_est))

    A_stack = np.full((m, k_target, repeats), np.nan)
    B_stack = np.full((m, m, repeats), np.nan)

    ref_perm = None
    ref_signs = None

    for b in range(repeats):
        idx = rng.integers(0, n, size=n)
        Xb = X[:, idx]

        try:
            k_candidates = [2, 3, 4, 5, 6, 7, 8]

            # ---------- فراخوانی تابع یادگیری IFA ----------
            best, all_res = learn_ifa_vb_full_with_ard(
                Xb,
                k_candidates,
                R=3,
                max_iter=500,
                tol=1e-6,
                sigma2_init=sigma2,
                rng_seed=seed+b,
                verbose=False,
                init_from_pca=True,
                a0=1e-6,
                b0=1e-6,
                ard_prune_thresh=1e3,
                norm_prune_thresh=1e-3
            )

            print("Selected k:", best['k'])
            k_b = best['k']
            A_b = best['A']
            S_b = best['S_mean']

            if A_b is None:
                print(f"learn_ifa returned None at bootstrap {b}, skipping.")
                continue

            # pad/truncate A_b to k_target
            A_tmp = np.zeros((m, k_target))
            kk = min(A_b.shape[1], k_target)
            A_tmp[:, :kk] = A_b[:, :kk]
            A_b = A_tmp

        except Exception as e:
            print(f"learn_ifa failed on bootstrap {b}: {e}")
            continue

        try:
            B_b = estimate_B_ridge(Xb, alpha=alpha_B)
        except Exception as e:
            print(f"estimate_B_ridge failed on bootstrap {b}: {e}")
            continue

        # ---------- اصلاح permutation و sign ----------
        if ref_perm is None:
            # اولین bootstrap موفق -> تنظیم reference
            ref_perm = np.arange(A_b.shape[1])
            ref_signs = _col_ref_signs(A_b)
        else:
            # similarity matrix بین ستون‌ها (cosine similarity امن)
            sim = np.abs(cosine_similarity(A_stack[:, ref_perm, 0].T, A_b.T))
            sim = np.nan_to_num(sim, nan=0.0, posinf=0.0, neginf=0.0)

            # Hungarian algorithm برای بهترین permutation
            row_ind, col_ind = linear_sum_assignment(-sim)
            A_b = A_b[:, col_ind]  # permute columns

            # sign alignment
            col_signs = _col_ref_signs(A_b)
            flip = (col_signs != ref_signs)
            for j in range(A_b.shape[1]):
                if flip[j]:
                    A_b[:, j] *= -1.0

        # ---------- ذخیره در stack ----------
        A_stack[:, :, b] = A_b
        B_stack[:, :, b] = B_b

    # compute probabilities ignoring nans
    A_prob = np.nanmean(np.abs(A_stack) > tau_A, axis=2)
    B_prob = np.nanmean(np.abs(B_stack) > tau_B, axis=2)
    A_median = np.nanmedian(A_stack, axis=2)
    B_median = np.nanmedian(B_stack, axis=2)

    return A_prob, A_median, B_prob, B_median




#---------------------------------
# 9) Purn
#--------------------------------



def optimal_topo_order_by_dp(Bw):
    """
    ترتیب بهینه‌ی DAG با DP روی زیرمجموعه‌ها (دقیق)
    Bw: ماتریس وزن (m x m). وزن یال j->i همان Bw[i,j].
    برمی‌گرداند: order (لیست اندیس‌ها از زود به دیر)
    """
    m = Bw.shape[0]
    # dp[mask] = بیشترین وزن قابل کسب برای هر زیرمجموعه S با فرض اینکه ترتیب فقط روی S تعریف شود
    N = 1 << m
    dp = np.full(N, -np.inf, dtype=float)
    parent = np.full((N, m), -1, dtype=int)  # برای بازسازی: چه گرهی آخرِ S بوده
    dp[0] = 0.0

    # پیش‌محاسبه‌ی مجموع وزن یال های u->v که وقتی v آخرِ S قرار گیرد افزوده می‌شود
    # اگر v را به عنوان آخرینِ S بگذاریم، پاداش: sum_{u in S\{v}} Bw[v,u]
    add_gain = np.zeros((m, N), dtype=float)
    for v in range(m):
        for mask in range(N):
            # مجموع وزن از هر u که در mask است به v (یعنی u قبل و v آخر)
            s = 0.0
            u = mask
            while u:
                lsb = u & -u
                idx = (lsb.bit_length() - 1)
                s += Bw[v, idx]
                u ^= lsb
            add_gain[v, mask] = s

    for mask in range(N):
        # اگر S=mask؛ v را به عنوان آخرین عضو اضافه کنیم
        # یعنی v باید در S باشد
        subset = mask
        while subset:
            lsb = subset & -subset
            v = (lsb.bit_length() - 1)
            prev = mask ^ lsb
            cand = dp[prev] + add_gain[v, prev]
            if cand > dp[mask]:
                dp[mask] = cand
                parent[mask, v] = 1
            subset ^= lsb

    # بازسازی ترتیب (از آخر به اول)
    order_rev = []
    mask = N - 1
    while mask:
        # پیدا کردن v ای که parent[mask, v] == 1
        v = int(np.where(parent[mask] == 1)[0][0])
        order_rev.append(v)
        mask ^= (1 << v)
    order = list(reversed(order_rev))
    return order

def spectral_init_order(Bw):
    """
    ترتیب اولیه‌ی طیفی برای m بزرگ: با استفاده از مقدارویژه‌ی بردار چپ-راست
    """
    m = Bw.shape[0]
    # ماتریس جهت‌دار را به ماتریس امتیاز خالص تبدیل می‌کنیم: score[j] = out_w(j) - in_w(j)
    out_w = np.sum(Bw, axis=0)  # وزن خروجی از هر j به دیگران: sum_i Bw[i,j]
    in_w  = np.sum(Bw, axis=1)  # وزن ورودی به هر i از دیگران: sum_j Bw[i,j]
    score = out_w - in_w
    return list(np.argsort(-score))  # بیشترین امتیاز جلوتر

def refine_order_2opt(order, Bw, iters=200):
    """
    بهبود محلی 2-opt روی تابع هدفِ 'مجموع وزن یال‌های هم‌سو با ترتیب'
    """
    m = len(order)
    pos = {v:i for i,v in enumerate(order)}
    def forward_weight():
        w = 0.0
        for i in range(m):
            vi = order[i]
            for j in range(i+1, m):
                vj = order[j]
                w += Bw[vj, vi]  # vi -> vj یعنی Bw[vj,vi]
        return w

    best = forward_weight()
    improved = True
    cnt = 0
    while improved and cnt < iters:
        improved = False
        cnt += 1
        for a in range(m-1):
            for b in range(a+1, m):
                # swap order[a], order[b] و ارزیابی سریع
                order[a], order[b] = order[b], order[a]
                val = forward_weight()
                if val > best + 1e-12:
                    best = val
                    improved = True
                else:
                    order[a], order[b] = order[b], order[a]
            if improved:
                break
    return order

def find_dag_order(B, use_prob=None):
    """
    ترتیب قابل دفاع برای DAG:
    - وزن‌ها = |B| (و در صورت وجود، با احتمالات بوت‌استرپ ضرب می‌شود)
    - اگر m<=20: حل دقیق DP؛ وگرنه: طیفی + 2-opt
    """
    W = np.abs(B).astype(float)
    if use_prob is not None:
        W = W * use_prob  # اطلاعات پایداری را هم دخیل می‌کنیم
    np.fill_diagonal(W, 0.0)
    m = B.shape[0]
    if m <= 20:
        order = optimal_topo_order_by_dp(W)
    else:
        order = spectral_init_order(W)
        order = refine_order_2opt(order, W, iters=300)
    return order

def project_B_to_dag(B, order):
    """
    هر یالی که ترتیب را نقض کند صفر می‌کنیم تا DAG شود.
    در B[i,j] یالِ j->i است. باید pos[j] < pos[i] باشد.
    """
    Bp = B.copy()
    pos = {v:i for i,v in enumerate(order)}
    m = B.shape[0]
    for i in range(m):
        for j in range(m):
            if i == j:
                Bp[i,j] = 0.0
                continue
            if not (pos[j] < pos[i]):
                Bp[i,j] = 0.0
    np.fill_diagonal(Bp, 0.0)
    return Bp


def prune_B_advanced(B_median, B_prob=None, tau=0.1, keep_prob=0.6):
    """
    هرس B با آستانه و احتمال و سپس پروجکشن به DAG طبق ترتیبِ بهینه (غیرحریصانه).
    """
    B = B_median.copy()
    # آستانه‌ی نرم
    B[np.abs(B) < tau] = 0.0
    # فیلتر پایداری بوت‌استرپ
    if B_prob is not None:
        B[B_prob < keep_prob] = 0.0
    # ترتیب غیرحریصانه
    order = find_dag_order(B, use_prob=B_prob)
    # پروجکشن به DAG
    B = project_B_to_dag(B, order)
    return B, order

def prune_A_advanced_with_order(A_median, B_order, A_prob=None,
                                tau=0.1, keep_prob=0.6, r_min=2,
                                soft_threshold=False, soft_lambda=0.6):


    A = A_median.copy().astype(float)
    m, k = A.shape

    # تبدیل B_order به دیکشنری موقعیتی
    B_order = np.array(B_order)
    pos = {node: idx for idx, node in enumerate(B_order)}

    # 1) آستانهٔ سخت/نرم بر اساس قدرمطلق
    if soft_threshold:
        # soft-thresholding: shrink مقادیر کوچکتر از tau
        small_mask = np.abs(A) < tau
        A[small_mask] *= (1.0 - soft_lambda)  # shrink them
    else:
        A[np.abs(A) < tau] = 0.0

    # 2) فیلتر پایداری
    if A_prob is not None:
        lowprob = (A_prob < keep_prob)
        A[lowprob] = 0.0

    # 3) اعمال محدودیت های موقعیتی و تضمین r_min برای هر latent (ستون)
    nodes = np.arange(m)
    for j in range(k):
        # تعیین threshold_pos برای این ستون j به شکلی یکنواخت روی مشاهده‌شده‌ها
        # روش پیشنهادی: تقسیم مشاهده‌شده‌ها به k بخش و اجازه به بخش اول j+1
        if k >= m:
            threshold_pos = min(j, m-1)
        else:
            threshold_pos = int(np.floor((j+1) * m / float(k))) - 1
            threshold_pos = max(0, min(threshold_pos, m-1))

        # allowed nodes: آنهایی که position <= threshold_pos
        allowed_nodes = np.array([node for node in nodes if pos[node] <= threshold_pos], dtype=int)
        if allowed_nodes.size == 0:
            # اگر هیچ نودی مجاز نبود، اجازه بده همه باشند (fallback)
            allowed_nodes = nodes.copy()

        # صفر کردن/غیرفعال کردن گره‌های غیرمجاز
        mask_non_allowed = np.ones(m, dtype=bool)
        mask_non_allowed[allowed_nodes] = False
        A[mask_non_allowed, j] = 0.0

        # اگر تعداد غیرصفر کمتر از r_min است، بهترین نامزدها را از allowed_nodes بازگردان
        nz_idx = np.flatnonzero(np.abs(A[:, j]) > 0)
        need = r_min - len(nz_idx)
        if need > 0:
            # کاندیدها: allowed و فعلاً صفر
            zero_idx_allowed = np.intersect1d(np.where(A[:, j] == 0.0)[0], allowed_nodes, assume_unique=True)
            if zero_idx_allowed.size > 0:
                # امتیازدهی: ترکیب A_prob (اگر موجود) و قدرمطلق A_median
                if A_prob is not None:
                    score = (1e6 * (A_prob[:, j])) + np.abs(A_median[:, j])
                else:
                    score = np.abs(A_median[:, j])
                cand_sorted = zero_idx_allowed[np.argsort(-score[zero_idx_allowed])]
                pick = cand_sorted[:need]
                # مقدارِ بازگردانده را حداقل tau یا مقدار میانه قرار بده
                for idx in pick:
                    val = A_median[idx, j]
                    if np.abs(val) < tau:
                        val = np.sign(val) * tau
                    A[idx, j] = val

    return A







def prune_A_with_min_nonzero(A_median, r_min=2, tau=0.1):
    """
    هرس A فقط با شرط حداقل r_min (پیش‌فرض=2) عنصر غیرصفر در هر ستون.
    """
    A = A_median.copy().astype(float)
    m, k = A.shape

    for j in range(k):
        nz_idx = np.flatnonzero(np.abs(A[:, j]) > 0)
        need = r_min - len(nz_idx)

        if need > 0:
            # ستون خیلی کم غیرصفر دارد → اضافه کن
            zero_idx = np.where(A[:, j] == 0)[0]
            if zero_idx.size > 0:
                cand = zero_idx[:need]  # ساده: اولین‌ها
                for idx in cand:
                    sgn = 1.0 if A_median[idx, j] >= 0 else -1.0
                    if sgn == 0: sgn = 1.0
                    A[idx, j] = sgn * tau  # مقدار کوچک غیرصفر
    return A






# ---------------------------
# 8) Graphviz DOT writer & render PNG
# ---------------------------
def graph_to_dot_png(B, A, title, filename_png, var_prefix_obs='x', var_prefix_lat='h'):
    """
    Produce a DOT file and render with 'dot' to PNG.
    B: m x m or None
    A: m x k
    """
    m, k = A.shape
    lines = []
    lines.append('digraph G {')
    lines.append('  graph [rankdir=TB, splines=true];')
    lines.append('  node [shape=circle, style=filled, fontsize=10];')
    # latents cluster top
    lines.append('  { rank = same;')
    for j in range(k):
        name = f'{var_prefix_lat}{j+1}'
        lines.append(f'    "{name}" [fillcolor="#FFCC80", style="dashed,filled"];')
    lines.append('  }')
    # observed cluster bottom
    lines.append('  { rank = same;')
    for i in range(m):
        name = f'{var_prefix_obs}{i+1}'
        lines.append(f'    "{name}" [fillcolor="#AECBFA"];')
    lines.append('  }')
    # latent -> observed
    for i in range(m):
        for j in range(k):
            w = float(A[i, j])
            if abs(w) < 1e-12: continue
            pen = 1.0 + min(3.0, abs(w)*3.0)
            style = 'dashed' if abs(w) < 0.12 else 'solid'
            color = 'red' if w < 0 else 'black'
            lines.append(f'  "{var_prefix_lat}{j+1}" -> "{var_prefix_obs}{i+1}" [label="{w:.2f}", color="{color}", style="{style}", penwidth={pen}];')
    # observed -> observed
    if B is not None:
        for i in range(m):
            for j in range(m):
                if i == j: continue
                w = float(B[i, j])
                if abs(w) < 1e-12: continue
                pen = 1.0 + min(3.0, abs(w)*3.0)
                style = 'dashed' if abs(w) < 0.12 else 'solid'
                color = 'red' if w < 0 else 'black'
                lines.append(f'  "{var_prefix_obs}{j+1}" -> "{var_prefix_obs}{i+1}" [label="{w:.2f}", color="{color}", style="{style}", penwidth={pen}];')
    lines.append(f'  labelloc="t"; label="{title}"; fontsize=14;')
    lines.append('}')
    dot_text = "\n".join(lines)
    dotfile = filename_png + ".dot"
    with open(dotfile, "w", encoding="utf-8") as f:
        f.write(dot_text)
    # render
    try:
        subprocess.check_call(["dot", "-Tpng", dotfile, "-o", filename_png])
    except Exception as e:
        raise RuntimeError(f"Graphviz dot rendering failed: {e}\nMake sure 'dot' (graphviz) is installed.")


#----------------------------
# 10) Perecision- Recall
#----------------------------

def precision_recall(Mat_true, Mat_est):
    # Mat_true, Mat_est: صفر/غیرصفر
    Mat_true_bin = (Mat_true != 0).astype(int)
    Mat_est_bin = (Mat_est != 0).astype(int)
    tp = np.sum((Mat_true_bin==1) & (Mat_est_bin==1))
    fp = np.sum((Mat_true_bin==0) & (Mat_est_bin==1))
    fn = np.sum((Mat_true_bin==1) & (Mat_est_bin==0))
    prec = tp / (tp + fp + 1e-12)
    rec  = tp / (tp + fn + 1e-12)
    return prec, rec

# ---------------------------
# 12) full advanced pipeline
# ---------------------------
def run_pipeline_advanced(B_manual, M_manual, config: RunConfig):
    os.makedirs(config.outdir, exist_ok=True)
    m, k_true = M_manual.shape
    assert B_manual.shape == (m, m), "B must be (m x m)"
    rng = np.random.default_rng(config.scm.seed)

#-------------------------------------------------------------------------------------------------------------------------------------------------
    # 1) generate data
    print("Generating data...")

    X, Q_true, H=generate_data(B_manual, M_manual, n=config.n, sigma2=config.scm.sigma2, seed=config.scm.seed,
                mog_config=config.mog)

    A_true = np.linalg.inv(np.eye(m) - B_manual) @ M_manual  # effective latent->observed mixing seen in X
    print("A_true:\n",A_true)

    # 2) bootstrap stability (estimate A and B many times)
    print(f"Bootstrap stability: repeats={config.boot.repeats}, k_est={config.ifa.k_est} ...")
    A_prob, A_median, B_prob, B_median = bootstrap_stability(X, k_est=config.ifa.k_est,
                                                             repeats=config.boot.repeats,
                                                             tau_A=config.boot.tau_A, tau_B=config.boot.tau_B,
                                                          seed=config.scm.seed,sigma2=config.scm.sigma2,alpha_B=5e-2)
#-------------------------------------------------------------------------------------------------------------------------------------------------
    #3) purning A,B

    # b)purning B
    print("B_median_Before\n",B_median)
    print("Pruning B...")
    B_pruned, order_B = prune_B_advanced(B_median, B_prob=B_prob, tau=config.boot.tau_B, keep_prob=config.boot.keep_prob)
    print("B_median_After prun :\n",B_median)

    # b)purning A
    print("A_median_Before\n",A_median)
    print("Pruning A...")
    print("B_order :\n",order_B)
    A_pruned =prune_A_advanced_with_order(A_median, B_order=order_B, A_prob=A_prob, tau=config.boot.tau_A, keep_prob=config.boot.keep_prob, r_min=2)

    print("A_median_After prun:\n",A_median)

#-------------------------------------------------------------------------------------------------------------------------------------------------
    # 4) align A_median (k_est -> k_true) via CF alignment
    print("Aligning estimated factors to true factors using CF...")
    try:
        S_est_full = np.linalg.pinv(A_pruned) @ X  # k_est x n
        perm, signs = cf_alignment(Q_true, S_est_full, tgrid=np.linspace(-3,3,121))

        # ساخت A_aligned و M_aligned
        k_true = Q_true.shape[0]
        m = A_pruned.shape[0]

        A_aligned = np.zeros((m, k_true))


        for i in range(k_true):
            j = perm[i]
            A_aligned[:, i] = signs[i] * A_pruned[:, j]

    except Exception as e:
        print("CF alignment failed:", e)
        # fallback: cosine + Hungarian
        from scipy.optimize import linear_sum_assignment
        A_est_n = normalize_cols(A_pruned)
        A_true_n = normalize_cols(A_true)
        C = 1 - (A_true_n.T @ A_est_n)
        row, col = linear_sum_assignment(C)
        k_true = A_true.shape[1]
        A_aligned = np.zeros((m, k_true))

        for i,j in zip(row, col):
            if i < k_true and j < A_pruned.shape[1]:
                sgn = np.sign((A_true[:, i] * A_pruned[:, j]).sum())
                A_aligned[:, i] = sgn * A_pruned[:, j]

    print("A_alined:\n",A_aligned)
#-------------------------------------------------------------------------------------------------------------------------------------------------
    # 4) apply stability filtering: keep entries in A_aligned where magnitude > tau_A

    print("Applying stability filtering A...")


    # a) فیلتر بر اساس احتمال بوت‌استرپ

    A_prob_aligned = A_prob[:, perm] if 'perm' in locals() and perm is not None else A_prob
    A_final = A_aligned.copy()

    if A_prob_aligned is not None:
        lowprob_mask = (A_prob_aligned < config.boot.keep_prob)
        A_final[lowprob_mask] = 0.0

    # b) آستانه بر اساس قدرمطلق
    A_final[np.abs(A_final) < config.boot.tau_A ] = 0.0

    print("A_final:\n", A_final)


    print("Applying stability filtering to B...")

    # a) فیلتر بر اساس احتمال بوت‌استرپ
    B_final = B_pruned.copy()
    B_final[B_prob < config.boot.keep_prob] = 0.0

   # b) آستانه بر اساس قدرمطلق
    B_final[np.abs(B_final) < config.boot.tau_B ] = 0.0


    np.fill_diagonal(B_final, 0.0)

    print("B_final:\n",B_final)



 #-------------------------------------------------------------------------------------------------------------------------
    # 5) compute S_est for final A selection (for reconstruction)
    print("compute S_est for final A selection (for reconstruction)...")
    # select corresponding est columns used in alignment (perm), reconstruct S_est_sel
    # if perm is available:
    try:
        k_est = config.ifa.k_est
        if 'perm' in locals() and perm is not None:
            # S_est_full computed earlier
            S_sel = np.zeros((k_true, X.shape[1]))
            for i in range(k_true):
                j = perm[i]
                S_sel[i, :] = signs[i] * S_est_full[j, :]
        else:
            # fallback: least-squares S_sel = pinv(A_final) X
            S_sel = np.linalg.pinv(A_final) @ X
    except Exception:
        S_sel = np.linalg.pinv(A_final) @ X

    print("S_sel : \n",S_sel)
#-------------------------------------------------------------------------------------------------------------------------------------------------

    # 6) metrics
    print("Computing metrics...")

    # a) Precision-Recall
    prec_B, rec_B = precision_recall(B_manual, B_final)
    prec_A, rec_A = precision_recall(A_true, A_final)

    # b) CF distance  True latent vs estimate latent
    cf_dist = cf_distance(Q_true, S_sel, m_grid=200, seed=config.scm.seed)

    # c) Reconstruction error
    recon_err = np.linalg.norm(X - A_final @ S_sel, 'fro') / np.linalg.norm(X, 'fro')


    print(f"Precision B: {prec_B:.3f}, Recall B: {rec_B:.3f}")
    print(f"Precision A: {prec_A:.3f}, Recall A: {rec_A:.3f}")
    print(f"Reconstruction Error: {recon_err:.4f}")
    print(f"CF Distance: {cf_dist:.6f}")

    # 7) render graphs with graphviz to PNG
    png_true = os.path.join(config.outdir, "true_graph.png")
    png_est = os.path.join(config.outdir, "est_graph.png")

    print("Rendering graphs (requires 'dot' from graphviz)...")
    try:
        graph_to_dot_png(B_manual, M_manual, "True model (H M & B)", png_true)
        graph_to_dot_png(B_final, A_final, "Estimated (bootstrap-stable) model", png_est)


    except Exception as e:
        print("Graphviz render failed:", e)
        # fallback: save textual dot files for inspection

    # 8) show PNGs side-by-side (if produced)
    try:
        img1 = plt.imread(png_true)
        img2 = plt.imread(png_est)
        fig, axes = plt.subplots(1,2, figsize=(20,10))
        axes[0].imshow(img1); axes[0].axis('off'); axes[0].set_title("True model")
        axes[1].imshow(img2); axes[1].axis('off'); axes[1].set_title("Estimated model")

        # ذخیره نمودار مقایسه‌ای
        comparison_name = config.senario+"_model_comparison"+".png"
        plt.savefig(comparison_name, bbox_inches='tight', dpi=300)
        print(f"Comparison plot saved as: {comparison_name}")

        plt.show()
    except Exception as e:
        print("Show PNGs failed (maybe dot did not produce images):", e)






    # 10) save JSON results
    out = {
        "precision_B": float(prec_B),
        "recall_B": float(rec_B),
        "precision_A": float(prec_A),
        "recall_A": float(rec_A),
        "reconstruction_error": float(recon_err),
        "cf_distance": float(cf_dist),
        "B_true": B_manual.tolist(),
        "M_true": M_manual.tolist(),

        "A_final": A_final.tolist(),
        "B_final": B_final.tolist(),
        "A_prob": A_prob.tolist(),
        "B_prob": B_prob.tolist(),
        "perm": perm.tolist() if 'perm' in locals() and perm is not None else None,
        "signs": signs.tolist() if 'signs' in locals() and signs is not None else None
    }
    with open(os.path.join(config.outdir, "results_advanced_cf.json"), "w", encoding="utf-8") as f:
        json.dump(out, f, indent=2)
    print("Saved results to", os.path.join(config.outdir, "results_advanced_cf.json"))
    return out



# ---------------------------
# Example: use your matrices (replace with larger examples if desired)
# ---------------------------


if __name__ == "__main__":


    res=run_comprehensive_data_tests()
    print(res)

    '''
    B_manual = np.array([
    [0.0, 1.0, 0.0],
    [0.0, 0.0, 1.0],
    [0.0, 0.0, 0.0]
    ], dtype=float)

    M_manual = np.array([
    [1.0, 0.0],
    [0.0, 0.0],
    [0.0, 1.0]
    ], dtype=float)

    cfg = RunConfig(
      n=20000,
      scm = SCMConfig(B=B_manual, M=M_manual, sigma2=4.3, seed=42),
      mog = MOGConfig(n_components=2, means_strategy="spread",variance_strategy="low",weights_strategy="balanced"),
      ifa = IFAConfig(k_est=3, max_iter=500, tol=1e-7, reg_latent=5e-3),
      boot = BootstrapConfig(repeats=20, tau_A=0.2, tau_B=0.1, keep_prob=0.6),
      senario="Default",
      outdir="./ifa_cf_results_strong",
      use_graphviz=True
       )


    res = run_pipeline_advanced(B_manual, M_manual, cfg)
    print(res)
    '''

    #----------------------------------------------------------------







Running scenario: Complex_Structure
Generating data...
A_true:
 [[1.  0.  0.7 0.7]
 [0.  1.  0.7 0.7]
 [0.  0.  1.  0. ]
 [0.  0.  0.  1. ]
 [0.  0.  0.  0. ]]
Bootstrap stability: repeats=100, k_est=3 ...
Selected k: 4
Selected k: 4
Selected k: 4
Selected k: 4
Selected k: 4
Selected k: 4
Selected k: 4
Selected k: 4
