In [18]:
import argparse

from causaldag import DAG

import networkx as nx
import os
import pandas as pd
import wandb
import numpy as np
import random

from algo.greedy_search import greedy_search_mec_size, greedy_search_confidence, greedy_search_bic

from models.noisy_expert import NoisyExpert
from models.oracles import EpsilonOracle

from utils.data_generation import generate_dataset
from utils.dag_utils import get_undirected_edges, is_dag_in_mec, get_mec, get_undirected_edges_pdag
from utils.metrics import get_mec_shd
from utils.language_models import get_lms_probs, temperature_scaling
from utils.CI_utils import load_data_from_file, save_data_to_file, data2pdag

In [28]:
algo="greedy_conf"
# algo="naive"
prior="mec"
dataset="child"
tabular=False
probability="posterior"
preprocess=True
calibrate=False
epsilon=0.05
tolerance=0.1
seed=953100
verbose=False


def blindly_follow_expert(observed_arcs, model, cpdag, *args, **kwargs):
    return [list(observed_arcs) + list(cpdag.arcs)], observed_arcs, model(observed_arcs, observed_arcs)


match algo:
    case "greedy_mec":
        algo = greedy_search_mec_size
    case "greedy_conf":
        algo = greedy_search_confidence
    case "greedy_bic":
        algo = greedy_search_bic
        tolerance = 1.

    case "global_scoring":
        from algo.global_scoring import global_scoring
        algo = global_scoring
        tolerance = 1.

    case "PC":
        algo = lambda a, b, cpdag, c, tol: (get_mec(cpdag), dict(), 1.)
        tolerance = 0.
    case "naive":
        algo = blindly_follow_expert
        tolerance = 1.

match prior:
    case "mec":
        from models.priors import MECPrior
        prior_type = MECPrior
    
    case "independent":
        from models.priors import IndependentPrior
        prior_type = IndependentPrior

if not os.path.exists("_raw_bayesian_nets"):
    from utils.download_datasets import download_datasets
    download_datasets()

true_G, data = generate_dataset('_raw_bayesian_nets/' + dataset + '.bif')

if preprocess:
    # use pc and ges to generate the cpdag
    # try load data
    try:
        data = load_data_from_file('./' + dataset+'_data' + '.pkl')
    except:
        # save data
        save_data_to_file(data, './' + dataset+'_data' + '.pkl')
    

    cpdag = data2pdag(data)

else:
    cpdag = DAG.from_nx(true_G).cpdag()



  0%|          | 0/20 [00:00<?, ?it/s]

In [29]:
undirected_edges = get_undirected_edges_pdag(cpdag, verbose=True)
llm_engine="text-davinci-002"
if tabular:
    oracle = EpsilonOracle(undirected_edges, epsilon=epsilon)
    observations = oracle.decide_all()
    likelihoods = oracle.likelihoods

else:
    try:
        codebook = pd.read_csv('codebooks/' + dataset + '.csv')
    except:
        print('cannot load the codebook')
        codebook = None

    if calibrate:
        tmp_scale, eps = temperature_scaling(cpdag.arcs, codebook, engine=llm_engine)
        print("LLM has %.3f error rate" % eps)
    else:
        tmp_scale = 1.

    likelihoods, observations = get_lms_probs(undirected_edges, codebook, tmp_scale, engine=llm_engine)

LowerBodyO2 is not defined
HypDistrib is not defined
HypoxiaInO2 is not defined
RUQO2 is not defined
Grunting is not defined
HypoxiaInO2 is not defined
HypoxiaInO2 is not defined
Grunting is not defined
HypDistrib is not defined
HypoxiaInO2 is not defined
LowerBodyO2 is not defined


In [30]:
print("\nTrue Orientations:", undirected_edges)
print("\nOrientations given by the expert:", observations)
print(likelihoods)
prior = prior_type(cpdag)
model = NoisyExpert(prior, likelihoods)


True Orientations: {('CO2Report', 'CO2'), ('LowerBodyO2', 'HypDistrib'), ('CardiacMixing', 'Disease'), ('BirthAsphyxia', 'DuctFlow'), ('XrayReport', 'ChestXray'), ('HypoxiaInO2', 'RUQO2'), ('DuctFlow', 'LungParench'), ('GruntingReport', 'Grunting'), ('LungFlow', 'CardiacMixing'), ('CardiacMixing', 'BirthAsphyxia'), ('Disease', 'Sick'), ('HypoxiaInO2', 'DuctFlow'), ('Disease', 'DuctFlow'), ('HypoxiaInO2', 'ChestXray'), ('LVH', 'Disease'), ('LungFlow', 'Age'), ('Grunting', 'LungParench'), ('CO2', 'LungParench'), ('LVH', 'DuctFlow'), ('Age', 'Sick'), ('LungFlow', 'ChestXray'), ('Age', 'DuctFlow'), ('Disease', 'LungParench'), ('CardiacMixing', 'HypDistrib'), ('ChestXray', 'LungParench'), ('HypoxiaInO2', 'LowerBodyO2')}

Orientations given by the expert: [('CO2', 'CO2Report'), ('HypDistrib', 'LowerBodyO2'), ('Disease', 'CardiacMixing'), ('DuctFlow', 'BirthAsphyxia'), ('XrayReport', 'ChestXray'), ('RUQO2', 'HypoxiaInO2'), ('DuctFlow', 'LungParench'), ('GruntingReport', 'Grunting'), ('Cardia

In [31]:
match probability:
    case "posterior":
        prob_method = model.posterior
    
    case "likelihood":
        prob_method = model.likelihood

    case "prior":
        prob_method = lambda _, edges: prior(edges)

new_mec, decisions, p_correct = blindly_follow_expert(observations, prob_method, cpdag, undirected_edges, tol=0.5)

In [6]:
decisions

[('CO2', 'CO2Report'),
 ('HypDistrib', 'LowerBodyO2'),
 ('Disease', 'CardiacMixing'),
 ('DuctFlow', 'BirthAsphyxia'),
 ('XrayReport', 'ChestXray'),
 ('RUQO2', 'HypoxiaInO2'),
 ('DuctFlow', 'LungParench'),
 ('GruntingReport', 'Grunting'),
 ('CardiacMixing', 'LungFlow'),
 ('BirthAsphyxia', 'CardiacMixing'),
 ('Disease', 'Sick'),
 ('DuctFlow', 'HypoxiaInO2'),
 ('Disease', 'DuctFlow'),
 ('HypoxiaInO2', 'ChestXray'),
 ('Disease', 'LVH'),
 ('LungFlow', 'Age'),
 ('LungParench', 'Grunting'),
 ('CO2', 'LungParench'),
 ('DuctFlow', 'LVH'),
 ('Age', 'Sick'),
 ('LungFlow', 'ChestXray'),
 ('Age', 'DuctFlow'),
 ('Disease', 'LungParench'),
 ('HypDistrib', 'CardiacMixing'),
 ('LungParench', 'ChestXray'),
 ('LowerBodyO2', 'HypoxiaInO2')]

In [26]:
new_mec

[[('CO2', 'CO2Report'),
  ('GruntingReport', 'Grunting'),
  ('Disease', 'LVH'),
  ('LVH', 'LVHreport'),
  ('ChestXray', 'XrayReport'),
  ('HypoxiaInO2', 'RUQO2'),
  ('LungParench', 'CO2'),
  ('DuctFlow', 'LVH')],
 [('CO2', 'CO2Report'),
  ('Grunting', 'GruntingReport'),
  ('Disease', 'LVH'),
  ('XrayReport', 'ChestXray'),
  ('LVH', 'LVHreport'),
  ('HypoxiaInO2', 'RUQO2'),
  ('LungParench', 'CO2'),
  ('DuctFlow', 'LVH')],
 [('GruntingReport', 'Grunting'),
  ('CO2Report', 'CO2'),
  ('RUQO2', 'HypoxiaInO2'),
  ('XrayReport', 'ChestXray'),
  ('Disease', 'LVH'),
  ('LVH', 'LVHreport'),
  ('CO2', 'LungParench'),
  ('DuctFlow', 'LVH')],
 [('CO2Report', 'CO2'),
  ('RUQO2', 'HypoxiaInO2'),
  ('Grunting', 'GruntingReport'),
  ('Disease', 'LVH'),
  ('LVH', 'LVHreport'),
  ('ChestXray', 'XrayReport'),
  ('CO2', 'LungParench'),
  ('DuctFlow', 'LVH')],
 [('CO2Report', 'CO2'),
  ('Grunting', 'GruntingReport'),
  ('Disease', 'LVH'),
  ('XrayReport', 'ChestXray'),
  ('LVH', 'LVHreport'),
  ('CO2', 'Lu

In [27]:
true_G.nodes

NodeView(('Age', 'BirthAsphyxia', 'CO2', 'CO2Report', 'CardiacMixing', 'ChestXray', 'Disease', 'DuctFlow', 'Grunting', 'GruntingReport', 'HypDistrib', 'HypoxiaInO2', 'LVH', 'LVHreport', 'LowerBodyO2', 'LungFlow', 'LungParench', 'RUQO2', 'Sick', 'XrayReport'))

In [32]:
import networkx as nx
def mec_to_graph(mec_list):
    directed_set = set()
    undirected_set = set()
    for mec in mec_list:
        for edge in mec:
            # if (edge[0], edge[1]) in directed_set, do nothing; if (edge[1], edge[0]) in directed_set, remove it, add to undirected_set
            if (edge[0], edge[1]) in directed_set or ((edge[0], edge[1]) in undirected_set or (edge[1], edge[0]) in undirected_set):
                continue
            elif (edge[1], edge[0]) in directed_set:
                directed_set.remove((edge[1], edge[0]))
                undirected_set.add((edge[0], edge[1]))
            else:
                directed_set.add(edge)  
    return directed_set, undirected_set

print(mec_to_graph(get_mec(cpdag)))
print(mec_to_graph(new_mec))

({('Disease', 'LVH'), ('LVH', 'LVHreport'), ('DuctFlow', 'LVH')}, {('CO2Report', 'CO2'), ('RUQO2', 'HypoxiaInO2'), ('Grunting', 'GruntingReport'), ('XrayReport', 'ChestXray'), ('CO2', 'LungParench')})
({('CO2', 'CO2Report'), ('Grunting', 'Sick'), ('Disease', 'CardiacMixing'), ('XrayReport', 'ChestXray'), ('HypDistrib', 'LowerBodyO2'), ('HypDistrib', 'CardiacMixing'), ('DuctFlow', 'LungParench'), ('DuctFlow', 'LVH'), ('GruntingReport', 'Grunting'), ('LowerBodyO2', 'HypoxiaInO2'), ('Disease', 'Sick'), ('Disease', 'DuctFlow'), ('LVH', 'LVHreport'), ('HypoxiaInO2', 'ChestXray'), ('LungParench', 'Grunting'), ('LungFlow', 'Age'), ('DuctFlow', 'HypoxiaInO2'), ('RUQO2', 'HypoxiaInO2'), ('CO2', 'LungParench'), ('Age', 'Sick'), ('LungFlow', 'ChestXray'), ('Age', 'DuctFlow'), ('DuctFlow', 'BirthAsphyxia'), ('BirthAsphyxia', 'CardiacMixing'), ('LungParench', 'ChestXray'), ('Disease', 'LungParench'), ('CardiacMixing', 'LungFlow'), ('Disease', 'LVH')}, set())


In [33]:
cpdag.edges

{frozenset({'Grunting', 'GruntingReport'}),
 frozenset({'Age', 'DuctFlow'}),
 frozenset({'Disease', 'LungParench'}),
 frozenset({'HypoxiaInO2', 'LowerBodyO2'}),
 frozenset({'HypDistrib', 'LowerBodyO2'}),
 frozenset({'DuctFlow', 'LungParench'}),
 frozenset({'ChestXray', 'HypoxiaInO2'}),
 frozenset({'HypoxiaInO2', 'RUQO2'}),
 frozenset({'CardiacMixing', 'LungFlow'}),
 frozenset({'DuctFlow', 'LVH'}),
 frozenset({'DuctFlow', 'HypoxiaInO2'}),
 frozenset({'ChestXray', 'LungFlow'}),
 frozenset({'ChestXray', 'LungParench'}),
 frozenset({'BirthAsphyxia', 'DuctFlow'}),
 frozenset({'Age', 'LungFlow'}),
 frozenset({'CO2', 'CO2Report'}),
 frozenset({'ChestXray', 'XrayReport'}),
 frozenset({'Disease', 'LVH'}),
 frozenset({'Grunting', 'LungParench'}),
 frozenset({'Age', 'Sick'}),
 frozenset({'CardiacMixing', 'Disease'}),
 frozenset({'BirthAsphyxia', 'CardiacMixing'}),
 frozenset({'Disease', 'DuctFlow'}),
 frozenset({'Disease', 'Sick'}),
 frozenset({'CO2', 'LungParench'}),
 frozenset({'CardiacMixing',

In [34]:
shd, learned_adj = get_mec_shd(true_G, new_mec)
    
learned_G = nx.from_numpy_array(learned_adj, create_using=nx.DiGraph)
learned_G = nx.relabel_nodes(learned_G, {i: n for i, n in zip(learned_G.nodes, true_G.nodes)})

diff = nx.difference(learned_G, true_G)
print("\nFinal wrong orientations:", diff.edges)

print('\nConfidence true DAG is in final MEC: %.3f' % p_correct)
print("Final MEC's SHD: ", shd)
print('MEC size: ', len(new_mec))
print('true-still-in-MEC: ', is_dag_in_mec(true_G, new_mec))


Final wrong orientations: [('Age', 'DuctFlow'), ('Age', 'Sick'), ('BirthAsphyxia', 'CardiacMixing'), ('CO2', 'LungParench'), ('CardiacMixing', 'LungFlow'), ('DuctFlow', 'BirthAsphyxia'), ('DuctFlow', 'HypoxiaInO2'), ('DuctFlow', 'LVH'), ('DuctFlow', 'LungParench'), ('Grunting', 'Sick'), ('GruntingReport', 'Grunting'), ('HypDistrib', 'CardiacMixing'), ('HypoxiaInO2', 'ChestXray'), ('LowerBodyO2', 'HypoxiaInO2'), ('LungFlow', 'Age'), ('RUQO2', 'HypoxiaInO2'), ('XrayReport', 'ChestXray')]

Confidence true DAG is in final MEC: nan
Final MEC's SHD:  23.0
MEC size:  1
true-still-in-MEC:  0.0


In [35]:
len(true_G.nodes)

20

In [19]:
len(cpdag.nodes)

20

In [29]:
new_mec

[[('CO2', 'CO2Report'),
  ('DuctFlow', 'HypoxiaInO2'),
  ('HypoxiaInO2', 'ChestXray'),
  ('HypDistrib', 'CardiacMixing'),
  ('DuctFlow', 'LVH'),
  ('CO2', 'LungParench'),
  ('RUQO2', 'HypoxiaInO2'),
  ('DuctFlow', 'LungParench'),
  ('CardiacMixing', 'BirthAsphyxia'),
  ('LungFlow', 'ChestXray'),
  ('Age', 'DuctFlow'),
  ('HypDistrib', 'LowerBodyO2'),
  ('Disease', 'Sick'),
  ('LowerBodyO2', 'HypoxiaInO2'),
  ('Disease', 'CardiacMixing'),
  ('LungParench', 'Grunting'),
  ('Age', 'Sick'),
  ('Disease', 'LVH'),
  ('XrayReport', 'ChestXray'),
  ('Disease', 'LungParench'),
  ('DuctFlow', 'BirthAsphyxia'),
  ('LungParench', 'ChestXray'),
  ('LungFlow', 'Age'),
  ('CardiacMixing', 'LungFlow'),
  ('GruntingReport', 'Grunting'),
  ('Disease', 'DuctFlow'),
  ('Grunting', 'Sick'),
  ('LVH', 'LVHreport')]]

In [30]:
cpdag.edges

{frozenset({'BirthAsphyxia', 'DuctFlow'}),
 frozenset({'CardiacMixing', 'Disease'}),
 frozenset({'CO2', 'LungParench'}),
 frozenset({'Age', 'Sick'}),
 frozenset({'Disease', 'LVH'}),
 frozenset({'ChestXray', 'LungParench'}),
 frozenset({'CardiacMixing', 'LungFlow'}),
 frozenset({'DuctFlow', 'LungParench'}),
 frozenset({'BirthAsphyxia', 'CardiacMixing'}),
 frozenset({'Age', 'LungFlow'}),
 frozenset({'CardiacMixing', 'HypDistrib'}),
 frozenset({'HypoxiaInO2', 'LowerBodyO2'}),
 frozenset({'HypoxiaInO2', 'RUQO2'}),
 frozenset({'HypDistrib', 'LowerBodyO2'}),
 frozenset({'ChestXray', 'HypoxiaInO2'}),
 frozenset({'Disease', 'Sick'}),
 frozenset({'CO2', 'CO2Report'}),
 frozenset({'ChestXray', 'XrayReport'}),
 frozenset({'Grunting', 'GruntingReport'}),
 frozenset({'ChestXray', 'LungFlow'}),
 frozenset({'DuctFlow', 'HypoxiaInO2'}),
 frozenset({'Disease', 'DuctFlow'}),
 frozenset({'DuctFlow', 'LVH'}),
 frozenset({'Age', 'DuctFlow'}),
 frozenset({'Grunting', 'LungParench'}),
 frozenset({'Disease', 

In [16]:
def edge_acc(directed_edges,undirected_edges, true_G):
    n=len(true_G.nodes)
    total_possible_edges=n*(n-1)/2

    true_edges = set(true_G.edges)
    incorrect = 0

    # count incorrect edges
    for edge in directed_edges:
        if edge not in true_edges:
            incorrect += 1
    
    for edge in undirected_edges:
        # For undirected edges, check if either direction exists in true graph
        reverse_edge = (edge[1], edge[0])
        incorrect += 1
    
    for edge in true_edges:
        reverse_edge = (edge[1], edge[0])
        if (edge not in directed_edges and 
            reverse_edge not in directed_edges and
            edge not in undirected_edges and 
            reverse_edge not in undirected_edges):
            incorrect += 1
    correct = total_possible_edges - incorrect
    accuracy = correct / total_possible_edges
    return accuracy
res=mec_to_graph(new_mec)
edge_acc(res[0], res[1], true_G)

0.8789473684210526

In [17]:
def edge_acc_pc(cpdag, true_G):
    directed_edges = set(cpdag.arcs)
    undirected_edges = set(get_undirected_edges_pdag(cpdag))

    n=len(true_G.nodes)
    total_possible_edges=n*(n-1)/2

    true_edges = set(true_G.edges)
    incorrect = 0

    # count incorrect edges
    for edge in directed_edges:
        if edge not in true_edges:
            incorrect += 1
    
    for edge in undirected_edges:
        # For undirected edges, check if either direction exists in true graph
        incorrect += 1
    
    for edge in true_edges:
        reverse_edge = (edge[1], edge[0])
        if (edge not in directed_edges and 
            reverse_edge not in directed_edges and
            edge not in undirected_edges and 
            reverse_edge not in undirected_edges):
            incorrect += 1
    correct = total_possible_edges - incorrect
    accuracy = correct / total_possible_edges
    return accuracy

edge_acc_pc(cpdag, true_G)

0.8263157894736842

In [36]:
mec=get_mec(cpdag)


[[('CO2', 'CO2Report'),
  ('GruntingReport', 'Grunting'),
  ('Disease', 'LVH'),
  ('LVH', 'LVHreport'),
  ('ChestXray', 'XrayReport'),
  ('HypoxiaInO2', 'RUQO2'),
  ('LungParench', 'CO2'),
  ('DuctFlow', 'LVH')],
 [('CO2', 'CO2Report'),
  ('Grunting', 'GruntingReport'),
  ('Disease', 'LVH'),
  ('XrayReport', 'ChestXray'),
  ('LVH', 'LVHreport'),
  ('HypoxiaInO2', 'RUQO2'),
  ('LungParench', 'CO2'),
  ('DuctFlow', 'LVH')],
 [('GruntingReport', 'Grunting'),
  ('CO2Report', 'CO2'),
  ('RUQO2', 'HypoxiaInO2'),
  ('XrayReport', 'ChestXray'),
  ('Disease', 'LVH'),
  ('LVH', 'LVHreport'),
  ('CO2', 'LungParench'),
  ('DuctFlow', 'LVH')],
 [('CO2Report', 'CO2'),
  ('RUQO2', 'HypoxiaInO2'),
  ('Grunting', 'GruntingReport'),
  ('Disease', 'LVH'),
  ('LVH', 'LVHreport'),
  ('ChestXray', 'XrayReport'),
  ('CO2', 'LungParench'),
  ('DuctFlow', 'LVH')],
 [('CO2Report', 'CO2'),
  ('Grunting', 'GruntingReport'),
  ('Disease', 'LVH'),
  ('XrayReport', 'ChestXray'),
  ('LVH', 'LVHreport'),
  ('CO2', 'Lu