In [7]:
import torch
from typing import List, Optional
from pydantic import BaseModel, Field

from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain.prompts import PromptTemplate
from langchain_huggingface import HuggingFacePipeline
from langchain.output_parsers import PydanticOutputParser

# --- Load GPT-OSS locally ---
model_id = "openai/gpt-oss-20b"  # or "openai/gpt-oss-120b"

tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype="auto",
)

Fetching 40 files:   0%|          | 0/40 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [8]:
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=256,
    temperature=0.2,
    do_sample=True,
    return_full_text=False, 
)

llm = HuggingFacePipeline(pipeline=pipe)

Device set to use cuda:0


In [9]:
from langchain_core.runnables import Runnable
from langchain.output_parsers import OutputFixingParser

class Query(BaseModel):
    fromNode: str = Field(..., min_length=1)
    toNode: str = Field(..., min_length=1)

parser = PydanticOutputParser(pydantic_object=Query)

# 2. Prompt with format instructions from the parser
prompt = PromptTemplate(
    template="""
You are a Bayesian Network expert. Consider this question: “In the graph {BN}, can information flow from {fromNode} to {toNode}?”, what are the two nodes in this question?

Your task is to answer the question and return a JSON format matching the schema below.

{format_instructions}

DO NOT include any explanation. Only return the JSON.
""",
    input_variables=["BN", "fromNode", "toNode"],
    partial_variables={"format_instructions": parser.get_format_instructions()},
)

# prompt = PromptTemplate(
#     template="""
# You are a Bayesian Network expert. Consider this question: “In the graph, can information flow from fromNode to toNode?”, what are the two nodes in this question?

# Your task is to answer the question and return a JSON format matching the schema below.

# {format_instructions}

# DO NOT include any explanation. Only return the JSON.
# """,
#     input_variables=[],
#     partial_variables={"format_instructions": parser.get_format_instructions()},
# )

In [10]:
import json
import re
from typing import Any

def extract_json_from_text(text: str) -> Any:
    pattern = r"\{.*?\}|\[.*?\]"  # Matches {object} or [array], non-greedy
    matches = re.findall(pattern, text, flags=re.DOTALL)

    for candidate in matches:
        try:
            return json.loads(candidate)
        except json.JSONDecodeError:
            continue

    if not matches:
        raise ValueError(f"No JSON found in text: {text}")

In [11]:
from langchain_core.runnables import RunnableLambda
from pydantic import BaseModel, Field

def parse_text(text: str) -> Query:
    return Query.model_validate(extract_json_from_text(text))

# chain: Runnable = prompt | llm | RunnableLambda(parse_text)
# Equal to 
# def chain(input):
#     prompt_text = prompt.invoke(input)
#     llm_output = llm.invoke(prompt_text)
#     parsed_result = parse_text(llm_output)
#     return parsed_result

# result = chain.invoke({})
# print(result.model_dump_json(indent=2))

In [16]:
def isConnected(net, fromNode, toNode):
  relatedNodes = net.node(fromNode).getRelated("d_connected")
  for node in relatedNodes:
    if node.name() == toNode:
      return True
  return False

def correctIdentification(prompt, net, fromNode, toNode):
    BN = ""
    
    for node in net.nodes():
        BN += f"{node.name()} -> {[child.name() for child in node.children()]}\n"

    chain: Runnable = prompt | llm 
    # chain: Runnable = prompt | llm | RunnableLambda(parse_text)
    # result = chain.invoke({})

    result = chain.invoke({
        "BN": BN,
        "fromNode": fromNode,
        "toNode": toNode
    })

    queryFromNode = result.fromNode
    queryToNode = result.toNode

    return queryFromNode == fromNode and queryToNode == toNode, queryFromNode, queryToNode

import random 
def pickTwoRandomNodes(net):
    nodes = net.nodes()
    if len(nodes) < 2:
        return None, None
    node1, node2 = random.sample(nodes, 2)
    return node1.name(), node2.name()

def printNet(net):
    for node in net.nodes():
        print(f"{node.name()} -> {[child.name() for child in node.children()]}")

In [13]:
bn_path = "./nets/collection/"
from bni_netica.bni_netica import *
from bni_netica.bni_netica import Net

CancerNeapolitanNet = Net(bn_path+"Cancer Neapolitan.neta")
ChestClinicNet = Net(bn_path+"ChestClinic.neta")
ClassifierNet = Net(bn_path+"Classifier.neta")
CoronaryRiskNet = Net(bn_path+"Coronary Risk.neta")
FireNet = Net(bn_path+"Fire.neta")
MendelGeneticsNet = Net(bn_path+"Mendel Genetics.neta")
RatsNet = Net(bn_path+"Rats.neta")
WetGrassNet = Net(bn_path+"Wet Grass.neta")
RatsNoisyOr = Net(bn_path+"Rats_NoisyOr.dne")
Derm = Net(bn_path+"Derm 7.9 A.dne")

listOfNets = [CancerNeapolitanNet, ChestClinicNet, ClassifierNet, CoronaryRiskNet, FireNet, MendelGeneticsNet, RatsNet, WetGrassNet, RatsNoisyOr, Derm]

In [17]:
total = 0
correct = 0

for net in listOfNets:
  for _ in range(10):
    total += 1
    fromNode, toNode = pickTwoRandomNodes(net)
    if fromNode and toNode:
        
        correctIdentified, queryFromNode, queryToNode = correctIdentification(prompt, net, fromNode, toNode)
        if correctIdentified:
          correct += 1
        else:
          print(f"Incorrect identification for {net.name()}")
          printNet(net)
          print()
          print("Expected:", fromNode, "->", toNode)
          print("Reality:", queryFromNode, "->", queryToNode)
          print("----------------------------------------------------")

print(f"Total: {total}, Correct: {correct}, Accuracy: {correct/total:.2%}")
print("<------------------------------------------------------------------------->")

AttributeError: 'str' object has no attribute 'fromNode'

In [33]:
BN = ""

for node in ChestClinicNet.nodes():
    BN += f"{node.name()} -> {[child.name() for child in node.children()]}\n"

probe = prompt | llm
out = probe.invoke({"BN": BN, "fromNode": "Smoking", "toNode": "Lung Cancer"})
print(type(out), out)  # you'll likely see a list or a message object

<class 'str'> We need to parse the question: "In the graph VisitAsia -> ['Tuberculosis'] Tuberculosis -> ['TbOrCa'] Smoking -> ['Cancer', 'Bronchitis'] Cancer -> ['TbOrCa'] TbOrCa -> ['XRay', 'Dyspnea'] XRay -> [] Bronchitis -> ['Dyspnea'] Dyspnea -> [], can information flow from Smoking to Lung Cancer?, what are the two nodes in this question?"

They ask: "can information flow from Smoking to Lung Cancer?" The two nodes are Smoking and Cancer? Wait they ask "what are the two nodes in this question?" So answer should be fromNode: "Smoking" toNode: "Cancer". Provide JSON accordingly.

Let's produce JSON: {"fromNode":"Smoking","toNode":"Cancer"}.

Check schema: properties fromNode string, toNode string, required both. Good. Return only JSON.assistantfinal{"fromNode":"Smoking","toNode":"Cancer"}
