In [None]:
from openai import OpenAI
from dateutil.relativedelta import relativedelta
import os
import json
from pydantic import BaseModel,Field
from langchain.llms import OpenAI
from dotenv import load_dotenv
from langchain.chat_models import ChatOpenAI, init_chat_model
from langchain_core.messages import HumanMessage, SystemMessage
from tqdm import tqdm
import re
from concurrent.futures import ThreadPoolExecutor, as_completed
import langsmith as ls
from langsmith import traceable, trace
from langsmith import Client, traceable, evaluate
from preprocess_data import prepare_qa_input_with_answer_filter,standardize_item
from datetime import datetime

In [2]:
file_path = '../dataset_langsmith/gsm8k.jsonl'

with open(file_path, "r", encoding="utf-8") as f:
    gsm8k = [json.loads(line) for line in f]

In [3]:
folder_path = "../dataset_langsmith/"
filename="tatqa.json"
file_path = os.path.join(folder_path, filename)
with open(file_path, "r", encoding="utf-8") as f:
    raw_data_tatqa = json.load(f)
tatqa = prepare_qa_input_with_answer_filter(raw_data_tatqa)

In [4]:
folder_path = "../dataset_langsmith/"
filename="tabmwp.json"
file_path = os.path.join(folder_path, filename)
with open(file_path, "r", encoding="utf-8") as f:
    tabmwp = json.load(f)

In [None]:
name= "gsm8k" # tatqa, tabmwp, gsm8k
length_test= 2# số lượng mẫu muốn test
if name == "gsm8k":
   DATA=gsm8k 
   name_dataset="GSM8K"
elif name == "tatqa":
   DATA=tatqa
   name_dataset="TATQA"
else:
   DATA=tabmwp
   name_dataset="TABMWP"

In [6]:
load_dotenv()
model=init_chat_model('gpt-4o-mini',model_provider='openai',temperature=0.2)

In [7]:
class MathReasoning(BaseModel):
    final_answer: str

In [8]:
model_with_tools = model.with_structured_output(MathReasoning)

In [9]:
def extract_ground_truth(answer, dataset_type):
    if dataset_type == "gsm8k":
        match = re.search(r"####\s*([\d,./]+)", answer)
        if match:
            raw_ans = match.group(1).replace(",", "").strip()
        else:
            raw_ans = answer.strip()
    elif dataset_type == "tatqa":
        if isinstance(answer, list):
            ans = str(answer[0]).strip()
        else:
            ans = str(answer).strip()
        ans = re.sub(r'^[\[\"]*([\d\-\.\/]+)[\]\"]*$', r'\1', ans)
        if '/' in ans:
            ans = re.sub(r"[^-\d/\.]", "", ans)
        else:
            ans = re.sub(r"[^-\d\.]", "", ans)
        raw_ans = ans
    else:
        raw_ans = str(answer).strip()

    # Chuyển phân số sang số thập phân nếu có
    if '/' in raw_ans:
        try:
            numerator, denominator = raw_ans.split('/')
            decimal_value = float(numerator) / float(denominator)
            return str(decimal_value)
        except Exception:
            return raw_ans
    else:
        return raw_ans

    
def compare_answers(predicted: str, actual: str, eps: float = 1e-2) -> bool:
    try:
        if '/' in predicted:
            try:
                numerator, denominator = predicted.split('/')
                predicted = str(float(numerator) / float(denominator))
            except Exception:
                predicted=str(predicted)
        pred = float(predicted.strip())
        act = float(actual.strip())
        return abs(pred - act) < eps
    except ValueError:
        return predicted.strip().lower() == actual.strip().lower()

def unwrap_singleton(value):
    # Nếu là list hoặc tuple Python
    if isinstance(value, (list, tuple)) and len(value) == 1:
        return value[0]
    # Nếu là chuỗi dạng '[2018]' hoặc "['2018']"
    if isinstance(value, str):
        import re
        match = re.fullmatch(r"\[\s*'?([-\w\.]+)'?\s*\]", value.strip())
        if match:
            return match.group(1)
    return value



In [None]:

def target_function(inputs: dict):
    question = inputs["question"]
    context = inputs.get("context", "")
    # Nếu có context, nối vào trước question
    if context.strip():
        user_content = f"# Context:\n{context}\n\n# Question: {question}"
    else:
        user_content = question

    messages = [
        SystemMessage(content="""
        You are a math expert.
            - Include a `final_answer` as a single number, no units or symbols.
            - If you cannot solve it, return a final_answer of "unknown".
            - When dealing with money, do not round to thousands unless explicitly stated.
        """),
        HumanMessage(content=user_content)
    ]
    ai_msg = model_with_tools.invoke(messages)
    predicted_answer = ai_msg.final_answer
  
    return {
        "final_answer": predicted_answer,
        "question": question,
        "context": context
    }
all_results = [] 
def compare_result(inputs: dict, reference_outputs: dict, outputs: dict):
    predicted=str(unwrap_singleton(outputs["final_answer"]))
    actual = extract_ground_truth(str(reference_outputs["answer"]), f"{name}")
    eps = 1e-2
    try:
        if '/' in predicted:
            try:
                numerator, denominator = predicted.split('/')
                predicted = str(float(numerator) / float(denominator))
            except Exception:
                predicted = str(predicted)
        if '/' in actual:
            try:
                numerator, denominator = actual.split('/')
                actual = str(float(numerator) / float(denominator))
            except Exception:
                actual = str(actual)
        pred = float(predicted.strip())
        act = float(actual.strip())
        score = abs(pred - act) <= eps
    except ValueError:
        score = predicted.strip().lower() == actual.strip().lower()
    
    
    all_results.append({
        "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
        "question": inputs["question"],
        "true_answer": actual,
        "predicted_answer": predicted,
        "context": inputs.get("context", ""),
        "correct": score
    })

    # Sau khi chạy xong:
    correct = sum(1 for x in all_results if x["correct"])
    total = len(all_results)
    accuracy = correct / total * 100
    wrong_answers = [x for x in all_results if not x["correct"]]

    summary = {
        "accuracy": accuracy,
        "correct": correct,
        "total": total,
        "wrong_answers": wrong_answers
    }

    with open(f"save_log/Zero-shot_results - {name}.json", "w", encoding="utf-8") as f:
        json.dump(summary, f, ensure_ascii=False, indent=2)
    return {"key": "is_correct", "score": int(score)}

client = Client()
evaluate(
    target_function,
    data=client.list_examples(dataset_name=f"{name_dataset}", splits=["base"]),
    evaluators=[compare_result],
    experiment_prefix=f"Zero-shot_{name_dataset}"
)

View the evaluation results for experiment: 'Zero-shot_GSM8K-8f0612d7' at:
https://smith.langchain.com/o/943b2ecf-878d-466a-a3ad-a779a4be18b4/datasets/95ec945e-8b0f-4361-92e4-638295e8bcb3/compare?selectedSessions=e18d128b-a6e4-45a8-8764-e2e1a6ed9c69




0it [00:00, ?it/s]

Unnamed: 0,inputs.question,outputs.final_answer,outputs.question,outputs.context,error,reference.answer,feedback.is_correct,execution_time,example_id,id
0,Josie grows grapes on her 10-acre farm. Each ...,100,Josie grows grapes on her 10-acre farm. Each ...,,,If each acre produces 5 tons of grapes per yea...,1,1.161696,0042ea8d-aba3-4e78-90af-fda9d178120f,d4a845d5-e00c-47bd-b974-42705c65f4cc
1,Jackie is trying to decide whether to do her t...,105,Jackie is trying to decide whether to do her t...,,,First find the total lost revenue if Jackie do...,0,0.655072,006cd270-8519-442d-a286-130ac608a060,e881702c-5bd7-4b65-9473-255fbe923a67
2,Raymond and Samantha are cousins. Raymond was ...,6,Raymond and Samantha are cousins. Raymond was ...,,,When Raymond's son was born Samantha was 23 - ...,0,0.666580,008dccc0-4f60-44cb-a59f-3c1368fedb46,99b42803-efea-48a7-ba95-2389bea96fe7
3,A water tank is filled with 120 liters of wate...,60,A water tank is filled with 120 liters of wate...,,,"After watering Celine's garden, 120 - 90 = <<1...",0,0.662889,01f2dc43-093e-46e8-a126-25bb29a1301e,15d8c9f0-e89f-4322-bcb5-a4c4d837f8bf
4,Uriah's book bag is getting too heavy for him....,30,Uriah's book bag is getting too heavy for him....,,,30 comic books weigh 7.5 pounds because 30 x ....,0,0.781380,03b7cc4c-d658-480a-b316-44d1dda6c211,54fe29a5-6819-42a4-be60-be70be85fd44
...,...,...,...,...,...,...,...,...,...,...
295,Jean has 30 lollipops. Jean eats 2 of the loll...,14,Jean has 30 lollipops. Jean eats 2 of the loll...,,,Jean has 30 - 2 = <<30-2=28>>28 lollipops\nJea...,1,0.611368,f9d015bb-525f-442f-b96e-9b07090aa6d8,6759dcbf-0e42-4983-aaad-076d3f99552d
296,"In a neighborhood, the number of rabbits pets ...",122,"In a neighborhood, the number of rabbits pets ...",,,"If there are two cats for every dog, and the n...",0,0.494173,fb141fed-989d-4619-9fe2-7e4200350594,b773ebbd-64b2-4201-bb49-f08a6e21664c
297,A mother goes shopping. She buys cocoa at $4.2...,5.00,A mother goes shopping. She buys cocoa at $4.2...,,,The total shopping cost is $4.2 + $9.45 + $1.3...,1,0.642035,ff002aa3-3abd-4e06-a87d-c8ee027543d7,a96fb4f8-95a2-461e-87a8-33a5ada9b8c3
298,Dylan needed chicken sausages and fish sausage...,44,Dylan needed chicken sausages and fish sausage...,,,He bought 38 + 6 = <<38+6=44>>44 fish sausages...,0,0.627994,ff19d22f-f0e7-4487-89d1-7fe3bfc8e058,f5eff820-32e2-4f8c-83f1-37d1a58d03ff
