In [None]:
from openai import OpenAI
import os
import json
from pydantic import BaseModel,Field
from langchain.llms import OpenAI
from dotenv import load_dotenv
from langchain.chat_models import init_chat_model
from langchain_core.messages import HumanMessage, SystemMessage
from tqdm import tqdm
import re
import textwrap
import traceback
from typing import List, Union
from langgraph.graph import StateGraph, END
from typing import TypedDict, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed
from langsmith import traceable
from langsmith import Client, traceable, evaluate
from langsmith import evaluate, Client
import langsmith as ls
from preprocess_data import prepare_qa_input_with_answer_filter
import datetime
from datetime import datetime
from prompts.few_shot_PoT import few_shot_tabmwp,few_shot_tatqa,few_shot_gsm8k

In [2]:
file_path = '../dataset_langsmith/gsm8k.jsonl'

with open(file_path, "r", encoding="utf-8") as f:
    gsm8k = [json.loads(line) for line in f]

In [3]:
folder_path = "../dataset_langsmith/"
filename="tatqa.json"
file_path = os.path.join(folder_path, filename)
with open(file_path, "r", encoding="utf-8") as f:
    raw_data_tatqa = json.load(f)
tatqa = prepare_qa_input_with_answer_filter(raw_data_tatqa)

In [4]:
folder_path = "../dataset_langsmith/"
filename="tabmwp.json"
file_path = os.path.join(folder_path, filename)
with open(file_path, "r", encoding="utf-8") as f:
    tabmwp = json.load(f)

In [5]:
name= "gsm8k" # tatqa, tabmwp 
length_test= 1 # số lượng mẫu muốn test
if name == "gsm8k":
   DATA=gsm8k 
   name_dataset="GSM8K"
elif name == "tatqa":
   DATA=tatqa
   name_dataset="TATQA"
else:
   DATA=tabmwp
   name_dataset="TABMWP"

In [6]:
load_dotenv()
model = init_chat_model("gpt-4o-mini", model_provider="openai", temperature=0.2)
client = Client()

In [7]:
class State(TypedDict):
    question: str
    response: str
    context: Optional[str]
    program: Optional[str]
    result: Optional[str]
    final_answer: Optional[str]
    error: Optional[str]



In [9]:
if name=="gsm8k":
    select_fewshot=few_shot_gsm8k
elif name=="tatqa":
    select_fewshot=few_shot_tatqa
else:
    select_fewshot=few_shot_tabmwp

In [10]:

def extract_code_from_markdown(text):
    # Tìm tất cả các đoạn code giữa ```python và ```
    code_blocks = re.findall(r"```python\n(.*?)```", text, re.DOTALL)
    # Gộp các đoạn code lại, cách nhau bởi 2 dòng trống
    return "\n\n".join(block.strip() for block in code_blocks)


In [11]:
@traceable(run_type="prompt")
def pot_node(state: State) -> State:
    context_str = f"# Context:\n{state['context']}\n" if state.get("context") else ""
    pot_messages = [
        SystemMessage(f"You will write python program to solve math problems."),
        HumanMessage(content=f"""
{select_fewshot}
# Include a final answer as a single number, no units or symbols.
# For each step, provide a very brief explanation in one short sentence only.
# The final answer 'MUST' be assigned the variable 'result'.
# If the question includes time points, pay attention to time formats.
# Before returning the final result, DOUBLE-CHECK each variable assignment and calculation to ensure they match the problem statement.
{context_str}
# Question: {state["question"]}
""")]
    model_invoke=model.invoke(pot_messages)
    code = extract_code_from_markdown(model_invoke.content)
    return {**state, "program": code, "response": model_invoke.content}

def exec_node(state: State) -> State:
    try:
        exec_globals = {}
        exec(state["program"], exec_globals)
        result = exec_globals.get("result", None)
        if result is None:
            raise ValueError("Missing `result`")
        return {**state, "result": str(result), "error": None}
    except Exception as e:
        return {**state, "result": None, "error": str(e)}

def check_eos(state: State) -> bool:
    if state["error"] is None:
        return True
    else:
        return False

def write_final_answer_node(state:State)->State:

    if state["error"] is None:
        result=str(state["result"])
    else:
        result=str(9999)
    return {**state,"final_answer":result}


builder = StateGraph(State)
builder.add_node("PoT", pot_node)
builder.add_node("Exec", exec_node)

builder.add_node("write_final_answer",write_final_answer_node)

builder.set_entry_point("PoT")
builder.add_edge("PoT", "Exec")
builder.add_edge("Exec", "write_final_answer")
builder.add_edge("write_final_answer",END)
graph = builder.compile()


In [12]:
def extract_ground_truth(answer, dataset_type):
    if dataset_type == "gsm8k":
        match = re.search(r"####\s*([\d,./]+)", answer)
        if match:
            raw_ans = match.group(1).replace(",", "").strip()
        else:
            raw_ans = answer.strip()
    elif dataset_type == "tatqa":
        if isinstance(answer, list):
            ans = str(answer[0]).strip()
        else:
            ans = str(answer).strip()
        ans = re.sub(r'^[\[\"]*([\d\-\.\/]+)[\]\"]*$', r'\1', ans)
        if '/' in ans:
            ans = re.sub(r"[^-\d/\.]", "", ans)
        else:
            ans = re.sub(r"[^-\d\.]", "", ans)
        raw_ans = ans
    else:
        raw_ans = str(answer).strip()
    return raw_ans


def compare_answers(predicted: str, actual: str, eps: float = 1e-2) -> bool:
    try:
        if '/' in predicted:
            try:
                numerator, denominator = predicted.split('/')
                predicted = str(float(numerator) / float(denominator))
            except Exception:
                predicted=str(predicted)
        if '/' in actual:
            try:
                numerator, denominator = actual.split('/')
                actual = str(float(numerator) / float(denominator))
            except Exception:
                actual=str(actual)
        pred = float(predicted.strip())
        act = float(actual.strip())
        return abs(pred - act) <= eps
    except ValueError:
        return predicted.strip().lower() == actual.strip().lower()

def unwrap_singleton(value):
    # Nếu là list hoặc tuple Python
    if isinstance(value, (list, tuple)) and len(value) == 1:
        return value[0]
    # Nếu là chuỗi dạng '[2018]' hoặc "['2018']"
    if isinstance(value, str):
        import re
        match = re.fullmatch(r"\[\s*'?([-\w\.]+)'?\s*\]", value.strip())
        if match:
            return match.group(1)
    return value

In [None]:
def run_graph(inputs: dict):
    # Chuẩn bị state đầu vào cho graph
    state = {
        "question": inputs["question"],
        "context": inputs.get("context", ""),
        "error": None,
    }
    # Chạy graph
    final_state = graph.invoke(state)
    

    result = {
        "response": final_state.get("response", ""),
        "final_answer": final_state.get("final_answer", ""),
        "program": final_state.get("program", ""),
        "response": final_state.get("response", "")
    }
    return  result
all_results = [] 
@traceable(run_type="tool")
def compare_result(inputs: dict, reference_outputs: dict, outputs: dict):
    predicted=str(unwrap_singleton(outputs["final_answer"]))
    actual = extract_ground_truth(str(reference_outputs["answer"]), f"{name}")
    eps = 1e-2
    try:
        if '/' in predicted:
            try:
                numerator, denominator = predicted.split('/')
                predicted = str(float(numerator) / float(denominator))
            except Exception:
                predicted = str(predicted)
        if '/' in actual:
            try:
                numerator, denominator = actual.split('/')
                actual = str(float(numerator) / float(denominator))
            except Exception:
                actual = str(actual)
        pred = float(predicted.strip())
        act = float(actual.strip())
        score = abs(pred - act) <= eps
    except ValueError:
        score = predicted.strip().lower() == actual.strip().lower()
    
    all_results.append({
        "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
        "question": inputs["question"],
        "program": outputs["program"],
        "response": outputs["response"],
        "true_answer": actual,
        "predicted_answer": predicted,
        "context": inputs.get("context", ""),
        "correct": score
    })

    # Sau khi chạy xong:
    correct = sum(1 for x in all_results if x["correct"])
    total = len(all_results)
    accuracy = correct / total * 100
    wrong_answers = [x for x in all_results if not x["correct"]]

    summary = {
        "accuracy": accuracy,
        "correct": correct,
        "total": total,
        "wrong_answers": wrong_answers
    }

    with open(f"save_log/PoT_results - {name}.json", "w", encoding="utf-8") as f:
        json.dump(summary, f, ensure_ascii=False, indent=2)
    return {"key": "is_correct", "score": int(score)}
    

@traceable(run_type="chain")
def target_function(inputs: dict):
    result = run_graph(inputs)
    return result


client = Client()
evaluate(
    target_function,
    data=client.list_examples(dataset_name=f"{name_dataset}", splits=["base"]),
    evaluators=[compare_result],
    experiment_prefix=f"PoT_{name_dataset}"
)


View the evaluation results for experiment: 'PoT_GSM8K-07ed6943' at:
https://smith.langchain.com/o/943b2ecf-878d-466a-a3ad-a779a4be18b4/datasets/95ec945e-8b0f-4361-92e4-638295e8bcb3/compare?selectedSessions=67ea92f4-c389-473c-8e23-e383f1e99e21




0it [00:00, ?it/s]

Unnamed: 0,inputs.question,outputs.response,outputs.final_answer,outputs.program,error,reference.answer,feedback.is_correct,execution_time,example_id,id
0,Josie grows grapes on her 10-acre farm. Each ...,```python\n# Step 1: Set the number of acres o...,100,# Step 1: Set the number of acres on Josie's f...,,If each acre produces 5 tons of grapes per yea...,1,5.419777,0042ea8d-aba3-4e78-90af-fda9d178120f,62d6e882-5cfc-4a58-867b-3e219f2f10ba
1,Jackie is trying to decide whether to do her t...,```python\n# Step 1: Set the hourly rate for f...,-15,# Step 1: Set the hourly rate for freelance wo...,,First find the total lost revenue if Jackie do...,0,5.701788,006cd270-8519-442d-a286-130ac608a060,6c63e8ee-1026-441b-8c9e-03a32afa4180
2,Raymond and Samantha are cousins. Raymond was ...,```python\n# Step 1: Set the current age of Sa...,14,# Step 1: Set the current age of Samantha\nsam...,,When Raymond's son was born Samantha was 23 - ...,1,3.305586,008dccc0-4f60-44cb-a59f-3c1368fedb46,899f1731-1d8b-4acd-bc15-494e8ecddc8c
3,A water tank is filled with 120 liters of wate...,```python\n# Step 1: Set the initial amount of...,90,# Step 1: Set the initial amount of water in t...,,"After watering Celine's garden, 120 - 90 = <<1...",1,2.917127,01f2dc43-093e-46e8-a126-25bb29a1301e,733c0a48-f4fa-4480-adc6-f6ed8be3c6e0
4,Uriah's book bag is getting too heavy for him....,```python\n# Step 1: Set the total weight Uria...,15,# Step 1: Set the total weight Uriah needs to ...,,30 comic books weigh 7.5 pounds because 30 x ....,1,5.232508,03b7cc4c-d658-480a-b316-44d1dda6c211,b6a4b06a-5d24-45b9-8274-5c7277dbb126
...,...,...,...,...,...,...,...,...,...,...
295,Jean has 30 lollipops. Jean eats 2 of the loll...,```python\n# Step 1: Set the initial number of...,14,# Step 1: Set the initial number of lollipops ...,,Jean has 30 - 2 = <<30-2=28>>28 lollipops\nJea...,1,4.045573,f9d015bb-525f-442f-b96e-9b07090aa6d8,2d91b12e-038e-41e4-ac22-0670d662fcfe
296,"In a neighborhood, the number of rabbits pets ...",```python\n# Step 1: Set the number of dogs\nd...,348,# Step 1: Set the number of dogs\ndogs = 60\n\...,,"If there are two cats for every dog, and the n...",1,2.315251,fb141fed-989d-4619-9fe2-7e4200350594,4c8be200-ab24-4f3e-9714-fccbcb1fb6cd
297,A mother goes shopping. She buys cocoa at $4.2...,```python\n# Step 1: Set the cost of cocoa\nco...,5.000000000000002,# Step 1: Set the cost of cocoa\ncocoa_cost = ...,,The total shopping cost is $4.2 + $9.45 + $1.3...,1,2.752946,ff002aa3-3abd-4e06-a87d-c8ee027543d7,5630f379-e04e-4c34-a3c6-8848d2cea0b1
298,Dylan needed chicken sausages and fish sausage...,```python\n# Step 1: Set the number of chicken...,82,# Step 1: Set the number of chicken sausages D...,,He bought 38 + 6 = <<38+6=44>>44 fish sausages...,1,2.297663,ff19d22f-f0e7-4487-89d1-7fe3bfc8e058,95673919-d889-4cb1-8672-a6587e53a281
