In [1]:
from openai import OpenAI
from dateutil.relativedelta import relativedelta
import os
import json
from pydantic import BaseModel,Field
from langchain.llms import OpenAI
from dotenv import load_dotenv
from langchain.chat_models import ChatOpenAI, init_chat_model
from langchain_core.messages import HumanMessage, SystemMessage
from tqdm import tqdm
import re
from concurrent.futures import ThreadPoolExecutor, as_completed
import langsmith as ls
from langsmith import traceable, trace
from langsmith import Client, traceable, evaluate
from preprocess_data import prepare_qa_input_with_answer_filter
from datetime import datetime



In [2]:
file_path = '../dataset_langsmith/gsm8k.jsonl'

with open(file_path, "r", encoding="utf-8") as f:
    gsm8k = [json.loads(line) for line in f]

In [3]:
folder_path = "../dataset_langsmith/"
filename="tatqa.json"
file_path = os.path.join(folder_path, filename)
with open(file_path, "r", encoding="utf-8") as f:
    raw_data_tatqa = json.load(f)
tatqa = prepare_qa_input_with_answer_filter(raw_data_tatqa)

In [4]:
folder_path = "../dataset_langsmith/"
filename="tabmwp.json"
file_path = os.path.join(folder_path, filename)
with open(file_path, "r", encoding="utf-8") as f:
    tabmwp = json.load(f)

In [5]:
name= "tatqa" # tatqa, tabmwp
length_test= 1
if name == "gsm8k":
   DATA=gsm8k 
   name_dataset="GSM8K"
elif name == "tatqa":
   DATA=tatqa
   name_dataset="TATQA"
else:
   DATA=tabmwp
   name_dataset="TABMWP"

In [6]:
load_dotenv()
model=init_chat_model('gpt-4o-mini',model_provider='openai',temperature=0.2)

In [7]:
class Step(BaseModel):
    explanation: str
    output: str
class MathReasoning(BaseModel):
    steps: list[Step]
    final_answer: str

In [8]:
model_with_tools = model.with_structured_output(MathReasoning)

In [9]:
def extract_ground_truth(answer, dataset_type):
    if dataset_type == "gsm8k":
        match = re.search(r"####\s*([\d,./]+)", answer)
        if match:
            raw_ans = match.group(1).replace(",", "").strip()
        else:
            raw_ans = answer.strip()
    elif dataset_type == "tatqa":
        if isinstance(answer, list):
            ans = str(answer[0]).strip()
        else:
            ans = str(answer).strip()
        ans = re.sub(r'^[\[\"]*([\d\-\.\/]+)[\]\"]*$', r'\1', ans)
        if '/' in ans:
            ans = re.sub(r"[^-\d/\.]", "", ans)
        else:
            ans = re.sub(r"[^-\d\.]", "", ans)
        raw_ans = ans
    else:
        raw_ans = str(answer).strip()
    return raw_ans


def compare_answers(predicted: str, actual: str, eps: float = 1e-2) -> bool:
    try:
        if '/' in predicted:
            try:
                numerator, denominator = predicted.split('/')
                predicted = str(float(numerator) / float(denominator))
            except Exception:
                predicted=str(predicted)
        if '/' in actual:
            try:
                numerator, denominator = actual.split('/')
                actual = str(float(numerator) / float(denominator))
            except Exception:
                actual=str(actual)
        pred = float(predicted.strip())
        act = float(actual.strip())
        return abs(pred - act) <= eps
    except ValueError:
        return predicted.strip().lower() == actual.strip().lower()

def unwrap_singleton(value):
    # Nếu là list hoặc tuple Python
    if isinstance(value, (list, tuple)) and len(value) == 1:
        return value[0]
    # Nếu là chuỗi dạng '[2018]' hoặc "['2018']"
    if isinstance(value, str):
        import re
        match = re.fullmatch(r"\[\s*'?([-\w\.]+)'?\s*\]", value.strip())
        if match:
            return match.group(1)
    return value

In [None]:

@traceable(run_type="chain")
def target_function(inputs: dict):
    question = inputs["question"]
    context = inputs.get("context", "")
    # Nếu có context, nối vào trước question
    if context.strip():
        user_content = f"# Context:\n{context}\n\n# Question: {question}"
    else:
        user_content = question

    messages = [
        SystemMessage(content="""
        You are a math expert.
        For every question, you **must** respond using the `MathReasoning` tool.
        - Do not respond with plain text or natural language.
        - Let's think step by step.
        - Use a list of `Step`s to break down the reasoning.
        - Include a `final_answer` as a single number, no units or symbols.
        """),
        HumanMessage(content=user_content)
    ]
    ai_msg = model_with_tools.invoke(messages)
    predicted_answer = ai_msg.final_answer  

    return {
        "final_answer": predicted_answer,
        "steps": getattr(ai_msg, "steps", None),
        "question": question,
        "context": context
    }
all_results = [] 
@traceable(run_type="tool")
def compare_result(inputs: dict, reference_outputs: dict, outputs: dict):
    predicted=str(unwrap_singleton(outputs["final_answer"]))
    actual = extract_ground_truth(str(reference_outputs["answer"]), f"{name}")
    eps = 1e-2
    try:
        if '/' in predicted:
            try:
                numerator, denominator = predicted.split('/')
                predicted = str(float(numerator) / float(denominator))
            except Exception:
                predicted = str(predicted)
        if '/' in actual:
            try:
                numerator, denominator = actual.split('/')
                actual = str(float(numerator) / float(denominator))
            except Exception:
                actual = str(actual)
        pred = float(predicted.strip())
        act = float(actual.strip())
        score = abs(pred - act) <= eps
    except ValueError:
        score = predicted.strip().lower() == actual.strip().lower()
    
    all_results.append({
        "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
        "question": inputs["question"],
        "steps": [step.dict() for step in outputs["steps"]],
        "true_answer": actual,
        "predicted_answer": predicted,
        "context": inputs.get("context", ""),
        "correct": score
    })

    # Sau khi chạy xong:
    correct = sum(1 for x in all_results if x["correct"])
    total = len(all_results)
    accuracy = correct / total * 100
    wrong_answers = [x for x in all_results if not x["correct"]]

    summary = {
        "accuracy": accuracy,
        "correct": correct,
        "total": total,
        "wrong_answers": wrong_answers
    }

    with open(f"save_log/CoT_results - {name}.json", "w", encoding="utf-8") as f:
        json.dump(summary, f, ensure_ascii=False, indent=2)
    return {"key": "is_correct", "score": int(score)}

client = Client()
evaluate(
    target_function,
    data=client.list_examples(dataset_name=f"{name_dataset}", splits=["base"]),
    evaluators=[compare_result],
    experiment_prefix=f"CoT_{name_dataset}"
)

View the evaluation results for experiment: 'CoT_TATQA-dbda439e' at:
https://smith.langchain.com/o/c422d8c3-e7d7-402f-a3bb-0998c67d5b6a/datasets/eacfd289-4d35-4de1-8e33-4072e26cdc28/compare?selectedSessions=49d179e5-0810-4a87-8d20-6b4a1bf1079b




0it [00:00, ?it/s]

Unnamed: 0,inputs.context,inputs.question,outputs.final_answer,outputs.steps,outputs.question,outputs.context,error,reference.answer,feedback.is_correct,execution_time,example_id,id
0,"| | | Year ended March 31, | |\n| --- | --...",What was the average Expected volatility betwe...,40.77,[explanation='Identify the expected volatility...,What was the average Expected volatility betwe...,"| | | Year ended March 31, | |\n| --- | --...",,40.77,1,5.715340,224a752d-98cc-4f75-9e76-e0f32e632042,813a08ef-e6e7-402a-bd71-34f32318b7a6
1,Reconciliations of total segment net revenues ...,What is the segment net revenues in 2018?,6835,[explanation='Identify the segment net revenue...,What is the segment net revenues in 2018?,Reconciliations of total segment net revenues ...,,6835,1,1.601649,226012ee-8124-41a2-8bbf-c8ed7f285364,57719fab-fb34-4dd6-b5b8-77ee7c9834ca
2,| | 2019 | | 2018 | |\n| --- | --- | --- | ...,What is the company's average gross profit in ...,1171074,[explanation='Identify the gross profit for 20...,What is the company's average gross profit in ...,| | 2019 | | 2018 | |\n| --- | --- | --- | ...,,1171074.5,0,6.949831,22aede05-bf0d-4206-a65c-e4ece3a6cf44,1a2c6086-cd28-4d14-98bb-44e7aaa96b4b
3,"| | | Years Ended December 31, | |\n| --- |...",What is the average Operating expenses?,14.7,[explanation='Identify the operating expenses ...,What is the average Operating expenses?,"| | | Years Ended December 31, | |\n| --- |...",,14.7,1,2.454621,2434e921-efcf-4cb5-aa64-8e2bde0d9964,b20a7ca4-ddbe-4517-b652-67914f898f78
4,(1) Revenues for Corporate and Other represent...,What were the operating expenses for Software ...,394.8,"[explanation=""The operating expenses for Softw...",What were the operating expenses for Software ...,(1) Revenues for Corporate and Other represent...,,394.8,1,1.884235,2623683e-4068-4742-a12b-d3f5a7f48a65,9e28a802-0cbb-4f25-9a61-24451bcd8c9b
...,...,...,...,...,...,...,...,...,...,...,...,...
265,| ($ in millions) | | | | |\n| --- | --- |...,What is the average of Cloud & Data Platforms ...,9051,[explanation='Identify the revenue for Cloud &...,What is the average of Cloud & Data Platforms ...,| ($ in millions) | | | | |\n| --- | --- |...,,9051,1,2.043263,fc005317-ac54-4cbf-909f-2b5accdee8cf,6fa3070b-8749-4eaf-85f8-3c10d0958d88
266,"| | December 31, | |\n| --- | --- | --- |\n|...",What is the difference between the Unused line...,109.9,[explanation='Identify the unused lines of cre...,What is the difference between the Unused line...,"| | December 31, | |\n| --- | --- | --- |\n|...",,109.9,1,2.714561,fcaf59a5-a97e-4961-9ce5-acfd4f238c42,5bd25cd8-eed5-4a85-a7d0-2d8ff8ee160d
267,"| | | Year Ended December 31, | |\n| --- | ...",What was the percentage change in cost of reve...,16.75,[explanation='Identify the cost of revenue for...,What was the percentage change in cost of reve...,"| | | Year Ended December 31, | |\n| --- | ...",,16.76,0,3.711781,fe31299d-d5d6-4e6b-9d55-cb6b6b155312,490c8445-bf9c-46bd-aff0-9cfec3ad9ab3
268,"| At December 31, 2019 | Operating Leases | Fi...",In which year was Operating Leases greater tha...,2022,[explanation='We need to identify the years wh...,In which year was Operating Leases greater tha...,"| At December 31, 2019 | Operating Leases | Fi...",,2022,1,7.459038,fea00b83-cc28-445f-8598-6c891980a9f8,7519e540-136c-4087-840a-345484585821
