In [13]:
from openai import OpenAI
from dateutil.relativedelta import relativedelta
import os
import json
from pydantic import BaseModel,Field
from langchain.llms import OpenAI
from dotenv import load_dotenv
from langchain.chat_models import ChatOpenAI, init_chat_model
from langchain_core.messages import HumanMessage, SystemMessage
from tqdm import tqdm
import re
from concurrent.futures import ThreadPoolExecutor, as_completed
import langsmith as ls
from langsmith import traceable, trace
from langsmith import Client, traceable, evaluate
from preprocess_data import prepare_qa_input_with_answer_filter,standardize_item

In [14]:
file_path = '../data/GSM8K/test.jsonl'

with open(file_path, "r", encoding="utf-8") as f:
    gsm8k = [json.loads(line) for line in f]

In [15]:
folder_path = "../dataset_langsmith/"
filename="tatqa_filtered.json"
file_path = os.path.join(folder_path, filename)
with open(file_path, "r", encoding="utf-8") as f:
    raw_data_tatqa = json.load(f)
tatqa = prepare_qa_input_with_answer_filter(raw_data_tatqa)

In [16]:
folder_path = "../dataset_langsmith/"
filename="tabmwp_filtered.json"
file_path = os.path.join(folder_path, filename)
with open(file_path, "r", encoding="utf-8") as f:
    tabmwp = json.load(f)

In [17]:
name= "gsm8k" # tatqa, tabmwp
length_test= 300 # số lượng mẫu muốn test
if name == "gsm8k":
   DATA=gsm8k 
   name_model="GSM8K"
elif name == "tatqa":
   DATA=tatqa
   name_model="TATQA"
else:
   DATA=tabmwp
   name_model="TABMWP"

In [18]:
load_dotenv()
model=init_chat_model('gpt-4.1-mini',model_provider='openai',temperature=0.2)

In [19]:
class Step(BaseModel):
    explanation: str
    output: str
class MathReasoning(BaseModel):
    steps: list[Step]
    final_answer: str

In [20]:
model_with_tools = model.with_structured_output(MathReasoning)

In [21]:
def extract_ground_truth(answer, dataset_type):
    if dataset_type == "gsm8k":
        match = re.search(r"####\s*([\d,./]+)", answer)
        if match:
            return match.group(1).replace(",", "").strip()
        return answer.strip()
    elif dataset_type == "tatqa":
        # Có thể là list hoặc chuỗi, lấy phần đầu tiên nếu là list
        if isinstance(answer, list):
            ans = str(answer[0]).strip()
        else:
            ans = str(answer).strip()
        # Nếu là dạng [2019] hoặc ["2019"], loại bỏ ngoặc vuông và dấu nháy
        ans = re.sub(r'^[\[\"]*([\d\-\.\/]+)[\]\"]*$', r'\1', ans)
        # Nếu là phân số thì giữ lại dấu "/"
        if '/' in ans:
            ans = re.sub(r"[^-\d/\.]", "", ans)
        else:
            ans = re.sub(r"[^-\d\.]", "", ans)
        return ans
    else:
        return str(answer).strip()
    
def compare_answers(predicted: str, actual: str, eps: float = 1e-3) -> bool:
    try:
        pred = round(float(predicted.strip()))
        act = round(float(actual.strip()))
        return abs(pred - act) < eps
    except ValueError:
        return predicted.strip().lower() == actual.strip().lower()

In [22]:
def process_item(item,dataset_type):
    question = item["question"]
    context = item.get("context", "")
    true_answer = extract_ground_truth(item["answer"],dataset_type)
    try:
        # Nếu có context, nối vào trước question
        if context.strip():
            user_content = f"# Context:\n{context}\n\n# Question: {question}"
        else:
            user_content = question

        messages = [
            SystemMessage(content="""
            You are a math expert.
            For every question and context, you **must** respond using the `MathReasoning` tool.
            - Do not respond with plain text or natural language.
            - Use a list of `Step`s to break down the reasoning.
            - Include a `final_answer` as a single number, no units or symbols.
            - If you cannot solve it, return a final_answer of "unknown".
            - When dealing with money, do not round to thousands unless explicitly stated.
            """),
            HumanMessage(content=user_content)
        ]
        ai_msg = model_with_tools.invoke(messages)
        predicted_answer = ai_msg.final_answer

        return {
            "question": question,
            "context": context,
            "true_answer": true_answer,
            "step": ai_msg.steps,
            "predicted_answer": predicted_answer,
            "correct": compare_answers(predicted_answer, true_answer)
        }
    except Exception as e:
        return {"error": str(e), "question": question}

dataset = []
for item in DATA:
    dataset.extend(standardize_item(item, f"{name}"))

results = []
correct = 0
total = len(dataset[:length_test])
with ThreadPoolExecutor(max_workers=5) as executor:
    futures = [executor.submit(process_item, item, f"{name}") for item in dataset[:length_test]]
    for future in tqdm(as_completed(futures), total=total):
        result = future.result()
        if "error" not in result:
            results.append(result)
            if result["correct"]:
                correct += 1
        else:
            print(f"Error on question: {result['question'][:60]}... => {result['error']}")
accuracy = correct / total * 100
print(f"Accuracy: {accuracy:.2f}% ({correct}/{total})")



  0%|          | 0/300 [00:00<?, ?it/s]Failed to send compressed multipart ingest: langsmith.utils.LangSmithRateLimitError: Rate limit exceeded for https://api.smith.langchain.com/runs/multipart. HTTPError('429 Client Error: Too Many Requests for url: https://api.smith.langchain.com/runs/multipart', '{"error":"Too many requests: tenant exceeded usage limits: Monthly unique traces usage limit exceeded"}\n')trace=39f63ba3-110a-4de1-8361-65e330ddfc4e,id=39f63ba3-110a-4de1-8361-65e330ddfc4e; trace=39f63ba3-110a-4de1-8361-65e330ddfc4e,id=0baa8bc2-1176-4235-9482-1933f345b2c6; trace=7bf2ad64-64c0-4285-8be3-43270dca439f,id=7bf2ad64-64c0-4285-8be3-43270dca439f; trace=7bf2ad64-64c0-4285-8be3-43270dca439f,id=8873b18d-d232-4990-b35b-451377738eac; trace=464c0da7-4bb4-4b71-99c3-802ad57b4f3b,id=464c0da7-4bb4-4b71-99c3-802ad57b4f3b; trace=464c0da7-4bb4-4b71-99c3-802ad57b4f3b,id=a69e3e47-dbfe-4f0a-8036-58071fedc3f5; trace=4ff92500-8bb5-48d6-b06d-93f4142e8636,id=4ff92500-8bb5-48d6-b06d-93f4142e8636; tra

Accuracy: 98.00% (294/300)





In [23]:
output_path = "CoT_results.json"
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
def custom_encoder(obj):
    if hasattr(obj, "model_dump"):
        return obj.model_dump()
    elif hasattr(obj, "dict"):
        return obj.dict()
    else:
        raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")

wrong_answers = [r for r in results if not r.get("correct", False)]

try:
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(wrong_answers, f, ensure_ascii=False, indent=2, default=custom_encoder)
    print(f"Đã lưu {len(wrong_answers)} kết quả sai vào {output_path}")
except TypeError as e:
    print(f"Lỗi khi ghi file JSON: {e}")


Đã lưu 6 kết quả sai vào CoT_results.json


In [24]:
import math
@traceable(run_type="chain")
def target_function(inputs: dict):
    question = inputs["question"]
    context = inputs.get("context", "")
    # Nếu có context, nối vào trước question
    if context.strip():
        user_content = f"# Context:\n{context}\n\n# Question: {question}"
    else:
        user_content = question

    messages = [
        SystemMessage(content="""
        You are a math expert.
        For every question, you **must** respond using the `MathReasoning` tool.
        - Do not respond with plain text or natural language.
        - Use a list of `Step`s to break down the reasoning.
        - Include a `final_answer` as a single number, no units or symbols.
        - If you cannot solve it, return a final_answer of "unknown".
        - When dealing with money, do not round to thousands unless explicitly stated.
        """),
        HumanMessage(content=user_content)
    ]
    ai_msg = model_with_tools.invoke(messages)
    predicted_answer = ai_msg.final_answer
    # Nếu muốn log reasoning steps
    return {
        "final_answer": predicted_answer,
        "steps": getattr(ai_msg, "steps", None)
    }

@traceable(run_type="tool")
def compare_result(inputs: dict, reference_outputs: dict, outputs: dict):
    reference_response = extract_ground_truth(reference_outputs["answer"],f"{name}")
    run_response = outputs.get("final_answer")
    reference_response = str(reference_response).strip()
    run_response = str(run_response).strip()
    try:
        score = math.isclose(float(reference_response), float(run_response), rel_tol=1e-3)
    except Exception:
        score = (reference_response == run_response)
    return {"key": "is_correct", "score": int(score)}

client = Client()
evaluate(
    target_function,
    data=client.list_examples(dataset_name=f"{name_model}", splits=["base"]),
    evaluators=[compare_result],
    experiment_prefix=f"{name_model} - Test Dataset"
)


Failed to send compressed multipart ingest: langsmith.utils.LangSmithRateLimitError: Rate limit exceeded for https://api.smith.langchain.com/runs/multipart. HTTPError('429 Client Error: Too Many Requests for url: https://api.smith.langchain.com/runs/multipart', '{"error":"Too many requests: tenant exceeded usage limits: Monthly unique traces usage limit exceeded"}\n')trace=d86a1d92-d407-40b0-b790-d75d6614c535,id=5c8fc6e5-d08e-492e-8141-07288db6986c; trace=d86a1d92-d407-40b0-b790-d75d6614c535,id=e930657b-36eb-4900-954d-f59b59666b10; trace=d86a1d92-d407-40b0-b790-d75d6614c535,id=e930657b-36eb-4900-954d-f59b59666b10; trace=d86a1d92-d407-40b0-b790-d75d6614c535,id=d86a1d92-d407-40b0-b790-d75d6614c535


View the evaluation results for experiment: 'GSM8K - Test Dataset-722f6f45' at:
https://smith.langchain.com/o/e7b5e917-6c40-46ad-a54b-83be84870fd4/datasets/a0f9251e-0ade-4046-a15d-3eb36e10746f/compare?selectedSessions=d7554ff4-c2a9-4c8d-91f0-1e428ebd2d03




0it [00:00, ?it/s]

Failed to send compressed multipart ingest: langsmith.utils.LangSmithRateLimitError: Rate limit exceeded for https://api.smith.langchain.com/runs/multipart. HTTPError('429 Client Error: Too Many Requests for url: https://api.smith.langchain.com/runs/multipart', '{"error":"Too many requests: tenant exceeded usage limits: Monthly unique traces usage limit exceeded"}\n')trace=c46075d5-3129-4699-9436-f9187f536015,id=c46075d5-3129-4699-9436-f9187f536015; trace=c46075d5-3129-4699-9436-f9187f536015,id=2749dde3-63e2-409e-8545-a236464ccfef; trace=c46075d5-3129-4699-9436-f9187f536015,id=07160b64-3f65-4cff-bee5-2c3658660c1d
Failed to send compressed multipart ingest: langsmith.utils.LangSmithRateLimitError: Rate limit exceeded for https://api.smith.langchain.com/runs/multipart. HTTPError('429 Client Error: Too Many Requests for url: https://api.smith.langchain.com/runs/multipart', '{"error":"Too many requests: tenant exceeded usage limits: Monthly unique traces usage limit exceeded"}\n')trace=c46

Unnamed: 0,inputs.question,outputs.final_answer,outputs.steps,error,reference.answer,feedback.is_correct,execution_time,example_id,id
0,Becky bought 20 apples for 45 cents each and r...,1,[explanation='Calculate the total cost for Bec...,,"Before the discount, Becky would have paid 20 ...",1,7.024641,01525728-4311-4402-b357-ff84108cf744,c46075d5-3129-4699-9436-f9187f536015
1,The bananas at the supermarket cost $0.80 each...,2,[explanation='Calculate the total number of ba...,,Ten bunches makes 10*3=$<<10*3=30>>30.\nThere ...,1,6.247846,034612d6-f271-4086-8803-c284072c82bf,c2916821-9a66-475f-91b2-ffe1e2f688a1
2,Verna loves to eat fruit. She bought three app...,7,[explanation='Calculate the total cost of the ...,,Three apples cost 3 x $1.50 = $<<3*1.5=4.50>>4...,1,4.826414,03e33507-0abf-476f-bd85-5eaaad4f87f9,ed2fa59e-2315-403a-927a-d11c737b7106
3,Emma can make and upload 72 vlogs per month. B...,18,[explanation='Calculate the total number of vl...,,Emma was able to make 18 + 21 + 15 = <<18+21+1...,1,2.385746,041a29d6-665d-46f8-af8c-ca74d776f0c9,a88608c7-26a8-435c-b5b0-94c9c337e9fb
4,"The combined age of Peter, Paul and Jean is 10...",50,[explanation='Let the age of John be x years.'...,,Let x be the age of John. Paul’s age is x + 10...,1,10.136547,06562ae2-9130-4f38-b334-265cc66ef8da,1acaf407-5759-4c80-9004-593915256ae8
...,...,...,...,...,...,...,...,...,...
295,A customer’s loyalty card at a store gives the...,31,[explanation='Calculate the rewards earned fro...,,The customer had 80 / 20 = $<<80/20=4>>4 in re...,1,3.806641,fc37b2c2-feda-46cb-80ab-69628352c1e3,7e9ac645-a401-46c6-bd03-c4ce7dd198d7
296,Edgar eats 18 pretzels a day. If his brother e...,63,[explanation='Edgar eats 18 pretzels a day. Hi...,,His brother eats 9 pretzels a day because 18 /...,1,2.176553,fccfc440-10fb-406b-8804-83951929c7f3,968fd66d-e68e-416c-8848-4d32258842bd
297,Madeline ate 6 grapes. Her brother wanted to m...,84,[explanation='Madeline ate 6 grapes.' output='...,,Madeline's brother used 6 x 5 = <<6*5=30>>30 g...,1,4.503954,fd267da6-b5f4-4d1d-9818-a88068193dd6,44660a89-2e24-4519-a8c8-98f45598be64
298,"As Sally walked to school, she was holding the...",34,[explanation='Calculate the number of red ball...,,"Sally started out with 25 red, 7 green, and 12...",1,5.010440,fd4425c2-c834-4e06-ba24-d31e18111299,9c71b6fe-f6cb-43e9-b5f3-6493579b87f3
