In [34]:
import os
import random
import numpy as np
import networkx as nx
from causallearn.search.ConstraintBased.PC import pc
from causallearn.utils.cit import kci 
from causallearn.utils.cit import fisherz
from causallearn.utils.PCUtils.BackgroundKnowledge import BackgroundKnowledge
from causallearn.graph.Node import Node
from plotting_utils import causal_learn_to_networkx, plot_and_save_graph
from data_preparation import load_and_prepare_student_data, load_and_prepare_adult_data, apply_variable_mapping, student_variable_mapping
from dag_utils import load_dag
from true_graph import create_true_graph_student, create_true_graph_adult, plot_true_graph
from evaluation import evaluate_graph

In [35]:
def set_random_seed(seed):
    """Set random seed for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

In [40]:
def convert_labels_to_nodes(labels):
    """Converts a list of labels to a list of Node instances."""
    return [Node(i, label) for i, label in enumerate(labels)]

In [44]:
def run_pc_algorithm(data, labels, alpha=0.01, stable=True, uc_rule=2, root_nodes=None, seed=None):
    """
    Runs the PC algorithm with optional background knowledge and returns the estimated causal graph.

    Args:
        data (numpy.ndarray): The dataset as a NumPy array.
        labels (list): A list of string labels for the variables in the dataset.
        alpha (float): Significance level for the conditional independence tests.
        stable (bool): Whether to use the conservative PC (stable) algorithm.
        uc_rule (int): The rule for handling unshielded colliders (0, 1, or 2).
        root_nodes (list, optional): A list of node names that should be considered root nodes (no parents).
        seed (int, optional): Seed for random number generators to ensure reproducibility.

    Returns:
        networkx.DiGraph: The estimated causal graph as a NetworkX DiGraph object.
    """
    if seed is not None:
        set_random_seed(seed)
    
    # Create background knowledge object
    bk = BackgroundKnowledge()
    
    # Convert labels to Node instances
    node_labels = convert_labels_to_nodes(labels)
    node_dict = {label: node for label, node in zip(labels, node_labels)}

    # Add forbidden edges to enforce root nodes (if specified)
    if root_nodes:
        for root_node in root_nodes:
            for other_node in labels:
                if other_node != root_node:
                    bk.add_forbidden_by_node(node_dict[other_node], node_dict[root_node])

    # Run PC algorithm with background knowledge
    cg_pc = pc(data, alpha=alpha, indep_test=fisherz, stable=stable, uc_rule=uc_rule, background_knowledge=bk)

    # Convert CausalLearn GeneralGraph to NetworkX DiGraph using the utility function
    nx_graph = causal_learn_to_networkx(cg_pc.G)

    # Ensure the graph is a DAG by removing bidirectional edges
    bidirectional_edges = [(u, v) for u, v in nx_graph.edges if nx_graph.has_edge(v, u)]
    for u, v in bidirectional_edges:
        nx_graph.remove_edge(v, u)

    return nx_graph

In [43]:
def remove_causal_edges(graph, node_name):
    """
    Remove all causal edges connected to a specific node.

    Args:
        graph (networkx.DiGraph): The causal graph.
        node_name (str): The name of the node whose edges are to be removed.

    Returns:
        networkx.DiGraph: The modified causal graph with specified edges removed.
    """
    edges_to_remove = list(graph.edges(node_name)) + list(graph.in_edges(node_name))
    graph.remove_edges_from(edges_to_remove)
    return graph

In [38]:
def main():
    # Specify dataset and root nodes
    dataset = 'adult'  # Change to 'student' for the student dataset
    data_file = 'data/adult_cleaned.csv' if dataset == 'adult' else 'data/student-por_raw.csv'
    root_nodes = ['sex', 'age', 'race', 'native.country'] if dataset == 'adult' else ['Medu']

    # Load and prepare data
    if dataset == 'adult':
        df_encoded, labels, data = load_and_prepare_adult_data(data_file)
        true_graph = create_true_graph_adult()
    else:
        df_encoded, labels, data = load_and_prepare_student_data(data_file)
        true_graph = create_true_graph_student()

    # Run the PC algorithm with forbidden edges to enforce root nodes
    estimated_graph_1 = run_pc_algorithm(data, labels, alpha=0.01, stable=True, uc_rule=0, root_nodes=root_nodes, seed=42)
    print("First run completed.")

    # Remove all edges connected to a specific node
    node_to_remove_edges = 'Medu' if dataset == 'student' else 'race'
    estimated_graph_1 = remove_causal_edges(estimated_graph_1, node_to_remove_edges)
    print(f"Removed all edges connected to {node_to_remove_edges}")

    # Run again with a different parameter for variability
    estimated_graph_2 = run_pc_algorithm(data, labels, alpha=0.01, stable=True, uc_rule=1, root_nodes=root_nodes, seed=42)
    print("Second run with different parameter completed.")

    # Evaluate the first estimated graph
    shd_1, recall_1, precision_1 = evaluate_graph(estimated_graph_1, true_graph)
    print("Evaluation of the first estimated graph:")
    print(f"Structural Hamming Distance (SHD): {shd_1}")
    print(f"Recall: {recall_1}")
    print(f"Precision: {precision_1}")

    # Evaluate the second estimated graph
    shd_2, recall_2, precision_2 = evaluate_graph(estimated_graph_2, true_graph)
    print("Evaluation of the second estimated graph:")
    print(f"Structural Hamming Distance (SHD): {shd_2}")
    print(f"Recall: {recall_2}")
    print(f"Precision: {precision_2}")

    return estimated_graph_2

In [39]:
main()

TypeError: Node() takes no arguments