In [1]:
from openai import OpenAI
import os
import json
from pydantic import BaseModel,Field
from langchain.llms import OpenAI
from dotenv import load_dotenv
from langchain.chat_models import init_chat_model
from langchain_core.messages import HumanMessage, SystemMessage
from tqdm import tqdm
import re
import textwrap
from langgraph.graph import StateGraph, END
from typing import TypedDict, Optional
from openai import OpenAI
from pydantic import BaseModel
import traceback
from typing import List, Union
from langgraph.graph import StateGraph, END
from typing import TypedDict, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed
from langsmith import traceable, trace
from langsmith import Client, traceable, evaluate
from langsmith import evaluate, Client
import langsmith as ls
from preprocess_data import prepare_qa_input_with_answer_filter,standardize_item


In [2]:
file_path = '../data/GSM8K/test.jsonl'

with open(file_path, "r", encoding="utf-8") as f:
    gsm8k = [json.loads(line) for line in f]

In [3]:
folder_path = "../dataset_langsmith/"
filename="tatqa_filtered.json"
file_path = os.path.join(folder_path, filename)
with open(file_path, "r", encoding="utf-8") as f:
    raw_data_tatqa = json.load(f)
tatqa = prepare_qa_input_with_answer_filter(raw_data_tatqa)

In [4]:
folder_path = "../dataset_langsmith/"
filename="tabmwp_filtered.json"
file_path = os.path.join(folder_path, filename)
with open(file_path, "r", encoding="utf-8") as f:
    tabmwp = json.load(f)

In [5]:
name= "tatqa" # tatqa, tabmwp
length_test= 300 # số lượng mẫu muốn test
if name == "gsm8k":
   DATA=gsm8k 
   name_model="GSM8K"
elif name == "tatqa":
   DATA=tatqa
   name_model="TATQA"
else:
   DATA=tabmwp
   name_model="TABMWP"

In [6]:
load_dotenv()
model = init_chat_model("gpt-4.1-mini", model_provider="openai", temperature=0.2)

In [7]:
class State(TypedDict):
    question: str
    context: Optional[str]
    program: Optional[str]
    result: Optional[str]
    final_answer: Optional[str]
    error: Optional[str]
class IntermediateProgram(BaseModel):
    program: str


In [8]:
def pot_node(state: State) -> State:
    context_str = f"# Context:\n{state['context']}\n" if state.get("context") else ""
    pot_messages = [
        SystemMessage('''You will write python program to solve math problems. You will only write code blocks.'''),
        HumanMessage(content=f'''
# Answer this question by implementing a solver() function.
# Write a Python program, and then return the answer.
Question: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?
```
def solver():
    """Olivia has $23. She bought five bagels for $3 each. How much money does she have left?"""
    money_initial = 23
    bagels = 5
    bagel_cost = 3
    money_spent = bagels * bagel_cost
    money_left = money_initial - money_spent
    ans = money_left
    return ans
result = solver()
```
Context: |  |  | Three Months Ended |  | % Variation |  |\n| --- | --- | --- | --- | --- | --- |\n|  | December 31, 2019 | September 29, 2019 | December 31, 2018 | Sequential | Year-Over-Year |\n|  |  |  | (Unaudited, in millions) |  |  |\n| Selling, general and administrative expenses | $(285) | $(267) | $(285) | (6.3)% | 0.4% |\n| Research and development expenses | (387) | (362) | (345) | (7.0) | (12.3) |\n| Total operating expenses | $(672) | $(629) | $(630) | (6.7)% | (6.6)% |\n| As percentage of net revenues | (24.4)% | (24.7)% | (23.8)% | +30 bps | -60 bps |\n
Question: What is the increase/ (decrease) in total operating expenses from the period December 31, 2018 to 2019?
```
def solver():
    """What is the increase/ (decrease) in total operating expenses from the period December 31, 2018 to 2019?"""
    expenses_2019 = 672
    expenses_2018 = 630
    # Tính mức tăng
    increase = expenses_2019 - expenses_2018
    # Nếu kết quả là số thập phân .0 thì chuyển thành số nguyên
    if isinstance(increase, float) and increase.is_integer():
        increase = int(increase)
    return increase           
```
Context:
| Number of times | Frequency |
|-----------------|-----------|
| 0               | 1         |
| 1               | 18        |
| 2               | 12        |
| 3               | 13        |
| 4               | 0         |
Question: How many customers are there in all?
```
def solver():
    """How many customers are there in all?"""
    frequencies = [1, 18, 12, 13, 0]
    ans = sum(frequencies)
    if isinstance(ans, float) and ans.is_integer():
        ans = int(ans)
    return ans
```
How about this question?
{context_str}
# Question: {state["question"]}
# Include a final answer as a single number, no units or symbols.
# 'CALL' the solver() function and then 'MUST' assign the variale 'result'.
# Before returning the final result, DOUBLE-CHECK each variable assignment and calculation to ensure they match the problem statement.
''')]

    model_pot = model.with_structured_output(IntermediateProgram)
    model_invoke=model_pot.invoke(pot_messages)
    code=model_invoke.program
    return {**state, "program": code}

def exec_node(state: State) -> State:
    try:
        exec_globals = {}
        exec(state["program"], {}, exec_globals)
        result = exec_globals.get("result", None)
        if result is None:
            raise ValueError("Missing `result`")
        return {**state, "result": str(result), "error": None}
    except Exception as e:
        return {**state, "result": None, "error": str(e)}

def check_eos(state: State) -> bool:
    if state["error"] is None:
        return True
    else:
        return False
def write_final_answer_node(state:State)->State:

    if state["error"] is None:
        result=str(state["result"])
    else:
        result=str(9999)
    return {**state,"final_answer":result}


builder = StateGraph(State)
builder.add_node("PoT", pot_node)
builder.add_node("Exec", exec_node)

builder.add_node("write_final_answer",write_final_answer_node)

builder.set_entry_point("PoT")
builder.add_edge("PoT", "Exec")
builder.add_edge("Exec", "write_final_answer")
builder.add_edge("write_final_answer",END)
graph = builder.compile()


In [9]:
def extract_ground_truth(answer, dataset_type):
    if dataset_type == "gsm8k":
        match = re.search(r"####\s*([\d,./]+)", answer)
        if match:
            return match.group(1).replace(",", "").strip()
        return answer.strip()
    elif dataset_type == "tatqa":
        # Có thể là list hoặc chuỗi, lấy phần đầu tiên nếu là list
        if isinstance(answer, list):
            ans = str(answer[0]).strip()
        else:
            ans = str(answer).strip()
        # Nếu là dạng [2019] hoặc ["2019"], loại bỏ ngoặc vuông và dấu nháy
        ans = re.sub(r'^[\[\"]*([\d\-\.\/]+)[\]\"]*$', r'\1', ans)
        # Nếu là phân số thì giữ lại dấu "/"
        if '/' in ans:
            ans = re.sub(r"[^-\d/\.]", "", ans)
        else:
            ans = re.sub(r"[^-\d\.]", "", ans)
        return ans
    else:
        return str(answer).strip()
    
def compare_answers(predicted: str, actual: str, eps: float = 1e-3) -> bool:
    try:
        pred = round(float(predicted.strip()))
        act = round(float(actual.strip()))
        return abs(pred - act) < eps
    except ValueError:
        return predicted.strip().lower() == actual.strip().lower()



In [10]:

def process_item(item, dataset_type):
    question = item["question"]
    context = item.get("context", "")
    true_answer = extract_ground_truth(item["answer"], dataset_type)
    try:
        # Nếu pipeline/graph của bạn cần context thì truyền vào, nếu không thì chỉ truyền question
        if context:
            result = graph.invoke({"question": question, "context": context})
        else:
            result = graph.invoke({"question": question})
        return {
            "question": question,
            "program": result.get("program", ""),
            "true_answer": true_answer,
            "predicted_answer": result["final_answer"],
            "context": context,
            "correct": compare_answers(result["final_answer"], true_answer)
        }
    except Exception as e:
        return {"error": str(e), "question": question}


dataset = []
for item in DATA:
    dataset.extend(standardize_item(item, f"{name}"))

results = []
correct = 0
total = len(dataset[:length_test])
with ThreadPoolExecutor(max_workers=5) as executor:
    futures = [executor.submit(process_item, item, f"{name}") for item in dataset[:length_test]]
    for future in tqdm(as_completed(futures), total=total):
        result = future.result()
        if "error" not in result:
            results.append(result)
            if result["correct"]:
                correct += 1
        else:
            print(f"Error on question: {result['question'][:60]}... => {result['error']}")
accuracy = correct / total * 100
print(f"Accuracy: {accuracy:.2f}% ({correct}/{total})")


  0%|          | 0/300 [00:00<?, ?it/s]Failed to send compressed multipart ingest: langsmith.utils.LangSmithRateLimitError: Rate limit exceeded for https://api.smith.langchain.com/runs/multipart. HTTPError('429 Client Error: Too Many Requests for url: https://api.smith.langchain.com/runs/multipart', '{"error":"Too many requests: tenant exceeded usage limits: Monthly unique traces usage limit exceeded"}\n')trace=c73d8c5e-2242-4d58-9c9d-21f3e5009f9e,id=c73d8c5e-2242-4d58-9c9d-21f3e5009f9e; trace=c73d8c5e-2242-4d58-9c9d-21f3e5009f9e,id=650d8301-dfa9-45a1-a382-f74a2e726509
Failed to send compressed multipart ingest: langsmith.utils.LangSmithRateLimitError: Rate limit exceeded for https://api.smith.langchain.com/runs/multipart. HTTPError('429 Client Error: Too Many Requests for url: https://api.smith.langchain.com/runs/multipart', '{"error":"Too many requests: tenant exceeded usage limits: Monthly unique traces usage limit exceeded"}\n')trace=6f858e58-6e13-4f3c-a42a-5247dbf7d8ff,id=6f858e58

Accuracy: 92.67% (278/300)





Failed to send compressed multipart ingest: langsmith.utils.LangSmithRateLimitError: Rate limit exceeded for https://api.smith.langchain.com/runs/multipart. HTTPError('429 Client Error: Too Many Requests for url: https://api.smith.langchain.com/runs/multipart', '{"error":"Too many requests: tenant exceeded usage limits: Monthly unique traces usage limit exceeded"}\n')trace=97c38f31-2ffb-459a-a1df-af08919ee98f,id=48780671-70e0-41eb-bf09-6753603736fe; trace=97c38f31-2ffb-459a-a1df-af08919ee98f,id=acf8b518-f396-4e8a-ace2-19742fa3774a; trace=97c38f31-2ffb-459a-a1df-af08919ee98f,id=acf8b518-f396-4e8a-ace2-19742fa3774a; trace=97c38f31-2ffb-459a-a1df-af08919ee98f,id=f8c97b84-3cfa-40ee-8e61-1303504c4ae9; trace=97c38f31-2ffb-459a-a1df-af08919ee98f,id=9e05dafc-3426-4a99-b2f6-52870952950f; trace=97c38f31-2ffb-459a-a1df-af08919ee98f,id=18a88b44-ceff-428f-a9fa-0b49d3fc6864; trace=97c38f31-2ffb-459a-a1df-af08919ee98f,id=18a88b44-ceff-428f-a9fa-0b49d3fc6864; trace=97c38f31-2ffb-459a-a1df-af08919ee98f

In [11]:
output_path = "PaL_results.json"
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
def custom_encoder(obj):
    if hasattr(obj, "model_dump"):
        return obj.model_dump()
    elif hasattr(obj, "dict"):
        return obj.dict()
    else:
        raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")

wrong_answers = [r for r in results if not r.get("correct", False)]

# Ghi file JSON
try:
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(wrong_answers, f, ensure_ascii=False, indent=2, default=custom_encoder)
    print(f"Đã lưu {len(wrong_answers)} kết quả sai vào {output_path}")
except TypeError as e:
    print(f"Lỗi khi ghi file JSON: {e}")


Đã lưu 22 kết quả sai vào PaL_results.json


In [12]:
import math
def run_graph(inputs: dict):
    # Chuẩn bị state đầu vào cho graph
    state = {
        "question": inputs["question"],
        "context": inputs.get("context", ""),
        "error": None,
    }
    # Chạy graph
    final_state = graph.invoke(state)
    # Lấy kết quả cuối cùng
    result = final_state
    # Nếu bạn muốn log token, có thể lấy từ state (nếu đã lưu ở node PoT)
    token_info = {
        "input_tokens": final_state.get("input_tokens", None),
        "output_tokens": final_state.get("output_tokens", None),
        "total_tokens": final_state.get("total_tokens", None),
    }
    return token_info, result
@traceable(run_type="tool")
def compare_result(inputs: dict, reference_outputs: dict, outputs: dict):
    # Lấy đúng trường kết quả
    reference_response = extract_ground_truth(reference_outputs["answer"],f"{name}")
    run_response = outputs.get("final_answer")
    # Đảm bảo là string
    reference_response = str(reference_response).strip()
    run_response = str(run_response).strip()
    # So sánh số
    try:
        score = math.isclose(float(reference_response), float(run_response), rel_tol=1e-3)
    except Exception:
        score = (reference_response == run_response)
    return {"key": "is_correct", "score": int(score)}

@traceable(run_type="chain")
def target_function(inputs: dict):
    token, result = run_graph(inputs)
    # Nếu dùng LangSmith, log token vào run tree
    rt = ls.get_current_run_tree()
    if token["input_tokens"] is not None:
        rt.metadata["input_tokens"] = token["input_tokens"]
        rt.metadata["output_tokens"] = token["output_tokens"]
        rt.metadata["total_tokens"] = token["total_tokens"]
    return result


client = Client()
evaluate(
    target_function,
    data=client.list_examples(dataset_name=f"{name_model}", splits=["base"]),
    evaluators=[compare_result],
    experiment_prefix=f"{name_model} - Test Dataset"
)


Failed to send compressed multipart ingest: langsmith.utils.LangSmithRateLimitError: Rate limit exceeded for https://api.smith.langchain.com/runs/multipart. HTTPError('429 Client Error: Too Many Requests for url: https://api.smith.langchain.com/runs/multipart', '{"error":"Too many requests: tenant exceeded usage limits: Monthly unique traces usage limit exceeded"}\n')trace=47b11616-3345-42c8-af27-dff86649988e,id=53d48178-7eb8-4f39-a602-c908d81f19b1; trace=47b11616-3345-42c8-af27-dff86649988e,id=acb8490e-8c96-493a-824f-7ed00a749d59; trace=47b11616-3345-42c8-af27-dff86649988e,id=acb8490e-8c96-493a-824f-7ed00a749d59; trace=47b11616-3345-42c8-af27-dff86649988e,id=d0c17b79-5a9b-4035-8af5-887b1a6a4702; trace=47b11616-3345-42c8-af27-dff86649988e,id=d4fa5c14-7fcf-4db6-8057-0a69d9a3dcc5; trace=47b11616-3345-42c8-af27-dff86649988e,id=8b3b00da-255a-4ffd-bc8c-7f2d9c4667d2; trace=47b11616-3345-42c8-af27-dff86649988e,id=8b3b00da-255a-4ffd-bc8c-7f2d9c4667d2; trace=47b11616-3345-42c8-af27-dff86649988e

View the evaluation results for experiment: 'TATQA - Test Dataset-e604794d' at:
https://smith.langchain.com/o/e7b5e917-6c40-46ad-a54b-83be84870fd4/datasets/04f51027-12f7-4ffc-b682-25bf1d978fcb/compare?selectedSessions=1808a6bc-484b-4933-a47c-f84dfee2e86a




0it [00:00, ?it/s]

Failed to send compressed multipart ingest: langsmith.utils.LangSmithRateLimitError: Rate limit exceeded for https://api.smith.langchain.com/runs/multipart. HTTPError('429 Client Error: Too Many Requests for url: https://api.smith.langchain.com/runs/multipart', '{"error":"Too many requests: tenant exceeded usage limits: Monthly unique traces usage limit exceeded"}\n')trace=279ca5b8-9e90-483a-b78b-9c00daf0e8e8,id=279ca5b8-9e90-483a-b78b-9c00daf0e8e8; trace=279ca5b8-9e90-483a-b78b-9c00daf0e8e8,id=483f65aa-5f1b-4337-b0e8-c221b75dcf48; trace=279ca5b8-9e90-483a-b78b-9c00daf0e8e8,id=6366e998-25a2-45f5-9d58-6fffba10cdd1; trace=279ca5b8-9e90-483a-b78b-9c00daf0e8e8,id=42db7aaa-5167-4ed8-b064-87f00880f81a; trace=279ca5b8-9e90-483a-b78b-9c00daf0e8e8,id=800446bb-f40a-41da-8b56-b0b8df2ce8a8
Failed to send compressed multipart ingest: langsmith.utils.LangSmithRateLimitError: Rate limit exceeded for https://api.smith.langchain.com/runs/multipart. HTTPError('429 Client Error: Too Many Requests for url

Unnamed: 0,inputs.context,inputs.question,outputs.question,outputs.context,outputs.program,outputs.result,outputs.final_answer,outputs.error,error,reference.answer,feedback.is_correct,execution_time,example_id,id
0,12. Geographic Information\n\nThe following ta...,What is the change in total revenue between 20...,What is the change in total revenue between 20...,12. Geographic Information\n\nThe following ta...,"def solver():\n """"""What is the change in to...",-19595,-19595,,,-8.44,0,4.381304,0015f4da-34bb-46d4-9747-caa9464cf78b,279ca5b8-9e90-483a-b78b-9c00daf0e8e8
1,(2) Includes property and equipment acquired u...,What is the change in Computer equipment from ...,What is the change in Computer equipment from ...,(2) Includes property and equipment acquired u...,"def solver():\n """"""What is the change in Co...",41,41,,,41,1,4.043920,00391ef2-5bfb-4bca-82f9-dbfec7735043,dd7e6f3e-fd80-4527-b453-d85e0c575a03
2,Notes to Consolidated Financial Statements - (...,"What was the income tax expense (benefit), in ...","What was the income tax expense (benefit), in ...",Notes to Consolidated Financial Statements - (...,"def solver():\n """"""What was the income tax ...",0.5,0.5,,,0.5,1,2.997291,0155e365-c5d6-494b-b582-356362eb88e6,86eb32cd-74fa-4c4f-a39b-c8cde6c4ed84
3,15. Product Warranties\n\nThe Company generall...,What is the change in beginning balance betwee...,What is the change in beginning balance betwee...,15. Product Warranties\n\nThe Company generall...,"def solver():\n """"""What is the change in be...",-37,-37,,,-37,1,2.100618,025eefed-96d8-4332-9b3f-fbc873b4553b,1f5cdd7a-0357-454c-bc8f-2895b08bb268
4,The Company has adopted five share option sche...,How many percent of the total shares granted a...,How many percent of the total shares granted a...,The Company has adopted five share option sche...,"def solver():\n """"""Calculate the percentage...",18.679352687647395,18.679352687647395,,,18.68,1,5.496540,049c5bec-65d9-4160-945c-a826a0e63b5d,638da4bf-5b30-4dd2-b460-308273923977
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
295,8. Goodwill\n\nThe changes in the carrying amo...,What was the difference in balance in 2017 bet...,What was the difference in balance in 2017 bet...,8. Goodwill\n\nThe changes in the carrying amo...,"def solver():\n """"""What was the difference ...",6708,6708,,,6708,1,2.046631,fe6ec1f5-5491-491b-af1c-9e1ac303b16d,c5eddceb-7353-4a5e-b5be-da8738ce0b10
296,Cash used in investing activities\n\nDetail of...,What was the change in the cash used in Leaseh...,What was the change in the cash used in Leaseh...,Cash used in investing activities\n\nDetail of...,"def solver():\n """"""Calculate the change in ...",0.4,0.4,,,0.4,1,2.040758,ff17fbf5-d718-4a60-9ceb-1fd924f9f196,9f68bae2-856b-4033-ad68-bc323b016055
297,16. FINANCIAL INSTRUMENTS AND OTHER FAIR VALUE...,What is the average value of the 2018 and 2019...,What is the average value of the 2018 and 2019...,16. FINANCIAL INSTRUMENTS AND OTHER FAIR VALUE...,"def solver():\n """"""Calculate the average va...",2511,2511,,,2511,1,2.224238,ff4c9828-42dc-4493-bec4-b77e030b0d04,964996d9-4c49-4916-909e-1c566a9047c2
298,NAVIOS MARITIME HOLDINGS INC. NOTES TO THE CON...,What was the Accumulated Amortization of favor...,What was the Accumulated Amortization of favor...,NAVIOS MARITIME HOLDINGS INC. NOTES TO THE CON...,"def solver():\n """"""What was the Accumulated...",2143,2143,,,(2143),1,1.660338,ffc1d281-f636-481a-ace8-d349cdcf01ba,524b6a4f-0394-4878-bc3b-3aa0054f3605
