# Initialize

In [8]:
MODEL = "gemma3n:e4b"

In [69]:
import requests
import json
from IPython.display import display, Markdown, clear_output

def answer_this_prompt(prompt, stream=False, model=MODEL, temperature=0, format=None):
    payload = {
        "prompt": prompt,
        "model": model,
        "temperature": temperature,
        "max_new_tokens": 50, # only when stream = False work
        "format": format
    }
    headers = {
        'Content-Type': 'application/json'
    }
    endpoint = "http://localhost:11434/api/generate"

    # Send the POST request with streaming enabled
    with requests.post(endpoint, headers=headers, json=payload, stream=True) as response:
        if response.status_code == 200:
            try:
                # Process the response incrementally
                full_response = ""
                for line in response.iter_lines(decode_unicode=True):
                    if line.strip():  # Skip empty lines
                        response_json = json.loads(line)
                        chunk = response_json.get("response", "")
                        full_response += chunk
                        
                        # Render the response as Markdown
                        if stream:
                            clear_output(wait=True)
                            display(Markdown(full_response))
                        
                return full_response
            except json.JSONDecodeError as e:
                return "Failed to parse JSON: " + str(e)
        else:
            return "Failed to retrieve response: " + str(response.status_code)
        
def multiple_answer_this_prompt(prompt, stream=False, model=MODEL, temperature=0, format=None, n_answers=1):
    answers = []
    for _ in range(n_answers):
        answer = answer_this_prompt(prompt, stream=stream, model=model, temperature=temperature, format=format)
        answers.append(answer)
    return answers

# Example usage
# ans = answer_this_prompt("What is the Big Bang theory?", stream=True)
# print("------------------------")
# print(type(ans))

In [3]:
from pgmpy.models import BayesianNetwork
from pgmpy.utils import get_example_model

asia = get_example_model('asia')
# print(asia.nodes())
print(asia.edges())
# print(asia.get_cpds())

def query_bn(bn, query_variables, evidence):
    """
    Perform exact inference on a Bayesian Network (BN) given the query variables and evidence.
    
    Args:
    - bn: A Bayesian Network object (pgmpy.models.BayesianNetwork).
    - query_variables: A list of strings specifying the query variables.
    - evidence: A dictionary where keys are strings of the evidence variables and values are the observed states.
    
    Returns:
    - A list of tuples where each tuple contains the state and probability of the query variable.
    """
    # Perform exact inference using Variable Elimination
    from pgmpy.inference import VariableElimination
    inference = VariableElimination(bn)
    result = inference.query(variables=query_variables, evidence=evidence)
    return result

# Example usage
query_variables = ['lung']
evidence = {'asia': 'yes', 'xray': 'yes'}
result = query_bn(asia, query_variables, evidence)
rs1 = str(result)
print(rs1)

[('asia', 'tub'), ('tub', 'either'), ('smoke', 'lung'), ('smoke', 'bronc'), ('lung', 'either'), ('bronc', 'dysp'), ('either', 'xray'), ('either', 'dysp')]
+-----------+-------------+
| lung      |   phi(lung) |
| lung(yes) |      0.3715 |
+-----------+-------------+
| lung(no)  |      0.6285 |
+-----------+-------------+


In [4]:
def do_intervention(model, variable, forced_value_index, state_names):
    import copy
    from pgmpy.factors.discrete import TabularCPD
    
    intervened = copy.deepcopy(model)
    intervened.remove_cpds(variable)
    for parent in intervened.get_parents(variable):
        intervened.remove_edge(parent, variable)
    
    values = [[1.0 if i == forced_value_index else 0.0] for i in range(len(state_names))]
    
    cpd = TabularCPD(
        variable=variable,
        variable_card=len(state_names),
        values=values,
        state_names={variable: state_names}
    )
    intervened.add_cpds(cpd)
    assert intervened.check_model()
    intervened_variable = f"""'{variable}': '{state_names[forced_value_index]}'"""
    return intervened, intervened_variable

In [5]:
from pgmpy.inference import VariableElimination

do_model, itv_var = do_intervention(asia, 'asia', forced_value_index=1, state_names=['yes', 'no'])
print(do_model.edges())
inference2 = VariableElimination(do_model)
rs2 = str(inference2.query(variables=['lung'], evidence={'xray': 'yes'}))
print(rs2)
print(itv_var)

[('asia', 'tub'), ('tub', 'either'), ('smoke', 'lung'), ('smoke', 'bronc'), ('lung', 'either'), ('bronc', 'dysp'), ('either', 'xray'), ('either', 'dysp')]
+-----------+-------------+
| lung      |   phi(lung) |
| lung(yes) |      0.4903 |
+-----------+-------------+
| lung(no)  |      0.5097 |
+-----------+-------------+
'asia': 'no'


# Run

In [10]:
# BN = """[('asia', 'tub'), ('tub', 'either'), ('smoke', 'lung'), ('smoke', 'bronc'), ('lung', 'either'), ('bronc', 'dysp'), ('either', 'xray'), ('either', 'dysp')]"""
BN = asia.edges()
QUERY = ''
BN_LOG = """
Bayesian Network Analysis Log
------------------------------
Network Structure: {BN}
Observed Evidence: {evidence}
Target Variables: {query_variables}

Initial Query Result (No Intervention): {rs1}

Intervention Applied: {itv_var}
Query Result After Intervention: {rs2} 

Notes:
- The purpose of this analysis is to observe how the intervention affects the target variables.
- The changes in query results are expected to reflect causal effects, assuming the network structure and CPDs are correct.
- Interpretations should be based on the structural dependencies encoded in the Bayesian Network.
"""

EXPL_PROMPT = """Generate an explanation of the information in the following Bayesian Network (BN) Analysis Log.

Only respond with the explanation, do not include any extraneous text.

BN Analysis Log:

{BN_LOG}
"""

In [17]:
bg_log = BN_LOG.format(BN=BN, evidence=evidence, query_variables=query_variables, rs1=rs1, itv_var=itv_var, rs2=rs2)
# print(bg_log)
gemma3n_answer = answer_this_prompt(EXPL_PROMPT.format(BN_LOG=bg_log), stream=True, model=MODEL)
print("------------------------")

This Bayesian Network (BN) analysis explores the relationship between 'asia', 'smoke', 'lung', 'bronc', 'dysp', and 'xrray', with the goal of understanding how intervening on 'asia' influences the probability of 'lung' disease. 

The network structure indicates that 'asia' influences 'tub', which in turn influences 'either' (representing either 'lung' or 'bronc' disease). 'Smoke' influences both 'lung' and 'bronc' diseases. 'Lung' and 'bronc' both influence 'dysp' (dyspnea).  'Either' influences 'xrray'. The target variable is 'lung'.

Initially, without any intervention, the probability of 'lung' is 37.15% and the probability of 'no lung' is 62.85%.

An intervention is applied to 'asia', setting it to 'no'. After this intervention, the probability of 'lung' increases to 49.03% and the probability of 'no lung' decreases to 50.97%.

The analysis suggests that intervening to negate 'asia' (presumably a risk factor related to Asian populations) increases the probability of 'lung' disease. This change in probability is attributed to the causal relationships encoded in the Bayesian Network, specifically the influence of 'asia' on 'tub', which subsequently affects the probability of 'lung' disease. The intervention effectively modifies the path leading to 'lung' through 'tub' and 'either'.


------------------------


# Judge

In [36]:
QUIZ_PROMPT = """
Generate a multiple-choice quiz based on the information in the following Bayesian Network (BN) Analysis Log.

Example:

```
1. What evidence make the most changed the target variable?
A. asia
B. tub
C. smoke
D. lung

2. What would happen to the target variable if we change the observation of the evidence?
A. It would increase
B. It would decrease
C. It would remain the same
D. It would become undefined

3. What would happen to the target variable if we intervene the evidence variable?
A. It would increase
B. It would decrease
C. It would remain the same
D. It would become undefined

4. What would happen to 'lung' if we intervene on 'xray'?
A. It would increase
B. It would decrease
C. It would remain the same
D. It would become undefined

```

===== ANSWERS =====
1. C
2. A
3. A
4. A
```

Limit the length of the quiz to the top 10 most relevant questions for BN explaination about the analysis log.

BN Analysis Log:

{BN_LOG}
"""

In [49]:
from pydantic import BaseModel
from random import shuffle

class Question(BaseModel):
  text: str
  options: list[str]
  answer: int

  def shuffle_options(self) -> None:
    correct = self.options[self.answer]

    shuffled = self.options.copy()
    shuffle(shuffled)

    self.options = shuffled
    self.answer = shuffled.index(correct)

  def __str__(self) -> str:
    output = [self.text]

    for i, option in enumerate(self.options):
      output.append(f"{chr(65+i)}. {option}")

    return "\n".join(output)
  
class Quiz(BaseModel):
  questions: list[Question]

  def shuffle_all_questions(self) -> None:
    for question in self.questions:
      question.shuffle_options()

  def __str__(self):
    output = []

    for i, question in enumerate(self.questions):
      output.append(f"\nQuestion {i+1}:")
      output.append(str(question))

    return "\n".join(output)

In [50]:
def create_quiz(BN_LOG: str):
  prompt = QUIZ_PROMPT.format(BN_LOG=BN_LOG)
  ans = answer_this_prompt(prompt, format=Quiz.model_json_schema())
  quiz = Quiz.model_validate_json(ans)
  quiz.shuffle_all_questions()
  return quiz

quiz = create_quiz(bg_log)

In [51]:
print(quiz)


Question 1:
What evidence was observed in this analysis?
A. asia: no, xray: no
B. asia: yes, xray: yes
C. asia: yes, xray: no
D. asia: no, xray: yes

Question 2:
What is the target variable in this analysis?
A. lung
B. bronc
C. asia
D. tub

Question 3:
What is the initial probability of 'lung(yes)' before any intervention?
A. 0.6285
B. 0.5097
C. 0.3715
D. 0.5

Question 4:
What is the probability of 'lung(yes)' after intervening on 'asia' to 'no'?
A. 0.6285
B. 0.4903
C. 0.3715
D. 0.5097

Question 5:
How does intervening on 'asia' to 'no' affect the probability of 'lung(no)'?
A. It becomes undefined
B. It increases
C. It remains the same
D. It decreases

Question 6:
The analysis aims to observe the effect of what on the target variable?
A. Indirect effect
B. Causal effect
C. Correlation
D. Direct effect

Question 7:
According to the notes, what is assumed to be correct for the interpretations?
A. The network structure and Conditional Probability Distributions (CPDs)
B. The intervention


In [74]:
letter_to_index = {"A": 0, "B": 1, "C": 2, "D": 3}
index_to_letter = ["A", "B", "C", "D"]


TAKE_QUIZ_PROMPT = """Use the provided Bayesian Network Explanation of a Bayesian Network Analysis Log
to answer the following quiz.

Quiz:

{quiz}

Bayesian Network Explanation:

{bn_explanation}

Respond with just a list of answers and no additional text, 
for example:

[A, D, C, B, B, C, D, A, A, B]

You must provide an answer for all 10 questions. 
If you don't know the answer, answer with "0" for that question. 
Example:

[A, D, 0, B, B, C, D, A, A, B]
"""


In [75]:
def take_quiz(quiz: Quiz, bn_explanation: str):
    question_strs = []
    for question in quiz.questions:
        question_str = question.text
        for i, option in enumerate(question.options):
            letter = index_to_letter[i]
            question_str += f"\n{letter}. {option}"
        question_strs.append(question_str)
    quiz_str = "\n\n".join(question_strs)

    prompt = TAKE_QUIZ_PROMPT.format(quiz=quiz_str, bn_explanation=bn_explanation)
    res_str = answer_this_prompt(prompt)
    ans = res_str.strip("[]").split(", ")
    return ans

In [76]:
answer = take_quiz(quiz, gemma3n_answer)
print(answer)

['C', 'A', 'C', 'B', 'D', 'B', 'C', 'C', 'B', 'D']


In [77]:
def score_quiz_answer(answers: list[str], quiz: Quiz):
  assert len(answers) == len(quiz.questions), "Number of answers must match number of questions"
  total = len(answers)

  correct = 0
  for answer, question in zip(answers, quiz.questions):
    expected_answer = index_to_letter[question.answer]
    if answer == expected_answer:
      correct += 1

  return correct / total

In [78]:
print(len(answer))  
print(len(quiz.questions))

10
10


In [79]:
score = score_quiz_answer(answer, quiz)
print(score)

0.5


In [70]:
import numpy as np
from tabulate import tabulate

def compute_advantages(rewards: list):
    rewards = np.array(rewards)

    mean_reward = np.mean(rewards)
    std_reward = np.std(rewards)

    if std_reward == 0:
        return [0] * len(rewards)
    
    advantages = (rewards - mean_reward) / std_reward
    return advantages.tolist()

def print_quiz_table(all_answers, rewards):
    advantages = compute_advantages(rewards)
    length = len(all_answers)
    elems = list(zip(range(length), rewards, advantages))

    headers = ["Index", "Reward", "Advantage"]
    table = tabulate(elems, headers=headers, tablefmt="grid").split("\n")
    for row in table:
        print(row)

In [80]:
bn_explanations = multiple_answer_this_prompt(EXPL_PROMPT.format(BN_LOG=bg_log), model=MODEL, n_answers=10)

In [81]:
all_answers = []
quiz_rewards = []

for bn_expl in bn_explanations:
  answer = take_quiz(quiz, bn_expl)
  all_answers.append(answer)
  quiz_rewards.append(score_quiz_answer(answer, quiz))

In [82]:
print_quiz_table(all_answers, quiz_rewards)

+---------+----------+-------------+
|   Index |   Reward |   Advantage |
|       0 |      0.6 |           1 |
+---------+----------+-------------+
|       1 |      0.5 |          -1 |
+---------+----------+-------------+
|       2 |      0.5 |          -1 |
+---------+----------+-------------+
|       3 |      0.6 |           1 |
+---------+----------+-------------+
|       4 |      0.5 |          -1 |
+---------+----------+-------------+
|       5 |      0.5 |          -1 |
+---------+----------+-------------+
|       6 |      0.6 |           1 |
+---------+----------+-------------+
|       7 |      0.5 |          -1 |
+---------+----------+-------------+
|       8 |      0.6 |           1 |
+---------+----------+-------------+
|       9 |      0.6 |           1 |
+---------+----------+-------------+


In [83]:
draft_answer = take_quiz(quiz, BN_LOG)
draft_score = score_quiz_answer(draft_answer, quiz)

print(draft_answer)
print(draft_score)

['C', 'A', 'B', 'B', 'D', 'B', 'C', 'C', 'C', 'B']
0.5


In [None]:
def check_answers_get_reward(quiz: Quiz, bn_explanations: list[str]) -> list[float]:
  quiz_rewards = []
  for bn_expl in bn_explanations:
    answer = take_quiz(quiz, bn_expl)
    quiz_rewards.append(score_quiz_answer(answer, quiz))
  return quiz_rewards

# GRPO

In [1]:
from unsloth import FastLanguageModel
import torch
max_seq_length = 2048 # Can increase for longer reasoning traces
lora_rank = 32 # Larger rank = smarter, but slower

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/gemma-3n-E4B-it",
    max_seq_length = max_seq_length,
    load_in_4bit = False, # False for LoRA 16bit
    fast_inference = True, # Enable vLLM fast inference
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.01, # Reduce if out of memory
)

model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha = lora_rank*2, # *2 speeds up training
    use_gradient_checkpointing = "unsloth", # Reduces memory usage
    random_state = 3407,
)

NotImplementedError: Unsloth currently only works on NVIDIA GPUs and Intel GPUs.

In [None]:
maximum_length = 1024
max_prompt_length = maximum_length + 1 # + 1 just in case!
max_completion_length = max_seq_length - max_prompt_length

from vllm import SamplingParams
vllm_sampling_params = SamplingParams(
    min_p = 0.1,
    top_p = 1.0,
    top_k = -1,
    seed = 3407,
    stop = [tokenizer.eos_token],
    include_stop_str_in_output = True,
)

from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    vllm_sampling_params = vllm_sampling_params,
    temperature = 1.0,
    learning_rate = 5e-6,
    weight_decay = 0.01,
    warmup_ratio = 0.1,
    lr_scheduler_type = "linear",
    optim = "adamw_8bit",
    logging_steps = 1,
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 1, # Increase to 4 for smoother training
    num_generations = 4, # Decrease if out of memory
    max_prompt_length = max_prompt_length,
    max_completion_length = max_completion_length,
    # num_train_epochs = 1, # Set to 1 for a full training run
    max_steps = 100,
    save_steps = 100,
    report_to = "none", # Can use Weights & Biases
    output_dir = "outputs",

    # For optional training + evaluation
    # fp16_full_eval = True,
    # per_device_eval_batch_size = 4,
    # eval_accumulation_steps = 1,
    # eval_strategy = "steps",
    # eval_steps = 1,
)


In [None]:
# For optional training + evaluation
# new_dataset = dataset.train_test_split(test_size = 0.01)

trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        check_answers_get_reward,
    ],
    args = training_args,
    train_dataset = dataset,

    # For optional training + evaluation
    # train_dataset = new_dataset["train"],
    # eval_dataset = new_dataset["test"],
)
trainer.train()

## Inference

In [None]:
text = "Why intervene asia change lung"

from vllm import SamplingParams
sampling_params = SamplingParams(
    temperature = 1.0,
    top_k = 50,
    max_tokens = 1024,
)
output = model.fast_generate(
    [text],
    sampling_params = sampling_params,
    lora_request = None,
)[0].outputs[0].text

output

In [None]:
model.save_lora("grpo_saved_lora")

Verify LoRA is actually trained!

In [None]:
from safetensors import safe_open

tensors = {}
with safe_open("grpo_saved_lora/adapter_model.safetensors", framework = "pt") as f:
    # Verify both A and B are non zero
    for key in f.keys():
        tensor = f.get_tensor(key)
        n_zeros = (tensor == 0).sum() / tensor.numel()
        assert(n_zeros.item() != tensor.numel())

Now we load the LoRA and test

In [None]:
messages = [
    {"role": "system", "content": BN_LOG},
    {"role": "user",   "content": 'What is the probability of having lung cancer if do go to asia'},
]

text = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt = True, # Must add for generation
    tokenize = False,
)
from vllm import SamplingParams
sampling_params = SamplingParams(
    temperature = 1.0,
    top_k = 50,
    max_tokens = 2048,
)
output = model.fast_generate(
    text,
    sampling_params = sampling_params,
    lora_request = model.load_lora("grpo_saved_lora"),
)[0].outputs[0].text

output