In [None]:
import os
import numpy as np
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
from lingam.direct_lingam import DirectLiNGAM
from lingam.utils import make_prior_knowledge, make_dot
from sklearn.preprocessing import StandardScaler
from evaluation import evaluate_graph
from true_graph import create_true_graph_student, create_true_graph_student_small, create_true_graph_adult, create_true_graph_adult_small

In [2]:
def set_random_seed(seed):
    """Set random seed for reproducibility."""
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

In [3]:
def create_dag_from_adjacency_matrix(adjacency_matrix, labels):
    graph = nx.DiGraph(adjacency_matrix)
    mapping = {i: labels[i] for i in range(len(labels))}
    graph = nx.relabel_nodes(graph, mapping)

    # Remove cycles to ensure the graph is a DAG
    try:
        while not nx.is_directed_acyclic_graph(graph):
            cycles = list(nx.find_cycle(graph, orientation='original'))
            for edge in cycles:
                graph.remove_edge(edge[0], edge[1])
    except nx.NetworkXNoCycle:
        pass

    return graph

In [4]:
def plot_dag_with_matplotlib(dag, title="DAG"):
    plt.figure(figsize=(12, 8))
    pos = nx.spring_layout(dag, k=0.85)
    nx.draw(dag, pos, with_labels=True, node_size=3000, node_color="skyblue", font_size=12, font_weight="bold", arrowsize=20)
    plt.title(title)
    plt.show()

In [5]:
def run_direct_lingam(data, labels, prior_knowledge=None, seed=None):

    if seed is not None:
        set_random_seed(seed)

    # Standardize the data
    scaler = StandardScaler()
    data = scaler.fit_transform(data)
    
    # Print prior knowledge matrix for debugging
    print("Prior Knowledge Matrix:")
    print(prior_knowledge)

    # Initialize the DirectLiNGAM model
    model = DirectLiNGAM(prior_knowledge=prior_knowledge)
    model.fit(data)

    # Get and print the adjacency matrix for debugging
    adjacency_matrix = model.adjacency_matrix_
    print("Adjacency Matrix After Fitting:")
    print(adjacency_matrix)

    return adjacency_matrix

In [None]:
def main():
    # Specify dataset
    dataset = 'adult_small'  # Options: 'adult', 'adult_small', 'student', 'student_small'
    
    # Load the appropriate processed CSV file
    if dataset == 'adult':
        data_file = 'data/processed_adult.csv'  # Path to the processed adult dataset
        df_encoded = pd.read_csv(data_file)
        labels = df_encoded.columns.tolist()
        data = df_encoded.to_numpy()
        true_graph = create_true_graph_adult()
        exogenous_vars = ['age', 'native.country']  # These should match the encoded labels
        sink_vars = ['income']

    elif dataset == 'adult_small':
        data_file = 'data/processed_adult_small.csv'  # Path to the processed smaller adult dataset
        df_encoded = pd.read_csv(data_file)
        labels = df_encoded.columns.tolist()
        data = df_encoded.to_numpy()
        true_graph = create_true_graph_adult_small()
        exogenous_vars = ['age', 'native.country']
        sink_vars = ['income']

    elif dataset == 'student':
        data_file = 'data/processed_student.csv'  # Path to the processed student dataset
        df_encoded = pd.read_csv(data_file)
        labels = df_encoded.columns.tolist()
        data = df_encoded.to_numpy()
        true_graph = create_true_graph_student()
        exogenous_vars = ['Medu', 'health']
        sink_vars = ['G_avg']

    elif dataset == 'student_small':
        data_file = 'data/processed_student_small.csv'  # Path to the processed smaller student dataset
        df_encoded = pd.read_csv(data_file)
        labels = df_encoded.columns.tolist()
        data = df_encoded.to_numpy()
        true_graph = create_true_graph_student_small()
        exogenous_vars = ['Medu', 'health']
        sink_vars = ['G_avg']

    else:
        raise ValueError("Invalid dataset specified. Choose 'adult', 'adult_small', 'student', or 'student_small'.")

    print(f"Processing dataset: {dataset}")
    print("Data loaded and prepared.")
    print(f"Labels: {labels}")

    # Construct the prior knowledge matrix
    prior_knowledge = make_prior_knowledge(
        n_variables=len(labels),
        exogenous_variables=[labels.index(var) for var in exogenous_vars],
        sink_variables=[labels.index(var) for var in sink_vars]
    )

    # Step 1: Create the adjacency matrix with no background knowledge
    adjacency_matrix_no_bk = run_direct_lingam(data, labels, prior_knowledge=None, seed=42)
    print("Adjacency matrix with no background knowledge:")
    print(adjacency_matrix_no_bk)

    # Step 2: Show DAG from adjacency matrix with no background knowledge
    dag_no_bk = create_dag_from_adjacency_matrix(adjacency_matrix_no_bk, labels)
    plot_dag_with_matplotlib(dag_no_bk, title="DAG with No Background Knowledge")

    # Evaluate the estimated graph with no background knowledge
    estimated_graph_no_bk = create_dag_from_adjacency_matrix(adjacency_matrix_no_bk, labels)
    if true_graph is not None:
        shd_no_bk, recall_no_bk, precision_no_bk = evaluate_graph(estimated_graph_no_bk, true_graph)
        print(f"Without Background Knowledge - SHD: {shd_no_bk}, Recall: {recall_no_bk}, Precision: {precision_no_bk}")

    # Step 3: Adjust the adjacency matrix with background knowledge
    adjacency_matrix_with_bk = run_direct_lingam(data, labels, prior_knowledge=prior_knowledge, seed=42)
    print("Adjacency matrix with background knowledge:")
    print(adjacency_matrix_with_bk)

    # Step 4: Show DAG from adjacency matrix with background knowledge
    dag_with_bk = create_dag_from_adjacency_matrix(adjacency_matrix_with_bk, labels)
    plot_dag_with_matplotlib(dag_with_bk, title="DAG with Background Knowledge")

    # Evaluate the estimated graph with background knowledge
    estimated_graph_with_bk = create_dag_from_adjacency_matrix(adjacency_matrix_with_bk, labels)
    if true_graph is not None:
        shd_with_bk, recall_with_bk, precision_with_bk = evaluate_graph(estimated_graph_with_bk, true_graph)
        print(f"With Background Knowledge - SHD: {shd_with_bk}, Recall: {recall_with_bk}, Precision: {precision_with_bk}")



In [None]:
main()