In [13]:
from openai import OpenAI
import os
import json
from pydantic import BaseModel,Field
from langchain.llms import OpenAI
from dotenv import load_dotenv
from langchain.chat_models import init_chat_model
from langchain_core.messages import HumanMessage, SystemMessage
from tqdm import tqdm
import re
import textwrap
import traceback
from typing import List, Union
from langgraph.graph import StateGraph, END
from typing import TypedDict, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed
from langsmith import traceable
from langsmith import Client, traceable, evaluate
from langsmith import evaluate, Client
import langsmith as ls
from preprocess_data import prepare_qa_input_with_answer_filter
import datetime
from datetime import datetime
from prompts.few_shot_PoT import few_shot_tabmwp,few_shot_tatqa,few_shot_gsm8k

In [14]:
file_path = '../dataset_langsmith/gsm8k.jsonl'

with open(file_path, "r", encoding="utf-8") as f:
    gsm8k = [json.loads(line) for line in f]

In [15]:
folder_path = "../dataset_langsmith/"
filename="tatqa.json"
file_path = os.path.join(folder_path, filename)
with open(file_path, "r", encoding="utf-8") as f:
    raw_data_tatqa = json.load(f)
tatqa = prepare_qa_input_with_answer_filter(raw_data_tatqa)

In [16]:
folder_path = "../dataset_langsmith/"
filename="tabmwp.json"
file_path = os.path.join(folder_path, filename)
with open(file_path, "r", encoding="utf-8") as f:
    tabmwp = json.load(f)

In [17]:
name= "tabmwp" # tatqa, tabmwp 
length_test= 1 # số lượng mẫu muốn test
if name == "gsm8k":
   DATA=gsm8k 
   name_dataset="GSM8K"
elif name == "tatqa":
   DATA=tatqa
   name_dataset="TATQA"
else:
   DATA=tabmwp
   name_dataset="TABMWP"

In [18]:
load_dotenv()
model = init_chat_model("gpt-4o-mini", model_provider="openai", temperature=0.2)
client = Client()

In [19]:
class State(TypedDict):
    question: str
    response: str
    context: Optional[str]
    program: Optional[str]
    result: Optional[str]
    final_answer: Optional[str]
    error: Optional[str]



In [20]:
if name=="gsm8k":
    select_fewshot=few_shot_gsm8k
elif name=="tatqa":
    select_fewshot=few_shot_tatqa
else:
    select_fewshot=few_shot_tabmwp

In [21]:

def extract_code_from_markdown(text):
    # Tìm tất cả các đoạn code giữa ```python và ```
    code_blocks = re.findall(r"```python\n(.*?)```", text, re.DOTALL)
    # Gộp các đoạn code lại, cách nhau bởi 2 dòng trống
    return "\n\n".join(block.strip() for block in code_blocks)


In [None]:
@traceable(run_type="prompt")
def pot_node(state: State) -> State:
    context_str = f"# Context:\n{state['context']}\n" if state.get("context") else ""
    pot_messages = [
        SystemMessage(f"You will write python program to solve math problems."),
        HumanMessage(content=f"""
{few_shot_gsm8k}
# Include a final answer as a single number, no units or symbols.
# The final answer 'MUST' be assigned the variable 'result'.
# If the question includes time points, pay attention to time formats.
# Before returning the final result, DOUBLE-CHECK each variable assignment and calculation to ensure they match the problem statement.
{context_str}
# Question: {state["question"]}
""")]
    model_invoke=model.invoke(pot_messages)
    code = extract_code_from_markdown(model_invoke.content)
    return {**state, "program": code, "response": model_invoke.content}

def exec_node(state: State) -> State:
    try:
        exec_globals = {}
        exec(state["program"], exec_globals)
        result = exec_globals.get("result", None)
        if result is None:
            raise ValueError("Missing `result`")
        return {**state, "result": str(result), "error": None}
    except Exception as e:
        return {**state, "result": None, "error": str(e)}

def check_eos(state: State) -> bool:
    if state["error"] is None:
        return True
    else:
        return False

def write_final_answer_node(state:State)->State:

    if state["error"] is None:
        result=str(state["result"])
    else:
        result=str(9999)
    return {**state,"final_answer":result}


builder = StateGraph(State)
builder.add_node("PoT", pot_node)
builder.add_node("Exec", exec_node)

builder.add_node("write_final_answer",write_final_answer_node)

builder.set_entry_point("PoT")
builder.add_edge("PoT", "Exec")
builder.add_edge("Exec", "write_final_answer")
builder.add_edge("write_final_answer",END)
graph = builder.compile()


In [23]:
def extract_ground_truth(answer, dataset_type):
    if dataset_type == "gsm8k":
        match = re.search(r"####\s*([\d,./]+)", answer)
        if match:
            raw_ans = match.group(1).replace(",", "").strip()
        else:
            raw_ans = answer.strip()
    elif dataset_type == "tatqa":
        if isinstance(answer, list):
            ans = str(answer[0]).strip()
        else:
            ans = str(answer).strip()
        ans = re.sub(r'^[\[\"]*([\d\-\.\/]+)[\]\"]*$', r'\1', ans)
        if '/' in ans:
            ans = re.sub(r"[^-\d/\.]", "", ans)
        else:
            ans = re.sub(r"[^-\d\.]", "", ans)
        raw_ans = ans
    else:
        raw_ans = str(answer).strip()
    return raw_ans


def compare_answers(predicted: str, actual: str, eps: float = 1e-2) -> bool:
    try:
        if '/' in predicted:
            try:
                numerator, denominator = predicted.split('/')
                predicted = str(float(numerator) / float(denominator))
            except Exception:
                predicted=str(predicted)
        if '/' in actual:
            try:
                numerator, denominator = actual.split('/')
                actual = str(float(numerator) / float(denominator))
            except Exception:
                actual=str(actual)
        pred = float(predicted.strip())
        act = float(actual.strip())
        return abs(pred - act) <= eps
    except ValueError:
        return predicted.strip().lower() == actual.strip().lower()

def unwrap_singleton(value):
    # Nếu là list hoặc tuple Python
    if isinstance(value, (list, tuple)) and len(value) == 1:
        return value[0]
    # Nếu là chuỗi dạng '[2018]' hoặc "['2018']"
    if isinstance(value, str):
        import re
        match = re.fullmatch(r"\[\s*'?([-\w\.]+)'?\s*\]", value.strip())
        if match:
            return match.group(1)
    return value

In [24]:
def run_graph(inputs: dict):
    # Chuẩn bị state đầu vào cho graph
    state = {
        "question": inputs["question"],
        "context": inputs.get("context", ""),
        "error": None,
    }
    # Chạy graph
    final_state = graph.invoke(state)

    result = {
        "response": final_state.get("response", ""),
        "final_answer": final_state.get("final_answer", ""),
        "program": final_state.get("program", ""),
        "response": final_state.get("response", "")
    }
    return  result

all_results = [] 
@traceable(run_type="tool")
def compare_result(inputs: dict, reference_outputs: dict, outputs: dict):
    predicted=str(unwrap_singleton(outputs["final_answer"]))
    actual = extract_ground_truth(str(reference_outputs["answer"]), f"{name}")
    eps = 1e-2
    try:
        if '/' in predicted:
            try:
                numerator, denominator = predicted.split('/')
                predicted = str(float(numerator) / float(denominator))
            except Exception:
                predicted = str(predicted)
        if '/' in actual:
            try:
                numerator, denominator = actual.split('/')
                actual = str(float(numerator) / float(denominator))
            except Exception:
                actual = str(actual)
        pred = float(predicted.strip())
        act = float(actual.strip())
        score = abs(pred - act) <= eps
    except ValueError:
        score = predicted.strip().lower() == actual.strip().lower()
    
    all_results.append({
        "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
        "question": inputs["question"],
        "program": outputs["program"],
        "response": outputs["response"],
        "true_answer": actual,
        "predicted_answer": predicted,
        "context": inputs.get("context", ""),
        "correct": score
    })

    # Sau khi chạy xong:
    correct = sum(1 for x in all_results if x["correct"])
    total = len(all_results)
    accuracy = correct / total * 100
    wrong_answers = [x for x in all_results if not x["correct"]]

    summary = {
        "accuracy": accuracy,
        "correct": correct,
        "total": total,
        "wrong_answers": wrong_answers
    }

    with open(f"save_log/PoT_results - {name}.json", "w", encoding="utf-8") as f:
        json.dump(summary, f, ensure_ascii=False, indent=2)
    return {"key": "is_correct", "score": int(score)}
    

@traceable(run_type="chain")
def target_function(inputs: dict):
    result = run_graph(inputs)
    return result


client = Client()
evaluate(
    target_function,
    data=client.list_examples(dataset_name=f"{name_dataset}", splits=["base"]),
    evaluators=[compare_result],
    experiment_prefix=f"PoT_{name_dataset}"
)


View the evaluation results for experiment: 'PoT_TABMWP-f7fae1be' at:
https://smith.langchain.com/o/eb55acc7-3259-44e5-99ee-f98b0aec16fb/datasets/93216d09-6d69-41e9-af74-8f2bd79d8ed5/compare?selectedSessions=3620218e-1182-44ff-bd80-976166c740cd




0it [00:00, ?it/s]

Unnamed: 0,inputs.context,inputs.question,outputs.response,outputs.final_answer,outputs.program,error,reference.answer,feedback.is_correct,execution_time,example_id,id
0,| Column 1 | Column 2 |\n| --- | --- |\n| poet...,How much money does Tony need to buy a textboo...,```python\n# Step 1: Set the prices of the boo...,13.59,# Step 1: Set the prices of the books\ntextboo...,,13.59,1,2.086832,001f8eb9-90de-4d6d-ab72-423d4f037c9a,b8b337f8-53ec-4c87-97ab-eb1594c2f023
1,Basketball hoops\n\n| Park | Number of basketb...,The parks department compared how many basketb...,```python\n# Step 1: List the number of basket...,3,# Step 1: List the number of basketball hoops ...,,3,1,2.473560,007cfb6b-7c76-4777-8381-2499edafd0b8,a44660a0-41be-42eb-9157-f9f2ab6d31db
2,| Employee | Pay period | |\n| --- | --- | --...,Look at Max's pay stub. Max lives in a state t...,```python\n# Step 1: Set the values for Max's ...,1283.3899999999999,# Step 1: Set the values for Max's taxes\nfede...,,1283.39,1,2.329159,01895585-df96-40e2-b79e-208050490b31,deb0743e-d461-42c1-bad9-d975f3f61e61
3,Cans of food collected\n\n| Name | Number of c...,"As part of a food drive, four friends collecte...",```python\n# Step 1: Set the number of cans co...,62/277,# Step 1: Set the number of cans collected by ...,,62/277,1,4.084544,02220f4e-baee-4d6c-b408-ea0083a61264,aa459d7d-79e5-4040-8c3a-a253e7ae5119
4,Making key chains\n\n| Key chains made | Frequ...,The parents running this year's craft sale cou...,```python\n# Step 1: Define the frequency of k...,31,# Step 1: Define the frequency of key chains m...,,31,1,3.214672,025776aa-16dd-4094-9356-a6bc74cf7174,fc432622-f5e7-40a6-96ad-93401d7f3498
...,...,...,...,...,...,...,...,...,...,...,...
295,Fire hydrants\n\n| Street | Number of hydrants...,The city recorded how many fire hydrants there...,```python\n# Step 1: List the number of fire h...,7,# Step 1: List the number of fire hydrants on ...,,7,1,3.152766,f8a864c3-bd7c-4734-9e75-4ca6faffec80,ff631dc2-5741-4125-80cb-7ea82bc2e846
296,Scores in a card game\n\n| Score | Frequency |...,Molly figured out the scores at the end of a c...,```python\n# Step 1: Define the scores and the...,53,# Step 1: Define the scores and their frequenc...,,53,1,3.256365,f8e819eb-6fba-4a99-b998-b364edd000a2,7b0bfd2b-a5eb-4b1a-8763-0535ac0a5c8f
297,| Column 1 | Column 2 |\n| --- | --- |\n| stea...,How much money does Erica need to buy a cheese...,```python\n# Step 1: Set the prices for cheese...,9,# Step 1: Set the prices for cheese pizza and ...,,9,1,2.212567,f936184b-fad0-45fc-871a-9154b1cfa7bb,ef723618-20a8-4070-a1e2-bbabc75a0090
298,Computers in classrooms\n\n| Teacher | Number ...,The teachers at a middle school counted how ma...,```python\n# Step 1: List all the number of co...,5,# Step 1: List all the number of computers for...,,5,1,2.925965,f967f9aa-94ea-480b-a3b0-11c1a01cd30d,79310d48-d10d-4a23-a806-b1355a67315e
