In [1]:
import argparse

from causaldag import DAG

import networkx as nx
import os
import pandas as pd
import wandb
import numpy as np
import random

from algo.greedy_search import greedy_search_mec_size, greedy_search_confidence, greedy_search_bic

from models.noisy_expert import NoisyExpert
from models.oracles import EpsilonOracle

from utils.data_generation import generate_dataset
from utils.dag_utils import get_undirected_edges, is_dag_in_mec, get_mec, get_undirected_edges_pdag
from utils.metrics import get_mec_shd
from utils.language_models import get_lms_probs, temperature_scaling
from utils.CI_utils import load_data_from_file, save_data_to_file, data2pdag

INFO:datasets:PyTorch version 2.5.1 available.
INFO:httpx:HTTP Request: GET https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json "HTTP/1.1 200 OK"


In [2]:
algo="greedy_conf"
prior="mec"
dataset="asia"
tabular=False
probability="posterior"
preprocess=False
calibrate=False
epsilon=0.05
tolerance=0.1
seed=953100
verbose=False


def blindly_follow_expert(observed_arcs, model, cpdag, *args, **kwargs):
    return [list(observed_arcs) + list(cpdag.arcs)], observed_arcs, model(observed_arcs, observed_arcs)


match algo:
    case "greedy_mec":
        algo = greedy_search_mec_size
    case "greedy_conf":
        algo = greedy_search_confidence
    case "greedy_bic":
        algo = greedy_search_bic
        tolerance = 1.

    case "global_scoring":
        from algo.global_scoring import global_scoring
        algo = global_scoring
        tolerance = 1.

    case "PC":
        algo = lambda a, b, cpdag, c, tol: (get_mec(cpdag), dict(), 1.)
        tolerance = 0.
    case "naive":
        algo = blindly_follow_expert
        tolerance = 1.

match prior:
    case "mec":
        from models.priors import MECPrior
        prior_type = MECPrior
    
    case "independent":
        from models.priors import IndependentPrior
        prior_type = IndependentPrior

if not os.path.exists("_raw_bayesian_nets"):
    from utils.download_datasets import download_datasets
    download_datasets()

true_G, data = generate_dataset('_raw_bayesian_nets/' + dataset + '.bif')

if preprocess:
    # use pc and ges to generate the cpdag
    # try load data
    try:
        data = load_data_from_file('./' + dataset+'_data' + '.pkl')
    except:
        # save data
        save_data_to_file(data, './' + dataset+'_data' + '.pkl')
    
    cpdag = data2pdag(data)

else:
    cpdag = DAG.from_nx(true_G).cpdag()

In [7]:
undirected_edges = get_undirected_edges(true_G, verbose=verbose)
llm_engine="text-davinci-002"
if tabular:
    oracle = EpsilonOracle(undirected_edges, epsilon=epsilon)
    observations = oracle.decide_all()
    likelihoods = oracle.likelihoods

else:
    try:
        codebook = pd.read_csv('codebooks/' + dataset + '.csv')
    except:
        print('cannot load the codebook')
        codebook = None

    if calibrate:
        tmp_scale, eps = temperature_scaling(cpdag.arcs, codebook, engine=llm_engine)
        print("LLM has %.3f error rate" % eps)
    else:
        tmp_scale = 1.

    likelihoods, observations = get_lms_probs(undirected_edges, codebook, tmp_scale, engine=llm_engine)

In [8]:
cpdag

{('asia', 'tub'), ('smoke', 'bronc'), ('smoke', 'lung')}

In [9]:
print("\nTrue Orientations:", undirected_edges)
print("\nOrientations given by the expert:", observations)
print(likelihoods)
prior = prior_type(cpdag)
model = NoisyExpert(prior, likelihoods)


True Orientations: {('smoke', 'lung'), ('smoke', 'bronc'), ('asia', 'tub')}

Orientations given by the expert: [('smoke', 'lung'), ('smoke', 'bronc'), ('asia', 'tub')]
{('smoke', 'lung'): 0.9972375690555538, ('lung', 'smoke'): 0.002762430944446141, ('smoke', 'bronc'): 0.9982206405683903, ('bronc', 'smoke'): 0.0017793594316096988, ('asia', 'tub'): 0.8571428571360543, ('tub', 'asia'): 0.14285714286394557}


In [10]:
match probability:
    case "posterior":
        prob_method = model.posterior
    
    case "likelihood":
        prob_method = model.likelihood

    case "prior":
        prob_method = lambda _, edges: prior(edges)

new_mec, decisions, p_correct = algo(observations, prob_method, cpdag, likelihoods, tol=tolerance)

In [11]:
shd, learned_adj = get_mec_shd(true_G, new_mec)
    
learned_G = nx.from_numpy_array(learned_adj, create_using=nx.DiGraph)
learned_G = nx.relabel_nodes(learned_G, {i: n for i, n in zip(learned_G.nodes, true_G.nodes)})

diff = nx.difference(learned_G, true_G)
print("\nFinal wrong orientations:", diff.edges)

print('\nConfidence true DAG is in final MEC: %.3f' % p_correct)
print("Final MEC's SHD: ", shd)
print('MEC size: ', len(new_mec))
print('true-still-in-MEC: ', is_dag_in_mec(true_G, new_mec))


Final wrong orientations: [('tub', 'asia')]

Confidence true DAG is in final MEC: 0.994
Final MEC's SHD:  1.0
MEC size:  2
true-still-in-MEC:  1.0


In [12]:
import networkx as nx
def mec_to_graph(mec_list):
    directed_set = set()
    undirected_set = set()
    for mec in mec_list:
        for edge in mec:
            # if (edge[0], edge[1]) in directed_set, do nothing; if (edge[1], edge[0]) in directed_set, remove it, add to undirected_set
            if (edge[0], edge[1]) in directed_set or ((edge[0], edge[1]) in undirected_set or (edge[1], edge[0]) in undirected_set):
                continue
            elif (edge[1], edge[0]) in directed_set:
                directed_set.remove((edge[1], edge[0]))
                undirected_set.add((edge[0], edge[1]))
            else:
                directed_set.add(edge)  
    return directed_set, undirected_set

print(mec_to_graph(get_mec(cpdag)))
print(mec_to_graph(new_mec))

({('bronc', 'dysp'), ('lung', 'either'), ('either', 'xray'), ('either', 'dysp'), ('tub', 'either')}, {('bronc', 'smoke'), ('asia', 'tub'), ('smoke', 'lung')})
({('bronc', 'dysp'), ('lung', 'either'), ('smoke', 'bronc'), ('either', 'xray'), ('either', 'dysp'), ('smoke', 'lung'), ('tub', 'either')}, {('asia', 'tub')})


In [14]:
cpdag.arcs

{('bronc', 'dysp'),
 ('either', 'dysp'),
 ('either', 'xray'),
 ('lung', 'either'),
 ('tub', 'either')}

In [16]:
decisions

[('smoke', 'lung'), ('smoke', 'lung'), ('smoke', 'bronc'), ('smoke', 'bronc')]