In [1]:
import argparse

from causaldag import DAG

import networkx as nx
import os
import pandas as pd
import wandb
import numpy as np
import random

from algo.greedy_search import greedy_search_mec_size, greedy_search_confidence, greedy_search_bic

from models.noisy_expert import NoisyExpert
from models.oracles import EpsilonOracle

from utils.data_generation import generate_dataset
from utils.dag_utils import get_undirected_edges, is_dag_in_mec, get_mec, get_undirected_edges_pdag
from utils.metrics import get_mec_shd
from utils.language_models import get_lms_probs, temperature_scaling
from utils.CI_utils import load_data_from_file, save_data_to_file, data2pdag

INFO:datasets:PyTorch version 2.5.1 available.
INFO:httpx:HTTP Request: GET https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json "HTTP/1.1 200 OK"


In [2]:
algo="greedy_conf"
# algo="naive"
prior="mec"
dataset="asia"
tabular=False
probability="posterior"
preprocess=True
calibrate=False
epsilon=0.05
tolerance=0.1
seed=953100
verbose=False


def blindly_follow_expert(observed_arcs, model, cpdag, *args, **kwargs):
    return [list(observed_arcs) + list(cpdag.arcs)], observed_arcs, model(observed_arcs, observed_arcs)


match algo:
    case "greedy_mec":
        algo = greedy_search_mec_size
    case "greedy_conf":
        algo = greedy_search_confidence
    case "greedy_bic":
        algo = greedy_search_bic
        tolerance = 1.

    case "global_scoring":
        from algo.global_scoring import global_scoring
        algo = global_scoring
        tolerance = 1.

    case "PC":
        algo = lambda a, b, cpdag, c, tol: (get_mec(cpdag), dict(), 1.)
        tolerance = 0.
    case "naive":
        algo = blindly_follow_expert
        tolerance = 1.

match prior:
    case "mec":
        from models.priors import MECPrior
        prior_type = MECPrior
    
    case "independent":
        from models.priors import IndependentPrior
        prior_type = IndependentPrior

if not os.path.exists("_raw_bayesian_nets"):
    from utils.download_datasets import download_datasets
    download_datasets()

true_G, data = generate_dataset('_raw_bayesian_nets/' + dataset + '.bif')

if preprocess:
    # use pc and ges to generate the cpdag
    # try load data
    try:
        data = load_data_from_file('./' + dataset+'_data' + '.pkl')
    except:
        # save data
        save_data_to_file(data, './' + dataset+'_data' + '.pkl')
    
    cpdag = data2pdag(data)

else:
    cpdag = DAG.from_nx(true_G).cpdag()

  0%|          | 0/8 [00:00<?, ?it/s]

In [3]:
undirected_edges = get_undirected_edges_pdag(cpdag, verbose=True)
llm_engine="text-davinci-002"
if tabular:
    oracle = EpsilonOracle(undirected_edges, epsilon=epsilon)
    observations = oracle.decide_all()
    likelihoods = oracle.likelihoods

else:
    try:
        codebook = pd.read_csv('codebooks/' + dataset + '.csv')
    except:
        print('cannot load the codebook')
        codebook = None

    if calibrate:
        tmp_scale, eps = temperature_scaling(cpdag.arcs, codebook, engine=llm_engine)
        print("LLM has %.3f error rate" % eps)
    else:
        tmp_scale = 1.

    likelihoods, observations = get_lms_probs(undirected_edges, codebook, tmp_scale, engine=llm_engine)

In [4]:
observations

[('lung', 'either'),
 ('asia', 'tub'),
 ('xray', 'either'),
 ('smoke', 'bronc'),
 ('tub', 'either'),
 ('smoke', 'lung')]

In [5]:
print("\nTrue Orientations:", undirected_edges)
print("\nOrientations given by the expert:", observations)
print(likelihoods)
prior = prior_type(cpdag)
model = NoisyExpert(prior, likelihoods)


True Orientations: {('either', 'lung'), ('asia', 'tub'), ('xray', 'either'), ('smoke', 'bronc'), ('tub', 'either'), ('smoke', 'lung')}

Orientations given by the expert: [('lung', 'either'), ('asia', 'tub'), ('xray', 'either'), ('smoke', 'bronc'), ('tub', 'either'), ('smoke', 'lung')]
{('either', 'lung'): 0.5, ('lung', 'either'): 0.5, ('asia', 'tub'): 0.8571428571360543, ('tub', 'asia'): 0.14285714286394557, ('xray', 'either'): 0.6086956521540643, ('either', 'xray'): 0.39130434784593576, ('smoke', 'bronc'): 0.9982206405683903, ('bronc', 'smoke'): 0.0017793594316096988, ('tub', 'either'): 0.8906249999963379, ('either', 'tub'): 0.10937500000366211, ('smoke', 'lung'): 0.9972375690555538, ('lung', 'smoke'): 0.002762430944446141}


In [6]:
match probability:
    case "posterior":
        prob_method = model.posterior
    
    case "likelihood":
        prob_method = model.likelihood

    case "prior":
        prob_method = lambda _, edges: prior(edges)

new_mec, decisions, p_correct = algo(observations, prob_method, cpdag, likelihoods, tol=0.3)

In [7]:
likelihoods

{('either', 'lung'): 0.5,
 ('lung', 'either'): 0.5,
 ('asia', 'tub'): 0.8571428571360543,
 ('tub', 'asia'): 0.14285714286394557,
 ('xray', 'either'): 0.6086956521540643,
 ('either', 'xray'): 0.39130434784593576,
 ('smoke', 'bronc'): 0.9982206405683903,
 ('bronc', 'smoke'): 0.0017793594316096988,
 ('tub', 'either'): 0.8906249999963379,
 ('either', 'tub'): 0.10937500000366211,
 ('smoke', 'lung'): 0.9972375690555538,
 ('lung', 'smoke'): 0.002762430944446141}

In [8]:
decisions

[('lung', 'either'),
 ('lung', 'either'),
 ('tub', 'asia'),
 ('tub', 'asia'),
 ('either', 'xray'),
 ('either', 'xray'),
 ('smoke', 'bronc'),
 ('smoke', 'bronc'),
 ('either', 'tub'),
 ('either', 'tub'),
 ('smoke', 'lung'),
 ('smoke', 'lung')]

In [36]:
new_mec

[[('lung', 'either'),
  ('either', 'dysp'),
  ('either', 'xray'),
  ('bronc', 'dysp'),
  ('tub', 'asia'),
  ('smoke', 'bronc'),
  ('either', 'tub'),
  ('smoke', 'lung')]]

In [37]:
get_mec(cpdag)

[[('either', 'dysp'),
  ('either', 'xray'),
  ('bronc', 'dysp'),
  ('tub', 'asia'),
  ('either', 'lung'),
  ('lung', 'smoke'),
  ('smoke', 'bronc'),
  ('either', 'tub')],
 [('lung', 'either'),
  ('either', 'dysp'),
  ('either', 'xray'),
  ('bronc', 'dysp'),
  ('tub', 'asia'),
  ('either', 'tub'),
  ('bronc', 'smoke'),
  ('smoke', 'lung')],
 [('either', 'dysp'),
  ('either', 'xray'),
  ('bronc', 'dysp'),
  ('tub', 'asia'),
  ('either', 'lung'),
  ('lung', 'smoke'),
  ('smoke', 'bronc'),
  ('tub', 'either')],
 [('either', 'dysp'),
  ('either', 'xray'),
  ('bronc', 'dysp'),
  ('asia', 'tub'),
  ('either', 'lung'),
  ('lung', 'smoke'),
  ('smoke', 'bronc'),
  ('tub', 'either')],
 [('lung', 'either'),
  ('either', 'dysp'),
  ('either', 'xray'),
  ('bronc', 'dysp'),
  ('tub', 'asia'),
  ('smoke', 'bronc'),
  ('either', 'tub'),
  ('smoke', 'lung')],
 [('either', 'dysp'),
  ('bronc', 'dysp'),
  ('tub', 'asia'),
  ('either', 'lung'),
  ('lung', 'smoke'),
  ('smoke', 'bronc'),
  ('either', 'tub'

In [9]:
shd, learned_adj = get_mec_shd(true_G, new_mec)
    
learned_G = nx.from_numpy_array(learned_adj, create_using=nx.DiGraph)
learned_G = nx.relabel_nodes(learned_G, {i: n for i, n in zip(learned_G.nodes, true_G.nodes)})

diff = nx.difference(learned_G, true_G)
print("\nFinal wrong orientations:", diff.edges)

print('\nConfidence true DAG is in final MEC: %.3f' % p_correct)
print("Final MEC's SHD: ", shd)
print('MEC size: ', len(new_mec))
print('true-still-in-MEC: ', is_dag_in_mec(true_G, new_mec))


Final wrong orientations: [('either', 'tub'), ('tub', 'asia')]

Confidence true DAG is in final MEC: 0.732
Final MEC's SHD:  2.0
MEC size:  1
true-still-in-MEC:  0.0


In [10]:
import networkx as nx
def mec_to_graph(mec_list):
    directed_set = set()
    undirected_set = set()
    for mec in mec_list:
        for edge in mec:
            # if (edge[0], edge[1]) in directed_set, do nothing; if (edge[1], edge[0]) in directed_set, remove it, add to undirected_set
            if (edge[0], edge[1]) in directed_set or ((edge[0], edge[1]) in undirected_set or (edge[1], edge[0]) in undirected_set):
                continue
            elif (edge[1], edge[0]) in directed_set:
                directed_set.remove((edge[1], edge[0]))
                undirected_set.add((edge[0], edge[1]))
            else:
                directed_set.add(edge)  
    return directed_set, undirected_set

print(mec_to_graph(get_mec(cpdag)))
print(mec_to_graph(new_mec))

({('bronc', 'dysp'), ('either', 'dysp')}, {('either', 'lung'), ('asia', 'tub'), ('xray', 'either'), ('lung', 'smoke'), ('bronc', 'smoke'), ('tub', 'either')})
({('lung', 'either'), ('either', 'xray'), ('either', 'tub'), ('bronc', 'dysp'), ('tub', 'asia'), ('smoke', 'bronc'), ('either', 'dysp'), ('smoke', 'lung')}, set())


In [17]:
true_G.edges

OutEdgeView([('asia', 'tub'), ('bronc', 'dysp'), ('either', 'xray'), ('either', 'dysp'), ('lung', 'either'), ('smoke', 'lung'), ('smoke', 'bronc'), ('tub', 'either')])

In [17]:
def edge_acc(directed_edges,undirected_edges, true_G):
    n=len(true_G.nodes)
    total_possible_edges=n*(n-1)/2

    true_edges = set(true_G.edges)
    incorrect = 0

    # count incorrect edges
    for edge in directed_edges:
        if edge not in true_edges:
            incorrect += 1
    
    for edge in undirected_edges:
        # For undirected edges, check if either direction exists in true graph
        reverse_edge = (edge[1], edge[0])
        incorrect += 1
    
    for edge in true_edges:
        reverse_edge = (edge[1], edge[0])
        if (edge not in directed_edges and 
            reverse_edge not in directed_edges and
            edge not in undirected_edges and 
            reverse_edge not in undirected_edges):
            incorrect += 1
    correct = total_possible_edges - incorrect
    accuracy = correct / total_possible_edges
    return accuracy
res=mec_to_graph(new_mec)
edge_acc(res[0], res[1], true_G)

0.9285714285714286

In [16]:
def edge_acc_pc(cpdag, true_G):
    directed_edges = set(cpdag.arcs)
    undirected_edges = set(get_undirected_edges_pdag(cpdag))

    n=len(true_G.nodes)
    total_possible_edges=n*(n-1)/2

    true_edges = set(true_G.edges)
    incorrect = 0

    # count incorrect edges
    for edge in directed_edges:
        if edge not in true_edges:
            incorrect += 1
    
    for edge in undirected_edges:
        # For undirected edges, check if either direction exists in true graph
        incorrect += 1
    
    for edge in true_edges:
        reverse_edge = (edge[1], edge[0])
        if (edge not in directed_edges and 
            reverse_edge not in directed_edges and
            edge not in undirected_edges and 
            reverse_edge not in undirected_edges):
            incorrect += 1
    correct = total_possible_edges - incorrect
    accuracy = correct / total_possible_edges
    return accuracy

edge_acc_pc(cpdag, true_G)

0.7857142857142857