In [17]:
from openai import OpenAI
from dateutil.relativedelta import relativedelta
import os
import json
from pydantic import BaseModel,Field
from langchain.llms import OpenAI
from dotenv import load_dotenv
from langchain.chat_models import ChatOpenAI, init_chat_model
from langchain_core.messages import HumanMessage, SystemMessage
from tqdm import tqdm
import pandas as pd
import re
from concurrent.futures import ThreadPoolExecutor, as_completed
import langsmith as ls
from langsmith import traceable, trace
from langsmith import Client, traceable, evaluate
from preprocess_data import prepare_qa_input_with_answer_filter,standardize_item

In [2]:
file_path = '../data/GSM8K/test.jsonl'

with open(file_path, "r", encoding="utf-8") as f:
    gsm8k = [json.loads(line) for line in f]

In [3]:
folder_path = "../dataset_langsmith/"
filename="tatqa_filtered.json"
file_path = os.path.join(folder_path, filename)
with open(file_path, "r", encoding="utf-8") as f:
    raw_data_tatqa = json.load(f)
tatqa = prepare_qa_input_with_answer_filter(raw_data_tatqa)

In [4]:
folder_path = "../dataset_langsmith/"
filename="tabmwp_filtered.json"
file_path = os.path.join(folder_path, filename)
with open(file_path, "r", encoding="utf-8") as f:
    tabmwp = json.load(f)

In [15]:
name= "tabmwp" # tatqa, tabmwp
length_test= 200 # số lượng mẫu muốn test
if name == "gsm8k":
   DATA=gsm8k 
   name_model="GSM8K"
elif name == "tatqa":
   DATA=tatqa
   name_model="TATQA"
else:
   DATA=tabmwp
   name_model="TABMWP"

In [6]:
load_dotenv()
model=init_chat_model('gpt-4.1-mini',model_provider='openai',temperature=0.2)

In [7]:
class Step(BaseModel):
    explanation: str
    output: str
class MathReasoning(BaseModel):
    steps: list[Step]
    final_answer: str

In [8]:
model_with_tools = model.with_structured_output(MathReasoning)

In [9]:
def extract_ground_truth(answer, dataset_type):
    if dataset_type == "gsm8k":
        match = re.search(r"####\s*([\d,./]+)", answer)
        if match:
            return match.group(1).replace(",", "").strip()
        return answer.strip()
    elif dataset_type == "tatqa":
        # Có thể là list hoặc chuỗi, lấy phần đầu tiên nếu là list
        if isinstance(answer, list):
            ans = str(answer[0]).strip()
        else:
            ans = str(answer).strip()
        # Nếu là dạng [2019] hoặc ["2019"], loại bỏ ngoặc vuông và dấu nháy
        ans = re.sub(r'^[\[\"]*([\d\-\.\/]+)[\]\"]*$', r'\1', ans)
        # Nếu là phân số thì giữ lại dấu "/"
        if '/' in ans:
            ans = re.sub(r"[^-\d/\.]", "", ans)
        else:
            ans = re.sub(r"[^-\d\.]", "", ans)
        return ans
    else:
        return str(answer).strip()
    
def compare_answers(predicted: str, actual: str, eps: float = 1e-3) -> bool:
    try:
        pred = round(float(predicted.strip()))
        act = round(float(actual.strip()))
        return abs(pred - act) < eps
    except ValueError:
        return predicted.strip().lower() == actual.strip().lower()

In [22]:
def process_item(item,dataset_type):
    question = item["question"]
    context = item.get("context", "")
    true_answer = extract_ground_truth(item["answer"],dataset_type)
    try:
        # Nếu có context, nối vào trước question
        if context.strip():
            user_content = f"# Context:\n{context}\n\n# Question: {question}"
        else:
            user_content = question

        messages = [
            SystemMessage(content="""
            You are a math expert.
            For every question and context, you **must** respond using the `MathReasoning` tool.
            - Do not respond with plain text or natural language.
            - Use a list of `Step`s to break down the reasoning.
            - Include a `final_answer` as a single number, no units or symbols.
            - If you cannot solve it, return a final_answer of "unknown".
            - When dealing with money, do not round to thousands unless explicitly stated.
            """),
            HumanMessage(content=user_content)
        ]
        ai_msg = model_with_tools.invoke(messages)
        predicted_answer = ai_msg.final_answer

        return {
            "question": question,
            "context": context,
            "true_answer": true_answer,
            "step": ai_msg.steps,
            "predicted_answer": predicted_answer,
            "correct": compare_answers(predicted_answer, true_answer)
        }
    except Exception as e:
        return {"error": str(e), "question": question}

dataset = []
for item in DATA:
    dataset.extend(standardize_item(item, f"{name}"))

results = []
correct = 0
total = len(dataset[:length_test])
with ThreadPoolExecutor(max_workers=5) as executor:
    futures = [executor.submit(process_item, item, f"{name}") for item in dataset[:length_test]]
    for future in tqdm(as_completed(futures), total=total):
        result = future.result()
        if "error" not in result:
            results.append(result)
            if result["correct"]:
                correct += 1
        else:
            print(f"Error on question: {result['question'][:60]}... => {result['error']}")
accuracy = correct / total * 100
print(f"Accuracy: {accuracy:.2f}% ({correct}/{total})")



  0%|          | 0/300 [00:00<?, ?it/s]Failed to send compressed multipart ingest: langsmith.utils.LangSmithRateLimitError: Rate limit exceeded for https://api.smith.langchain.com/runs/multipart. HTTPError('429 Client Error: Too Many Requests for url: https://api.smith.langchain.com/runs/multipart', '{"error":"Too many requests: tenant exceeded usage limits: Monthly unique traces usage limit exceeded"}\n')trace=39f63ba3-110a-4de1-8361-65e330ddfc4e,id=39f63ba3-110a-4de1-8361-65e330ddfc4e; trace=39f63ba3-110a-4de1-8361-65e330ddfc4e,id=0baa8bc2-1176-4235-9482-1933f345b2c6; trace=7bf2ad64-64c0-4285-8be3-43270dca439f,id=7bf2ad64-64c0-4285-8be3-43270dca439f; trace=7bf2ad64-64c0-4285-8be3-43270dca439f,id=8873b18d-d232-4990-b35b-451377738eac; trace=464c0da7-4bb4-4b71-99c3-802ad57b4f3b,id=464c0da7-4bb4-4b71-99c3-802ad57b4f3b; trace=464c0da7-4bb4-4b71-99c3-802ad57b4f3b,id=a69e3e47-dbfe-4f0a-8036-58071fedc3f5; trace=4ff92500-8bb5-48d6-b06d-93f4142e8636,id=4ff92500-8bb5-48d6-b06d-93f4142e8636; tra

Accuracy: 98.00% (294/300)





In [23]:
output_path = "CoT_results.json"
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
def custom_encoder(obj):
    if hasattr(obj, "model_dump"):
        return obj.model_dump()
    elif hasattr(obj, "dict"):
        return obj.dict()
    else:
        raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")

wrong_answers = [r for r in results if not r.get("correct", False)]

try:
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(wrong_answers, f, ensure_ascii=False, indent=2, default=custom_encoder)
    print(f"Đã lưu {len(wrong_answers)} kết quả sai vào {output_path}")
except TypeError as e:
    print(f"Lỗi khi ghi file JSON: {e}")


Đã lưu 6 kết quả sai vào CoT_results.json


In [16]:
import math
@traceable(run_type="chain")
def target_function(inputs: dict):
    question = inputs["question"]
    context = inputs.get("context", "")
    # Nếu có context, nối vào trước question
    if context.strip():
        user_content = f"# Context:\n{context}\n\n# Question: {question}"
    else:
        user_content = question

    messages = [
        SystemMessage(content="""
        You are a math expert.
        For every question, you **must** respond using the `MathReasoning` tool.
        - Do not respond with plain text or natural language.
        - Use a list of `Step`s to break down the reasoning.
        - Include a `final_answer` as a single number, no units or symbols.
        - If you cannot solve it, return a final_answer of "unknown".
        - When dealing with money, do not round to thousands unless explicitly stated.
        """),
        HumanMessage(content=user_content)
    ]
    ai_msg = model_with_tools.invoke(messages)
    predicted_answer = ai_msg.final_answer
    # Nếu muốn log reasoning steps
    return {
        "final_answer": predicted_answer,
        "steps": getattr(ai_msg, "steps", None)
    }

@traceable(run_type="tool")
def compare_result(inputs: dict, reference_outputs: dict, outputs: dict):
    reference_response = extract_ground_truth(reference_outputs["answer"],f"{name}")
    run_response = outputs.get("final_answer")
    reference_response = str(reference_response).strip()
    run_response = str(run_response).strip()
    try:
        score = math.isclose(float(reference_response), float(run_response), rel_tol=1e-3)
    except Exception:
        score = (reference_response == run_response)
    return {"key": "is_correct", "score": int(score)}

client = Client()
evaluate(
    target_function,
    data=client.list_examples(dataset_name=f"{name_model}", splits=["base"]),
    evaluators=[compare_result],
    experiment_prefix=f"{name_model} - CoT"
)


View the evaluation results for experiment: 'TABMWP - CoT-9a3b5b26' at:
https://smith.langchain.com/o/5fc25493-0003-4d31-ac07-9d677640262f/datasets/eb4dd623-5fda-46f4-8535-aa2f9d48f874/compare?selectedSessions=74614afd-ccbb-4c94-9e75-bd7797d4912a




0it [00:00, ?it/s]

Unnamed: 0,inputs.context,inputs.question,outputs.final_answer,outputs.steps,error,reference.answer,feedback.is_correct,execution_time,example_id,id
0,Track team sizes (number of members)\n\n| Stem...,"Mr. McCall, a track coach, recorded the sizes ...",6,[explanation='Identify the teams with at least...,,6,1,3.166458,002abfd1-6f45-4d30-a8f0-0bb4bb1d8556,425e9bf1-356d-4a64-bbff-b9c8a6687366
1,Fish per tank\n\n| Stem | Leaf |\n| --- | ---...,A pet store owner had his staff count the numb...,7,[explanation='Identify the numbers of fish in ...,,7,1,4.081901,01b1dbd1-4d6f-45d9-831c-a7a69b3b9859,1524a6c9-b08c-4b7c-9d7d-3a7722e52734
2,| Column 1 | Column 2 |\n| --- | --- |\n| busi...,How much money does Estelle need to buy a brig...,8557,[explanation='Identify the price of the bright...,,8557,1,2.459812,039dcdea-6eee-42df-8018-e8cd018f74a3,efafde81-0c4c-4542-936c-a19c2acea75f
3,Pages written\n\n| Day | Number of pages |\n| ...,An author kept a log of how many pages he wrot...,3,[explanation='List the number of pages written...,,3,1,6.908228,0536acf5-d205-4d21-9e9a-1c964fa316c9,9cdd2681-6424-4ebe-bf01-3df8859aa198
4,| Column 1 | Column 2 |\n| --- | --- |\n| Euro...,How much money does Donald need to buy an Afri...,2332,[explanation='Identify the prices of the Afric...,,2332,1,1.885036,09d33bfe-5f0e-4b4f-be25-e8d96ef1a75a,bca4e980-72a5-4d90-bc39-5544ae926a1b
...,...,...,...,...,...,...,...,...,...,...
195,Middletown School District sports budget\n\n| ...,Each year the Middletown School District publi...,1,[explanation='Identify the per-student budget ...,,1,1,2.292172,fa39a138-8df4-4d69-8982-1c25bb0fa827,7d2d8889-c567-4b8e-bff6-a5f967267319
196,| Column 1 | Column 2 |\n| --- | --- |\n| Aust...,How much money does Sasha need to buy a Europe...,18456,[explanation='Identify the cost of one Europea...,,18456,1,2.514443,fa77d470-7211-432c-b430-1322e7670d00,a100ec8f-eee1-4882-a004-9d579ab4f2c7
197,Clubs\n\n| Name | Number of clubs |\n| --- | -...,Some students compared how many clubs they bel...,5,[explanation='List the number of clubs each st...,,5,1,3.166517,fcdd496f-89de-4bf0-a679-2b47d8447606,a541a534-7a70-4723-a89a-ed32595c3c41
198,| Column 1 | Column 2 |\n| --- | --- |\n| blac...,"Darnel purchased 4 pounds of rocks, 2 pounds o...",30,[explanation='Calculate the cost of 4 pounds o...,,30,1,2.860447,fd72da95-eda4-4c1a-8e41-fd12a2034a9e,2d773ed6-4cb8-491a-be8a-60777f18fe74


In [18]:
experiment_name = "GSM8K - CoT-af9abd37" 

runs = list(client.list_runs(project_name=experiment_name, execution_order=1))

data = []
count = 0
for run in runs:
    is_correct = None
    feedback_list = client.list_feedback(run_ids=[run.id])
    for fb in feedback_list:
        if fb.key == "is_correct":
            is_correct = fb.score
    count+=1
    row = {
        "run_id": run.id,
        "error": run.error,
        "latency_sec": (run.end_time - run.start_time).total_seconds() if run.end_time and run.start_time else None,
        "total_cost": run.total_cost,
        "input_tokens": run.prompt_tokens,
        "output_tokens": run.completion_tokens,
        "total_tokens": run.total_tokens,
        "is_correct": is_correct,
    }
    data.append(row)

df_cot_gsm8k = pd.DataFrame(data)

In [19]:
experiment_name = "TATQA - CoT-5c817cf2" 

runs = list(client.list_runs(project_name=experiment_name, execution_order=1))

data = []
count = 0
for run in runs:
    is_correct = None
    feedback_list = client.list_feedback(run_ids=[run.id])
    for fb in feedback_list:
        if fb.key == "is_correct":
            is_correct = fb.score
    count+=1
    row = {
        "run_id": run.id,
        "error": run.error,
        "latency_sec": (run.end_time - run.start_time).total_seconds() if run.end_time and run.start_time else None,
        "total_cost": run.total_cost,
        "input_tokens": run.prompt_tokens,
        "output_tokens": run.completion_tokens,
        "total_tokens": run.total_tokens,
        "is_correct": is_correct,
    }
    data.append(row)

df_cot_tatqa = pd.DataFrame(data)

In [20]:
experiment_name = "TABMWP - CoT-9a3b5b26" 

runs = list(client.list_runs(project_name=experiment_name, execution_order=1))

data = []
count = 0
for run in runs:
    is_correct = None
    feedback_list = client.list_feedback(run_ids=[run.id])
    for fb in feedback_list:
        if fb.key == "is_correct":
            is_correct = fb.score
    count+=1
    row = {
        "run_id": run.id,
        "error": run.error,
        "latency_sec": (run.end_time - run.start_time).total_seconds() if run.end_time and run.start_time else None,
        "total_cost": run.total_cost,
        "input_tokens": run.prompt_tokens,
        "output_tokens": run.completion_tokens,
        "total_tokens": run.total_tokens,
        "is_correct": is_correct,
    }
    data.append(row)

df_cot_tabmwp = pd.DataFrame(data)

In [21]:
df_cot_tatqa.to_csv("CoT_TATQA.csv", index=False)
df_cot_gsm8k.to_csv("CoT_GSM8K.csv", index=False)
df_cot_tabmwp.to_csv("CoT_TABMWP.csv", index=False)