In [1]:
import os
import sys

# Set working directory to project root
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
    sys.path.append(project_root)
os.chdir(project_root)

In [2]:
from src.data_processing.data_loader import load_synthetic_data
from src.data_processing.preprocessor import preprocess_data
from src.utils.graph_utils import get_dag 
from src.evaluation.metrics import calculate_base_performance

dataset = "asia"
target_node = "dysp"
task = "classification"

df = load_synthetic_data(dataset=dataset)
train, val, test = preprocess_data(df)

true_model, true_adj_matrix = get_dag(dataset)
num_classes = df[target_node].nunique()

In [3]:
from src.models.causal_discovery.discovery_factory import CausalDiscoveryFactory

factory = CausalDiscoveryFactory(method='grasp')

learned_adj_matrix = factory.discover_graph(train)

GRaSP edge count: 7    
GRaSP completed in: 0.07s 


In [4]:
from src.models.baselines.xgb import XGBBaseline

xgb_model = XGBBaseline(target_node=target_node, task=task, num_classes=num_classes)
xgb_model.train(train, val)

xgb_preds = xgb_model.predict(test)
print(xgb_preds)
results = calculate_base_performance(test[target_node], xgb_preds, task=task)

print(f"Results for XGB on {dataset} dataset predicting {target_node}:")
print(results)

num_differences = (xgb_preds != test[target_node]).sum()
print(f"Number of differences between XGB predictions and true values: {num_differences}")

[1 1 0 ... 1 0 1]
Results for XGB on asia dataset predicting dysp:
{'accuracy': 0.8605, 'precision': 0.8636565388183053, 'recall': 0.857138696160475, 'f1': 0.8588690555228864}
Number of differences between XGB predictions and true values: 279


In [5]:
from src.models.baselines.mlp import MLPBaseline

mlp_model = MLPBaseline(target_node=target_node, task=task, num_classes=num_classes)
mlp_model.train(train, val, patience=10, epochs=1000)

mlp_preds = mlp_model.predict(test)
results = calculate_base_performance(test[target_node], mlp_preds, task=task)

print(f"Results for MLP on {dataset} dataset predicting {target_node}:")
print(results)

Epoch 1 - Train Loss: 0.5068, Val Loss: 0.3592
Epoch 2 - Train Loss: 0.4223, Val Loss: 0.3529
Epoch 3 - Train Loss: 0.4180, Val Loss: 0.3493
Epoch 9 - Train Loss: 0.4155, Val Loss: 0.3491
Epoch 17 - Train Loss: 0.4154, Val Loss: 0.3472
Epoch 18 - Train Loss: 0.4142, Val Loss: 0.3440
Stopping early at epoch 28
Results for MLP on asia dataset predicting dysp:
{'accuracy': 0.8605, 'precision': 0.8636565388183053, 'recall': 0.857138696160475, 'f1': 0.8588690555228864}


In [6]:
from src.models.baselines.hierarchical_xgb import HierarchicalXGB

hier_model = HierarchicalXGB(graph=true_adj_matrix, target_node=target_node)
hier_model.train(train, val)

hier_preds = hier_model.predict(test)

results = calculate_base_performance(test[target_node], hier_preds, task=task)

print(f"Results for Hierarchical XGB on {dataset} dataset predicting {target_node}:")
print(results)

Training tub | Task: classification | Num Classes: 2 | Parents: ['asia']
Training lung | Task: classification | Num Classes: 2 | Parents: ['smoke']
Training bronc | Task: classification | Num Classes: 2 | Parents: ['smoke']
Training either | Task: classification | Num Classes: 2 | Parents: ['tub', 'lung']
Training xray | Task: classification | Num Classes: 2 | Parents: ['either']
Training dysp | Task: classification | Num Classes: 2 | Parents: ['bronc', 'either']
Results for Hierarchical XGB on asia dataset predicting dysp:
{'accuracy': 0.8605, 'precision': 0.8636565388183053, 'recall': 0.857138696160475, 'f1': 0.8588690555228864}


In [7]:
from src.models.causal_gnns.gcn import GCNBaseline

gcn_model = GCNBaseline(target_node=target_node, task=task, num_classes=num_classes, adj_mat=true_adj_matrix)
gcn_model.train(train, val, patience=10, epochs=500)

gcn_preds = gcn_model.predict(test)
results = calculate_base_performance(test[target_node], gcn_preds, task=task)

print(f"Results for GCN on {dataset} dataset predicting {target_node}:")
print(results)

Epoch 1 - Train Loss: 0.6345, Val Loss: 0.5647, Val Acc: 0.639
Epoch 2 - Train Loss: 0.5499, Val Loss: 0.4787, Val Acc: 0.801
Epoch 3 - Train Loss: 0.5034, Val Loss: 0.4534, Val Acc: 0.879
Epoch 4 - Train Loss: 0.4876, Val Loss: 0.4299, Val Acc: 0.801
Epoch 5 - Train Loss: 0.4832, Val Loss: 0.4269, Val Acc: 0.801
Epoch 6 - Train Loss: 0.4816, Val Loss: 0.4233, Val Acc: 0.801
Epoch 9 - Train Loss: 0.4805, Val Loss: 0.4216, Val Acc: 0.801
Early stopping at epoch 19
Results for GCN on asia dataset predicting dysp:
{'accuracy': 0.791, 'precision': 0.77223539807261, 'recall': 0.7993906272670117, 'f1': 0.7781125553258825}


In [11]:
from src.evaluation.metrics import calculate_cace


print(true_model.cpds)

intervention_node = "bronc"
models = [xgb_model, mlp_model, hier_model, gcn_model]
results = calculate_cace(
    models=models,
    data=test, 
    intervention_node=intervention_node, 
    adj_mat=true_adj_matrix,
    cpds = true_model.cpds,
    )

print(f"CACE on {dataset} dataset intervening on {intervention_node}: {results}")

[<TabularCPD representing P(asia:2) at 0x1766464b0>, <TabularCPD representing P(bronc:2 | smoke:2) at 0x176bdc620>, <TabularCPD representing P(dysp:2 | bronc:2, either:2) at 0x176ae7800>, <TabularCPD representing P(either:2 | lung:2, tub:2) at 0x177aa8980>, <TabularCPD representing P(lung:2 | smoke:2) at 0x177154d10>, <TabularCPD representing P(smoke:2) at 0x177154ef0>, <TabularCPD representing P(tub:2 | asia:2) at 0x177107200>, <TabularCPD representing P(xray:2 | either:2) at 0x177b06840>]
CACE on asia dataset intervening on xray: ({<src.models.baselines.xgb.XGBBaseline object at 0x16c57be00>: {0.0}, <src.models.baselines.mlp.MLPBaseline object at 0x17679c7a0>: {0.022}, <src.models.baselines.hierarchical_xgb.HierarchicalXGB object at 0x3029509e0>: {0.0}, <src.models.causal_gnns.gcn.GCNBaseline object at 0x30293a510>: {0.0}}, 0.0)
