In [None]:
from openai import OpenAI
import os
import json
from pydantic import BaseModel,Field
from langchain.llms import OpenAI
from dotenv import load_dotenv
from langchain.chat_models import init_chat_model
from langchain_core.messages import HumanMessage, SystemMessage
from tqdm import tqdm
import re
import pandas as pd
import langsmith as ls
import textwrap
from typing import TypedDict, Optional
from openai import OpenAI
from pydantic import BaseModel
import traceback
from langgraph.graph import StateGraph, END
from typing import TypedDict, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed
from langsmith import traceable, trace
from preprocess_data import prepare_qa_input_with_answer_filter,standardize_item
from langsmith import Client, traceable, evaluate

In [2]:
file_path = '../data/GSM8K/test.jsonl'

with open(file_path, "r", encoding="utf-8") as f:
    gsm8k = [json.loads(line) for line in f]

In [3]:
folder_path = "../dataset_langsmith/"
filename="tatqa_filtered.json"
file_path = os.path.join(folder_path, filename)
with open(file_path, "r", encoding="utf-8") as f:
    raw_data_tatqa = json.load(f)
tatqa = prepare_qa_input_with_answer_filter(raw_data_tatqa)

In [4]:
folder_path = "../dataset_langsmith/"
filename="tabmwp_filtered.json"
file_path = os.path.join(folder_path, filename)
with open(file_path, "r", encoding="utf-8") as f:
    tabmwp = json.load(f)

In [5]:
name= "tatqa" # tatqa, tabmwp
length_test= 300 # số lượng mẫu muốn test
if name == "gsm8k":
   DATA=gsm8k 
   name_model="GSM8K"
elif name == "tatqa":
   DATA=tatqa
   name_model="TATQA"
else:
   DATA=tabmwp
   name_model="TABMWP"

In [None]:
load_dotenv()
model = init_chat_model("gpt-4.1-mini", model_provider="openai", temperature=0.2)
client = Client()

In [7]:
class State(TypedDict):
    question: str
    context: Optional[str]
    program: Optional[str]
    result: Optional[str]
    final_answer: Optional[str]
    error: Optional[str]
class IntermediateProgram(BaseModel):
    program: str


In [8]:

def pot_node(state: State) -> State:
    context_str = f"# Context:\n{state['context']}\n" if state.get("context") else ""
    pot_messages = [
        SystemMessage("""
# Answer this question by implementing a solver() function.
# Write a Python program, and then return the answer.
                      
Question: Carlos is planning a lemon tree. The tree will cost $90 to plant. Each year it will grow 7 lemons, which he can sell for $1.5 each. It costs $3 a year to water and feed the tree. How many years will it take before he starts earning money on the lemon tree?
def solver():
    total_cost = 90
    cost_of_watering_and_feeding = 3
    cost_of_each_lemon = 1.5
    num_of_lemon_per_year = 7
    ans = 0
    while total_cost > 0:
        total_cost += cost_of_watering_and_feeding
        total_cost -= num_of_lemon_per_year * cost_of_each_lemon
        ans += 1
    return ans

                      
# Context:
|  |  | Three Months Ended |  | % Variation |  |\n| --- | --- | --- | --- | --- | --- |\n|  | December 31, 2019 | September 29, 2019 | December 31, 2018 | Sequential | Year-Over-Year |\n|  |  |  | (Unaudited, in millions) |  |  |\n| Selling, general and administrative expenses | $(285) | $(267) | $(285) | (6.3)% | 0.4% |\n| Research and development expenses | (387) | (362) | (345) | (7.0) | (12.3) |\n| Total operating expenses | $(672) | $(629) | $(630) | (6.7)% | (6.6)% |\n| As percentage of net revenues | (24.4)% | (24.7)% | (23.8)% | +30 bps | -60 bps |\n
Question: What is the increase/ (decrease) in total operating expenses from the period December 31, 2018 to 2019?
def solver():
    # Giá trị total operating expenses (triệu USD)
    expenses_2019 = 672
    expenses_2018 = 630
    # Tính mức tăng
    increase = expenses_2019 - expenses_2018
    # Nếu kết quả là số thập phân .0 thì chuyển thành số nguyên
    if isinstance(increase, float) and increase.is_integer():
        increase = int(increase)
    return increase

# Context:
| Number of times | Frequency |
|-----------------|-----------|
| 0               | 1         |
| 1               | 18        |
| 2               | 12        |
| 3               | 13        |
| 4               | 0         |
Question: How many customers are there in all?
def solver():
    frequencies = [1, 18, 12, 13, 0]
    ans = sum(frequencies)
    if isinstance(ans, float) and ans.is_integer():
        ans = int(ans)
    return ans
"""),
        HumanMessage(content=f"""
{context_str}
# Question: {state["question"]}
# Include a final answer as a single number, no units or symbols.
# 'CALL' the solver() function and then 'MUST' assign the variale 'result'.
# Before returning the final result, DOUBLE-CHECK each variable assignment and calculation to ensure they match the problem statement.
""")]

    model_pot = model.with_structured_output(IntermediateProgram)
    model_invoke=model_pot.invoke(pot_messages)
    code=model_invoke.program
    return {**state, "program": code}

def exec_node(state: State) -> State:
    try:
        exec_globals = {}
        exec(state["program"], {}, exec_globals)
        result = exec_globals.get("result", None)
        if result is None:
            raise ValueError("Missing `result`")
        return {**state, "result": str(result), "error": None}
    except Exception as e:
        return {**state, "result": None, "error": str(e)}

def check_eos(state: State) -> bool:
    if state["error"] is None:
        return True
    else:
        return False

def write_final_answer_node(state:State)->State:

    if state["error"] is None:
        result=str(state["result"])
    else:
        result=str(9999)
    return {**state,"final_answer":result}


builder = StateGraph(State)
builder.add_node("PoT", pot_node)
builder.add_node("Exec", exec_node)

builder.add_node("write_final_answer",write_final_answer_node)

builder.set_entry_point("PoT")
builder.add_edge("PoT", "Exec")
builder.add_edge("Exec", "write_final_answer")
builder.add_edge("write_final_answer",END)
graph = builder.compile()


In [9]:
import re

def extract_ground_truth(answer, dataset_type):
    if dataset_type == "gsm8k":
        match = re.search(r"####\s*([\d,./]+)", answer)
        if match:
            return match.group(1).replace(",", "").strip()
        return answer.strip()
    elif dataset_type == "tatqa":
        # Có thể là list hoặc chuỗi, lấy phần đầu tiên nếu là list
        if isinstance(answer, list):
            ans = str(answer[0]).strip()
        else:
            ans = str(answer).strip()
        # Nếu là dạng [2019] hoặc ["2019"], loại bỏ ngoặc vuông và dấu nháy
        ans = re.sub(r'^[\[\"]*([\d\-\.\/]+)[\]\"]*$', r'\1', ans)
        # Nếu là phân số thì giữ lại dấu "/"
        if '/' in ans:
            ans = re.sub(r"[^-\d/\.]", "", ans)
        else:
            ans = re.sub(r"[^-\d\.]", "", ans)
        return ans
    else:
        return str(answer).strip()

    
def compare_answers(predicted: str, actual: str, eps: float = 1e-3) -> bool:
    try:
        pred = round(float(predicted.strip()))
        act = round(float(actual.strip()))
        return abs(pred - act) < eps
    except ValueError:
        return predicted.strip().lower() == actual.strip().lower()



In [10]:

def process_item(item, dataset_type):
    question = item["question"]
    context = item.get("context", "")
    true_answer = extract_ground_truth(item["answer"], dataset_type)
    try:
        # Nếu pipeline/graph của bạn cần context thì truyền vào, nếu không thì chỉ truyền question
        if context:
            result = graph.invoke({"question": question, "context": context})
        else:
            result = graph.invoke({"question": question})
        return {
            "question": question,
            "program": result.get("program", ""),
            "true_answer": true_answer,
            "predicted_answer": result["final_answer"],
            "context": context,
            "correct": compare_answers(result["final_answer"], true_answer)
        }
    except Exception as e:
        return {"error": str(e), "question": question}


dataset = []
for item in DATA:
    dataset.extend(standardize_item(item, f"{name}"))

results = []
correct = 0
total = len(dataset[:length_test])
with ThreadPoolExecutor(max_workers=5) as executor:
    futures = [executor.submit(process_item, item, f"{name}") for item in dataset[:length_test]]
    for future in tqdm(as_completed(futures), total=total):
        result = future.result()
        if "error" not in result:
            results.append(result)
            if result["correct"]:
                correct += 1
        else:
            print(f"Error on question: {result['question'][:60]}... => {result['error']}")
accuracy = correct / total * 100
print(f"Accuracy: {accuracy:.2f}% ({correct}/{total})")


  0%|          | 0/300 [00:00<?, ?it/s]Failed to send compressed multipart ingest: langsmith.utils.LangSmithRateLimitError: Rate limit exceeded for https://api.smith.langchain.com/runs/multipart. HTTPError('429 Client Error: Too Many Requests for url: https://api.smith.langchain.com/runs/multipart', '{"error":"Too many requests: tenant exceeded usage limits: Monthly unique traces usage limit exceeded"}\n')trace=4781d9c9-e5c2-413b-946c-85c0b9632596,id=4781d9c9-e5c2-413b-946c-85c0b9632596; trace=4781d9c9-e5c2-413b-946c-85c0b9632596,id=fcc2b717-02b6-440b-90f3-25d69ef07a3c; trace=4781d9c9-e5c2-413b-946c-85c0b9632596,id=4179c46a-1211-4a97-9a29-edbc59696c8a; trace=4781d9c9-e5c2-413b-946c-85c0b9632596,id=c312255a-7f33-4ec9-86cf-607a5802afe2
Failed to send compressed multipart ingest: langsmith.utils.LangSmithRateLimitError: Rate limit exceeded for https://api.smith.langchain.com/runs/multipart. HTTPError('429 Client Error: Too Many Requests for url: https://api.smith.langchain.com/runs/multip

Accuracy: 92.33% (277/300)





In [11]:
output_path = "PoT_results.json"
wrong_answers = [r for r in results if not r["correct"]]
with open(output_path, "w", encoding="utf-8") as f:
    json.dump(wrong_answers, f, ensure_ascii=False, indent=2)
print(f"Đã lưu kết quả vào {output_path}")


Đã lưu kết quả vào PoT_results.json


In [None]:
import math
def run_graph(inputs: dict):
    # Chuẩn bị state đầu vào cho graph
    state = {
        "question": inputs["question"],
        "context": inputs.get("context", ""),
        "error": None,
    }
    # Chạy graph
    final_state = graph.invoke(state)
    # Lấy kết quả cuối cùng
    result = final_state
    # Nếu bạn muốn log token, có thể lấy từ state (nếu đã lưu ở node PoT)
    token_info = {
        "input_tokens": final_state.get("input_tokens", None),
        "output_tokens": final_state.get("output_tokens", None),
        "total_tokens": final_state.get("total_tokens", None),
    }
    return token_info, result
@traceable(run_type="tool")
def compare_result(inputs: dict, reference_outputs: dict, outputs: dict):
    # Lấy đúng trường kết quả
    reference_response = extract_ground_truth(reference_outputs["answer"],f"{name}")
    run_response = outputs.get("final_answer")
    # Đảm bảo là string
    reference_response = str(reference_response).strip()
    run_response = str(run_response).strip()
    # So sánh số
    try:
        score = math.isclose(float(reference_response), float(run_response), rel_tol=1e-3)
    except Exception:
        score = (reference_response == run_response)
    return {"key": "is_correct", "score": int(score)}

@traceable(run_type="chain")
def target_function(inputs: dict):
    token, result = run_graph(inputs)
    # Nếu dùng LangSmith, log token vào run tree
    rt = ls.get_current_run_tree()
    if token["input_tokens"] is not None:
        rt.metadata["input_tokens"] = token["input_tokens"]
        rt.metadata["output_tokens"] = token["output_tokens"]
        rt.metadata["total_tokens"] = token["total_tokens"]
    return result

evaluate(
    target_function,
    data=client.list_examples(dataset_name=f"{name_model}", splits=["base"]),
    evaluators=[compare_result],
    experiment_prefix=f"{name_model} - Test Dataset"
)


Failed to send compressed multipart ingest: langsmith.utils.LangSmithRateLimitError: Rate limit exceeded for https://api.smith.langchain.com/runs/multipart. HTTPError('429 Client Error: Too Many Requests for url: https://api.smith.langchain.com/runs/multipart', '{"error":"Too many requests: tenant exceeded usage limits: Monthly unique traces usage limit exceeded"}\n')trace=3d758d13-08f6-4b66-8223-11da43195db4,id=dd92fe14-0f91-4484-a535-c7f8ad76cd25; trace=3d758d13-08f6-4b66-8223-11da43195db4,id=7e90e388-5f91-49a5-845b-68fe2c2934b6; trace=3d758d13-08f6-4b66-8223-11da43195db4,id=7e90e388-5f91-49a5-845b-68fe2c2934b6; trace=3d758d13-08f6-4b66-8223-11da43195db4,id=0315f1dc-4d7d-4337-8bf8-07616219d973; trace=3d758d13-08f6-4b66-8223-11da43195db4,id=f865c118-a364-49f5-9bfd-10db74ec7c52; trace=3d758d13-08f6-4b66-8223-11da43195db4,id=e897ed5b-d6fe-40e0-92de-6da7f1fdb0bc; trace=3d758d13-08f6-4b66-8223-11da43195db4,id=e897ed5b-d6fe-40e0-92de-6da7f1fdb0bc; trace=3d758d13-08f6-4b66-8223-11da43195db4

View the evaluation results for experiment: 'TATQA - Test Dataset-ad34426f' at:
https://smith.langchain.com/o/e7b5e917-6c40-46ad-a54b-83be84870fd4/datasets/04f51027-12f7-4ffc-b682-25bf1d978fcb/compare?selectedSessions=7340e3c6-8803-4075-99fb-ec9c696f7783




0it [00:00, ?it/s]

Failed to send compressed multipart ingest: langsmith.utils.LangSmithRateLimitError: Rate limit exceeded for https://api.smith.langchain.com/runs/multipart. HTTPError('429 Client Error: Too Many Requests for url: https://api.smith.langchain.com/runs/multipart', '{"error":"Too many requests: tenant exceeded usage limits: Monthly unique traces usage limit exceeded"}\n')trace=d04ac9c8-ec41-4ef5-9769-032fd3871b34,id=d04ac9c8-ec41-4ef5-9769-032fd3871b34; trace=d04ac9c8-ec41-4ef5-9769-032fd3871b34,id=474524d4-79e6-49d9-8229-73cb33c3995e; trace=d04ac9c8-ec41-4ef5-9769-032fd3871b34,id=aadcff25-6175-499d-9096-a0a0a698790f; trace=d04ac9c8-ec41-4ef5-9769-032fd3871b34,id=36d81bdc-c34e-4347-b621-e477bd6c3847; trace=d04ac9c8-ec41-4ef5-9769-032fd3871b34,id=eba54abe-757c-4959-8b14-1cff9f47049b
Failed to send compressed multipart ingest: langsmith.utils.LangSmithRateLimitError: Rate limit exceeded for https://api.smith.langchain.com/runs/multipart. HTTPError('429 Client Error: Too Many Requests for url

Unnamed: 0,inputs.context,inputs.question,outputs.question,outputs.context,outputs.program,outputs.result,outputs.final_answer,outputs.error,error,reference.answer,feedback.is_correct,execution_time,example_id,id
0,12. Geographic Information\n\nThe following ta...,What is the change in total revenue between 20...,What is the change in total revenue between 20...,12. Geographic Information\n\nThe following ta...,def solver():\n revenue_2019 = 212628\n ...,-19595,-19595,,,-8.44,0,1.712723,0015f4da-34bb-46d4-9747-caa9464cf78b,d04ac9c8-ec41-4ef5-9769-032fd3871b34
1,(2) Includes property and equipment acquired u...,What is the change in Computer equipment from ...,What is the change in Computer equipment from ...,(2) Includes property and equipment acquired u...,def solver():\n # Computer equipment values...,41,41,,,41,1,2.612138,00391ef2-5bfb-4bca-82f9-dbfec7735043,37a5dde2-fe25-46f8-805f-399a7e30d406
2,Notes to Consolidated Financial Statements - (...,"What was the income tax expense (benefit), in ...","What was the income tax expense (benefit), in ...",Notes to Consolidated Financial Statements - (...,def solver():\n # Income tax expense (benef...,0.5,0.5,,,0.5,1,1.617770,0155e365-c5d6-494b-b582-356362eb88e6,36c2d166-d6ba-4a06-8c74-7bd581a94204
3,15. Product Warranties\n\nThe Company generall...,What is the change in beginning balance betwee...,What is the change in beginning balance betwee...,15. Product Warranties\n\nThe Company generall...,def solver():\n beginning_balance_2019 = 52...,-37,-37,,,-37,1,1.357607,025eefed-96d8-4332-9b3f-fbc873b4553b,1465dc1c-42b4-4199-ba1b-84a052504663
4,The Company has adopted five share option sche...,How many percent of the total shares granted a...,How many percent of the total shares granted a...,The Company has adopted five share option sche...,def solver():\n # Number of shares granted ...,19,19,,,18.68,0,3.375867,049c5bec-65d9-4160-945c-a826a0e63b5d,e6c717d9-f3c6-488f-89e1-a1bbbaa34b8b
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
295,8. Goodwill\n\nThe changes in the carrying amo...,What was the difference in balance in 2017 bet...,What was the difference in balance in 2017 bet...,8. Goodwill\n\nThe changes in the carrying amo...,"def solver():\n # Balance at December 31, 2...",6708,6708,,,6708,1,1.792027,fe6ec1f5-5491-491b-af1c-9e1ac303b16d,9e512a97-f61f-4d29-9208-1f2b908504fc
296,Cash used in investing activities\n\nDetail of...,What was the change in the cash used in Leaseh...,What was the change in the cash used in Leaseh...,Cash used in investing activities\n\nDetail of...,def solver():\n leasehold_2017 = 0.5\n l...,0.4,0.4,,,0.4,1,2.181957,ff17fbf5-d718-4a60-9ceb-1fd924f9f196,bf089d5b-9035-450a-96a3-7fac2763529c
297,16. FINANCIAL INSTRUMENTS AND OTHER FAIR VALUE...,What is the average value of the 2018 and 2019...,What is the average value of the 2018 and 2019...,16. FINANCIAL INSTRUMENTS AND OTHER FAIR VALUE...,def solver():\n # Given fair values of inve...,2511.0,2511.0,,,2511,1,2.695663,ff4c9828-42dc-4493-bec4-b77e030b0d04,dbb40321-2d61-44f6-a18d-7b9ba5a55ccf
298,NAVIOS MARITIME HOLDINGS INC. NOTES TO THE CON...,What was the Accumulated Amortization of favor...,What was the Accumulated Amortization of favor...,NAVIOS MARITIME HOLDINGS INC. NOTES TO THE CON...,"def solver():\n # From the table, accumulat...",2143,2143,,,(2143),1,1.561620,ffc1d281-f636-481a-ace8-d349cdcf01ba,f5d28cd7-52ac-4d04-934b-5663b5429a3a
