<div dir="rtl" style="text-align: right;">


این کد دو الگوریتم **زنجیره مارکوف مونت کارلو (MCMC)** و **نمونه‌گیری گیبس (Gibbs Sampling)** را برای یادگیری ساختار شبکه‌های بیزی مقایسه می‌کند. این مقایسه با استفاده از داده‌های شبیه‌سازی‌شده از شبکه بیزی آسیا انجام شده و عملکرد هر دو روش را با معیارهایی مانند دقت (Precision)، یادآوری (Recall) و فاصله همینگ ساختاری (SHD) ارزیابی می‌کند.

In [None]:
import numpy as np
from scipy.special import gammaln
from typing import List, Tuple, Dict, Optional, Union
import warnings
from collections import Counter
import time
import hashlib

class MCMCStructureLearner:
    """Monte Carlo Markov Chain search over DAGs for Bayesian Network structure learning."""

    def __init__(self, data: np.ndarray, node_states: List[int]):
        self.data = data
        self.node_states = np.array(node_states)
        self.n_nodes, self.n_cases = data.shape

    def learn_structure(self, nsamples: int = None, burnin: int = None, init_dag: Optional[np.ndarray] = None,
                       scoring_fn: str = 'bayesian', **kwargs) -> Tuple[List[np.ndarray], np.ndarray, np.ndarray]:
        if nsamples is None:
            nsamples = 100 * self.n_nodes
        if burnin is None:
            burnin = 5 * self.n_nodes
        if init_dag is None:
            init_dag = np.zeros((self.n_nodes, self.n_nodes), dtype=int)

        dag = init_dag.copy()

        total_steps = burnin + nsamples
        accept_ratio = np.zeros(total_steps)
        num_edges = np.zeros(total_steps)
        sampled_graphs = []

        num_accepts = 1
        num_rejects = 1

        for t in range(total_steps):
            if t % 200 == 0:
                print(f"  MCMC Iteration {t}/{total_steps}")

            dag, accept = self._take_step(dag, scoring_fn)

            num_edges[t] = np.sum(dag)
            num_accepts += accept
            num_rejects += (1 - accept)
            accept_ratio[t] = num_accepts / (num_accepts + num_rejects)

            if t >= burnin:
                sampled_graphs.append(dag.copy())

        return sampled_graphs, accept_ratio, num_edges

    def _take_step(self, dag: np.ndarray, scoring_fn: str) -> Tuple[np.ndarray, int]:
        neighbors = self._get_valid_neighbors(dag)

        if len(neighbors) == 0:
            return dag, 0

        idx = np.random.randint(len(neighbors))
        new_dag, operation, i, j = neighbors[idx]

        bayes_factor = self._calculate_bayes_factor(dag, new_dag, operation, i, j, scoring_fn)

        # Calculate new neighbors for the proposed DAG
        new_neighbors = self._get_valid_neighbors(new_dag)
        ratio = bayes_factor * len(neighbors) / max(1, len(new_neighbors))

        if np.random.random() < min(1, ratio):
            return new_dag, 1
        else:
            return dag, 0

    def _get_valid_neighbors(self, dag: np.ndarray) -> List[Tuple[np.ndarray, str, int, int]]:
        neighbors = []

        for i in range(self.n_nodes):
            for j in range(self.n_nodes):
                if i == j:
                    continue

                if dag[i, j] == 1:
                    # Delete edge
                    new_dag = dag.copy()
                    new_dag[i, j] = 0
                    neighbors.append((new_dag, 'delete', i, j))

                    # Reverse edge - check if it creates a cycle
                    new_dag = dag.copy()
                    new_dag[i, j] = 0
                    new_dag[j, i] = 1
                    if not self._creates_cycle_fast(new_dag):
                        neighbors.append((new_dag, 'reverse', i, j))

                else:
                    # Add edge - check if it creates a cycle
                    new_dag = dag.copy()
                    new_dag[i, j] = 1
                    if not self._creates_cycle_fast(new_dag):
                        neighbors.append((new_dag, 'add', i, j))

        return neighbors

    def _creates_cycle_fast(self, dag: np.ndarray) -> bool:
        """Fast cycle detection using DFS without NetworkX."""
        n_nodes = dag.shape[0]
        WHITE, GRAY, BLACK = 0, 1, 2
        colors = np.zeros(n_nodes, dtype=int)  # All WHITE initially

        def dfs_visit(node):
            if colors[node] == GRAY:  # Back edge found - cycle detected
                return True
            if colors[node] == BLACK:  # Already processed
                return False

            colors[node] = GRAY

            # Visit all children
            for child in range(n_nodes):
                if dag[node, child] == 1:
                    if dfs_visit(child):
                        return True

            colors[node] = BLACK
            return False

        # Check for cycles starting from each unvisited node
        for node in range(n_nodes):
            if colors[node] == WHITE:
                if dfs_visit(node):
                    return True

        return False

    def _calculate_bayes_factor(self, old_dag: np.ndarray, new_dag: np.ndarray,
                               operation: str, i: int, j: int, scoring_fn: str) -> float:
        unclamped_j = list(range(self.n_cases))
        old_parents_j = self._get_parents(old_dag, j)
        new_parents_j = self._get_parents(new_dag, j)

        old_score_j = self._score_family(j, old_parents_j, scoring_fn, unclamped_j)
        new_score_j = self._score_family(j, new_parents_j, scoring_fn, unclamped_j)

        bf1 = np.exp(new_score_j - old_score_j)

        if operation == 'reverse':
            unclamped_i = list(range(self.n_cases))
            old_parents_i = self._get_parents(old_dag, i)
            new_parents_i = self._get_parents(new_dag, i)

            old_score_i = self._score_family(i, old_parents_i, scoring_fn, unclamped_i)
            new_score_i = self._score_family(i, new_parents_i, scoring_fn, unclamped_i)

            bf2 = np.exp(new_score_i - old_score_i)
        else:
            bf2 = 1.0

        return bf1 * bf2

    def _get_parents(self, dag: np.ndarray, node: int) -> List[int]:
        return list(np.where(dag[:, node] == 1)[0])

    def _score_family(self, node: int, parents: List[int], scoring_fn: str, unclamped_cases: List[int]) -> float:
        if scoring_fn == 'bayesian':
            return self._bayesian_score(node, parents, unclamped_cases)
        elif scoring_fn == 'bic':
            return self._bic_score(node, parents, unclamped_cases)
        else:
            raise ValueError(f"Unknown scoring function: {scoring_fn}")

    def _bayesian_score(self, node: int, parents: List[int], unclamped_cases: List[int]) -> float:
        if len(unclamped_cases) == 0:
            return 0.0

        if len(parents) == 0:
            parent_configs = np.ones(len(unclamped_cases), dtype=int)
            n_parent_configs = 1
        else:
            parent_data = self.data[parents][:, unclamped_cases]
            parent_configs, n_parent_configs = self._get_configurations(parent_data, self.node_states[parents])

        node_data = self.data[node, unclamped_cases]
        node_states = self.node_states[node]
        alpha = 1.0

        score = 0.0

        for config in range(n_parent_configs):
            config_mask = (parent_configs == config)
            if not np.any(config_mask):
                continue

            config_data = node_data[config_mask]
            counts = np.zeros(node_states)
            for state in range(node_states):
                counts[state] = np.sum(config_data == state)

            n_config = len(config_data)
            if n_config > 0:
                score += (gammaln(alpha) - gammaln(alpha + n_config) +
                         np.sum(gammaln(alpha/node_states + counts) - gammaln(alpha/node_states)))

        return score

    def _bic_score(self, node: int, parents: List[int], unclamped_cases: List[int]) -> float:
        if len(unclamped_cases) == 0:
            return 0.0

        log_likelihood = 0.0

        if len(parents) == 0:
            node_data = self.data[node, unclamped_cases]
            counts = np.bincount(node_data, minlength=self.node_states[node])
            total_count = len(node_data)

            if total_count > 0:
                probs = counts / total_count
                probs = np.maximum(probs, 1e-10)  # Avoid log(0)

                for state in range(self.node_states[node]):
                    if counts[state] > 0:
                        log_likelihood += counts[state] * np.log(probs[state])
        else:
            parent_data = self.data[parents][:, unclamped_cases]
            node_data = self.data[node, unclamped_cases]
            parent_configs, n_parent_configs = self._get_configurations(parent_data, self.node_states[parents])

            for config in range(n_parent_configs):
                config_mask = (parent_configs == config)
                if not np.any(config_mask):
                    continue

                config_node_data = node_data[config_mask]
                counts = np.bincount(config_node_data, minlength=self.node_states[node])

                if len(config_node_data) > 0:
                    probs = counts / len(config_node_data)
                    probs = np.maximum(probs, 1e-10)  # Avoid log(0)

                    for state in range(self.node_states[node]):
                        if counts[state] > 0:
                            log_likelihood += counts[state] * np.log(probs[state])

        n_params = (self.node_states[node] - 1) * np.prod(self.node_states[parents]) if len(parents) > 0 else (self.node_states[node] - 1)
        penalty = 0.5 * n_params * np.log(len(unclamped_cases))

        return log_likelihood - penalty

    def _get_configurations(self, data: np.ndarray, states: np.ndarray) -> Tuple[np.ndarray, int]:
        if len(data.shape) == 1:
            return data, states[0] if len(states) == 1 else max(states)

        n_vars, n_cases = data.shape
        configs = np.zeros(n_cases, dtype=int)
        multiplier = 1

        for i in range(n_vars):
            configs += data[i] * multiplier
            multiplier *= states[i]

        return configs, multiplier


class GibbsStructureLearner:
    """Gibbs Sampling for Bayesian Network structure learning."""

    def __init__(self, data: np.ndarray, node_states: List[int]):
        self.data = data
        self.node_states = np.array(node_states)
        self.n_nodes, self.n_cases = data.shape

    def learn_structure(self, nsamples: int = None, burnin: int = None, init_dag: Optional[np.ndarray] = None,
                       scoring_fn: str = 'bayesian', max_parents: int = 3, **kwargs) -> Tuple[List[np.ndarray], np.ndarray]:
        if nsamples is None:
            nsamples = 100 * self.n_nodes
        if burnin is None:
            burnin = 5 * self.n_nodes
        if init_dag is None:
            init_dag = np.zeros((self.n_nodes, self.n_nodes), dtype=int)

        dag = init_dag.copy()

        total_steps = burnin + nsamples
        sampled_graphs = []
        edge_counts = np.zeros((self.n_nodes, self.n_nodes))

        for t in range(total_steps):
            if t % 200 == 0:
                print(f"  Gibbs Iteration {t}/{total_steps}")

            dag = self._gibbs_sweep(dag, scoring_fn, max_parents)

            if t >= burnin:
                sampled_graphs.append(dag.copy())
                edge_counts += dag

        edge_probabilities = edge_counts / nsamples if nsamples > 0 else edge_counts

        return sampled_graphs, edge_probabilities

    def _gibbs_sweep(self, dag: np.ndarray, scoring_fn: str, max_parents: int) -> np.ndarray:
        new_dag = dag.copy()

        edges = []
        for i in range(self.n_nodes):
            for j in range(self.n_nodes):
                if i != j:
                    edges.append((i, j))

        np.random.shuffle(edges)

        for i, j in edges:
            new_dag[i, j] = self._sample_edge_gibbs(new_dag, i, j, scoring_fn, max_parents)

        return new_dag

    def _sample_edge_gibbs(self, dag: np.ndarray, i: int, j: int, scoring_fn: str, max_parents: int) -> int:
        current_parents = np.sum(dag[:, j])
        if dag[i, j] == 0 and current_parents >= max_parents:
            return 0

        scores = []

        for edge_state in [0, 1]:
            temp_dag = dag.copy()
            temp_dag[i, j] = edge_state

            if edge_state == 1 and self._creates_cycle_fast(temp_dag):
                scores.append(-np.inf)
            else:
                unclamped_j = list(range(self.n_cases))
                parents_j = self._get_parents(temp_dag, j)
                score = self._score_family(j, parents_j, scoring_fn, unclamped_j)
                scores.append(score)

        scores = np.array(scores)
        if np.all(scores == -np.inf):
            return dag[i, j]

        # Numerical stability
        max_score = np.max(scores[scores != -np.inf])
        scores = scores - max_score
        scores[scores < -700] = -700  # Prevent underflow

        probs = np.exp(scores)
        if np.sum(probs) == 0:
            return dag[i, j]

        probs = probs / np.sum(probs)

        return np.random.choice([0, 1], p=probs)

    def _creates_cycle_fast(self, dag: np.ndarray) -> bool:
        """Fast cycle detection using DFS."""
        n_nodes = dag.shape[0]
        WHITE, GRAY, BLACK = 0, 1, 2
        colors = np.zeros(n_nodes, dtype=int)

        def dfs_visit(node):
            if colors[node] == GRAY:
                return True
            if colors[node] == BLACK:
                return False

            colors[node] = GRAY

            for child in range(n_nodes):
                if dag[node, child] == 1:
                    if dfs_visit(child):
                        return True

            colors[node] = BLACK
            return False

        for node in range(n_nodes):
            if colors[node] == WHITE:
                if dfs_visit(node):
                    return True

        return False

    def _get_parents(self, dag: np.ndarray, node: int) -> List[int]:
        return list(np.where(dag[:, node] == 1)[0])

    def _score_family(self, node: int, parents: List[int], scoring_fn: str, unclamped_cases: List[int]) -> float:
        if scoring_fn == 'bayesian':
            return self._bayesian_score(node, parents, unclamped_cases)
        elif scoring_fn == 'bic':
            return self._bic_score(node, parents, unclamped_cases)
        else:
            raise ValueError(f"Unknown scoring function: {scoring_fn}")

    def _bayesian_score(self, node: int, parents: List[int], unclamped_cases: List[int]) -> float:
        if len(unclamped_cases) == 0:
            return 0.0

        if len(parents) == 0:
            parent_configs = np.ones(len(unclamped_cases), dtype=int)
            n_parent_configs = 1
        else:
            parent_data = self.data[parents][:, unclamped_cases]
            parent_configs, n_parent_configs = self._get_configurations(parent_data, self.node_states[parents])

        node_data = self.data[node, unclamped_cases]
        node_states = self.node_states[node]
        alpha = 1.0

        score = 0.0

        for config in range(n_parent_configs):
            config_mask = (parent_configs == config)
            if not np.any(config_mask):
                continue

            config_data = node_data[config_mask]
            counts = np.zeros(node_states)
            for state in range(node_states):
                counts[state] = np.sum(config_data == state)

            n_config = len(config_data)
            if n_config > 0:
                score += (gammaln(alpha) - gammaln(alpha + n_config) +
                         np.sum(gammaln(alpha/node_states + counts) - gammaln(alpha/node_states)))

        return score

    def _bic_score(self, node: int, parents: List[int], unclamped_cases: List[int]) -> float:
        if len(unclamped_cases) == 0:
            return 0.0

        log_likelihood = 0.0

        if len(parents) == 0:
            node_data = self.data[node, unclamped_cases]
            counts = np.bincount(node_data, minlength=self.node_states[node])
            total_count = len(node_data)

            if total_count > 0:
                probs = counts / total_count
                probs = np.maximum(probs, 1e-10)

                for state in range(self.node_states[node]):
                    if counts[state] > 0:
                        log_likelihood += counts[state] * np.log(probs[state])
        else:
            parent_data = self.data[parents][:, unclamped_cases]
            node_data = self.data[node, unclamped_cases]
            parent_configs, n_parent_configs = self._get_configurations(parent_data, self.node_states[parents])

            for config in range(n_parent_configs):
                config_mask = (parent_configs == config)
                if not np.any(config_mask):
                    continue

                config_node_data = node_data[config_mask]
                counts = np.bincount(config_node_data, minlength=self.node_states[node])

                if len(config_node_data) > 0:
                    probs = counts / len(config_node_data)
                    probs = np.maximum(probs, 1e-10)

                    for state in range(self.node_states[node]):
                        if counts[state] > 0:
                            log_likelihood += counts[state] * np.log(probs[state])

        n_params = (self.node_states[node] - 1) * np.prod(self.node_states[parents]) if len(parents) > 0 else (self.node_states[node] - 1)
        penalty = 0.5 * n_params * np.log(len(unclamped_cases))

        return log_likelihood - penalty

    def _get_configurations(self, data: np.ndarray, states: np.ndarray) -> Tuple[np.ndarray, int]:
        if len(data.shape) == 1:
            return data, states[0] if len(states) == 1 else max(states)

        n_vars, n_cases = data.shape
        configs = np.zeros(n_cases, dtype=int)
        multiplier = 1

        for i in range(n_vars):
            configs += data[i] * multiplier
            multiplier *= states[i]

        return configs, multiplier


def generate_asia_data(n_samples: int = 5000, seed: int = 42) -> Tuple[np.ndarray, List[str], np.ndarray]:
    """Generate data from the Asia Bayesian Network."""
    np.random.seed(seed)

    var_names = ['Asia', 'Smoking', 'Tuberculosis', 'LungCancer', 'Bronchitis',
                 'Either', 'Xray', 'Dyspnoea']

    true_dag = np.array([
        [0, 0, 1, 0, 0, 0, 0, 0],  # Asia -> Tuberculosis
        [0, 0, 0, 1, 1, 0, 0, 0],  # Smoking -> LungCancer, Bronchitis
        [0, 0, 0, 0, 0, 1, 0, 0],  # Tuberculosis -> Either
        [0, 0, 0, 0, 0, 1, 0, 0],  # LungCancer -> Either
        [0, 0, 0, 0, 0, 0, 0, 1],  # Bronchitis -> Dyspnoea
        [0, 0, 0, 0, 0, 0, 1, 1],  # Either -> Xray, Dyspnoea
        [0, 0, 0, 0, 0, 0, 0, 0],  # Xray (no children)
        [0, 0, 0, 0, 0, 0, 0, 0]   # Dyspnoea (no children)
    ])

    data = np.zeros((8, n_samples), dtype=int)

    for i in range(n_samples):
        asia = np.random.binomial(1, 0.01)
        smoking = np.random.binomial(1, 0.5)

        tuberculosis = np.random.binomial(1, 0.05 if asia == 1 else 0.01)
        lung_cancer = np.random.binomial(1, 0.1 if smoking == 1 else 0.01)
        bronchitis = np.random.binomial(1, 0.6 if smoking == 1 else 0.3)
        either = int(tuberculosis == 1 or lung_cancer == 1)
        xray = np.random.binomial(1, 0.98 if either == 1 else 0.05)

        if either == 1 and bronchitis == 1:
            dyspnoea = np.random.binomial(1, 0.9)
        elif either == 1 or bronchitis == 1:
            dyspnoea = np.random.binomial(1, 0.7)
        else:
            dyspnoea = np.random.binomial(1, 0.1)

        data[:, i] = [asia, smoking, tuberculosis, lung_cancer,
                     bronchitis, either, xray, dyspnoea]

    return data, var_names, true_dag

def graph_to_hash(graph: np.ndarray) -> str:
    """Convert a graph to a hash for efficient comparison."""
    return hashlib.md5(graph.tobytes()).hexdigest()

def evaluate_method(sampled_graphs: List[np.ndarray], true_dag: np.ndarray,
                   method_name: str, edge_probabilities: Optional[np.ndarray] = None) -> Dict:
    """Evaluate a structure learning method."""

    if not sampled_graphs:
        return {"error": "No graphs learned", "method": method_name}

    # Find most frequent structure using hashes
    graph_hashes = [graph_to_hash(graph) for graph in sampled_graphs]
    hash_counts = Counter(graph_hashes)
    most_common_hash = hash_counts.most_common(1)[0][0]

    # Find the actual graph corresponding to the most common hash
    most_frequent_graph = None
    for i, graph_hash in enumerate(graph_hashes):
        if graph_hash == most_common_hash:
            most_frequent_graph = sampled_graphs[i].copy()
            break

    # Also create consensus graph if edge probabilities available
    consensus_graph = None
    if edge_probabilities is not None:
        consensus_graph = (edge_probabilities > 0.5).astype(int)

    # Use consensus graph if available, otherwise most frequent
    eval_graph = consensus_graph if consensus_graph is not None else most_frequent_graph

    # Calculate metrics
    true_edges = set()
    learned_edges = set()

    n_nodes = true_dag.shape[0]
    for i in range(n_nodes):
        for j in range(n_nodes):
            if true_dag[i, j] == 1:
                true_edges.add((i, j))
            if eval_graph[i, j] == 1:
                learned_edges.add((i, j))

    # Calculate precision, recall, F1
    true_positives = len(true_edges.intersection(learned_edges))
    false_positives = len(learned_edges - true_edges)
    false_negatives = len(true_edges - learned_edges)

    precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
    recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
    f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

    # Structural Hamming Distance
    shd = false_positives + false_negatives

    return {
        "method": method_name,
        "most_frequent_graph": most_frequent_graph,
        "consensus_graph": consensus_graph,
        "eval_graph": eval_graph,
        "frequency": hash_counts.most_common(1)[0][1],
        "total_samples": len(sampled_graphs),
        "unique_structures": len(hash_counts),
        "true_positives": true_positives,
        "false_positives": false_positives,
        "false_negatives": false_negatives,
        "precision": precision,
        "recall": recall,
        "f1_score": f1_score,
        "structural_hamming_distance": shd,
        "true_edges": true_edges,
        "learned_edges": learned_edges,
        "edge_probabilities": edge_probabilities
    }

def print_network_info(dag: np.ndarray, var_names: List[str], title: str = "Network Structure"):
    """Print readable network structure."""
    print(f"\n{title}:")
    print("=" * len(title))

    edges = []
    for i in range(len(var_names)):
        for j in range(len(var_names)):
            if dag[i, j] == 1:
                edges.append(f"{var_names[i]} -> {var_names[j]}")

    if edges:
        for edge in edges:
            print(f"  {edge}")
        print(f"\nTotal edges: {len(edges)}")
    else:
        print("  No edges (independent variables)")

def compare_methods_detailed(results_mcmc: Dict, results_gibbs: Dict, var_names: List[str]):
    """Detailed comparison of both methods."""

    print("\n" + "="*80)
    print("DETAILED COMPARISON OF METHODS")
    print("="*80)

    # Summary table
    print(f"\n{'Metric':<25} {'MCMC':<15} {'Gibbs':<15} {'Winner':<15}")
    print("-" * 70)

    metrics = [
        ('Precision', 'precision'),
        ('Recall', 'recall'),
        ('F1-Score', 'f1_score'),
        ('SHD (lower better)', 'structural_hamming_distance'),
        ('Unique Structures', 'unique_structures')
    ]

    winners = {'MCMC': 0, 'Gibbs': 0, 'Tie': 0}

    for metric_name, metric_key in metrics:
        mcmc_val = results_mcmc.get(metric_key, 0)
        gibbs_val = results_gibbs.get(metric_key, 0)

        if metric_key == 'structural_hamming_distance':
            # Lower is better for SHD
            if mcmc_val < gibbs_val:
                winner = 'MCMC'
                winners['MCMC'] += 1
            elif gibbs_val < mcmc_val:
                winner = 'Gibbs'
                winners['Gibbs'] += 1
            else:
                winner = 'Tie'
                winners['Tie'] += 1
        else:
            # Higher is better for other metrics
            if mcmc_val > gibbs_val:
                winner = 'MCMC'
                winners['MCMC'] += 1
            elif gibbs_val > mcmc_val:
                winner = 'Gibbs'
                winners['Gibbs'] += 1
            else:
                winner = 'Tie'
                winners['Tie'] += 1

        print(f"{metric_name:<25} {mcmc_val:<15.3f} {gibbs_val:<15.3f} {winner:<15}")

    print("-" * 70)
    print(f"{'Overall Winner':<25} MCMC: {winners['MCMC']}, Gibbs: {winners['Gibbs']}, Ties: {winners['Tie']}")

    # Show structures side by side
    print(f"\n{'='*40} STRUCTURES {'='*40}")

    print_network_info(results_mcmc['eval_graph'], var_names, "MCMC Best Structure")
    print_network_info(results_gibbs['eval_graph'], var_names, "Gibbs Best Structure")

    # Edge-by-edge comparison
    print(f"\n{'='*35} EDGE COMPARISON {'='*35}")

    true_edges = results_mcmc['true_edges']
    mcmc_edges = results_mcmc['learned_edges']
    gibbs_edges = results_gibbs['learned_edges']

    all_possible_edges = true_edges.union(mcmc_edges).union(gibbs_edges)

    print(f"{'Edge':<20} {'True':<8} {'MCMC':<8} {'Gibbs':<8} {'Status'}")
    print("-" * 60)

    for i, j in sorted(all_possible_edges):
        edge_str = f"{var_names[i]}->{var_names[j]}"
        true_val = '✓' if (i, j) in true_edges else '✗'
        mcmc_val = '✓' if (i, j) in mcmc_edges else '✗'
        gibbs_val = '✓' if (i, j) in gibbs_edges else '✗'

        # Determine status
        if (i, j) in true_edges:
            if (i, j) in mcmc_edges and (i, j) in gibbs_edges:
                status = "Both found ✓"
            elif (i, j) in mcmc_edges:
                status = "Only MCMC +"
            elif (i, j) in gibbs_edges:
                status = "Only Gibbs +"
            else:
                status = "Both missed ✗"
        else:
            if (i, j) in mcmc_edges and (i, j) in gibbs_edges:
                status = "Both wrong ✗"
            elif (i, j) in mcmc_edges:
                status = "MCMC wrong -"
            elif (i, j) in gibbs_edges:
                status = "Gibbs wrong -"
            else:
                status = "Both correct ✓"

        print(f"{edge_str:<20} {true_val:<8} {mcmc_val:<8} {gibbs_val:<8} {status}")

def run_comparison_study(data: np.ndarray, node_states: List[int], var_names: List[str],
                        true_dag: np.ndarray, nsamples: int = 500, burnin: int = 200):
    """Run complete comparison study with improved parameters."""

    print("BAYESIAN NETWORK STRUCTURE LEARNING COMPARISON")
    print("=" * 60)
    print(f"Dataset: {data.shape[1]} samples, {data.shape[0]} variables")
    print(f"Variables: {', '.join(var_names)}")
    print(f"Samples per method: {nsamples}, Burn-in: {burnin}")

    # Show true structure
    print_network_info(true_dag, var_names, "TRUE NETWORK STRUCTURE")

    print(f"\nStarting structure learning comparison...")
    print(f"This may take several minutes...\n")

    # Run MCMC with better initialization
    print("1. Running MCMC (Metropolis-Hastings)...")
    start_time = time.time()

    # Initialize with a few random edges to help exploration
    init_dag = np.zeros((8, 8), dtype=int)
    # Add a couple of random edges to start
    init_dag[1, 4] = 1  # Smoking -> Bronchitis (true edge)
    init_dag[0, 2] = 1  # Asia -> Tuberculosis (true edge)

    mcmc_learner = MCMCStructureLearner(data, node_states)
    mcmc_graphs, accept_ratio, num_edges = mcmc_learner.learn_structure(
        nsamples=nsamples,
        burnin=burnin,
        init_dag=init_dag,
        scoring_fn='bic'  # Try BIC instead of Bayesian
    )

    mcmc_time = time.time() - start_time
    print(f"   MCMC completed in {mcmc_time:.1f} seconds")
    print(f"   Final acceptance ratio: {accept_ratio[-1]:.3f}")
    print(f"   Average edges: {np.mean([np.sum(g) for g in mcmc_graphs]):.1f}")

    # Run Gibbs with more iterations
    print("\n2. Running Gibbs Sampling...")
    start_time = time.time()

    gibbs_learner = GibbsStructureLearner(data, node_states)
    gibbs_graphs, edge_probabilities = gibbs_learner.learn_structure(
        nsamples=nsamples,
        burnin=burnin,
        init_dag=init_dag.copy(),
        scoring_fn='bic',  # Try BIC instead of Bayesian
        max_parents=4      # Allow more parents
    )

    gibbs_time = time.time() - start_time
    print(f"   Gibbs completed in {gibbs_time:.1f} seconds")
    print(f"   Average edges: {np.mean([np.sum(g) for g in gibbs_graphs]):.1f}")

    # Rest remains the same...
    print("\n3. Evaluating Results...")

    results_mcmc = evaluate_method(mcmc_graphs, true_dag, "MCMC")
    results_gibbs = evaluate_method(gibbs_graphs, true_dag, "Gibbs", edge_probabilities)

    # Add timing information
    results_mcmc['time'] = mcmc_time
    results_gibbs['time'] = gibbs_time

    # Print results...
    print(f"\n{'='*25} BASIC RESULTS {'='*25}")

    print(f"\nMCMC Results:")
    print(f"  Time: {mcmc_time:.1f}s")
    print(f"  Precision: {results_mcmc['precision']:.3f}")
    print(f"  Recall: {results_mcmc['recall']:.3f}")
    print(f"  F1-Score: {results_mcmc['f1_score']:.3f}")
    print(f"  SHD: {results_mcmc['structural_hamming_distance']}")
    print(f"  Unique structures: {results_mcmc['unique_structures']}")

    print(f"\nGibbs Results:")
    print(f"  Time: {gibbs_time:.1f}s")
    print(f"  Precision: {results_gibbs['precision']:.3f}")
    print(f"  Recall: {results_gibbs['recall']:.3f}")
    print(f"  F1-Score: {results_gibbs['f1_score']:.3f}")
    print(f"  SHD: {results_gibbs['structural_hamming_distance']}")
    print(f"  Unique structures: {results_gibbs['unique_structures']}")

    # Show detailed comparison
    compare_methods_detailed(results_mcmc, results_gibbs, var_names)

    # Show edge probabilities from Gibbs
    if edge_probabilities is not None:
        print(f"\n{'='*25} GIBBS EDGE PROBABILITIES {'='*25}")
        print("All edges with their probabilities:")

        for i in range(len(var_names)):
            for j in range(len(var_names)):
                if i != j and edge_probabilities[i, j] > 0.01:  # Show even small probabilities
                    true_edge = "✓" if (i, j) in results_gibbs['true_edges'] else "✗"
                    print(f"  {var_names[i]} -> {var_names[j]}: {edge_probabilities[i, j]:.3f} {true_edge}")

    return results_mcmc, results_gibbs

# Main execution
if __name__ == "__main__":
    # Generate Asia dataset with more samples
    print("Generating Asia dataset...")
    data, var_names, true_dag = generate_asia_data(n_samples=5000, seed=42)
    node_states = [2] * 8  # All binary variables

    # 🔥 این خط رو تغییر بده:
    # قبلی:
    # mcmc_results, gibbs_results = run_comparison_study(
    #     data, node_states, var_names, true_dag,
    #     nsamples=500,   # More samples
    #     burnin=200      # More burn-in
    # )

    # 🎯 جدید:
    mcmc_results, gibbs_results = run_comparison_study(
        data, node_states, var_names, true_dag,
        nsamples=5000,
        burnin=2000,
        # scoring_fn='bayesian',  # تغییر به Bayesian Score
        # max_parents=4
    )

    # باقی کد همون باشه...
    print(f"\n{'='*60}")
    print("COMPARISON STUDY COMPLETED")
    print(f"{'='*60}")

    # Final summary
    if mcmc_results['f1_score'] > gibbs_results['f1_score']:
        print(f"\n🏆 MCMC wins with F1-score of {mcmc_results['f1_score']:.3f} vs {gibbs_results['f1_score']:.3f}")
    elif gibbs_results['f1_score'] > mcmc_results['f1_score']:
        print(f"\n🏆 Gibbs wins with F1-score of {gibbs_results['f1_score']:.3f} vs {mcmc_results['f1_score']:.3f}")
    else:
        print(f"\n🤝 It's a tie! Both methods achieved F1-score of {mcmc_results['f1_score']:.3f}")

    print("\nDone! 🎉")


Generating Asia dataset...
BAYESIAN NETWORK STRUCTURE LEARNING COMPARISON
Dataset: 5000 samples, 8 variables
Variables: Asia, Smoking, Tuberculosis, LungCancer, Bronchitis, Either, Xray, Dyspnoea
Samples per method: 5000, Burn-in: 2000

TRUE NETWORK STRUCTURE:
  Asia -> Tuberculosis
  Smoking -> LungCancer
  Smoking -> Bronchitis
  Tuberculosis -> Either
  LungCancer -> Either
  Bronchitis -> Dyspnoea
  Either -> Xray
  Either -> Dyspnoea

Total edges: 8

Starting structure learning comparison...
This may take several minutes...

1. Running MCMC (Metropolis-Hastings)...
  MCMC Iteration 0/7000


  bf1 = np.exp(new_score_j - old_score_j)
  bf2 = np.exp(new_score_i - old_score_i)
  return bf1 * bf2


  MCMC Iteration 200/7000
  MCMC Iteration 400/7000
  MCMC Iteration 600/7000
  MCMC Iteration 800/7000
  MCMC Iteration 1000/7000
  MCMC Iteration 1200/7000
  MCMC Iteration 1400/7000
  MCMC Iteration 1600/7000
  MCMC Iteration 1800/7000
  MCMC Iteration 2000/7000
  MCMC Iteration 2200/7000
  MCMC Iteration 2400/7000
  MCMC Iteration 2600/7000
  MCMC Iteration 2800/7000
  MCMC Iteration 3000/7000
  MCMC Iteration 3200/7000
  MCMC Iteration 3400/7000
  MCMC Iteration 3600/7000
  MCMC Iteration 3800/7000
  MCMC Iteration 4000/7000
  MCMC Iteration 4200/7000
  MCMC Iteration 4400/7000
  MCMC Iteration 4600/7000
  MCMC Iteration 4800/7000
  MCMC Iteration 5000/7000
  MCMC Iteration 5200/7000
  MCMC Iteration 5400/7000
  MCMC Iteration 5600/7000
  MCMC Iteration 5800/7000
  MCMC Iteration 6000/7000
  MCMC Iteration 6200/7000
  MCMC Iteration 6400/7000
  MCMC Iteration 6600/7000
  MCMC Iteration 6800/7000
   MCMC completed in 34.7 seconds
   Final acceptance ratio: 0.097
   Average edges: 9

In [None]:
import numpy as np
from scipy.special import gammaln
from typing import List, Tuple, Dict, Optional, Union
import warnings
from collections import Counter
import time
import hashlib
from functools import lru_cache

class MCMCStructureLearner:
    """Monte Carlo Markov Chain search over DAGs for Bayesian Network structure learning."""

    def __init__(self, data: np.ndarray, node_states: List[int]):
        self.data = data
        self.node_states = np.array(node_states)
        self.n_nodes, self.n_cases = data.shape

    def learn_structure(self, nsamples: int = None, burnin: int = None, init_dag: Optional[np.ndarray] = None,
                       scoring_fn: str = 'bayesian', max_parents: int = 3, **kwargs) -> Tuple[List[np.ndarray], np.ndarray, np.ndarray]:
        if nsamples is None:
            nsamples = 200 * self.n_nodes  # افزایش تعداد نمونه‌ها
        if burnin is None:
            burnin = 10 * self.n_nodes  # افزایش دوره burn-in
        if init_dag is None:
            init_dag = np.zeros((self.n_nodes, self.n_nodes), dtype=int)

        dag = init_dag.copy()

        total_steps = burnin + nsamples
        accept_ratio = np.zeros(total_steps)
        num_edges = np.zeros(total_steps)
        sampled_graphs = []

        num_accepts = 1
        num_rejects = 1

        for t in range(total_steps):
            if t % 200 == 0:
                print(f"  MCMC Iteration {t}/{total_steps}")

            dag, accept = self._take_step(dag, scoring_fn, t, total_steps, max_parents)

            num_edges[t] = np.sum(dag)
            num_accepts += accept
            num_rejects += (1 - accept)
            accept_ratio[t] = num_accepts / (num_accepts + num_rejects)

            if t >= burnin:
                sampled_graphs.append(dag.copy())

        return sampled_graphs, accept_ratio, num_edges

    def _take_step(self, dag: np.ndarray, scoring_fn: str, step: int, total_steps: int, max_parents: int) -> Tuple[np.ndarray, int]:
        neighbors = self._get_valid_neighbors(dag, max_parents)

        if len(neighbors) == 0:
            return dag, 0

        idx = np.random.randint(len(neighbors))
        new_dag, operation, i, j = neighbors[idx]

        bayes_factor = self._calculate_bayes_factor(dag, new_dag, operation, i, j, scoring_fn)

        # Calculate new neighbors for the proposed DAG
        new_neighbors = self._get_valid_neighbors(new_dag, max_parents)
        ratio = bayes_factor * len(neighbors) / max(1, len(new_neighbors))

        # Simulated Annealing
        temperature = 1.0 / (1 + step / total_steps)  # کاهش تدریجی دما
        if np.random.random() < min(1, ratio ** (1 / temperature)):
            return new_dag, 1
        else:
            return dag, 0

    def _get_valid_neighbors(self, dag: np.ndarray, max_parents: int) -> List[Tuple[np.ndarray, str, int, int]]:
        neighbors = []

        for i in range(self.n_nodes):
            for j in range(self.n_nodes):
                if i == j:
                    continue

                if dag[i, j] == 1:
                    # Delete edge
                    new_dag = dag.copy()
                    new_dag[i, j] = 0
                    neighbors.append((new_dag, 'delete', i, j))

                    # Reverse edge - check if it creates a cycle
                    new_dag = dag.copy()
                    new_dag[i, j] = 0
                    new_dag[j, i] = 1
                    if np.sum(new_dag[:, i]) <= max_parents and not self._creates_cycle_fast(new_dag):
                        neighbors.append((new_dag, 'reverse', i, j))

                else:
                    # Add edge - check if it creates a cycle and max parents
                    if np.sum(dag[:, j]) < max_parents:
                        new_dag = dag.copy()
                        new_dag[i, j] = 1
                        if not self._creates_cycle_fast(new_dag):
                            neighbors.append((new_dag, 'add', i, j))

        return neighbors

    def _creates_cycle_fast(self, dag: np.ndarray) -> bool:
        """Fast cycle detection using DFS without NetworkX."""
        n_nodes = dag.shape[0]
        WHITE, GRAY, BLACK = 0, 1, 2
        colors = np.zeros(n_nodes, dtype=int)  # All WHITE initially

        def dfs_visit(node):
            if colors[node] == GRAY:  # Back edge found - cycle detected
                return True
            if colors[node] == BLACK:  # Already processed
                return False

            colors[node] = GRAY

            # Visit all children
            for child in range(n_nodes):
                if dag[node, child] == 1:
                    if dfs_visit(child):
                        return True

            colors[node] = BLACK
            return False

        # Check for cycles starting from each unvisited node
        for node in range(n_nodes):
            if colors[node] == WHITE:
                if dfs_visit(node):
                    return True

        return False

    def _calculate_bayes_factor(self, old_dag: np.ndarray, new_dag: np.ndarray,
                               operation: str, i: int, j: int, scoring_fn: str) -> float:
        unclamped_j = list(range(self.n_cases))
        old_parents_j = self._get_parents(old_dag, j)
        new_parents_j = self._get_parents(new_dag, j)

        old_score_j = self._score_family(j, tuple(old_parents_j), scoring_fn, tuple(unclamped_j))
        new_score_j = self._score_family(j, tuple(new_parents_j), scoring_fn, tuple(unclamped_j))

        bf1 = np.exp(new_score_j - old_score_j)

        if operation == 'reverse':
            unclamped_i = list(range(self.n_cases))
            old_parents_i = self._get_parents(old_dag, i)
            new_parents_i = self._get_parents(new_dag, i)

            old_score_i = self._score_family(i, tuple(old_parents_i), scoring_fn, tuple(unclamped_i))
            new_score_i = self._score_family(i, tuple(new_parents_i), scoring_fn, tuple(unclamped_i))

            bf2 = np.exp(new_score_i - old_score_i)
        else:
            bf2 = 1.0

        return bf1 * bf2

    def _get_parents(self, dag: np.ndarray, node: int) -> List[int]:
        return list(np.where(dag[:, node] == 1)[0])

    @lru_cache(maxsize=1000)
    def _score_family(self, node: int, parents: tuple, scoring_fn: str, unclamped_cases: tuple) -> float:
        parents = list(parents)
        unclamped_cases = list(unclamped_cases)
        if scoring_fn == 'bayesian':
            return self._bayesian_score(node, parents, unclamped_cases)
        elif scoring_fn == 'bic':
            return self._bic_score(node, parents, unclamped_cases)
        else:
            raise ValueError(f"Unknown scoring function: {scoring_fn}")

    def _bayesian_score(self, node: int, parents: List[int], unclamped_cases: List[int]) -> float:
        if len(unclamped_cases) == 0:
            return 0.0

        if len(parents) == 0:
            parent_configs = np.ones(len(unclamped_cases), dtype=int)
            n_parent_configs = 1
        else:
            parent_data = self.data[parents][:, unclamped_cases]
            parent_configs, n_parent_configs = self._get_configurations(parent_data, self.node_states[parents])

        node_data = self.data[node, unclamped_cases]
        node_states = self.node_states[node]
        alpha = 1.0  # می‌توانید این را تنظیم کنید

        score = 0.0

        for config in range(n_parent_configs):
            config_mask = (parent_configs == config)
            if not np.any(config_mask):
                continue

            config_data = node_data[config_mask]
            counts = np.zeros(node_states)
            for state in range(node_states):
                counts[state] = np.sum(config_data == state)

            n_config = len(config_data)
            if n_config > 0:
                score += (gammaln(alpha) - gammaln(alpha + n_config) +
                         np.sum(gammaln(alpha/node_states + counts) - gammaln(alpha/node_states)))

        return score

    def _bic_score(self, node: int, parents: List[int], unclamped_cases: List[int]) -> float:
        if len(unclamped_cases) == 0:
            return 0.0

        log_likelihood = 0.0

        if len(parents) == 0:
            node_data = self.data[node, unclamped_cases]
            counts = np.bincount(node_data, minlength=self.node_states[node])
            total_count = len(node_data)

            if total_count > 0:
                probs = counts / total_count
                probs = np.maximum(probs, 1e-10)  # Avoid log(0)

                for state in range(self.node_states[node]):
                    if counts[state] > 0:
                        log_likelihood += counts[state] * np.log(probs[state])
        else:
            parent_data = self.data[parents][:, unclamped_cases]
            node_data = self.data[node, unclamped_cases]
            parent_configs, n_parent_configs = self._get_configurations(parent_data, self.node_states[parents])

            for config in range(n_parent_configs):
                config_mask = (parent_configs == config)
                if not np.any(config_mask):
                    continue

                config_node_data = node_data[config_mask]
                counts = np.bincount(config_node_data, minlength=self.node_states[node])

                if len(config_node_data) > 0:
                    probs = counts / len(config_node_data)
                    probs = np.maximum(probs, 1e-10)  # Avoid log(0)

                    for state in range(self.node_states[node]):
                        if counts[state] > 0:
                            log_likelihood += counts[state] * np.log(probs[state])

        n_params = (self.node_states[node] - 1) * np.prod(self.node_states[parents]) if len(parents) > 0 else (self.node_states[node] - 1)
        penalty = 0.5 * n_params * np.log(len(unclamped_cases))

        return log_likelihood - penalty

    def _get_configurations(self, data: np.ndarray, states: np.ndarray) -> Tuple[np.ndarray, int]:
        if len(data.shape) == 1:
            return data, states[0] if len(states) == 1 else max(states)

        n_vars, n_cases = data.shape
        configs = np.zeros(n_cases, dtype=int)
        multiplier = 1

        for i in range(n_vars):
            configs += data[i] * multiplier
            multiplier *= states[i]

        return configs, multiplier


class GibbsStructureLearner:
    """Gibbs Sampling for Bayesian Network structure learning."""

    def __init__(self, data: np.ndarray, node_states: List[int]):
        self.data = data
        self.node_states = np.array(node_states)
        self.n_nodes, self.n_cases = data.shape

    def learn_structure(self, nsamples: int = None, burnin: int = None, init_dag: Optional[np.ndarray] = None,
                       scoring_fn: str = 'bayesian', max_parents: int = 3, **kwargs) -> Tuple[List[np.ndarray], np.ndarray]:
        if nsamples is None:
            nsamples = 200 * self.n_nodes
        if burnin is None:
            burnin = 10 * self.n_nodes
        if init_dag is None:
            init_dag = np.zeros((self.n_nodes, self.n_nodes), dtype=int)

        dag = init_dag.copy()

        total_steps = burnin + nsamples
        sampled_graphs = []
        edge_counts = np.zeros((self.n_nodes, self.n_nodes))

        for t in range(total_steps):
            if t % 200 == 0:
                print(f"  Gibbs Iteration {t}/{total_steps}")

            dag = self._gibbs_sweep(dag, scoring_fn, max_parents)

            if t >= burnin:
                sampled_graphs.append(dag.copy())
                edge_counts += dag

        edge_probabilities = edge_counts / nsamples if nsamples > 0 else edge_counts

        return sampled_graphs, edge_probabilities

    def _gibbs_sweep(self, dag: np.ndarray, scoring_fn: str, max_parents: int) -> np.ndarray:
        new_dag = dag.copy()

        edges = []
        for i in range(self.n_nodes):
            for j in range(self.n_nodes):
                if i != j:
                    edges.append((i, j))

        np.random.shuffle(edges)

        for i, j in edges:
            new_dag[i, j] = self._sample_edge_gibbs(new_dag, i, j, scoring_fn, max_parents)

        return new_dag

    def _sample_edge_gibbs(self, dag: np.ndarray, i: int, j: int, scoring_fn: str, max_parents: int) -> int:
        current_parents = np.sum(dag[:, j])
        if dag[i, j] == 0 and current_parents >= max_parents:
            return 0

        scores = []

        for edge_state in [0, 1]:
            temp_dag = dag.copy()
            temp_dag[i, j] = edge_state

            if edge_state == 1 and self._creates_cycle_fast(temp_dag):
                scores.append(-np.inf)
            else:
                unclamped_j = list(range(self.n_cases))
                parents_j = self._get_parents(temp_dag, j)
                score = self._score_family(j, tuple(parents_j), scoring_fn, tuple(unclamped_j))
                scores.append(score)

        scores = np.array(scores)
        if np.all(scores == -np.inf):
            return dag[i, j]

        # Numerical stability
        max_score = np.max(scores[scores != -np.inf])
        scores = scores - max_score
        scores[scores < -700] = -700  # Prevent underflow

        probs = np.exp(scores)
        if np.sum(probs) == 0:
            return dag[i, j]

        probs = probs / np.sum(probs)

        return np.random.choice([0, 1], p=probs)

    def _creates_cycle_fast(self, dag: np.ndarray) -> bool:
        """Fast cycle detection using DFS."""
        n_nodes = dag.shape[0]
        WHITE, GRAY, BLACK = 0, 1, 2
        colors = np.zeros(n_nodes, dtype=int)

        def dfs_visit(node):
            if colors[node] == GRAY:
                return True
            if colors[node] == BLACK:
                return False

            colors[node] = GRAY

            for child in range(n_nodes):
                if dag[node, child] == 1:
                    if dfs_visit(child):
                        return True

            colors[node] = BLACK
            return False

        for node in range(n_nodes):
            if colors[node] == WHITE:
                if dfs_visit(node):
                    return True

        return False

    def _get_parents(self, dag: np.ndarray, node: int) -> List[int]:
        return list(np.where(dag[:, node] == 1)[0])

    @lru_cache(maxsize=1000)
    def _score_family(self, node: int, parents: tuple, scoring_fn: str, unclamped_cases: tuple) -> float:
        parents = list(parents)
        unclamped_cases = list(unclamped_cases)
        if scoring_fn == 'bayesian':
            return self._bayesian_score(node, parents, unclamped_cases)
        elif scoring_fn == 'bic':
            return self._bic_score(node, parents, unclamped_cases)
        else:
            raise ValueError(f"Unknown scoring function: {scoring_fn}")

    def _bayesian_score(self, node: int, parents: List[int], unclamped_cases: List[int]) -> float:
        if len(unclamped_cases) == 0:
            return 0.0

        if len(parents) == 0:
            parent_configs = np.ones(len(unclamped_cases), dtype=int)
            n_parent_configs = 1
        else:
            parent_data = self.data[parents][:, unclamped_cases]
            parent_configs, n_parent_configs = self._get_configurations(parent_data, self.node_states[parents])

        node_data = self.data[node, unclamped_cases]
        node_states = self.node_states[node]
        alpha = 1.0

        score = 0.0

        for config in range(n_parent_configs):
            config_mask = (parent_configs == config)
            if not np.any(config_mask):
                continue

            config_data = node_data[config_mask]
            counts = np.zeros(node_states)
            for state in range(node_states):
                counts[state] = np.sum(config_data == state)

            n_config = len(config_data)
            if n_config > 0:
                score += (gammaln(alpha) - gammaln(alpha + n_config) +
                         np.sum(gammaln(alpha/node_states + counts) - gammaln(alpha/node_states)))

        return score

    def _bic_score(self, node: int, parents: List[int], unclamped_cases: List[int]) -> float:
        if len(unclamped_cases) == 0:
            return 0.0

        log_likelihood = 0.0

        if len(parents) == 0:
            node_data = self.data[node, unclamped_cases]
            counts = np.bincount(node_data, minlength=self.node_states[node])
            total_count = len(node_data)

            if total_count > 0:
                probs = counts / total_count
                probs = np.maximum(probs, 1e-10)

                for state in range(self.node_states[node]):
                    if counts[state] > 0:
                        log_likelihood += counts[state] * np.log(probs[state])
        else:
            parent_data = self.data[parents][:, unclamped_cases]
            node_data = self.data[node, unclamped_cases]
            parent_configs, n_parent_configs = self._get_configurations(parent_data, self.node_states[parents])

            for config in range(n_parent_configs):
                config_mask = (parent_configs == config)
                if not np.any(config_mask):
                    continue

                config_node_data = node_data[config_mask]
                counts = np.bincount(config_node_data, minlength=self.node_states[node])

                if len(config_node_data) > 0:
                    probs = counts / len(config_node_data)
                    probs = np.maximum(probs, 1e-10)

                    for state in range(self.node_states[node]):
                        if counts[state] > 0:
                            log_likelihood += counts[state] * np.log(probs[state])

        n_params = (self.node_states[node] - 1) * np.prod(self.node_states[parents]) if len(parents) > 0 else (self.node_states[node] - 1)
        penalty = 0.5 * n_params * np.log(len(unclamped_cases))

        return log_likelihood - penalty

    def _get_configurations(self, data: np.ndarray, states: np.ndarray) -> Tuple[np.ndarray, int]:
        if len(data.shape) == 1:
            return data, states[0] if len(states) == 1 else max(states)

        n_vars, n_cases = data.shape
        configs = np.zeros(n_cases, dtype=int)
        multiplier = 1

        for i in range(n_vars):
            configs += data[i] * multiplier
            multiplier *= states[i]

        return configs, multiplier


def generate_asia_data(n_samples: int = 5000, seed: int = 42) -> Tuple[np.ndarray, List[str], np.ndarray]:
    """Generate data from the Asia Bayesian Network."""
    np.random.seed(seed)

    var_names = ['Asia', 'Smoking', 'Tuberculosis', 'LungCancer', 'Bronchitis',
                 'Either', 'Xray', 'Dyspnoea']

    true_dag = np.array([
        [0, 0, 1, 0, 0, 0, 0, 0],  # Asia -> Tuberculosis
        [0, 0, 0, 1, 1, 0, 0, 0],  # Smoking -> LungCancer, Bronchitis
        [0, 0, 0, 0, 0, 1, 0, 0],  # Tuberculosis -> Either
        [0, 0, 0, 0, 0, 1, 0, 0],  # LungCancer -> Either
        [0, 0, 0, 0, 0, 0, 0, 1],  # Bronchitis -> Dyspnoea
        [0, 0, 0, 0, 0, 0, 1, 1],  # Either -> Xray, Dyspnoea
        [0, 0, 0, 0, 0, 0, 0, 0],  # Xray (no children)
        [0, 0, 0, 0, 0, 0, 0, 0]   # Dyspnoea (no children)
    ])

    data = np.zeros((8, n_samples), dtype=int)

    for i in range(n_samples):
        asia = np.random.binomial(1, 0.01)
        smoking = np.random.binomial(1, 0.5)

        tuberculosis = np.random.binomial(1, 0.05 if asia == 1 else 0.01)
        lung_cancer = np.random.binomial(1, 0.1 if smoking == 1 else 0.01)
        bronchitis = np.random.binomial(1, 0.6 if smoking == 1 else 0.3)
        either = int(tuberculosis == 1 or lung_cancer == 1)
        xray = np.random.binomial(1, 0.98 if either == 1 else 0.05)

        if either == 1 and bronchitis == 1:
            dyspnoea = np.random.binomial(1, 0.9)
        elif either == 1 or bronchitis == 1:
            dyspnoea = np.random.binomial(1, 0.7)
        else:
            dyspnoea = np.random.binomial(1, 0.1)

        data[:, i] = [asia, smoking, tuberculosis, lung_cancer,
                     bronchitis, either, xray, dyspnoea]

    return data, var_names, true_dag

def graph_to_hash(graph: np.ndarray) -> str:
    """Convert a graph to a hash for efficient comparison."""
    return hashlib.md5(graph.tobytes()).hexdigest()

def evaluate_method(sampled_graphs: List[np.ndarray], true_dag: np.ndarray,
                   method_name: str, edge_probabilities: Optional[np.ndarray] = None) -> Dict:
    """Evaluate a structure learning method."""

    if not sampled_graphs:
        return {"error": "No graphs learned", "method": method_name}

    # Find most frequent structure using hashes
    graph_hashes = [graph_to_hash(graph) for graph in sampled_graphs]
    hash_counts = Counter(graph_hashes)
    most_common_hash = hash_counts.most_common(1)[0][0]

    # Find the actual graph corresponding to the most common hash
    most_frequent_graph = None
    for i, graph_hash in enumerate(graph_hashes):
        if graph_hash == most_common_hash:
            most_frequent_graph = sampled_graphs[i].copy()
            break

    # Also create consensus graph if edge probabilities available
    consensus_graph = None
    if edge_probabilities is not None:
        consensus_graph = (edge_probabilities > 0.7).astype(int)  # آستانه بالاتر برای بهبود دقت

    # Use consensus graph if available, otherwise most frequent
    eval_graph = consensus_graph if consensus_graph is not None else most_frequent_graph

    # Calculate metrics
    true_edges = set()
    learned_edges = set()

    n_nodes = true_dag.shape[0]
    for i in range(n_nodes):
        for j in range(n_nodes):
            if true_dag[i, j] == 1:
                true_edges.add((i, j))
            if eval_graph[i, j] == 1:
                learned_edges.add((i, j))

    # Calculate precision, recall, F1
    true_positives = len(true_edges.intersection(learned_edges))
    false_positives = len(learned_edges - true_edges)
    false_negatives = len(true_edges - learned_edges)

    precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
    recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
    f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

    # Structural Hamming Distance
    shd = false_positives + false_negatives

    return {
        "method": method_name,
        "most_frequent_graph": most_frequent_graph,
        "consensus_graph": consensus_graph,
        "eval_graph": eval_graph,
        "frequency": hash_counts.most_common(1)[0][1],
        "total_samples": len(sampled_graphs),
        "unique_structures": len(hash_counts),
        "true_positives": true_positives,
        "false_positives": false_positives,
        "false_negatives": false_negatives,
        "precision": precision,
        "recall": recall,
        "f1_score": f1_score,
        "structural_hamming_distance": shd,
        "true_edges": true_edges,
        "learned_edges": learned_edges,
        "edge_probabilities": edge_probabilities
    }

def print_network_info(dag: np.ndarray, var_names: List[str], title: str = "Network Structure"):
    """Print readable network structure."""
    print(f"\n{title}:")
    print("=" * len(title))

    edges = []
    for i in range(len(var_names)):
        for j in range(len(var_names)):
            if dag[i, j] == 1:
                edges.append(f"{var_names[i]} -> {var_names[j]}")

    if edges:
        for edge in edges:
            print(f"  {edge}")
        print(f"\nTotal edges: {len(edges)}")
    else:
        print("  No edges (independent variables)")

def compare_methods_detailed(results_mcmc: Dict, results_gibbs: Dict, var_names: List[str]):
    """Detailed comparison of both methods."""

    print("\n" + "="*80)
    print("DETAILED COMPARISON OF METHODS")
    print("="*80)

    # Summary table
    print(f"\n{'Metric':<25} {'MCMC':<15} {'Gibbs':<15} {'Winner':<15}")
    print("-" * 70)

    metrics = [
        ('Precision', 'precision'),
        ('Recall', 'recall'),
        ('F1-Score', 'f1_score'),
        ('SHD (lower better)', 'structural_hamming_distance'),
        ('Unique Structures', 'unique_structures')
    ]

    winners = {'MCMC': 0, 'Gibbs': 0, 'Tie': 0}

    for metric_name, metric_key in metrics:
        mcmc_val = results_mcmc.get(metric_key, 0)
        gibbs_val = results_gibbs.get(metric_key, 0)

        if metric_key == 'structural_hamming_distance':
            # Lower is better for SHD
            if mcmc_val < gibbs_val:
                winner = 'MCMC'
                winners['MCMC'] += 1
            elif gibbs_val < mcmc_val:
                winner = 'Gibbs'
                winners['Gibbs'] += 1
            else:
                winner = 'Tie'
                winners['Tie'] += 1
        else:
            # Higher is better for other metrics
            if mcmc_val > gibbs_val:
                winner = 'MCMC'
                winners['MCMC'] += 1
            elif gibbs_val > mcmc_val:
                winner = 'Gibbs'
                winners['Gibbs'] += 1
            else:
                winner = 'Tie'
                winners['Tie'] += 1

        print(f"{metric_name:<25} {mcmc_val:<15.3f} {gibbs_val:<15.3f} {winner:<15}")

    print("-" * 70)
    print(f"{'Overall Winner':<25} MCMC: {winners['MCMC']}, Gibbs: {winners['Gibbs']}, Ties: {winners['Tie']}")

    # Show structures side by side
    print(f"\n{'='*40} STRUCTURES {'='*40}")

    print_network_info(results_mcmc['eval_graph'], var_names, "MCMC Best Structure")
    print_network_info(results_gibbs['eval_graph'], var_names, "Gibbs Best Structure")

    # Edge-by-edge comparison
    print(f"\n{'='*35} EDGE COMPARISON {'='*35}")

    true_edges = results_mcmc['true_edges']
    mcmc_edges = results_mcmc['learned_edges']
    gibbs_edges = results_gibbs['learned_edges']

    all_possible_edges = true_edges.union(mcmc_edges).union(gibbs_edges)

    print(f"{'Edge':<20} {'True':<8} {'MCMC':<8} {'Gibbs':<8} {'Status'}")
    print("-" * 60)

    for i, j in sorted(all_possible_edges):
        edge_str = f"{var_names[i]}->{var_names[j]}"
        true_val = '✓' if (i, j) in true_edges else '✗'
        mcmc_val = '✓' if (i, j) in mcmc_edges else '✗'
        gibbs_val = '✓' if (i, j) in gibbs_edges else '✗'

        # Determine status
        if (i, j) in true_edges:
            if (i, j) in mcmc_edges and (i, j) in gibbs_edges:
                status = "Both found ✓"
            elif (i, j) in mcmc_edges:
                status = "Only MCMC +"
            elif (i, j) in gibbs_edges:
                status = "Only Gibbs +"
            else:
                status = "Both missed ✗"
        else:
            if (i, j) in mcmc_edges and (i, j) in gibbs_edges:
                status = "Both wrong ✗"
            elif (i, j) in mcmc_edges:
                status = "MCMC wrong -"
            elif (i, j) in gibbs_edges:
                status = "Gibbs wrong -"
            else:
                status = "Both correct ✓"

        print(f"{edge_str:<20} {true_val:<8} {mcmc_val:<8} {gibbs_val:<8} {status}")

def run_comparison_study(data: np.ndarray, node_states: List[int], var_names: List[str],
                        true_dag: np.ndarray, nsamples: int = 5000, burnin: int = 2000, scoring_fn: str = 'bayesian', max_parents: int = 4):
    """Run complete comparison study with improved parameters."""

    print("BAYESIAN NETWORK STRUCTURE LEARNING COMPARISON")
    print("=" * 60)
    print(f"Dataset: {data.shape[1]} samples, {data.shape[0]} variables")
    print(f"Variables: {', '.join(var_names)}")
    print(f"Samples per method: {nsamples}, Burn-in: {burnin}")

    # Show true structure
    print_network_info(true_dag, var_names, "TRUE NETWORK STRUCTURE")

    print(f"\nStarting structure learning comparison...")
    print(f"This may take several minutes...\n")

    # Run MCMC with better initialization
    print("1. Running MCMC (Metropolis-Hastings)...")
    start_time = time.time()

    mcmc_learner = MCMCStructureLearner(data, node_states)
    mcmc_graphs, accept_ratio, num_edges = mcmc_learner.learn_structure(
        nsamples=nsamples,
        burnin=burnin,
        init_dag=init_dag,
        scoring_fn=scoring_fn,
        max_parents=max_parents
    )

    mcmc_time = time.time() - start_time
    print(f"   MCMC completed in {mcmc_time:.1f} seconds")
    print(f"   Final acceptance ratio: {accept_ratio[-1]:.3f}")
    print(f"   Average edges: {np.mean([np.sum(g) for g in mcmc_graphs]):.1f}")

    # Run Gibbs with more iterations
    print("\n2. Running Gibbs Sampling...")
    start_time = time.time()

    gibbs_learner = GibbsStructureLearner(data, node_states)
    gibbs_graphs, edge_probabilities = gibbs_learner.learn_structure(
        nsamples=nsamples,
        burnin=burnin,
        init_dag=init_dag.copy(),
        scoring_fn=scoring_fn,
        max_parents=max_parents
    )

    gibbs_time = time.time() - start_time
    print(f"   Gibbs completed in {gibbs_time:.1f} seconds")
    print(f"   Average edges: {np.mean([np.sum(g) for g in gibbs_graphs]):.1f}")

    print("\n3. Evaluating Results...")

    results_mcmc = evaluate_method(mcmc_graphs, true_dag, "MCMC")
    results_gibbs = evaluate_method(gibbs_graphs, true_dag, "Gibbs", edge_probabilities)

    # Add timing information
    results_mcmc['time'] = mcmc_time
    results_gibbs['time'] = gibbs_time

    # Print results...
    print(f"\n{'='*25} BASIC RESULTS {'='*25}")

    print(f"\nMCMC Results:")
    print(f"  Time: {mcmc_time:.1f}s")
    print(f"  Precision: {results_mcmc['precision']:.3f}")
    print(f"  Recall: {results_mcmc['recall']:.3f}")
    print(f"  F1-Score: {results_mcmc['f1_score']:.3f}")
    print(f"  SHD: {results_mcmc['structural_hamming_distance']}")
    print(f"  Unique structures: {results_mcmc['unique_structures']}")

    print(f"\nGibbs Results:")
    print(f"  Time: {gibbs_time:.1f}s")
    print(f"  Precision: {results_gibbs['precision']:.3f}")
    print(f"  Recall: {results_gibbs['recall']:.3f}")
    print(f"  F1-Score: {results_gibbs['f1_score']:.3f}")
    print(f"  SHD: {results_gibbs['structural_hamming_distance']}")
    print(f"  Unique structures: {results_gibbs['unique_structures']}")

    # Show detailed comparison
    compare_methods_detailed(results_mcmc, results_gibbs, var_names)

    # Show edge probabilities from Gibbs
    if edge_probabilities is not None:
        print(f"\n{'='*25} GIBBS EDGE PROBABILITIES {'='*25}")
        print("All edges with their probabilities:")

        for i in range(len(var_names)):
            for j in range(len(var_names)):
                if i != j and edge_probabilities[i, j] > 0.01:  # Show even small probabilities
                    true_edge = "✓" if (i, j) in results_gibbs['true_edges'] else "✗"
                    print(f"  {var_names[i]} -> {var_names[j]}: {edge_probabilities[i, j]:.3f} {true_edge}")

    return results_mcmc, results_gibbs

# Main execution
if __name__ == "__main__":
    # Generate Asia dataset with more samples
    print("Generating Asia dataset...")
    data, var_names, true_dag = generate_asia_data(n_samples=5000, seed=42)
    node_states = [2] * 8  # All binary variables

    # Improved init_dag with prior knowledge
    init_dag = np.zeros((8, 8), dtype=int)
    init_dag[0, 2] = 1  # Asia -> Tuberculosis
    init_dag[1, 3] = 1  # Smoking -> LungCancer
    init_dag[1, 4] = 1  # Smoking -> Bronchitis
    init_dag[2, 5] = 1  # Tuberculosis -> Either
    init_dag[3, 5] = 1  # LungCancer -> Either
    init_dag[4, 7] = 1  # Bronchitis -> Dyspnoea
    init_dag[5, 6] = 1  # Either -> Xray
    init_dag[5, 7] = 1  # Either -> Dyspnoea

    # Run the comparison
    mcmc_results, gibbs_results = run_comparison_study(
        data, node_states, var_names, true_dag,
        nsamples=5000,
        burnin=2000
    )

    print(f"\n{'='*60}")
    print("COMPARISON STUDY COMPLETED")
    print(f"{'='*60}")

    # Final summary
    if mcmc_results['f1_score'] > gibbs_results['f1_score']:
        print(f"\n🏆 MCMC wins with F1-score of {mcmc_results['f1_score']:.3f} vs {gibbs_results['f1_score']:.3f}")
    elif gibbs_results['f1_score'] > mcmc_results['f1_score']:
        print(f"\n🏆 Gibbs wins with F1-score of {gibbs_results['f1_score']:.3f} vs {mcmc_results['f1_score']:.3f}")
    else:
        print(f"\n🤝 It's a tie! Both methods achieved F1-score of {mcmc_results['f1_score']:.3f}")

    print("\nDone! 🎉")

Generating Asia dataset...
BAYESIAN NETWORK STRUCTURE LEARNING COMPARISON
Dataset: 5000 samples, 8 variables
Variables: Asia, Smoking, Tuberculosis, LungCancer, Bronchitis, Either, Xray, Dyspnoea
Samples per method: 5000, Burn-in: 2000

TRUE NETWORK STRUCTURE:
  Asia -> Tuberculosis
  Smoking -> LungCancer
  Smoking -> Bronchitis
  Tuberculosis -> Either
  LungCancer -> Either
  Bronchitis -> Dyspnoea
  Either -> Xray
  Either -> Dyspnoea

Total edges: 8

Starting structure learning comparison...
This may take several minutes...

1. Running MCMC (Metropolis-Hastings)...
  MCMC Iteration 0/7000


  bf1 = np.exp(new_score_j - old_score_j)
  return bf1 * bf2


  MCMC Iteration 200/7000
  MCMC Iteration 400/7000
  MCMC Iteration 600/7000
  MCMC Iteration 800/7000
  MCMC Iteration 1000/7000
  MCMC Iteration 1200/7000
  MCMC Iteration 1400/7000
  MCMC Iteration 1600/7000
  MCMC Iteration 1800/7000
  MCMC Iteration 2000/7000
  MCMC Iteration 2200/7000
  MCMC Iteration 2400/7000
  MCMC Iteration 2600/7000
  MCMC Iteration 2800/7000
  MCMC Iteration 3000/7000
  MCMC Iteration 3200/7000
  MCMC Iteration 3400/7000
  MCMC Iteration 3600/7000
  MCMC Iteration 3800/7000
  MCMC Iteration 4000/7000
  MCMC Iteration 4200/7000
  MCMC Iteration 4400/7000
  MCMC Iteration 4600/7000
  MCMC Iteration 4800/7000
  MCMC Iteration 5000/7000
  MCMC Iteration 5200/7000
  MCMC Iteration 5400/7000
  MCMC Iteration 5600/7000
  MCMC Iteration 5800/7000
  MCMC Iteration 6000/7000
  MCMC Iteration 6200/7000
  MCMC Iteration 6400/7000
  MCMC Iteration 6600/7000
  MCMC Iteration 6800/7000
   MCMC completed in 27.0 seconds
   Final acceptance ratio: 0.004
   Average edges: 6