In [16]:
from openai import OpenAI
from dateutil.relativedelta import relativedelta
import os
import json
import pandas as pd
from pydantic import BaseModel,Field
from langchain.llms import OpenAI
from dotenv import load_dotenv
from langchain.chat_models import ChatOpenAI, init_chat_model
from langchain_core.messages import HumanMessage, SystemMessage
from tqdm import tqdm
import re
from concurrent.futures import ThreadPoolExecutor, as_completed
import langsmith as ls
from langsmith import traceable, trace
from langsmith import Client, traceable, evaluate
from preprocess_data import prepare_qa_input_with_answer_filter,standardize_item

In [2]:
file_path = '../data/GSM8K/test.jsonl'

with open(file_path, "r", encoding="utf-8") as f:
    gsm8k = [json.loads(line) for line in f]

In [3]:
folder_path = "../dataset_langsmith/"
filename="tatqa_filtered.json"
file_path = os.path.join(folder_path, filename)
with open(file_path, "r", encoding="utf-8") as f:
    raw_data_tatqa = json.load(f)
tatqa = prepare_qa_input_with_answer_filter(raw_data_tatqa)

In [4]:
folder_path = "../dataset_langsmith/"
filename="tabmwp_filtered.json"
file_path = os.path.join(folder_path, filename)
with open(file_path, "r", encoding="utf-8") as f:
    tabmwp = json.load(f)

In [13]:
name= "tabmwp" # tatqa, tabmwp
length_test= 200 # số lượng mẫu muốn test
if name == "gsm8k":
   DATA=gsm8k 
   name_model="GSM8K"
elif name == "tatqa":
   DATA=tatqa
   name_model="TATQA"
else:
   DATA=tabmwp
   name_model="TABMWP"

In [6]:
load_dotenv()
model=init_chat_model('gpt-4.1-mini',model_provider='openai',temperature=0.2)

In [7]:
class MathReasoning(BaseModel):
    explanation: str
    final_answer: str

In [8]:
model_with_tools = model.with_structured_output(MathReasoning)

In [9]:
def extract_ground_truth(answer, dataset_type):
    if dataset_type == "gsm8k":
        match = re.search(r"####\s*([\d,./]+)", answer)
        if match:
            return match.group(1).replace(",", "").strip()
        return answer.strip()
    elif dataset_type == "tatqa":
        # Có thể là list hoặc chuỗi, lấy phần đầu tiên nếu là list
        if isinstance(answer, list):
            ans = str(answer[0]).strip()
        else:
            ans = str(answer).strip()
        # Nếu là dạng [2019] hoặc ["2019"], loại bỏ ngoặc vuông và dấu nháy
        ans = re.sub(r'^[\[\"]*([\d\-\.\/]+)[\]\"]*$', r'\1', ans)
        # Nếu là phân số thì giữ lại dấu "/"
        if '/' in ans:
            ans = re.sub(r"[^-\d/\.]", "", ans)
        else:
            ans = re.sub(r"[^-\d\.]", "", ans)
        return ans
    else:
        return str(answer).strip()
    
def compare_answers(predicted: str, actual: str, eps: float = 1e-3) -> bool:
    try:
        pred = round(float(predicted.strip()))
        act = round(float(actual.strip()))
        return abs(pred - act) < eps
    except ValueError:
        return predicted.strip().lower() == actual.strip().lower()

In [22]:
def process_item(item,dataset_type):
    question = item["question"]
    context = item.get("context", "")
    true_answer = extract_ground_truth(item["answer"],dataset_type)
    try:
        # Nếu có context, nối vào trước question
        if context.strip():
            user_content = f"# Context:\n{context}\n\n# Question: {question}"
        else:
            user_content = question

        messages = [
            SystemMessage(content="""
            You are a math expert.
            For every question and context, you **must** respond using the `MathReasoning` tool.
            - Do not respond with plain text or natural language.
            - Include a `final_answer` as a single number, no units or symbols.
            - If you cannot solve it, return a final_answer of "unknown".
            - When dealing with money, do not round to thousands unless explicitly stated.
            """),
            HumanMessage(content=user_content)
        ]
        ai_msg = model_with_tools.invoke(messages)
        predicted_answer = ai_msg.final_answer

        return {
            "question": question,
            "context": context,
            "true_answer": true_answer,
            "predicted_answer": predicted_answer,
            "correct": compare_answers(predicted_answer, true_answer)
        }
    except Exception as e:
        return {"error": str(e), "question": question}

dataset = []
for item in DATA:
    dataset.extend(standardize_item(item, f"{name}"))

results = []
correct = 0
total = len(dataset[:length_test])
with ThreadPoolExecutor(max_workers=5) as executor:
    futures = [executor.submit(process_item, item, f"{name}") for item in dataset[:length_test]]
    for future in tqdm(as_completed(futures), total=total):
        result = future.result()
        if "error" not in result:
            results.append(result)
            if result["correct"]:
                correct += 1
        else:
            print(f"Error on question: {result['question'][:60]}... => {result['error']}")
accuracy = correct / total * 100
print(f"Accuracy: {accuracy:.2f}% ({correct}/{total})")

  0%|          | 0/300 [00:00<?, ?it/s]Failed to send compressed multipart ingest: langsmith.utils.LangSmithRateLimitError: Rate limit exceeded for https://api.smith.langchain.com/runs/multipart. HTTPError('429 Client Error: Too Many Requests for url: https://api.smith.langchain.com/runs/multipart', '{"error":"Too many requests: tenant exceeded usage limits: Monthly unique traces usage limit exceeded"}\n')trace=8f73d4be-a23d-442d-b82e-a3f7acedd869,id=8f73d4be-a23d-442d-b82e-a3f7acedd869; trace=8f73d4be-a23d-442d-b82e-a3f7acedd869,id=23a3ed6c-7cf2-465d-b30d-4dd58087bdd7; trace=2db1b249-74ad-4866-afb1-644bfeef00de,id=2db1b249-74ad-4866-afb1-644bfeef00de; trace=2db1b249-74ad-4866-afb1-644bfeef00de,id=723f7dfb-2e59-4ec5-b767-2e079df09256; trace=82ea7807-7475-4ea2-bf2f-df9408fb5865,id=82ea7807-7475-4ea2-bf2f-df9408fb5865; trace=43553a8f-cd18-4909-a3bd-0c5d97342564,id=43553a8f-cd18-4909-a3bd-0c5d97342564; trace=82ea7807-7475-4ea2-bf2f-df9408fb5865,id=6093ae53-04e8-4983-bf2b-faae8e38090c; tra

Accuracy: 97.00% (291/300)





In [23]:
output_path = "Zero_shot_results.json"
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
def custom_encoder(obj):
    if hasattr(obj, "model_dump"):
        return obj.model_dump()
    elif hasattr(obj, "dict"):
        return obj.dict()
    else:
        raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")

wrong_answers = [r for r in results if not r.get("correct", False)]

try:
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(wrong_answers, f, ensure_ascii=False, indent=2, default=custom_encoder)
    print(f"Đã lưu {len(wrong_answers)} kết quả sai vào {output_path}")
except TypeError as e:
    print(f"Lỗi khi ghi file JSON: {e}")


Đã lưu 9 kết quả sai vào Zero_shot_results.json


In [14]:
import math
@traceable(run_type="chain")
def target_function(inputs: dict):
    question = inputs["question"]
    context = inputs.get("context", "")
    # Nếu có context, nối vào trước question
    if context.strip():
        user_content = f"# Context:\n{context}\n\n# Question: {question}"
    else:
        user_content = question

    messages = [
        SystemMessage(content="""
        You are a math expert.
        For every question, you **must** respond using the `MathReasoning` tool.
        - Do not respond with plain text or natural language.
        - Include a `final_answer` as a single number, no units or symbols.
        - If you cannot solve it, return a final_answer of "unknown".
        - When dealing with money, do not round to thousands unless explicitly stated.
        """),
        HumanMessage(content=user_content)
    ]
    ai_msg = model_with_tools.invoke(messages)
    predicted_answer = ai_msg.final_answer
    # Nếu muốn log reasoning steps
    return {
        "final_answer": predicted_answer,
        "steps": getattr(ai_msg, "steps", None)
    }

@traceable(run_type="tool")
def compare_result(inputs: dict, reference_outputs: dict, outputs: dict):
    reference_response = extract_ground_truth(reference_outputs["answer"],f"{name}")
    run_response = outputs.get("final_answer")
    reference_response = str(reference_response).strip()
    run_response = str(run_response).strip()
    try:
        score = math.isclose(float(reference_response), float(run_response), rel_tol=1e-3)
    except Exception:
        score = (reference_response == run_response)
    return {"key": "is_correct", "score": int(score)}

client = Client()
evaluate(
    target_function,
    data=client.list_examples(dataset_name=f"{name_model}", splits=["base"]),
    evaluators=[compare_result],
    experiment_prefix=f"{name_model} - Zero-shot"
)

View the evaluation results for experiment: 'TABMWP - Zero-shot-59e86993' at:
https://smith.langchain.com/o/5fc25493-0003-4d31-ac07-9d677640262f/datasets/eb4dd623-5fda-46f4-8535-aa2f9d48f874/compare?selectedSessions=cccc57a6-ac97-42ba-bf3a-33b62058d7e7




0it [00:00, ?it/s]

Unnamed: 0,inputs.context,inputs.question,outputs.final_answer,outputs.steps,error,reference.answer,feedback.is_correct,execution_time,example_id,id
0,Track team sizes (number of members)\n\n| Stem...,"Mr. McCall, a track coach, recorded the sizes ...",6,,,6,1,4.570451,002abfd1-6f45-4d30-a8f0-0bb4bb1d8556,ddbfd75c-964d-4466-bbcc-b8f3a0f37377
1,Fish per tank\n\n| Stem | Leaf |\n| --- | ---...,A pet store owner had his staff count the numb...,7,,,7,1,4.148864,01b1dbd1-4d6f-45d9-831c-a7a69b3b9859,f9a8e373-defc-4c7c-b855-324bbbbda052
2,| Column 1 | Column 2 |\n| --- | --- |\n| busi...,How much money does Estelle need to buy a brig...,8557,,,8557,1,1.561358,039dcdea-6eee-42df-8018-e8cd018f74a3,e0abb05d-29d4-41c5-b67e-2f8de608f63e
3,Pages written\n\n| Day | Number of pages |\n| ...,An author kept a log of how many pages he wrot...,3,,,3,1,2.717519,0536acf5-d205-4d21-9e9a-1c964fa316c9,8ee7d92c-2343-4f4b-878d-454441ab7488
4,| Column 1 | Column 2 |\n| --- | --- |\n| Euro...,How much money does Donald need to buy an Afri...,2332,,,2332,1,1.749655,09d33bfe-5f0e-4b4f-be25-e8d96ef1a75a,a0fe9f71-a74f-4ae9-ab27-6af9c7ed880f
...,...,...,...,...,...,...,...,...,...,...
195,Middletown School District sports budget\n\n| ...,Each year the Middletown School District publi...,1,,,1,1,2.427433,fa39a138-8df4-4d69-8982-1c25bb0fa827,641fd841-cbb1-49e7-abeb-c811d5122758
196,| Column 1 | Column 2 |\n| --- | --- |\n| Aust...,How much money does Sasha need to buy a Europe...,18456,,,18456,1,1.877386,fa77d470-7211-432c-b430-1322e7670d00,74e48f96-1277-48bb-869d-07be77df1803
197,Clubs\n\n| Name | Number of clubs |\n| --- | -...,Some students compared how many clubs they bel...,5,,,5,1,2.276781,fcdd496f-89de-4bf0-a679-2b47d8447606,9a6bc18d-c5e5-4482-927a-0d392947f7f1
198,| Column 1 | Column 2 |\n| --- | --- |\n| blac...,"Darnel purchased 4 pounds of rocks, 2 pounds o...",30,,,30,1,1.945407,fd72da95-eda4-4c1a-8e41-fd12a2034a9e,9e0a031a-898b-46af-8e17-12afafe85ee2


In [17]:
experiment_name = "GSM8K - Zero-shot-5bc125d4" 

runs = list(client.list_runs(project_name=experiment_name, execution_order=1))

data = []
count = 0
for run in runs:
    is_correct = None
    feedback_list = client.list_feedback(run_ids=[run.id])
    for fb in feedback_list:
        if fb.key == "is_correct":
            is_correct = fb.score
    count+=1
    row = {
        "run_id": run.id,
        "error": run.error,
        "latency_sec": (run.end_time - run.start_time).total_seconds() if run.end_time and run.start_time else None,
        "total_cost": run.total_cost,
        "input_tokens": run.prompt_tokens,
        "output_tokens": run.completion_tokens,
        "total_tokens": run.total_tokens,
        "is_correct": is_correct,
    }
    data.append(row)

df_zeroshot_gsm8k = pd.DataFrame(data)

In [18]:
experiment_name = "TATQA - Zero-shot-49e74dd1" 

runs = list(client.list_runs(project_name=experiment_name, execution_order=1))

data = []
count = 0
for run in runs:
    is_correct = None
    feedback_list = client.list_feedback(run_ids=[run.id])
    for fb in feedback_list:
        if fb.key == "is_correct":
            is_correct = fb.score
    count+=1
    row = {
        "run_id": run.id,
        "error": run.error,
        "latency_sec": (run.end_time - run.start_time).total_seconds() if run.end_time and run.start_time else None,
        "total_cost": run.total_cost,
        "input_tokens": run.prompt_tokens,
        "output_tokens": run.completion_tokens,
        "total_tokens": run.total_tokens,
        "is_correct": is_correct,
    }
    data.append(row)

df_zeroshot_tatqa = pd.DataFrame(data)

In [19]:
experiment_name = "TABMWP - Zero-shot-59e86993" 

runs = list(client.list_runs(project_name=experiment_name, execution_order=1))

data = []
count = 0
for run in runs:
    is_correct = None
    feedback_list = client.list_feedback(run_ids=[run.id])
    for fb in feedback_list:
        if fb.key == "is_correct":
            is_correct = fb.score
    count+=1
    row = {
        "run_id": run.id,
        "error": run.error,
        "latency_sec": (run.end_time - run.start_time).total_seconds() if run.end_time and run.start_time else None,
        "total_cost": run.total_cost,
        "input_tokens": run.prompt_tokens,
        "output_tokens": run.completion_tokens,
        "total_tokens": run.total_tokens,
        "is_correct": is_correct,
    }
    data.append(row)

df_zeroshot_tabmwp = pd.DataFrame(data)

In [20]:
df_zeroshot_tatqa.to_csv("Zero-shot_TATQA.csv", index=False)
df_zeroshot_gsm8k.to_csv("Zero-shot_GSM8K.csv", index=False)
df_zeroshot_tabmwp.to_csv("Zero-shot_TABMWP.csv", index=False)