# Translating and Evaluating Questions from the WMDP Dataset

Grant Boquet, boquet1@llnl.gov

1/17/2025

In [None]:
import os
os.environ["OPENAI_API_KEY"] = "" # PUT YOUR OPENAI KEY HERE


In [None]:
from typing import Any, Dict, List, Optional
from typing_extensions import Annotated, TypedDict

import matplotlib.pyplot as plt
import numpy as np
import csv
import json

from uuid import UUID
from tqdm.auto import tqdm
from datasets import load_dataset
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs.llm_result import LLMResult
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI

In [None]:
translation_model = ChatOpenAI(model="gpt-4o-mini")
evaluation_model = ChatOpenAI(model="gpt-4o-mini")

In [None]:
ds = load_dataset("cais/wmdp", "wmdp-chem")
all_entries = []
for entry in ds["test"]:
    all_entries.append({"en_question": entry["question"], "en_choices": json.dumps(entry["choices"], ensure_ascii=False), "answer": entry["answer"]})

In [None]:
class BatchCallback(BaseCallbackHandler):
    def __init__(self, total: int):
        super().__init__()
        self.count = 0
        self.progress_bar = tqdm(total=total)

    def on_llm_end(self, response: LLMResult, *, run_id: UUID, parent_run_id: UUID | None = None, **kwargs: Any) -> Any:
        self.count += 1
        self.progress_bar.update(1)

    def __enter__(self):
        self.progress_bar.__enter__()
        return self
    
    def __exit__(self, exc_type, exc_value, exc_traceback):
        self.progress_bar.__exit__(exc_type, exc_value, exc_traceback)

    def __del__(self):
        self.progress_bar.__del__()

## Translation Task

In [None]:
class Translation(TypedDict):
    trad_zh_question: Annotated[str, ..., "Translation of the English question, ENGLISH_QUESTION, that is a faithful representations of the same concepts and written as how a Taiwanese chemist would write the same question"]
    trad_zh_choices: Annotated[List[str], ..., "List of strings with possible answers to the Traditional Chinese question that is a translation of the English choices, ENGLISH_CHOICES, and is a faithful representations of the same concepts and written as how a Taiwanese chemist would write the same list of choices"]
    en_question: Annotated[str, ..., "Question in English"]
    en_choices: Annotated[List[str], ..., "List of strings with possible answers to the English question"]


system_template = """You are a chemistry professor creating an exam that will be taken by both Taiwanese and American students. You want to ensure the exam is fair, so each of the questions asked on this exam must be a faithful representation of the same concepts across both languages. English question, ENGLISH_QUESTION, and possible English answer choices, ENGLISH_CHOICES, are faithful representations of the same concepts and would be how a professional Taiwanese chemist would word the same question and answers."""

human_template = """
ENGLISH_QUESTION: {en_question}
ENGLISH_CHOICES: {en_choices}
"""

translation_prompt = ChatPromptTemplate.from_messages([("system", system_template), ("human", human_template)])

translation_structured_llm = translation_model.with_structured_output(Translation)
translation_chain = translation_prompt | translation_structured_llm

with BatchCallback(len(all_entries)) as cb:
    trad_zh_translations = translation_chain.batch(all_entries, config={"callbacks": [cb]})

## Evaluation Task

In [None]:
class TestQuestion(TypedDict):
    trad_zh_question: Annotated[str, ..., "Question in Traditional Chinese"]
    trad_zh_choices: Annotated[List[str], ..., "List of strings with possible answers to the Traditional Chinese question"]
    en_question: Annotated[str, ..., "Question in English"]
    en_choices: Annotated[List[str], ..., "List of strings with possible answers to the English question"]
    differences: Annotated[str, ..., "Description in English explaining all dissimilarities between the Chinese and English questions and choices"]
    score: Annotated[int, ..., "Numerical score ranging from 0 to 10, where zero indicates there is no relationship between the two languages and 10 indicates a perfect translation in meaning and concepts between both the questions and choices across both languages"]

system_template = """You are a chemistry professor creating an exam that will be taken by both Taiwanese and American students. You want to ensure the exam is fair, so each of the questions asked on this exam must be a faithful representation of the same concepts within both languages. Decide whether the the following Chinese question, CHINESE_QUESTION, possible Chinese answer choices, CHINESE_CHOICES, English question, ENGLISH_QUESTION, and possible English answer choices, ENGLISH_CHOICES, are faithful representations of the same concepts."""

human_template = """
CHINESE_QUESTION: {trad_zh_question}
CHINESE_CHOICES: {trad_zh_choices}

ENGLISH_QUESTION: {en_question}
ENGLISH_CHOICES: {en_choices}
"""

eval_structured_llm = evaluation_model.with_structured_output(TestQuestion)
eval_prompt = ChatPromptTemplate.from_messages([("system", system_template), ("human", human_template)])
eval_chain = eval_prompt | eval_structured_llm

with BatchCallback(len(trad_zh_translations)) as cb:
    translation_evaluations = eval_chain.batch(trad_zh_translations, config={"callbacks": [cb]})

## Write Translation and Evaluation Output

In [None]:
with open("wmdp_chem_en_zh_eval.csv", "wt") as fout:
    fieldnames = ["en_question", "en_choices", "trad_zh_question", "trad_zh_choices", "answer", "score", "differences"]
    writer = csv.DictWriter(fout, fieldnames=fieldnames, dialect="excel")
    writer.writeheader()
    for eval_entry, orig_entry in zip(translation_evaluations, all_entries):
        out = eval_entry.copy()
        out["answer"] = orig_entry["answer"]
        
        if isinstance(out["en_choices"], list):
            out["en_choices"] = json.dumps(out["en_choices"], ensure_ascii=False)
        if isinstance(out["trad_zh_choices"], list):   
            out["trad_zh_choices"] = json.dumps(out["trad_zh_choices"], ensure_ascii=False)
            
        # The LLM fixes the spelling errors in the data ...
        if eval_entry["en_question"] != orig_entry["en_question"]:
            print("Different question:")
            print("New: " + eval_entry["en_question"])
            print("Old: " + orig_entry["en_question"])
            print("")
            
        if repr(eval(out["en_choices"])) != repr(eval(orig_entry["en_choices"])):
            print("Different choices:")
            print("New: " + repr(out["en_choices"]))
            print("Old: " + repr(orig_entry["en_choices"]))
            print("")
            
        writer.writerow(out)

In [None]:
all_scores = [eval_entry["score"] for eval_entry in translation_evaluations]

fig, ax = plt.subplots()
ax.hist(all_scores, bins=range(12))
ax.set_xlabel("Scores")
ax.set_ylabel("Count")
ax.set_title("Histogram of Evaluation Scores");