In [1]:
from openai import OpenAI
import os
import json
from pydantic import BaseModel,Field
from langchain.llms import OpenAI
from dotenv import load_dotenv
from langchain.chat_models import init_chat_model
from langchain_core.messages import HumanMessage, SystemMessage
from tqdm import tqdm
import re
import textwrap
import traceback
from typing import List, Union
from langgraph.graph import StateGraph, END
from typing import TypedDict, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed
from langsmith import traceable
from langsmith import Client, traceable, evaluate
from langsmith import evaluate, Client
import langsmith as ls
from preprocess_data import prepare_qa_input_with_answer_filter,standardize_item
from datetime import datetime
from prompts.few_shot_PaL import few_shot_tabmwp,few_shot_tatqa,few_shot_gsm8k

In [26]:
file_path = '../dataset_langsmith/gsm8k.jsonl'

with open(file_path, "r", encoding="utf-8") as f:
    gsm8k = [json.loads(line) for line in f]

In [27]:
folder_path = "../dataset_langsmith/"
filename="tatqa.json"
file_path = os.path.join(folder_path, filename)
with open(file_path, "r", encoding="utf-8") as f:
    raw_data_tatqa = json.load(f)
tatqa = prepare_qa_input_with_answer_filter(raw_data_tatqa)

In [28]:
folder_path = "../dataset_langsmith/"
filename="tabmwp.json"
file_path = os.path.join(folder_path, filename)
with open(file_path, "r", encoding="utf-8") as f:
    tabmwp = json.load(f)

In [None]:
name= "tatqa" # tatqa, tabmwp, gsm8k
length_test= 1
if name == "gsm8k":
   DATA=gsm8k 
   name_dataset="GSM8K"
elif name == "tatqa":
   DATA=tatqa
   name_dataset="TATQA"
else:
   DATA=tabmwp
   name_dataset="TABMWP"

In [30]:
load_dotenv()
model = init_chat_model("gpt-4o-mini", model_provider="openai", temperature=0.2)

In [31]:
class State(TypedDict):
    question: str
    context: Optional[str]
    program: Optional[str]
    result: Optional[str]
    final_answer: Optional[str]
    error: Optional[str]
class IntermediateProgram(BaseModel):
    program: str


In [32]:
if name=="gsm8k":
    select_fewshot=few_shot_gsm8k
elif name=="tatqa":
    select_fewshot=few_shot_tatqa
else:
    select_fewshot=few_shot_tabmwp

In [33]:
def extract_code_from_markdown(text):
    """
    Trích xuất code Python từ chuỗi markdown có dạng ```python ... ```
    """
    # Dùng regex để tìm đoạn code giữa ```python và ```
    match = re.search(r"```python\n(.*?)```", text, re.DOTALL)
    if match:
        return match.group(1).strip()
    # Nếu không có markdown, trả về nguyên text
    return text.strip()

In [34]:
@traceable(run_type="prompt")
def pot_node(state: State) -> State:
    context_str = f"# Context:\n{state['context']}\n" if state.get("context") else ""
    pot_messages = [
        SystemMessage(f"You will write python program to solve math problems. You will only write code blocks."),
        HumanMessage(content=f"""
{select_fewshot}
# Answer this question by implementing a solver() function.
# Include a final answer as a single number, no units or symbols.
# 'CALL' the solver() function and then 'MUST' assign the variale 'result'.
# If the question includes time points, pay attention to time formats.
# Before returning the final result, DOUBLE-CHECK each variable assignment and calculation to ensure they match the problem statement.
{context_str}
# Question: {state["question"]}
""")]

    model_invoke=model.invoke(pot_messages)
    code = extract_code_from_markdown(model_invoke.content)
    return {**state, "program": code}

def exec_node(state: State) -> State:
    try:
        exec_globals = {}
        exec(state["program"], exec_globals)
        result = exec_globals.get("result", None)

        if result is None:
            raise ValueError("Missing `result`")
        return {**state, "result": str(result), "error": None}
    except Exception as e:
        return {**state, "result": None, "error": str(e)}

def check_eos(state: State) -> bool:
    if state["error"] is None:
        return True
    else:
        return False
def write_final_answer_node(state:State)->State:

    if state["error"] is None:
        result=str(state["result"])
    else:
        result=str(9999)
    return {**state,"final_answer":result}


builder = StateGraph(State)
builder.add_node("PoT", pot_node)
builder.add_node("Exec", exec_node)

builder.add_node("write_final_answer",write_final_answer_node)

builder.set_entry_point("PoT")
builder.add_edge("PoT", "Exec")
builder.add_edge("Exec", "write_final_answer")
builder.add_edge("write_final_answer",END)
graph = builder.compile()


In [35]:
def extract_ground_truth(answer, dataset_type):
    if dataset_type == "gsm8k":
        match = re.search(r"####\s*([\d,./]+)", answer)
        if match:
            raw_ans = match.group(1).replace(",", "").strip()
        else:
            raw_ans = answer.strip()
    elif dataset_type == "tatqa":
        if isinstance(answer, list):
            ans = str(answer[0]).strip()
        else:
            ans = str(answer).strip()
        ans = re.sub(r'^[\[\"]*([\d\-\.\/]+)[\]\"]*$', r'\1', ans)
        if '/' in ans:
            ans = re.sub(r"[^-\d/\.]", "", ans)
        else:
            ans = re.sub(r"[^-\d\.]", "", ans)
        raw_ans = ans
    else:
        raw_ans = str(answer).strip()
    return raw_ans


def compare_answers(predicted: str, actual: str, eps: float = 1e-2) -> bool:
    try:
        if '/' in predicted:
            try:
                numerator, denominator = predicted.split('/')
                predicted = str(float(numerator) / float(denominator))
            except Exception:
                predicted=str(predicted)
        if '/' in actual:
            try:
                numerator, denominator = actual.split('/')
                actual = str(float(numerator) / float(denominator))
            except Exception:
                actual=str(actual)
        pred = float(predicted.strip())
        act = float(actual.strip())
        return abs(pred - act) <= eps
    except ValueError:
        return predicted.strip().lower() == actual.strip().lower()

def unwrap_singleton(value):
    # Nếu là list hoặc tuple Python
    if isinstance(value, (list, tuple)) and len(value) == 1:
        return value[0]
    # Nếu là chuỗi dạng '[2018]' hoặc "['2018']"
    if isinstance(value, str):
        import re
        match = re.fullmatch(r"\[\s*'?([-\w\.]+)'?\s*\]", value.strip())
        if match:
            return match.group(1)
    return value

In [36]:

def run_graph(inputs: dict):
    # Chuẩn bị state đầu vào cho graph
    state = {
        "question": inputs["question"],
        "context": inputs.get("context", ""),
        "error": None,
    }
    # Chạy graph
    final_state = graph.invoke(state)
    

    result = {
        "final_answer": final_state.get("final_answer", ""),
        "program": final_state.get("program", ""),
        "response": final_state.get("response", "")
    }
    return  result
all_results = [] 
def compare_result(inputs: dict, reference_outputs: dict, outputs: dict):
    predicted=str(unwrap_singleton(outputs["final_answer"]))
    actual = extract_ground_truth(str(reference_outputs["answer"]), f"{name}")
    eps = 1e-2
    try:
        if '/' in predicted:
            try:
                numerator, denominator = predicted.split('/')
                predicted = str(float(numerator) / float(denominator))
            except Exception:
                predicted = str(predicted)
        if '/' in actual:
            try:
                numerator, denominator = actual.split('/')
                actual = str(float(numerator) / float(denominator))
            except Exception:
                actual = str(actual)
        pred = float(predicted.strip())
        act = float(actual.strip())
        score = abs(pred - act) <= eps
    except ValueError:
        score = predicted.strip().lower() == actual.strip().lower()
    
    all_results.append({
        "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
        "question": inputs["question"],
        "program": outputs["program"],
        "true_answer": actual,
        "predicted_answer": predicted,
        "context": inputs.get("context", ""),
        "correct": score
    })

    # Sau khi chạy xong:
    correct = sum(1 for x in all_results if x["correct"])
    total = len(all_results)
    accuracy = correct / total * 100
    wrong_answers = [x for x in all_results if not x["correct"]]

    summary = {
        "accuracy": accuracy,
        "correct": correct,
        "total": total,
        "wrong_answers": wrong_answers
    }

    with open(f"save_log/PaL_results - {name}.json", "w", encoding="utf-8") as f:
        json.dump(summary, f, ensure_ascii=False, indent=2)
    return {"key": "is_correct", "score": int(score)}
    

@traceable(run_type="chain")
def target_function(inputs: dict):
    result = run_graph(inputs)
    return result


client = Client()
evaluate(
    target_function,
    data=client.list_examples(dataset_name=f"{name_dataset}", splits=["base"]),
    evaluators=[compare_result],
    experiment_prefix=f"PaL_{name_dataset}"
)


View the evaluation results for experiment: 'PaL_GSM8K-6267fbe5' at:
https://smith.langchain.com/o/d2c4d5e3-1ef4-4585-9aa2-091744d33e24/datasets/a09d6e43-4429-452e-965a-0613d78f1c81/compare?selectedSessions=25f80e31-2386-46a0-b361-3c730b8434ed




0it [00:00, ?it/s]

Unnamed: 0,inputs.question,outputs.final_answer,outputs.program,outputs.response,error,reference.answer,feedback.is_correct,execution_time,example_id,id
0,Raymond and Samantha are cousins. Raymond was ...,14,"def solver():\n """"""Return how many years ag...",,,When Raymond's son was born Samantha was 23 - ...,1,2.682401,00d6514d-9915-4f5b-a227-e5c4de06e2db,431b3dbb-01b4-4445-ae2e-d35476aa6083
1,"Will buys 15 oranges. When he gets home, he as...",3.0,"def solver():\n """"""Return the number of unw...",,,"The older son washes 8 oranges, so the younger...",1,2.449563,01d8715e-444b-431c-ad50-cce166ac7896,5b2d6cb2-a160-46a9-8318-ef7521284795
2,Brittany and her mom go to the museum. The cos...,30,"def solver():\n """"""Return the amount of mon...",,,The total cost of the tickets is 12+10=<<12+10...,1,2.015948,01da9071-b651-49c4-baed-6438599c7a19,212dcb53-84c3-4a17-a811-4cc6fe7c6c63
3,Darrell and Allen's ages are in the ratio of 7...,109.0,"def solver():\n """"""Return Allen's age 10 ye...",,,The total ratio representing their ages is 7+1...,1,3.781126,0309257f-8577-48f5-9ff3-c4ebf7738374,36c7cb0c-30cb-4eab-a330-7da42438abb4
4,Three friends spent $20.25 on 3 tickets to the...,34.0,"def solver():\n """"""Return how much each fri...",,,They spent $20.25 - $4.50 = $<<20.25-4.5=15.75...,1,2.070327,04c43153-d487-40f8-bf27-e7ea21ca3622,3f6db9b5-c474-4d9b-a528-9df0c95857c9
...,...,...,...,...,...,...,...,...,...,...
295,The Doubtfire sisters are driving home with 7 ...,40,"def solver():\n """"""Return the total number ...",,,Patchy has just had 3 * 7 = <<3*7=21>>21 kitte...,1,2.490289,faebd878-f103-48c2-a2ad-5b7d2fb74c30,7167dba5-b472-4078-9884-b7cd97e21101
296,Morisette and Kael were asked to bring fruits....,27.0,"def solver():\n """"""Return the total number ...",,,Kael brought 5 x 2 = <<5*2=10>>10 apples.\nAnd...,1,4.526390,fc250399-f7eb-45c1-88f7-b6b44d24316f,afe75a10-9012-4271-85a9-fed1ec5c19f3
297,"Rani has ten more crabs than Monic, who has 4 ...",122,"def solver():\n """"""Return the total number ...",,,"If Bo has 40 crabs, then Monic, who has 4 fewe...",1,2.381886,fd626d5b-afd3-454c-89c9-e72124caa7ed,0346abe4-c4d2-4421-9c03-be68ebb92861
298,Charlie has three times as many Facebook frien...,16.0,"def solver():\n """"""Return the number of Fac...",,,Dorothy has 12 / 3 = <<12/3=4>>4 Facebook frie...,1,1.560296,fde8b920-c5ee-463a-958c-45829d675453,19d4a9e3-ec9f-4f35-89c6-6f57eb119ed2
