In [2]:
import os
import random
import numpy as np
import pandas as pd
import lingam
from lingam import LiM
import lingam.utils as lutils
import networkx as nx
import graphviz
from lingam.utils import make_dot
from sklearn.preprocessing import StandardScaler
from data_preparation import load_and_prepare_student_data, load_and_prepare_adult_data
from true_graph import create_true_graph_student, create_true_graph_student_small, create_true_graph_adult, create_true_graph_adult_small
from evaluation import evaluate_graph

In [3]:
def set_random_seed(seed):
    """Set random seed for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

In [4]:
def run_lim_algorithm(data, labels, discrete_columns, seed=None):
    if seed is not None:
        set_random_seed(seed)

    # Standardize the data
    scaler = StandardScaler()
    data = scaler.fit_transform(data)

    # Ensure discrete_columns is a 2D array as required by the algorithm
    discrete_columns = np.array([discrete_columns], dtype=float)

    # Initialize and fit the LiM model
    model_lim = LiM(max_iter=1000)
    model_lim.fit(data, discrete_columns, only_global=True)

    # Get the adjacency matrix
    adjacency_matrix = model_lim._adjacency_matrix

    return adjacency_matrix, labels

In [6]:
def main():
    # Specify dataset
    dataset = 'adult_small'  # Change to 'adult', 'adult_small', 'student', or 'student_small'
    
    # Initialize data, labels, and discrete_columns
    if dataset == 'adult':
        data_file = 'data/processed_adult.csv'
        df_encoded = pd.read_csv(data_file)
        labels = df_encoded.columns.tolist()
        data = df_encoded.to_numpy()
        true_graph = create_true_graph_adult()
        discrete_columns = [1 if col in ['workclass', 'education', 'marital.status', 'occupation', 'relationship', 'race', 'sex', 'native.country', 'income'] else 0 for col in labels]

    elif dataset == 'adult_small':
        data_file = 'data/processed_adult_small.csv'
        df_encoded = pd.read_csv(data_file)
        labels = df_encoded.columns.tolist()
        data = df_encoded.to_numpy()
        true_graph = create_true_graph_adult_small()
        discrete_columns = [1 if col in ['workclass', 'education', 'occupation', 'native.country', 'income'] else 0 for col in labels]

    elif dataset == 'student':
        data_file = 'data/processed_student.csv'
        df_encoded = pd.read_csv(data_file)
        labels = df_encoded.columns.tolist()
        data = df_encoded.to_numpy()
        true_graph = create_true_graph_student()
        discrete_columns = [1 if col in ['internet_yes', 'higher_yes', 'famsup_yes', 'paid_yes', 'schoolsup', 'Pstatus', 'failures'] else 0 for col in labels]

    elif dataset == 'student_small':
        data_file = 'data/processed_student_small.csv'
        df_encoded = pd.read_csv(data_file)
        labels = df_encoded.columns.tolist()
        data = df_encoded.to_numpy()
        true_graph = create_true_graph_student_small()
        discrete_columns = [1 if col in ['higher_yes'] else 0 for col in labels]

    else:
        raise ValueError("Invalid dataset specified. Choose 'adult', 'adult_small', 'student', or 'student_small'.")

    print(f"Processing dataset: {dataset}")

    # Run the LiM algorithm
    adjacency_matrix, labels = run_lim_algorithm(data, labels, discrete_columns, seed=42)
    print("LiM Algorithm graph created.")

    # Evaluate the estimated graph against the true graph
    estimated_graph = nx.DiGraph(adjacency_matrix)
    shd, recall, precision = evaluate_graph(estimated_graph, true_graph)
    print(f"Structural Hamming Distance (SHD): {shd}")
    print(f"Recall: {recall}")
    print(f"Precision: {precision}")

    # Plot using make_dot from lingam.utils
    dot = make_dot(adjacency_matrix, labels=labels)
    dot.render("causal_graph")  # Save the graph as a PDF/PNG file

    # Optional: Display the graph inline (if in Jupyter Notebook or similar)
    dot.view()  # This will open the saved graph in the default viewer

    return adjacency_matrix, labels, shd, recall, precision

if __name__ == "__main__":
    adjacency_matrix, labels, shd, recall, precision = main()

Processing dataset: adult_small
W_est (without the 2nd phase) is: 
 [[0.         0.28541651 0.         0.         0.         0.
  0.        ]
 [0.         0.         0.         0.         0.         0.
  0.        ]
 [0.43835038 0.         0.         0.         0.         0.
  0.        ]
 [0.         0.         0.         0.         0.         0.
  0.        ]
 [0.         0.         0.         0.         0.         0.
  0.        ]
 [0.         0.         0.         0.         0.         0.
  0.        ]
 [0.         0.         0.         0.         0.         0.
  0.        ]]
LiM Algorithm graph created.
Structural Hamming Distance (SHD): 18
Recall: 0.0
Precision: 0.0
