Skip to content
Merged
7 changes: 4 additions & 3 deletions examples/custom_tasks_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import lighteval.tasks.default_prompts as prompt
from lighteval.metrics.metrics import Metrics
from lighteval.tasks.lighteval_task import LightevalTaskConfig
from lighteval.tasks.tasks.gpqa import gpqa_instruct_prompt
from lighteval.tasks.tasks.gsm8k import gsm8k_prompt


gsm8k_test = LightevalTaskConfig(
name="gsm8k_test",
prompt_function=prompt.gsm8k,
prompt_function=gsm8k_prompt,
hf_repo="gsm8k",
hf_subset="main",
hf_avail_splits=["train", "test"],
Expand All @@ -42,7 +43,7 @@

gpqa_diamond_test = LightevalTaskConfig(
name="gpqa:diamond_test",
prompt_function=prompt.gpqa_instruct,
prompt_function=gpqa_instruct_prompt,
hf_repo="Idavidrein/gpqa",
hf_subset="gpqa_diamond",
hf_avail_splits=["train"],
Expand Down
41 changes: 24 additions & 17 deletions examples/nanotron/custom_evaluation_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,21 @@
"""

import re
from string import ascii_uppercase
from typing import List, Tuple

import lighteval.tasks.default_prompts as prompt
from lighteval.metrics.metrics import Metrics
from lighteval.metrics.normalizations import LogProbCharNorm, helm_normalizer, math_normalizer
from lighteval.tasks.default_prompts import LETTER_INDICES
from lighteval.tasks.lighteval_task import LightevalTaskConfig
from lighteval.tasks.requests import Doc
from lighteval.tasks.tasks.arc import arc_prompt
from lighteval.tasks.tasks.gsm8k import gsm8k_prompt
from lighteval.tasks.tasks.math import math_prompt
from lighteval.tasks.tasks.openbookqa import openbookqa_prompt
from lighteval.tasks.tasks.piqa import piqa_prompt
from lighteval.tasks.tasks.quac import quac_prompt
from lighteval.tasks.tasks.triviaqa import triviaqa_prompt
from lighteval.tasks.tasks.winogrande import winogrande_prompt


_TASKS_STRINGS: List[Tuple[LightevalTaskConfig, str]] = []
Expand All @@ -48,7 +55,7 @@ def commonsense_qa_prompt(line, task_name: str = None):
task_name=task_name,
query=line["question"],
choices=[f" {c}" for c in line["choices"]["text"]],
gold_index=LETTER_INDICES.index(line["answerKey"].strip()),
gold_index=ascii_uppercase.index(line["answerKey"].strip()),
instruction="",
)

Expand Down Expand Up @@ -99,7 +106,7 @@ def preprocess(text):
),
LightevalTaskConfig(
name="winogrande",
prompt_function=prompt.winogrande,
prompt_function=winogrande_prompt,
hf_repo="winogrande",
hf_subset="winogrande_xl",
metrics=[
Expand All @@ -112,7 +119,7 @@ def preprocess(text):
),
LightevalTaskConfig(
name="piqa",
prompt_function=prompt.piqa_harness,
prompt_function=piqa_prompt,
hf_repo="piqa",
hf_subset="plain_text",
metrics=[
Expand All @@ -139,7 +146,7 @@ def preprocess(text):
),
LightevalTaskConfig(
name="openbookqa",
prompt_function=prompt.openbookqa,
prompt_function=openbookqa_prompt,
hf_repo="openbookqa",
hf_subset="main",
metrics=[
Expand All @@ -152,7 +159,7 @@ def preprocess(text):
),
LightevalTaskConfig(
name="arc:easy",
prompt_function=prompt.arc,
prompt_function=arc_prompt,
hf_repo="ai2_arc",
hf_subset="ARC-Easy",
evaluation_splits=["test"],
Expand All @@ -167,7 +174,7 @@ def preprocess(text):
),
LightevalTaskConfig(
name="arc:challenge",
prompt_function=prompt.arc,
prompt_function=arc_prompt,
hf_repo="ai2_arc",
hf_subset="ARC-Challenge",
evaluation_splits=["test"],
Expand Down Expand Up @@ -216,7 +223,7 @@ def natural_questions_prompt(line, task_name: str = None):
WORLD_KNOWLEDGE_TASKS = [
LightevalTaskConfig(
name="trivia_qa",
prompt_function=prompt.triviaqa,
prompt_function=triviaqa_prompt,
hf_repo="trivia_qa",
hf_subset="rc.nocontext",
metrics=[
Expand Down Expand Up @@ -266,7 +273,7 @@ def boolq_prompt(line, task_name: str = None):
),
LightevalTaskConfig(
name="quac",
prompt_function=prompt.quac,
prompt_function=quac_prompt,
hf_repo="lighteval/quac_helm",
hf_subset="deault",
metrics=[
Expand All @@ -290,7 +297,7 @@ class CustomMathEvaluationTask(LightevalTaskConfig):
def __init__(
self,
name,
prompt_function=prompt.math,
prompt_function=math_prompt,
hf_repo="DigitalLearningGmbH/MATH-lighteval",
hf_subset=None,
metrics=[
Expand Down Expand Up @@ -329,7 +336,7 @@ def __init__(
]
GSM8K = LightevalTaskConfig(
name="gsm8k",
prompt_function=prompt.gsm8k,
prompt_function=gsm8k_prompt,
hf_repo="gsm8k",
hf_subset="main",
hf_avail_splits=["train", "test"],
Expand All @@ -352,10 +359,10 @@ def mmlu_harness(line, task_name: str = None):
topic = line["subject"]
prompt = f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n"
prompt += line["question"] + "\n"
prompt += "".join([f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, line["choices"])])
prompt += "".join([f"{key}. {choice}\n" for key, choice in zip(ascii_uppercase, line["choices"])])
prompt += "Answer:"

gold_ix = LETTER_INDICES.index(line["answer"]) if isinstance(line["answer"], str) else line["answer"]
gold_ix = ascii_uppercase.index(line["answer"]) if isinstance(line["answer"], str) else line["answer"]
"__few_shots" in line and line["__few_shots"] is True # We are adding few shots

return Doc(
Expand Down Expand Up @@ -590,7 +597,7 @@ def agi_eval_prompt(line, task_name: str = None):
prompt += line["question"] + "\n" + "\n".join(cleaned_options) + "\n"
prompt += "Answer: "

choices = LETTER_INDICES[: len(line["options"])]
choices = ascii_uppercase[: len(line["options"])]

output = Doc(
query=prompt,
Expand All @@ -599,7 +606,7 @@ def agi_eval_prompt(line, task_name: str = None):

if line["label"]:
output.choices = choices
output.gold_index = LETTER_INDICES.index(line["label"].strip())
output.gold_index = ascii_uppercase.index(line["label"].strip())
else:
output.choices = [line["answer"]]
output.gold_index = 0
Expand All @@ -616,7 +623,7 @@ def agi_eval_prompt_no_letters(line, task_name: str = None):
output = Doc(
query=line["question"],
choices=cleaned_options,
gold_index=LETTER_INDICES.index(line["label"].strip()),
gold_index=ascii_uppercase.index(line["label"].strip()),
instruction="",
)

Expand Down
Loading
Loading