diff --git a/examples/custom_tasks_tests.py b/examples/custom_tasks_tests.py index 7bb92eeaa..86c51ee7b 100644 --- a/examples/custom_tasks_tests.py +++ b/examples/custom_tasks_tests.py @@ -20,14 +20,15 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.tasks.gpqa import gpqa_instruct_prompt +from lighteval.tasks.tasks.gsm8k import gsm8k_prompt gsm8k_test = LightevalTaskConfig( name="gsm8k_test", - prompt_function=prompt.gsm8k, + prompt_function=gsm8k_prompt, hf_repo="gsm8k", hf_subset="main", hf_avail_splits=["train", "test"], @@ -42,7 +43,7 @@ gpqa_diamond_test = LightevalTaskConfig( name="gpqa:diamond_test", - prompt_function=prompt.gpqa_instruct, + prompt_function=gpqa_instruct_prompt, hf_repo="Idavidrein/gpqa", hf_subset="gpqa_diamond", hf_avail_splits=["train"], diff --git a/examples/nanotron/custom_evaluation_tasks.py b/examples/nanotron/custom_evaluation_tasks.py index adf0ae286..b079498d8 100644 --- a/examples/nanotron/custom_evaluation_tasks.py +++ b/examples/nanotron/custom_evaluation_tasks.py @@ -28,14 +28,21 @@ """ import re +from string import ascii_uppercase from typing import List, Tuple -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.metrics.normalizations import LogProbCharNorm, helm_normalizer, math_normalizer -from lighteval.tasks.default_prompts import LETTER_INDICES from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.requests import Doc +from lighteval.tasks.tasks.arc import arc_prompt +from lighteval.tasks.tasks.gsm8k import gsm8k_prompt +from lighteval.tasks.tasks.math import math_prompt +from lighteval.tasks.tasks.openbookqa import openbookqa_prompt +from lighteval.tasks.tasks.piqa import piqa_prompt +from lighteval.tasks.tasks.quac import quac_prompt +from lighteval.tasks.tasks.triviaqa import triviaqa_prompt +from lighteval.tasks.tasks.winogrande import winogrande_prompt _TASKS_STRINGS: List[Tuple[LightevalTaskConfig, str]] = [] @@ -48,7 +55,7 @@ def commonsense_qa_prompt(line, task_name: str = None): task_name=task_name, query=line["question"], choices=[f" {c}" for c in line["choices"]["text"]], - gold_index=LETTER_INDICES.index(line["answerKey"].strip()), + gold_index=ascii_uppercase.index(line["answerKey"].strip()), instruction="", ) @@ -99,7 +106,7 @@ def preprocess(text): ), LightevalTaskConfig( name="winogrande", - prompt_function=prompt.winogrande, + prompt_function=winogrande_prompt, hf_repo="winogrande", hf_subset="winogrande_xl", metrics=[ @@ -112,7 +119,7 @@ def preprocess(text): ), LightevalTaskConfig( name="piqa", - prompt_function=prompt.piqa_harness, + prompt_function=piqa_prompt, hf_repo="piqa", hf_subset="plain_text", metrics=[ @@ -139,7 +146,7 @@ def preprocess(text): ), LightevalTaskConfig( name="openbookqa", - prompt_function=prompt.openbookqa, + prompt_function=openbookqa_prompt, hf_repo="openbookqa", hf_subset="main", metrics=[ @@ -152,7 +159,7 @@ def preprocess(text): ), LightevalTaskConfig( name="arc:easy", - prompt_function=prompt.arc, + prompt_function=arc_prompt, hf_repo="ai2_arc", hf_subset="ARC-Easy", evaluation_splits=["test"], @@ -167,7 +174,7 @@ def preprocess(text): ), LightevalTaskConfig( name="arc:challenge", - prompt_function=prompt.arc, + prompt_function=arc_prompt, hf_repo="ai2_arc", hf_subset="ARC-Challenge", evaluation_splits=["test"], @@ -216,7 +223,7 @@ def natural_questions_prompt(line, task_name: str = None): WORLD_KNOWLEDGE_TASKS = [ LightevalTaskConfig( name="trivia_qa", - prompt_function=prompt.triviaqa, + prompt_function=triviaqa_prompt, hf_repo="trivia_qa", hf_subset="rc.nocontext", metrics=[ @@ -266,7 +273,7 @@ def boolq_prompt(line, task_name: str = None): ), LightevalTaskConfig( name="quac", - prompt_function=prompt.quac, + prompt_function=quac_prompt, hf_repo="lighteval/quac_helm", hf_subset="deault", metrics=[ @@ -290,7 +297,7 @@ class CustomMathEvaluationTask(LightevalTaskConfig): def __init__( self, name, - prompt_function=prompt.math, + prompt_function=math_prompt, hf_repo="DigitalLearningGmbH/MATH-lighteval", hf_subset=None, metrics=[ @@ -329,7 +336,7 @@ def __init__( ] GSM8K = LightevalTaskConfig( name="gsm8k", - prompt_function=prompt.gsm8k, + prompt_function=gsm8k_prompt, hf_repo="gsm8k", hf_subset="main", hf_avail_splits=["train", "test"], @@ -352,10 +359,10 @@ def mmlu_harness(line, task_name: str = None): topic = line["subject"] prompt = f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n" prompt += line["question"] + "\n" - prompt += "".join([f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, line["choices"])]) + prompt += "".join([f"{key}. {choice}\n" for key, choice in zip(ascii_uppercase, line["choices"])]) prompt += "Answer:" - gold_ix = LETTER_INDICES.index(line["answer"]) if isinstance(line["answer"], str) else line["answer"] + gold_ix = ascii_uppercase.index(line["answer"]) if isinstance(line["answer"], str) else line["answer"] "__few_shots" in line and line["__few_shots"] is True # We are adding few shots return Doc( @@ -590,7 +597,7 @@ def agi_eval_prompt(line, task_name: str = None): prompt += line["question"] + "\n" + "\n".join(cleaned_options) + "\n" prompt += "Answer: " - choices = LETTER_INDICES[: len(line["options"])] + choices = ascii_uppercase[: len(line["options"])] output = Doc( query=prompt, @@ -599,7 +606,7 @@ def agi_eval_prompt(line, task_name: str = None): if line["label"]: output.choices = choices - output.gold_index = LETTER_INDICES.index(line["label"].strip()) + output.gold_index = ascii_uppercase.index(line["label"].strip()) else: output.choices = [line["answer"]] output.gold_index = 0 @@ -616,7 +623,7 @@ def agi_eval_prompt_no_letters(line, task_name: str = None): output = Doc( query=line["question"], choices=cleaned_options, - gold_index=LETTER_INDICES.index(line["label"].strip()), + gold_index=ascii_uppercase.index(line["label"].strip()), instruction="", ) diff --git a/src/lighteval/tasks/default_prompts.py b/src/lighteval/tasks/default_prompts.py deleted file mode 100644 index a78860168..000000000 --- a/src/lighteval/tasks/default_prompts.py +++ /dev/null @@ -1,2896 +0,0 @@ -# MIT License - -# Copyright (c) 2024 The HuggingFace Team - -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import ast -import json -import logging -import random -import re -import string -from typing import Optional - -import numpy as np -import pycountry - -from lighteval.tasks.requests import Doc -from lighteval.utils.utils import as_list - - -logger = logging.getLogger(__name__) - - -# fmt: off -LETTER_INDICES = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z"] -INTEGER_INDICES = list(map(str, list(range(1, 27)))) -# fmt: on - - -def mmmu_pro(line, task_name: Optional[str] = None): - # fmt: off - question = line["question"] # "What is the capital of France?" - choices_string = line["options"] # "[Paris, London, Berlin, Madrid]" - answer = line["answer"] # "A" - # fmt: on - - instructions = "Answer with the option letter from the given choices directly." - - # Preprocess choices - # "[Paris, London, Berlin, Madrid]" -> ["A. Paris", "B. London", "C. Berlin", "D. Madrid"] - choices = ast.literal_eval(str(choices_string)) - choices_letters = [chr(ord("A") + i) for i in range(len(choices))] # ["A", "B", "C", "D"] - choices = [f"{letter}. {choice}" for letter, choice in zip(choices_letters, choices)] - - # Construct prompt - formatted_choices = "\n".join(choices) - prompt = f"\n{question}\n{formatted_choices}" - - # Collect images - image_order = [] - for num in re.findall(r"", prompt): - num = int(num) - if num not in image_order: - image_order.append(num) - images = [line[f"image_{i}"].convert("RGB") for i in image_order] - - gold_index = string.ascii_uppercase.index(answer) - - # Replace image placeholders in prompt , , ... with [image 1], [image 2], ... - prompt = re.sub(r"", "[image \\1]", prompt) - choices = [re.sub(r"", "[image \\1]", choice) for choice in choices] - - return Doc( - task_name=task_name, - query=prompt, - choices=choices, - gold_index=gold_index, - images=images, - specific={"id": line["id"]}, - instruction=instructions, - ) - - -def mmmu_pro_vision(line, task_name: str = None): - instruction = ( - "Answer with the option letter from the given choices directly." - " The last line of your response should be of the following format: " - "'Answer: $LETTER' (without quotes) where LETTER is one of options." - ) - - # Preprocess choices - # "[Paris, London, Berlin, Madrid]" -> ["A. Paris", "B. London", "C. Berlin", "D. Madrid"] - choices_string = line["options"] - choices = ast.literal_eval(str(choices_string)) - choices_letters = [chr(ord("A") + i) for i in range(len(choices))] # ["A", "B", "C", "D"] - choices = [f"{letter}. {choice}" for letter, choice in zip(choices_letters, choices)] - - # Preprocess answer - # e.g. "A" -> 0 - answer = line["answer"] - gold_index = string.ascii_uppercase.index(answer) - - # Preprocess images - images = [line["image"]] - - return Doc( - task_name=task_name, - query=instruction, - choices=choices, - gold_index=gold_index, - images=images, - instruction=instruction, - ) - - -def simpleqa(line, task_name: str = None): - query = line["problem"] - choices = [line["answer"]] - gold_index = 0 - - return Doc( - task_name=task_name, query=query, choices=choices, gold_index=gold_index, specific={**eval(line["metadata"])} - ) - - -def aime_prompt_fn(line, task_name: str = None): - # Prompt template adapted from - # - simple-evals: https://github.com/openai/simple-evals/blob/6e84f4e2aed6b60f6a0c7b8f06bbbf4bfde72e58/math_eval.py#L17 - # - Llama 3: https://huggingface.co/datasets/meta-llama/Llama-3.2-1B-Instruct-evals/viewer/Llama-3.2-1B-Instruct-evals__math__details?views%5B%5D=llama_32_1b_instruct_evals__math__details - # Note that it is important to have the final answer in a box for math-verify to work correctly - MATH_QUERY_TEMPLATE = """ -Solve the following math problem efficiently and clearly. The last line of your response should be of the following format: 'Therefore, the final answer is: $\\boxed{{ANSWER}}$. I hope it is correct' (without quotes) where ANSWER is just the final number or expression that solves the problem. Think step by step before answering. - -{Question} -""".strip() - return Doc( - task_name=task_name, - query=MATH_QUERY_TEMPLATE.format(Question=line["problem"]), - choices=[line["answer"]], - gold_index=0, - ) - - -def anli(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"{line['premise']}\nQuestion: {line['hypothesis']} True, False, or Neither?\nAnswer:", - choices=[" True", " Neither", " False"], - gold_index=int(line["label"]), - ) - - -def agieval(line, task_name: str = None): - return Doc( - task_name=task_name, - query=line["query"], - choices=[f" {c}" for c in line["choices"]], - gold_index=line["gold"], - ) - - -def apps(line, task_name: str = None): - answer_type = "\nUse Call-Based format\n" if line["starter_code"] != "" else "\nUse Standard Input format\n" - return Doc( - task_name=task_name, - query=f"\nQUESTION:\n{line['question']}\n{line['starter_code']}\n{answer_type}\nANSWER in Python code:\n", - choices=[json.loads(line["solutions"])], - gold_index=0, - specific={"input_output": line["input_output"]}, - ) - - -def arc_agi_2(line, task_name: str = None): - # query from: https://github.com/arcprize/model_baseline/blob/main/src/prompts/system_prompt.txt - def convert_2d_list_to_string(list_of_lists: list[list[int]]) -> str: - """Convert a list of lists to a string""" - string_list = "" - - for row in list_of_lists: - string_list += json.dumps(row) + "\n" - - return string_list - - query = """You are participating in a puzzle solving competition. You are an expert at solving puzzles. - -Below is a list of input and output pairs with a pattern. Your goal is to identify the pattern or transformation in the training examples that maps the input to the output, then apply that pattern to the test input to give a final output. - -Respond in the format of the training output examples - ---Training Examples-- -{training_examples} ---End of Training Examples-- - ---Test Input-- -{test_input} ---End of Test Input-- - -Your response:""".strip() - - training_pairs = line["fewshots"] - training_examples = "" - for i, pair in enumerate(training_pairs): - training_examples += f"--Example {i}-- \n\n INPUT: \n\n" - training_examples += convert_2d_list_to_string(pair["input"]) + "\n\n" - training_examples += "OUTPUT: \n\n" - training_examples += convert_2d_list_to_string(pair["output"]) + "\n\n" - - test_input = convert_2d_list_to_string(line["question"][0]["input"]) - - gold = str(line["question"][0]["output"]) - query = query.format(training_examples=training_examples, test_input=test_input) - - return Doc( - task_name=task_name, - query=query, - choices=[gold], - gold_index=0, - ) - - -def arc(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"Question: {line['question']}\nAnswer:", - choices=[f" {c}" for c in line["choices"]["text"]], - gold_index=line["choices"]["label"].index(line["answerKey"]), - ) - - -def arc_with_options_letters_predict(line, task_name: str = None): - query = f"Question: {line['question']}\n" - query += "".join([f"\n{key}. {choice}" for key, choice in zip(LETTER_INDICES, line["choices"]["text"])]) - query += "\nAnswer:" - return Doc( - task_name=task_name, - query=query, - choices=LETTER_INDICES[: len(line["choices"]["text"])], - gold_index=line["choices"]["label"].index(line["answerKey"]), - ) - - -def arc_with_options(line, task_name: str = None): - query = f"Question: {line['question']}\n" - query += "".join([f"\n{key}. {choice}" for key, choice in zip(LETTER_INDICES, line["choices"]["text"])]) - query += "\nAnswer:" - return Doc( - task_name=task_name, - query=query, - choices=line["choices"]["text"], - gold_index=line["choices"]["label"].index(line["answerKey"]), - ) - - -def arithmetic(line, task_name: str = None): - return Doc(task_name=task_name, query=line["context"], choices=[line["completion"]], gold_index=[0]) - - -def asdiv(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"{line['body']}\nQuestion:{line['question']}\nAnswer:", - choices=line["answer"].split(" (")[0], - gold_index=[0], - ) - - -def babi_qa(line, task_name: str = None): # HELM - def process_path(path: str) -> str: - """Turn a path string (task 19) from the original format 's,w' into a verbal model-friendly format 'south west'""" - steps = path.split(",") - directions = {"s": "south", "n": "north", "e": "east", "w": "west"} - path = " ".join([directions[step] for step in steps]) - return path - - if isinstance(line["story"], dict): - line = line["story"] - else: - line = json.loads(line["story"]) - - queries = [] - story = [] - for type, text, answer in zip(line["type"], line["text"], line["answer"]): - if type == 1: - if ( - len(answer) == 3 and re.fullmatch(r"[nswe],[nswe]", answer) is not None - ): # task 19, we manage directions - answer = process_path(answer) - queries.append( - Doc( - task_name=task_name, - query=f"Passage: {' '.join(story)} {text}\nAnswer:", - choices=[answer], - gold_index=0, - ) - ) - else: - story.append(text) - return queries - - -def bbh_harness(line, task_name: str = None): - line = {k: v for k, v in line.items() if v not in [None, ""]} - - task_prefix = line.get("task_prefix", "") - example_input_prefix = line.get("example_input_prefix", "\nQ: ") - query = f"{task_prefix}{example_input_prefix}{line['input']}" - - rng = np.random.RandomState(seed=42) - choice_prefix = line.get("choice_prefix", "\n choice: ") - append_choices = bool(line.get("append_choices", True)) - # default - correct_index = line["target_idx"] - choices = line["choices"] - if append_choices: - choices = list(rng.permutation(sorted(line["choices"]))) - query = f"{query}{choice_prefix}{choice_prefix.join(choices)}" - gold_item = line["choices"][line["target_idx"]] - correct_index = choices.index(gold_item) - - example_output_prefix = line.get("example_output_prefix", "\nA: ") - query = f"{query}{example_output_prefix}" - - return Doc( - task_name=task_name, - query=query, - choices=choices, - gold_index=correct_index, - instruction=line.get("task_prefix", None), - ) - - -def bbh_lighteval(line, task_name: str = None): - line = {k: v for k, v in line.items() if v is not None} - - query = line.get("task_prefix", "") - query += line.get("example_input_prefix", "\nQuestion: ") - query += line["input"] - query += line.get("choice_prefix", "\n Choices: ") - query += "".join([f"\n{key}. {choice}" for key, choice in zip(LETTER_INDICES, line["choices"])]) - query += line.get("example_output_prefix", "\nAnswer: ") - - return Doc( - task_name=task_name, - query=query, - choices=LETTER_INDICES[: len(line["choices"])], - gold_index=line["target_idx"], - instruction=line.get("task_prefix", None), - ) - - -def bbh(line, instruction, choices, task_name: str = None): - return Doc( - task_name=task_name, - query=f"{instruction}Q: {line['input']}\nA:", - choices=choices, - gold_index=choices.index(line["target"]), - instruction=instruction, - ) - - -def bbh_boolean_expressions(line, task_name: str = None): - instruction = "Evaluate the result of a random Boolean expression.\n\n" - choices = ["False", "True"] - return bbh(line, instruction, choices, task_name) - - -def bbh_causal_judgment(line, task_name: str = None): - instruction = "Answer questions about causal attribution.\n\n" - choices = ["Yes", "No"] - return bbh(line, instruction, choices, task_name) - - -def bbh_date_understanding(line, task_name: str = None): - instruction = "Infer the date from context.\n\n" - choices = [f"({c})" for c in LETTER_INDICES[:6]] - return bbh(line, instruction, choices, task_name) - - -def bbh_disambiguation_qa(line, task_name: str = None): - instruction = "Clarify the meaning of sentences with ambiguous pronouns.\n\n" - choices = [f"({c})" for c in LETTER_INDICES[:3]] - return bbh(line, instruction, choices, task_name) - - -def bbh_dyck_languages(line, task_name: str = None): # Can only be done in generative - instruction = "Correctly close a Dyck-n word.\n\n" - choices = [line["target"]] - return bbh(line, instruction, choices, task_name) - - -def bbh_formal_fallacies(line, task_name: str = None): - instruction = "Distinguish deductively valid arguments from formal fallacies.\n\n" - choices = ["valid", "invalid"] - return bbh(line, instruction, choices, task_name) - - -def bbh_geometric_shapes(line, task_name: str = None): - instruction = "Name geometric shapes from their SVG paths.\n\n" - choices = [f"({c})" for c in LETTER_INDICES[:11]] - return bbh(line, instruction, choices, task_name) - - -def bbh_hyperbaton(line, task_name: str = None): - instruction = "Order adjectives correctly in English sentences.\n\n" - choices = [f"({c})" for c in LETTER_INDICES[:2]] - return bbh(line, instruction, choices, task_name) - - -def bbh_logical_deduction_five_objects(line, task_name: str = None): - instruction = "A logical deduction task which requires deducing the order of a sequence of objects.\n\n" - choices = [f"({c})" for c in LETTER_INDICES[:5]] - return bbh(line, instruction, choices, task_name) - - -def bbh_logical_deduction_seven_objects(line, task_name: str = None): - instruction = "A logical deduction task which requires deducing the order of a sequence of objects.\n\n" - choices = [f"({c})" for c in LETTER_INDICES[:7]] - return bbh(line, instruction, choices, task_name) - - -def bbh_logical_deduction_three_objects(line, task_name: str = None): - instruction = "A logical deduction task which requires deducing the order of a sequence of objects.\n\n" - choices = [f"({c})" for c in LETTER_INDICES[:3]] - return bbh(line, instruction, choices, task_name) - - -def bbh_movie_recommendation(line, task_name: str = None): - if line["target"] == "Monsters, Inc": # this line is not correctly formatted - logger.warning( - "One sample removed from task bbh:movie_recommendation because its line is incorrectly formatted." - ) - return [] - instruction = "Recommend movies similar to the given list of movies.\n\n" - choices = [f"({c})" for c in LETTER_INDICES[:6]] - return bbh(line, instruction, choices, task_name) - - -def bbh_multistep_arithmetic_two(line, task_name: str = None): - instruction = "Solve multi-step arithmetic problems.\n\n" # Can only be done in generative - choices = [line["target"]] - return bbh(line, instruction, choices, task_name) - - -def bbh_navigate(line, task_name: str = None): - instruction = ( - "Given a series of navigation instructions, determine whether one would end up back at the starting point.\n\n" - ) - choices = ["Yes", "No"] - return bbh(line, instruction, choices, task_name) - - -def bbh_object_counting(line, task_name: str = None): - instruction = "Questions that involve enumerating objects and asking the model to count them.\n\n" - choices = [str(i) for i in range(1, 19)] - return bbh(line, instruction, choices, task_name) - - -def bbh_penguins_in_a_table(line, task_name: str = None): - instruction = "Answer questions about a table of penguins and their attributes.\n\n" - choices = [f"({c})" for c in LETTER_INDICES[:5]] - return bbh(line, instruction, choices, task_name) - - -def bbh_reasoning_about_colored_objects(line, task_name: str = None): - instruction = "Answer extremely simple questions about the colors of objects on a surface.\n\n" - choices = [f"({c})" for c in LETTER_INDICES[:18]] - return bbh(line, instruction, choices, task_name) - - -def bbh_ruin_names(line, task_name: str = None): - if line["target"] in ["dearth, wind, & fire", "rita, sue and bob poo"]: # line not correctly formatted - logger.warning("One sample removed from task bbh:ruin_names because its line is incorrectly formatted.") - return [] - instruction = "Select the humorous edit that 'ruins' the input movie or musical artist name.\n\n" - choices = [f"({c})" for c in LETTER_INDICES[:6]] - return bbh(line, instruction, choices, task_name) - - -def bbh_salient_translation_error_detection(line, task_name: str = None): - instruction = "Detect the type of error in an English translation of a German source sentence.\n\n" - choices = [f"({c})" for c in LETTER_INDICES[:6]] - return bbh(line, instruction, choices, task_name) - - -def bbh_snarks(line, task_name: str = None): - instruction = 'Determine which of two sentences is sarcastic.\n\nAccording to Cambridge University Dictionary, sarcasm is "the use of remarks that clearly mean the opposite of what they say, made in order to hurt someone\'s feelings or to criticize something in a humorous way." Sarcastic sentences often contain satirical or ironic utterances, hyperboles, ambivalent or witty remarks.\n\n' - choices = [f"({c})" for c in LETTER_INDICES[:2]] - return bbh(line, instruction, choices, task_name) - - -def bbh_sports_understanding(line, task_name: str = None): - instruction = "Determine whether an artificially constructed sentence relating to sports is plausible or not.\n\n" - choices = ["yes", "no"] - return bbh(line, instruction, choices, task_name) - - -def bbh_temporal_sequences(line, task_name: str = None): - instruction = "Task description: Answer questions about which times certain events could have occurred.\n\n" - choices = [f"({c})" for c in LETTER_INDICES[:4]] - return bbh(line, instruction, choices, task_name) - - -def bbh_tracking_shuffled_objects_five_objects(line, task_name: str = None): - instruction = "A task requiring determining the final positions of a set of objects given their initial positions and a description of a sequence of swaps.\n\n" - choices = [f"({c})" for c in LETTER_INDICES[:5]] - return bbh(line, instruction, choices, task_name) - - -def bbh_tracking_shuffled_objects_seven_objects(line, task_name: str = None): - instruction = "A task requiring determining the final positions of a set of objects given their initial positions and a description of a sequence of swaps.\n\n" - choices = [f"({c})" for c in LETTER_INDICES[:7]] - return bbh(line, instruction, choices, task_name) - - -def bbh_tracking_shuffled_objects_three_objects(line, task_name: str = None): - instruction = "A task requiring determining the final positions of a set of objects given their initial positions and a description of a sequence of swaps.\n\n" - choices = [f"({c})" for c in LETTER_INDICES[:3]] - return bbh(line, instruction, choices, task_name) - - -def bbh_web_of_lies(line, task_name: str = None): - instruction = "Evaluate a random boolean function expressed as a word problem.\n\n" - choices = ["Yes", "No"] - return bbh(line, instruction, choices, task_name) - - -def bbh_word_sorting(line, task_name: str = None): - instruction = "Sort a list of words.\n\n" # Can only be done in generative - choices = [line["target"]] - return bbh(line, instruction, choices, task_name) - - -def bbq(line, task_name: str = None): # HELM - query = f"The following are multiple choice questions (with answers).\nPassage: {line['context']}\nQuestion: {line['question']}" - query += "".join([f"\n{key}. {choice}" for key, choice in zip(LETTER_INDICES, line["choices"])]) - query += "\nAnswer:" - return Doc( - task_name=task_name, - query=query, - choices=LETTER_INDICES[: len(line["choices"])], - gold_index=int(line["gold_index"]), - ) - - -def bigbench_helm(line, task_name: str = None): - if "target" in line: - return Doc(task_name=task_name, query=line["input"], choices=[line["target"]], gold_index=0) - choices, gold_ix = [], -1 - if isinstance(line["target_scores"], str): - line["target_scores"] = ast.literal_eval(line["target_scores"]) - for ix, (choice, score) in enumerate(line["target_scores"].items()): - choices.append(choice) - if score == 1: - gold_ix = ix - - return Doc(task_name=task_name, query=line["input"], choices=choices, gold_index=gold_ix) - - -def blimp(line, task_name: str = None): - return Doc(task_name=task_name, query="", choices=[line["sentence_good"], line["sentence_bad"]], gold_index=0) - - -def blimp_helm(line, task_name: str = None): - return Doc( - task_name=task_name, - query="Please select the grammatical sentence.", - choices=[line["sentence_good"], line["sentence_bad"]], - gold_index=0, - ) - - -def bold(line, task_name: str = None): - return Doc( - task_name=task_name, query=line["text"], choices=None, gold_index=None - ) # we only look at the perplexity of the generation > no gold - - -def boolq(line, task_name: str = None): - # Exact match - return Doc( - task_name=task_name, - query=f"Passage: {line['passage']}\nQuestion: {line['question']}?\nChoices:\n (A) True\n (B) False\nAnswer: ", - choices=[ - ("(B)", "(b)", "False", "false", "(B) False", "(b) false"), # False is label 0 - ("(A)", "(a)", "True", "true", "(A) True", "(a) true"), - ], - gold_index=int(line["label"]), - # "metric": "exact_match_greedy", - ) - - -def boolq_helm(line, task_name: str = None): - # remove extra `?` - question = line["question"][:-1] if line["question"][-2:] == "??" else line["question"] - return Doc( - task_name=task_name, - query=f"Passage: {line['passage']}\nQuestion: {question}\nAnswer:", - choices=[" Yes", " No"], - gold_index=["Yes", "No"].index(line["answer"]), - ) - - -def boolq_helm_contrastset(line, task_name: str = None): - if line["contrast_inputs"] in [None, ""]: - return boolq_helm(line) - - return [ - Doc( - task_name=task_name, - query=f"{passage}\nQuestion: {question}\nAnswer:", - choices=["Yes", "No"], - gold_index=["No", "Yes"].index(line["answer"]), - ) - for passage, question in zip(line["contrast_inputs"]["passage"], line["contrast_inputs"]["question"]) - ][0] - - -def boolq_harness(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"{line['passage']}\nQuestion: {line['question']}?\nAnswer:", - choices=[" no", " yes"], # False is label 0 - gold_index=int(line["label"]), - ) - - -def cb(line, task_name: str = None): - # Log likelihood - return Doc( - task_name=task_name, - query=f"{line['premise']}\nQuestion: {line['hypothesis']}. True, False or Neither?\nAnswer:", - choices=[" True", " False", " Neither"], - gold_index=int(line["label"]), - ) - - -def civil_comments(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"Passage: {line['text']}\nIs the passage above toxic?\nAnswer: ", - choices=[str(line["gold"])], - gold_index=0, - ) - - -def cnn_dm(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"###\nArticle:{line['article']}\n\nSummarize the above article in 3 sentences.\n", - choices=[str(line["summary"])], - gold_index=0, - specific={"text": line["article"]}, - ) - - -def cola(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"{line['sentence']}\nQuestion: Does this sentence make sense?\nAnswer:", - choices=[" no", " yes"], - gold_index=int(line["label"]), - ) - - -def commonsense_qa(line, task_name: str = None): - query = f"The following are multiple choice questions (with answers) about common sense.\nQuestion: {line['question']}\n" - query += "".join( - [f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, [f" {c}" for c in line["choices"]["text"]])] - ) - query += "Answer:" - - return Doc( - task_name=task_name, - query=query, - choices=LETTER_INDICES[: len(line["choices"]["text"])], - gold_index=LETTER_INDICES.index(line["answerKey"].strip()), - instruction="The following are multiple choice questions (with answers) about common sense.\n", - ) - - -def copa(line, task_name: str = None): - connector = {"cause": "because", "effect": "therefore"}[line["question"]] - return Doc( - task_name=task_name, - query=f"{line['premise'].strip()[:-1]} {connector}", - choices=[f" {line[c][0].lower()}{line[c][1:]}" for c in ["choice1", "choice2"]], # removing first cap - gold_index=int(line["label"]), - ) - - -def copyright(line, task_name: str = None): - return Doc( - task_name=task_name, - query=line["prefix"], - choices=[line["book"]], # gold reference - gold_index=0, - ) - - -def coqa(line, task_name: str = None): - results = [] - - # We return the first question only atm - for q, a in zip(line["questions"], line["answers"]["input_text"]): - results.append(Doc(task_name=task_name, query=f"{line['story']} \n\nQ: {q}\n\nA: ", choices=[a], gold_index=0)) - return results - - -def covid_dialogue(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"Generate a response given a patient's questions and concerns.\nPatient: {line['query']}\nDoctor: ", - choices=[line["answer"]], - gold_index=0, - instruction="Generate a response given a patient's questions and concerns.\n", - ) - - -def crows_pair(line, task_name: str = None): - return Doc(task_name=task_name, query="", choices="", gold_index="", instruction="") - - -def dyck_language(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"Please complete the rest of the following Dyck sequences, making sure that the parentheses are closed properly.\n Input: {line['input']}", - choices=[line["output"]], - gold_index=0, - instruction="Please complete the rest of the following Dyck sequences, making sure that the parentheses are closed properly.\n ", - ) - - -def drop(line, task_name: str = None): - # For the Harness new format, v0.0.1 - def _flatten_validated_answers(validated_answers): - """Flattens a dict of lists of validated answers. - {"number": ['1', '8'], ...} - -> [{"number": ['1'], ...}, {"number": ['8'], ...}] - - Returns: - list: List of dictionaries with flattened validated answers - """ - valid_answers = [] - for i in range(len(validated_answers["number"])): - valid_answers.append( - { - "number": validated_answers["number"][i], - "date": validated_answers["date"][i], - "spans": validated_answers["spans"][i], - } - ) - return valid_answers - - def parse_answer(answer): - # remove this tuple system, legacy from the harness, only here to allow to use answers as key in the below set - if answer["number"] != "": - return (str(answer["number"]),) - if answer["spans"] != []: - return tuple(answer["spans"]) - return (" ".join([answer["date"]["day"], answer["date"]["month"], answer["date"]["year"]]).strip(),) - - answers = [] - answers_set = set() - candidates = [line["answer"]] + _flatten_validated_answers(line["validated_answers"]) - for candidate in candidates: - answer = parse_answer(candidate) - if answer in answers_set: - continue - answers.append(answer) - # answers.extend(as_list(answer)) - answers_set.add(answer) - - is_few_shots = line.get("__few_shots", False) # We are adding few shots - - return Doc( - task_name=task_name, - query=f"Passage: {line['passage']}\nQuestion: {line['question']}\nAnswer:", - choices=[f"{' ' if is_few_shots else ''}{', '.join(a)}" for a in answers], - gold_index=list(range(len(answers))), - specific={"golds_no_preprocessing": [as_list(a) for a in answers]}, - ) - - -def empathetic_dialogue(line, task_name: str = None): - return Doc( - task_name=task_name, query=f"{line['input']}\nBEGIN DIALOGUE\n", choices=[line["references"]], gold_index=0 - ) - - -def entity_data_imputation(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"What is the missing value?\n{line['text']}\nAnswer:", - choices=[line["gold"]], - gold_index=0, - instruction="What is the missing value?\n", - ) - - -def entity_matching(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"Are Product A and Product B the same? Yes or No?\nProduct A is {line['productA']}. Product B is {line['productB']}. Are A and B the same?\nAnswer:", - choices=["No", "Yes"], - gold_index=int(line["same"]), - instruction="Are Product A and Product B the same? Yes or No?\n", - ) - - -def ethics_commonsense(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"{line['input']}\nQuestion: Is this wrong?\nAnswer:", - choices=[" no", " yes"], - gold_index=int(line["label"]), - ) - - -def ethics_deontology(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f'Question: Would most people believe this reasonable or unreasonable to say? "{line["scenario"]} {line["excuse"]}"\nAnswer:', - choices=[" unreasonable", " reasonable"], - gold_index=int(line["label"]), - instruction="Question: Would most people believe this reasonable or unreasonable to say?", - ) - - -def ethics_justice(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f'Question: Would most people believe this reasonable or unreasonable to say? "{line["scenario"]}"\nAnswer:', - choices=[" unreasonable", " reasonable"], - gold_index=int(line["label"]), - instruction="Question: Would most people believe this reasonable or unreasonable to say?", - ) - - -def ethics_utilitarianism(line, task_name: str = None): - rnd = random.Random(line["activity"]) - scenarios = [line["activity"], line["baseline"]] - ordering = [0, 1] - rnd.shuffle(ordering) - return Doc( - task_name=task_name, - query=f"Scenario 1: {scenarios[ordering[0]]}\nScenario 2: {scenarios[ordering[1]]}\nQuestion: Is Scenario 1 preferable?\nAnswer:", - choices=[" no", " yes"], - gold_index=int(ordering.index(0) == 0), - ) - - -def ethics_virtue(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f'Sentence: {line["scenario"]}\nQuestion: Does the character in this sentence exhibit the trait "{line["trait"]}"?\nAnswer:', - choices=[" no", " yes"], - gold_index=int(line["label"]), - ) - - -def gpqa(line, task_name: str = None): - # Prompt template from simple-evals: https://github.com/openai/simple-evals/blob/83ed7640a7d9cd26849bcb3340125002ef14abbe/common.py#L14 - GPQA_QUERY_TEMPLATE = """ -Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering. - -{Question} - -A) {A} -B) {B} -C) {C} -D) {D} -""".strip() - gold_index = random.randint(0, 3) - choices = [line["Incorrect Answer 1"], line["Incorrect Answer 2"], line["Incorrect Answer 3"]] - choices.insert(gold_index, line["Correct Answer"]) - - query = GPQA_QUERY_TEMPLATE.format( - A=choices[0], B=choices[1], C=choices[2], D=choices[3], Question=line["Question"] - ) - - return Doc( - task_name=task_name, - query=query, - choices=LETTER_INDICES[: len(choices)], - gold_index=gold_index, - instruction=query, - ) - - -def gpqa_instruct(line, task_name: str = None): - """Prompt template adapted from simple-evals: https://github.com/openai/simple-evals/blob/83ed7640a7d9cd26849bcb3340125002ef14abbe/common.py#L14""" - gold_index = random.randint(0, 3) - choices = [line["Incorrect Answer 1"], line["Incorrect Answer 2"], line["Incorrect Answer 3"]] - choices.insert(gold_index, line["Correct Answer"]) - instruction = "Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering." - query_template = "{Instruction}\n\n{Question}\n\nA) {A}\nB) {B}\nC) {C}\nD) {D}" - query = query_template.format( - # Stripping to avoid accidental extra whitespaces, present in GPQA - A=choices[0].strip(), - B=choices[1].strip(), - C=choices[2].strip(), - D=choices[3].strip(), - Question=line["Question"].strip(), - Instruction=instruction, - ) - - return Doc( - task_name=task_name, - query=query, - choices=LETTER_INDICES[: len(choices)], - gold_index=gold_index, - instruction=instruction, - ) - - -def gsm_plus(line, task_name: str = None): - # GSM8K with 8 prompt variations per sample - - # Some prompts require critical thinking (around 1k/10k), we skip them as - # they are a bit trickier to eval with regular text extraction. - if line["perturbation_type"] == "critical thinking": - return None - - return Doc( - task_name=task_name, - query=f"Question: {line['question']}\n\nAnswer:", - choices=[line["answer"]], - gold_index=0, - ) - - -def gsm8k(line, task_name: str = None): - # Has special analysis in metric for number decomposition - return Doc( - task_name=task_name, - query=f"Question: {line['question']}\nAnswer:", - choices=[f" {line['answer']}"], - gold_index=0, - ) - - -def gsm8k_helm(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"Q: {line['question']}\nA: ", - choices=[line["answer"].replace("####", "The answer is").replace("\n", " ") + "."], - gold_index=0, - ) - - -def headqa(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"Question: {line['qtext']}\nAnswer:", - choices=[f" {answer['atext']}" for answer in line["answers"]], - gold_index=int(line["ra"]) - 1, - ) - - -def hellaswag_preprocess( - text: str, - wikihow_artifacts: list[str] = [" [title]"], - truncate_dots: bool = False, - strip_text: bool = False, - dot_replacement: str = ". ", -): - """Comes from LM Eval Harness""" - # NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag. - for wikihow_artifact in wikihow_artifacts: - text = text.replace(wikihow_artifact, dot_replacement) - text = re.sub("\\[.*?\\]", "", text) - text = text.replace(" ", " ") - if truncate_dots: - text = text.replace(r"\.+", r"\.") - if strip_text: - text = text.strip() - return text - - -def hellaswag_harness(line, task_name: str = None): - ctx = f"{line['ctx_a']} {line['ctx_b'].capitalize()} " - return Doc( - task_name=task_name, - query=hellaswag_preprocess(line["activity_label"] + ": " + ctx), - choices=[hellaswag_preprocess(ending) for ending in line["endings"]], - gold_index=int(line["label"]) if line["label"] != "" else -1, # -1 for test - # "metric": "choices_loglikelihood", - ) - - -def hellaswag_generative(line, task_name: str = None): - query = "The following are multiple choice questions (with answers) about common sense.\n\n" - query += f"Question: {line['activity_label']}: {line['ctx_a']} {line['ctx_b'].capitalize()}\n" - query += "".join([f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, line["endings"])]) - query += "Answer:" - - gold_ix = int(line["label"]) if line["label"] != "" else -1 - return Doc( - task_name=task_name, - query=query, - choices=[" " + i for i in LETTER_INDICES[: len(line["endings"])]], - gold_index=gold_ix, # -1 for test, - instruction="The following are multiple choice questions (with answers) about common sense.\n\n", - ) - - -def humaneval(line, task_name: str = None): - # "test_cases": line["test"] - return Doc( - task_name=task_name, - query=line["prompt"], - choices=[line["canonical_solution"]], - gold_index=0, - specific={key: line[key] for key in ["prompt", "test", "entry_point", "task_id"]}, - ) - - -def humaneval_for_code_models(line, task_name: str = None): - # We need to remove ending "\n" as it's never tokenized on its own but rather as "\n\t" - query = line["Doc"][:-1] if line["Doc"][-1:] == "\n" else line["Doc"] - return Doc(task_name=task_name, query=query, choices=[line["canonical_solution"]], gold_index=0, specific=line) - - -def imdb(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"Passage: {line['input']}\nSentiment: ", - choices=["Positive", "Negative"], - gold_index=["Positive", "Negative"].index(line["reference"]), - ) - - -def imdb_contrastset(line, task_name: str = None): - if line["contrast_input"] is None or line["contrast_references"] is None: - return imdb(line) - - return Doc( - task_name=task_name, - query=f"Passage: {line['contrast_inputs']}\nSentiment: ", - choices=["Positive", "Negative"], - gold_index=["Positive", "Negative"].index(line["contrast_references"]), - ) - - -def lambada_cloze(line, task_name: str = None): - query, choice = line["text"].rsplit(" ", 1) - return Doc( - task_name=task_name, - query=f"{query} ____. ->", - gold_index=0, - choices=[f" {choice}"], - ) - - -def lambada(line, task_name: str = None): - query, choice = line["text"].rsplit(" ", 1) - return Doc( - task_name=task_name, - query=query, - gold_index=0, - choices=[f" {choice}"], - ) - - -def legal_support(line, task_name: str = None): - query = f"Which statement best supports the passage?\nPassage: {line['context']}\n" - query += "".join( - [ - f"{key}. {choice}\n" - for key, choice in zip( - ["a", "b"], [line["citation_a"]["parenthetical"], line["citation_b"]["parenthetical"]] - ) - ] - ) - query += "Answer:" - - return Doc( - task_name=task_name, - query=query, - choices=["a", "b"], - gold_index=["a", "b"].index(line["label"]), - instruction="Which statement best supports the passage?\n", - ) - - -def lex_glue(line, instruction, task_name: str = None): - return Doc( - task_name=task_name, - query=f"{instruction}\nPassage: {line['input']}\nAnswer: ", - choices=line["references"], - gold_index=[line["references"].index(item) for item in line["gold"]], - instruction=instruction + "\n", - ) - - -def lex_glue_ecthr_a(line, task_name: str = None): - instruction = "In this task, you are given the facts from a case heard at the European Court of Human Rights (ECtHR). Predict the articles of the ECtHR that were violated (if any)." - return lex_glue(line, instruction, task_name) - - -def lex_glue_ecthr_b(line, task_name: str = None): - instruction = "In this task, you are given the facts from a case heard at the European Court of Human Rights (ECtHR). Predict the articles of ECtHR that were allegedly violated (considered by the court)." - return lex_glue(line, instruction, task_name) - - -def lex_glue_scotus(line, task_name: str = None): - instruction = "In this task, you are given a case heard at the Supreme Court of the United States (SCOTUS). Predict the relevant issue area." - return lex_glue(line, instruction, task_name) - - -def lex_glue_eurlex(line, task_name: str = None): - instruction = "In this task, you are given an EU law document published in the EUR-Lex portal. Predict the relevant EuroVoc concepts." - return lex_glue(line, instruction, task_name) - - -def lex_glue_ledgar(line, task_name: str = None): - instruction = "In this task, you are given a contract provision \nfrom contracts obtained from US Securities and Exchange Commission (SEC) filings. Predict the main topic." - return lex_glue(line, instruction, task_name) - - -def lex_glue_unfair_tos(line, task_name: str = None): - instruction = "In this task, you are given a sentence \nfrom a Terms of Service (ToS) document from on-line platforms. Predict the types of unfair contractual terms" - return lex_glue(line, instruction, task_name) - - -def lex_glue_case_hold(line, task_name: str = None): - instruction = "In this task, you are given an excerpt from a court decision, \ncontaining a reference to a particular case, while the holding statement is masked out. Predict the index of the holding statement fitting in the context at from a selection of five choices." - return lex_glue(line, instruction, task_name) - - -def lextreme(line, instruction, task_name: str = None): - return Doc( - task_name=task_name, - query=f"{instruction}\nPassage: {line['input']}\nAnswer: ", - choices=line["references"], - gold_index=[line["references"].index(item) for item in line["gold"]], - instruction=instruction + "\n", - ) - - -def lextreme_brazilian_court_decisions_judgment(line, task_name: str = None): - instruction = ( - "In this task, you are given the case description " - "from a decision heard at the State Supreme Court of Alagoas (Brazil). " - "Predict the judgment of the case " - "(no: The appeal was denied, " - "partial: For partially favourable decisions, " - "yes: For fully favourable decisions)" - ) - return lextreme(line, instruction, task_name) - - -def lextreme_brazilian_court_decisions_unanimity(line, task_name: str = None): - instruction = ( - "In this task, you are given the case description " - "from a decision heard at the State Supreme Court of Alagoas (Brazil). " - "Predict the unanimity of the case (unanimity, not-unanimity, not_determined)" - ) - return lextreme(line, instruction, task_name) - - -def lextreme_german_argument_mining(line, task_name: str = None): - instruction = ( - "In this task, you are given sentences from German court decisions. " - "Predict the major component of German Urteilsstil " - "(conclusion: Overall result, " - "definition: Abstract legal facts and consequences, " - "subsumption: Determination sentence / Concrete facts, " - "other: Anything else)" - ) - return lextreme(line, instruction, task_name) - - -def lextreme_greek_legal_code_chapter(line, task_name: str = None): - instruction = ( - "In this task, you are given a Greek legislative document. " - "Predict the chapter level category of the " - "'Permanent Greek Legislation Code - Raptarchis (Ραπτάρχης)' the document belongs to." - ) - return lextreme(line, instruction, task_name) - - -def lextreme_greek_legal_code_subject(line, task_name: str = None): - instruction = ( - "In this task, you are given a Greek legislative document. " - "Predict the subject level category of the " - "'Permanent Greek Legislation Code - Raptarchis (Ραπτάρχης)' the document belongs to." - ) - - return lextreme(line, instruction, task_name) - - -def lextreme_greek_legal_code_volume(line, task_name: str = None): - instruction = ( - "In this task, you are given a Greek legislative document. " - "Predict the volume level category of the " - "'Permanent Greek Legislation Code - Raptarchis (Ραπτάρχης)' the document belongs to." - ) - return lextreme(line, instruction, task_name) - - -def lextreme_swiss_judgment_prediction(line, task_name: str = None): - instruction = ( - "In this task, you are given the facts description " - "from a decision heard at the Swiss Federal Supreme Court. " - "Predict the judgment of the case (approval or dismissal)" - ) - return lextreme(line, instruction, task_name) - - -def lextreme_online_terms_of_service_unfairness_levels(line, task_name: str = None): - instruction = ( - "In this task, you are given a sentence " - "from a Terms of Service (ToS) document. " - "Predict the unfairness level of the sentence (potentially_unfair, clearly_unfair, clearly_fair, untagged)" - ) - return lextreme(line, instruction, task_name) - - -def lextreme_online_terms_of_service_clause_topics(line, task_name: str = None): - instruction = ( - "In this task, you are given a sentence " - "from a Terms of Service (ToS) document. " - "Predict the clause topics of the sentence " - "(0: Arbitration, " - "1: Unilateral change, " - "2: Content removal, " - "3: Jurisdiction, " - "4: Choice of law, " - "5: Limitation of liability, " - "6: Unilateral termination, " - "7: Contract by using, " - "8: Privacy included)" - ) - return lextreme(line, instruction, task_name) - - -def lextreme_covid19_emergency_event(line, task_name: str = None): - instruction = ( - "In this task, you are given a sentence from a European legislative document. " - "Predict the applicable measurements against COVID-19 " - "(0: State of Emergency, " - "1: Restrictions of fundamental rights and civil liberties, " - "2: Restrictions of daily liberties, " - "3: Closures / lockdown, " - "4: Suspension of international cooperation and commitments, " - "5: Police mobilization, " - "6: Army mobilization, " - "7: Government oversight)" - ) - - return lextreme(line, instruction, task_name) - - -def lextreme_multi_eurlex_level_1(line, task_name: str = None): - instruction = ( - "In this task, you are given a document from an EU law. Predict the level 1 concept in the EUROVOC taxonomy." - ) - return lextreme(line, instruction, task_name) - - -def lextreme_multi_eurlex_level_2(line, task_name: str = None): - instruction = ( - "In this task, you are given a document from an EU law. Predict the level 2 concept in the EUROVOC taxonomy." - ) - return lextreme(line, instruction, task_name) - - -def lextreme_multi_eurlex_level_3(line, task_name: str = None): - instruction = ( - "In this task, you are given a document from an EU law. Predict the level 3 concept in the EUROVOC taxonomy." - ) - - return lextreme(line, instruction, task_name) - - -def lextreme_greek_legal_ner(line, task_name: str = None): - instruction = ( - "In this task, you are given a sentence from Greek legislation. Predict the named entity type for each token." - ) - return lextreme(line, instruction, task_name) - - -def lextreme_legalnero(line, task_name: str = None): - instruction = ( - "In this task, you are given a sentence from Romanian legislation. " - "Predict the named entity type for each token." - ) - return lextreme(line, instruction, task_name) - - -def lextreme_lener_br(line, task_name: str = None): - instruction = ( - "In this task, you are given a sentence " - "from Brazilian legal documents (court decisions and legislation). " - "Predict the named entity type for each token." - ) - return lextreme(line, instruction, task_name) - - -def lextreme_mapa_coarse(line, task_name: str = None): - instruction = ( - "In this task, you are given a sentence from the EUR-Lex database. " - "Predict the coarse grained named entity type for each token." - ) - return lextreme(line, instruction, task_name) - - -def lextreme_mapa_fine(line, task_name: str = None): - instruction = ( - "In this task, you are given a sentence from the EUR-Lex database. " - "Predict the fine grained named entity type for each token." - ) - return lextreme(line, instruction, task_name) - - -def legal_summarization(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"###\nArticle: {line['article']}\n\nSummarize the above article.\n", - gold_index=0, - choices=[line["summary"]], - specific={"text": line["article"]}, - ) - - -def mgsm(line, question_key, answer_key, task_name: str = None): - if line["answer"] is not None: - query = f"{line['question']}\n{answer_key}" - gold = f" {line['answer'][len(answer_key) + 1 :]}" - else: - query = f"{question_key} {line['question']}\n{answer_key}" - gold = f" {str(line['answer_number'])}" - return Doc(task_name=task_name, query=query, choices=[gold], gold_index=0) - - -def mgsm_en(line, task_name: str = None): - question_key = "Question:" - answer_key = "Step-by-Step Answer:" - return mgsm(line, question_key, answer_key, task_name) - - -def mgsm_es(line, task_name: str = None): - question_key = "Pregunta:" - answer_key = "Respuesta paso a paso:" - return mgsm(line, question_key, answer_key, task_name) - - -def mgsm_fr(line, task_name: str = None): - question_key = "Question:" - answer_key = "R\u00e9ponse \u00e9tape par \u00e9tape :" - return mgsm(line, question_key, answer_key, task_name) - - -def mgsm_de(line, task_name: str = None): - question_key = "Frage:" - answer_key = "Schritt-f\u00fcr-Schritt-Antwort:" - return mgsm(line, question_key, answer_key, task_name) - - -def mgsm_ru(line, task_name: str = None): - question_key = "\u0417\u0430\u0434\u0430\u0447\u0430:" - answer_key = "\u041f\u043e\u0448\u0430\u0433\u043e\u0432\u043e\u0435\u0440\u0435\u0448\u0435\u043d\u0438\u0435:" - return mgsm(line, question_key, answer_key, task_name) - - -def mgsm_zh(line, task_name: str = None): - question_key = "\u95ee\u9898:" - answer_key = "\u9010\u6b65\u89e3\u7b54:" - return mgsm(line, question_key, answer_key, task_name) - - -def mgsm_ja(line, task_name: str = None): - question_key = "\u554f\u984c:" - answer_key = "\u30b9\u30c6\u30c3\u30d7\u3054\u3068\u306e\u7b54\u3048:" - return mgsm(line, question_key, answer_key, task_name) - - -def mgsm_th(line, task_name: str = None): - question_key = "\u0e42\u0e08\u0e17\u0e22\u0e4c:" - answer_key = "\u0e04\u0e33\u0e15\u0e2d\u0e1a\u0e17\u0e35\u0e25\u0e30\u0e02\u0e31\u0e49\u0e19\u0e15\u0e2d\u0e19:" - return mgsm(line, question_key, answer_key, task_name) - - -def mgsm_sw(line, task_name: str = None): - question_key = "Swali:" - answer_key = "Jibu la Hatua kwa Hatua:" - return mgsm(line, question_key, answer_key, task_name) - - -def mgsm_bn(line, task_name: str = None): - question_key = "\u09aa\u09cd\u09b0\u09b6\u09cd\u09a8:" - answer_key = "\u09a7\u09be\u09aa\u09c7 \u09a7\u09be\u09aa\u09c7 \u0989\u09a4\u09cd\u09a4\u09b0:" - return mgsm(line, question_key, answer_key, task_name) - - -def mgsm_te(line, task_name: str = None): - question_key = "\u0c2a\u0c4d\u0c30\u0c36\u0c4d\u0c28:" - answer_key = "\u0c26\u0c36\u0c32\u0c35\u0c3e\u0c30\u0c40\u0c17\u0c3e \u0c38\u0c2e\u0c3e\u0c27\u0c3e\u0c28\u0c02:" - return mgsm(line, question_key, answer_key, task_name) - - -def multilexsum(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"###\nArticle: {line['article']}\n\nSummarize the above article in 2 sentences.\n", - gold_index=0, - choices=[line["summary"]], - specific={"text": line["article"]}, - ) - - -def logiqa(line, task_name: str = None): - query = f"Passage: {line['context']}\nQuestion: {line['question']}\nChoices:\n" - query += "".join([f"{key}. {choice}\n" for key, choice in zip(["A", "B", "C", "D"], line["options"])]) - query += "Answer:" - - return Doc( - task_name=task_name, - query=query, - choices=[f" {c}" for c in line["options"]], - gold_index=["a", "b", "c", "d"].index(line["label"]), - ) - - -def lsat_qa(line, task_name: str = None): - query = f"The following are multiple choice questions (with answers).\nPassage: {line['passage']}\nQuestion: {line['question']}\n" - query += "".join([f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, line["references"])]) - query += "Answer:" - return Doc( - task_name=task_name, - query=query, - choices=LETTER_INDICES[: len(line["references"])], - gold_index=line["gold_index"], - instruction="The following are multiple choice questions (with answers).\n", - ) - - -def math_500(line, task_name: str = None): - # Prompt template adapted from - # - simple-evals: https://github.com/openai/simple-evals/blob/6e84f4e2aed6b60f6a0c7b8f06bbbf4bfde72e58/math_eval.py#L17 - # - Llama 3: https://huggingface.co/datasets/meta-llama/Llama-3.2-1B-Instruct-evals/viewer/Llama-3.2-1B-Instruct-evals__math__details?views%5B%5D=llama_32_1b_instruct_evals__math__details - # Note that it is important to have the final answer in a box for math-verify to work correctly - MATH_QUERY_TEMPLATE = """ -Solve the following math problem efficiently and clearly. The last line of your response should be of the following format: 'Therefore, the final answer is: $\\boxed{{ANSWER}}$. I hope it is correct' (without quotes) where ANSWER is just the final number or expression that solves the problem. Think step by step before answering. - -{Question} -""".strip() - - return Doc( - task_name=task_name, - query=MATH_QUERY_TEMPLATE.format(Question=line["problem"]), - gold_index=0, - choices=[line["solution"]], - ) - - -def math(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"Problem: {line['problem']}\nAnswer:", - gold_index=0, - choices=[f" {line['solution']}"], - ) - - -def math_cot(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"{line['problem']}\nPlease reason step by step, and put your final answer within \\boxed{{}}.", - gold_index=0, - choices=[f" {line['solution']}"], - ) - - -def math_helm(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"Given a mathematics problem, determine the answer. Simplify your answer as much as possible.\nProblem: {line['problem']}\nAnswer: $\n###\n", - gold_index=0, - choices=[f" {line['solution']}"], - instruction="Given a mathematics problem, determine the answer. Simplify your answer as much as possible.\n", - ) - - -def mathqa(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"Questions: {line['Problem']}\nAnswer", - choices=[ - c[4:].rstrip(" ,") # todo: check below regew - for c in re.findall(r"[abcd] \) .*?, |e \) .*?$", line["options"]) - ], - gold_index=["a", "b", "c", "d", "e"].index(line["correct"]), - ) - - -def me_q_sum(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"###\nArticle:{line['query']}\n\nSummarize the above article in 1 sentence.\n", - gold_index=0, - choices=[line["answer"]], - ) - - -def med_dialog(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"###\nArticle:{line['src']}\n\nSummarize the above article in 1 sentence.\n", - gold_index=0, - choices=[line["tgt"]], - ) - - -def med_mcqa(line, task_name: str = None): - query = f"Give a letter answer among A, B, C or D.\nQuestion: {line['question']}\n" - query += "".join( - [ - f"{key}. {choice}\n" - for key, choice in zip(LETTER_INDICES, [line["opa"], line["opb"], line["opc"], line["opd"]]) - ] - ) - query += "Answer:" - return Doc( - task_name=task_name, - query=query, - choices=LETTER_INDICES[:4], - gold_index=line["cop"] - 1, - instruction="Give a letter answer among A, B, C or D.\n", - ) - - -def med_paragraph_simplification(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"###\nArticle:{line['query']}\n\nSummarize the above article in 10 sentences.\n", - gold_index=0, - choices=[line["answer"]], - ) - - -def med_qa(line, task_name: str = None): - query = f"Give a letter answer among A, B, C or D.\nQuestion: {line['question']}\n" - query += "".join([f"{option['key']}. {option['value']}\n" for option in line["options"]]) - query += "Answer:" - return Doc( - task_name=task_name, - query=query, - choices=[opt["key"] for opt in line["options"]], - gold_index=LETTER_INDICES.index(line["answer_idx"]), - instruction="Give a letter answer among A, B, C or D.\n", - ) - - -def mmlu(line, topic, task_name: str = None): - query = f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n" - query += line["question"] + "\n" - query += "".join([f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, line["choices"])]) - query += "Answer:" - - gold_ix = LETTER_INDICES.index(line["answer"]) if isinstance(line["answer"], str) else line["answer"] - is_few_shots = line.get("__few_shots", False) # We are adding few shots - - return Doc( - task_name=task_name, - query=query, - choices=[" A", " B", " C", " D"] if is_few_shots else ["A", "B", "C", "D"], - gold_index=gold_ix, - instruction=f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n", - ) - - -def custom_mmlu_thom(line, task_name: str = None): - topic = "abstract_algebra" - query = f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n" - query += line["question"] + "\n" - choices = [line["option_1"], line["option_2"], line["option_3"], line["option_4"]] - query += "".join([f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, choices)]) - query += "Answer:" - - gold_ix = LETTER_INDICES.index(line["answer"]) if isinstance(line["answer"], str) else line["answer"] - is_few_shots = line.get("__few_shots", False) # We are adding few shots - - return Doc( - task_name=task_name, - query=query, - choices=[" A", " B", " C", " D"] if is_few_shots else ["A", "B", "C", "D"], - gold_index=gold_ix, - instruction=f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n", - ) - - -def mmlu_abstract_algebra(line, task_name: str = None): - return mmlu(line, "abstract_algebra", task_name) - - -def mmlu_anatomy(line, task_name: str = None): - return mmlu(line, "anatomy", task_name) - - -def mmlu_astronomy(line, task_name: str = None): - return mmlu(line, "astronomy", task_name) - - -def mmlu_business_ethics(line, task_name: str = None): - return mmlu(line, "business_ethics", task_name) - - -def mmlu_clinical_knowledge(line, task_name: str = None): - return mmlu(line, "clinical_knowledge", task_name) - - -def mmlu_college_biology(line, task_name: str = None): - return mmlu(line, "college_biology", task_name) - - -def mmlu_college_chemistry(line, task_name: str = None): - return mmlu(line, "college_chemistry", task_name) - - -def mmlu_college_computer_science(line, task_name: str = None): - return mmlu(line, "college_computer_science", task_name) - - -def mmlu_college_mathematics(line, task_name: str = None): - return mmlu(line, "college_mathematics", task_name) - - -def mmlu_college_medicine(line, task_name: str = None): - return mmlu(line, "college_medicine", task_name) - - -def mmlu_college_physics(line, task_name: str = None): - return mmlu(line, "college_physics", task_name) - - -def mmlu_computer_security(line, task_name: str = None): - return mmlu(line, "computer_security", task_name) - - -def mmlu_conceptual_physics(line, task_name: str = None): - return mmlu(line, "conceptual_physics", task_name) - - -def mmlu_econometrics(line, task_name: str = None): - return mmlu(line, "econometrics", task_name) - - -def mmlu_electrical_engineering(line, task_name: str = None): - return mmlu(line, "electrical_engineering", task_name) - - -def mmlu_elementary_mathematics(line, task_name: str = None): - return mmlu(line, "elementary_mathematics", task_name) - - -def mmlu_formal_logic(line, task_name: str = None): - return mmlu(line, "formal_logic", task_name) - - -def mmlu_global_facts(line, task_name: str = None): - return mmlu(line, "global_facts", task_name) - - -def mmlu_high_school_biology(line, task_name: str = None): - return mmlu(line, "high_school_biology", task_name) - - -def mmlu_high_school_chemistry(line, task_name: str = None): - return mmlu(line, "high_school_chemistry", task_name) - - -def mmlu_high_school_computer_science(line, task_name: str = None): - return mmlu(line, "high_school_computer_science", task_name) - - -def mmlu_high_school_european_history(line, task_name: str = None): - return mmlu(line, "high_school_european_history", task_name) - - -def mmlu_high_school_geography(line, task_name: str = None): - return mmlu(line, "high_school_geography", task_name) - - -def mmlu_high_school_government_and_politics(line, task_name: str = None): - return mmlu(line, "high_school_government_and_politics", task_name) - - -def mmlu_high_school_macroeconomics(line, task_name: str = None): - return mmlu(line, "high_school_macroeconomics", task_name) - - -def mmlu_high_school_mathematics(line, task_name: str = None): - return mmlu(line, "high_school_mathematics", task_name) - - -def mmlu_high_school_microeconomics(line, task_name: str = None): - return mmlu(line, "high_school_microeconomics", task_name) - - -def mmlu_high_school_physics(line, task_name: str = None): - return mmlu(line, "high_school_physics", task_name) - - -def mmlu_high_school_psychology(line, task_name: str = None): - return mmlu(line, "high_school_psychology", task_name) - - -def mmlu_high_school_statistics(line, task_name: str = None): - return mmlu(line, "high_school_statistics", task_name) - - -def mmlu_high_school_us_history(line, task_name: str = None): - return mmlu(line, "high_school_us_history", task_name) - - -def mmlu_high_school_world_history(line, task_name: str = None): - return mmlu(line, "high_school_world_history", task_name) - - -def mmlu_human_aging(line, task_name: str = None): - return mmlu(line, "human_aging", task_name) - - -def mmlu_human_sexuality(line, task_name: str = None): - return mmlu(line, "human_sexuality", task_name) - - -def mmlu_international_law(line, task_name: str = None): - return mmlu(line, "international_law", task_name) - - -def mmlu_jurisprudence(line, task_name: str = None): - return mmlu(line, "jurisprudence", task_name) - - -def mmlu_logical_fallacies(line, task_name: str = None): - return mmlu(line, "logical_fallacies", task_name) - - -def mmlu_machine_learning(line, task_name: str = None): - return mmlu(line, "machine_learning", task_name) - - -def mmlu_management(line, task_name: str = None): - return mmlu(line, "management", task_name) - - -def mmlu_marketing(line, task_name: str = None): - return mmlu(line, "marketing", task_name) - - -def mmlu_medical_genetics(line, task_name: str = None): - return mmlu(line, "medical_genetics", task_name) - - -def mmlu_miscellaneous(line, task_name: str = None): - return mmlu(line, "miscellaneous", task_name) - - -def mmlu_moral_disputes(line, task_name: str = None): - return mmlu(line, "moral_disputes", task_name) - - -def mmlu_moral_scenarios(line, task_name: str = None): - return mmlu(line, "moral_scenarios", task_name) - - -def mmlu_nutrition(line, task_name: str = None): - return mmlu(line, "nutrition", task_name) - - -def mmlu_philosophy(line, task_name: str = None): - return mmlu(line, "philosophy", task_name) - - -def mmlu_prehistory(line, task_name: str = None): - return mmlu(line, "prehistory", task_name) - - -def mmlu_professional_accounting(line, task_name: str = None): - return mmlu(line, "professional_accounting", task_name) - - -def mmlu_professional_law(line, task_name: str = None): - return mmlu(line, "professional_law", task_name) - - -def mmlu_professional_medicine(line, task_name: str = None): - return mmlu(line, "professional_medicine", task_name) - - -def mmlu_professional_psychology(line, task_name: str = None): - return mmlu(line, "professional_psychology", task_name) - - -def mmlu_public_relations(line, task_name: str = None): - return mmlu(line, "public_relations", task_name) - - -def mmlu_security_studies(line, task_name: str = None): - return mmlu(line, "security_studies", task_name) - - -def mmlu_sociology(line, task_name: str = None): - return mmlu(line, "sociology", task_name) - - -def mmlu_us_foreign_policy(line, task_name: str = None): - return mmlu(line, "us_foreign_policy", task_name) - - -def mmlu_virology(line, task_name: str = None): - return mmlu(line, "virology", task_name) - - -def mmlu_world_religions(line, task_name: str = None): - return mmlu(line, "world_religions", task_name) - - -def mmlu_harness(line, task_name: str = None): - topic = line["subject"] - query = f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n" - query += line["question"] + "\n" - query += "".join([f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, line["choices"])]) - query += "Answer:" - - gold_ix = LETTER_INDICES.index(line["answer"]) if isinstance(line["answer"], str) else line["answer"] - - return Doc( - task_name=task_name, - query=query, - choices=[" A", " B", " C", " D"], - gold_index=gold_ix, - instruction=f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n", - ) - - -def mmlu_helm(line, task_name: str = None): - subject = line["subject"] - query = f"The following are multiple choice questions (with answers) about {subject.replace('_', ' ')}.\n\nQuestion: {line['question']}" - query += "".join([f"\n{key}. {choice}" for key, choice in zip(LETTER_INDICES, line["choices"])]) - query += "\nAnswer:" - - gold_ix = LETTER_INDICES.index(line["answer"]) if isinstance(line["answer"], str) else line["answer"] - - return Doc( - task_name=task_name, - query=query, - choices=[" A", " B", " C", " D"], - gold_index=gold_ix, - fewshot_sorting_class=line["choices"][gold_ix], - instruction=f"The following are multiple choice questions (with answers) about {subject.replace('_', ' ')}.\n\n", - ) - - -def mmlu_redux_2(line, topic, task_name: str = None): - """ - Ref: https://arxiv.org/abs/2406.04127 - """ - query = f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n" - query += line["question"] + "\n" - query += "".join([f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, line["choices"])]) - query += "Answer: " - - # Handle answer format - MMLU-Redux-2 uses integer indices directly - gold_ix = line["answer"] if isinstance(line["answer"], int) else int(line["answer"]) - - return Doc( - task_name=task_name, - query=query, - choices=LETTER_INDICES[: len(line["choices"])], - gold_index=gold_ix, - instruction=f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n", - ) - - -def mmlu_qa_abstract_algebra(line, task_name: str = None): - return mmlu_qa(line, "abstract_algebra", task_name) - - -def mmlu_qa_college_chemistry(line, task_name: str = None): - return mmlu_qa(line, "college_chemistry", task_name) - - -def mmlu_qa_global_facts(line, task_name: str = None): - return mmlu_qa(line, "global_facts", task_name) - - -def mmlu_qa_miscellaneous(line, task_name: str = None): - return mmlu_qa(line, "miscellaneous", task_name) - - -def mmlu_qa_nutrition(line, task_name: str = None): - return mmlu_qa(line, "nutrition", task_name) - - -def mmlu_qa_us_foreign_policy(line, task_name: str = None): - return mmlu_qa(line, "us_foreign_policy", task_name) - - -def mmlu_qa(line, subject, task_name: str = None): - query = f"The following are multiple choice questions (with answers) about {subject.replace('_', ' ')}.\nQuestion: {line['question']}" - query += "".join([f"\n{key}. {choice}" for key, choice in zip(LETTER_INDICES, line["choices"])]) - query += "\nAnswer:" - - return Doc( - task_name=task_name, - query=query, - choices=["A", "B", "C", "D"], # line["choices"], - gold_index=LETTER_INDICES.index(line["answer"]) if isinstance(line["answer"], str) else line["answer"], - instruction=f"The following are multiple choice questions (with answers) about {subject.replace('_', ' ')}.\n", - ) - - -def mnli(line, task_name: str = None): - hypothesis = line["hypothesis"].strip() + ("" if line["hypothesis"].strip().endswith(".") else ".") - return Doc( - task_name=task_name, - query=f"{line['premise']}\nQuestion: {hypothesis} True, False or Neither?\nAnswer:", - choices=[" True", " Neither", " False"], - gold_index=int(line["label"]), - ) - - -def mrpc(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"Sentence 1: {line['sentence1']}\nSentence 2: {line['sentence2']}\nQuestion: Do both sentences mean the same thing?\nAnswer:", - choices=[" no", " yes"], - gold_index=int(line["label"]), - ) - - -def multirc(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"{line['paragraph']}\nQuestion: {line['question']}\nAnswer:", - choices=[f" {line['answer']}\nIs the answer correct? yes", f" {line['answer']}\nIs the answer correct? no"], - gold_index=0 if line["label"] else 1, - ) - - -def musr(line, task_name: str = None): - choices = ast.literal_eval(line["choices"]) - - query = line["narrative"] + "\n\n" - query += line["question"] + "\n\n" - for i, choice in enumerate(choices): - query += f"{i + 1} - {choice}\n" - query += "Answer:" - - return Doc(task_name=task_name, query=query, choices=choices, gold_index=line["answer_index"]) - - -def mutual(line, task_name: str = None): - def clean(text): - replace_list = [(" '", "'"), (" \n", "\n"), ("\n ", "\n"), (" n't", "n't"), ("`` ", '"'), ("''", '"')] - replace_list.extend([(" :", ":"), (" ;", ";"), (" !", "!"), (" ?", "?"), (" ,", ","), (" .", ".")]) - for in_str, out_str in replace_list: - text = text.replace(in_str, out_str) - return text - - return Doc( - task_name=task_name, - query=clean(line["article"]), - choices=[f" {clean(option)}" for option in line["options"]], - gold_index=["A", "B", "C", "D"].index(line["answers"]), - ) - - -def narrativeqa(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"Passage: {line['passage']}\nQuestion: {line['question']}\nAnswer:", - gold_index=list(range(len(line["references"]))), - choices=[[str(a) for a in line["references"]]], - ) - - -def natural_qa_closedbook(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"Question: {line['question']}\nAnswer: ", - gold_index=0, - choices=[line["short_answers"]], - ) - - -def natural_qa_openbook_longans(line, task_name: str = None): - ans_idx = random.randint(0, len(line["short_answers"]) - 1) - return Doc( - task_name=task_name, - query=f"Passage: {line['long_answers'][ans_idx]}\n\nQuestion: {line['question']}\nAnswer: ", - gold_index=0, - choices=[line["short_answers"]], - ) - - -def natural_qa_openbook_wiki(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"Title: {line['title']}\n\nPassage: {line['document']}\n\n Question: {line['question']}\nAnswer: ", - gold_index=0, - choices=[line["short_answers"]], - ) - - -def newsqa(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"Passage: {line['text']}\nQuestion {line['questions']}\nAnswer: ", - gold_index=0, - choices=[line["gold"]], - ) - - -def numeracy(line, task_name: str = None): - name = ["x", "y", "z"] - vars = "" - for ix, value in enumerate(line["vars"]): - vars += f"{name[ix]} {value}, " - vars += name[ix + 1] - - return Doc(task_name=task_name, query=f"{line['equation']}, {vars}", gold_index=0, choices=[str(line["output"])]) - - -def openbookqa(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"{line['question_stem']}", - choices=[f" {c}" for c in line["choices"]["text"]], - gold_index=["A", "B", "C", "D", "E"].index(line["answerKey"].strip()), - # "metric": "choices_loglikelihood", - ) - - -def openbookqa_helm(line, task_name: str = None): - query = "The following are multiple choice questions (with answers) about common sense.\n" - query += f"Question: {line['question_stem']}\n" - query += "".join([f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, line["choices"]["text"])]) - query += "Answer: " - - gold_ix = ["A", "B", "C", "D", "E"].index(line["answerKey"].strip()) - return Doc( - task_name=task_name, - query=query, - choices=["A", "B", "C", "D", "E"], - gold_index=gold_ix, - instruction="The following are multiple choice questions (with answers) about common sense.\n", - ) - - -def piqa_harness(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"Question: {line['goal']}\nAnswer:", - choices=[f" {line['sol1']}", f" {line['sol2']}"], - gold_index=int(line["label"]), - # "metric": "choices_loglikelihood", - ) - - -def piqa_helm(line, task_name: str = None): - query = "The following are multiple choice questions (with answers) about common sense.\n" - query += f"Question: {line['goal']}\n" - query += "".join([f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, [line["sol1"], line["sol2"]])]) - query += "Answer: " - - gold_ix = int(line["label"]) - is_few_shots = line.get("__few_shots", False) - return Doc( - task_name=task_name, - query=query, - choices=["A", "B"] if not is_few_shots else [line["sol1"], line["sol2"]], - gold_index=gold_ix, - instruction="The following are multiple choice questions (with answers) about common sense.\n", - ) - - -def prost(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"{line['context']}\nQuestion: {line['ex_question']}\nAnswer:", - choices=[f" {line['A']}", f" {line['B']}", f" {line['C']}", f" {line['D']}"], - gold_index=int(line["label"]), - ) - - -def pubmed_qa(line, task_name: str = None): - contexts = "\n".join(line["context"]["contexts"]) - return Doc( - task_name=task_name, - query=f"Abstract: {contexts}\nQuestion: {line['question']}\nAnswer:", - choices=[" yes", " no", " maybe"], - gold_index=["yes", "no", "maybe"].index(line["final_decision"]), - ) - - -def pubmed_qa_helm(line, task_name: str = None): - query = "Answer A for yes, B for no or C for maybe.\n\nContext: " - query += "\n".join( - [ - f"{label.title()}. {context}" - for label, context in zip(line["context"]["labels"], line["context"]["contexts"]) - ] - ) - query += f"\n\nQuestion: {line['question']}\nAnswer: " - gold_ix = ["yes", "no", "maybe"].index(line["final_decision"]) - return Doc( - task_name=task_name, - query=query, - choices=["A", "B", "C"], - gold_index=gold_ix, - ) - - -def qa4mre(line, task_name: str = None): - source = line["document_str"].strip().replace("'", "'") - return Doc( - task_name=task_name, - query=f"{source}\nQuestion: {line['question_str']}\nAnswer:", - choices=[f" {answer}" for answer in line["answer_options"]["answer_str"]], - gold_index=int(line["correct_answer_id"]) - 1, - ) - - -def qasper(line, task_type="generative", task_name: str = None): - def extract_answer(answer_choices): - keys = ["free_form_answer", "extractive_spans"] - for k in keys: - if answer_choices[k]: - return answer_choices[k] - if answer_choices["unanswerable"]: - return "unanswerable" - if answer_choices["yes_no"]: - return "yes" - return "no" - - results = [] - for question, answer_list in zip(line["qas"]["question"], line["qas"]["answers"]): - for answer in answer_list["answer"]: - gold = extract_answer(answer) - # Qasper is either evaluated with loglikelihoods for yes no questions, or generative acc for the rest - if gold == "yes" or gold == "no": # log likelihood - results.append( - Doc( - task_name=task_name, - query=f"TITLE: {line['title']}\nABSTRACT: {line['abstract']}\n\nQ: {question}\n\nA:", - gold_index=int(gold == "no"), - choices=[" yes", " no"], - ) - ) - elif task_type == "generative": - results.append( - Doc( - task_name=task_name, - query=f"TITLE: {line['title']}\nABSTRACT: {line['abstract']}\n\nQ: {question}\n\nA:", - choices=[gold], - gold_index=0, - ) - ) - return results - - -def qasper_ll(line, task_name: str = None): - return qasper(line, "", task_name) - - -def qnli(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"{line['question']}\n{line['sentence']}\nQuestion: Does this response answer the question?\nAnswer:", - choices=[" yes", " no"], - gold_index=int(line["label"]), - ) - - -def qqp(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"Question 1: {line['question1']}\nQuestion 2: {line['question2']}\nQuestion: Do both questions ask the same thing?\nAnswer:", - choices=[" no", " yes"], - gold_index=int(line["label"]), - ) - - -def quac(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"{line['prompt']}\nAnswer:", - gold_index=list(range(len(line["references"]))), - choices=as_list(line["references"]), - ) - - -def race(line, task_name: str = None): # high - line["problems"] = ast.literal_eval(line["problems"]) - text = f"Article: {line['article']}\n\n" - for problem in line["problems"][:-1]: - index = ["A", "B", "C", "D", "E"].index(problem["answer"]) - if problem["question"][-6:] == " _ .": - text += f"{problem['question'][-5:]}{problem['options'][index]}\n" - else: - text += f"Question: {problem['question']}\n" - text += f"Answer: {problem['options'][index]}\n" - # The harness function is faulty and not adding "Question" before the last one... - text += line["problems"][-1]["question"] - return Doc( - task_name=task_name, - query=text, - choices=[f" {o}" for o in line["problems"][-1]["options"]], - gold_index=["A", "B", "C", "D", "E"].index(line["problems"][-1]["answer"]), - ) - - -def raft(line, query_keys, instruction, task_name: str = None): - query = instruction - query += "\n".join([f"{key}: {line[key]}" for key in query_keys]) - query += "\nLabel:" - return Doc(task_name=task_name, query=query, gold_index=0, choices=[str(line["Label"])], instruction=instruction) - - -def raft_ade_corpus_v2(line, task_name: str = None): - instruction = "Label the sentence based on whether it is related to an adverse drug effect (ADE). Details are described below:\nDrugs: Names of drugs and chemicals that include brand names, trivial names, abbreviations and systematic names were annotated. Mentions of drugs or chemicals should strictly be in a therapeutic context. This category does not include the names of metabolites, reaction byproducts, or hospital chemicals (e.g. surgical equipment disinfectants).\nAdverse effect: Mentions of adverse effects include signs, symptoms, diseases, disorders, acquired abnormalities, deficiencies, organ damage or death that strictly occur as a consequence of drug intake.\nPossible labels:\n1. ADE-related\n2. not ADE-related" - query_keys = ["Sentence"] - return raft(line, query_keys, instruction, task_name) - - -def raft_banking_77(line, task_name: str = None): - instruction = "The following is a banking customer service query. Classify the query into one of the 77 categories available.\nPossible labels:\n1. Refund_not_showing_up\n2. activate_my_card\n3. age_limit\n4. apple_pay_or_google_pay\n5. atm_support\n6. automatic_top_up\n7. balance_not_updated_after_bank_transfer\n8. balance_not_updated_after_cheque_or_cash_deposit\n9. beneficiary_not_allowed\n10. cancel_transfer\n11. card_about_to_expire\n12. card_acceptance\n13. card_arrival\n14. card_delivery_estimate\n15. card_linking\n16. card_not_working\n17. card_payment_fee_charged\n18. card_payment_not_recognised\n19. card_payment_wrong_exchange_rate\n20. card_swallowed\n21. cash_withdrawal_charge\n22. cash_withdrawal_not_recognised\n23. change_pin\n24. compromised_card\n25. contactless_not_working\n26. country_support\n27. declined_card_payment\n28. declined_cash_withdrawal\n29. declined_transfer\n30. direct_debit_payment_not_recognised\n31. disposable_card_limits\n32. edit_personal_details\n33. exchange_charge\n34. exchange_rate\n35. exchange_via_app\n36. extra_charge_on_statement\n37. failed_transfer\n38. fiat_currency_support\n39. get_disposable_virtual_card\n40. get_physical_card\n41. getting_spare_card\n42. getting_virtual_card\n43. lost_or_stolen_card\n44. lost_or_stolen_phone\n45. order_physical_card\n46. passcode_forgotten\n47. pending_card_payment\n48. pending_cash_withdrawal\n49. pending_top_up\n50. pending_transfer\n51. pin_blocked\n52. receiving_money\n53. request_refund\n54. reverted_card_payment?\n55. supported_cards_and_currencies\n56. terminate_account\n57. top_up_by_bank_transfer_charge\n58. top_up_by_card_charge\n59. top_up_by_cash_or_cheque\n60. top_up_failed\n61. top_up_limits\n62. top_up_reverted\n63. topping_up_by_card\n64. transaction_charged_twice\n65. transfer_fee_charged\n66. transfer_into_account\n67. transfer_not_received_by_recipient\n68. transfer_timing\n69. unable_to_verify_identity\n70. verify_my_identity\n71. verify_source_of_funds\n72. verify_top_up\n73. virtual_card_not_working\n74. visa_or_mastercard\n75. why_verify_identity\n76. wrong_amount_of_cash_received\n77. wrong_exchange_rate_for_cash_withdrawal" - query_keys = ["Query"] - return raft(line, query_keys, instruction, task_name) - - -def raft_neurips_impact_statement_risks(line, task_name: str = None): - instruction = "Label the impact statement based on whether it mentions a harmful application of the research done in the paper. Make sure the statement is sufficient to conclude there are harmful applications of the research being done, not a past risk that this research is solving.\nPossible labels:\n1. doesn't mention a harmful application\n2. mentions a harmful application" - query_keys = ["Impact statement", "Paper title"] - return raft(line, query_keys, instruction, task_name) - - -def raft_one_stop_english(line, task_name: str = None): - instruction = "The following is an article sourced from The Guardian newspaper, and rewritten by teachers to suit three levels of adult English as Second Language (ESL) learners: elementary, intermediate, and advanced. Predict the level of the article.\nPossible labels:\n1. advanced\n2. elementary\n3. intermediate" - query_keys = ["Article"] - return raft(line, query_keys, instruction, task_name) - - -def raft_overruling(line, task_name: str = None): - instruction = "In law, an overruling sentence is a statement that nullifies a previous case decision as a precedent, by a constitutionally valid statute or a decision by the same or higher ranking court which establishes a different rule on the point of law involved. Label the sentence based on whether it is overruling or not.\nPossible labels:\n1. not overruling\n2. overruling" - query_keys = ["Sentence"] - return raft(line, query_keys, instruction, task_name) - - -def raft_semiconductor_org_types(line, task_name: str = None): - instruction = 'The dataset is a list of institutions that have contributed papers to semiconductor conferences in the last 25 years, as catalogued by IEEE and sampled randomly. The goal is to classify the institutions into one of three categories: "university", "company" or "research institute".\nPossible labels:\n1. company\n2. research institute\n3. university' - query_keys = ["Organization name", "Paper title"] - return raft(line, query_keys, instruction, task_name) - - -def raft_systematic_review_inclusion(line, task_name: str = None): - instruction = "Identify whether this paper should be included in a meta-review which includes the findings of systematic reviews on interventions designed to promote charitable donations.\nIncluded reviews should describe monetary charitable donations, assess any population of participants in any context, and be peer reviewed and written in English.\nThey should not report new data, be non-systematic reviews, consider cause-related marketing or other kinds of prosocial behaviour.\nPossible labels:\n1. included\n2. not included" - query_keys = ["Title", "Abstract", "Journal"] - return raft(line, query_keys, instruction, task_name) - - -def raft_tai_safety_research(line, task_name: str = None): - instruction = 'Transformative AI (TAI) is defined as AI that precipitates a transition comparable to (or more significant than) the agricultural or industrial revolution. Label a paper as "TAI safety research" if:\n1. The contents of the paper are directly motivated by, and substantively inform, the challenge of ensuring good outcomes for TAI,\n2. There is substantive content on AI safety, not just AI capabilities,\n3. The intended audience is the community of researchers,\n4. It meets a subjective threshold of seriousness/quality,\n5. Peer review is not required.\nPossible labels:\n1. TAI safety research\n2. not TAI safety research' - query_keys = ["Title", "Abstract Note", "Publication Title", "Item Type", "Publication Year"] - return raft(line, query_keys, instruction, task_name) - - -def raft_terms_of_service(line, task_name: str = None): - instruction = "Label the sentence from a Terms of Service based on whether it is potentially unfair. If it seems clearly unfair, mark it as potentially unfair.\nAccording to art. 3 of the Directive 93/13 on Unfair Terms in Consumer Contracts, a contractual term is unfair if: 1) it has not been individually negotiated; and 2) contrary to the requirement of good faith, it causes a significant imbalance in the parties rights and obligations, to the detriment of the consumer.\nDetails on types of potentially unfair clauses are found below:\nThe jurisdiction clause stipulates what courts will have the competence to adjudicate disputes under the contract. Jurisdiction clauses giving consumers a right to bring disputes in their place of residence were marked as clearly fair, whereas clauses stating that any judicial proceeding takes a residence away were marked as clearly unfair.\nThe choice of law clause specifies what law will govern the contract, meaning also what law will be applied in potential adjudication of a dispute arising under the contract. Clauses defining the applicable law as the law of the consumer's country of residence were marked as clearly fair. In every other case, the choice of law clause was considered as potentially unfair.\nThe limitation of liability clause stipulates that the duty to pay damages is limited or excluded, for certain kind of losses, under certain conditions. Clauses that explicitly affirm non-excludable providers' liabilities were marked as clearly fair. Clauses that reduce, limit, or exclude the liability of the service provider were marked as potentially unfair when concerning broad categories of losses or causes of them.\nThe unilateral change clause specifies the conditions under which the service provider could amend and modify the terms of service and/or the service itself. Such clause was always considered as potentially unfair.\nThe unilateral termination clause gives provider the right to suspend and/or terminate the service and/or the contract, and sometimes details the circumstances under which the provider claims to have a right to do so.\nThe contract by using clause stipulates that the consumer is bound by the terms of use of a specific service, simply by using the service, without even being required to mark that he or she has read and accepted them. We always marked such clauses as potentially unfair.\nThe content removal gives the provider a right to modify/delete user's content, including in-app purchases, and sometimes specifies the conditions under which the service provider may do so.\nThe arbitration clause requires or allows the parties to resolve their disputes through an arbitration process, before the case could go to court. Clauses stipulating that the arbitration should take place in a state other then the state of consumer's residence or be based on arbiter's discretion were marked as clearly unfair. Clauses defining arbitration as fully optional were marked as clearly fair.\nPossible labels:\n1. not potentially unfair\n2. potentially unfair" - query_keys = ["Sentence"] - return raft(line, query_keys, instruction, task_name) - - -def raft_tweet_eval_hate(line, task_name: str = None): - instruction = "Label whether the following tweet contains hate speech against either immigrants or women. Hate Speech (HS) is commonly defined as any communication that disparages a person or a group on the basis of some characteristic such as race, color, ethnicity, gender, sexual orientation, nationality, religion, or other characteristics.\nPossible labels:\n1. hate speech\n2. not hate speech" - query_keys = ["Tweet"] - return raft(line, query_keys, instruction, task_name) - - -def raft_twitter_complaints(line, task_name: str = None): - instruction = "A complaint presents a state of affairs which breaches the writer\u2019s favorable expectation. Label the tweet text based on whether it contains a complaint.\nPossible labels:\n1. complaint\n2. no complaint" - query_keys = ["Tweet text"] - return raft(line, query_keys, instruction, task_name) - - -def real_toxicity_prompts(line, task_name: str = None): - return Doc(task_name=task_name, query=line["Doc"]["text"], choices=None, gold_index=None) - - -def record(line, task_name: str = None): - # LL f1 and em over examples, - initial_text, *highlights = line["passage"].strip().split("\n@highlight\n") - query = f"{initial_text}\n\n" - for highlight in highlights: - query += f" - {highlight}.\n" - - choices = [f" - {line['query'].replace('@placeholder', entity)}" for entity in line["entities"]] - return Doc( - task_name=task_name, - query=query, - choices=choices, - gold_index=[line["entities"].index(a) for a in line["answers"]], # any of these - ) - - -def rte(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"{line['sentence1']}\nQuestion: {line['sentence2']} True or False?\nAnswer:", - choices=[" True", " False"], # 0 = entailment, 1 = not entailment - gold_index=int(line["label"]), - # "metric": "choices_loglikelihood", - ) - - -def sciq(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"{line['support']}\nQuestion: {line['question']}\nAnswer:".strip(), - choices=[ - f" {c}" for c in [line["distractor1"], line["distractor2"], line["distractor3"], line["correct_answer"]] - ], - gold_index=3, - ) - - -def siqa(line, task_name: str = None): - query = "The following are multiple choice questions (with answers) about common sense.\n" - query += f"Question: {line['context']} {line['question']}\n" - query += "".join( - [ - f"{key}. {choice}\n" - for key, choice in zip(LETTER_INDICES, [line["answerA"], line["answerB"], line["answerC"]]) - ] - ) - query += "Answer: " - - return Doc( - task_name=task_name, - query=query, - choices=["A", "B", "C"], - gold_index=int(line["label"]) - 1, - instruction="The following are multiple choice questions (with answers) about common sense.\n", - ) - - -def sst(line, task_name: str = None): - def general_detokenize(cur_string): - cur_string = cur_string.replace(" n't", "n't") - cur_string = cur_string.replace(" )", ")") - cur_string = cur_string.replace("( ", "(") - cur_string = cur_string.replace('" ', '"') - cur_string = cur_string.replace(' "', '"') - cur_string = re.sub(r" (['.,])", r"\1", cur_string) - return cur_string - - return Doc( - task_name=task_name, - query=f"{general_detokenize(line['sentence'])}\nQuestion: Is this sentence positive or negative?\nAnswer:", - choices=[" negative", " positive"], - gold_index=int(line["label"]), - ) - - -def stsb(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"sentence 1: {line['sentence1']}\nsentence 2: {line['sentence2']}\nOn a scale of 0 to 5, how similar are the two sentences?\nAnswer:", - gold_index=0, - choices=[line["label"]], - ) - - -def storycloze(line, task_name: str = None): - # LL - return Doc( - task_name=task_name, - query=" ".join( - [line["input_sentence_1"], line["input_sentence_2"], line["input_sentence_3"], line["input_sentence_4"]] - ), - choices=[f" {line['sentence_quiz1']}", f" {line['sentence_quiz2']}"], - gold_index=int(line["answer_right_ending"]) - 1, - # "metric": "choices_loglikelihood", - ) - - -def swag(line, task_name: str = None): - return Doc( - task_name=task_name, - query=line["startphrase"], - choices=[f" {c}" for c in [line["ending0"], line["ending1"], line["ending2"], line["ending3"]]], - gold_index=int(line["label"]), - ) - - -def synthetic_reasoning(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"Please solve the following problem.\n\n{line['source']}\nTarget: ", - gold_index=0, - choices=[line["target"]], - instruction="Please solve the following problem.\n\n", - ) - - -def synthetic_reasoning_natural(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"Please solve the following problem.\n\nRules: \n{line['question']}", - gold_index=0, - choices=[line["target"]], - instruction="Please solve the following problem.\n\n", - ) - - -def the_pile(line, task_name: str = None): - return Doc(task_name=task_name, query=line["text"], gold_index=None, choices=None) - - -def toxigen(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"Is the following statement hateful? Respond with either Yes or No. Statement: '{line['text']}'", - choices=["No", "Yes"], - gold_index=int(round(((line["toxicity_ai"] + line["toxicity_human"]) > 5.5), 0)), - ) - - -def triviaqa(line, task_name: str = None): - def _remove_prefixes(aliases): - # Optimization: Remove any alias that has a strict prefix elsewhere in the list - # we can do this because if the prefix is acceptable by isgreedy, we can stop looking - aliases.sort() - ret = [aliases[0]] - for alias in aliases[1:]: - if not alias.startswith(ret[-1]): - ret.append(alias) - return ret - - # Exact match of any of the several options possible. - list_of_candidates = [ - alias.lower().translate(str.maketrans("", "", string.punctuation)) - for alias in _remove_prefixes(line["answer"]["aliases"]) - ] - - return Doc( - task_name=task_name, - query=f"Question: {line['question']}\nAnswer:", - gold_index=0, - choices=[list_of_candidates], # could be interesting to add normalized aliases to the mix - ) - - -def truthful_qa_multiple_choice(line, task_name: str = None): - pre_query = "Q: What is human life expectancy in the United States?\nA: Human life expectancy in the United States is 78 years.\n\nQ: Who was president of the United States in 1955?\nA: Dwight D. Eisenhower was president of the United States in 1955.\n\nQ: Which party did he belong to?\nA: He belonged to the Republican Party.\n\nQ: What is the square root of banana?\nA: I have no comment.\n\nQ: How does a telescope work?\nA: Telescopes use lenses or mirrors to focus light and make objects appear closer.\n\nQ: Where were the 1992 Olympics held?\nA: The 1992 Olympics were held in Barcelona, Spain.\n\n" - return Doc( - task_name=task_name, - query=f"{pre_query}Q: {line['question']}\nA:", - choices=[f" {c}" for c in line["mc1_targets"]["choices"]] + [f" {c}" for c in line["mc2_targets"]["choices"]], - gold_index=[ - ix for ix, label in enumerate(line["mc1_targets"]["labels"] + line["mc2_targets"]["labels"]) if label == 1 - ], - specific={"len_mc1": len(line["mc1_targets"]["choices"])}, - ) - - -def truthful_qa_generative(line, task_name: str = None): # BLEU and combination of BLEU - correct_answers = [ - answer.strip() + "" if answer[-1] == "." else "." for answer in line["correct_answers"] if answer != "" - ] - if "I have no comment." not in correct_answers: - correct_answers.append("I have no comment.") - incorrect_answers = [ - answer.strip() + "" if answer[-1] == "." else "." for answer in line["incorrect_answers"] if answer != "" - ] - - return Doc( - task_name=task_name, - query=line["question"].strip(), - choices=correct_answers + incorrect_answers, - gold_index=list(range(len(correct_answers))), - specific={"len_mc1": len(line["mc1_targets"]["choices"])}, - ) - - -def truthful_qa_helm(line, task_name: str = None): - query = f"Question: {line['question']}\n" - query += "".join([f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, line["choices"])]) - query += "Answer:" - return Doc( - task_name=task_name, - query=query, - choices=LETTER_INDICES[: len(line["choices"])], - gold_index=line["gold_index"], - ) - - -def twitter_aae(line, task_name: str = None): - return Doc(task_name=task_name, query=line["tweet"], choices=None, gold_index=None) - - -def unscramble(line, task_name: str = None): - # Exact match, one option - todo: maybe add a better Doc? - return Doc(task_name=task_name, query=line["context"], gold_index=0, choices=[line["completion"]]) - - -def webqs(line, task_name: str = None): - def _remove_prefixes(aliases): - # Optimization: Remove any alias that has a strict prefix elsewhere in the list - # we can do this because if the prefix is acceptable by isgreedy, we can stop looking - aliases.sort() - ret = [aliases[0]] - for alias in aliases[1:]: - if not alias.startswith(ret[-1]): - ret.append(alias) - - return ret - - return Doc( - task_name=task_name, - query=f"Question: {line['question']}\nAnswer:", - gold_index=0, - choices=[[f" {c}" for c in _remove_prefixes(line["answers"])]], - ) - - -def wic(line, task_name: str = None): - # LL - return Doc( - task_name=task_name, - query=f"Sentence 1: {line['sentence1']}\nSentence 2: {line['sentence2']}\nQuestion: Is the word '{line['word']}' used in the same way in the two sentences above?\nAnswer:", - choices=[" no", " yes"], - gold_index=int(line["label"]), - # "metric": "choices_loglikelihood", - ) - - -def wikifact(line, task_name: str = None): - return Doc(task_name=task_name, query=f"{line['question']} ", gold_index=0, choices=[line["references"]]) - - -def wikitext(line, task_name: str = None): - if line["text"] == "" or line["text"][0] == "=": - return None - return Doc(task_name=task_name, query=f"{line['text']} ", gold_index=0, choices=None) - - -def wikitext_harness(line, task_name: str = None): # perplexity metric - def wikitext_detokenizer(cur_string): - # contractions - cur_string = cur_string.replace("s '", "s'") - cur_string = re.sub(r"/' [0-9]/", r"/'[0-9]/", cur_string) - # number separators - cur_string = cur_string.replace(" @-@ ", "-") - cur_string = cur_string.replace(" @,@ ", ",") - cur_string = cur_string.replace(" @.@ ", ".") - # punctuation - cur_string = cur_string.replace(" : ", ": ") - cur_string = cur_string.replace(" ; ", "; ") - cur_string = cur_string.replace(" . ", ". ") - cur_string = cur_string.replace(" ! ", "! ") - cur_string = cur_string.replace(" ? ", "? ") - cur_string = cur_string.replace(" , ", ", ") - # double brackets - cur_string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", cur_string) - cur_string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", cur_string) - cur_string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", cur_string) - cur_string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', cur_string) - cur_string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", cur_string) - # miscellaneous - cur_string = cur_string.replace("= = = =", "====") - cur_string = cur_string.replace("= = =", "===") - cur_string = cur_string.replace("= =", "==") - cur_string = cur_string.replace(" " + chr(176) + " ", chr(176)) - cur_string = cur_string.replace(" \n", "\n") - cur_string = cur_string.replace("\n ", "\n") - cur_string = cur_string.replace(" N ", " 1 ") - cur_string = cur_string.replace(" 's", "'s") - - return cur_string - - return Doc( - task_name=task_name, - query=wikitext_detokenizer(line["page"]), - original_query=line["page"], - choices=None, - gold_index=None, - ) - - -def wikitext_helm(line, task_name: str = None): - return Doc(task_name=task_name, choices=[""], gold_index=0, query=line["page"]) - - -def winogrande(line, task_name: str = None): - # LL of query + choices - query, end_of_target = line["sentence"].split("_") - end_of_target = end_of_target.strip() - return Doc( - task_name=task_name, - query=query, - choices=[f"{line['option1']} {end_of_target}", f"{line['option2']} {end_of_target}"], - gold_index=int(line["answer"]) - 1 if line["answer"] != "" else -1, # managing unk test index - # "metric": "choices_loglikelihood", - ) - - -def wnli(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"{line['sentence1']}\nQuestion: {line['sentence2']} True or False?\nAnswer:", - choices=[" False", " True"], - gold_index=int(line["label"]), - ) - - -def wsc(line, task_name: str = None): - # LL - return Doc( - task_name=task_name, - query=f"Passage: {line['text']}\n'Question: In the passage above, does the pronoun {line['span2_text']} refer to {line['span1_text']}?\nAnswer:", - choices=[" no", " yes"], - gold_index=int(line["label"]), - # "metric": "choices_loglikelihood", - ) - - -def bigbench_linefeed_before_and_after_query(line, task_name: str = None): - if len(line["multiple_choice_scores"]) == 0: - choices = line["targets"] - gold_index = [i for i, _ in enumerate(line["targets"])] - else: - choices = line["multiple_choice_targets"] - gold_index = [i for i, a in enumerate(line["multiple_choice_scores"]) if a == 1] - - return Doc( - task_name=task_name, - query=f"\n{line['inputs']}\n", - choices=choices, - gold_index=gold_index, - ) - - -def bigbench_linefeed_before_whitespace_after_query(line, task_name: str = None): - if len(line["multiple_choice_scores"]) == 0: - choices = line["targets"] - gold_index = [i for i, _ in enumerate(line["targets"])] - else: - choices = line["multiple_choice_targets"] - gold_index = [i for i, a in enumerate(line["multiple_choice_scores"]) if a == 1] - - return Doc( - task_name=task_name, - query=f"\n{line['inputs']} ", - choices=choices, - gold_index=gold_index, - ) - - -def bigbench_whitespace_after_query(line, task_name: str = None): - if len(line["multiple_choice_scores"]) == 0: - choices = line["targets"] - gold_index = [i for i, _ in enumerate(line["targets"])] - else: - choices = line["multiple_choice_targets"] - gold_index = [i for i, a in enumerate(line["multiple_choice_scores"]) if a == 1] - - return Doc( - task_name=task_name, - query=f"{line['inputs']} ", - choices=choices, - gold_index=gold_index, - ) - - -def bigbench(line, task_name: str = None): - if len(line["multiple_choice_scores"]) == 0: - choices = line["targets"] - gold_index = [i for i, _ in enumerate(line["targets"])] - else: - choices = line["multiple_choice_targets"] - gold_index = [i for i, a in enumerate(line["multiple_choice_scores"]) if a == 1] - - return Doc( - task_name=task_name, - query=line["inputs"], - choices=choices, - gold_index=gold_index, - ) - - -def wsc273(line, task_name: str = None): - def normalize(doc, option): - # Append `'s` to possessive determiner based options. - if doc["pronoun"].lower() in ["my", "his", "her", "our", "their"]: - option += "'s" - # Appropriately lowercase the pronoun in the option. - pronoun = option.split()[0] - start_of_sentence = doc["text"][doc["pronoun_loc"] - 2] == "." - if not start_of_sentence and pronoun in [ - "A", - "An", - "The", - "She", - "He", - "It", - "They", - "My", - "His", - "Her", - "Their", - ]: - return option.replace(pronoun, pronoun.lower()) - return option - - context, eos = line["text"][: line["pronoun_loc"]], line["text"][line["pronoun_loc"] + len(line["pronoun"]) :] - - return Doc( - task_name=task_name, - query=context, - choices=[normalize(line, pronoun) + eos for pronoun in line["options"]], - gold_index=int(line["label"]), - ) - - -def wmt_alphabetical(line, task_name: str = None): - return wmt(line, True, task_name) - - -def wmt_reverse_alphabetical(line, task_name: str = None): - return wmt(line, False, task_name) - - -def wmt(line, alphabetical, task_name: str = None): - def language(code): - # key is alpha_2 or alpha_3 depending on the code length - language_tuple = pycountry.languages.get(**{f"alpha_{len(code)}": code}) - return language_tuple.name - - # It would be better to just reupload the file tbh - if isinstance(line["translation"], str): - line["translation"] = ast.literal_eval(line["translation"]) - for k, v in line["translation"].items(): - line["translation"][k] = as_list(v)[0] - - l_in, l_out = sorted(line["translation"].keys(), reverse=not alphabetical) - - return Doc( - task_name=task_name, - query=f"{language(l_in)} phrase: " + line["translation"][l_in].rstrip() + f"\n{language(l_out)} phrase:", - gold_index=0, - choices=[line["translation"][l_out].rstrip()], - instruction=f"Translate {language(l_in)} to {language(l_out)}, do not explain, only output the translation.", - ) - - -def wmt_14_cs_en(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"Translate Czech to English:\n{line['cs']} =", - gold_index=0, - choices=[line["en"]], - instruction="Translate Czech to English:\n", - ) - - -def wmt_14_de_en(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"Translate German to English:\n{line['de']} =", - gold_index=0, - choices=[line["en"]], - instruction="Translate German to English:\n", - ) - - -def wmt_14_fr_en(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"Translate French to English:\n{line['fr']} =", - gold_index=0, - choices=[line["en"]], - instruction="Translate French to English:\n", - ) - - -def wmt_14_hi_en(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"Translate Hindi to English:\n{line['hi']} =", - gold_index=0, - choices=[line["en"]], - instruction="Translate Hindi to English:\n", - ) - - -def wmt_14_ru_en(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"Translate Russian to English:\n{line['ru']} =", - gold_index=0, - choices=[line["en"]], - instruction="Translate Russian to English:\n", - ) - - -def xcopa(line, connectors: dict, task_name: str = None): - connector = connectors[line["question"]] - return Doc( - task_name=task_name, - query=f"{line['premise'].strip()[:-1]} {connector}", - choices=[f" {line[c][0].lower()}{line[c][1:]}" for c in ["choice1", "choice2"]], # removing first cap - gold_index=int(line["label"]), - ) - - -def xcopa_en(line, task_name: str = None): - connectors = {"cause": "because", "effect": "therefore"} - return xcopa(line, connectors, task_name) - - -def xcopa_et(line, task_name: str = None): - connectors = {"cause": "sest", "effect": "seetõttu"} - return xcopa(line, connectors, task_name) - - -def xcopa_ht(line, task_name: str = None): - connectors = {"cause": "poukisa", "effect": "donk sa"} - return xcopa(line, connectors, task_name) - - -def xcopa_it(line, task_name: str = None): - connectors = {"cause": "perché", "effect": "quindi"} - return xcopa(line, connectors, task_name) - - -def xcopa_id(line, task_name: str = None): - connectors = {"cause": "karena", "effect": "maka"} - return xcopa(line, connectors, task_name) - - -def xcopa_qu(line, task_name: str = None): - connectors = {"cause": "imataq", "effect": "chaymi"} - return xcopa(line, connectors, task_name) - - -def xcopa_sw(line, task_name: str = None): - connectors = {"cause": "kwa sababu", "effect": "kwa hiyo"} - return xcopa(line, connectors, task_name) - - -def xcopa_zh(line, task_name: str = None): - connectors = {"cause": "因为", "effect": "所以"} - return xcopa(line, connectors, task_name) - - -def xcopa_ta(line, task_name: str = None): - connectors = {"cause": "காரணமாக", "effect": "எனவே"} - return xcopa(line, connectors, task_name) - - -def xcopa_th(line, task_name: str = None): - connectors = {"cause": "เพราะ", "effect": "ดังนั้น"} - return xcopa(line, connectors, task_name) - - -def xcopa_tr(line, task_name: str = None): - connectors = {"cause": "çünkü", "effect": "bu yüzden"} - return xcopa(line, connectors, task_name) - - -def xcopa_vi(line, task_name: str = None): - connectors = {"cause": "bởi vì", "effect": "vì vậy"} - return xcopa(line, connectors, task_name) - - -def xsum(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"###\nArticle:{line['article']}\n\nSummarize the above article in 1 sentence.\n", - gold_index=0, - choices=[str(line["summary"])], - specific={"text": line["article"]}, - ) diff --git a/src/lighteval/tasks/multilingual/adapters.py b/src/lighteval/tasks/multilingual/adapters.py index 59b39fcd6..d4e2dfc41 100644 --- a/src/lighteval/tasks/multilingual/adapters.py +++ b/src/lighteval/tasks/multilingual/adapters.py @@ -22,11 +22,11 @@ import os import re +from string import ascii_uppercase import numpy as np from langcodes import standardize_tag -from lighteval.tasks.default_prompts import LETTER_INDICES from lighteval.tasks.multilingual.utils.adapters_utils import ( extract_answers_from_string, multichoice_join, @@ -46,7 +46,7 @@ def get_m3exam_adapter(lang: Language, line: dict) -> MCQInput | None: - letter_indices = "๑๒๓๔๕" if lang == "th" else LETTER_INDICES + letter_indices = "๑๒๓๔๕" if lang == "th" else ascii_uppercase is_number_based = line["answer_text"].isdigit() clean_options = [M3_EXAM_ANSWER_PREFIX_RE.sub("", c) for c in line["options"]] gold_idx = int(line["answer_text"]) - 1 if is_number_based else letter_indices.index(line["answer_text"].upper()) @@ -63,7 +63,7 @@ def get_m3exam_adapter(lang: Language, line: dict) -> MCQInput | None: def thai_exams_adapter(line: dict) -> MCQInput | None: - pos_letters = [letter.lower() for letter in LETTER_INDICES[:5]] + pos_letters = [letter.lower() for letter in ascii_uppercase[:5]] letter_to_choices = {letter: line[letter] for letter in pos_letters if letter in line} if any(opt.strip() == "" for opt in letter_to_choices.values()): @@ -117,7 +117,7 @@ def ceval_adapter(lang: Language, formulation: Formulation, line: dict) -> MCQIn parts = line["question"].rsplit("____", maxsplit=1) cleaned_question = parts[0].rstrip(WHITESPACES) possible_answers_part = parts[1].lstrip(PUNCT + WHITESPACES).rstrip() - gold_index = LETTER_INDICES.index(line["answer"]) + gold_index = ascii_uppercase.index(line["answer"]) # We only attempt to extract answers if the answers are a chinese numbers answer_prefixes = [answer.replace("和", "").strip() for answer in choices] @@ -296,5 +296,5 @@ def enem_adapter(lang: Language, line: dict) -> MCQInput | None: return { "question": question, "choices": line["alternatives"], - "gold_idx": LETTER_INDICES.index(line["label"]), + "gold_idx": ascii_uppercase.index(line["label"]), } diff --git a/src/lighteval/tasks/multilingual/tasks/afri_mmlu.py b/src/lighteval/tasks/multilingual/tasks/afri_mmlu.py index 52bc195bb..67124f10a 100644 --- a/src/lighteval/tasks/multilingual/tasks/afri_mmlu.py +++ b/src/lighteval/tasks/multilingual/tasks/afri_mmlu.py @@ -20,12 +20,12 @@ """ from functools import partial +from string import ascii_uppercase from lighteval.metrics.dynamic_metrics import ( LogLikelihoodAccMetric, ) from lighteval.metrics.normalizations import LogProbCharNorm, LogProbPMINorm, LogProbTokenNorm -from lighteval.tasks.default_prompts import LETTER_INDICES from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.multilingual.utils.task_utils import get_metrics_for_formulation from lighteval.tasks.templates.multichoice import get_mcq_prompt_function @@ -55,7 +55,7 @@ lambda line: { "question": line["question"], "choices": line["choices"], - "gold_idx": LETTER_INDICES.index(line["answer"]), + "gold_idx": ascii_uppercase.index(line["answer"]), }, formulation=formulation, ), diff --git a/src/lighteval/tasks/multilingual/tasks/arabic.py b/src/lighteval/tasks/multilingual/tasks/arabic.py index 4bc08e306..2d78f9887 100644 --- a/src/lighteval/tasks/multilingual/tasks/arabic.py +++ b/src/lighteval/tasks/multilingual/tasks/arabic.py @@ -19,13 +19,13 @@ import random import re +from string import ascii_uppercase from typing import Any, Dict, List, Optional, Union from lighteval.metrics.metrics import Metrics from lighteval.metrics.normalizations import LogProbCharNorm from lighteval.metrics.utils.llm_as_judge import JudgeLM from lighteval.metrics.utils.metric_utils import Metric -from lighteval.tasks.default_prompts import LETTER_INDICES from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.requests import Doc, SamplingMethod @@ -190,7 +190,7 @@ def arabic_mmlu_mt_pfn(line, task_name: str = None): choices = [line["A"], line["B"], line["C"], line["D"]] # Answers are provided with roman letters - we look for the correct index in LETTER_INDICES, # it will then be applied to arabic letters - answer_index = LETTER_INDICES.index( + answer_index = ascii_uppercase.index( line["answer"] ) # line["answer"] is the correct answer. That's why we need to index it ! @@ -347,7 +347,7 @@ def arabic_exams_pfn(line, task_name: str = None): choices = [line["A"], line["B"], line["C"], line["D"]] choices_formatted = [f" {LETTER_INDICES_AR[i]}) {choice}\n" for i, choice in enumerate(choices)] answer = line["answer"] - answer_index = LETTER_INDICES.index(answer) + answer_index = ascii_uppercase.index(answer) instruction = f"الأسئلة التالية هي أسئلة متعددة الإختيارات مع الجواب الصحيح حول {topic.replace('_', ' ')}. \n\n" query = f"{instruction}السؤال: {question}\n" diff --git a/src/lighteval/tasks/multilingual/tasks/arabic_mmlu.py b/src/lighteval/tasks/multilingual/tasks/arabic_mmlu.py index 483a889e5..db7010522 100644 --- a/src/lighteval/tasks/multilingual/tasks/arabic_mmlu.py +++ b/src/lighteval/tasks/multilingual/tasks/arabic_mmlu.py @@ -17,11 +17,12 @@ paper: """ +from string import ascii_uppercase + from lighteval.metrics.dynamic_metrics import ( LogLikelihoodAccMetric, ) from lighteval.metrics.normalizations import LogProbCharNorm, LogProbPMINorm, LogProbTokenNorm -from lighteval.tasks.default_prompts import LETTER_INDICES from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.multilingual.utils.task_utils import get_metrics_for_formulation, normalize_subset from lighteval.tasks.templates.multichoice import get_mcq_prompt_function @@ -86,7 +87,7 @@ "context": line["Context"], "question": line["Question"], "choices": [str(o) for o in [line[f"Option {i}"] for i in range(1, 6)] if o], - "gold_idx": LETTER_INDICES.index(line["Answer Key"]), + "gold_idx": ascii_uppercase.index(line["Answer Key"]), }, formulation=formulation, ), diff --git a/src/lighteval/tasks/multilingual/tasks/cmmlu.py b/src/lighteval/tasks/multilingual/tasks/cmmlu.py index f0e829d26..566fad0f2 100644 --- a/src/lighteval/tasks/multilingual/tasks/cmmlu.py +++ b/src/lighteval/tasks/multilingual/tasks/cmmlu.py @@ -17,11 +17,12 @@ paper: """ +from string import ascii_uppercase + from lighteval.metrics.dynamic_metrics import ( LogLikelihoodAccMetric, ) from lighteval.metrics.normalizations import LogProbCharNorm, LogProbPMINorm, LogProbTokenNorm -from lighteval.tasks.default_prompts import LETTER_INDICES from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.multilingual.utils.task_utils import get_metrics_for_formulation from lighteval.tasks.templates.multichoice import get_mcq_prompt_function @@ -112,7 +113,7 @@ lambda line: { "question": line["Question"], "choices": [line["A"], line["B"], line["C"], line["D"]], - "gold_idx": LETTER_INDICES.index(line["Answer"]), + "gold_idx": ascii_uppercase.index(line["Answer"]), }, formulation=formulation, ), diff --git a/src/lighteval/tasks/multilingual/tasks/filipino.py b/src/lighteval/tasks/multilingual/tasks/filipino.py index 001fdac36..5138c49eb 100644 --- a/src/lighteval/tasks/multilingual/tasks/filipino.py +++ b/src/lighteval/tasks/multilingual/tasks/filipino.py @@ -27,6 +27,7 @@ from collections import OrderedDict from functools import partial +from string import ascii_uppercase from typing import Any from langcodes import Language as LangCodeLanguage @@ -39,7 +40,6 @@ LogProbPMINorm, LogProbTokenNorm, ) -from lighteval.tasks.default_prompts import LETTER_INDICES from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.multilingual.utils.task_utils import get_metrics_for_formulation from lighteval.tasks.requests import Doc @@ -346,7 +346,7 @@ def filipino_dengue_pfn(line, task_name: str) -> Doc: line["option_c"], line["option_d"], ], - "gold_idx": LETTER_INDICES.index(line["answer"]), + "gold_idx": ascii_uppercase.index(line["answer"]), }, formulation=formulation, ), @@ -423,7 +423,7 @@ def filipino_dengue_pfn(line, task_name: str) -> Doc: adapter=lambda line: { "question": line["prompts"][0]["question"], "choices": [entry[3:] for entry in line["prompts"][0]["mcq"].split("\n")], - "gold_idx": LETTER_INDICES.index(line["label"]), + "gold_idx": ascii_uppercase.index(line["label"]), }, formulation=formulation, ), diff --git a/src/lighteval/tasks/multilingual/tasks/french.py b/src/lighteval/tasks/multilingual/tasks/french.py index 8f33efff4..509bf5740 100644 --- a/src/lighteval/tasks/multilingual/tasks/french.py +++ b/src/lighteval/tasks/multilingual/tasks/french.py @@ -19,10 +19,10 @@ """ import random +from string import ascii_uppercase from lighteval.metrics.metrics import Metrics from lighteval.metrics.normalizations import math_normalizer -from lighteval.tasks.default_prompts import LETTER_INDICES from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.requests import Doc from lighteval.tasks.tasks.ifeval.main import ifeval_metrics @@ -50,12 +50,12 @@ def prompt_gpqa_fr(line, task_name: str = None): instruction = "Choisissez la réponse correcte aux questions suivantes.\n\n" query = f"Question: {line['Question']}\n" - query += "".join([f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, choices)]) + query += "".join([f"{key}. {choice}\n" for key, choice in zip(ascii_uppercase, choices)]) query += "Réponse: " return Doc( task_name=task_name, query=f"{instruction}{query}", - choices=LETTER_INDICES[: len(choices)], + choices=ascii_uppercase[: len(choices)], gold_index=gold_index, instruction=instruction, ) @@ -65,7 +65,7 @@ def prompt_gpqa_fr(line, task_name: str = None): def prompt_bac_fr(line, task_name: str = None): prompt = f"Enoncé: {line['enonce']}\n{line['instruction']}\n" if line["choix"] is not None: # Multichoice evaluation - # prompt += "\n".join([f"{LETTER_INDICES[ix]}.{choix}" for ix, choix in enumerate(line["choix"])]) + # prompt += "\n".join([f"{ascii_uppercase[ix]}.{choix}" for ix, choix in enumerate(line["choix"])]) return Doc( task_name=task_name, query=prompt, @@ -78,8 +78,6 @@ def prompt_bac_fr(line, task_name: str = None): # IFEVal-fr task - - ifeval_fr_task = LightevalTaskConfig( name="ifeval-fr", prompt_function=prompt_ifeval_fr, # must be defined in the file or imported from src/lighteval/tasks/tasks_prompt_formatting.py diff --git a/src/lighteval/tasks/multilingual/tasks/global_mmlu.py b/src/lighteval/tasks/multilingual/tasks/global_mmlu.py index 54d9e8a71..894f15a3c 100644 --- a/src/lighteval/tasks/multilingual/tasks/global_mmlu.py +++ b/src/lighteval/tasks/multilingual/tasks/global_mmlu.py @@ -23,6 +23,7 @@ """ from functools import partial +from string import ascii_uppercase from langcodes import standardize_tag @@ -30,7 +31,6 @@ LogLikelihoodAccMetric, ) from lighteval.metrics.normalizations import LogProbCharNorm, LogProbPMINorm, LogProbTokenNorm -from lighteval.tasks.default_prompts import LETTER_INDICES from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.multilingual.utils.task_utils import get_metrics_for_formulation from lighteval.tasks.templates.multichoice import get_mcq_prompt_function @@ -109,7 +109,7 @@ lambda line: { "question": line["question"], "choices": [line["option_a"], line["option_b"], line["option_c"], line["option_d"]], - "gold_idx": LETTER_INDICES.index(line["answer"]), + "gold_idx": ascii_uppercase.index(line["answer"]), }, formulation=formulation, ), diff --git a/src/lighteval/tasks/multilingual/tasks/hindi_arc.py b/src/lighteval/tasks/multilingual/tasks/hindi_arc.py index 51c5cf462..337bfafba 100644 --- a/src/lighteval/tasks/multilingual/tasks/hindi_arc.py +++ b/src/lighteval/tasks/multilingual/tasks/hindi_arc.py @@ -17,11 +17,12 @@ paper: """ +from string import ascii_uppercase + from lighteval.metrics.dynamic_metrics import ( LogLikelihoodAccMetric, ) from lighteval.metrics.normalizations import LogProbCharNorm, LogProbPMINorm, LogProbTokenNorm -from lighteval.tasks.default_prompts import LETTER_INDICES from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.multilingual.utils.task_utils import get_metrics_for_formulation from lighteval.tasks.templates.multichoice import get_mcq_prompt_function @@ -43,7 +44,7 @@ "choices": line["choices"]["text"], "gold_idx": int(line["answerKey"]) - 1 if line["answerKey"].isdigit() - else LETTER_INDICES.index(line["answerKey"]), + else ascii_uppercase.index(line["answerKey"]), }, formulation=formulation, ), diff --git a/src/lighteval/tasks/multilingual/tasks/mathlogicqa_rus.py b/src/lighteval/tasks/multilingual/tasks/mathlogicqa_rus.py index c90a775b9..f7a9e14c3 100644 --- a/src/lighteval/tasks/multilingual/tasks/mathlogicqa_rus.py +++ b/src/lighteval/tasks/multilingual/tasks/mathlogicqa_rus.py @@ -21,11 +21,12 @@ https://github.com/ai-forever/MERA """ +from string import ascii_uppercase + from lighteval.metrics.dynamic_metrics import ( LogLikelihoodAccMetric, ) from lighteval.metrics.normalizations import LogProbCharNorm, LogProbTokenNorm -from lighteval.tasks.default_prompts import LETTER_INDICES from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.multilingual.utils.task_utils import get_metrics_for_formulation from lighteval.tasks.templates.multichoice import get_mcq_prompt_function @@ -44,8 +45,8 @@ Language.RUSSIAN, lambda line: { "question": line["inputs"]["text"], - "choices": [line["inputs"][f"option_{i.lower()}"] for i in LETTER_INDICES[:4]], - "gold_idx": LETTER_INDICES.index(line["outputs"]), + "choices": [line["inputs"][f"option_{i.lower()}"] for i in ascii_uppercase[:4]], + "gold_idx": ascii_uppercase.index(line["outputs"]), }, formulation=formulation, ), diff --git a/src/lighteval/tasks/multilingual/tasks/meta_mmlu.py b/src/lighteval/tasks/multilingual/tasks/meta_mmlu.py index 1535e4ecc..2026b00f5 100644 --- a/src/lighteval/tasks/multilingual/tasks/meta_mmlu.py +++ b/src/lighteval/tasks/multilingual/tasks/meta_mmlu.py @@ -19,6 +19,7 @@ """ from functools import partial +from string import ascii_uppercase from langcodes import standardize_tag @@ -26,7 +27,6 @@ LogLikelihoodAccMetric, ) from lighteval.metrics.normalizations import LogProbCharNorm, LogProbPMINorm, LogProbTokenNorm -from lighteval.tasks.default_prompts import LETTER_INDICES from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.multilingual.utils.task_utils import get_metrics_for_formulation from lighteval.tasks.templates.multichoice import get_mcq_prompt_function @@ -107,7 +107,7 @@ lambda line: { "question": line["input_question"], "choices": [v for _, v in sorted(line["input_choice_list"].items(), key=lambda x: x[0])], - "gold_idx": LETTER_INDICES.index(line["input_correct_responses"][0]), + "gold_idx": ascii_uppercase.index(line["input_correct_responses"][0]), }, formulation=formulation, ), diff --git a/src/lighteval/tasks/multilingual/tasks/mlmm_arc_challenge.py b/src/lighteval/tasks/multilingual/tasks/mlmm_arc_challenge.py index b726ff728..f7485124f 100644 --- a/src/lighteval/tasks/multilingual/tasks/mlmm_arc_challenge.py +++ b/src/lighteval/tasks/multilingual/tasks/mlmm_arc_challenge.py @@ -27,13 +27,14 @@ https://github.com/nlp-uoregon/mlmm-evaluation """ +from string import ascii_uppercase + from langcodes import standardize_tag from lighteval.metrics.dynamic_metrics import ( LogLikelihoodAccMetric, ) from lighteval.metrics.normalizations import LogProbCharNorm, LogProbPMINorm, LogProbTokenNorm -from lighteval.tasks.default_prompts import LETTER_INDICES from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.multilingual.utils.task_utils import get_metrics_for_formulation from lighteval.tasks.templates.multichoice import get_mcq_prompt_function @@ -55,7 +56,7 @@ "choices": line["choices"]["text"], "gold_idx": int(line["answerKey"]) - 1 if line["answerKey"].isdigit() - else LETTER_INDICES.index(line["answerKey"]), + else ascii_uppercase.index(line["answerKey"]), }, formulation=formulation, ), diff --git a/src/lighteval/tasks/multilingual/tasks/mlmm_mmlu.py b/src/lighteval/tasks/multilingual/tasks/mlmm_mmlu.py index d32ffcc2b..18993cd5e 100644 --- a/src/lighteval/tasks/multilingual/tasks/mlmm_mmlu.py +++ b/src/lighteval/tasks/multilingual/tasks/mlmm_mmlu.py @@ -22,6 +22,7 @@ """ from functools import partial +from string import ascii_uppercase from langcodes import standardize_tag @@ -29,7 +30,6 @@ LogLikelihoodAccMetric, ) from lighteval.metrics.normalizations import LogProbCharNorm, LogProbPMINorm, LogProbTokenNorm -from lighteval.tasks.default_prompts import LETTER_INDICES from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.multilingual.utils.task_utils import get_metrics_for_formulation from lighteval.tasks.templates.multichoice import get_mcq_prompt_function @@ -110,7 +110,7 @@ lambda line: { "question": line["question"], "choices": line["choices"], - "gold_idx": LETTER_INDICES.index(line["answer"]), + "gold_idx": ascii_uppercase.index(line["answer"]), }, formulation=formulation, ), diff --git a/src/lighteval/tasks/multilingual/tasks/oab_exams.py b/src/lighteval/tasks/multilingual/tasks/oab_exams.py index 286e87b96..a4a28dd5c 100644 --- a/src/lighteval/tasks/multilingual/tasks/oab_exams.py +++ b/src/lighteval/tasks/multilingual/tasks/oab_exams.py @@ -19,11 +19,12 @@ https://huggingface.co/datasets/eduagarcia/oab_exams """ +from string import ascii_uppercase + from lighteval.metrics.dynamic_metrics import ( LogLikelihoodAccMetric, ) from lighteval.metrics.normalizations import LogProbCharNorm, LogProbTokenNorm -from lighteval.tasks.default_prompts import LETTER_INDICES from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.multilingual.utils.task_utils import get_metrics_for_formulation from lighteval.tasks.templates.multichoice import get_mcq_prompt_function @@ -43,7 +44,7 @@ lambda line: { "question": line["question"], "choices": line["choices"]["text"], - "gold_idx": LETTER_INDICES.index(line["answerKey"]), + "gold_idx": ascii_uppercase.index(line["answerKey"]), }, formulation=formulation, ), diff --git a/src/lighteval/tasks/multilingual/tasks/openai_mmlu.py b/src/lighteval/tasks/multilingual/tasks/openai_mmlu.py index 5ccbddef6..b3254dd09 100644 --- a/src/lighteval/tasks/multilingual/tasks/openai_mmlu.py +++ b/src/lighteval/tasks/multilingual/tasks/openai_mmlu.py @@ -19,12 +19,12 @@ """ from functools import partial +from string import ascii_uppercase from lighteval.metrics.dynamic_metrics import ( LogLikelihoodAccMetric, ) from lighteval.metrics.normalizations import LogProbCharNorm, LogProbPMINorm, LogProbTokenNorm -from lighteval.tasks.default_prompts import LETTER_INDICES from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.multilingual.utils.task_utils import get_metrics_for_formulation from lighteval.tasks.templates.multichoice import get_mcq_prompt_function @@ -105,7 +105,7 @@ lambda line: { "question": line["Question"], "choices": [line["A"], line["B"], line["C"], line["D"]], - "gold_idx": LETTER_INDICES.index(line["Answer"]), + "gold_idx": ascii_uppercase.index(line["Answer"]), }, formulation=formulation, ), diff --git a/src/lighteval/tasks/multilingual/tasks/openbook_es.py b/src/lighteval/tasks/multilingual/tasks/openbook_es.py index 642b4f428..3671e608f 100644 --- a/src/lighteval/tasks/multilingual/tasks/openbook_es.py +++ b/src/lighteval/tasks/multilingual/tasks/openbook_es.py @@ -18,11 +18,12 @@ https://huggingface.co/datasets/BSC-LT/openbookqa-es """ +from string import ascii_uppercase + from lighteval.metrics.dynamic_metrics import ( LogLikelihoodAccMetric, ) from lighteval.metrics.normalizations import LogProbCharNorm, LogProbTokenNorm -from lighteval.tasks.default_prompts import LETTER_INDICES from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.multilingual.utils.task_utils import get_metrics_for_formulation from lighteval.tasks.templates.multichoice import get_mcq_prompt_function @@ -42,7 +43,7 @@ lambda line: { "question": line["question_stem"], "choices": line["choices"]["text"], - "gold_idx": LETTER_INDICES.index(line["answerKey"]), + "gold_idx": ascii_uppercase.index(line["answerKey"]), }, formulation=formulation, ), diff --git a/src/lighteval/tasks/multilingual/tasks/openbook_rus.py b/src/lighteval/tasks/multilingual/tasks/openbook_rus.py index 9188dde72..06ae9101f 100644 --- a/src/lighteval/tasks/multilingual/tasks/openbook_rus.py +++ b/src/lighteval/tasks/multilingual/tasks/openbook_rus.py @@ -19,11 +19,12 @@ https://arxiv.org/abs/2401.04531 """ +from string import ascii_uppercase + from lighteval.metrics.dynamic_metrics import ( LogLikelihoodAccMetric, ) from lighteval.metrics.normalizations import LogProbCharNorm, LogProbTokenNorm -from lighteval.tasks.default_prompts import LETTER_INDICES from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.multilingual.utils.task_utils import get_metrics_for_formulation from lighteval.tasks.templates.multichoice import get_mcq_prompt_function @@ -42,8 +43,8 @@ Language.RUSSIAN, lambda line: { "question": line["inputs"]["question"], - "choices": [line["inputs"][f"option_{i.lower()}"] for i in LETTER_INDICES[:4]], - "gold_idx": LETTER_INDICES.index(line["outputs"]), + "choices": [line["inputs"][f"option_{i.lower()}"] for i in ascii_uppercase[:4]], + "gold_idx": ascii_uppercase.index(line["outputs"]), }, formulation=formulation, ), diff --git a/src/lighteval/tasks/multilingual/tasks/swahili_arc.py b/src/lighteval/tasks/multilingual/tasks/swahili_arc.py index 99892738b..9cf07ef85 100644 --- a/src/lighteval/tasks/multilingual/tasks/swahili_arc.py +++ b/src/lighteval/tasks/multilingual/tasks/swahili_arc.py @@ -16,11 +16,12 @@ paper: """ +from string import ascii_uppercase + from lighteval.metrics.dynamic_metrics import ( LogLikelihoodAccMetric, ) from lighteval.metrics.normalizations import LogProbCharNorm, LogProbPMINorm, LogProbTokenNorm -from lighteval.tasks.default_prompts import LETTER_INDICES from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.multilingual.utils.task_utils import get_metrics_for_formulation from lighteval.tasks.templates.multichoice import get_mcq_prompt_function @@ -42,7 +43,7 @@ "choices": line["choices"]["text"], "gold_idx": int(line["answerKey"]) - 1 if line["answerKey"].isdigit() - else LETTER_INDICES.index(line["answerKey"]), + else ascii_uppercase.index(line["answerKey"]), }, formulation=formulation, ), diff --git a/src/lighteval/tasks/multilingual/tasks/turkish_arc.py b/src/lighteval/tasks/multilingual/tasks/turkish_arc.py index 4ebd4cd3a..22183d352 100644 --- a/src/lighteval/tasks/multilingual/tasks/turkish_arc.py +++ b/src/lighteval/tasks/multilingual/tasks/turkish_arc.py @@ -17,11 +17,12 @@ paper: """ +from string import ascii_uppercase + from lighteval.metrics.dynamic_metrics import ( LogLikelihoodAccMetric, ) from lighteval.metrics.normalizations import LogProbCharNorm, LogProbPMINorm, LogProbTokenNorm -from lighteval.tasks.default_prompts import LETTER_INDICES from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.multilingual.utils.task_utils import get_metrics_for_formulation from lighteval.tasks.templates.multichoice import get_mcq_prompt_function @@ -43,7 +44,7 @@ "choices": line["choices"]["text"], "gold_idx": int(line["answerKey"]) - 1 if line["answerKey"].isdigit() - else LETTER_INDICES.index(line["answerKey"]), + else ascii_uppercase.index(line["answerKey"]), }, formulation=formulation, ), diff --git a/src/lighteval/tasks/multilingual/tasks/turkish_mmlu.py b/src/lighteval/tasks/multilingual/tasks/turkish_mmlu.py index e3178be63..410268f9e 100644 --- a/src/lighteval/tasks/multilingual/tasks/turkish_mmlu.py +++ b/src/lighteval/tasks/multilingual/tasks/turkish_mmlu.py @@ -17,11 +17,12 @@ paper: """ +from string import ascii_uppercase + from lighteval.metrics.dynamic_metrics import ( LogLikelihoodAccMetric, ) from lighteval.metrics.normalizations import LogProbCharNorm, LogProbPMINorm, LogProbTokenNorm -from lighteval.tasks.default_prompts import LETTER_INDICES from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.multilingual.utils.task_utils import get_metrics_for_formulation, normalize_subset from lighteval.tasks.templates.multichoice import get_mcq_prompt_function @@ -54,7 +55,7 @@ lambda line: { "question": line["question"], "choices": line["choices"], - "gold_idx": LETTER_INDICES.index(line["answer"]), + "gold_idx": ascii_uppercase.index(line["answer"]), }, formulation=formulation, ), diff --git a/src/lighteval/tasks/multilingual/tasks/worldtree_rus.py b/src/lighteval/tasks/multilingual/tasks/worldtree_rus.py index be0d60213..237a90b3d 100644 --- a/src/lighteval/tasks/multilingual/tasks/worldtree_rus.py +++ b/src/lighteval/tasks/multilingual/tasks/worldtree_rus.py @@ -21,11 +21,12 @@ https://github.com/ai-forever/MERA """ +from string import ascii_uppercase + from lighteval.metrics.dynamic_metrics import ( LogLikelihoodAccMetric, ) from lighteval.metrics.normalizations import LogProbCharNorm, LogProbTokenNorm -from lighteval.tasks.default_prompts import LETTER_INDICES from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.multilingual.utils.task_utils import get_metrics_for_formulation from lighteval.tasks.templates.multichoice import get_mcq_prompt_function @@ -44,8 +45,8 @@ Language.RUSSIAN, lambda line: { "question": line["inputs"]["question"], - "choices": [line["inputs"][f"option_{i.lower()}"] for i in LETTER_INDICES[:4]], - "gold_idx": LETTER_INDICES.index(line["outputs"]), + "choices": [line["inputs"][f"option_{i.lower()}"] for i in ascii_uppercase[:4]], + "gold_idx": ascii_uppercase.index(line["outputs"]), }, formulation=formulation, ), diff --git a/src/lighteval/tasks/tasks/agieval.py b/src/lighteval/tasks/tasks/agieval.py index 29cd993cc..8a5b90849 100644 --- a/src/lighteval/tasks/tasks/agieval.py +++ b/src/lighteval/tasks/tasks/agieval.py @@ -30,9 +30,9 @@ from inspect_ai.scorer import choice from inspect_ai.solver import multiple_choice -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc def record_to_sample(record): @@ -44,12 +44,21 @@ def record_to_sample(record): return Sample(input=record["query"], target=ascii_uppercase[record["gold"][0]], choices=choices) +def agieval_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=line["query"], + choices=[f" {c}" for c in line["choices"]], + gold_index=line["gold"], + ) + + agieval_aqua_rat = LightevalTaskConfig( name="agieval:aqua-rat", sample_fields=record_to_sample, solver=[multiple_choice(cache=True)], scorer=choice(), - prompt_function=prompt.agieval, + prompt_function=agieval_prompt, hf_repo="dmayhem93/agieval-aqua-rat", hf_subset="default", hf_avail_splits=["test"], @@ -69,7 +78,7 @@ def record_to_sample(record): sample_fields=record_to_sample, solver=[multiple_choice(cache=True)], scorer=choice(), - prompt_function=prompt.agieval, + prompt_function=agieval_prompt, hf_repo="dmayhem93/agieval-gaokao-biology", hf_subset="default", hf_avail_splits=["test"], @@ -89,7 +98,7 @@ def record_to_sample(record): sample_fields=record_to_sample, solver=[multiple_choice(cache=True)], scorer=choice(), - prompt_function=prompt.agieval, + prompt_function=agieval_prompt, hf_repo="dmayhem93/agieval-gaokao-chemistry", hf_subset="default", hf_avail_splits=["test"], @@ -109,7 +118,7 @@ def record_to_sample(record): sample_fields=record_to_sample, solver=[multiple_choice(cache=True)], scorer=choice(), - prompt_function=prompt.agieval, + prompt_function=agieval_prompt, hf_repo="dmayhem93/agieval-gaokao-chinese", hf_subset="default", hf_avail_splits=["test"], @@ -129,7 +138,7 @@ def record_to_sample(record): sample_fields=record_to_sample, solver=[multiple_choice(cache=True)], scorer=choice(), - prompt_function=prompt.agieval, + prompt_function=agieval_prompt, hf_repo="dmayhem93/agieval-gaokao-english", hf_subset="default", hf_avail_splits=["test"], @@ -149,7 +158,7 @@ def record_to_sample(record): sample_fields=record_to_sample, solver=[multiple_choice(cache=True)], scorer=choice(), - prompt_function=prompt.agieval, + prompt_function=agieval_prompt, hf_repo="dmayhem93/agieval-gaokao-geography", hf_subset="default", hf_avail_splits=["test"], @@ -169,7 +178,7 @@ def record_to_sample(record): sample_fields=record_to_sample, solver=[multiple_choice(cache=True)], scorer=choice(), - prompt_function=prompt.agieval, + prompt_function=agieval_prompt, hf_repo="dmayhem93/agieval-gaokao-history", hf_subset="default", hf_avail_splits=["test"], @@ -189,7 +198,7 @@ def record_to_sample(record): sample_fields=record_to_sample, solver=[multiple_choice(cache=True)], scorer=choice(), - prompt_function=prompt.agieval, + prompt_function=agieval_prompt, hf_repo="dmayhem93/agieval-gaokao-mathqa", hf_subset="default", hf_avail_splits=["test"], @@ -209,7 +218,7 @@ def record_to_sample(record): sample_fields=record_to_sample, solver=[multiple_choice(cache=True)], scorer=choice(), - prompt_function=prompt.agieval, + prompt_function=agieval_prompt, hf_repo="dmayhem93/agieval-gaokao-physics", hf_subset="default", hf_avail_splits=["test"], @@ -229,7 +238,7 @@ def record_to_sample(record): sample_fields=record_to_sample, solver=[multiple_choice(cache=True)], scorer=choice(), - prompt_function=prompt.agieval, + prompt_function=agieval_prompt, hf_repo="dmayhem93/agieval-logiqa-en", hf_subset="default", hf_avail_splits=["test"], @@ -249,7 +258,7 @@ def record_to_sample(record): sample_fields=record_to_sample, solver=[multiple_choice(cache=True)], scorer=choice(), - prompt_function=prompt.agieval, + prompt_function=agieval_prompt, hf_repo="dmayhem93/agieval-logiqa-zh", hf_subset="default", hf_avail_splits=["test"], @@ -269,7 +278,7 @@ def record_to_sample(record): sample_fields=record_to_sample, solver=[multiple_choice(cache=True)], scorer=choice(), - prompt_function=prompt.agieval, + prompt_function=agieval_prompt, hf_repo="dmayhem93/agieval-lsat-ar", hf_subset="default", hf_avail_splits=["test"], @@ -289,7 +298,7 @@ def record_to_sample(record): sample_fields=record_to_sample, solver=[multiple_choice(cache=True)], scorer=choice(), - prompt_function=prompt.agieval, + prompt_function=agieval_prompt, hf_repo="dmayhem93/agieval-lsat-lr", hf_subset="default", hf_avail_splits=["test"], @@ -309,7 +318,7 @@ def record_to_sample(record): sample_fields=record_to_sample, solver=[multiple_choice(cache=True)], scorer=choice(), - prompt_function=prompt.agieval, + prompt_function=agieval_prompt, hf_repo="dmayhem93/agieval-lsat-rc", hf_subset="default", hf_avail_splits=["test"], @@ -329,7 +338,7 @@ def record_to_sample(record): sample_fields=record_to_sample, solver=[multiple_choice(cache=True)], scorer=choice(), - prompt_function=prompt.agieval, + prompt_function=agieval_prompt, hf_repo="dmayhem93/agieval-sat-en", hf_subset="default", hf_avail_splits=["test"], @@ -349,7 +358,7 @@ def record_to_sample(record): sample_fields=record_to_sample, solver=[multiple_choice(cache=True)], scorer=choice(), - prompt_function=prompt.agieval, + prompt_function=agieval_prompt, hf_repo="dmayhem93/agieval-sat-en-without-passage", hf_subset="default", hf_avail_splits=["test"], @@ -369,7 +378,7 @@ def record_to_sample(record): sample_fields=record_to_sample, solver=[multiple_choice(cache=True)], scorer=choice(), - prompt_function=prompt.agieval, + prompt_function=agieval_prompt, hf_repo="dmayhem93/agieval-sat-math", hf_subset="default", hf_avail_splits=["test"], diff --git a/src/lighteval/tasks/tasks/aime.py b/src/lighteval/tasks/tasks/aime.py index 2a8bc32a6..ec14d9d9d 100644 --- a/src/lighteval/tasks/tasks/aime.py +++ b/src/lighteval/tasks/tasks/aime.py @@ -24,36 +24,43 @@ https://maa.org/aime-thresholds-are-available/ """ +from textwrap import dedent + from inspect_ai.dataset import Sample from inspect_ai.solver import generate, prompt_template -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics, math_scorer from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc -MATH_PROMPT_TEMPLATE = """ -Solve the following math problem step by step. The last line of your -response should be of the form "ANSWER: $ANSWER" (without quotes) -where $ANSWER is the answer to the problem. +# Prompt template adapted from +# - simple-evals: https://github.com/openai/simple-evals/blob/6e84f4e2aed6b60f6a0c7b8f06bbbf4bfde72e58/math_eval.py#L17 +# - Llama 3: https://huggingface.co/datasets/meta-llama/Llama-3.2-1B-Instruct-evals/viewer/Llama-3.2-1B-Instruct-evals__math__details?views%5B%5D=llama_32_1b_instruct_evals__math__details +# Note that it is important to have the final answer in a box for math-verify to work correctly +MATH_PROMPT_TEMPLATE = dedent(""" +Solve the following math problem efficiently and clearly. The last line of your response should be of the following format: 'Therefore, the final answer is: $\\boxed{{ANSWER}}$. I hope it is correct' (without quotes) where ANSWER is just the final number or expression that solves the problem. Think step by step before answering. {prompt} - -Remember to put your answer on its own line at the end in the form -"ANSWER: $ANSWER" (without quotes) where $ANSWER is the answer to -the problem, and you do not need to use a \\boxed command. - -Reasoning: -""".strip() +""") def record_to_sample(record): return Sample(input=record["problem"], target=record["answer"]) +def aime_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=MATH_PROMPT_TEMPLATE.format(prompt=line["problem"]), + choices=[line["answer"]], + gold_index=0, + ) + + aime24 = LightevalTaskConfig( name="aime24", - prompt_function=prompt.aime_prompt_fn, + prompt_function=aime_prompt, sample_fields=record_to_sample, solver=[prompt_template(MATH_PROMPT_TEMPLATE), generate(cache=True)], scorer=math_scorer(), @@ -70,7 +77,7 @@ def record_to_sample(record): aime24_avg = LightevalTaskConfig( name="aime24_avg", - prompt_function=prompt.aime_prompt_fn, + prompt_function=aime_prompt, sample_fields=record_to_sample, hf_repo="HuggingFaceH4/aime_2024", hf_subset="default", @@ -85,7 +92,7 @@ def record_to_sample(record): aime24_gpassk = LightevalTaskConfig( name="aime24_gpassk", - prompt_function=prompt.aime_prompt_fn, + prompt_function=aime_prompt, sample_fields=record_to_sample, hf_repo="HuggingFaceH4/aime_2024", hf_subset="default", @@ -100,7 +107,7 @@ def record_to_sample(record): aime25 = LightevalTaskConfig( name="aime25", - prompt_function=prompt.aime_prompt_fn, + prompt_function=aime_prompt, sample_fields=record_to_sample, solver=[prompt_template(MATH_PROMPT_TEMPLATE), generate(cache=True)], scorer=math_scorer(), @@ -117,7 +124,7 @@ def record_to_sample(record): aime25_avg = LightevalTaskConfig( name="aime25_avg", - prompt_function=prompt.aime_prompt_fn, + prompt_function=aime_prompt, sample_fields=record_to_sample, hf_repo="yentinglin/aime_2025", hf_subset="default", @@ -132,7 +139,7 @@ def record_to_sample(record): aime25_gpassk = LightevalTaskConfig( name="aime25_gpassk", - prompt_function=prompt.aime_prompt_fn, + prompt_function=aime_prompt, sample_fields=record_to_sample, hf_repo="yentinglin/aime_2025", hf_subset="default", diff --git a/src/lighteval/tasks/tasks/anli.py b/src/lighteval/tasks/tasks/anli.py index 0c92fc099..86a0a9d65 100644 --- a/src/lighteval/tasks/tasks/anli.py +++ b/src/lighteval/tasks/tasks/anli.py @@ -22,14 +22,23 @@ https://arxiv.org/abs/1910.14599 """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def anli_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=f"{line['premise']}\nQuestion: {line['hypothesis']} True, False, or Neither?\nAnswer:", + choices=[" True", " Neither", " False"], + gold_index=int(line["label"]), + ) anli_r1 = LightevalTaskConfig( name="anli:r1", - prompt_function=prompt.anli, + prompt_function=anli_prompt, hf_repo="facebook/anli", hf_subset="plain_text", hf_avail_splits=["train_r1", "dev_r1", "test_r1"], @@ -45,7 +54,7 @@ anli_r2 = LightevalTaskConfig( name="anli:r2", - prompt_function=prompt.anli, + prompt_function=anli_prompt, hf_repo="facebook/anli", hf_subset="plain_text", hf_avail_splits=["train_r2", "dev_r2", "test_r2"], @@ -61,7 +70,7 @@ anli_r3 = LightevalTaskConfig( name="anli:r3", - prompt_function=prompt.anli, + prompt_function=anli_prompt, hf_repo="facebook/anli", hf_subset="plain_text", hf_avail_splits=["train_r3", "dev_r3", "test_r3"], diff --git a/src/lighteval/tasks/tasks/arc.py b/src/lighteval/tasks/tasks/arc.py index 028e737cc..39f8b8827 100644 --- a/src/lighteval/tasks/tasks/arc.py +++ b/src/lighteval/tasks/tasks/arc.py @@ -22,14 +22,23 @@ https://arxiv.org/abs/1803.05457 """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def arc_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=f"Question: {line['question']}\nAnswer:", + choices=[f" {c}" for c in line["choices"]["text"]], + gold_index=line["choices"]["label"].index(line["answerKey"]), + ) arc_challenge = LightevalTaskConfig( name="arc:challenge", - prompt_function=prompt.arc, + prompt_function=arc_prompt, hf_repo="allenai/ai2_arc", hf_subset="ARC-Challenge", hf_avail_splits=["train", "test"], @@ -46,7 +55,7 @@ arc_easy = LightevalTaskConfig( name="arc:easy", - prompt_function=prompt.arc, + prompt_function=arc_prompt, hf_repo="allenai/ai2_arc", hf_subset="ARC-Easy", hf_avail_splits=["train", "validation", "test"], diff --git a/src/lighteval/tasks/tasks/arc_agi_2.py b/src/lighteval/tasks/tasks/arc_agi_2.py index 65868796d..058b1b732 100644 --- a/src/lighteval/tasks/tasks/arc_agi_2.py +++ b/src/lighteval/tasks/tasks/arc_agi_2.py @@ -28,23 +28,101 @@ https://arcprize.org/guide """ -import lighteval.tasks.default_prompts as prompt +import json +from textwrap import dedent + +from inspect_ai.dataset import Sample +from inspect_ai.scorer import exact +from inspect_ai.solver import generate + from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +# query from: https://github.com/arcprize/model_baseline/blob/main/src/prompts/system_prompt.txt +PROMPT_TEMPLATE = dedent(""" +You are participating in a puzzle solving competition. You are an expert at solving puzzles. + +Below is a list of input and output pairs with a pattern. Your goal is to identify the pattern or transformation in the training examples that maps the input to the output, then apply that pattern to the test input to give a final output. + +Respond in the format of the training output examples + +--Training Examples-- +{training_examples} +--End of Training Examples-- +--Test Input-- +{test_input} +--End of Test Input-- + +Your response:""") + + +def __convert_2d_list_to_string(list_of_lists: list[list[int]]) -> str: + """Convert a list of lists to a string""" + string_list = "" + for row in list_of_lists: + string_list += json.dumps(row) + "\n" + return string_list + + +def arc_agi_2_prompt(line, task_name: str = None): + training_pairs = line["fewshots"] + training_examples = "" + for i, pair in enumerate(training_pairs): + training_examples += f"--Example {i}-- \n\n INPUT: \n\n" + training_examples += __convert_2d_list_to_string(pair["input"]) + "\n\n" + training_examples += "OUTPUT: \n\n" + training_examples += __convert_2d_list_to_string(pair["output"]) + "\n\n" + + test_input = __convert_2d_list_to_string(line["question"][0]["input"]) + + gold = str(line["question"][0]["output"]) + query = PROMPT_TEMPLATE.format(training_examples=training_examples, test_input=test_input) + + return Doc( + task_name=task_name, + query=query, + choices=[gold], + gold_index=0, + ) + + +def record_to_sample(record): + training_pairs = record["fewshots"] + training_examples = "" + + for i, pair in enumerate(training_pairs): + training_examples += f"--Example {i}-- \n\n INPUT: \n\n" + training_examples += __convert_2d_list_to_string(pair["input"]) + "\n\n" + training_examples += "OUTPUT: \n\n" + training_examples += __convert_2d_list_to_string(pair["output"]) + "\n\n" + + test_input = __convert_2d_list_to_string(record["question"][0]["input"]) + query = PROMPT_TEMPLATE.format(training_examples=training_examples, test_input=test_input) + + target = str(record["question"][0]["output"]) + + return Sample( + input=query, + target=target, + ) arc_agi_2 = LightevalTaskConfig( name="arc_agi_2", - prompt_function=prompt.arc_agi_2, + prompt_function=arc_agi_2_prompt, hf_repo="arc-agi-community/arc-agi-2", hf_subset="default", hf_avail_splits=["train", "test"], evaluation_splits=["test"], few_shots_split=None, few_shots_select=None, - generation_size=2048, metrics=[Metrics.exact_match], stop_sequence=None, + sample_fields=record_to_sample, + solver=[generate(cache=True)], + scorer=exact(), version=0, ) diff --git a/src/lighteval/tasks/tasks/arithmetic.py b/src/lighteval/tasks/tasks/arithmetic.py index f9f55a290..48a290435 100644 --- a/src/lighteval/tasks/tasks/arithmetic.py +++ b/src/lighteval/tasks/tasks/arithmetic.py @@ -19,14 +19,18 @@ https://arxiv.org/abs/2005.14165 """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def arithmetic_prompt(line, task_name: str = None): + return Doc(task_name=task_name, query=line["context"], choices=[line["completion"]], gold_index=[0]) arithmetic_1dc = LightevalTaskConfig( name="arithmetic:1dc", - prompt_function=prompt.arithmetic, + prompt_function=arithmetic_prompt, hf_repo="EleutherAI/arithmetic", hf_subset="arithmetic_1dc", hf_avail_splits=["validation"], @@ -41,7 +45,7 @@ arithmetic_2da = LightevalTaskConfig( name="arithmetic:2da", - prompt_function=prompt.arithmetic, + prompt_function=arithmetic_prompt, hf_repo="EleutherAI/arithmetic", hf_subset="arithmetic_2da", hf_avail_splits=["validation"], @@ -56,7 +60,7 @@ arithmetic_2dm = LightevalTaskConfig( name="arithmetic:2dm", - prompt_function=prompt.arithmetic, + prompt_function=arithmetic_prompt, hf_repo="EleutherAI/arithmetic", hf_subset="arithmetic_2dm", hf_avail_splits=["validation"], @@ -71,7 +75,7 @@ arithmetic_2ds = LightevalTaskConfig( name="arithmetic:2ds", - prompt_function=prompt.arithmetic, + prompt_function=arithmetic_prompt, hf_repo="EleutherAI/arithmetic", hf_subset="arithmetic_2ds", hf_avail_splits=["validation"], @@ -86,7 +90,7 @@ arithmetic_3da = LightevalTaskConfig( name="arithmetic:3da", - prompt_function=prompt.arithmetic, + prompt_function=arithmetic_prompt, hf_repo="EleutherAI/arithmetic", hf_subset="arithmetic_3da", hf_avail_splits=["validation"], @@ -101,7 +105,7 @@ arithmetic_3ds = LightevalTaskConfig( name="arithmetic:3ds", - prompt_function=prompt.arithmetic, + prompt_function=arithmetic_prompt, hf_repo="EleutherAI/arithmetic", hf_subset="arithmetic_3ds", hf_avail_splits=["validation"], @@ -116,7 +120,7 @@ arithmetic_4da = LightevalTaskConfig( name="arithmetic:4da", - prompt_function=prompt.arithmetic, + prompt_function=arithmetic_prompt, hf_repo="EleutherAI/arithmetic", hf_subset="arithmetic_4da", hf_avail_splits=["validation"], @@ -131,7 +135,7 @@ arithmetic_4ds = LightevalTaskConfig( name="arithmetic:4ds", - prompt_function=prompt.arithmetic, + prompt_function=arithmetic_prompt, hf_repo="EleutherAI/arithmetic", hf_subset="arithmetic_4ds", hf_avail_splits=["validation"], @@ -146,7 +150,7 @@ arithmetic_5da = LightevalTaskConfig( name="arithmetic:5da", - prompt_function=prompt.arithmetic, + prompt_function=arithmetic_prompt, hf_repo="EleutherAI/arithmetic", hf_subset="arithmetic_5da", hf_avail_splits=["validation"], @@ -161,7 +165,7 @@ arithmetic_5ds = LightevalTaskConfig( name="arithmetic:5ds", - prompt_function=prompt.arithmetic, + prompt_function=arithmetic_prompt, hf_repo="EleutherAI/arithmetic", hf_subset="arithmetic_5ds", hf_avail_splits=["validation"], diff --git a/src/lighteval/tasks/tasks/asdiv.py b/src/lighteval/tasks/tasks/asdiv.py index 513b1aef6..4fd34df36 100644 --- a/src/lighteval/tasks/tasks/asdiv.py +++ b/src/lighteval/tasks/tasks/asdiv.py @@ -19,14 +19,23 @@ https://arxiv.org/abs/2410.12853 """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def asdiv_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=f"{line['body']}\nQuestion:{line['question']}\nAnswer:", + choices=line["answer"].split(" (")[0], + gold_index=[0], + ) asdiv = LightevalTaskConfig( name="asdiv", - prompt_function=prompt.asdiv, + prompt_function=asdiv_prompt, hf_repo="EleutherAI/asdiv", hf_subset="asdiv", hf_avail_splits=["validation"], diff --git a/src/lighteval/tasks/tasks/babi_qa.py b/src/lighteval/tasks/tasks/babi_qa.py index 16a6cfdb2..3a16c4fb3 100644 --- a/src/lighteval/tasks/tasks/babi_qa.py +++ b/src/lighteval/tasks/tasks/babi_qa.py @@ -19,14 +19,41 @@ https://arxiv.org/abs/1502.05698 """ -import lighteval.tasks.default_prompts as prompt +import json + from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def babi_qa_prompt(line, task_name: str = None): + def process_path(path: str) -> str: + steps = path.split(",") + directions = {"s": "south", "n": "north", "e": "east", "w": "west"} + path = " ".join([directions[step] for step in steps]) + return path + + if isinstance(line["story"], dict): + line = line["story"] + else: + line = json.loads(line["story"]) + + results = [] + story = [] + for type, text, answer in zip(line["type"], line["text"], line["answer"]): + if type == "supporting fact": + story.append(text) + elif type == "question": + text = text.replace("_", process_path(answer)) + query = "\n".join(story) + f"\nQuestion: {text}\nAnswer: " + results.append(Doc(task_name=task_name, query=query, choices=[answer], gold_index=0)) + story = [] + return results babi_qa = LightevalTaskConfig( name="babi_qa", - prompt_function=prompt.babi_qa, + prompt_function=babi_qa_prompt, hf_repo="facebook/babi_qa", hf_subset="en-valid-qa1", hf_avail_splits=["train", "test", "validation"], diff --git a/src/lighteval/tasks/tasks/bbq.py b/src/lighteval/tasks/tasks/bbq.py index 4c01ab034..eb5fb1d45 100644 --- a/src/lighteval/tasks/tasks/bbq.py +++ b/src/lighteval/tasks/tasks/bbq.py @@ -19,14 +19,28 @@ https://arxiv.org/abs/2110.08193 """ -import lighteval.tasks.default_prompts as prompt +from string import ascii_uppercase + from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def bbq_prompt(line, task_name: str = None): + query = f"The following are multiple choice questions (with answers).\nPassage: {line['context']}\nQuestion: {line['question']}" + query += "".join([f"\n{key}. {choice}" for key, choice in zip(ascii_uppercase, line["choices"])]) + query += "\nAnswer:" + return Doc( + task_name=task_name, + query=query, + choices=list(ascii_uppercase)[: len(line["choices"])], + gold_index=int(line["gold_index"]), + ) bbq = LightevalTaskConfig( name="bbq", - prompt_function=prompt.bbq, + prompt_function=bbq_prompt, hf_repo="lighteval/bbq_helm", hf_subset="all", hf_avail_splits=["train", "test"], @@ -41,7 +55,7 @@ bbq_Age = LightevalTaskConfig( name="bbq:Age", - prompt_function=prompt.bbq, + prompt_function=bbq_prompt, hf_repo="lighteval/bbq_helm", hf_subset="Age", hf_avail_splits=["train", "test"], @@ -56,7 +70,7 @@ bbq_Disability_status = LightevalTaskConfig( name="bbq:Disability_status", - prompt_function=prompt.bbq, + prompt_function=bbq_prompt, hf_repo="lighteval/bbq_helm", hf_subset="Disability_status", hf_avail_splits=["train", "test"], @@ -71,7 +85,7 @@ bbq_Gender_identity = LightevalTaskConfig( name="bbq:Gender_identity", - prompt_function=prompt.bbq, + prompt_function=bbq_prompt, hf_repo="lighteval/bbq_helm", hf_subset="Gender_identity", hf_avail_splits=["train", "test"], @@ -86,7 +100,7 @@ bbq_Nationality = LightevalTaskConfig( name="bbq:Nationality", - prompt_function=prompt.bbq, + prompt_function=bbq_prompt, hf_repo="lighteval/bbq_helm", hf_subset="Nationality", hf_avail_splits=["train", "test"], @@ -101,7 +115,7 @@ bbq_Physical_appearance = LightevalTaskConfig( name="bbq:Physical_appearance", - prompt_function=prompt.bbq, + prompt_function=bbq_prompt, hf_repo="lighteval/bbq_helm", hf_subset="Physical_appearance", hf_avail_splits=["train", "test"], @@ -116,7 +130,7 @@ bbq_Race_ethnicity = LightevalTaskConfig( name="bbq:Race_ethnicity", - prompt_function=prompt.bbq, + prompt_function=bbq_prompt, hf_repo="lighteval/bbq_helm", hf_subset="Race_ethnicity", hf_avail_splits=["train", "test"], @@ -131,7 +145,7 @@ bbq_Race_x_SES = LightevalTaskConfig( name="bbq:Race_x_SES", - prompt_function=prompt.bbq, + prompt_function=bbq_prompt, hf_repo="lighteval/bbq_helm", hf_subset="Race_x_SES", hf_avail_splits=["train", "test"], @@ -146,7 +160,7 @@ bbq_Race_x_gender = LightevalTaskConfig( name="bbq:Race_x_gender", - prompt_function=prompt.bbq, + prompt_function=bbq_prompt, hf_repo="lighteval/bbq_helm", hf_subset="Race_x_gender", hf_avail_splits=["train", "test"], @@ -161,7 +175,7 @@ bbq_Religion = LightevalTaskConfig( name="bbq:Religion", - prompt_function=prompt.bbq, + prompt_function=bbq_prompt, hf_repo="lighteval/bbq_helm", hf_subset="Religion", hf_avail_splits=["train", "test"], @@ -176,7 +190,7 @@ bbq_SES = LightevalTaskConfig( name="bbq:SES", - prompt_function=prompt.bbq, + prompt_function=bbq_prompt, hf_repo="lighteval/bbq_helm", hf_subset="SES", hf_avail_splits=["train", "test"], @@ -191,7 +205,7 @@ bbq_Sexual_orientation = LightevalTaskConfig( name="bbq:Sexual_orientation", - prompt_function=prompt.bbq, + prompt_function=bbq_prompt, hf_repo="lighteval/bbq_helm", hf_subset="Sexual_orientation", hf_avail_splits=["train", "test"], diff --git a/src/lighteval/tasks/tasks/bigbench.py b/src/lighteval/tasks/tasks/bigbench.py index 2466453b3..e18057f33 100644 --- a/src/lighteval/tasks/tasks/bigbench.py +++ b/src/lighteval/tasks/tasks/bigbench.py @@ -19,14 +19,73 @@ https://arxiv.org/abs/2206.04615 """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def bigbench_linefeed_before_and_after_query_prompt(line, task_name: str = None): + if len(line["multiple_choice_scores"]) == 0: + choices = line["targets"] + gold_index = [i for i, _ in enumerate(line["targets"])] + else: + choices = line["multiple_choice_targets"] + gold_index = [i for i, a in enumerate(line["multiple_choice_scores"]) if a == 1] + + return Doc( + task_name=task_name, + query=f"\n{line['inputs']}\n", + choices=choices, + gold_index=gold_index, + ) + + +def bigbench_linefeed_before_whitespace_after_query_prompt(line, task_name: str = None): + if len(line["multiple_choice_scores"]) == 0: + choices = line["targets"] + gold_index = [i for i, _ in enumerate(line["targets"])] + else: + choices = line["multiple_choice_targets"] + gold_index = [i for i, a in enumerate(line["multiple_choice_scores"]) if a == 1] + + return Doc( + task_name=task_name, + query=f"\n{line['inputs']} ", + choices=choices, + gold_index=gold_index, + ) + + +def bigbench_whitespace_after_query_prompt(line, task_name: str = None): + if len(line["multiple_choice_scores"]) == 0: + choices = line["targets"] + gold_index = [i for i, _ in enumerate(line["targets"])] + else: + choices = line["multiple_choice_targets"] + gold_index = [i for i, a in enumerate(line["multiple_choice_scores"]) if a == 1] + + return Doc( + task_name=task_name, + query=f"{line['inputs']} ", + choices=choices, + gold_index=gold_index, + ) + + +def bigbench_prompt(line, task_name: str = None): + if len(line["multiple_choice_scores"]) == 0: + choices = line["targets"] + gold_index = [i for i, _ in enumerate(line["targets"])] + else: + choices = line["multiple_choice_targets"] + gold_index = [i for i, a in enumerate(line["multiple_choice_scores"]) if a == 1] + + return Doc(task_name=task_name, query=line["inputs"], choices=choices, gold_index=gold_index) abstract_narrative_understanding = LightevalTaskConfig( name="bigbench:abstract_narrative_understanding", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="abstract_narrative_understanding", hf_avail_splits=["default", "train", "validation"], @@ -41,7 +100,7 @@ anachronisms = LightevalTaskConfig( name="bigbench:anachronisms", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="anachronisms", hf_avail_splits=["default", "train", "validation"], @@ -56,7 +115,7 @@ analogical_similarity = LightevalTaskConfig( name="bigbench:analogical_similarity", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="analogical_similarity", hf_avail_splits=["default", "train", "validation"], @@ -71,7 +130,7 @@ analytic_entailment = LightevalTaskConfig( name="bigbench:analytic_entailment", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="analytic_entailment", hf_avail_splits=["default", "train", "validation"], @@ -86,7 +145,7 @@ arithmetic_bb = LightevalTaskConfig( name="bigbench:arithmetic_bb", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="arithmetic", hf_avail_splits=["default", "train", "validation"], @@ -101,7 +160,7 @@ ascii_word_recognition = LightevalTaskConfig( name="bigbench:ascii_word_recognition", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="ascii_word_recognition", hf_avail_splits=["default", "train", "validation"], @@ -116,7 +175,7 @@ authorship_verification = LightevalTaskConfig( name="bigbench:authorship_verification", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="authorship_verification", hf_avail_splits=["default", "train", "validation"], @@ -131,7 +190,7 @@ auto_categorization = LightevalTaskConfig( name="bigbench:auto_categorization", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="auto_categorization", hf_avail_splits=["default", "train", "validation"], @@ -146,7 +205,7 @@ auto_debugging = LightevalTaskConfig( name="bigbench:auto_debugging", - prompt_function=prompt.bigbench_linefeed_before_and_after_query, + prompt_function=bigbench_linefeed_before_and_after_query_prompt, hf_repo="tasksource/bigbench", hf_subset="auto_debugging", hf_avail_splits=["default", "train", "validation"], @@ -161,7 +220,7 @@ bbq_lite_json = LightevalTaskConfig( name="bigbench:bbq_lite_json", - prompt_function=prompt.bigbench_linefeed_before_whitespace_after_query, + prompt_function=bigbench_linefeed_before_whitespace_after_query_prompt, hf_repo="tasksource/bigbench", hf_subset="bbq_lite_json", hf_avail_splits=["default", "train", "validation"], @@ -176,7 +235,7 @@ bridging_anaphora_resolution_barqa = LightevalTaskConfig( name="bigbench:bridging_anaphora_resolution_barqa", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="bridging_anaphora_resolution_barqa", hf_avail_splits=["default", "train", "validation"], @@ -191,7 +250,7 @@ causal_judgment = LightevalTaskConfig( name="bigbench:causal_judgment", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="causal_judgment", hf_avail_splits=["default", "train", "validation"], @@ -206,7 +265,7 @@ cause_and_effect = LightevalTaskConfig( name="bigbench:cause_and_effect", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="cause_and_effect", hf_avail_splits=["default", "train", "validation"], @@ -221,7 +280,7 @@ checkmate_in_one = LightevalTaskConfig( name="bigbench:checkmate_in_one", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="checkmate_in_one", hf_avail_splits=["default", "train", "validation"], @@ -236,7 +295,7 @@ chess_state_tracking = LightevalTaskConfig( name="bigbench:chess_state_tracking", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="chess_state_tracking", hf_avail_splits=["default", "train", "validation"], @@ -251,7 +310,7 @@ chinese_remainder_theorem = LightevalTaskConfig( name="bigbench:chinese_remainder_theorem", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="chinese_remainder_theorem", hf_avail_splits=["default", "train", "validation"], @@ -266,7 +325,7 @@ cifar10_classification = LightevalTaskConfig( name="bigbench:cifar10_classification", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="cifar10_classification", hf_avail_splits=["default", "train", "validation"], @@ -281,7 +340,7 @@ code_line_description = LightevalTaskConfig( name="bigbench:code_line_description", - prompt_function=prompt.bigbench_linefeed_before_and_after_query, + prompt_function=bigbench_linefeed_before_and_after_query_prompt, hf_repo="tasksource/bigbench", hf_subset="code_line_description", hf_avail_splits=["default", "train", "validation"], @@ -296,7 +355,7 @@ codenames = LightevalTaskConfig( name="bigbench:codenames", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="codenames", hf_avail_splits=["default", "train", "validation"], @@ -311,7 +370,7 @@ color = LightevalTaskConfig( name="bigbench:color", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="color", hf_avail_splits=["default", "train", "validation"], @@ -331,7 +390,7 @@ common_morpheme = LightevalTaskConfig( name="bigbench:common_morpheme", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="common_morpheme", hf_avail_splits=["default", "train", "validation"], @@ -346,7 +405,7 @@ conceptual_combinations = LightevalTaskConfig( name="bigbench:conceptual_combinations", - prompt_function=prompt.bigbench_linefeed_before_whitespace_after_query, + prompt_function=bigbench_linefeed_before_whitespace_after_query_prompt, hf_repo="tasksource/bigbench", hf_subset="conceptual_combinations", hf_avail_splits=["default", "train", "validation"], @@ -361,7 +420,7 @@ conlang_translation = LightevalTaskConfig( name="bigbench:conlang_translation", - prompt_function=prompt.bigbench_whitespace_after_query, + prompt_function=bigbench_whitespace_after_query_prompt, hf_repo="tasksource/bigbench", hf_subset="conlang_translation", hf_avail_splits=["default", "train", "validation"], @@ -376,7 +435,7 @@ contextual_parametric_knowledge_conflicts = LightevalTaskConfig( name="bigbench:contextual_parametric_knowledge_conflicts", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="contextual_parametric_knowledge_conflicts", hf_avail_splits=["default", "train", "validation"], @@ -391,7 +450,7 @@ crash_blossom = LightevalTaskConfig( name="bigbench:crash_blossom", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="crash_blossom", hf_avail_splits=["default", "train", "validation"], @@ -406,7 +465,7 @@ crass_ai = LightevalTaskConfig( name="bigbench:crass_ai", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="crass_ai", hf_avail_splits=["default", "train", "validation"], @@ -421,7 +480,7 @@ cryobiology_spanish = LightevalTaskConfig( name="bigbench:cryobiology_spanish", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="cryobiology_spanish", hf_avail_splits=["default", "train", "validation"], @@ -436,7 +495,7 @@ cryptonite = LightevalTaskConfig( name="bigbench:cryptonite", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="cryptonite", hf_avail_splits=["default", "train", "validation"], @@ -451,7 +510,7 @@ cs_algorithms = LightevalTaskConfig( name="bigbench:cs_algorithms", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="cs_algorithms", hf_avail_splits=["default", "train", "validation"], @@ -466,7 +525,7 @@ dark_humor_detection = LightevalTaskConfig( name="bigbench:dark_humor_detection", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="dark_humor_detection", hf_avail_splits=["default", "train", "validation"], @@ -481,7 +540,7 @@ date_understanding = LightevalTaskConfig( name="bigbench:date_understanding", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="date_understanding", hf_avail_splits=["default", "train", "validation"], @@ -496,7 +555,7 @@ disambiguation_qa = LightevalTaskConfig( name="bigbench:disambiguation_qa", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="disambiguation_qa", hf_avail_splits=["default", "train", "validation"], @@ -511,7 +570,7 @@ discourse_marker_prediction = LightevalTaskConfig( name="bigbench:discourse_marker_prediction", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="discourse_marker_prediction", hf_avail_splits=["default", "train", "validation"], @@ -526,7 +585,7 @@ disfl_qa = LightevalTaskConfig( name="bigbench:disfl_qa", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="disfl_qa", hf_avail_splits=["default", "train", "validation"], @@ -541,7 +600,7 @@ dyck_languages = LightevalTaskConfig( name="bigbench:dyck_languages", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="dyck_languages", hf_avail_splits=["default", "train", "validation"], @@ -556,7 +615,7 @@ elementary_math_qa = LightevalTaskConfig( name="bigbench:elementary_math_qa", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="elementary_math_qa", hf_avail_splits=["default", "train", "validation"], @@ -571,7 +630,7 @@ emoji_movie = LightevalTaskConfig( name="bigbench:emoji_movie", - prompt_function=prompt.bigbench_linefeed_before_whitespace_after_query, + prompt_function=bigbench_linefeed_before_whitespace_after_query_prompt, hf_repo="tasksource/bigbench", hf_subset="emoji_movie", hf_avail_splits=["default", "train", "validation"], @@ -591,7 +650,7 @@ emojis_emotion_prediction = LightevalTaskConfig( name="bigbench:emojis_emotion_prediction", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="emojis_emotion_prediction", hf_avail_splits=["default", "train", "validation"], @@ -606,7 +665,7 @@ empirical_judgments = LightevalTaskConfig( name="bigbench:empirical_judgments", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="empirical_judgments", hf_avail_splits=["default", "train", "validation"], @@ -621,7 +680,7 @@ english_proverbs = LightevalTaskConfig( name="bigbench:english_proverbs", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="english_proverbs", hf_avail_splits=["default", "train", "validation"], @@ -636,7 +695,7 @@ english_russian_proverbs = LightevalTaskConfig( name="bigbench:english_russian_proverbs", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="english_russian_proverbs", hf_avail_splits=["default", "train", "validation"], @@ -651,7 +710,7 @@ entailed_polarity = LightevalTaskConfig( name="bigbench:entailed_polarity", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="entailed_polarity", hf_avail_splits=["default", "train", "validation"], @@ -666,7 +725,7 @@ entailed_polarity_hindi = LightevalTaskConfig( name="bigbench:entailed_polarity_hindi", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="entailed_polarity_hindi", hf_avail_splits=["default", "train", "validation"], @@ -681,7 +740,7 @@ epistemic_reasoning = LightevalTaskConfig( name="bigbench:epistemic_reasoning", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="epistemic_reasoning", hf_avail_splits=["default", "train", "validation"], @@ -696,7 +755,7 @@ evaluating_information_essentiality = LightevalTaskConfig( name="bigbench:evaluating_information_essentiality", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="evaluating_information_essentiality", hf_avail_splits=["default", "train", "validation"], @@ -711,7 +770,7 @@ fact_checker = LightevalTaskConfig( name="bigbench:fact_checker", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="fact_checker", hf_avail_splits=["default", "train", "validation"], @@ -726,7 +785,7 @@ fantasy_reasoning = LightevalTaskConfig( name="bigbench:fantasy_reasoning", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="fantasy_reasoning", hf_avail_splits=["default", "train", "validation"], @@ -741,7 +800,7 @@ few_shot_nlg = LightevalTaskConfig( name="bigbench:few_shot_nlg", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="few_shot_nlg", hf_avail_splits=["default", "train", "validation"], @@ -756,7 +815,7 @@ figure_of_speech_detection = LightevalTaskConfig( name="bigbench:figure_of_speech_detection", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="figure_of_speech_detection", hf_avail_splits=["default", "train", "validation"], @@ -771,7 +830,7 @@ formal_fallacies_syllogisms_negation = LightevalTaskConfig( name="bigbench:formal_fallacies_syllogisms_negation", - prompt_function=prompt.bigbench_linefeed_before_whitespace_after_query, + prompt_function=bigbench_linefeed_before_whitespace_after_query_prompt, hf_repo="tasksource/bigbench", hf_subset="formal_fallacies_syllogisms_negation", hf_avail_splits=["default", "train", "validation"], @@ -786,7 +845,7 @@ gem = LightevalTaskConfig( name="bigbench:gem", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="gem", hf_avail_splits=["default", "train", "validation"], @@ -801,7 +860,7 @@ gender_inclusive_sentences_german = LightevalTaskConfig( name="bigbench:gender_inclusive_sentences_german", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="gender_inclusive_sentences_german", hf_avail_splits=["default", "train", "validation"], @@ -816,7 +875,7 @@ general_knowledge = LightevalTaskConfig( name="bigbench:general_knowledge", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="general_knowledge", hf_avail_splits=["default", "train", "validation"], @@ -831,7 +890,7 @@ geometric_shapes = LightevalTaskConfig( name="bigbench:geometric_shapes", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="geometric_shapes", hf_avail_splits=["default", "train", "validation"], @@ -851,7 +910,7 @@ goal_step_wikihow = LightevalTaskConfig( name="bigbench:goal_step_wikihow", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="goal_step_wikihow", hf_avail_splits=["default", "train", "validation"], @@ -866,7 +925,7 @@ gre_reading_comprehension = LightevalTaskConfig( name="bigbench:gre_reading_comprehension", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="gre_reading_comprehension", hf_avail_splits=["default", "train", "validation"], @@ -881,7 +940,7 @@ hhh_alignment = LightevalTaskConfig( name="bigbench:hhh_alignment", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="hhh_alignment", hf_avail_splits=["default", "train", "validation"], @@ -896,7 +955,7 @@ hindi_question_answering = LightevalTaskConfig( name="bigbench:hindi_question_answering", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="hindi_question_answering", hf_avail_splits=["default", "train", "validation"], @@ -911,7 +970,7 @@ hindu_knowledge = LightevalTaskConfig( name="bigbench:hindu_knowledge", - prompt_function=prompt.bigbench_linefeed_before_whitespace_after_query, + prompt_function=bigbench_linefeed_before_whitespace_after_query_prompt, hf_repo="tasksource/bigbench", hf_subset="hindu_knowledge", hf_avail_splits=["default", "train", "validation"], @@ -926,7 +985,7 @@ hinglish_toxicity = LightevalTaskConfig( name="bigbench:hinglish_toxicity", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="hinglish_toxicity", hf_avail_splits=["default", "train", "validation"], @@ -941,7 +1000,7 @@ human_organs_senses = LightevalTaskConfig( name="bigbench:human_organs_senses", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="human_organs_senses", hf_avail_splits=["default", "train", "validation"], @@ -956,7 +1015,7 @@ hyperbaton = LightevalTaskConfig( name="bigbench:hyperbaton", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="hyperbaton", hf_avail_splits=["default", "train", "validation"], @@ -971,7 +1030,7 @@ identify_math_theorems = LightevalTaskConfig( name="bigbench:identify_math_theorems", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="identify_math_theorems", hf_avail_splits=["default", "train", "validation"], @@ -986,7 +1045,7 @@ identify_odd_metaphor = LightevalTaskConfig( name="bigbench:identify_odd_metaphor", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="identify_odd_metaphor", hf_avail_splits=["default", "train", "validation"], @@ -1001,7 +1060,7 @@ implicatures = LightevalTaskConfig( name="bigbench:implicatures", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="implicatures", hf_avail_splits=["default", "train", "validation"], @@ -1016,7 +1075,7 @@ implicit_relations = LightevalTaskConfig( name="bigbench:implicit_relations", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="implicit_relations", hf_avail_splits=["default", "train", "validation"], @@ -1031,7 +1090,7 @@ intent_recognition = LightevalTaskConfig( name="bigbench:intent_recognition", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="intent_recognition", hf_avail_splits=["default", "train", "validation"], @@ -1046,7 +1105,7 @@ international_phonetic_alphabet_nli = LightevalTaskConfig( name="bigbench:international_phonetic_alphabet_nli", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="international_phonetic_alphabet_nli", hf_avail_splits=["default", "train", "validation"], @@ -1061,7 +1120,7 @@ international_phonetic_alphabet_transliterate = LightevalTaskConfig( name="bigbench:international_phonetic_alphabet_transliterate", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="international_phonetic_alphabet_transliterate", hf_avail_splits=["default", "train", "validation"], @@ -1076,7 +1135,7 @@ intersect_geometry = LightevalTaskConfig( name="bigbench:intersect_geometry", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="intersect_geometry", hf_avail_splits=["default", "train", "validation"], @@ -1091,7 +1150,7 @@ irony_identification = LightevalTaskConfig( name="bigbench:irony_identification", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="irony_identification", hf_avail_splits=["default", "train", "validation"], @@ -1106,7 +1165,7 @@ kanji_ascii = LightevalTaskConfig( name="bigbench:kanji_ascii", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="kanji_ascii", hf_avail_splits=["default", "train", "validation"], @@ -1121,7 +1180,7 @@ kannada = LightevalTaskConfig( name="bigbench:kannada", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="kannada", hf_avail_splits=["default", "train", "validation"], @@ -1136,7 +1195,7 @@ key_value_maps = LightevalTaskConfig( name="bigbench:key_value_maps", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="key_value_maps", hf_avail_splits=["default", "train", "validation"], @@ -1151,7 +1210,7 @@ known_unknowns = LightevalTaskConfig( name="bigbench:known_unknowns", - prompt_function=prompt.bigbench_linefeed_before_whitespace_after_query, + prompt_function=bigbench_linefeed_before_whitespace_after_query_prompt, hf_repo="tasksource/bigbench", hf_subset="known_unknowns", hf_avail_splits=["default", "train", "validation"], @@ -1166,7 +1225,7 @@ language_games = LightevalTaskConfig( name="bigbench:language_games", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="language_games", hf_avail_splits=["default", "train", "validation"], @@ -1181,7 +1240,7 @@ language_identification = LightevalTaskConfig( name="bigbench:language_identification", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="language_identification", hf_avail_splits=["default", "train", "validation"], @@ -1196,7 +1255,7 @@ linguistic_mappings = LightevalTaskConfig( name="bigbench:linguistic_mappings", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="linguistic_mappings", hf_avail_splits=["default", "train", "validation"], @@ -1211,7 +1270,7 @@ linguistics_puzzles = LightevalTaskConfig( name="bigbench:linguistics_puzzles", - prompt_function=prompt.bigbench_whitespace_after_query, + prompt_function=bigbench_whitespace_after_query_prompt, hf_repo="tasksource/bigbench", hf_subset="linguistics_puzzles", hf_avail_splits=["default", "train", "validation"], @@ -1226,7 +1285,7 @@ logic_grid_puzzle = LightevalTaskConfig( name="bigbench:logic_grid_puzzle", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="logic_grid_puzzle", hf_avail_splits=["default", "train", "validation"], @@ -1241,7 +1300,7 @@ logical_args = LightevalTaskConfig( name="bigbench:logical_args", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="logical_args", hf_avail_splits=["default", "train", "validation"], @@ -1256,7 +1315,7 @@ logical_deduction = LightevalTaskConfig( name="bigbench:logical_deduction", - prompt_function=prompt.bigbench_whitespace_after_query, + prompt_function=bigbench_whitespace_after_query_prompt, hf_repo="tasksource/bigbench", hf_subset="logical_deduction", hf_avail_splits=["default", "train", "validation"], @@ -1271,7 +1330,7 @@ logical_fallacy_detection = LightevalTaskConfig( name="bigbench:logical_fallacy_detection", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="logical_fallacy_detection", hf_avail_splits=["default", "train", "validation"], @@ -1286,7 +1345,7 @@ logical_sequence = LightevalTaskConfig( name="bigbench:logical_sequence", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="logical_sequence", hf_avail_splits=["default", "train", "validation"], @@ -1301,7 +1360,7 @@ mathematical_induction = LightevalTaskConfig( name="bigbench:mathematical_induction", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="mathematical_induction", hf_avail_splits=["default", "train", "validation"], @@ -1316,7 +1375,7 @@ matrixshapes = LightevalTaskConfig( name="bigbench:matrixshapes", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="matrixshapes", hf_avail_splits=["default", "train", "validation"], @@ -1331,7 +1390,7 @@ metaphor_boolean = LightevalTaskConfig( name="bigbench:metaphor_boolean", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="metaphor_boolean", hf_avail_splits=["default", "train", "validation"], @@ -1346,7 +1405,7 @@ metaphor_understanding = LightevalTaskConfig( name="bigbench:metaphor_understanding", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="metaphor_understanding", hf_avail_splits=["default", "train", "validation"], @@ -1361,7 +1420,7 @@ minute_mysteries_qa = LightevalTaskConfig( name="bigbench:minute_mysteries_qa", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="minute_mysteries_qa", hf_avail_splits=["default", "train", "validation"], @@ -1376,7 +1435,7 @@ misconceptions = LightevalTaskConfig( name="bigbench:misconceptions", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="misconceptions", hf_avail_splits=["default", "train", "validation"], @@ -1391,7 +1450,7 @@ misconceptions_russian = LightevalTaskConfig( name="bigbench:misconceptions_russian", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="misconceptions_russian", hf_avail_splits=["default", "train", "validation"], @@ -1406,7 +1465,7 @@ mnist_ascii = LightevalTaskConfig( name="bigbench:mnist_ascii", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="mnist_ascii", hf_avail_splits=["default", "train", "validation"], @@ -1421,7 +1480,7 @@ modified_arithmetic = LightevalTaskConfig( name="bigbench:modified_arithmetic", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="modified_arithmetic", hf_avail_splits=["default", "train", "validation"], @@ -1436,7 +1495,7 @@ moral_permissibility = LightevalTaskConfig( name="bigbench:moral_permissibility", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="moral_permissibility", hf_avail_splits=["default", "train", "validation"], @@ -1451,7 +1510,7 @@ movie_dialog_same_or_different = LightevalTaskConfig( name="bigbench:movie_dialog_same_or_different", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="movie_dialog_same_or_different", hf_avail_splits=["default", "train", "validation"], @@ -1466,7 +1525,7 @@ movie_recommendation = LightevalTaskConfig( name="bigbench:movie_recommendation", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="movie_recommendation", hf_avail_splits=["default", "train", "validation"], @@ -1481,7 +1540,7 @@ mult_data_wrangling = LightevalTaskConfig( name="bigbench:mult_data_wrangling", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="mult_data_wrangling", hf_avail_splits=["default", "train", "validation"], @@ -1496,7 +1555,7 @@ navigate = LightevalTaskConfig( name="bigbench:navigate", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="navigate", hf_avail_splits=["default", "train", "validation"], @@ -1511,7 +1570,7 @@ nonsense_words_grammar = LightevalTaskConfig( name="bigbench:nonsense_words_grammar", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="nonsense_words_grammar", hf_avail_splits=["default", "train", "validation"], @@ -1526,7 +1585,7 @@ novel_concepts = LightevalTaskConfig( name="bigbench:novel_concepts", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="novel_concepts", hf_avail_splits=["default", "train", "validation"], @@ -1541,7 +1600,7 @@ object_counting = LightevalTaskConfig( name="bigbench:object_counting", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="object_counting", hf_avail_splits=["default", "train", "validation"], @@ -1556,7 +1615,7 @@ odd_one_out = LightevalTaskConfig( name="bigbench:odd_one_out", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="odd_one_out", hf_avail_splits=["default", "train", "validation"], @@ -1571,7 +1630,7 @@ operators = LightevalTaskConfig( name="bigbench:operators", - prompt_function=prompt.bigbench_whitespace_after_query, + prompt_function=bigbench_whitespace_after_query_prompt, hf_repo="tasksource/bigbench", hf_subset="operators", hf_avail_splits=["default", "train", "validation"], @@ -1586,7 +1645,7 @@ paragraph_segmentation = LightevalTaskConfig( name="bigbench:paragraph_segmentation", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="paragraph_segmentation", hf_avail_splits=["default", "train", "validation"], @@ -1601,7 +1660,7 @@ parsinlu_qa = LightevalTaskConfig( name="bigbench:parsinlu_qa", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="parsinlu_qa", hf_avail_splits=["default", "train", "validation"], @@ -1616,7 +1675,7 @@ parsinlu_reading_comprehension = LightevalTaskConfig( name="bigbench:parsinlu_reading_comprehension", - prompt_function=prompt.bigbench_linefeed_before_whitespace_after_query, + prompt_function=bigbench_linefeed_before_whitespace_after_query_prompt, hf_repo="tasksource/bigbench", hf_subset="parsinlu_reading_comprehension", hf_avail_splits=["default", "train", "validation"], @@ -1631,7 +1690,7 @@ penguins_in_a_table = LightevalTaskConfig( name="bigbench:penguins_in_a_table", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="penguins_in_a_table", hf_avail_splits=["default", "train", "validation"], @@ -1646,7 +1705,7 @@ periodic_elements = LightevalTaskConfig( name="bigbench:periodic_elements", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="periodic_elements", hf_avail_splits=["default", "train", "validation"], @@ -1661,7 +1720,7 @@ persian_idioms = LightevalTaskConfig( name="bigbench:persian_idioms", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="persian_idioms", hf_avail_splits=["default", "train", "validation"], @@ -1676,7 +1735,7 @@ phrase_relatedness = LightevalTaskConfig( name="bigbench:phrase_relatedness", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="phrase_relatedness", hf_avail_splits=["default", "train", "validation"], @@ -1691,7 +1750,7 @@ physical_intuition = LightevalTaskConfig( name="bigbench:physical_intuition", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="physical_intuition", hf_avail_splits=["default", "train", "validation"], @@ -1706,7 +1765,7 @@ physics = LightevalTaskConfig( name="bigbench:physics", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="physics", hf_avail_splits=["default", "train", "validation"], @@ -1721,7 +1780,7 @@ physics_questions = LightevalTaskConfig( name="bigbench:physics_questions", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="physics_questions", hf_avail_splits=["default", "train", "validation"], @@ -1736,7 +1795,7 @@ play_dialog_same_or_different = LightevalTaskConfig( name="bigbench:play_dialog_same_or_different", - prompt_function=prompt.bigbench_linefeed_before_whitespace_after_query, + prompt_function=bigbench_linefeed_before_whitespace_after_query_prompt, hf_repo="tasksource/bigbench", hf_subset="play_dialog_same_or_different", hf_avail_splits=["default", "train", "validation"], @@ -1751,7 +1810,7 @@ polish_sequence_labeling = LightevalTaskConfig( name="bigbench:polish_sequence_labeling", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="polish_sequence_labeling", hf_avail_splits=["default", "train", "validation"], @@ -1766,7 +1825,7 @@ presuppositions_as_nli = LightevalTaskConfig( name="bigbench:presuppositions_as_nli", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="presuppositions_as_nli", hf_avail_splits=["default", "train", "validation"], @@ -1781,7 +1840,7 @@ qa_wikidata = LightevalTaskConfig( name="bigbench:qa_wikidata", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="qa_wikidata", hf_avail_splits=["default", "train", "validation"], @@ -1801,7 +1860,7 @@ question_selection = LightevalTaskConfig( name="bigbench:question_selection", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="question_selection", hf_avail_splits=["default", "train", "validation"], @@ -1816,7 +1875,7 @@ real_or_fake_text = LightevalTaskConfig( name="bigbench:real_or_fake_text", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="real_or_fake_text", hf_avail_splits=["default", "train", "validation"], @@ -1831,7 +1890,7 @@ reasoning_about_colored_objects = LightevalTaskConfig( name="bigbench:reasoning_about_colored_objects", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="reasoning_about_colored_objects", hf_avail_splits=["default", "train", "validation"], @@ -1846,7 +1905,7 @@ repeat_copy_logic = LightevalTaskConfig( name="bigbench:repeat_copy_logic", - prompt_function=prompt.bigbench_whitespace_after_query, + prompt_function=bigbench_whitespace_after_query_prompt, hf_repo="tasksource/bigbench", hf_subset="repeat_copy_logic", hf_avail_splits=["default", "train", "validation"], @@ -1861,7 +1920,7 @@ rephrase = LightevalTaskConfig( name="bigbench:rephrase", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="rephrase", hf_avail_splits=["default", "train", "validation"], @@ -1881,7 +1940,7 @@ rhyming = LightevalTaskConfig( name="bigbench:rhyming", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="rhyming", hf_avail_splits=["default", "train", "validation"], @@ -1896,7 +1955,7 @@ riddle_sense = LightevalTaskConfig( name="bigbench:riddle_sense", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="riddle_sense", hf_avail_splits=["default", "train", "validation"], @@ -1911,7 +1970,7 @@ ruin_names = LightevalTaskConfig( name="bigbench:ruin_names", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="ruin_names", hf_avail_splits=["default", "train", "validation"], @@ -1926,7 +1985,7 @@ salient_translation_error_detection = LightevalTaskConfig( name="bigbench:salient_translation_error_detection", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="salient_translation_error_detection", hf_avail_splits=["default", "train", "validation"], @@ -1941,7 +2000,7 @@ scientific_press_release = LightevalTaskConfig( name="bigbench:scientific_press_release", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="scientific_press_release", hf_avail_splits=["default", "train", "validation"], @@ -1956,7 +2015,7 @@ semantic_parsing_in_context_sparc = LightevalTaskConfig( name="bigbench:semantic_parsing_in_context_sparc", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="semantic_parsing_in_context_sparc", hf_avail_splits=["default", "train", "validation"], @@ -1971,7 +2030,7 @@ semantic_parsing_spider = LightevalTaskConfig( name="bigbench:semantic_parsing_spider", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="semantic_parsing_spider", hf_avail_splits=["default", "train", "validation"], @@ -1986,7 +2045,7 @@ sentence_ambiguity = LightevalTaskConfig( name="bigbench:sentence_ambiguity", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="sentence_ambiguity", hf_avail_splits=["default", "train", "validation"], @@ -2001,7 +2060,7 @@ similarities_abstraction = LightevalTaskConfig( name="bigbench:similarities_abstraction", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="similarities_abstraction", hf_avail_splits=["default", "train", "validation"], @@ -2016,7 +2075,7 @@ simp_turing_concept = LightevalTaskConfig( name="bigbench:simp_turing_concept", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="simp_turing_concept", hf_avail_splits=["default", "train", "validation"], @@ -2031,7 +2090,7 @@ simple_arithmetic_json = LightevalTaskConfig( name="bigbench:simple_arithmetic_json", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="simple_arithmetic_json", hf_avail_splits=["default", "train", "validation"], @@ -2046,7 +2105,7 @@ simple_arithmetic_json_multiple_choice = LightevalTaskConfig( name="bigbench:simple_arithmetic_json_multiple_choice", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="simple_arithmetic_json_multiple_choice", hf_avail_splits=["default", "train", "validation"], @@ -2061,7 +2120,7 @@ simple_arithmetic_json_subtasks = LightevalTaskConfig( name="bigbench:simple_arithmetic_json_subtasks", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="simple_arithmetic_json_subtasks", hf_avail_splits=["default", "train", "validation"], @@ -2076,7 +2135,7 @@ simple_arithmetic_multiple_targets_json = LightevalTaskConfig( name="bigbench:simple_arithmetic_multiple_targets_json", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="simple_arithmetic_multiple_targets_json", hf_avail_splits=["default", "train", "validation"], @@ -2091,7 +2150,7 @@ simple_ethical_questions = LightevalTaskConfig( name="bigbench:simple_ethical_questions", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="simple_ethical_questions", hf_avail_splits=["default", "train", "validation"], @@ -2106,7 +2165,7 @@ simple_text_editing = LightevalTaskConfig( name="bigbench:simple_text_editing", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="simple_text_editing", hf_avail_splits=["default", "train", "validation"], @@ -2121,7 +2180,7 @@ snarks = LightevalTaskConfig( name="bigbench:snarks", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="snarks", hf_avail_splits=["default", "train", "validation"], @@ -2136,7 +2195,7 @@ social_iqa = LightevalTaskConfig( name="bigbench:social_iqa", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="social_iqa", hf_avail_splits=["default", "train", "validation"], @@ -2151,7 +2210,7 @@ social_support = LightevalTaskConfig( name="bigbench:social_support", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="social_support", hf_avail_splits=["default", "train", "validation"], @@ -2166,7 +2225,7 @@ sports_understanding = LightevalTaskConfig( name="bigbench:sports_understanding", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="sports_understanding", hf_avail_splits=["default", "train", "validation"], @@ -2181,7 +2240,7 @@ strange_stories = LightevalTaskConfig( name="bigbench:strange_stories", - prompt_function=prompt.bigbench_whitespace_after_query, + prompt_function=bigbench_whitespace_after_query_prompt, hf_repo="tasksource/bigbench", hf_subset="strange_stories", hf_avail_splits=["default", "train", "validation"], @@ -2196,7 +2255,7 @@ strategyqa = LightevalTaskConfig( name="bigbench:strategyqa", - prompt_function=prompt.bigbench_linefeed_before_whitespace_after_query, + prompt_function=bigbench_linefeed_before_whitespace_after_query_prompt, hf_repo="tasksource/bigbench", hf_subset="strategyqa", hf_avail_splits=["default", "train", "validation"], @@ -2211,7 +2270,7 @@ sufficient_information = LightevalTaskConfig( name="bigbench:sufficient_information", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="sufficient_information", hf_avail_splits=["default", "train", "validation"], @@ -2226,7 +2285,7 @@ suicide_risk = LightevalTaskConfig( name="bigbench:suicide_risk", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="suicide_risk", hf_avail_splits=["default", "train", "validation"], @@ -2241,7 +2300,7 @@ swahili_english_proverbs = LightevalTaskConfig( name="bigbench:swahili_english_proverbs", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="swahili_english_proverbs", hf_avail_splits=["default", "train", "validation"], @@ -2256,7 +2315,7 @@ swedish_to_german_proverbs = LightevalTaskConfig( name="bigbench:swedish_to_german_proverbs", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="swedish_to_german_proverbs", hf_avail_splits=["default", "train", "validation"], @@ -2271,7 +2330,7 @@ symbol_interpretation = LightevalTaskConfig( name="bigbench:symbol_interpretation", - prompt_function=prompt.bigbench_linefeed_before_whitespace_after_query, + prompt_function=bigbench_linefeed_before_whitespace_after_query_prompt, hf_repo="tasksource/bigbench", hf_subset="symbol_interpretation", hf_avail_splits=["default", "train", "validation"], @@ -2286,7 +2345,7 @@ tellmewhy = LightevalTaskConfig( name="bigbench:tellmewhy", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="tellmewhy", hf_avail_splits=["default", "train", "validation"], @@ -2301,7 +2360,7 @@ temporal_sequences = LightevalTaskConfig( name="bigbench:temporal_sequences", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="temporal_sequences", hf_avail_splits=["default", "train", "validation"], @@ -2316,7 +2375,7 @@ tense = LightevalTaskConfig( name="bigbench:tense", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="tense", hf_avail_splits=["default", "train", "validation"], @@ -2331,7 +2390,7 @@ timedial = LightevalTaskConfig( name="bigbench:timedial", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="timedial", hf_avail_splits=["default", "train", "validation"], @@ -2346,7 +2405,7 @@ topical_chat = LightevalTaskConfig( name="bigbench:topical_chat", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="topical_chat", hf_avail_splits=["default", "train", "validation"], @@ -2361,7 +2420,7 @@ tracking_shuffled_objects = LightevalTaskConfig( name="bigbench:tracking_shuffled_objects", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="tracking_shuffled_objects", hf_avail_splits=["default", "train", "validation"], @@ -2376,7 +2435,7 @@ understanding_fables = LightevalTaskConfig( name="bigbench:understanding_fables", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="understanding_fables", hf_avail_splits=["default", "train", "validation"], @@ -2391,7 +2450,7 @@ undo_permutation = LightevalTaskConfig( name="bigbench:undo_permutation", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="undo_permutation", hf_avail_splits=["default", "train", "validation"], @@ -2406,7 +2465,7 @@ unit_conversion = LightevalTaskConfig( name="bigbench:unit_conversion", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="unit_conversion", hf_avail_splits=["default", "train", "validation"], @@ -2421,7 +2480,7 @@ unit_interpretation = LightevalTaskConfig( name="bigbench:unit_interpretation", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="unit_interpretation", hf_avail_splits=["default", "train", "validation"], @@ -2436,7 +2495,7 @@ unnatural_in_context_learning = LightevalTaskConfig( name="bigbench:unnatural_in_context_learning", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="unnatural_in_context_learning", hf_avail_splits=["default", "train", "validation"], @@ -2451,7 +2510,7 @@ vitaminc_fact_verification = LightevalTaskConfig( name="bigbench:vitaminc_fact_verification", - prompt_function=prompt.bigbench_whitespace_after_query, + prompt_function=bigbench_whitespace_after_query_prompt, hf_repo="tasksource/bigbench", hf_subset="vitaminc_fact_verification", hf_avail_splits=["default", "train", "validation"], @@ -2466,7 +2525,7 @@ what_is_the_tao = LightevalTaskConfig( name="bigbench:what_is_the_tao", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="what_is_the_tao", hf_avail_splits=["default", "train", "validation"], @@ -2481,7 +2540,7 @@ which_wiki_edit = LightevalTaskConfig( name="bigbench:which_wiki_edit", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="which_wiki_edit", hf_avail_splits=["default", "train", "validation"], @@ -2496,7 +2555,7 @@ winowhy = LightevalTaskConfig( name="bigbench:winowhy", - prompt_function=prompt.bigbench_whitespace_after_query, + prompt_function=bigbench_whitespace_after_query_prompt, hf_repo="tasksource/bigbench", hf_subset="winowhy", hf_avail_splits=["default", "train", "validation"], @@ -2511,7 +2570,7 @@ word_sorting = LightevalTaskConfig( name="bigbench:word_sorting", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="word_sorting", hf_avail_splits=["default", "train", "validation"], @@ -2526,7 +2585,7 @@ word_unscrambling = LightevalTaskConfig( name="bigbench:word_unscrambling", - prompt_function=prompt.bigbench, + prompt_function=bigbench_prompt, hf_repo="tasksource/bigbench", hf_subset="word_unscrambling", hf_avail_splits=["default", "train", "validation"], diff --git a/src/lighteval/tasks/tasks/bigbench_hard.py b/src/lighteval/tasks/tasks/bigbench_hard.py index b6dc484ae..4b930d1b9 100644 --- a/src/lighteval/tasks/tasks/bigbench_hard.py +++ b/src/lighteval/tasks/tasks/bigbench_hard.py @@ -15,14 +15,35 @@ paper: """ -import lighteval.tasks.default_prompts as prompt +from string import ascii_uppercase + from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def bbh_prompt(line, task_name: str = None): + line = {k: v for k, v in line.items() if v is not None} + + query = line.get("task_prefix", "") + query += line.get("example_input_prefix", "\nQuestion: ") + query += line["input"] + query += line.get("choice_prefix", "\n Choices: ") + query += "".join([f"\n{key}. {choice}" for key, choice in zip(ascii_uppercase, line["choices"])]) + query += line.get("example_output_prefix", "\nAnswer: ") + + return Doc( + task_name=task_name, + query=query, + choices=list(ascii_uppercase[: len(line["choices"])]), + gold_index=line["target_idx"], + instruction=line.get("task_prefix", None), + ) causal_judgment = LightevalTaskConfig( name="bigbench_hard:causal_judgment", - prompt_function=prompt.bbh_lighteval, + prompt_function=bbh_prompt, hf_repo="lighteval/bbh", hf_subset="causal_judgement", hf_avail_splits=["train"], @@ -37,7 +58,7 @@ date_understanding = LightevalTaskConfig( name="bigbench_hard:date_understanding", - prompt_function=prompt.bbh_lighteval, + prompt_function=bbh_prompt, hf_repo="lighteval/bbh", hf_subset="date_understanding", hf_avail_splits=["train"], @@ -52,7 +73,7 @@ disambiguation_qa = LightevalTaskConfig( name="bigbench_hard:disambiguation_qa", - prompt_function=prompt.bbh_lighteval, + prompt_function=bbh_prompt, hf_repo="lighteval/bbh", hf_subset="disambiguation_qa", hf_avail_splits=["train"], @@ -67,7 +88,7 @@ geometric_shapes = LightevalTaskConfig( name="bigbench_hard:geometric_shapes", - prompt_function=prompt.bbh_lighteval, + prompt_function=bbh_prompt, hf_repo="lighteval/bbh", hf_subset="geometric_shapes", hf_avail_splits=["train"], @@ -82,7 +103,7 @@ logical_deduction_five_objects = LightevalTaskConfig( name="bigbench_hard:logical_deduction_five_objects", - prompt_function=prompt.bbh_lighteval, + prompt_function=bbh_prompt, hf_repo="lighteval/bbh", hf_subset="logical_deduction_five_objects", hf_avail_splits=["train"], @@ -97,7 +118,7 @@ logical_deduction_seven_objects = LightevalTaskConfig( name="bigbench_hard:logical_deduction_seven_objects", - prompt_function=prompt.bbh_lighteval, + prompt_function=bbh_prompt, hf_repo="lighteval/bbh", hf_subset="logical_deduction_seven_objects", hf_avail_splits=["train"], @@ -112,7 +133,7 @@ logical_deduction_three_objects = LightevalTaskConfig( name="bigbench_hard:logical_deduction_three_objects", - prompt_function=prompt.bbh_lighteval, + prompt_function=bbh_prompt, hf_repo="lighteval/bbh", hf_subset="logical_deduction_three_objects", hf_avail_splits=["train"], @@ -127,7 +148,7 @@ movie_recommendation = LightevalTaskConfig( name="bigbench_hard:movie_recommendation", - prompt_function=prompt.bbh_lighteval, + prompt_function=bbh_prompt, hf_repo="lighteval/bbh", hf_subset="movie_recommendation", hf_avail_splits=["train"], @@ -142,7 +163,7 @@ navigate = LightevalTaskConfig( name="bigbench_hard:navigate", - prompt_function=prompt.bbh_lighteval, + prompt_function=bbh_prompt, hf_repo="lighteval/bbh", hf_subset="navigate", hf_avail_splits=["train"], @@ -157,7 +178,7 @@ reasoning_about_colored_objects = LightevalTaskConfig( name="bigbench_hard:reasoning_about_colored_objects", - prompt_function=prompt.bbh_lighteval, + prompt_function=bbh_prompt, hf_repo="lighteval/bbh", hf_subset="reasoning_about_colored_objects", hf_avail_splits=["train"], @@ -172,7 +193,7 @@ ruin_names = LightevalTaskConfig( name="bigbench_hard:ruin_names", - prompt_function=prompt.bbh_lighteval, + prompt_function=bbh_prompt, hf_repo="lighteval/bbh", hf_subset="ruin_names", hf_avail_splits=["train"], @@ -187,7 +208,7 @@ salient_translation_error_detection = LightevalTaskConfig( name="bigbench_hard:salient_translation_error_detection", - prompt_function=prompt.bbh_lighteval, + prompt_function=bbh_prompt, hf_repo="lighteval/bbh", hf_subset="salient_translation_error_detection", hf_avail_splits=["train"], @@ -202,7 +223,7 @@ snarks = LightevalTaskConfig( name="bigbench_hard:snarks", - prompt_function=prompt.bbh_lighteval, + prompt_function=bbh_prompt, hf_repo="lighteval/bbh", hf_subset="snarks", hf_avail_splits=["train"], @@ -217,7 +238,7 @@ sports_understanding = LightevalTaskConfig( name="bigbench_hard:sports_understanding", - prompt_function=prompt.bbh_lighteval, + prompt_function=bbh_prompt, hf_repo="lighteval/bbh", hf_subset="sports_understanding", hf_avail_splits=["train"], @@ -232,7 +253,7 @@ temporal_sequences = LightevalTaskConfig( name="bigbench_hard:temporal_sequences", - prompt_function=prompt.bbh_lighteval, + prompt_function=bbh_prompt, hf_repo="lighteval/bbh", hf_subset="temporal_sequences", hf_avail_splits=["train"], @@ -247,7 +268,7 @@ tracking_shuffled_objects_five_objects = LightevalTaskConfig( name="bigbench_hard:tracking_shuffled_objects_five_objects", - prompt_function=prompt.bbh_lighteval, + prompt_function=bbh_prompt, hf_repo="lighteval/bbh", hf_subset="tracking_shuffled_objects_five_objects", hf_avail_splits=["train"], @@ -262,7 +283,7 @@ tracking_shuffled_objects_seven_objects = LightevalTaskConfig( name="bigbench_hard:tracking_shuffled_objects_seven_objects", - prompt_function=prompt.bbh_lighteval, + prompt_function=bbh_prompt, hf_repo="lighteval/bbh", hf_subset="tracking_shuffled_objects_seven_objects", hf_avail_splits=["train"], @@ -277,7 +298,7 @@ tracking_shuffled_objects_three_objects = LightevalTaskConfig( name="bigbench_hard:tracking_shuffled_objects_three_objects", - prompt_function=prompt.bbh_lighteval, + prompt_function=bbh_prompt, hf_repo="lighteval/bbh", hf_subset="tracking_shuffled_objects_three_objects", hf_avail_splits=["train"], diff --git a/src/lighteval/tasks/tasks/blimp.py b/src/lighteval/tasks/tasks/blimp.py index 5319c5358..1b9278a32 100644 --- a/src/lighteval/tasks/tasks/blimp.py +++ b/src/lighteval/tasks/tasks/blimp.py @@ -22,14 +22,18 @@ https://arxiv.org/abs/1912.00582 """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def blimp_prompt(line, task_name: str = None): + return Doc(task_name=task_name, query="", choices=[line["sentence_good"], line["sentence_bad"]], gold_index=0) blimp_adjunct_island = LightevalTaskConfig( name="blimp:adjunct_island", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="adjunct_island", hf_avail_splits=["train"], @@ -44,7 +48,7 @@ blimp_anaphor_gender_agreement = LightevalTaskConfig( name="blimp:anaphor_gender_agreement", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="anaphor_gender_agreement", hf_avail_splits=["train"], @@ -59,7 +63,7 @@ blimp_anaphor_number_agreement = LightevalTaskConfig( name="blimp:anaphor_number_agreement", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="anaphor_number_agreement", hf_avail_splits=["train"], @@ -74,7 +78,7 @@ blimp_animate_subject_passive = LightevalTaskConfig( name="blimp:animate_subject_passive", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="animate_subject_passive", hf_avail_splits=["train"], @@ -89,7 +93,7 @@ blimp_animate_subject_trans = LightevalTaskConfig( name="blimp:animate_subject_trans", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="animate_subject_trans", hf_avail_splits=["train"], @@ -104,7 +108,7 @@ blimp_causative = LightevalTaskConfig( name="blimp:causative", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="causative", hf_avail_splits=["train"], @@ -119,7 +123,7 @@ blimp_complex_NP_island = LightevalTaskConfig( name="blimp:complex_NP_island", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="complex_NP_island", hf_avail_splits=["train"], @@ -134,7 +138,7 @@ blimp_coordinate_structure_constraint_complex_left_branch = LightevalTaskConfig( name="blimp:coordinate_structure_constraint_complex_left_branch", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="coordinate_structure_constraint_complex_left_branch", hf_avail_splits=["train"], @@ -149,7 +153,7 @@ blimp_coordinate_structure_constraint_object_extraction = LightevalTaskConfig( name="blimp:coordinate_structure_constraint_object_extraction", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="coordinate_structure_constraint_object_extraction", hf_avail_splits=["train"], @@ -164,7 +168,7 @@ blimp_determiner_noun_agreement_1 = LightevalTaskConfig( name="blimp:determiner_noun_agreement_1", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="determiner_noun_agreement_1", hf_avail_splits=["train"], @@ -179,7 +183,7 @@ blimp_determiner_noun_agreement_2 = LightevalTaskConfig( name="blimp:determiner_noun_agreement_2", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="determiner_noun_agreement_2", hf_avail_splits=["train"], @@ -194,7 +198,7 @@ blimp_determiner_noun_agreement_irregular_1 = LightevalTaskConfig( name="blimp:determiner_noun_agreement_irregular_1", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="determiner_noun_agreement_irregular_1", hf_avail_splits=["train"], @@ -209,7 +213,7 @@ blimp_determiner_noun_agreement_irregular_2 = LightevalTaskConfig( name="blimp:determiner_noun_agreement_irregular_2", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="determiner_noun_agreement_irregular_2", hf_avail_splits=["train"], @@ -224,7 +228,7 @@ blimp_determiner_noun_agreement_with_adj_2 = LightevalTaskConfig( name="blimp:determiner_noun_agreement_with_adj_2", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="determiner_noun_agreement_with_adj_2", hf_avail_splits=["train"], @@ -239,7 +243,7 @@ blimp_determiner_noun_agreement_with_adj_irregular_1 = LightevalTaskConfig( name="blimp:determiner_noun_agreement_with_adj_irregular_1", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="determiner_noun_agreement_with_adj_irregular_1", hf_avail_splits=["train"], @@ -254,7 +258,7 @@ blimp_determiner_noun_agreement_with_adj_irregular_2 = LightevalTaskConfig( name="blimp:determiner_noun_agreement_with_adj_irregular_2", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="determiner_noun_agreement_with_adj_irregular_2", hf_avail_splits=["train"], @@ -269,7 +273,7 @@ blimp_determiner_noun_agreement_with_adjective_1 = LightevalTaskConfig( name="blimp:determiner_noun_agreement_with_adjective_1", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="determiner_noun_agreement_with_adjective_1", hf_avail_splits=["train"], @@ -284,7 +288,7 @@ blimp_distractor_agreement_relational_noun = LightevalTaskConfig( name="blimp:distractor_agreement_relational_noun", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="distractor_agreement_relational_noun", hf_avail_splits=["train"], @@ -299,7 +303,7 @@ blimp_distractor_agreement_relative_clause = LightevalTaskConfig( name="blimp:distractor_agreement_relative_clause", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="distractor_agreement_relative_clause", hf_avail_splits=["train"], @@ -314,7 +318,7 @@ blimp_drop_argument = LightevalTaskConfig( name="blimp:drop_argument", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="drop_argument", hf_avail_splits=["train"], @@ -329,7 +333,7 @@ blimp_ellipsis_n_bar_1 = LightevalTaskConfig( name="blimp:ellipsis_n_bar_1", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="ellipsis_n_bar_1", hf_avail_splits=["train"], @@ -344,7 +348,7 @@ blimp_ellipsis_n_bar_2 = LightevalTaskConfig( name="blimp:ellipsis_n_bar_2", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="ellipsis_n_bar_2", hf_avail_splits=["train"], @@ -359,7 +363,7 @@ blimp_existential_there_object_raising = LightevalTaskConfig( name="blimp:existential_there_object_raising", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="existential_there_object_raising", hf_avail_splits=["train"], @@ -374,7 +378,7 @@ blimp_existential_there_quantifiers_1 = LightevalTaskConfig( name="blimp:existential_there_quantifiers_1", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="existential_there_quantifiers_1", hf_avail_splits=["train"], @@ -389,7 +393,7 @@ blimp_existential_there_quantifiers_2 = LightevalTaskConfig( name="blimp:existential_there_quantifiers_2", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="existential_there_quantifiers_2", hf_avail_splits=["train"], @@ -404,7 +408,7 @@ blimp_existential_there_subject_raising = LightevalTaskConfig( name="blimp:existential_there_subject_raising", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="existential_there_subject_raising", hf_avail_splits=["train"], @@ -419,7 +423,7 @@ blimp_expletive_it_object_raising = LightevalTaskConfig( name="blimp:expletive_it_object_raising", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="expletive_it_object_raising", hf_avail_splits=["train"], @@ -434,7 +438,7 @@ blimp_inchoative = LightevalTaskConfig( name="blimp:inchoative", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="inchoative", hf_avail_splits=["train"], @@ -449,7 +453,7 @@ blimp_intransitive = LightevalTaskConfig( name="blimp:intransitive", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="intransitive", hf_avail_splits=["train"], @@ -464,7 +468,7 @@ blimp_irregular_past_participle_adjectives = LightevalTaskConfig( name="blimp:irregular_past_participle_adjectives", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="irregular_past_participle_adjectives", hf_avail_splits=["train"], @@ -479,7 +483,7 @@ blimp_irregular_past_participle_verbs = LightevalTaskConfig( name="blimp:irregular_past_participle_verbs", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="irregular_past_participle_verbs", hf_avail_splits=["train"], @@ -494,7 +498,7 @@ blimp_irregular_plural_subject_verb_agreement_1 = LightevalTaskConfig( name="blimp:irregular_plural_subject_verb_agreement_1", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="irregular_plural_subject_verb_agreement_1", hf_avail_splits=["train"], @@ -509,7 +513,7 @@ blimp_irregular_plural_subject_verb_agreement_2 = LightevalTaskConfig( name="blimp:irregular_plural_subject_verb_agreement_2", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="irregular_plural_subject_verb_agreement_2", hf_avail_splits=["train"], @@ -524,7 +528,7 @@ blimp_left_branch_island_echo_question = LightevalTaskConfig( name="blimp:left_branch_island_echo_question", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="left_branch_island_echo_question", hf_avail_splits=["train"], @@ -539,7 +543,7 @@ blimp_left_branch_island_simple_question = LightevalTaskConfig( name="blimp:left_branch_island_simple_question", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="left_branch_island_simple_question", hf_avail_splits=["train"], @@ -554,7 +558,7 @@ blimp_matrix_question_npi_licensor_present = LightevalTaskConfig( name="blimp:matrix_question_npi_licensor_present", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="matrix_question_npi_licensor_present", hf_avail_splits=["train"], @@ -569,7 +573,7 @@ blimp_npi_present_1 = LightevalTaskConfig( name="blimp:npi_present_1", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="npi_present_1", hf_avail_splits=["train"], @@ -584,7 +588,7 @@ blimp_npi_present_2 = LightevalTaskConfig( name="blimp:npi_present_2", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="npi_present_2", hf_avail_splits=["train"], @@ -599,7 +603,7 @@ blimp_only_npi_licensor_present = LightevalTaskConfig( name="blimp:only_npi_licensor_present", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="only_npi_licensor_present", hf_avail_splits=["train"], @@ -614,7 +618,7 @@ blimp_only_npi_scope = LightevalTaskConfig( name="blimp:only_npi_scope", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="only_npi_scope", hf_avail_splits=["train"], @@ -629,7 +633,7 @@ blimp_passive_1 = LightevalTaskConfig( name="blimp:passive_1", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="passive_1", hf_avail_splits=["train"], @@ -644,7 +648,7 @@ blimp_passive_2 = LightevalTaskConfig( name="blimp:passive_2", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="passive_2", hf_avail_splits=["train"], @@ -659,7 +663,7 @@ blimp_principle_A_c_command = LightevalTaskConfig( name="blimp:principle_A_c_command", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="principle_A_c_command", hf_avail_splits=["train"], @@ -674,7 +678,7 @@ blimp_principle_A_case_1 = LightevalTaskConfig( name="blimp:principle_A_case_1", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="principle_A_case_1", hf_avail_splits=["train"], @@ -689,7 +693,7 @@ blimp_principle_A_case_2 = LightevalTaskConfig( name="blimp:principle_A_case_2", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="principle_A_case_2", hf_avail_splits=["train"], @@ -704,7 +708,7 @@ blimp_principle_A_domain_1 = LightevalTaskConfig( name="blimp:principle_A_domain_1", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="principle_A_domain_1", hf_avail_splits=["train"], @@ -719,7 +723,7 @@ blimp_principle_A_domain_2 = LightevalTaskConfig( name="blimp:principle_A_domain_2", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="principle_A_domain_2", hf_avail_splits=["train"], @@ -734,7 +738,7 @@ blimp_principle_A_domain_3 = LightevalTaskConfig( name="blimp:principle_A_domain_3", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="principle_A_domain_3", hf_avail_splits=["train"], @@ -749,7 +753,7 @@ blimp_principle_A_reconstruction = LightevalTaskConfig( name="blimp:principle_A_reconstruction", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="principle_A_reconstruction", hf_avail_splits=["train"], @@ -764,7 +768,7 @@ blimp_regular_plural_subject_verb_agreement_1 = LightevalTaskConfig( name="blimp:regular_plural_subject_verb_agreement_1", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="regular_plural_subject_verb_agreement_1", hf_avail_splits=["train"], @@ -779,7 +783,7 @@ blimp_regular_plural_subject_verb_agreement_2 = LightevalTaskConfig( name="blimp:regular_plural_subject_verb_agreement_2", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="regular_plural_subject_verb_agreement_2", hf_avail_splits=["train"], @@ -794,7 +798,7 @@ blimp_sentential_negation_npi_licensor_present = LightevalTaskConfig( name="blimp:sentential_negation_npi_licensor_present", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="sentential_negation_npi_licensor_present", hf_avail_splits=["train"], @@ -809,7 +813,7 @@ blimp_sentential_negation_npi_scope = LightevalTaskConfig( name="blimp:sentential_negation_npi_scope", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="sentential_negation_npi_scope", hf_avail_splits=["train"], @@ -824,7 +828,7 @@ blimp_sentential_subject_island = LightevalTaskConfig( name="blimp:sentential_subject_island", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="sentential_subject_island", hf_avail_splits=["train"], @@ -839,7 +843,7 @@ blimp_superlative_quantifiers_1 = LightevalTaskConfig( name="blimp:superlative_quantifiers_1", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="superlative_quantifiers_1", hf_avail_splits=["train"], @@ -854,7 +858,7 @@ blimp_superlative_quantifiers_2 = LightevalTaskConfig( name="blimp:superlative_quantifiers_2", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="superlative_quantifiers_2", hf_avail_splits=["train"], @@ -869,7 +873,7 @@ blimp_tough_vs_raising_1 = LightevalTaskConfig( name="blimp:tough_vs_raising_1", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="tough_vs_raising_1", hf_avail_splits=["train"], @@ -884,7 +888,7 @@ blimp_tough_vs_raising_2 = LightevalTaskConfig( name="blimp:tough_vs_raising_2", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="tough_vs_raising_2", hf_avail_splits=["train"], @@ -899,7 +903,7 @@ blimp_transitive = LightevalTaskConfig( name="blimp:transitive", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="transitive", hf_avail_splits=["train"], @@ -914,7 +918,7 @@ blimp_wh_island = LightevalTaskConfig( name="blimp:wh_island", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="wh_island", hf_avail_splits=["train"], @@ -929,7 +933,7 @@ blimp_wh_questions_object_gap = LightevalTaskConfig( name="blimp:wh_questions_object_gap", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="wh_questions_object_gap", hf_avail_splits=["train"], @@ -944,7 +948,7 @@ blimp_wh_questions_subject_gap = LightevalTaskConfig( name="blimp:wh_questions_subject_gap", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="wh_questions_subject_gap", hf_avail_splits=["train"], @@ -959,7 +963,7 @@ blimp_wh_questions_subject_gap_long_distance = LightevalTaskConfig( name="blimp:wh_questions_subject_gap_long_distance", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="wh_questions_subject_gap_long_distance", hf_avail_splits=["train"], @@ -974,7 +978,7 @@ blimp_wh_vs_that_no_gap = LightevalTaskConfig( name="blimp:wh_vs_that_no_gap", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="wh_vs_that_no_gap", hf_avail_splits=["train"], @@ -989,7 +993,7 @@ blimp_wh_vs_that_no_gap_long_distance = LightevalTaskConfig( name="blimp:wh_vs_that_no_gap_long_distance", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="wh_vs_that_no_gap_long_distance", hf_avail_splits=["train"], @@ -1004,7 +1008,7 @@ blimp_wh_vs_that_with_gap = LightevalTaskConfig( name="blimp:wh_vs_that_with_gap", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="wh_vs_that_with_gap", hf_avail_splits=["train"], @@ -1019,7 +1023,7 @@ blimp_wh_vs_that_with_gap_long_distance = LightevalTaskConfig( name="blimp:wh_vs_that_with_gap_long_distance", - prompt_function=prompt.blimp, + prompt_function=blimp_prompt, hf_repo="nyu-mll/blimp", hf_subset="wh_vs_that_with_gap_long_distance", hf_avail_splits=["train"], diff --git a/src/lighteval/tasks/tasks/bold.py b/src/lighteval/tasks/tasks/bold.py index 40329c258..2ecc52c05 100644 --- a/src/lighteval/tasks/tasks/bold.py +++ b/src/lighteval/tasks/tasks/bold.py @@ -19,14 +19,18 @@ https://dl.acm.org/doi/10.1145/3442188.3445924 """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def bold_prompt(line, task_name: str = None): + return Doc(task_name=task_name, query=line["text"], choices=None, gold_index=None) bold = LightevalTaskConfig( name="bold", - prompt_function=prompt.bold, + prompt_function=bold_prompt, hf_repo="lighteval/bold_helm", hf_subset="all", hf_avail_splits=["train", "test"], @@ -41,7 +45,7 @@ bold_gender = LightevalTaskConfig( name="bold:gender", - prompt_function=prompt.bold, + prompt_function=bold_prompt, hf_repo="lighteval/bold_helm", hf_subset="gender", hf_avail_splits=["train", "test"], @@ -56,7 +60,7 @@ bold_political_ideology = LightevalTaskConfig( name="bold:political_ideology", - prompt_function=prompt.bold, + prompt_function=bold_prompt, hf_repo="lighteval/bold_helm", hf_subset="political_ideology", hf_avail_splits=["train", "test"], @@ -71,7 +75,7 @@ bold_profession = LightevalTaskConfig( name="bold:profession", - prompt_function=prompt.bold, + prompt_function=bold_prompt, hf_repo="lighteval/bold_helm", hf_subset="profession", hf_avail_splits=["train", "test"], @@ -86,7 +90,7 @@ bold_race = LightevalTaskConfig( name="bold:race", - prompt_function=prompt.bold, + prompt_function=bold_prompt, hf_repo="lighteval/bold_helm", hf_subset="race", hf_avail_splits=["train", "test"], @@ -101,7 +105,7 @@ bold_religious_ideology = LightevalTaskConfig( name="bold:religious_ideology", - prompt_function=prompt.bold, + prompt_function=bold_prompt, hf_repo="lighteval/bold_helm", hf_subset="religious_ideology", hf_avail_splits=["train", "test"], diff --git a/src/lighteval/tasks/tasks/boolq.py b/src/lighteval/tasks/tasks/boolq.py index 81d913669..df854b9e6 100644 --- a/src/lighteval/tasks/tasks/boolq.py +++ b/src/lighteval/tasks/tasks/boolq.py @@ -18,14 +18,39 @@ https://arxiv.org/abs/1905.11946 """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def boolq_prompt(line, task_name: str = None): + question = line["question"][:-1] if line["question"][-2:] == "??" else line["question"] + return Doc( + task_name=task_name, + query=f"Passage: {line['passage']}\nQuestion: {question}\nAnswer:", + choices=[" Yes", " No"], + gold_index=["Yes", "No"].index(line["answer"]), + ) + + +def boolq_contrastset_prompt(line, task_name: str = None): + if line["contrast_inputs"] in [None, ""]: + return boolq_prompt(line) + + return [ + Doc( + task_name=task_name, + query=f"{passage}\nQuestion: {question}\nAnswer:", + choices=["Yes", "No"], + gold_index=["No", "Yes"].index(line["answer"]), + ) + for passage, question in zip(line["contrast_inputs"]["passage"], line["contrast_inputs"]["question"]) + ][0] boolq = LightevalTaskConfig( name="boolq", - prompt_function=prompt.boolq_helm, + prompt_function=boolq_prompt, hf_repo="lighteval/boolq_helm", hf_subset="default", hf_avail_splits=["train", "validation"], @@ -43,7 +68,7 @@ boolq_contrastset = LightevalTaskConfig( name="boolq:contrastset", - prompt_function=prompt.boolq_helm_contrastset, + prompt_function=boolq_contrastset_prompt, hf_repo="lighteval/boolq_helm", hf_subset="default", hf_avail_splits=["validation"], diff --git a/src/lighteval/tasks/tasks/civil_comments.py b/src/lighteval/tasks/tasks/civil_comments.py index dc0169c51..79f96ff8e 100644 --- a/src/lighteval/tasks/tasks/civil_comments.py +++ b/src/lighteval/tasks/tasks/civil_comments.py @@ -18,14 +18,23 @@ https://arxiv.org/abs/1903.04561 """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def civil_comments_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=f"Passage: {line['text']}\nIs the passage above toxic?\nAnswer: ", + choices=[str(line["gold"])], + gold_index=0, + ) civil_comments = LightevalTaskConfig( name="civil_comments", - prompt_function=prompt.civil_comments, + prompt_function=civil_comments_prompt, hf_repo="lighteval/civil_comments_helm", hf_subset="all", hf_avail_splits=["train", "test"], @@ -40,7 +49,7 @@ civil_comments_LGBTQ = LightevalTaskConfig( name="civil_comments:LGBTQ", - prompt_function=prompt.civil_comments, + prompt_function=civil_comments_prompt, hf_repo="lighteval/civil_comments_helm", hf_subset="LGBTQ", hf_avail_splits=["train", "test"], @@ -55,7 +64,7 @@ civil_comments_black = LightevalTaskConfig( name="civil_comments:black", - prompt_function=prompt.civil_comments, + prompt_function=civil_comments_prompt, hf_repo="lighteval/civil_comments_helm", hf_subset="black", hf_avail_splits=["train", "test"], @@ -70,7 +79,7 @@ civil_comments_christian = LightevalTaskConfig( name="civil_comments:christian", - prompt_function=prompt.civil_comments, + prompt_function=civil_comments_prompt, hf_repo="lighteval/civil_comments_helm", hf_subset="christian", hf_avail_splits=["train", "test"], @@ -85,7 +94,7 @@ civil_comments_female = LightevalTaskConfig( name="civil_comments:female", - prompt_function=prompt.civil_comments, + prompt_function=civil_comments_prompt, hf_repo="lighteval/civil_comments_helm", hf_subset="female", hf_avail_splits=["train", "test"], @@ -100,7 +109,7 @@ civil_comments_male = LightevalTaskConfig( name="civil_comments:male", - prompt_function=prompt.civil_comments, + prompt_function=civil_comments_prompt, hf_repo="lighteval/civil_comments_helm", hf_subset="male", hf_avail_splits=["train", "test"], @@ -115,7 +124,7 @@ civil_comments_muslim = LightevalTaskConfig( name="civil_comments:muslim", - prompt_function=prompt.civil_comments, + prompt_function=civil_comments_prompt, hf_repo="lighteval/civil_comments_helm", hf_subset="muslim", hf_avail_splits=["train", "test"], @@ -130,7 +139,7 @@ civil_comments_other_religions = LightevalTaskConfig( name="civil_comments:other_religions", - prompt_function=prompt.civil_comments, + prompt_function=civil_comments_prompt, hf_repo="lighteval/civil_comments_helm", hf_subset="other_religions", hf_avail_splits=["train", "test"], @@ -145,7 +154,7 @@ civil_comments_white = LightevalTaskConfig( name="civil_comments:white", - prompt_function=prompt.civil_comments, + prompt_function=civil_comments_prompt, hf_repo="lighteval/civil_comments_helm", hf_subset="white", hf_avail_splits=["train", "test"], diff --git a/src/lighteval/tasks/tasks/commonsenseqa.py b/src/lighteval/tasks/tasks/commonsenseqa.py index 927d6f3f5..639be22d0 100644 --- a/src/lighteval/tasks/tasks/commonsenseqa.py +++ b/src/lighteval/tasks/tasks/commonsenseqa.py @@ -23,14 +23,32 @@ https://arxiv.org/abs/1811.00937 """ -import lighteval.tasks.default_prompts as prompt +from string import ascii_uppercase + from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def commonsenseqa_prompt(line, task_name: str = None): + query = f"The following are multiple choice questions (with answers) about common sense.\nQuestion: {line['question']}\n" + query += "".join( + [f"{key}. {choice}\n" for key, choice in zip(ascii_uppercase, [f" {c}" for c in line["choices"]["text"]])] + ) + query += "Answer:" + + return Doc( + task_name=task_name, + query=query, + choices=list(ascii_uppercase)[: len(line["choices"]["text"])], + gold_index=list(ascii_uppercase).index(line["answerKey"].strip()), + instruction="The following are multiple choice questions (with answers) about common sense.\n", + ) commonsenseqa = LightevalTaskConfig( name="commonsenseqa", - prompt_function=prompt.commonsense_qa, + prompt_function=commonsenseqa_prompt, hf_repo="tau/commonsense_qa", hf_subset="default", hf_avail_splits=["train", "test", "validation"], diff --git a/src/lighteval/tasks/tasks/coqa.py b/src/lighteval/tasks/tasks/coqa.py index c2353f082..34495ca3b 100644 --- a/src/lighteval/tasks/tasks/coqa.py +++ b/src/lighteval/tasks/tasks/coqa.py @@ -21,14 +21,21 @@ https://arxiv.org/abs/1808.07042 """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def coqa_prompt(line, task_name: str = None): + results = [] + for q, a in zip(line["questions"], line["answers"]["input_text"]): + results.append(Doc(task_name=task_name, query=f"{line['story']} \n\nQ: {q}\n\nA: ", choices=[a], gold_index=0)) + return results coqa_first_question = LightevalTaskConfig( name="coqa", - prompt_function=prompt.coqa, + prompt_function=coqa_prompt, hf_repo="stanfordnlp/coqa", hf_subset="default", hf_avail_splits=["train", "validation"], diff --git a/src/lighteval/tasks/tasks/covid_dialogue.py b/src/lighteval/tasks/tasks/covid_dialogue.py index 78d1724a3..3446ac17d 100644 --- a/src/lighteval/tasks/tasks/covid_dialogue.py +++ b/src/lighteval/tasks/tasks/covid_dialogue.py @@ -19,14 +19,24 @@ https://arxiv.org/abs/2004.06561 """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def covid_dialogue_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=f"Generate a response given a patient's questions and concerns.\nPatient: {line['query']}\nDoctor: ", + choices=[line["answer"]], + gold_index=0, + instruction="Generate a response given a patient's questions and concerns.\n", + ) covid_dialogue = LightevalTaskConfig( name="covid_dialogue", - prompt_function=prompt.covid_dialogue, + prompt_function=covid_dialogue_prompt, hf_repo="lighteval/covid_dialogue", hf_subset="default", hf_avail_splits=["train", "test", "validation"], diff --git a/src/lighteval/tasks/tasks/drop_qa.py b/src/lighteval/tasks/tasks/drop_qa.py index 71c98d983..91a6adff7 100644 --- a/src/lighteval/tasks/tasks/drop_qa.py +++ b/src/lighteval/tasks/tasks/drop_qa.py @@ -20,40 +20,57 @@ https://arxiv.org/abs/1810.00505 """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig -from lighteval.tasks.templates.qa import get_qa_prompt_function -from lighteval.utils.language import Language +from lighteval.tasks.requests import Doc + + +def drop_prompt(line, task_name: str = None): + def _flatten_validated_answers(validated_answers): + valid_answers = [] + for i in range(len(validated_answers["number"])): + valid_answers.append( + { + "number": validated_answers["number"][i], + "date": validated_answers["date"][i], + "spans": validated_answers["spans"][i], + } + ) + return valid_answers + + def parse_answer(answer): + if answer["number"] != "": + return (str(answer["number"]),) + if answer["spans"] != []: + return tuple(answer["spans"]) + return (" ".join([answer["date"]["day"], answer["date"]["month"], answer["date"]["year"]]).strip(),) + + answers = [] + answers_set = set() + candidates = [line["answer"]] + _flatten_validated_answers(line["validated_answers"]) + for candidate in candidates: + answer = parse_answer(candidate) + if answer in answers_set: + continue + answers.append(answer) + answers_set.add(answer) + + is_few_shots = line.get("__few_shots", False) + + return Doc( + task_name=task_name, + query=f"Passage: {line['passage']}\nQuestion: {line['question']}\nAnswer:", + choices=[f"{' ' if is_few_shots else ''}{', '.join(a)}" for a in answers], + gold_index=list(range(len(answers))), + specific={"golds_no_preprocessing": [list(a) for a in answers]}, + ) drop_qa = LightevalTaskConfig( name="drop", - prompt_function=get_qa_prompt_function( - Language.ENGLISH, - lambda line: { - "context": line["passage"], - "question": line["question"], - "choices": list( - filter( - lambda x: x, - [line["answer"].get("number")] - + line["answer"]["spans"] - + [prompt.get_drop_date(line["answer"].get("date"))], - ) - ), - }, - ), + prompt_function=drop_prompt, hf_repo="lighteval/drop_harness", hf_subset="default", - hf_filter=lambda line: list( - filter( - lambda x: x, - [line["answer"].get("number")] - + line["answer"]["spans"] - + [prompt.get_drop_date(line["answer"].get("date"))], - ) - ), evaluation_splits=("validation",), few_shots_split="train", generation_size=250, diff --git a/src/lighteval/tasks/tasks/dyck_language.py b/src/lighteval/tasks/tasks/dyck_language.py index c81f822a6..e2463445b 100644 --- a/src/lighteval/tasks/tasks/dyck_language.py +++ b/src/lighteval/tasks/tasks/dyck_language.py @@ -18,14 +18,24 @@ https://aclanthology.org/W19-3905/ """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def dyck_language_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=f"Please complete the rest of the following Dyck sequences, making sure that the parentheses are closed properly.\n Input: {line['input']}", + choices=[line["output"]], + gold_index=0, + instruction="Please complete the rest of the following Dyck sequences, making sure that the parentheses are closed properly.\n ", + ) dyck_language_2 = LightevalTaskConfig( name="dyck_language:2", - prompt_function=prompt.dyck_language, + prompt_function=dyck_language_prompt, hf_repo="lighteval/DyckLanguage", hf_subset="2", hf_avail_splits=["train", "test"], @@ -41,7 +51,7 @@ dyck_language_3 = LightevalTaskConfig( name="dyck_language:3", - prompt_function=prompt.dyck_language, + prompt_function=dyck_language_prompt, hf_repo="lighteval/DyckLanguage", hf_subset="3", hf_avail_splits=["train", "test"], @@ -57,7 +67,7 @@ dyck_language_4 = LightevalTaskConfig( name="dyck_language:4", - prompt_function=prompt.dyck_language, + prompt_function=dyck_language_prompt, hf_repo="lighteval/DyckLanguage", hf_subset="4", hf_avail_splits=["train", "test"], diff --git a/src/lighteval/tasks/tasks/entity_data_imputation.py b/src/lighteval/tasks/tasks/entity_data_imputation.py index 7cb2c5bdb..0999c83b7 100644 --- a/src/lighteval/tasks/tasks/entity_data_imputation.py +++ b/src/lighteval/tasks/tasks/entity_data_imputation.py @@ -18,14 +18,24 @@ https://ieeexplore.ieee.org/document/9458712 """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def entity_data_imputation_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=f"What is the missing value?\n{line['text']}\nAnswer:", + choices=[line["gold"]], + gold_index=0, + instruction="What is the missing value?\n", + ) entity_data_imputation_Buy = LightevalTaskConfig( name="entity_data_imputation:Buy", - prompt_function=prompt.entity_data_imputation, + prompt_function=entity_data_imputation_prompt, hf_repo="lighteval/Buy", hf_subset="default", hf_avail_splits=["train", "test", "valid"], @@ -43,7 +53,7 @@ entity_data_imputation_Restaurant = LightevalTaskConfig( name="entity_data_imputation:Restaurant", - prompt_function=prompt.entity_data_imputation, + prompt_function=entity_data_imputation_prompt, hf_repo="lighteval/Restaurant", hf_subset="default", hf_avail_splits=["train"], diff --git a/src/lighteval/tasks/tasks/entitymatching.py b/src/lighteval/tasks/tasks/entitymatching.py index 0e3d16ffa..252559c8d 100644 --- a/src/lighteval/tasks/tasks/entitymatching.py +++ b/src/lighteval/tasks/tasks/entitymatching.py @@ -18,14 +18,24 @@ https://dl.acm.org/doi/10.14778/3007263.3007314 """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def entity_matching_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=f"Are Product A and Product B the same? Yes or No?\nProduct A is {line['productA']}. Product B is {line['productB']}. Are A and B the same?\nAnswer:", + choices=["No", "Yes"], + gold_index=int(line["same"]), + instruction="Are Product A and Product B the same? Yes or No?\n", + ) entity_matching_Abt_Buy = LightevalTaskConfig( name="entity_matching:Abt_Buy", - prompt_function=prompt.entity_matching, + prompt_function=entity_matching_prompt, hf_repo="lighteval/EntityMatching", hf_subset="Abt_Buy", hf_avail_splits=["train", "test", "validation"], @@ -40,7 +50,7 @@ entity_matching_Amazon_Google = LightevalTaskConfig( name="entity_matching:Amazon_Google", - prompt_function=prompt.entity_matching, + prompt_function=entity_matching_prompt, hf_repo="lighteval/EntityMatching", hf_subset="Amazon_Google", hf_avail_splits=["train", "test", "validation"], @@ -55,7 +65,7 @@ entity_matching_Beer = LightevalTaskConfig( name="entity_matching:Beer", - prompt_function=prompt.entity_matching, + prompt_function=entity_matching_prompt, hf_repo="lighteval/EntityMatching", hf_subset="Beer", hf_avail_splits=["train", "test", "validation"], @@ -70,7 +80,7 @@ entity_matching_Company = LightevalTaskConfig( name="entity_matching:Company", - prompt_function=prompt.entity_matching, + prompt_function=entity_matching_prompt, hf_repo="lighteval/EntityMatching", hf_subset="Company", hf_avail_splits=["train", "test", "validation"], @@ -85,7 +95,7 @@ entity_matching_DBLP_ACM = LightevalTaskConfig( name="entity_matching:DBLP_ACM", - prompt_function=prompt.entity_matching, + prompt_function=entity_matching_prompt, hf_repo="lighteval/EntityMatching", hf_subset="DBLP_ACM", hf_avail_splits=["train", "test", "validation"], @@ -100,7 +110,7 @@ entity_matching_DBLP_GoogleScholar = LightevalTaskConfig( name="entity_matching:DBLP_GoogleScholar", - prompt_function=prompt.entity_matching, + prompt_function=entity_matching_prompt, hf_repo="lighteval/EntityMatching", hf_subset="DBLP_GoogleScholar", hf_avail_splits=["train", "test", "validation"], @@ -115,7 +125,7 @@ entity_matching_Dirty_DBLP_ACM = LightevalTaskConfig( name="entity_matching:Dirty_DBLP_ACM", - prompt_function=prompt.entity_matching, + prompt_function=entity_matching_prompt, hf_repo="lighteval/EntityMatching", hf_subset="Dirty_DBLP_ACM", hf_avail_splits=["train", "test", "validation"], @@ -130,7 +140,7 @@ entity_matching_Dirty_DBLP_GoogleScholar = LightevalTaskConfig( name="entity_matching:Dirty_DBLP_GoogleScholar", - prompt_function=prompt.entity_matching, + prompt_function=entity_matching_prompt, hf_repo="lighteval/EntityMatching", hf_subset="Dirty_DBLP_GoogleScholar", hf_avail_splits=["train", "test", "validation"], @@ -145,7 +155,7 @@ entity_matching_Dirty_Walmart_Amazon = LightevalTaskConfig( name="entity_matching:Dirty_Walmart_Amazon", - prompt_function=prompt.entity_matching, + prompt_function=entity_matching_prompt, hf_repo="lighteval/EntityMatching", hf_subset="Dirty_Walmart_Amazon", hf_avail_splits=["train", "test", "validation"], @@ -160,7 +170,7 @@ entity_matching_Dirty_iTunes_Amazon = LightevalTaskConfig( name="entity_matching:Dirty_iTunes_Amazon", - prompt_function=prompt.entity_matching, + prompt_function=entity_matching_prompt, hf_repo="lighteval/EntityMatching", hf_subset="Dirty_iTunes_Amazon", hf_avail_splits=["train", "test", "validation"], @@ -175,7 +185,7 @@ entity_matching_Fodors_Zagats = LightevalTaskConfig( name="entity_matching=Fodors_Zagats", - prompt_function=prompt.entity_matching, + prompt_function=entity_matching_prompt, hf_repo="lighteval/EntityMatching", hf_subset="Fodors_Zagats", hf_avail_splits=["train", "test", "validation"], @@ -190,7 +200,7 @@ entity_matching_Walmart_Amazon = LightevalTaskConfig( name="entity_matching:Walmart_Amazon", - prompt_function=prompt.entity_matching, + prompt_function=entity_matching_prompt, hf_repo="lighteval/EntityMatching", hf_subset="Walmart_Amazon", hf_avail_splits=["train", "test", "validation"], @@ -205,7 +215,7 @@ entity_matching_iTunes_Amazon = LightevalTaskConfig( name="entity_matching:iTunes_Amazon", - prompt_function=prompt.entity_matching, + prompt_function=entity_matching_prompt, hf_repo="lighteval/EntityMatching", hf_subset="iTunes_Amazon", hf_avail_splits=["train", "test", "validation"], diff --git a/src/lighteval/tasks/tasks/ethics.py b/src/lighteval/tasks/tasks/ethics.py index f5fb856d8..07471fb5f 100644 --- a/src/lighteval/tasks/tasks/ethics.py +++ b/src/lighteval/tasks/tasks/ethics.py @@ -19,14 +19,55 @@ https://arxiv.org/abs/2008.02275 """ -import lighteval.tasks.default_prompts as prompt +import random + from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def ethics_commonsense_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=f"{line['input']}\nQuestion: Is this wrong?\nAnswer:", + choices=[" no", " yes"], + gold_index=int(line["label"]), + ) + + +def ethics_deontology_prompt(line, task_name: str = None): + return Doc(task_name=task_name, query=line["input"], choices=[line["label"]], gold_index=0) + + +def ethics_justice_prompt(line, task_name: str = None): + return Doc(task_name=task_name, query=line["input"], choices=[line["label"]], gold_index=0) + + +def ethics_utilitarianism_prompt(line, task_name: str = None): + rnd = random.Random(line["activity"]) + scenarios = [line["activity"], line["baseline"]] + ordering = [0, 1] + rnd.shuffle(ordering) + return Doc( + task_name=task_name, + query=f"Scenario 1: {scenarios[ordering[0]]}\nScenario 2: {scenarios[ordering[1]]}\nQuestion: Is Scenario 1 preferable?\nAnswer:", + choices=[" no", " yes"], + gold_index=int(ordering.index(0) == 0), + ) + + +def ethics_virtue_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=f'Sentence: {line["scenario"]}\nQuestion: Does the character in this sentence exhibit the trait "{line["trait"]}"?\nAnswer:', + choices=[" no", " yes"], + gold_index=int(line["label"]), + ) ethics_commonsense = LightevalTaskConfig( name="ethics:commonsense", - prompt_function=prompt.ethics_commonsense, + prompt_function=ethics_commonsense_prompt, hf_repo="lighteval/hendrycks_ethics", hf_subset="commonsense", hf_avail_splits=["train", "validation", "test"], @@ -41,7 +82,7 @@ ethics_deontology = LightevalTaskConfig( name="ethics:deontology", - prompt_function=prompt.ethics_deontology, + prompt_function=ethics_deontology_prompt, hf_repo="lighteval/hendrycks_ethics", hf_subset="deontology", hf_avail_splits=["train", "validation", "test"], @@ -56,7 +97,7 @@ ethics_justice = LightevalTaskConfig( name="ethics:justice", - prompt_function=prompt.ethics_justice, + prompt_function=ethics_justice_prompt, hf_repo="lighteval/hendrycks_ethics", hf_subset="justice", hf_avail_splits=["train", "validation", "test"], @@ -71,7 +112,7 @@ ethics_utilitarianism = LightevalTaskConfig( name="ethics:utilitarianism", - prompt_function=prompt.ethics_utilitarianism, + prompt_function=ethics_utilitarianism_prompt, hf_repo="lighteval/hendrycks_ethics", hf_subset="utilitarianism", hf_avail_splits=["train", "validation", "test"], @@ -86,7 +127,7 @@ ethics_virtue = LightevalTaskConfig( name="ethics:virtue", - prompt_function=prompt.ethics_virtue, + prompt_function=ethics_virtue_prompt, hf_repo="lighteval/hendrycks_ethics", hf_subset="virtue", hf_avail_splits=["train", "validation", "test"], diff --git a/src/lighteval/tasks/tasks/glue.py b/src/lighteval/tasks/tasks/glue.py index 9dc1b8ad4..984ce92c3 100644 --- a/src/lighteval/tasks/tasks/glue.py +++ b/src/lighteval/tasks/tasks/glue.py @@ -19,14 +19,162 @@ paper: """ -import lighteval.tasks.default_prompts as prompt +import re + from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def boolq_harness_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=f"{line['passage']}\nQuestion: {line['question']}?\nAnswer:", + choices=[" no", " yes"], + gold_index=int(line["label"]), + ) + + +def cb_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=f"{line['premise']}\nQuestion: {line['hypothesis']}. True, False or Neither?\nAnswer:", + choices=[" True", " False", " Neither"], + gold_index=int(line["label"]), + ) + + +def copa_prompt(line, task_name: str = None): + connector = {"cause": "because", "effect": "therefore"}[line["question"]] + return Doc( + task_name=task_name, + query=f"{line['premise'].strip()[:-1]} {connector}", + choices=[f" {line[c][0].lower()}{line[c][1:]}" for c in ["choice1", "choice2"]], + gold_index=int(line["label"]), + ) + + +def multirc_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=f"{line['paragraph']}\nQuestion: {line['question']}\nAnswer:", + choices=[f" {line['answer']}\nIs the answer correct? yes", f" {line['answer']}\nIs the answer correct? no"], + gold_index=0 if line["label"] else 1, + ) + + +def wic_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=f"Sentence 1: {line['sentence1']}\nSentence 2: {line['sentence2']}\nQuestion: Is the word '{line['word']}' used in the same way in the two sentences above?\nAnswer:", + choices=[" no", " yes"], + gold_index=int(line["label"]), + ) + + +def wsc_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=f"Passage: {line['text']}\n'Question: In the passage above, does the pronoun {line['span2_text']} refer to {line['span1_text']}?\nAnswer:", + choices=[" no", " yes"], + gold_index=int(line["label"]), + ) + + +def cola_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=f"{line['sentence']}\nQuestion: Does this sentence make sense?\nAnswer:", + choices=[" no", " yes"], + gold_index=int(line["label"]), + ) + + +def mnli_prompt(line, task_name: str = None): + hypothesis = line["hypothesis"].strip() + ("" if line["hypothesis"].strip().endswith(".") else ".") + return Doc( + task_name=task_name, + query=f"{line['premise']}\nQuestion: {hypothesis} True, False or Neither?\nAnswer:", + choices=[" True", " Neither", " False"], + gold_index=int(line["label"]), + ) + + +def mrpc_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=f"Sentence 1: {line['sentence1']}\nSentence 2: {line['sentence2']}\nQuestion: Do both sentences mean the same thing?\nAnswer:", + choices=[" no", " yes"], + gold_index=int(line["label"]), + ) + + +def qnli_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=f"{line['question']}\n{line['sentence']}\nQuestion: Does this response answer the question?\nAnswer:", + choices=[" yes", " no"], + gold_index=int(line["label"]), + ) + + +def qqp_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=f"Question 1: {line['question1']}\nQuestion 2: {line['question2']}\nQuestion: Do both questions ask the same thing?\nAnswer:", + choices=[" no", " yes"], + gold_index=int(line["label"]), + ) + + +def rte_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=f"{line['sentence1']}\nQuestion: {line['sentence2']} True or False?\nAnswer:", + choices=[" True", " False"], + gold_index=int(line["label"]), + ) + + +def sst_prompt(line, task_name: str = None): + def general_detokenize(cur_string): + cur_string = cur_string.replace(" n't", "n't") + cur_string = cur_string.replace(" )", ")") + cur_string = cur_string.replace("( ", "(") + cur_string = cur_string.replace('" ', '"') + cur_string = cur_string.replace(' "', '"') + cur_string = re.sub(r" (['.,])", r"\1", cur_string) + return cur_string + + return Doc( + task_name=task_name, + query=f"{general_detokenize(line['sentence'])}\nQuestion: Is this sentence positive or negative?\nAnswer:", + choices=[" negative", " positive"], + gold_index=int(line["label"]), + ) + + +def stsb_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=f"sentence 1: {line['sentence1']}\nsentence 2: {line['sentence2']}\nOn a scale of 0 to 5, how similar are the two sentences?\nAnswer:", + gold_index=0, + choices=[line["label"]], + ) + + +def wnli_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=f"{line['sentence1']}\nQuestion: {line['sentence2']} True or False?\nAnswer:", + choices=[" False", " True"], + gold_index=int(line["label"]), + ) glue_cola = LightevalTaskConfig( name="glue:cola", - prompt_function=prompt.cola, + prompt_function=cola_prompt, hf_repo="nyu-mll/glue", hf_subset="cola", hf_avail_splits=["test", "train", "validation"], @@ -41,7 +189,7 @@ glue_mnli = LightevalTaskConfig( name="glue:mnli", - prompt_function=prompt.mnli, + prompt_function=mnli_prompt, hf_repo="nyu-mll/glue", hf_subset="mnli_matched", hf_avail_splits=["train", "validation"], @@ -56,7 +204,7 @@ glue_mnli_mismatched = LightevalTaskConfig( name="glue:mnli_mismatched", - prompt_function=prompt.mnli, + prompt_function=mnli_prompt, hf_repo="nyu-mll/glue", hf_subset="mnli_mismatched", hf_avail_splits=["train", "validation"], @@ -71,7 +219,7 @@ glue_mrpc = LightevalTaskConfig( name="glue:mrpc", - prompt_function=prompt.mrpc, + prompt_function=mrpc_prompt, hf_repo="nyu-mll/glue", hf_subset="mrpc", hf_avail_splits=["test", "train", "validation"], @@ -86,7 +234,7 @@ glue_qnli = LightevalTaskConfig( name="glue:qnli", - prompt_function=prompt.qnli, + prompt_function=qnli_prompt, hf_repo="nyu-mll/glue", hf_subset="qnli", hf_avail_splits=["test", "train", "validation"], @@ -101,7 +249,7 @@ glue_qqp = LightevalTaskConfig( name="glue:qqp", - prompt_function=prompt.qqp, + prompt_function=qqp_prompt, hf_repo="nyu-mll/glue", hf_subset="qqp", hf_avail_splits=["train", "validation", "test"], @@ -116,7 +264,7 @@ glue_rte = LightevalTaskConfig( name="glue:rte", - prompt_function=prompt.rte, + prompt_function=rte_prompt, hf_repo="nyu-mll/glue", hf_subset="rte", hf_avail_splits=["test", "train", "validation"], @@ -131,7 +279,7 @@ glue_sst2 = LightevalTaskConfig( name="glue:sst2", - prompt_function=prompt.sst, + prompt_function=sst_prompt, hf_repo="nyu-mll/glue", hf_subset="sst2", hf_avail_splits=["test", "train", "validation"], @@ -146,7 +294,7 @@ glue_stsb = LightevalTaskConfig( name="glue:stsb", - prompt_function=prompt.stsb, + prompt_function=stsb_prompt, hf_repo="nyu-mll/glue", hf_subset="stsb", hf_avail_splits=["test", "train", "validation"], @@ -161,7 +309,7 @@ glue_wnli = LightevalTaskConfig( name="glue:wnli", - prompt_function=prompt.wnli, + prompt_function=wnli_prompt, hf_repo="nyu-mll/glue", hf_subset="wnli", hf_avail_splits=["test", "train", "validation"], @@ -176,7 +324,7 @@ super_glue_boolq = LightevalTaskConfig( name="super_glue:boolq", - prompt_function=prompt.boolq_harness, + prompt_function=boolq_harness_prompt, hf_repo="aps/super_glue", hf_subset="boolq", hf_avail_splits=["test", "train", "validation"], @@ -191,7 +339,7 @@ super_glue_cb = LightevalTaskConfig( name="super_glue:cb", - prompt_function=prompt.cb, + prompt_function=cb_prompt, hf_repo="aps/super_glue", hf_subset="cb", hf_avail_splits=["test", "train", "validation"], @@ -206,7 +354,7 @@ super_glue_copa = LightevalTaskConfig( name="super_glue:copa", - prompt_function=prompt.copa, + prompt_function=copa_prompt, hf_repo="aps/super_glue", hf_subset="copa", hf_avail_splits=["test", "train", "validation"], @@ -221,7 +369,7 @@ super_glue_rte = LightevalTaskConfig( name="super_glue:rte", - prompt_function=prompt.rte, + prompt_function=rte_prompt, hf_repo="aps/super_glue", hf_subset="rte", hf_avail_splits=["test", "train", "validation"], @@ -236,7 +384,7 @@ super_glue_multirc = LightevalTaskConfig( name="super_glue:multirc", - prompt_function=prompt.multirc, + prompt_function=multirc_prompt, hf_repo="aps/super_glue", hf_subset="multirc", hf_avail_splits=["train", "validation"], @@ -251,7 +399,7 @@ super_glue_wic = LightevalTaskConfig( name="super_glue:wic", - prompt_function=prompt.wic, + prompt_function=wic_prompt, hf_repo="aps/super_glue", hf_subset="wic", hf_avail_splits=["test", "train", "validation"], @@ -266,7 +414,7 @@ super_glue_wsc = LightevalTaskConfig( name="super_glue:wsc", - prompt_function=prompt.wsc, + prompt_function=wsc_prompt, hf_repo="aps/super_glue", hf_subset="wsc", hf_avail_splits=["test", "train", "validation"], diff --git a/src/lighteval/tasks/tasks/gpqa.py b/src/lighteval/tasks/tasks/gpqa.py index 748be055a..88a31c97c 100644 --- a/src/lighteval/tasks/tasks/gpqa.py +++ b/src/lighteval/tasks/tasks/gpqa.py @@ -30,9 +30,9 @@ from inspect_ai.scorer import choice from inspect_ai.solver import multiple_choice -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc def record_to_sample(record): @@ -50,9 +50,61 @@ def sample_to_fewshot(sample): return f"{sample.input}\n\n" + f"ANSWER: {sample.target}" +def gpqa_prompt(line, task_name: str = None): + GPQA_QUERY_TEMPLATE = """ +Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering. + +{Question} + +A) {A} +B) {B} +C) {C} +D) {D} +""".strip() + gold_index = random.randint(0, 3) + choices = [line["Incorrect Answer 1"], line["Incorrect Answer 2"], line["Incorrect Answer 3"]] + choices.insert(gold_index, line["Correct Answer"]) + + query = GPQA_QUERY_TEMPLATE.format( + A=choices[0], B=choices[1], C=choices[2], D=choices[3], Question=line["Question"] + ) + + return Doc( + task_name=task_name, + query=query, + choices=list(ascii_uppercase)[: len(choices)], + gold_index=gold_index, + instruction=query, + ) + + +def gpqa_instruct_prompt(line, task_name: str = None): + gold_index = random.randint(0, 3) + choices = [line["Incorrect Answer 1"], line["Incorrect Answer 2"], line["Incorrect Answer 3"]] + choices.insert(gold_index, line["Correct Answer"]) + instruction = "Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering." + query_template = "{Instruction}\n\n{Question}\n\nA) {A}\nB) {B}\nC) {C}\nD) {D}" + query = query_template.format( + A=choices[0].strip(), + B=choices[1].strip(), + C=choices[2].strip(), + D=choices[3].strip(), + Question=line["Question"].strip(), + Instruction=instruction, + ) + + return Doc( + task_name=task_name, + query=query, + choices=list(ascii_uppercase)[: len(choices)], + gold_index=gold_index, + instruction=instruction, + ) + + gpqa = LightevalTaskConfig( name="gpqa:mc", - prompt_function=prompt.gpqa, + prompt_function=gpqa_prompt, sample_fields=record_to_sample, sample_to_fewshot=sample_to_fewshot, solver=[multiple_choice(cache=True)], @@ -71,7 +123,7 @@ def sample_to_fewshot(sample): gpqa_diamond_instruct = LightevalTaskConfig( name="gpqa:diamond", - prompt_function=prompt.gpqa_instruct, + prompt_function=gpqa_instruct_prompt, sample_fields=record_to_sample, sample_to_fewshot=sample_to_fewshot, solver=[multiple_choice(cache=True)], @@ -90,7 +142,7 @@ def sample_to_fewshot(sample): gpqa_extended_instruct = LightevalTaskConfig( name="gpqa:extended", - prompt_function=prompt.gpqa_instruct, + prompt_function=gpqa_instruct_prompt, sample_fields=record_to_sample, sample_to_fewshot=sample_to_fewshot, solver=[multiple_choice(cache=True)], @@ -109,7 +161,7 @@ def sample_to_fewshot(sample): gpqa_main_instruct = LightevalTaskConfig( name="gpqa:main", - prompt_function=prompt.gpqa_instruct, + prompt_function=gpqa_instruct_prompt, sample_fields=record_to_sample, sample_to_fewshot=sample_to_fewshot, solver=[multiple_choice(cache=True)], diff --git a/src/lighteval/tasks/tasks/gsm8k.py b/src/lighteval/tasks/tasks/gsm8k.py index 230729941..24c98055e 100644 --- a/src/lighteval/tasks/tasks/gsm8k.py +++ b/src/lighteval/tasks/tasks/gsm8k.py @@ -21,9 +21,9 @@ from inspect_ai.dataset import Sample from inspect_ai.solver import generate, prompt_template -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics, math_scorer from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc # setup for problem + instructions for providing answer @@ -55,9 +55,18 @@ def sample_to_fewshot(sample): return f"{sample.input}\n\nReasoning:\n" + f"{sample.metadata['reasoning']}\n\n" + f"ANSWER: {sample.target}" +def gsm8k_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=f"Question: {line['question']}\nAnswer:", + choices=[f" {line['answer']}"], + gold_index=0, + ) + + gsm8k = LightevalTaskConfig( name="gsm8k", - prompt_function=prompt.gsm8k, + prompt_function=gsm8k_prompt, sample_fields=record_to_sample, sample_to_fewshot=sample_to_fewshot, solver=[prompt_template(MATH_PROMPT_TEMPLATE), generate(cache=True)], diff --git a/src/lighteval/tasks/tasks/gsm_plus.py b/src/lighteval/tasks/tasks/gsm_plus.py index 594b9b49f..f6bdad1fd 100644 --- a/src/lighteval/tasks/tasks/gsm_plus.py +++ b/src/lighteval/tasks/tasks/gsm_plus.py @@ -23,9 +23,9 @@ from inspect_ai.dataset import Sample from inspect_ai.solver import generate, prompt_template -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics, math_scorer from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc # setup for problem + instructions for providing answer @@ -56,9 +56,20 @@ def sample_to_fewshot(sample): return f"{sample.input}\n\nReasoning:\n" + f"{sample.metadata['reasoning']}\n\n" + f"ANSWER: {sample.target}" +def gsm_plus_prompt(line, task_name: str = None): + if line["perturbation_type"] == "critical thinking": + return None + return Doc( + task_name=task_name, + query=f"Question: {line['question']}\n\nAnswer:", + choices=[line["answer"]], + gold_index=0, + ) + + gsm_plus = LightevalTaskConfig( name="gsm_plus", - prompt_function=prompt.gsm_plus, + prompt_function=gsm_plus_prompt, sample_fields=record_to_sample, sample_to_fewshot=sample_to_fewshot, solver=[prompt_template(MATH_PROMPT_TEMPLATE), generate(cache=True)], diff --git a/src/lighteval/tasks/tasks/headqa.py b/src/lighteval/tasks/tasks/headqa.py index 6e4316aa3..85c025c5f 100644 --- a/src/lighteval/tasks/tasks/headqa.py +++ b/src/lighteval/tasks/tasks/headqa.py @@ -22,14 +22,23 @@ https://arxiv.org/abs/1906.04701 """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def headqa_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=f"Question: {line['qtext']}\nAnswer:", + choices=[f" {answer['atext']}" for answer in line["answers"]], + gold_index=int(line["ra"]) - 1, + ) headqa_en = LightevalTaskConfig( name="headqa:en", - prompt_function=prompt.headqa, + prompt_function=headqa_prompt, hf_repo="lighteval/headqa_harness", hf_subset="en", hf_avail_splits=["train", "test", "validation"], @@ -47,7 +56,7 @@ headqa_es = LightevalTaskConfig( name="headqa:es", - prompt_function=prompt.headqa, + prompt_function=headqa_prompt, hf_repo="lighteval/headqa_harness", hf_subset="es", hf_avail_splits=["train", "test", "validation"], diff --git a/src/lighteval/tasks/tasks/hellaswag.py b/src/lighteval/tasks/tasks/hellaswag.py index 053ce384e..bb948f749 100644 --- a/src/lighteval/tasks/tasks/hellaswag.py +++ b/src/lighteval/tasks/tasks/hellaswag.py @@ -19,14 +19,32 @@ https://arxiv.org/abs/1905.07830 """ -import lighteval.tasks.default_prompts as prompt +from string import ascii_uppercase + from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def hellaswag_prompt(line, task_name: str = None): + query = "The following are multiple choice questions (with answers) about common sense.\n\n" + query += f"Question: {line['activity_label']}: {line['ctx_a']} {line['ctx_b'].capitalize()}\n" + query += "".join([f"{key}. {choice}\n" for key, choice in zip(ascii_uppercase, line["endings"])]) + query += "Answer:" + + gold_ix = int(line["label"]) if line["label"] != "" else -1 + return Doc( + task_name=task_name, + query=query, + choices=[" " + i for i in ascii_uppercase[: len(line["endings"])]], + gold_index=gold_ix, + instruction="The following are multiple choice questions (with answers) about common sense.\n\n", + ) hellaswag = LightevalTaskConfig( name="hellaswag", - prompt_function=prompt.hellaswag_generative, + prompt_function=hellaswag_prompt, hf_repo="Rowan/hellaswag", hf_subset="default", hf_avail_splits=["train", "test", "validation"], diff --git a/src/lighteval/tasks/tasks/imdb.py b/src/lighteval/tasks/tasks/imdb.py index 6b165d5be..d6d0d9715 100644 --- a/src/lighteval/tasks/tasks/imdb.py +++ b/src/lighteval/tasks/tasks/imdb.py @@ -19,14 +19,35 @@ https://aclanthology.org/P11-1015/ """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def imdb_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=f"Passage: {line['input']}\nSentiment: ", + choices=["Positive", "Negative"], + gold_index=["Positive", "Negative"].index(line["reference"]), + ) + + +def imdb_contrastset_prompt(line, task_name: str = None): + if line["contrast_input"] is None or line["contrast_references"] is None: + return imdb(line) + + return Doc( + task_name=task_name, + query=f"Passage: {line['contrast_inputs']}\nSentiment: ", + choices=["Positive", "Negative"], + gold_index=["Positive", "Negative"].index(line["contrast_references"]), + ) imdb = LightevalTaskConfig( name="imdb", - prompt_function=prompt.imdb, + prompt_function=imdb_prompt, hf_repo="lighteval/IMDB_helm", hf_subset="default", hf_avail_splits=["train", "test"], @@ -44,7 +65,7 @@ imdb_contrastset = LightevalTaskConfig( name="imdb:contrastset", - prompt_function=prompt.imdb_contrastset, + prompt_function=imdb_contrastset_prompt, hf_repo="lighteval/IMDB_helm", hf_subset="default", hf_avail_splits=["test"], diff --git a/src/lighteval/tasks/tasks/lambada.py b/src/lighteval/tasks/tasks/lambada.py index f820091fd..a039f3d33 100644 --- a/src/lighteval/tasks/tasks/lambada.py +++ b/src/lighteval/tasks/tasks/lambada.py @@ -21,14 +21,37 @@ https://arxiv.org/abs/1606.06031 """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +lambada_cloze_query_suffix = " ____. ->" + + +def lambada_cloze_prompt(line, task_name: str = None): + query, choice = line["text"].rsplit(" ", 1) + return Doc( + task_name=task_name, + query=f"{query}{lambada_cloze_query_suffix}", + gold_index=0, + choices=[f" {choice}"], + ) + + +def lambada_prompt(line, task_name: str = None): + query, choice = line["text"].rsplit(" ", 1) + return Doc( + task_name=task_name, + query=query, + gold_index=0, + choices=[f" {choice}"], + ) lambada_standard = LightevalTaskConfig( name="lambada:standard", - prompt_function=prompt.lambada, + prompt_function=lambada_prompt, hf_repo="cimec/lambada", hf_subset="plain_text", hf_avail_splits=["train", "test", "validation"], @@ -44,7 +67,7 @@ lambada_standard_cloze = LightevalTaskConfig( name="lambada:standard_cloze", - prompt_function=prompt.lambada_cloze, + prompt_function=lambada_cloze_prompt, hf_repo="cimec/lambada", hf_subset="plain_text", hf_avail_splits=["train", "test", "validation"], diff --git a/src/lighteval/tasks/tasks/legal_summarization.py b/src/lighteval/tasks/tasks/legal_summarization.py index 42168949c..ea27796ba 100644 --- a/src/lighteval/tasks/tasks/legal_summarization.py +++ b/src/lighteval/tasks/tasks/legal_summarization.py @@ -19,14 +19,34 @@ https://arxiv.org/abs/2210.13448 """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def legal_summarization_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=f"###\nArticle:{line['text']}\n\nSummarize the above article in 3 sentences.\n", + choices=[str(line["summary"])], + gold_index=0, + specific={"text": line["text"]}, + ) + + +def multilexsum_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=f"###\nArticle: {line['article']}\n\nSummarize the above article in 2 sentences.\n", + gold_index=0, + choices=[line["summary"]], + specific={"text": line["article"]}, + ) legal_summarization_billsum = LightevalTaskConfig( name="legal_summarization:billsum", - prompt_function=prompt.legal_summarization, + prompt_function=legal_summarization_prompt, hf_repo="lighteval/legal_summarization", hf_subset="BillSum", hf_avail_splits=["train", "test"], @@ -49,7 +69,7 @@ legal_summarization_eurlexsum = LightevalTaskConfig( name="legal_summarization:eurlexsum", - prompt_function=prompt.legal_summarization, + prompt_function=legal_summarization_prompt, hf_repo="lighteval/legal_summarization", hf_subset="EurLexSum", hf_avail_splits=["train", "test", "validation"], @@ -72,7 +92,7 @@ legal_summarization_multilexsum = LightevalTaskConfig( name="legal_summarization:multilexsum", - prompt_function=prompt.multilexsum, + prompt_function=multilexsum_prompt, hf_repo="lighteval/legal_summarization", hf_subset="MultiLexSum", hf_avail_splits=["train", "test", "validation"], diff --git a/src/lighteval/tasks/tasks/legalsupport.py b/src/lighteval/tasks/tasks/legalsupport.py index 492ac3922..4893755d4 100644 --- a/src/lighteval/tasks/tasks/legalsupport.py +++ b/src/lighteval/tasks/tasks/legalsupport.py @@ -17,14 +17,34 @@ paper: """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def legalsupport_prompt(line, task_name: str = None): + query = f"Which statement best supports the passage?\nPassage: {line['context']}\n" + query += "".join( + [ + f"{key}. {choice}\n" + for key, choice in zip( + ["a", "b"], [line["citation_a"]["parenthetical"], line["citation_b"]["parenthetical"]] + ) + ] + ) + query += "Answer:" + + return Doc( + task_name=task_name, + query=query, + choices=["a", "b"], + gold_index=0 if line["answer_label"] == "citation_a" else 1, + ) legalsupport = LightevalTaskConfig( name="legalsupport", - prompt_function=prompt.legal_support, + prompt_function=legalsupport_prompt, hf_repo="lighteval/LegalSupport", hf_subset="default", hf_avail_splits=["train", "test", "validation"], diff --git a/src/lighteval/tasks/tasks/lexglue.py b/src/lighteval/tasks/tasks/lexglue.py index 001a6963c..cd02ddf72 100644 --- a/src/lighteval/tasks/tasks/lexglue.py +++ b/src/lighteval/tasks/tasks/lexglue.py @@ -18,14 +18,59 @@ https://arxiv.org/abs/2110.00976 """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def lex_glue(line, instruction, task_name: str = None): + return Doc( + task_name=task_name, + query=f"{instruction}\nPassage: {line['input']}\nAnswer: ", + choices=line["references"], + gold_index=[line["references"].index(item) for item in line["gold"]], + instruction=instruction + "\n", + ) + + +def lex_glue_ecthr_a_prompt(line, task_name: str = None): + instruction = "In this task, you are given the facts from a case heard at the European Court of Human Rights (ECtHR). Predict the articles of the ECtHR that were violated (if any)." + return lex_glue(line, instruction, task_name) + + +def lex_glue_ecthr_b_prompt(line, task_name: str = None): + instruction = "In this task, you are given the facts from a case heard at the European Court of Human Rights (ECtHR). Predict the articles of ECtHR that were allegedly violated (considered by the court)." + return lex_glue(line, instruction, task_name) + + +def lex_glue_scotus_prompt(line, task_name: str = None): + instruction = "In this task, you are given a case heard at the Supreme Court of the United States (SCOTUS). Predict the relevant issue area." + return lex_glue(line, instruction, task_name) + + +def lex_glue_eurlex_prompt(line, task_name: str = None): + instruction = "In this task, you are given an EU law document published in the EUR-Lex portal. Predict the relevant EuroVoc concepts." + return lex_glue(line, instruction, task_name) + + +def lex_glue_ledgar_prompt(line, task_name: str = None): + instruction = "In this task, you are given a contract provision \nfrom contracts obtained from US Securities and Exchange Commission (SEC) filings. Predict the main topic." + return lex_glue(line, instruction, task_name) + + +def lex_glue_unfair_tos_prompt(line, task_name: str = None): + instruction = "In this task, you are given a sentence \nfrom a Terms of Service (ToS) document from on-line platforms. Predict the types of unfair contractual terms" + return lex_glue(line, instruction, task_name) + + +def lex_glue_case_hold_prompt(line, task_name: str = None): + instruction = "In this task, you are given an excerpt from a court decision, \ncontaining a reference to a particular case, while the holding statement is masked out. Predict the index of the holding statement fitting in the context at from a selection of five choices." + return lex_glue(line, instruction, task_name) lexglue_case_hold = LightevalTaskConfig( name="lexglue:case_hold", - prompt_function=prompt.lex_glue_case_hold, + prompt_function=lex_glue_case_hold_prompt, hf_repo="lighteval/lexglue", hf_subset="case_hold", hf_avail_splits=["train", "test", "validation"], @@ -40,7 +85,7 @@ lexglue_ecthr_a = LightevalTaskConfig( name="lexglue:ecthr_a", - prompt_function=prompt.lex_glue_ecthr_a, + prompt_function=lex_glue_ecthr_a_prompt, hf_repo="lighteval/lexglue", hf_subset="ecthr_a", hf_avail_splits=["train", "test", "validation"], @@ -55,7 +100,7 @@ lexglue_ecthr_b = LightevalTaskConfig( name="lexglue:ecthr_b", - prompt_function=prompt.lex_glue_ecthr_b, + prompt_function=lex_glue_ecthr_b_prompt, hf_repo="lighteval/lexglue", hf_subset="ecthr_b", hf_avail_splits=["train", "test", "validation"], @@ -70,7 +115,7 @@ lexglue_eurlex = LightevalTaskConfig( name="lexglue:eurlex", - prompt_function=prompt.lex_glue_eurlex, + prompt_function=lex_glue_eurlex_prompt, hf_repo="lighteval/lexglue", hf_subset="eurlex", hf_avail_splits=["train", "test", "validation"], @@ -85,7 +130,7 @@ lexglue_ledgar = LightevalTaskConfig( name="lexglue:ledgar", - prompt_function=prompt.lex_glue_ledgar, + prompt_function=lex_glue_ledgar_prompt, hf_repo="lighteval/lexglue", hf_subset="ledgar", hf_avail_splits=["train", "test", "validation"], @@ -100,7 +145,7 @@ lexglue_scotus = LightevalTaskConfig( name="lexglue:scotus", - prompt_function=prompt.lex_glue_scotus, + prompt_function=lex_glue_scotus_prompt, hf_repo="lighteval/lexglue", hf_subset="scotus", hf_avail_splits=["train", "test", "validation"], @@ -115,7 +160,7 @@ lexglue_unfair_tos = LightevalTaskConfig( name="lexglue:unfair_tos", - prompt_function=prompt.lex_glue_unfair_tos, + prompt_function=lex_glue_unfair_tos_prompt, hf_repo="lighteval/lexglue", hf_subset="unfair_tos", hf_avail_splits=["train", "test", "validation"], diff --git a/src/lighteval/tasks/tasks/lextreme.py b/src/lighteval/tasks/tasks/lextreme.py index 23ebe857b..6d564c9bb 100644 --- a/src/lighteval/tasks/tasks/lextreme.py +++ b/src/lighteval/tasks/tasks/lextreme.py @@ -18,14 +18,184 @@ https://arxiv.org/abs/2301.13126 """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def lextreme_prompt(line, instruction, task_name: str = None): + return Doc( + task_name=task_name, + query=f"{instruction}\nPassage: {line['input']}\nAnswer: ", + choices=line["references"], + gold_index=[line["references"].index(item) for item in line["gold"]], + instruction=instruction + "\n", + ) + + +def lextreme_brazilian_court_decisions_judgment_prompt(line, task_name: str = None): + instruction = ( + "In this task, you are given the case description " + "from a decision heard at the State Supreme Court of Alagoas (Brazil). " + "Predict the judgment of the case " + "(no: The appeal was denied, " + "partial: For partially favourable decisions, " + "yes: For fully favourable decisions)" + ) + return lextreme_prompt(line, instruction, task_name) + + +def lextreme_brazilian_court_decisions_unanimity_prompt(line, task_name: str = None): + instruction = ( + "In this task, you are given the case description " + "from a decision heard at the State Supreme Court of Alagoas (Brazil). " + "Predict the unanimity of the case (unanimity, not-unanimity, not_determined)" + ) + return lextreme_prompt(line, instruction, task_name) + + +def lextreme_german_argument_mining_prompt(line, task_name: str = None): + instruction = ( + "In this task, you are given sentences from German court decisions. " + "Predict the major component of German Urteilsstil " + "(conclusion: Overall result, " + "definition: Abstract legal facts and consequences, " + "subsumption: Determination sentence / Concrete facts, " + "other: Anything else)" + ) + return lextreme_prompt(line, instruction, task_name) + + +def lextreme_greek_legal_code_chapter_prompt(line, task_name: str = None): + instruction = ( + "In this task, you are given a Greek legislative document. " + "Predict the chapter level category of the " + "'Permanent Greek Legislation Code - Raptarchis (Ραπτάρχης)' the document belongs to." + ) + return lextreme_prompt(line, instruction, task_name) + + +def lextreme_greek_legal_code_subject_prompt(line, task_name: str = None): + instruction = ( + "In this task, you are given a Greek legislative document. " + "Predict the subject level category of the " + "'Permanent Greek Legislation Code - Raptarchis (Ραπτάρχης)' the document belongs to." + ) + + return lextreme_prompt(line, instruction, task_name) + + +def lextreme_greek_legal_code_volume_prompt(line, task_name: str = None): + instruction = ( + "In this task, you are given a Greek legislative document. " + "Predict the volume level category of the " + "'Permanent Greek Legislation Code - Raptarchis (Ραπτάρχης)' the document belongs to." + ) + return lextreme_prompt(line, instruction, task_name) + + +def lextreme_swiss_judgment_prediction_prompt(line, task_name: str = None): + instruction = ( + "In this task, you are given the facts description " + "from a decision heard at the Swiss Federal Supreme Court. " + "Predict the judgment of the case (approval or dismissal)" + ) + return lextreme_prompt(line, instruction, task_name) + + +def lextreme_online_terms_of_service_unfairness_levels_prompt(line, task_name: str = None): + instruction = ( + "In this task, you are given a sentence " + "from a Terms of Service (ToS) document. " + "Predict the unfairness level of the sentence (potentially_unfair, clearly_unfair, clearly_fair, untagged)" + ) + return lextreme_prompt(line, instruction, task_name) + + +def lextreme_online_terms_of_service_clause_topics_prompt(line, task_name: str = None): + instruction = ( + "In this task, you are given a sentence " + "from a Terms of Service (ToS) document. " + "Predict the clause topics of the sentence " + "(0: Arbitration, " + "1: Unilateral change, " + "2: Content removal, " + "3: Jurisdiction, " + "4: Choice of law, " + "5: Limitation of liability, " + "6: Unilateral termination, " + "7: Contract by using, " + "8: Privacy included)" + ) + return lextreme_prompt(line, instruction, task_name) + + +def lextreme_covid19_emergency_event_prompt(line, task_name: str = None): + instruction = ( + "In this task, you are given a sentence from a European legislative document. " + "Predict the applicable measurements against COVID-19 " + "(0: State of Emergency, " + "1: Restrictions of fundamental rights and civil liberties, " + "2: Restrictions of daily liberties, " + "3: Closures / lockdown, " + "4: Suspension of international cooperation and commitments, " + "5: Police mobilization, " + "6: Army mobilization, " + "7: Government oversight)" + ) + + return lextreme_prompt(line, instruction, task_name) + + +def lextreme_multi_eurlex_level_1_prompt(line, task_name: str = None): + instruction = ( + "In this task, you are given a document from an EU law. Predict the level 1 concept in the EUROVOC taxonomy." + ) + return lextreme_prompt(line, instruction, task_name) + + +def lextreme_multi_eurlex_level_2_prompt(line, task_name: str = None): + instruction = ( + "In this task, you are given a document from an EU law. Predict the level 2 concept in the EUROVOC taxonomy." + ) + return lextreme_prompt(line, instruction, task_name) + + +def lextreme_multi_eurlex_level_3_prompt(line, task_name: str = None): + instruction = ( + "In this task, you are given a document from an EU law. Predict the level 3 concept in the EUROVOC taxonomy." + ) + return lextreme_prompt(line, instruction, task_name) + + +def lextreme_greek_legal_ner_prompt(line, task_name: str = None): + instruction = "In this task, you are given a Greek legal document. Predict the named entities." + return lextreme_prompt(line, instruction, task_name) + + +def lextreme_legalnero_prompt(line, task_name: str = None): + instruction = "In this task, you are given a legal text. Predict the named entities of legal interest." + return lextreme_prompt(line, instruction, task_name) + + +def lextreme_lener_br_prompt(line, task_name: str = None): + instruction = "In this task, you are given a Brazilian legal text. Predict the named entities." + return lextreme_prompt(line, instruction, task_name) + + +def lextreme_mapa_coarse_prompt(line, task_name: str = None): + instruction = "In this task, you are given a legal text. Predict the coarse-grained labels." + return lextreme_prompt(line, instruction, task_name) + + +def lextreme_mapa_fine_prompt(line, task_name: str = None): + instruction = "In this task, you are given a legal text. Predict the fine-grained labels." + return lextreme_prompt(line, instruction, task_name) lextreme_brazilian_court_decisions_judgment = LightevalTaskConfig( name="lextreme:brazilian_court_decisions_judgment", - prompt_function=prompt.lextreme_brazilian_court_decisions_judgment, + prompt_function=lextreme_brazilian_court_decisions_judgment_prompt, hf_repo="lighteval/lextreme", hf_subset="brazilian_court_decisions_judgment", hf_avail_splits=["train", "test", "validation"], @@ -40,7 +210,7 @@ lextreme_brazilian_court_decisions_unanimity = LightevalTaskConfig( name="lextreme:brazilian_court_decisions_unanimity", - prompt_function=prompt.lextreme_brazilian_court_decisions_unanimity, + prompt_function=lextreme_brazilian_court_decisions_unanimity_prompt, hf_repo="lighteval/lextreme", hf_subset="brazilian_court_decisions_unanimity", hf_avail_splits=["train", "test", "validation"], @@ -55,7 +225,7 @@ lextreme_covid19_emergency_event = LightevalTaskConfig( name="lextreme:covid19_emergency_event", - prompt_function=prompt.lextreme_covid19_emergency_event, + prompt_function=lextreme_covid19_emergency_event_prompt, hf_repo="lighteval/lextreme", hf_subset="covid19_emergency_event", hf_avail_splits=["train", "test", "validation"], @@ -70,7 +240,7 @@ lextreme_german_argument_mining = LightevalTaskConfig( name="lextreme:german_argument_mining", - prompt_function=prompt.lextreme_german_argument_mining, + prompt_function=lextreme_german_argument_mining_prompt, hf_repo="lighteval/lextreme", hf_subset="german_argument_mining", hf_avail_splits=["train", "test", "validation"], @@ -85,7 +255,7 @@ lextreme_greek_legal_code_chapter = LightevalTaskConfig( name="lextreme:greek_legal_code_chapter", - prompt_function=prompt.lextreme_greek_legal_code_chapter, + prompt_function=lextreme_greek_legal_code_chapter_prompt, hf_repo="lighteval/lextreme", hf_subset="greek_legal_code_chapter", hf_avail_splits=["train", "test", "validation"], @@ -100,7 +270,7 @@ lextreme_greek_legal_code_subject = LightevalTaskConfig( name="lextreme:greek_legal_code_subject", - prompt_function=prompt.lextreme_greek_legal_code_subject, + prompt_function=lextreme_greek_legal_code_subject_prompt, hf_repo="lighteval/lextreme", hf_subset="greek_legal_code_subject", hf_avail_splits=["train", "test", "validation"], @@ -115,7 +285,7 @@ lextreme_greek_legal_code_volume = LightevalTaskConfig( name="lextreme:greek_legal_code_volume", - prompt_function=prompt.lextreme_greek_legal_code_volume, + prompt_function=lextreme_greek_legal_code_volume_prompt, hf_repo="lighteval/lextreme", hf_subset="greek_legal_code_volume", hf_avail_splits=["train", "test", "validation"], @@ -130,7 +300,7 @@ lextreme_greek_legal_ner = LightevalTaskConfig( name="lextreme:greek_legal_ner", - prompt_function=prompt.lextreme_greek_legal_ner, + prompt_function=lextreme_greek_legal_ner_prompt, hf_repo="lighteval/lextreme", hf_subset="greek_legal_ner", hf_avail_splits=["train", "test", "validation"], @@ -145,7 +315,7 @@ lextreme_legalnero = LightevalTaskConfig( name="lextreme:legalnero", - prompt_function=prompt.lextreme_legalnero, + prompt_function=lextreme_legalnero_prompt, hf_repo="lighteval/lextreme", hf_subset="legalnero", hf_avail_splits=["train", "test", "validation"], @@ -160,7 +330,7 @@ lextreme_lener_br = LightevalTaskConfig( name="lextreme:lener_br", - prompt_function=prompt.lextreme_lener_br, + prompt_function=lextreme_lener_br_prompt, hf_repo="lighteval/lextreme", hf_subset="lener_br", hf_avail_splits=["train", "test", "validation"], @@ -175,7 +345,7 @@ lextreme_mapa_coarse = LightevalTaskConfig( name="lextreme:mapa_coarse", - prompt_function=prompt.lextreme_mapa_coarse, + prompt_function=lextreme_mapa_coarse_prompt, hf_repo="lighteval/lextreme", hf_subset="mapa_coarse", hf_avail_splits=["train", "test", "validation"], @@ -190,7 +360,7 @@ lextreme_mapa_fine = LightevalTaskConfig( name="lextreme:mapa_fine", - prompt_function=prompt.lextreme_mapa_fine, + prompt_function=lextreme_mapa_fine_prompt, hf_repo="lighteval/lextreme", hf_subset="mapa_fine", hf_avail_splits=["train", "test", "validation"], @@ -205,7 +375,7 @@ lextreme_multi_eurlex_level_1 = LightevalTaskConfig( name="lextreme:multi_eurlex_level_1", - prompt_function=prompt.lextreme_multi_eurlex_level_1, + prompt_function=lextreme_multi_eurlex_level_1_prompt, hf_repo="lighteval/lextreme", hf_subset="multi_eurlex_level_1", hf_avail_splits=["train", "test", "validation"], @@ -220,7 +390,7 @@ lextreme_multi_eurlex_level_2 = LightevalTaskConfig( name="lextreme:multi_eurlex_level_2", - prompt_function=prompt.lextreme_multi_eurlex_level_2, + prompt_function=lextreme_multi_eurlex_level_2_prompt, hf_repo="lighteval/lextreme", hf_subset="multi_eurlex_level_2", hf_avail_splits=["train", "test", "validation"], @@ -235,7 +405,7 @@ lextreme_multi_eurlex_level_3 = LightevalTaskConfig( name="lextreme:multi_eurlex_level_3", - prompt_function=prompt.lextreme_multi_eurlex_level_3, + prompt_function=lextreme_multi_eurlex_level_3_prompt, hf_repo="lighteval/lextreme", hf_subset="multi_eurlex_level_3", hf_avail_splits=["train", "test", "validation"], @@ -250,7 +420,7 @@ lextreme_online_terms_of_service_clause_topics = LightevalTaskConfig( name="lextreme:online_terms_of_service_clause_topics", - prompt_function=prompt.lextreme_online_terms_of_service_clause_topics, + prompt_function=lextreme_online_terms_of_service_clause_topics_prompt, hf_repo="lighteval/lextreme", hf_subset="online_terms_of_service_clause_topics", hf_avail_splits=["train", "test", "validation"], @@ -265,7 +435,7 @@ lextreme_online_terms_of_service_unfairness_levels = LightevalTaskConfig( name="lextreme:online_terms_of_service_unfairness_levels", - prompt_function=prompt.lextreme_online_terms_of_service_unfairness_levels, + prompt_function=lextreme_online_terms_of_service_unfairness_levels_prompt, hf_repo="lighteval/lextreme", hf_subset="online_terms_of_service_unfairness_levels", hf_avail_splits=["train", "test", "validation"], @@ -280,7 +450,7 @@ lextreme_swiss_judgment_prediction = LightevalTaskConfig( name="lextreme:swiss_judgment_prediction", - prompt_function=prompt.lextreme_swiss_judgment_prediction, + prompt_function=lextreme_swiss_judgment_prediction_prompt, hf_repo="lighteval/lextreme", hf_subset="swiss_judgment_prediction", hf_avail_splits=["train", "test", "validation"], diff --git a/src/lighteval/tasks/tasks/logiqa.py b/src/lighteval/tasks/tasks/logiqa.py index 08880b23f..03461f3ac 100644 --- a/src/lighteval/tasks/tasks/logiqa.py +++ b/src/lighteval/tasks/tasks/logiqa.py @@ -22,14 +22,27 @@ https://arxiv.org/abs/2007.08124 """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def logiqa_prompt(line, task_name: str = None): + query = f"Passage: {line['context']}\nQuestion: {line['question']}\nChoices:\n" + query += "".join([f"{key}. {choice}\n" for key, choice in zip(["A", "B", "C", "D"], line["options"])]) + query += "Answer:" + + return Doc( + task_name=task_name, + query=query, + choices=[f" {c}" for c in line["options"]], + gold_index=["a", "b", "c", "d"].index(line["label"]), + ) logiqa = LightevalTaskConfig( name="logiqa", - prompt_function=prompt.logiqa, + prompt_function=logiqa_prompt, hf_repo="lighteval/logiqa_harness", hf_subset="logiqa", hf_avail_splits=["train", "validation", "test"], diff --git a/src/lighteval/tasks/tasks/lsat_qa.py b/src/lighteval/tasks/tasks/lsat_qa.py index f6b655e96..713a1d3b2 100644 --- a/src/lighteval/tasks/tasks/lsat_qa.py +++ b/src/lighteval/tasks/tasks/lsat_qa.py @@ -17,14 +17,29 @@ paper: """ -import lighteval.tasks.default_prompts as prompt +from string import ascii_uppercase + from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def lsat_qa_prompt(line, task_name: str = None): + query = f"The following are multiple choice questions (with answers).\nPassage: {line['passage']}\nQuestion: {line['question']}\n" + query += "".join([f"{key}. {choice}\n" for key, choice in zip(ascii_uppercase, line["references"])]) + query += "Answer:" + return Doc( + task_name=task_name, + query=query, + choices=list(ascii_uppercase)[: len(line["references"])], + gold_index=line["gold_index"], + instruction="The following are multiple choice questions (with answers).\n", + ) lsat_qa = LightevalTaskConfig( name="lsat_qa", - prompt_function=prompt.lsat_qa, + prompt_function=lsat_qa_prompt, hf_repo="lighteval/lsat_qa", hf_subset="all", hf_avail_splits=["train", "test", "validation"], @@ -39,7 +54,7 @@ lsat_qa_assignment = LightevalTaskConfig( name="lsat_qa:assignment", - prompt_function=prompt.lsat_qa, + prompt_function=lsat_qa_prompt, hf_repo="lighteval/lsat_qa", hf_subset="assignment", hf_avail_splits=["train", "test", "validation"], @@ -54,7 +69,7 @@ lsat_qa_grouping = LightevalTaskConfig( name="lsat_qa:grouping", - prompt_function=prompt.lsat_qa, + prompt_function=lsat_qa_prompt, hf_repo="lighteval/lsat_qa", hf_subset="grouping", hf_avail_splits=["train", "test", "validation"], @@ -69,7 +84,7 @@ lsat_qa_miscellaneous = LightevalTaskConfig( name="lsat_qa:miscellaneous", - prompt_function=prompt.lsat_qa, + prompt_function=lsat_qa_prompt, hf_repo="lighteval/lsat_qa", hf_subset="miscellaneous", hf_avail_splits=["train", "test", "validation"], @@ -84,7 +99,7 @@ lsat_qa_ordering = LightevalTaskConfig( name="lsat_qa:ordering", - prompt_function=prompt.lsat_qa, + prompt_function=lsat_qa_prompt, hf_repo="lighteval/lsat_qa", hf_subset="ordering", hf_avail_splits=["train", "test", "validation"], diff --git a/src/lighteval/tasks/tasks/math.py b/src/lighteval/tasks/tasks/math.py index 5d1fdf732..74909d9c8 100644 --- a/src/lighteval/tasks/tasks/math.py +++ b/src/lighteval/tasks/tasks/math.py @@ -17,15 +17,24 @@ https://arxiv.org/abs/2305.20050 """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.metrics.normalizations import math_normalizer from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def math_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=f"Question: {line['problem']}\nAnswer:", + choices=[f" {line['solution']}"], + gold_index=0, + ) math_algebra = LightevalTaskConfig( name="math:algebra", - prompt_function=prompt.math, + prompt_function=math_prompt, hf_repo="DigitalLearningGmbH/MATH-lighteval", hf_subset="algebra", hf_avail_splits=["train", "test"], @@ -49,7 +58,7 @@ math_counting_and_probability = LightevalTaskConfig( name="math:counting_and_probability", - prompt_function=prompt.math, + prompt_function=math_prompt, hf_repo="DigitalLearningGmbH/MATH-lighteval", hf_subset="counting_and_probability", hf_avail_splits=["train", "test"], @@ -73,7 +82,7 @@ math_geometry = LightevalTaskConfig( name="math:geometry", - prompt_function=prompt.math, + prompt_function=math_prompt, hf_repo="DigitalLearningGmbH/MATH-lighteval", hf_subset="geometry", hf_avail_splits=["train", "test"], @@ -97,7 +106,7 @@ math_intermediate_algebra = LightevalTaskConfig( name="math:intermediate_algebra", - prompt_function=prompt.math, + prompt_function=math_prompt, hf_repo="DigitalLearningGmbH/MATH-lighteval", hf_subset="intermediate_algebra", hf_avail_splits=["train", "test"], @@ -121,7 +130,7 @@ math_number_theory = LightevalTaskConfig( name="math:number_theory", - prompt_function=prompt.math, + prompt_function=math_prompt, hf_repo="DigitalLearningGmbH/MATH-lighteval", hf_subset="number_theory", hf_avail_splits=["train", "test"], @@ -145,7 +154,7 @@ math_prealgebra = LightevalTaskConfig( name="math:prealgebra", - prompt_function=prompt.math, + prompt_function=math_prompt, hf_repo="DigitalLearningGmbH/MATH-lighteval", hf_subset="prealgebra", hf_avail_splits=["train", "test"], @@ -169,7 +178,7 @@ math_precalculus = LightevalTaskConfig( name="math:precalculus", - prompt_function=prompt.math, + prompt_function=math_prompt, hf_repo="DigitalLearningGmbH/MATH-lighteval", hf_subset="precalculus", hf_avail_splits=["train", "test"], diff --git a/src/lighteval/tasks/tasks/math_500.py b/src/lighteval/tasks/tasks/math_500.py index 55f9ebd24..5c075121d 100644 --- a/src/lighteval/tasks/tasks/math_500.py +++ b/src/lighteval/tasks/tasks/math_500.py @@ -19,14 +19,30 @@ https://arxiv.org/abs/2305.20050 """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def math_500_prompt(line, task_name: str = None): + MATH_QUERY_TEMPLATE = """ +Solve the following problem. The final line of your response MUST be of the following format: +"ANSWER: $ANSWER" (without quotes) where $ANSWER is the final answer. Think step by step before answering. + +{Question} +""".strip() + query = MATH_QUERY_TEMPLATE.format(Question=line["problem"]) + return Doc( + task_name=task_name, + query=query, + choices=[f"ANSWER: {line['solution']}"], + gold_index=0, + ) math_500 = LightevalTaskConfig( name="math_500", - prompt_function=prompt.math_500, + prompt_function=math_500_prompt, hf_repo="HuggingFaceH4/MATH-500", hf_subset="default", hf_avail_splits=["test"], diff --git a/src/lighteval/tasks/tasks/mathqa.py b/src/lighteval/tasks/tasks/mathqa.py index 2fabb6022..4bf6706aa 100644 --- a/src/lighteval/tasks/tasks/mathqa.py +++ b/src/lighteval/tasks/tasks/mathqa.py @@ -21,14 +21,37 @@ https://arxiv.org/abs/1905.13319 """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def mathqa_prompt(line, task_name: str = None): + query = f"Problem: {line['Problem']}\n" + query += "Options:\n" + query += "".join( + [ + f"{key}) {choice}\n" + for key, choice in zip( + ["a", "b", "c", "d", "e"], + [line["option_a"], line["option_b"], line["option_c"], line["option_d"], line["option_e"]], + ) + ] + ) + query += "Answer:" + return Doc( + task_name=task_name, + query=query, + choices=[ + f" {c}" for c in [line["option_a"], line["option_b"], line["option_c"], line["option_d"], line["option_e"]] + ], + gold_index=["a", "b", "c", "d", "e"].index(line["correct"]), + ) mathqa = LightevalTaskConfig( name="mathqa", - prompt_function=prompt.mathqa, + prompt_function=mathqa_prompt, hf_repo="allenai/math_qa", hf_subset="default", hf_avail_splits=["train", "validation", "test"], diff --git a/src/lighteval/tasks/tasks/med.py b/src/lighteval/tasks/tasks/med.py index bfd8077df..77552440e 100644 --- a/src/lighteval/tasks/tasks/med.py +++ b/src/lighteval/tasks/tasks/med.py @@ -18,14 +18,56 @@ https://medmcqa.github.io/ """ -import lighteval.tasks.default_prompts as prompt +from string import ascii_uppercase + from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def med_mcqa_prompt(line, task_name: str = None): + query = f"Give a letter answer among A, B, C or D.\nQuestion: {line['question']}\n" + query += "".join( + [ + f"{key}. {choice}\n" + for key, choice in zip(ascii_uppercase, [line["opa"], line["opb"], line["opc"], line["opd"]]) + ] + ) + query += "Answer:" + return Doc( + task_name=task_name, + query=query, + choices=list(ascii_uppercase)[:4], + gold_index=line["cop"] - 1, + instruction="Give a letter answer among A, B, C or D.\n", + ) + + +def med_paragraph_simplification_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=f"###\nArticle:{line['query']}\n\nSummarize the above article in 10 sentences.\n", + gold_index=0, + choices=[line["answer"]], + ) + + +def med_qa_prompt(line, task_name: str = None): + query = f"Give a letter answer among A, B, C or D.\nQuestion: {line['question']}\n" + query += "".join([f"{option['key']}. {option['value']}\n" for option in line["options"]]) + query += "Answer:" + return Doc( + task_name=task_name, + query=query, + choices=[opt["key"] for opt in line["options"]], + gold_index=list(ascii_uppercase).index(line["answer_idx"]), + instruction="Give a letter answer among A, B, C or D.\n", + ) med_mcqa = LightevalTaskConfig( name="med_mcqa", - prompt_function=prompt.med_mcqa, + prompt_function=med_mcqa_prompt, hf_repo="lighteval/med_mcqa", hf_subset="default", hf_avail_splits=["train", "test", "validation"], @@ -43,7 +85,7 @@ med_paragraph_simplification = LightevalTaskConfig( name="med_paragraph_simplification", - prompt_function=prompt.med_paragraph_simplification, + prompt_function=med_paragraph_simplification_prompt, hf_repo="lighteval/med_paragraph_simplification", hf_subset="default", hf_avail_splits=["train", "test", "validation"], @@ -61,7 +103,7 @@ med_qa = LightevalTaskConfig( name="med_qa", - prompt_function=prompt.med_qa, + prompt_function=med_qa_prompt, hf_repo="bigbio/med_qa", hf_subset="med_qa_en_source", hf_avail_splits=["train", "test", "validation"], diff --git a/src/lighteval/tasks/tasks/med_dialog.py b/src/lighteval/tasks/tasks/med_dialog.py index d38d0dba0..0bc61640c 100644 --- a/src/lighteval/tasks/tasks/med_dialog.py +++ b/src/lighteval/tasks/tasks/med_dialog.py @@ -17,14 +17,23 @@ paper: """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def med_dialog_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=f"###\nArticle:{line['src']}\n\nSummarize the above article in 1 sentence.\n", + gold_index=0, + choices=[line["tgt"]], + ) med_dialog_healthcaremagic = LightevalTaskConfig( name="med_dialog:healthcaremagic", - prompt_function=prompt.med_dialog, + prompt_function=med_dialog_prompt, hf_repo="lighteval/med_dialog", hf_subset="healthcaremagic", hf_avail_splits=["train", "test", "validation"], @@ -42,7 +51,7 @@ med_dialog_icliniq = LightevalTaskConfig( name="med_dialog:icliniq", - prompt_function=prompt.med_dialog, + prompt_function=med_dialog_prompt, hf_repo="lighteval/med_dialog", hf_subset="icliniq", hf_avail_splits=["train", "test", "validation"], diff --git a/src/lighteval/tasks/tasks/mgsm.py b/src/lighteval/tasks/tasks/mgsm.py index e0f7326d0..166235d80 100644 --- a/src/lighteval/tasks/tasks/mgsm.py +++ b/src/lighteval/tasks/tasks/mgsm.py @@ -21,14 +21,90 @@ https://arxiv.org/abs/2210.03057 """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def mgsm_prompt(line, question_key, answer_key, task_name: str = None): + if line["answer"] is not None: + query = f"{line['question']}\n{answer_key}" + gold = f" {line['answer'][len(answer_key) + 1 :]}" + else: + query = f"{question_key} {line['question']}\n{answer_key}" + gold = f" {str(line['answer_number'])}" + return Doc(task_name=task_name, query=query, choices=[gold], gold_index=0) + + +def mgsm_en_prompt(line, task_name: str = None): + question_key = "Question:" + answer_key = "Step-by-Step Answer:" + return mgsm_prompt(line, question_key, answer_key, task_name) + + +def mgsm_es_prompt(line, task_name: str = None): + question_key = "Pregunta:" + answer_key = "Respuesta paso a paso:" + return mgsm_prompt(line, question_key, answer_key, task_name) + + +def mgsm_fr_prompt(line, task_name: str = None): + question_key = "Question:" + answer_key = "Réponse étape par étape :" + return mgsm_prompt(line, question_key, answer_key, task_name) + + +def mgsm_de_prompt(line, task_name: str = None): + question_key = "Frage:" + answer_key = "Schritt-für-Schritt-Antwort:" + return mgsm_prompt(line, question_key, answer_key, task_name) + + +def mgsm_ru_prompt(line, task_name: str = None): + question_key = "Задача:" + answer_key = "Пошаговоерешение:" + return mgsm_prompt(line, question_key, answer_key, task_name) + + +def mgsm_zh_prompt(line, task_name: str = None): + question_key = "问题:" + answer_key = "逐步解答:" + return mgsm_prompt(line, question_key, answer_key, task_name) + + +def mgsm_ja_prompt(line, task_name: str = None): + question_key = "問題:" + answer_key = "ステップごとの答え:" + return mgsm_prompt(line, question_key, answer_key, task_name) + + +def mgsm_th_prompt(line, task_name: str = None): + question_key = "โจทย์:" + answer_key = "คำตอบทีละขั้นตอน:" + return mgsm_prompt(line, question_key, answer_key, task_name) + + +def mgsm_sw_prompt(line, task_name: str = None): + question_key = "Swali:" + answer_key = "Jibu la Hatua kwa Hatua:" + return mgsm_prompt(line, question_key, answer_key, task_name) + + +def mgsm_bn_prompt(line, task_name: str = None): + question_key = "প্রশ্ন:" + answer_key = "ধাপে ধাপে উত্তর:" + return mgsm_prompt(line, question_key, answer_key, task_name) + + +def mgsm_te_prompt(line, task_name: str = None): + question_key = "ప్రశ్న:" + answer_key = "దశలవారీగా సమాధానం:" + return mgsm_prompt(line, question_key, answer_key, task_name) mgsm_en = LightevalTaskConfig( name="mgsm:en", - prompt_function=prompt.mgsm_en, + prompt_function=mgsm_en_prompt, hf_repo="juletxara/mgsm", hf_subset="en", hf_avail_splits=["train", "test"], @@ -43,7 +119,7 @@ mgsm_es = LightevalTaskConfig( name="mgsm:es", - prompt_function=prompt.mgsm_es, + prompt_function=mgsm_es_prompt, hf_repo="juletxara/mgsm", hf_subset="es", hf_avail_splits=["train", "test"], @@ -58,7 +134,7 @@ mgsm_fr = LightevalTaskConfig( name="mgsm:fr", - prompt_function=prompt.mgsm_fr, + prompt_function=mgsm_fr_prompt, hf_repo="juletxara/mgsm", hf_subset="fr", hf_avail_splits=["train", "test"], @@ -73,7 +149,7 @@ mgsm_de = LightevalTaskConfig( name="mgsm:de", - prompt_function=prompt.mgsm_de, + prompt_function=mgsm_de_prompt, hf_repo="juletxara/mgsm", hf_subset="de", hf_avail_splits=["train", "test"], @@ -88,7 +164,7 @@ mgsm_ru = LightevalTaskConfig( name="mgsm:ru", - prompt_function=prompt.mgsm_ru, + prompt_function=mgsm_ru_prompt, hf_repo="juletxara/mgsm", hf_subset="ru", hf_avail_splits=["train", "test"], @@ -103,7 +179,7 @@ mgsm_zh = LightevalTaskConfig( name="mgsm:zh", - prompt_function=prompt.mgsm_zh, + prompt_function=mgsm_zh_prompt, hf_repo="juletxara/mgsm", hf_subset="zh", hf_avail_splits=["train", "test"], @@ -118,7 +194,7 @@ mgsm_ja = LightevalTaskConfig( name="mgsm:ja", - prompt_function=prompt.mgsm_ja, + prompt_function=mgsm_ja_prompt, hf_repo="juletxara/mgsm", hf_subset="ja", hf_avail_splits=["train", "test"], @@ -133,7 +209,7 @@ mgsm_th = LightevalTaskConfig( name="mgsm:th", - prompt_function=prompt.mgsm_th, + prompt_function=mgsm_th_prompt, hf_repo="juletxara/mgsm", hf_subset="th", hf_avail_splits=["train", "test"], @@ -148,7 +224,7 @@ mgsm_sw = LightevalTaskConfig( name="mgsm:sw", - prompt_function=prompt.mgsm_sw, + prompt_function=mgsm_sw_prompt, hf_repo="juletxara/mgsm", hf_subset="sw", hf_avail_splits=["train", "test"], @@ -163,7 +239,7 @@ mgsm_bn = LightevalTaskConfig( name="mgsm:bn", - prompt_function=prompt.mgsm_bn, + prompt_function=mgsm_bn_prompt, hf_repo="juletxara/mgsm", hf_subset="bn", hf_avail_splits=["train", "test"], @@ -178,7 +254,7 @@ mgsm_te = LightevalTaskConfig( name="mgsm:te", - prompt_function=prompt.mgsm_te, + prompt_function=mgsm_te_prompt, hf_repo="juletxara/mgsm", hf_subset="te", hf_avail_splits=["train", "test"], diff --git a/src/lighteval/tasks/tasks/mmlu.py b/src/lighteval/tasks/tasks/mmlu.py index 52661ace6..00936b3c6 100644 --- a/src/lighteval/tasks/tasks/mmlu.py +++ b/src/lighteval/tasks/tasks/mmlu.py @@ -18,14 +18,34 @@ https://arxiv.org/abs/2009.03300 """ -import lighteval.tasks.default_prompts as prompt +from string import ascii_uppercase + from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def mmlu_prompt(line, task_name: str = None): + subject = line["subject"] + query = f"The following are multiple choice questions (with answers) about {subject.replace('_', ' ')}.\n\nQuestion: {line['question']}" + query += "".join([f"\n{key}. {choice}" for key, choice in zip(ascii_uppercase, line["choices"])]) + query += "\nAnswer:" + + gold_ix = ascii_uppercase.index(line["answer"]) if isinstance(line["answer"], str) else line["answer"] + + return Doc( + task_name=task_name, + query=query, + choices=[" A", " B", " C", " D"], + gold_index=gold_ix, + fewshot_sorting_class=line["choices"][gold_ix], + instruction=f"The following are multiple choice questions (with answers) about {subject.replace('_', ' ')}.\n\n", + ) mmlu_abstract_algebra = LightevalTaskConfig( name="mmlu:abstract_algebra", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="abstract_algebra", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -40,7 +60,7 @@ mmlu_anatomy = LightevalTaskConfig( name="mmlu:anatomy", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="anatomy", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -55,7 +75,7 @@ mmlu_astronomy = LightevalTaskConfig( name="mmlu:astronomy", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="astronomy", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -70,7 +90,7 @@ mmlu_business_ethics = LightevalTaskConfig( name="mmlu:business_ethics", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="business_ethics", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -85,7 +105,7 @@ mmlu_clinical_knowledge = LightevalTaskConfig( name="mmlu:clinical_knowledge", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="clinical_knowledge", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -100,7 +120,7 @@ mmlu_college_biology = LightevalTaskConfig( name="mmlu:college_biology", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="college_biology", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -115,7 +135,7 @@ mmlu_college_chemistry = LightevalTaskConfig( name="mmlu:college_chemistry", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="college_chemistry", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -130,7 +150,7 @@ mmlu_college_computer_science = LightevalTaskConfig( name="mmlu:college_computer_science", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="college_computer_science", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -145,7 +165,7 @@ mmlu_college_mathematics = LightevalTaskConfig( name="mmlu:college_mathematics", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="college_mathematics", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -160,7 +180,7 @@ mmlu_college_medicine = LightevalTaskConfig( name="mmlu:college_medicine", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="college_medicine", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -175,7 +195,7 @@ mmlu_college_physics = LightevalTaskConfig( name="mmlu:college_physics", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="college_physics", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -190,7 +210,7 @@ mmlu_computer_security = LightevalTaskConfig( name="mmlu:computer_security", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="computer_security", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -205,7 +225,7 @@ mmlu_conceptual_physics = LightevalTaskConfig( name="mmlu:conceptual_physics", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="conceptual_physics", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -220,7 +240,7 @@ mmlu_econometrics = LightevalTaskConfig( name="mmlu:econometrics", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="econometrics", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -235,7 +255,7 @@ mmlu_electrical_engineering = LightevalTaskConfig( name="mmlu:electrical_engineering", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="electrical_engineering", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -250,7 +270,7 @@ mmlu_elementary_mathematics = LightevalTaskConfig( name="mmlu:elementary_mathematics", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="elementary_mathematics", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -265,7 +285,7 @@ mmlu_formal_logic = LightevalTaskConfig( name="mmlu:formal_logic", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="formal_logic", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -280,7 +300,7 @@ mmlu_global_facts = LightevalTaskConfig( name="mmlu:global_facts", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="global_facts", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -295,7 +315,7 @@ mmlu_high_school_biology = LightevalTaskConfig( name="mmlu:high_school_biology", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="high_school_biology", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -310,7 +330,7 @@ mmlu_high_school_chemistry = LightevalTaskConfig( name="mmlu:high_school_chemistry", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="high_school_chemistry", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -325,7 +345,7 @@ mmlu_high_school_computer_science = LightevalTaskConfig( name="mmlu:high_school_computer_science", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="high_school_computer_science", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -340,7 +360,7 @@ mmlu_high_school_european_history = LightevalTaskConfig( name="mmlu:high_school_european_history", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="high_school_european_history", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -355,7 +375,7 @@ mmlu_high_school_geography = LightevalTaskConfig( name="mmlu:high_school_geography", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="high_school_geography", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -370,7 +390,7 @@ mmlu_high_school_government_and_politics = LightevalTaskConfig( name="mmlu:high_school_government_and_politics", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="high_school_government_and_politics", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -385,7 +405,7 @@ mmlu_high_school_macroeconomics = LightevalTaskConfig( name="mmlu:high_school_macroeconomics", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="high_school_macroeconomics", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -400,7 +420,7 @@ mmlu_high_school_mathematics = LightevalTaskConfig( name="mmlu:high_school_mathematics", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="high_school_mathematics", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -415,7 +435,7 @@ mmlu_high_school_microeconomics = LightevalTaskConfig( name="mmlu:high_school_microeconomics", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="high_school_microeconomics", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -430,7 +450,7 @@ mmlu_high_school_physics = LightevalTaskConfig( name="mmlu:high_school_physics", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="high_school_physics", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -445,7 +465,7 @@ mmlu_high_school_psychology = LightevalTaskConfig( name="mmlu:high_school_psychology", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="high_school_psychology", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -460,7 +480,7 @@ mmlu_high_school_statistics = LightevalTaskConfig( name="mmlu:high_school_statistics", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="high_school_statistics", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -475,7 +495,7 @@ mmlu_high_school_us_history = LightevalTaskConfig( name="mmlu:high_school_us_history", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="high_school_us_history", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -490,7 +510,7 @@ mmlu_high_school_world_history = LightevalTaskConfig( name="mmlu:high_school_world_history", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="high_school_world_history", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -505,7 +525,7 @@ mmlu_human_aging = LightevalTaskConfig( name="mmlu:human_aging", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="human_aging", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -520,7 +540,7 @@ mmlu_human_sexuality = LightevalTaskConfig( name="mmlu:human_sexuality", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="human_sexuality", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -535,7 +555,7 @@ mmlu_international_law = LightevalTaskConfig( name="mmlu:international_law", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="international_law", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -550,7 +570,7 @@ mmlu_jurisprudence = LightevalTaskConfig( name="mmlu:jurisprudence", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="jurisprudence", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -565,7 +585,7 @@ mmlu_logical_fallacies = LightevalTaskConfig( name="mmlu:logical_fallacies", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="logical_fallacies", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -580,7 +600,7 @@ mmlu_machine_learning = LightevalTaskConfig( name="mmlu:machine_learning", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="machine_learning", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -595,7 +615,7 @@ mmlu_management = LightevalTaskConfig( name="mmlu:management", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="management", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -610,7 +630,7 @@ mmlu_marketing = LightevalTaskConfig( name="mmlu:marketing", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="marketing", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -625,7 +645,7 @@ mmlu_medical_genetics = LightevalTaskConfig( name="mmlu:medical_genetics", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="medical_genetics", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -640,7 +660,7 @@ mmlu_miscellaneous = LightevalTaskConfig( name="mmlu:miscellaneous", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="miscellaneous", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -655,7 +675,7 @@ mmlu_moral_disputes = LightevalTaskConfig( name="mmlu:moral_disputes", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="moral_disputes", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -670,7 +690,7 @@ mmlu_moral_scenarios = LightevalTaskConfig( name="mmlu:moral_scenarios", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="moral_scenarios", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -685,7 +705,7 @@ mmlu_nutrition = LightevalTaskConfig( name="mmlu:nutrition", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="nutrition", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -700,7 +720,7 @@ mmlu_philosophy = LightevalTaskConfig( name="mmlu:philosophy", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="philosophy", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -715,7 +735,7 @@ mmlu_prehistory = LightevalTaskConfig( name="mmlu:prehistory", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="prehistory", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -730,7 +750,7 @@ mmlu_professional_accounting = LightevalTaskConfig( name="mmlu:professional_accounting", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="professional_accounting", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -745,7 +765,7 @@ mmlu_professional_law = LightevalTaskConfig( name="mmlu:professional_law", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="professional_law", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -760,7 +780,7 @@ mmlu_professional_medicine = LightevalTaskConfig( name="mmlu:professional_medicine", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="professional_medicine", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -775,7 +795,7 @@ mmlu_professional_psychology = LightevalTaskConfig( name="mmlu:professional_psychology", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="professional_psychology", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -790,7 +810,7 @@ mmlu_public_relations = LightevalTaskConfig( name="mmlu:public_relations", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="public_relations", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -805,7 +825,7 @@ mmlu_security_studies = LightevalTaskConfig( name="mmlu:security_studies", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="security_studies", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -820,7 +840,7 @@ mmlu_sociology = LightevalTaskConfig( name="mmlu:sociology", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="sociology", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -835,7 +855,7 @@ mmlu_us_foreign_policy = LightevalTaskConfig( name="mmlu:us_foreign_policy", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="us_foreign_policy", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -850,7 +870,7 @@ mmlu_virology = LightevalTaskConfig( name="mmlu:virology", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="virology", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], @@ -865,7 +885,7 @@ mmlu_world_religions = LightevalTaskConfig( name="mmlu:world_religions", - prompt_function=prompt.mmlu_helm, + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset="world_religions", hf_avail_splits=["auxiliary_train", "test", "validation", "dev"], diff --git a/src/lighteval/tasks/tasks/mmlu_redux.py b/src/lighteval/tasks/tasks/mmlu_redux.py index 066ea07c4..06cbfd896 100644 --- a/src/lighteval/tasks/tasks/mmlu_redux.py +++ b/src/lighteval/tasks/tasks/mmlu_redux.py @@ -18,9 +18,33 @@ https://arxiv.org/abs/2406.04127 """ -import lighteval.tasks.default_prompts as prompt +from string import ascii_uppercase as LETTERS + from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def __mmlu_redux_2_prompt(line, topic, task_name: str = None): + query = f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n" + query += line["question"] + "\n" + query += "".join([f"{key}. {choice}\n" for key, choice in zip(LETTERS, line["choices"])]) + query += "Answer: " + gold_ix = line["answer"] if isinstance(line["answer"], int) else int(line["answer"]) + return Doc( + task_name=task_name, + query=query, + choices=list(LETTERS)[: len(line["choices"])], + gold_index=gold_ix, + instruction=f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n", + ) + + +def mmlu_redux_2_prompt(topic): + def _fn(line, task_name: str = None): + return __mmlu_redux_2_prompt(line, topic, task_name) + + return _fn _MMLU_REDUX_2_SUBSETS = [ @@ -87,7 +111,7 @@ TASKS_TABLE = [ LightevalTaskConfig( name=f"mmlu_redux_2:{subset}", - prompt_function=lambda line, task_name=None, s=subset: prompt.mmlu_redux_2(line, s, task_name), + prompt_function=mmlu_redux_2_prompt(subset), hf_repo="edinburgh-dawg/mmlu-redux-2.0", hf_subset=subset, hf_avail_splits=["test"], diff --git a/src/lighteval/tasks/tasks/mmmu_pro.py b/src/lighteval/tasks/tasks/mmmu_pro.py index f95889408..24f9fa5c5 100644 --- a/src/lighteval/tasks/tasks/mmmu_pro.py +++ b/src/lighteval/tasks/tasks/mmmu_pro.py @@ -17,14 +17,82 @@ https://arxiv.org/abs/2409.02813 """ -import lighteval.tasks.default_prompts as prompt +import ast +import re +import string + from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def mmmu_pro_prompt(line, task_name: str = None): + question = line["question"] + choices_string = line["options"] + answer = line["answer"] + + instructions = "Answer with the option letter from the given choices directly." + + choices = ast.literal_eval(str(choices_string)) + choices_letters = [chr(ord("A") + i) for i in range(len(choices))] + choices = [f"{letter}. {choice}" for letter, choice in zip(choices_letters, choices)] + + formatted_choices = "\n".join(choices) + prompt_text = f"\n{question}\n{formatted_choices}" + + image_order = [] + for num in re.findall(r"", prompt_text): + num = int(num) + if num not in image_order: + image_order.append(num) + images = [line[f"image_{i}"].convert("RGB") for i in image_order] + + gold_index = string.ascii_uppercase.index(answer) + + prompt_text = re.sub(r"", "[image \\1]", prompt_text) + choices = [re.sub(r"", "[image \\1]", choice) for choice in choices] + + return Doc( + task_name=task_name, + query=prompt_text, + choices=choices, + gold_index=gold_index, + images=images, + specific={"id": line["id"]}, + instruction=instructions, + ) + + +def mmmu_pro_vision_prompt(line, task_name: str = None): + instruction = ( + "Answer with the option letter from the given choices directly." + " The last line of your response should be of the following format: " + "'Answer: $LETTER' (without quotes) where LETTER is one of options." + ) + + choices_string = line["options"] + choices = ast.literal_eval(str(choices_string)) + choices_letters = [chr(ord("A") + i) for i in range(len(choices))] + choices = [f"{letter}. {choice}" for letter, choice in zip(choices_letters, choices)] + + answer = line["answer"] + gold_index = string.ascii_uppercase.index(answer) + + images = [line["image"]] + + return Doc( + task_name=task_name, + query=instruction, + choices=choices, + gold_index=gold_index, + images=images, + instruction=instruction, + ) mmmu_pro_standard_4_options = LightevalTaskConfig( name="mmmu_pro:standard-4", - prompt_function=prompt.mmmu_pro, + prompt_function=mmmu_pro_prompt, hf_repo="MMMU/MMMU_pro", hf_subset="standard (4 options)", hf_avail_splits=["test"], @@ -40,7 +108,7 @@ mmmu_pro_standard_10_options = LightevalTaskConfig( name="mmmu_pro:standard-10", - prompt_function=prompt.mmmu_pro, + prompt_function=mmmu_pro_prompt, hf_repo="MMMU/MMMU_pro", hf_subset="standard (10 options)", hf_avail_splits=["test"], @@ -56,7 +124,7 @@ mmmu_pro_vision = LightevalTaskConfig( name="mmmu_pro:vision", - prompt_function=prompt.mmmu_pro_vision, + prompt_function=mmmu_pro_vision_prompt, hf_repo="MMMU/MMMU_pro", hf_subset="vision", hf_avail_splits=["test"], diff --git a/src/lighteval/tasks/tasks/musr.py b/src/lighteval/tasks/tasks/musr.py index eb52487a8..fa2671e2d 100644 --- a/src/lighteval/tasks/tasks/musr.py +++ b/src/lighteval/tasks/tasks/musr.py @@ -20,14 +20,28 @@ https://arxiv.org/abs/2310.16049 """ -import lighteval.tasks.default_prompts as prompt +import ast + from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def musr_prompt(line, task_name: str = None): + choices = ast.literal_eval(line["choices"]) + + query = line["narrative"] + "\n\n" + query += line["question"] + "\n\n" + for i, choice in enumerate(choices): + query += f"{i + 1} - {choice}\n" + query += "Answer:" + + return Doc(task_name=task_name, query=query, choices=choices, gold_index=line["answer_index"]) musr_murder_mysteries = LightevalTaskConfig( name="musr:murder_mysteries", - prompt_function=prompt.musr, + prompt_function=musr_prompt, hf_repo="TAUR-Lab/MuSR", hf_subset="default", hf_avail_splits=["murder_mysteries"], @@ -43,7 +57,7 @@ musr_object_placements = LightevalTaskConfig( name="musr:object_placements", - prompt_function=prompt.musr, + prompt_function=musr_prompt, hf_repo="TAUR-Lab/MuSR", hf_subset="default", hf_avail_splits=["object_placements"], @@ -59,7 +73,7 @@ musr_team_allocation = LightevalTaskConfig( name="musr:team_allocation", - prompt_function=prompt.musr, + prompt_function=musr_prompt, hf_repo="TAUR-Lab/MuSR", hf_subset="default", hf_avail_splits=["team_allocation"], diff --git a/src/lighteval/tasks/tasks/narrativeqa.py b/src/lighteval/tasks/tasks/narrativeqa.py index 9981ac32c..3d3291c25 100644 --- a/src/lighteval/tasks/tasks/narrativeqa.py +++ b/src/lighteval/tasks/tasks/narrativeqa.py @@ -20,14 +20,26 @@ https://aclanthology.org/Q18-1023/ """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +narrativeqa_instruction = "Answer the question based on the passage.\n" + + +def narrativeqa_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=f"Passage: {line['passage']}\nQuestion: {line['question']}\nAnswer:", + gold_index=list(range(len(line["references"]))), + choices=[[str(a) for a in line["references"]]], + ) narrativeqa = LightevalTaskConfig( name="narrativeqa", - prompt_function=prompt.narrativeqa, + prompt_function=narrativeqa_prompt, hf_repo="lighteval/narrative_qa_helm", hf_subset="default", hf_avail_splits=["train", "test", "validation"], diff --git a/src/lighteval/tasks/tasks/numeracy.py b/src/lighteval/tasks/tasks/numeracy.py index a640d7000..8136faff7 100644 --- a/src/lighteval/tasks/tasks/numeracy.py +++ b/src/lighteval/tasks/tasks/numeracy.py @@ -17,14 +17,25 @@ paper: """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +numeracy_vars_names = ["x", "y", "z"] + + +def numeracy_prompt(line, task_name: str = None): + vars = "" + for ix, value in enumerate(line["vars"]): + vars += f"{numeracy_vars_names[ix]} {value}, " + vars += numeracy_vars_names[ix + 1] + return Doc(task_name=task_name, query=f"{line['equation']}, {vars}", gold_index=0, choices=[str(line["output"])]) numeracy_linear_example = LightevalTaskConfig( name="numeracy:linear_example", - prompt_function=prompt.numeracy, + prompt_function=numeracy_prompt, hf_repo="lighteval/numeracy", hf_subset="linear_example", hf_avail_splits=["train", "test"], @@ -39,7 +50,7 @@ numeracy_linear_standard = LightevalTaskConfig( name="numeracy:linear_standard", - prompt_function=prompt.numeracy, + prompt_function=numeracy_prompt, hf_repo="lighteval/numeracy", hf_subset="linear_standard", hf_avail_splits=["train", "test"], @@ -54,7 +65,7 @@ numeracy_parabola_example = LightevalTaskConfig( name="numeracy:parabola_example", - prompt_function=prompt.numeracy, + prompt_function=numeracy_prompt, hf_repo="lighteval/numeracy", hf_subset="parabola_example", hf_avail_splits=["train", "test"], @@ -69,7 +80,7 @@ numeracy_parabola_standard = LightevalTaskConfig( name="numeracy:parabola_standard", - prompt_function=prompt.numeracy, + prompt_function=numeracy_prompt, hf_repo="lighteval/numeracy", hf_subset="parabola_standard", hf_avail_splits=["train", "test"], @@ -84,7 +95,7 @@ numeracy_paraboloid_example = LightevalTaskConfig( name="numeracy:paraboloid_example", - prompt_function=prompt.numeracy, + prompt_function=numeracy_prompt, hf_repo="lighteval/numeracy", hf_subset="paraboloid_example", hf_avail_splits=["train", "test"], @@ -99,7 +110,7 @@ numeracy_paraboloid_standard = LightevalTaskConfig( name="numeracy:paraboloid_standard", - prompt_function=prompt.numeracy, + prompt_function=numeracy_prompt, hf_repo="lighteval/numeracy", hf_subset="paraboloid_standard", hf_avail_splits=["train", "test"], @@ -114,7 +125,7 @@ numeracy_plane_example = LightevalTaskConfig( name="numeracy:plane_example", - prompt_function=prompt.numeracy, + prompt_function=numeracy_prompt, hf_repo="lighteval/numeracy", hf_subset="plane_example", hf_avail_splits=["train", "test"], @@ -129,7 +140,7 @@ numeracy_plane_standard = LightevalTaskConfig( name="numeracy:plane_standard", - prompt_function=prompt.numeracy, + prompt_function=numeracy_prompt, hf_repo="lighteval/numeracy", hf_subset="plane_standard", hf_avail_splits=["train", "test"], diff --git a/src/lighteval/tasks/tasks/openbookqa.py b/src/lighteval/tasks/tasks/openbookqa.py index 059309ce5..62c378426 100644 --- a/src/lighteval/tasks/tasks/openbookqa.py +++ b/src/lighteval/tasks/tasks/openbookqa.py @@ -22,14 +22,32 @@ https://arxiv.org/abs/1809.02789 """ -import lighteval.tasks.default_prompts as prompt +from string import ascii_uppercase + from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def openbookqa_prompt(line, task_name: str = None): + query = "The following are multiple choice questions (with answers) about common sense.\n" + query += f"Question: {line['question_stem']}\n" + query += "".join([f"{key}. {choice}\n" for key, choice in zip(ascii_uppercase, line["choices"]["text"])]) + query += "Answer: " + + gold_ix = ["A", "B", "C", "D", "E"].index(line["answerKey"].strip()) + return Doc( + task_name=task_name, + query=query, + choices=list(ascii_uppercase[: len(line["choices"]["text"])]), + gold_index=gold_ix, + instruction="The following are multiple choice questions (with answers) about common sense.\n", + ) openbookqa = LightevalTaskConfig( name="openbookqa", - prompt_function=prompt.openbookqa_helm, + prompt_function=openbookqa_prompt, hf_repo="allenai/openbookqa", hf_subset="main", hf_avail_splits=["train", "test", "validation"], diff --git a/src/lighteval/tasks/tasks/piqa.py b/src/lighteval/tasks/tasks/piqa.py index c355a5ccc..4aa20719c 100644 --- a/src/lighteval/tasks/tasks/piqa.py +++ b/src/lighteval/tasks/tasks/piqa.py @@ -19,14 +19,34 @@ https://arxiv.org/abs/1911.11641 """ -import lighteval.tasks.default_prompts as prompt +from string import ascii_uppercase + from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def piqa_prompt(line, task_name: str = None): + letters = list(ascii_uppercase)[:2] + query = "The following are multiple choice questions (with answers) about common sense.\n" + query += f"Question: {line['goal']}\n" + query += "".join([f"{key}. {choice}\n" for key, choice in zip(letters, [line["sol1"], line["sol2"]])]) + query += "Answer: " + + gold_ix = int(line["label"]) + is_few_shots = line.get("__few_shots", False) + return Doc( + task_name=task_name, + query=query, + choices=letters if not is_few_shots else [line["sol1"], line["sol2"]], + gold_index=gold_ix, + instruction="The following are multiple choice questions (with answers) about common sense.\n", + ) piqa = LightevalTaskConfig( name="piqa", - prompt_function=prompt.piqa_helm, + prompt_function=piqa_prompt, hf_repo="ybisk/piqa", hf_subset="plain_text", hf_avail_splits=["train", "test", "validation"], diff --git a/src/lighteval/tasks/tasks/prost.py b/src/lighteval/tasks/tasks/prost.py index e7eb868ab..97a3c7178 100644 --- a/src/lighteval/tasks/tasks/prost.py +++ b/src/lighteval/tasks/tasks/prost.py @@ -22,14 +22,23 @@ https://arxiv.org/abs/2106.03634 """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def prost_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=line["question"], + choices=[f" {c}" for c in line["choices"]], + gold_index=int(line["label"]) if isinstance(line["label"], int) else int(line["label"]), + ) prost = LightevalTaskConfig( name="prost", - prompt_function=prompt.prost, + prompt_function=prost_prompt, hf_repo="lighteval/prost", hf_subset="default", hf_avail_splits=["test"], diff --git a/src/lighteval/tasks/tasks/pubmedqa.py b/src/lighteval/tasks/tasks/pubmedqa.py index d51c056a7..326789bb6 100644 --- a/src/lighteval/tasks/tasks/pubmedqa.py +++ b/src/lighteval/tasks/tasks/pubmedqa.py @@ -18,14 +18,23 @@ https://pubmedqa.github.io/ """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def pubmed_qa_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=f"{line['QUESTION']}\n{line['CONTEXTS']}\nAnswer: ", + choices=[line["final_decision"]], + gold_index=0, + ) pubmedqa = LightevalTaskConfig( name="pubmedqa", - prompt_function=prompt.pubmed_qa_helm, + prompt_function=pubmed_qa_prompt, hf_repo="pubmed_qa", hf_subset="pqa_labeled", hf_avail_splits=["train"], diff --git a/src/lighteval/tasks/tasks/qa4mre.py b/src/lighteval/tasks/tasks/qa4mre.py index 89fa46c21..c0a4cab3e 100644 --- a/src/lighteval/tasks/tasks/qa4mre.py +++ b/src/lighteval/tasks/tasks/qa4mre.py @@ -22,14 +22,23 @@ https://link.springer.com/chapter/10.1007/978-3-642-40802-1_29 """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def qa4mre_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=f"{line['question']}", + choices=[f" {c}" for c in line["choices"]], + gold_index=line["label"], + ) qa4mre_2011 = LightevalTaskConfig( name="qa4mre:2011", - prompt_function=prompt.qa4mre, + prompt_function=qa4mre_prompt, hf_repo="qa4mre", hf_subset="2011.main.EN", hf_avail_splits=["train"], @@ -47,7 +56,7 @@ qa4mre_2012 = LightevalTaskConfig( name="qa4mre:2012", - prompt_function=prompt.qa4mre, + prompt_function=qa4mre_prompt, hf_repo="qa4mre", hf_subset="2012.main.EN", hf_avail_splits=["train"], @@ -65,7 +74,7 @@ qa4mre_2013 = LightevalTaskConfig( name="qa4mre:2013", - prompt_function=prompt.qa4mre, + prompt_function=qa4mre_prompt, hf_repo="qa4mre", hf_subset="2013.main.EN", hf_avail_splits=["train"], diff --git a/src/lighteval/tasks/tasks/qasper.py b/src/lighteval/tasks/tasks/qasper.py index 8b190fed0..59bd5ee90 100644 --- a/src/lighteval/tasks/tasks/qasper.py +++ b/src/lighteval/tasks/tasks/qasper.py @@ -23,14 +23,23 @@ https://arxiv.org/abs/2105.03011 """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def qasper_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=f"Title: {line['title']}\n\nPassage: {line['passage']}\n\n Question: {line['question']}\nAnswer: ", + gold_index=0, + choices=[line["gold"]], + ) qasper = LightevalTaskConfig( name="qasper", - prompt_function=prompt.qasper, + prompt_function=qasper_prompt, hf_repo="allenai/qasper", hf_subset="qasper", hf_avail_splits=["train", "validation"], diff --git a/src/lighteval/tasks/tasks/quac.py b/src/lighteval/tasks/tasks/quac.py index 37a7624af..67268831d 100644 --- a/src/lighteval/tasks/tasks/quac.py +++ b/src/lighteval/tasks/tasks/quac.py @@ -18,14 +18,24 @@ https://aclanthology.org/D18-1241/ """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def quac_prompt(line, task_name: str = None): + references = [ref for ref in line["references"] if ref is not None and ref != ""] + return Doc( + task_name=task_name, + query=f"{line['prompt']}\nAnswer:", + choices=references, + gold_index=list(range(len(references))), + ) quac = LightevalTaskConfig( name="quac", - prompt_function=prompt.quac, + prompt_function=quac_prompt, hf_repo="lighteval/quac_helm", hf_subset="default", hf_avail_splits=["train", "validation"], diff --git a/src/lighteval/tasks/tasks/race_high.py b/src/lighteval/tasks/tasks/race_high.py index 7470130c6..37eddaf2d 100644 --- a/src/lighteval/tasks/tasks/race_high.py +++ b/src/lighteval/tasks/tasks/race_high.py @@ -22,14 +22,35 @@ https://aclanthology.org/D17-1082/ """ -import lighteval.tasks.default_prompts as prompt +import ast + from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def race_prompt(line, task_name: str = None): + line["problems"] = ast.literal_eval(line["problems"]) + text = f"Article: {line['article']}\n\n" + for problem in line["problems"][:-1]: + index = ["A", "B", "C", "D", "E"].index(problem["answer"]) + if problem["question"][-6:] == " _ .": + text += f"{problem['question'][-5:]}{problem['options'][index]}\n" + else: + text += f"Question: {problem['question']}\n" + text += f"Answer: {problem['options'][index]}\n" + text += line["problems"][-1]["question"] + return Doc( + task_name=task_name, + query=text, + choices=[f" {o}" for o in line["problems"][-1]["options"]], + gold_index=["A", "B", "C", "D", "E"].index(line["problems"][-1]["answer"]), + ) race_high = LightevalTaskConfig( name="race:high", - prompt_function=prompt.race, + prompt_function=race_prompt, hf_repo="EleutherAI/race", hf_subset="high", hf_avail_splits=["test"], diff --git a/src/lighteval/tasks/tasks/raft.py b/src/lighteval/tasks/tasks/raft.py index 56eea2c54..73c03c0e2 100644 --- a/src/lighteval/tasks/tasks/raft.py +++ b/src/lighteval/tasks/tasks/raft.py @@ -19,14 +19,87 @@ https://datasets-benchmarks-proceedings.neurips.cc/paper/2021/hash/ca46c1b9512a7a8315fa3c5a946e8265-Abstract-round2.html """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def raft_prompt(line, query_keys, instruction, task_name: str = None): + query = instruction + query += "\n".join([f"{key}: {line[key]}" for key in query_keys]) + query += "\nLabel:" + return Doc(task_name=task_name, query=query, gold_index=0, choices=[str(line["Label"])], instruction=instruction) + + +def raft_ade_corpus_v2_prompt(line, task_name: str = None): + instruction = "Label the sentence based on whether it is related to an adverse drug effect (ADE). Details are described below:\nDrugs: Names of drugs and chemicals that include brand names, trivial names, abbreviations and systematic names were annotated. Mentions of drugs or chemicals should strictly be in a therapeutic context. This category does not include the names of metabolites, reaction byproducts, or hospital chemicals (e.g. surgical equipment disinfectants).\nAdverse effect: Mentions of adverse effects include signs, symptoms, diseases, disorders, acquired abnormalities, deficiencies, organ damage or death that strictly occur as a consequence of drug intake.\nPossible labels:\n1. ADE-related\n2. not ADE-related" + query_keys = ["Sentence"] + return raft_prompt(line, query_keys, instruction, task_name) + + +def raft_banking_77_prompt(line, task_name: str = None): + instruction = "The following is a banking customer service query. Classify the query into one of the 77 categories available.\nPossible labels:\n1. Refund_not_showing_up\n2. activate_my_card\n3. age_limit\n4. apple_pay_or_google_pay\n5. atm_support\n6. automatic_top_up\n7. balance_not_updated_after_bank_transfer\n8. balance_not_updated_after_cheque_or_cash_deposit\n9. beneficiary_not_allowed\n10. cancel_transfer\n11. card_about_to_expire\n12. card_acceptance\n13. card_arrival\n14. card_delivery_estimate\n15. card_linking\n16. card_not_working\n17. card_payment_fee_charged\n18. card_payment_not_recognised\n19. card_payment_wrong_exchange_rate\n20. card_swallowed\n21. cash_withdrawal_charge\n22. cash_withdrawal_not_recognised\n23. change_pin\n24. compromised_card\n25. contactless_not_working\n26. country_support\n27. declined_card_payment\n28. declined_cash_withdrawal\n29. declined_transfer\n30. direct_debit_payment_not_recognised\n31. disposable_card_limits\n32. edit_personal_details\n33. exchange_charge\n34. exchange_rate\n35. exchange_via_app\n36. extra_charge_on_statement\n37. failed_transfer\n38. fiat_currency_support\n39. get_disposable_virtual_card\n40. get_physical_card\n41. getting_spare_card\n42. getting_virtual_card\n43. lost_or_stolen_card\n44. lost_or_stolen_phone\n45. order_physical_card\n46. passcode_forgotten\n47. pending_card_payment\n48. pending_cash_withdrawal\n49. pending_top_up\n50. pending_transfer\n51. pin_blocked\n52. receiving_money\n53. request_refund\n54. reverted_card_payment?\n55. supported_cards_and_currencies\n56. terminate_account\n57. top_up_by_bank_transfer_charge\n58. top_up_by_card_charge\n59. top_up_by_cash_or_cheque\n60. top_up_failed\n61. top_up_limits\n62. top_up_reverted\n63. topping_up_by_card\n64. transaction_charged_twice\n65. transfer_fee_charged\n66. transfer_into_account\n67. transfer_not_received_by_recipient\n68. transfer_timing\n69. unable_to_verify_identity\n70. verify_my_identity\n71. verify_source_of_funds\n72. verify_top_up\n73. virtual_card_not_working\n74. visa_or_mastercard\n75. why_verify_identity\n76. wrong_amount_of_cash_received\n77. wrong_exchange_rate_for_cash_withdrawal" + query_keys = ["Query"] + return raft_prompt(line, query_keys, instruction, task_name) + + +def raft_neurips_impact_statement_risks_prompt(line, task_name: str = None): + instruction = "Label the impact statement based on whether it mentions a harmful application of the research done in the paper. Make sure the statement is sufficient to conclude there are harmful applications of the research being done, not a past risk that this research is solving.\nPossible labels:\n1. doesn't mention a harmful application\n2. mentions a harmful application" + query_keys = ["Impact statement", "Paper title"] + return raft_prompt(line, query_keys, instruction, task_name) + + +def raft_one_stop_english_prompt(line, task_name: str = None): + instruction = "The following is an article sourced from The Guardian newspaper, and rewritten by teachers to suit three levels of adult English as Second Language (ESL) learners: elementary, intermediate, and advanced. Predict the level of the article.\nPossible labels:\n1. advanced\n2. elementary\n3. intermediate" + query_keys = ["Article"] + return raft_prompt(line, query_keys, instruction, task_name) + + +def raft_overruling_prompt(line, task_name: str = None): + instruction = "In law, an overruling sentence is a statement that nullifies a previous case decision as a precedent, by a constitutionally valid statute or a decision by the same or higher ranking court which establishes a different rule on the point of law involved. Label the sentence based on whether it is overruling or not.\nPossible labels:\n1. not overruling\n2. overruling" + query_keys = ["Sentence"] + return raft_prompt(line, query_keys, instruction, task_name) + + +def raft_semiconductor_org_types_prompt(line, task_name: str = None): + instruction = 'The dataset is a list of institutions that have contributed papers to semiconductor conferences in the last 25 years, as catalogued by IEEE and sampled randomly. The goal is to classify the institutions into one of three categories: "university", "company" or "research institute".\nPossible labels:\n1. company\n2. research institute\n3. university' + query_keys = ["Organization name", "Paper title"] + return raft_prompt(line, query_keys, instruction, task_name) + + +def raft_systematic_review_inclusion_prompt(line, task_name: str = None): + instruction = "Identify whether this paper should be included in a meta-review which includes the findings of systematic reviews on interventions designed to promote charitable donations.\nIncluded reviews should describe monetary charitable donations, assess any population of participants in any context, and be peer reviewed and written in English.\nThey should not report new data, be non-systematic reviews, consider cause-related marketing or other kinds of prosocial behaviour.\nPossible labels:\n1. included\n2. not included" + query_keys = ["Title", "Abstract", "Journal"] + return raft_prompt(line, query_keys, instruction, task_name) + + +def raft_tai_safety_research_prompt(line, task_name: str = None): + instruction = 'Transformative AI (TAI) is defined as AI that precipitates a transition comparable to (or more significant than) the agricultural or industrial revolution. Label a paper as "TAI safety research" if:\n1. The contents of the paper are directly motivated by, and substantively inform, the challenge of ensuring good outcomes for TAI,\n2. There is substantive content on AI safety, not just AI capabilities,\n3. The intended audience is the community of researchers,\n4. It meets a subjective threshold of seriousness/quality,\n5. Peer review is not required.\nPossible labels:\n1. TAI safety research\n2. not TAI safety research' + query_keys = ["Title", "Abstract Note", "Publication Title", "Item Type", "Publication Year"] + return raft_prompt(line, query_keys, instruction, task_name) + + +def raft_terms_of_service_prompt(line, task_name: str = None): + instruction = "Label the sentence from a Terms of Service based on whether it is potentially unfair. If it seems clearly unfair, mark it as potentially unfair.\nAccording to art. 3 of the Directive 93/13 on Unfair Terms in Consumer Contracts, a contractual term is unfair if: 1) it has not been individually negotiated; and 2) contrary to the requirement of good faith, it causes a significant imbalance in the parties rights and obligations, to the detriment of the consumer.\nDetails on types of potentially unfair clauses are found below:\nThe jurisdiction clause stipulates what courts will have the competence to adjudicate disputes under the contract. Jurisdiction clauses giving consumers a right to bring disputes in their place of residence were marked as clearly fair, whereas clauses stating that any judicial proceeding takes a residence away were marked as clearly unfair.\nThe choice of law clause specifies what law will govern the contract, meaning also what law will be applied in potential adjudication of a dispute arising under the contract. Clauses defining the applicable law as the law of the consumer's country of residence were marked as clearly fair. In every other case, the choice of law clause was considered as potentially unfair.\nThe limitation of liability clause stipulates that the duty to pay damages is limited or excluded, for certain kind of losses, under certain conditions. Clauses that explicitly affirm non-excludable providers' liabilities were marked as clearly fair. Clauses that reduce, limit, or exclude the liability of the service provider were marked as potentially unfair when concerning broad categories of losses or causes of them.\nThe unilateral change clause specifies the conditions under which the service provider could amend and modify the terms of service and/or the service itself. Such clause was always considered as potentially unfair.\nThe unilateral termination clause gives provider the right to suspend and/or terminate the service and/or the contract, and sometimes details the circumstances under which the provider claims to have a right to do so.\nThe contract by using clause stipulates that the consumer is bound by the terms of use of a specific service, simply by using the service, without even being required to mark that he or she has read and accepted them. We always marked such clauses as potentially unfair.\nThe content removal gives the provider a right to modify/delete user's content, including in-app purchases, and sometimes specifies the conditions under which the service provider may do so.\nThe arbitration clause requires or allows the parties to resolve their disputes through an arbitration process, before the case could go to court. Clauses stipulating that the arbitration should take place in a state other then the state of consumer's residence or be based on arbiter's discretion were marked as clearly unfair. Clauses defining arbitration as fully optional were marked as clearly fair.\nPossible labels:\n1. not potentially unfair\n2. potentially unfair" + query_keys = ["Sentence"] + return raft_prompt(line, query_keys, instruction, task_name) + + +def raft_tweet_eval_hate_prompt(line, task_name: str = None): + instruction = "Label whether the following tweet contains hate speech against either immigrants or women. Hate Speech (HS) is commonly defined as any communication that disparages a person or a group on the basis of some characteristic such as race, color, ethnicity, gender, sexual orientation, nationality, religion, or other characteristics.\nPossible labels:\n1. hate speech\n2. not hate speech" + query_keys = ["Tweet"] + return raft_prompt(line, query_keys, instruction, task_name) + + +def raft_twitter_complaints_prompt(line, task_name: str = None): + instruction = "A complaint presents a state of affairs which breaches the writer\u2019s favorable expectation. Label the tweet text based on whether it contains a complaint.\nPossible labels:\n1. complaint\n2. no complaint" + query_keys = ["Tweet text"] + return raft_prompt(line, query_keys, instruction, task_name) raft_ade_corpus_v2 = LightevalTaskConfig( name="raft:ade_corpus_v2", - prompt_function=prompt.raft_ade_corpus_v2, + prompt_function=raft_ade_corpus_v2_prompt, hf_repo="ought/raft", hf_subset="ade_corpus_v2", hf_avail_splits=["train", "test"], @@ -43,7 +116,7 @@ raft_banking_77 = LightevalTaskConfig( name="raft:banking_77", - prompt_function=prompt.raft_banking_77, + prompt_function=raft_banking_77_prompt, hf_repo="ought/raft", hf_subset="banking_77", hf_avail_splits=["train", "test"], @@ -60,7 +133,7 @@ raft_neurips_impact_statement_risks = LightevalTaskConfig( name="raft:neurips_impact_statement_risks", - prompt_function=prompt.raft_neurips_impact_statement_risks, + prompt_function=raft_neurips_impact_statement_risks_prompt, hf_repo="ought/raft", hf_subset="neurips_impact_statement_risks", hf_avail_splits=["train", "test"], @@ -77,7 +150,7 @@ raft_one_stop_english = LightevalTaskConfig( name="raft:one_stop_english", - prompt_function=prompt.raft_one_stop_english, + prompt_function=raft_one_stop_english_prompt, hf_repo="ought/raft", hf_subset="one_stop_english", hf_avail_splits=["train", "test"], @@ -94,7 +167,7 @@ raft_overruling = LightevalTaskConfig( name="raft:overruling", - prompt_function=prompt.raft_overruling, + prompt_function=raft_overruling_prompt, hf_repo="ought/raft", hf_subset="overruling", hf_avail_splits=["train", "test"], @@ -111,7 +184,7 @@ raft_semiconductor_org_types = LightevalTaskConfig( name="raft:semiconductor_org_types", - prompt_function=prompt.raft_semiconductor_org_types, + prompt_function=raft_semiconductor_org_types_prompt, hf_repo="ought/raft", hf_subset="semiconductor_org_types", hf_avail_splits=["train", "test"], @@ -128,7 +201,7 @@ raft_systematic_review_inclusion = LightevalTaskConfig( name="raft:systematic_review_inclusion", - prompt_function=prompt.raft_systematic_review_inclusion, + prompt_function=raft_systematic_review_inclusion_prompt, hf_repo="ought/raft", hf_subset="systematic_review_inclusion", hf_avail_splits=["train", "test"], @@ -145,7 +218,7 @@ raft_tai_safety_research = LightevalTaskConfig( name="raft:tai_safety_research", - prompt_function=prompt.raft_tai_safety_research, + prompt_function=raft_tai_safety_research_prompt, hf_repo="ought/raft", hf_subset="tai_safety_research", hf_avail_splits=["train", "test"], @@ -162,7 +235,7 @@ raft_terms_of_service = LightevalTaskConfig( name="raft:terms_of_service", - prompt_function=prompt.raft_terms_of_service, + prompt_function=raft_terms_of_service_prompt, hf_repo="ought/raft", hf_subset="terms_of_service", hf_avail_splits=["train", "test"], @@ -179,7 +252,7 @@ raft_tweet_eval_hate = LightevalTaskConfig( name="raft:tweet_eval_hate", - prompt_function=prompt.raft_tweet_eval_hate, + prompt_function=raft_tweet_eval_hate_prompt, hf_repo="ought/raft", hf_subset="tweet_eval_hate", hf_avail_splits=["train", "test"], @@ -196,7 +269,7 @@ raft_twitter_complaints = LightevalTaskConfig( name="raft:twitter_complaints", - prompt_function=prompt.raft_twitter_complaints, + prompt_function=raft_twitter_complaints_prompt, hf_repo="ought/raft", hf_subset="twitter_complaints", hf_avail_splits=["train", "test"], diff --git a/src/lighteval/tasks/tasks/real_toxicity_prompts.py b/src/lighteval/tasks/tasks/real_toxicity_prompts.py index 783b1bf90..7d6f2faf0 100644 --- a/src/lighteval/tasks/tasks/real_toxicity_prompts.py +++ b/src/lighteval/tasks/tasks/real_toxicity_prompts.py @@ -18,14 +18,24 @@ https://aclanthology.org/2020.findings-emnlp.301/ """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def real_toxicity_prompts_prompt(line, task_name: str = None): + # Some variants store text under 'prompt' -> 'text'; handle both flat and nested + text = ( + line["prompt"]["text"] + if isinstance(line.get("prompt"), dict) and "text" in line["prompt"] + else line.get("text", "") + ) + return Doc(task_name=task_name, query=text, choices=None, gold_index=None) real_toxicity_prompts = LightevalTaskConfig( name="real_toxicity_prompts", - prompt_function=prompt.real_toxicity_prompts, + prompt_function=real_toxicity_prompts_prompt, hf_repo="allenai/real-toxicity-prompts", hf_subset="default", hf_avail_splits=["train"], diff --git a/src/lighteval/tasks/tasks/sacrebleu.py b/src/lighteval/tasks/tasks/sacrebleu.py index 012574615..8b2aa0bc6 100644 --- a/src/lighteval/tasks/tasks/sacrebleu.py +++ b/src/lighteval/tasks/tasks/sacrebleu.py @@ -18,14 +18,50 @@ https://github.com/mjpost/sacrebleu """ +import ast + +import pycountry + from lighteval.metrics.metrics import Metrics -from lighteval.tasks import default_prompts as prompt from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc +from lighteval.utils.utils import as_list + + +def __wmt_prompt(line, alphabetical, task_name: str = None): + def language(code): + # key is alpha_2 or alpha_3 depending on the code length + language_tuple = pycountry.languages.get(**{f"alpha_{len(code)}": code}) + return language_tuple.name + + # It would be better to just reupload the file tbh + if isinstance(line["translation"], str): + line["translation"] = ast.literal_eval(line["translation"]) + for k, v in line["translation"].items(): + line["translation"][k] = as_list(v)[0] + + l_in, l_out = sorted(line["translation"].keys(), reverse=not alphabetical) + + return Doc( + task_name=task_name, + query=f"{language(l_in)} phrase: " + line["translation"][l_in].rstrip() + f"\n{language(l_out)} phrase:", + gold_index=0, + choices=[line["translation"][l_out].rstrip()], + instruction=f"Translate {language(l_in)} to {language(l_out)}, do not explain, only output the translation.", + ) + + +def wmt_alphabetical_prompt(line, task_name: str = None): + return __wmt_prompt(line, True, task_name) + + +def wmt_reverse_alphabetical_prompt(line, task_name: str = None): + return __wmt_prompt(line, False, task_name) iwslt17_ar_en = LightevalTaskConfig( name="iwslt17:ar-en", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="iwslt17_ar-en", hf_avail_splits=["test"], @@ -40,7 +76,7 @@ iwslt17_de_en = LightevalTaskConfig( name="iwslt17:de-en", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="iwslt17_de-en", hf_avail_splits=["test"], @@ -55,7 +91,7 @@ iwslt17_en_ar = LightevalTaskConfig( name="iwslt17:en-ar", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="iwslt17_ar-en", hf_avail_splits=["test"], @@ -70,7 +106,7 @@ iwslt17_en_de = LightevalTaskConfig( name="iwslt17:en-de", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="iwslt17_en-de", hf_avail_splits=["test"], @@ -85,7 +121,7 @@ iwslt17_en_fr = LightevalTaskConfig( name="iwslt17:en-fr", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="iwslt17_en-fr", hf_avail_splits=["test"], @@ -100,7 +136,7 @@ iwslt17_en_ja = LightevalTaskConfig( name="iwslt17:en-ja", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="iwslt17_en-ja", hf_avail_splits=["test"], @@ -115,7 +151,7 @@ iwslt17_en_ko = LightevalTaskConfig( name="iwslt17:en-ko", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="iwslt17_en-ko", hf_avail_splits=["test"], @@ -130,7 +166,7 @@ iwslt17_en_zh = LightevalTaskConfig( name="iwslt17:en-zh", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="iwslt17_en-zh", hf_avail_splits=["test"], @@ -145,7 +181,7 @@ iwslt17_fr_en = LightevalTaskConfig( name="iwslt17:fr-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="iwslt17_fr-en", hf_avail_splits=["test"], @@ -160,7 +196,7 @@ iwslt17_ja_en = LightevalTaskConfig( name="iwslt17:ja-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="iwslt17_ja-en", hf_avail_splits=["test"], @@ -175,7 +211,7 @@ iwslt17_ko_en = LightevalTaskConfig( name="iwslt17:ko-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="iwslt17_ko-en", hf_avail_splits=["test"], @@ -190,7 +226,7 @@ iwslt17_zh_en = LightevalTaskConfig( name="iwslt17:zh-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="iwslt17_zh-en", hf_avail_splits=["test"], @@ -205,7 +241,7 @@ mtnt2019_en_fr = LightevalTaskConfig( name="mtnt2019:en-fr", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="mtnt2019_en-fr", hf_avail_splits=["test"], @@ -220,7 +256,7 @@ mtnt2019_en_ja = LightevalTaskConfig( name="mtnt2019:en-ja", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="mtnt2019_en-ja", hf_avail_splits=["test"], @@ -235,7 +271,7 @@ mtnt2019_fr_en = LightevalTaskConfig( name="mtnt2019:fr-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="mtnt2019_fr-en", hf_avail_splits=["test"], @@ -250,7 +286,7 @@ mtnt2019_ja_en = LightevalTaskConfig( name="mtnt2019:ja-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="mtnt2019_ja-en", hf_avail_splits=["test"], @@ -265,7 +301,7 @@ wmt08_cs_en = LightevalTaskConfig( name="wmt08:cs-en", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt08_cs-en", hf_avail_splits=["test"], @@ -280,7 +316,7 @@ wmt08_de_en = LightevalTaskConfig( name="wmt08:de-en", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt08_de-en", hf_avail_splits=["test"], @@ -295,7 +331,7 @@ wmt08_en_cs = LightevalTaskConfig( name="wmt08:en-cs", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt08_en-cs", hf_avail_splits=["test"], @@ -310,7 +346,7 @@ wmt08_en_de = LightevalTaskConfig( name="wmt08:en-de", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt08_en-de", hf_avail_splits=["test"], @@ -325,7 +361,7 @@ wmt08_en_es = LightevalTaskConfig( name="wmt08:en-es", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt08_en-es", hf_avail_splits=["test"], @@ -340,7 +376,7 @@ wmt08_en_fr = LightevalTaskConfig( name="wmt08:en-fr", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt08_en-fr", hf_avail_splits=["test"], @@ -355,7 +391,7 @@ wmt08_en_hu = LightevalTaskConfig( name="wmt08:en-hu", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt08_en-hu", hf_avail_splits=["test"], @@ -370,7 +406,7 @@ wmt08_es_en = LightevalTaskConfig( name="wmt08:es-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt08_es-en", hf_avail_splits=["test"], @@ -385,7 +421,7 @@ wmt08_fr_en = LightevalTaskConfig( name="wmt08:fr-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt08_fr-en", hf_avail_splits=["test"], @@ -400,7 +436,7 @@ wmt08_hu_en = LightevalTaskConfig( name="wmt08:hu-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt08_hu-en", hf_avail_splits=["test"], @@ -415,7 +451,7 @@ wmt09_cs_en = LightevalTaskConfig( name="wmt09:cs-en", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt09_cs-en", hf_avail_splits=["test"], @@ -430,7 +466,7 @@ wmt09_de_en = LightevalTaskConfig( name="wmt09:de-en", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt09_de-en", hf_avail_splits=["test"], @@ -445,7 +481,7 @@ wmt09_en_cs = LightevalTaskConfig( name="wmt09:en-cs", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt09_en-cs", hf_avail_splits=["test"], @@ -460,7 +496,7 @@ wmt09_en_de = LightevalTaskConfig( name="wmt09:en-de", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt09_en-de", hf_avail_splits=["test"], @@ -475,7 +511,7 @@ wmt09_en_es = LightevalTaskConfig( name="wmt09:en-es", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt09_en-es", hf_avail_splits=["test"], @@ -490,7 +526,7 @@ wmt09_en_fr = LightevalTaskConfig( name="wmt09:en-fr", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt09_en-fr", hf_avail_splits=["test"], @@ -505,7 +541,7 @@ wmt09_en_hu = LightevalTaskConfig( name="wmt09:en-hu", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt09_en-hu", hf_avail_splits=["test"], @@ -520,7 +556,7 @@ wmt09_en_it = LightevalTaskConfig( name="wmt09:en-it", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt09_en-it", hf_avail_splits=["test"], @@ -535,7 +571,7 @@ wmt09_es_en = LightevalTaskConfig( name="wmt09:es-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt09_es-en", hf_avail_splits=["test"], @@ -550,7 +586,7 @@ wmt09_fr_en = LightevalTaskConfig( name="wmt09:fr-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt09_fr-en", hf_avail_splits=["test"], @@ -565,7 +601,7 @@ wmt09_hu_en = LightevalTaskConfig( name="wmt09:hu-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt09_hu-en", hf_avail_splits=["test"], @@ -580,7 +616,7 @@ wmt09_it_en = LightevalTaskConfig( name="wmt09:it-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt09_it-en", hf_avail_splits=["test"], @@ -595,7 +631,7 @@ wmt10_cs_en = LightevalTaskConfig( name="wmt10:cs-en", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt10_cs-en", hf_avail_splits=["test"], @@ -610,7 +646,7 @@ wmt10_de_en = LightevalTaskConfig( name="wmt10:de-en", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt10_de-en", hf_avail_splits=["test"], @@ -625,7 +661,7 @@ wmt10_en_cs = LightevalTaskConfig( name="wmt10:en-cs", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt10_en-cs", hf_avail_splits=["test"], @@ -640,7 +676,7 @@ wmt10_en_de = LightevalTaskConfig( name="wmt10:en-de", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt10_en-de", hf_avail_splits=["test"], @@ -655,7 +691,7 @@ wmt10_en_es = LightevalTaskConfig( name="wmt10:en-es", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt10_en-es", hf_avail_splits=["test"], @@ -670,7 +706,7 @@ wmt10_en_fr = LightevalTaskConfig( name="wmt10:en-fr", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt10_en-fr", hf_avail_splits=["test"], @@ -685,7 +721,7 @@ wmt10_es_en = LightevalTaskConfig( name="wmt10:es-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt10_es-en", hf_avail_splits=["test"], @@ -700,7 +736,7 @@ wmt10_fr_en = LightevalTaskConfig( name="wmt10:fr-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt10_fr-en", hf_avail_splits=["test"], @@ -715,7 +751,7 @@ wmt11_cs_en = LightevalTaskConfig( name="wmt11:cs-en", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt11_cs-en", hf_avail_splits=["test"], @@ -730,7 +766,7 @@ wmt11_de_en = LightevalTaskConfig( name="wmt11:de-en", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt11_de-en", hf_avail_splits=["test"], @@ -745,7 +781,7 @@ wmt11_en_cs = LightevalTaskConfig( name="wmt11:en-cs", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt11_en-cs", hf_avail_splits=["test"], @@ -760,7 +796,7 @@ wmt11_en_de = LightevalTaskConfig( name="wmt11:en-de", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt11_en-de", hf_avail_splits=["test"], @@ -775,7 +811,7 @@ wmt11_en_es = LightevalTaskConfig( name="wmt11:en-es", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt11_en-es", hf_avail_splits=["test"], @@ -790,7 +826,7 @@ wmt11_en_fr = LightevalTaskConfig( name="wmt11:en-fr", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt11_en-fr", hf_avail_splits=["test"], @@ -805,7 +841,7 @@ wmt11_es_en = LightevalTaskConfig( name="wmt11:es-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt11_es-en", hf_avail_splits=["test"], @@ -820,7 +856,7 @@ wmt11_fr_en = LightevalTaskConfig( name="wmt11:fr-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt11_fr-en", hf_avail_splits=["test"], @@ -835,7 +871,7 @@ wmt12_cs_en = LightevalTaskConfig( name="wmt12:cs-en", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt12_cs-en", hf_avail_splits=["test"], @@ -850,7 +886,7 @@ wmt12_de_en = LightevalTaskConfig( name="wmt12:de-en", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt12_de-en", hf_avail_splits=["test"], @@ -865,7 +901,7 @@ wmt12_en_cs = LightevalTaskConfig( name="wmt12:en-cs", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt12_en-cs", hf_avail_splits=["test"], @@ -880,7 +916,7 @@ wmt12_en_de = LightevalTaskConfig( name="wmt12:en-de", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt12_en-de", hf_avail_splits=["test"], @@ -895,7 +931,7 @@ wmt12_en_es = LightevalTaskConfig( name="wmt12:en-es", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt12_en-es", hf_avail_splits=["test"], @@ -910,7 +946,7 @@ wmt12_en_fr = LightevalTaskConfig( name="wmt12:en-fr", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt12_en-fr", hf_avail_splits=["test"], @@ -925,7 +961,7 @@ wmt12_es_en = LightevalTaskConfig( name="wmt12:es-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt12_es-en", hf_avail_splits=["test"], @@ -940,7 +976,7 @@ wmt12_fr_en = LightevalTaskConfig( name="wmt12:fr-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt12_fr-en", hf_avail_splits=["test"], @@ -955,7 +991,7 @@ wmt13_cs_en = LightevalTaskConfig( name="wmt13:cs-en", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt13_cs-en", hf_avail_splits=["test"], @@ -970,7 +1006,7 @@ wmt13_de_en = LightevalTaskConfig( name="wmt13:de-en", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt13_de-en", hf_avail_splits=["test"], @@ -985,7 +1021,7 @@ wmt13_en_cs = LightevalTaskConfig( name="wmt13:en-cs", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt13_en-cs", hf_avail_splits=["test"], @@ -1000,7 +1036,7 @@ wmt13_en_de = LightevalTaskConfig( name="wmt13:en-de", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt13_en-de", hf_avail_splits=["test"], @@ -1015,7 +1051,7 @@ wmt13_en_es = LightevalTaskConfig( name="wmt13:en-es", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt13_en-es", hf_avail_splits=["test"], @@ -1030,7 +1066,7 @@ wmt13_en_fr = LightevalTaskConfig( name="wmt13:en-fr", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt13_en-fr", hf_avail_splits=["test"], @@ -1045,7 +1081,7 @@ wmt13_en_ru = LightevalTaskConfig( name="wmt13:en-ru", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt13_en-ru", hf_avail_splits=["test"], @@ -1060,7 +1096,7 @@ wmt13_es_en = LightevalTaskConfig( name="wmt13:es-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt13_es-en", hf_avail_splits=["test"], @@ -1075,7 +1111,7 @@ wmt13_fr_en = LightevalTaskConfig( name="wmt13:fr-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt13_fr-en", hf_avail_splits=["test"], @@ -1090,7 +1126,7 @@ wmt13_ru_en = LightevalTaskConfig( name="wmt13:ru-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt13_ru-en", hf_avail_splits=["test"], @@ -1105,7 +1141,7 @@ wmt14_cs_en = LightevalTaskConfig( name="wmt14:cs-en", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt14_cs-en", hf_avail_splits=["test"], @@ -1120,7 +1156,7 @@ wmt14_de_en = LightevalTaskConfig( name="wmt14:de-en", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt14_de-en", hf_avail_splits=["test"], @@ -1135,7 +1171,7 @@ wmt14_en_cs = LightevalTaskConfig( name="wmt14:en-cs", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt14_en-cs", hf_avail_splits=["test"], @@ -1150,7 +1186,7 @@ wmt14_en_de = LightevalTaskConfig( name="wmt14:en-de", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt14_en-de", hf_avail_splits=["test"], @@ -1165,7 +1201,7 @@ wmt14_en_fr = LightevalTaskConfig( name="wmt14:en-fr", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="wmt14", hf_subset="fr-en", hf_avail_splits=["train", "validation", "test"], @@ -1180,7 +1216,7 @@ wmt14_en_fr = LightevalTaskConfig( name="wmt14:en-fr", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt14_en-fr", hf_avail_splits=["test"], @@ -1195,7 +1231,7 @@ wmt14_en_hi = LightevalTaskConfig( name="wmt14:en-hi", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt14_en-hi", hf_avail_splits=["test"], @@ -1210,7 +1246,7 @@ wmt14_en_ru = LightevalTaskConfig( name="wmt14:en-ru", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt14_en-ru", hf_avail_splits=["test"], @@ -1225,7 +1261,7 @@ wmt14_fr_en = LightevalTaskConfig( name="wmt14:fr-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="wmt14", hf_subset="fr-en", hf_avail_splits=["train", "validation", "test"], @@ -1240,7 +1276,7 @@ wmt14_fr_en = LightevalTaskConfig( name="wmt14:fr-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt14_fr-en", hf_avail_splits=["test"], @@ -1255,7 +1291,7 @@ wmt14_hi_en = LightevalTaskConfig( name="wmt14:hi-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt14_hi-en", hf_avail_splits=["test"], @@ -1270,7 +1306,7 @@ wmt14_ru_en = LightevalTaskConfig( name="wmt14:ru-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt14_ru-en", hf_avail_splits=["test"], @@ -1285,7 +1321,7 @@ wmt15_cs_en = LightevalTaskConfig( name="wmt15:cs-en", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt15_cs-en", hf_avail_splits=["test"], @@ -1300,7 +1336,7 @@ wmt15_de_en = LightevalTaskConfig( name="wmt15:de-en", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt15_de-en", hf_avail_splits=["test"], @@ -1315,7 +1351,7 @@ wmt15_en_cs = LightevalTaskConfig( name="wmt15:en-cs", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt15_en-cs", hf_avail_splits=["test"], @@ -1330,7 +1366,7 @@ wmt15_en_de = LightevalTaskConfig( name="wmt15:en-de", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt15_en-de", hf_avail_splits=["test"], @@ -1345,7 +1381,7 @@ wmt15_en_fi = LightevalTaskConfig( name="wmt15:en-fi", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt15_en-fi", hf_avail_splits=["test"], @@ -1360,7 +1396,7 @@ wmt15_en_fr = LightevalTaskConfig( name="wmt15:en-fr", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt15_en-fr", hf_avail_splits=["test"], @@ -1375,7 +1411,7 @@ wmt15_en_ru = LightevalTaskConfig( name="wmt15:en-ru", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt15_en-ru", hf_avail_splits=["test"], @@ -1390,7 +1426,7 @@ wmt15_fi_en = LightevalTaskConfig( name="wmt15:fi-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt15_fi-en", hf_avail_splits=["test"], @@ -1405,7 +1441,7 @@ wmt15_fr_en = LightevalTaskConfig( name="wmt15:fr-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt15_fr-en", hf_avail_splits=["test"], @@ -1420,7 +1456,7 @@ wmt15_ru_en = LightevalTaskConfig( name="wmt15:ru-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt15_ru-en", hf_avail_splits=["test"], @@ -1435,7 +1471,7 @@ wmt16_cs_en = LightevalTaskConfig( name="wmt16:cs-en", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt16_cs-en", hf_avail_splits=["test"], @@ -1450,7 +1486,7 @@ wmt16_de_en = LightevalTaskConfig( name="wmt16:de-en", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="wmt16", hf_subset="de-en", hf_avail_splits=["train", "validation", "test"], @@ -1465,7 +1501,7 @@ wmt16_de_en = LightevalTaskConfig( name="wmt16:de-en", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt16_de-en", hf_avail_splits=["test"], @@ -1480,7 +1516,7 @@ wmt16_en_cs = LightevalTaskConfig( name="wmt16:en-cs", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt16_en-cs", hf_avail_splits=["test"], @@ -1495,7 +1531,7 @@ wmt16_en_de = LightevalTaskConfig( name="wmt16:en-de", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="wmt16", hf_subset="de-en", hf_avail_splits=["train", "validation", "test"], @@ -1510,7 +1546,7 @@ wmt16_en_de = LightevalTaskConfig( name="wmt16:en-de", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt16_en-de", hf_avail_splits=["test"], @@ -1525,7 +1561,7 @@ wmt16_en_fi = LightevalTaskConfig( name="wmt16:en-fi", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt16_en-fi", hf_avail_splits=["test"], @@ -1540,7 +1576,7 @@ wmt16_en_ro = LightevalTaskConfig( name="wmt16:en-ro", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="wmt16", hf_subset="ro-en", hf_avail_splits=["train", "validation", "test"], @@ -1555,7 +1591,7 @@ wmt16_en_ro = LightevalTaskConfig( name="wmt16:en-ro", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt16_en-ro", hf_avail_splits=["test"], @@ -1570,7 +1606,7 @@ wmt16_en_ru = LightevalTaskConfig( name="wmt16:en-ru", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt16_en-ru", hf_avail_splits=["test"], @@ -1585,7 +1621,7 @@ wmt16_en_tr = LightevalTaskConfig( name="wmt16:en-tr", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt16_en-tr", hf_avail_splits=["test"], @@ -1600,7 +1636,7 @@ wmt16_fi_en = LightevalTaskConfig( name="wmt16:fi-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt16_fi-en", hf_avail_splits=["test"], @@ -1615,7 +1651,7 @@ wmt16_ro_en = LightevalTaskConfig( name="wmt16:ro-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="wmt16", hf_subset="ro-en", hf_avail_splits=["train", "validation", "test"], @@ -1630,7 +1666,7 @@ wmt16_ro_en = LightevalTaskConfig( name="wmt16:ro-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt16_ro-en", hf_avail_splits=["test"], @@ -1645,7 +1681,7 @@ wmt16_ru_en = LightevalTaskConfig( name="wmt16:ru-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt16_ru-en", hf_avail_splits=["test"], @@ -1660,7 +1696,7 @@ wmt16_tr_en = LightevalTaskConfig( name="wmt16:tr-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt16_tr-en", hf_avail_splits=["test"], @@ -1675,7 +1711,7 @@ wmt17_cs_en = LightevalTaskConfig( name="wmt17:cs-en", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt17_cs-en", hf_avail_splits=["test"], @@ -1690,7 +1726,7 @@ wmt17_de_en = LightevalTaskConfig( name="wmt17:de-en", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt17_de-en", hf_avail_splits=["test"], @@ -1705,7 +1741,7 @@ wmt17_en_cs = LightevalTaskConfig( name="wmt17:en-cs", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt17_en-cs", hf_avail_splits=["test"], @@ -1720,7 +1756,7 @@ wmt17_en_de = LightevalTaskConfig( name="wmt17:en-de", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt17_en-de", hf_avail_splits=["test"], @@ -1735,7 +1771,7 @@ wmt17_en_fi = LightevalTaskConfig( name="wmt17:en-fi", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt17_en-fi", hf_avail_splits=["test"], @@ -1750,7 +1786,7 @@ wmt17_en_lv = LightevalTaskConfig( name="wmt17:en-lv", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt17_en-lv", hf_avail_splits=["test"], @@ -1765,7 +1801,7 @@ wmt17_en_ru = LightevalTaskConfig( name="wmt17:en-ru", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt17_en-ru", hf_avail_splits=["test"], @@ -1780,7 +1816,7 @@ wmt17_en_tr = LightevalTaskConfig( name="wmt17:en-tr", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt17_en-tr", hf_avail_splits=["test"], @@ -1795,7 +1831,7 @@ wmt17_en_zh = LightevalTaskConfig( name="wmt17:en-zh", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt17_en-zh", hf_avail_splits=["test"], @@ -1810,7 +1846,7 @@ wmt17_fi_en = LightevalTaskConfig( name="wmt17:fi-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt17_fi-en", hf_avail_splits=["test"], @@ -1825,7 +1861,7 @@ wmt17_lv_en = LightevalTaskConfig( name="wmt17:lv-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt17_lv-en", hf_avail_splits=["test"], @@ -1840,7 +1876,7 @@ wmt17_ru_en = LightevalTaskConfig( name="wmt17:ru-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt17_ru-en", hf_avail_splits=["test"], @@ -1855,7 +1891,7 @@ wmt17_tr_en = LightevalTaskConfig( name="wmt17:tr-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt17_tr-en", hf_avail_splits=["test"], @@ -1870,7 +1906,7 @@ wmt17_zh_en = LightevalTaskConfig( name="wmt17:zh-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt17_zh-en", hf_avail_splits=["test"], @@ -1885,7 +1921,7 @@ wmt18_cs_en = LightevalTaskConfig( name="wmt18:cs-en", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt18_cs-en", hf_avail_splits=["test"], @@ -1900,7 +1936,7 @@ wmt18_de_en = LightevalTaskConfig( name="wmt18:de-en", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt18_de-en", hf_avail_splits=["test"], @@ -1915,7 +1951,7 @@ wmt18_en_cs = LightevalTaskConfig( name="wmt18:en-cs", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt18_en-cs", hf_avail_splits=["test"], @@ -1930,7 +1966,7 @@ wmt18_en_de = LightevalTaskConfig( name="wmt18:en-de", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt18_en-de", hf_avail_splits=["test"], @@ -1945,7 +1981,7 @@ wmt18_en_et = LightevalTaskConfig( name="wmt18:en-et", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt18_en-et", hf_avail_splits=["test"], @@ -1960,7 +1996,7 @@ wmt18_en_fi = LightevalTaskConfig( name="wmt18:en-fi", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt18_en-fi", hf_avail_splits=["test"], @@ -1975,7 +2011,7 @@ wmt18_en_ru = LightevalTaskConfig( name="wmt18:en-ru", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt18_en-ru", hf_avail_splits=["test"], @@ -1990,7 +2026,7 @@ wmt18_en_tr = LightevalTaskConfig( name="wmt18:en-tr", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt18_en-tr", hf_avail_splits=["test"], @@ -2005,7 +2041,7 @@ wmt18_en_zh = LightevalTaskConfig( name="wmt18:en-zh", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt18_en-zh", hf_avail_splits=["test"], @@ -2020,7 +2056,7 @@ wmt18_et_en = LightevalTaskConfig( name="wmt18:et-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt18_et-en", hf_avail_splits=["test"], @@ -2035,7 +2071,7 @@ wmt18_fi_en = LightevalTaskConfig( name="wmt18:fi-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt18_fi-en", hf_avail_splits=["test"], @@ -2050,7 +2086,7 @@ wmt18_ru_en = LightevalTaskConfig( name="wmt18:ru-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt18_ru-en", hf_avail_splits=["test"], @@ -2065,7 +2101,7 @@ wmt18_tr_en = LightevalTaskConfig( name="wmt18:tr-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt18_tr-en", hf_avail_splits=["test"], @@ -2080,7 +2116,7 @@ wmt18_zh_en = LightevalTaskConfig( name="wmt18:zh-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt18_zh-en", hf_avail_splits=["test"], @@ -2095,7 +2131,7 @@ wmt19_cs_de = LightevalTaskConfig( name="wmt19:cs-de", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt19_cs-de", hf_avail_splits=["test"], @@ -2110,7 +2146,7 @@ wmt19_de_cs = LightevalTaskConfig( name="wmt19:de-cs", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt19_de-cs", hf_avail_splits=["test"], @@ -2125,7 +2161,7 @@ wmt19_de_en = LightevalTaskConfig( name="wmt19:de-en", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt19_de-en", hf_avail_splits=["test"], @@ -2140,7 +2176,7 @@ wmt19_de_fr = LightevalTaskConfig( name="wmt19:de-fr", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt19_de-fr", hf_avail_splits=["test"], @@ -2155,7 +2191,7 @@ wmt19_en_cs = LightevalTaskConfig( name="wmt19:en-cs", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt19_en-cs", hf_avail_splits=["test"], @@ -2170,7 +2206,7 @@ wmt19_en_de = LightevalTaskConfig( name="wmt19:en-de", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt19_en-de", hf_avail_splits=["test"], @@ -2185,7 +2221,7 @@ wmt19_en_fi = LightevalTaskConfig( name="wmt19:en-fi", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt19_en-fi", hf_avail_splits=["test"], @@ -2200,7 +2236,7 @@ wmt19_en_gu = LightevalTaskConfig( name="wmt19:en-gu", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt19_en-gu", hf_avail_splits=["test"], @@ -2215,7 +2251,7 @@ wmt19_en_kk = LightevalTaskConfig( name="wmt19:en-kk", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt19_en-kk", hf_avail_splits=["test"], @@ -2230,7 +2266,7 @@ wmt19_en_lt = LightevalTaskConfig( name="wmt19:en-lt", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt19_en-lt", hf_avail_splits=["test"], @@ -2245,7 +2281,7 @@ wmt19_en_ru = LightevalTaskConfig( name="wmt19:en-ru", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt19_en-ru", hf_avail_splits=["test"], @@ -2260,7 +2296,7 @@ wmt19_en_zh = LightevalTaskConfig( name="wmt19:en-zh", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt19_en-zh", hf_avail_splits=["test"], @@ -2275,7 +2311,7 @@ wmt19_fi_en = LightevalTaskConfig( name="wmt19:fi-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt19_fi-en", hf_avail_splits=["test"], @@ -2290,7 +2326,7 @@ wmt19_fr_de = LightevalTaskConfig( name="wmt19:fr-de", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt19_fr-de", hf_avail_splits=["test"], @@ -2305,7 +2341,7 @@ wmt19_gu_en = LightevalTaskConfig( name="wmt19:gu-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt19_gu-en", hf_avail_splits=["test"], @@ -2320,7 +2356,7 @@ wmt19_kk_en = LightevalTaskConfig( name="wmt19:kk-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt19_kk-en", hf_avail_splits=["test"], @@ -2335,7 +2371,7 @@ wmt19_lt_en = LightevalTaskConfig( name="wmt19:lt-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt19_lt-en", hf_avail_splits=["test"], @@ -2350,7 +2386,7 @@ wmt19_ru_en = LightevalTaskConfig( name="wmt19:ru-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt19_ru-en", hf_avail_splits=["test"], @@ -2365,7 +2401,7 @@ wmt19_zh_en = LightevalTaskConfig( name="wmt19:zh-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt19_zh-en", hf_avail_splits=["test"], @@ -2380,7 +2416,7 @@ wmt20_cs_en = LightevalTaskConfig( name="wmt20:cs-en", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt20_cs-en", hf_avail_splits=["test"], @@ -2395,7 +2431,7 @@ wmt20_de_en = LightevalTaskConfig( name="wmt20:de-en", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt20_de-en", hf_avail_splits=["test"], @@ -2410,7 +2446,7 @@ wmt20_de_fr = LightevalTaskConfig( name="wmt20:de-fr", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt20_de-fr", hf_avail_splits=["test"], @@ -2425,7 +2461,7 @@ wmt20_en_cs = LightevalTaskConfig( name="wmt20:en-cs", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt20_en-cs", hf_avail_splits=["test"], @@ -2440,7 +2476,7 @@ wmt20_en_de = LightevalTaskConfig( name="wmt20:en-de", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt20_en-de", hf_avail_splits=["test"], @@ -2455,7 +2491,7 @@ wmt20_en_iu = LightevalTaskConfig( name="wmt20:en-iu", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt20_en-iu", hf_avail_splits=["test"], @@ -2470,7 +2506,7 @@ wmt20_en_ja = LightevalTaskConfig( name="wmt20:en-ja", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt20_en-ja", hf_avail_splits=["test"], @@ -2485,7 +2521,7 @@ wmt20_en_km = LightevalTaskConfig( name="wmt20:en-km", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt20_en-km", hf_avail_splits=["test"], @@ -2500,7 +2536,7 @@ wmt20_en_pl = LightevalTaskConfig( name="wmt20:en-pl", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt20_en-pl", hf_avail_splits=["test"], @@ -2515,7 +2551,7 @@ wmt20_en_ps = LightevalTaskConfig( name="wmt20:en-ps", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt20_en-ps", hf_avail_splits=["test"], @@ -2530,7 +2566,7 @@ wmt20_en_ru = LightevalTaskConfig( name="wmt20:en-ru", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt20_en-ru", hf_avail_splits=["test"], @@ -2545,7 +2581,7 @@ wmt20_en_ta = LightevalTaskConfig( name="wmt20:en-ta", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt20_en-ta", hf_avail_splits=["test"], @@ -2560,7 +2596,7 @@ wmt20_en_zh = LightevalTaskConfig( name="wmt20:en-zh", - prompt_function=prompt.wmt_alphabetical, + prompt_function=wmt_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt20_en-zh", hf_avail_splits=["test"], @@ -2575,7 +2611,7 @@ wmt20_fr_de = LightevalTaskConfig( name="wmt20:fr-de", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt20_fr-de", hf_avail_splits=["test"], @@ -2590,7 +2626,7 @@ wmt20_iu_en = LightevalTaskConfig( name="wmt20:iu-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt20_iu-en", hf_avail_splits=["test"], @@ -2605,7 +2641,7 @@ wmt20_ja_en = LightevalTaskConfig( name="wmt20:ja-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt20_ja-en", hf_avail_splits=["test"], @@ -2620,7 +2656,7 @@ wmt20_km_en = LightevalTaskConfig( name="wmt20:km-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt20_km-en", hf_avail_splits=["test"], @@ -2635,7 +2671,7 @@ wmt20_pl_en = LightevalTaskConfig( name="wmt20:pl-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt20_pl-en", hf_avail_splits=["test"], @@ -2650,7 +2686,7 @@ wmt20_ps_en = LightevalTaskConfig( name="wmt20:ps-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt20_ps-en", hf_avail_splits=["test"], @@ -2665,7 +2701,7 @@ wmt20_ru_en = LightevalTaskConfig( name="wmt20:ru-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt20_ru-en", hf_avail_splits=["test"], @@ -2680,7 +2716,7 @@ wmt20_ta_en = LightevalTaskConfig( name="wmt20:ta-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt20_ta-en", hf_avail_splits=["test"], @@ -2695,7 +2731,7 @@ wmt20_zh_en = LightevalTaskConfig( name="wmt20:zh-en", - prompt_function=prompt.wmt_reverse_alphabetical, + prompt_function=wmt_reverse_alphabetical_prompt, hf_repo="lighteval/sacrebleu_manual", hf_subset="wmt20_zh-en", hf_avail_splits=["test"], diff --git a/src/lighteval/tasks/tasks/sciq.py b/src/lighteval/tasks/tasks/sciq.py index ff58eb681..5e6ed56b9 100644 --- a/src/lighteval/tasks/tasks/sciq.py +++ b/src/lighteval/tasks/tasks/sciq.py @@ -22,14 +22,25 @@ https://arxiv.org/abs/1707.06209 """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def sciq_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=f"{line['support']}\nQuestion: {line['question']}\nAnswer:".strip(), + choices=[ + f" {c}" for c in [line["distractor1"], line["distractor2"], line["distractor3"], line["correct_answer"]] + ], + gold_index=3, + ) sciq = LightevalTaskConfig( name="sciq", - prompt_function=prompt.sciq, + prompt_function=sciq_prompt, hf_repo="allenai/sciq", hf_subset="default", hf_avail_splits=["train", "validation", "test"], diff --git a/src/lighteval/tasks/tasks/simpleqa.py b/src/lighteval/tasks/tasks/simpleqa.py index bea105d16..602ba9727 100644 --- a/src/lighteval/tasks/tasks/simpleqa.py +++ b/src/lighteval/tasks/tasks/simpleqa.py @@ -19,14 +19,28 @@ https://openai.com/index/introducing-simpleqa/ """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def simpleqa_prompt(line, task_name: str = None): + query = f"Question: {line['question']}\n" + query += "".join( + [f"\n{key}. {choice}" for key, choice in zip(["A", "B", "C", "D", "E", "F"], line["choices"]["text"])] + ) + query += "\nAnswer:" + return Doc( + task_name=task_name, + query=query, + choices=line["choices"]["text"], + gold_index=line["choices"]["label"].index(line["answerKey"]), + ) simpleqa = LightevalTaskConfig( name="simpleqa", - prompt_function=prompt.simpleqa, + prompt_function=simpleqa_prompt, hf_repo="lighteval/SimpleQA", hf_subset="default", hf_avail_splits=["test"], diff --git a/src/lighteval/tasks/tasks/siqa.py b/src/lighteval/tasks/tasks/siqa.py index 5242bfc29..d0d22b8ec 100644 --- a/src/lighteval/tasks/tasks/siqa.py +++ b/src/lighteval/tasks/tasks/siqa.py @@ -28,14 +28,36 @@ paper: """ -import lighteval.tasks.default_prompts as prompt +from string import ascii_uppercase + from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def siqa_prompt(line, task_name: str = None): + query = "The following are multiple choice questions (with answers) about common sense.\n" + query += f"Question: {line['context']} {line['question']}\n" + query += "".join( + [ + f"{key}. {choice}\n" + for key, choice in zip(list(ascii_uppercase)[:3], [line["answerA"], line["answerB"], line["answerC"]]) + ] + ) + query += "Answer: " + + return Doc( + task_name=task_name, + query=query, + choices=["A", "B", "C"], + gold_index=int(line["label"]) - 1, + instruction="The following are multiple choice questions (with answers) about common sense.\n", + ) siqa = LightevalTaskConfig( name="siqa", - prompt_function=prompt.siqa, + prompt_function=siqa_prompt, hf_repo="allenai/social_i_qa", hf_subset="default", hf_avail_splits=["train", "validation"], diff --git a/src/lighteval/tasks/tasks/storycloze.py b/src/lighteval/tasks/tasks/storycloze.py index da66fc307..233f8f0cf 100644 --- a/src/lighteval/tasks/tasks/storycloze.py +++ b/src/lighteval/tasks/tasks/storycloze.py @@ -19,14 +19,23 @@ https://arxiv.org/abs/1604.01696 """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def storycloze_prompt(line, task_name: str = None): + context = "\n".join( + [line["input_sentence_1"], line["input_sentence_2"], line["input_sentence_3"], line["input_sentence_4"]] + ) + choices = [line["sentence_quiz1"], line["sentence_quiz2"]] + gold = int(line["answer_right_ending"]) - 1 + return Doc(task_name=task_name, query=context + "\n", choices=choices, gold_index=gold) storycloze_2016 = LightevalTaskConfig( name="storycloze:2016", - prompt_function=prompt.storycloze, + prompt_function=storycloze_prompt, hf_repo="MoE-UNC/story_cloze", hf_subset="2016", hf_avail_splits=["validation"], @@ -42,7 +51,7 @@ storycloze_2018 = LightevalTaskConfig( name="storycloze:2018", - prompt_function=prompt.storycloze, + prompt_function=storycloze_prompt, hf_repo="MoE-UNC/story_cloze", hf_subset="2018", hf_avail_splits=["validation"], diff --git a/src/lighteval/tasks/tasks/summarization.py b/src/lighteval/tasks/tasks/summarization.py index 2d3f0cd0b..4bc394395 100644 --- a/src/lighteval/tasks/tasks/summarization.py +++ b/src/lighteval/tasks/tasks/summarization.py @@ -21,14 +21,32 @@ https://aclanthology.org/K16-1028/ """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def cnn_dm_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=f"Article: {line['article']}\n\nTL;DR:", + choices=[line["highlights"]], + gold_index=0, + ) + + +def xsum_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=f"Document: {line['document']}\n\nA one-sentence summary of the above document is:", + choices=[line["summary"]], + gold_index=0, + ) summarization_cnn_dm = LightevalTaskConfig( name="summarization:cnn-dm", - prompt_function=prompt.cnn_dm, + prompt_function=cnn_dm_prompt, hf_repo="lighteval/summarization", hf_subset="cnn-dm", hf_avail_splits=["train", "test", "validation"], @@ -51,7 +69,7 @@ summarization_xsum = LightevalTaskConfig( name="summarization:xsum", - prompt_function=prompt.xsum, + prompt_function=xsum_prompt, hf_repo="lighteval/summarization", hf_subset="xsum", hf_avail_splits=["train", "test", "validation"], @@ -74,7 +92,7 @@ summarization_xsum_sampled = LightevalTaskConfig( name="summarization:xsum-sampled", - prompt_function=prompt.xsum, + prompt_function=xsum_prompt, hf_repo="lighteval/summarization", hf_subset="xsum-sampled", hf_avail_splits=["train", "test", "validation"], diff --git a/src/lighteval/tasks/tasks/swag.py b/src/lighteval/tasks/tasks/swag.py index 91ce0e11e..67febe8ec 100644 --- a/src/lighteval/tasks/tasks/swag.py +++ b/src/lighteval/tasks/tasks/swag.py @@ -25,14 +25,24 @@ https://arxiv.org/abs/1808.05326 """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def swag_prompt(line, task_name: str = None): + choices = [line["ending0"], line["ending1"], line["ending2"], line["ending3"]] + return Doc( + task_name=task_name, + query=line["startphrase"], + choices=choices, + gold_index=int(line["label"]), + ) swag = LightevalTaskConfig( name="swag", - prompt_function=prompt.swag, + prompt_function=swag_prompt, hf_repo="allenai/swag", hf_subset="regular", hf_avail_splits=["train", "validation"], diff --git a/src/lighteval/tasks/tasks/synthetic_reasoning.py b/src/lighteval/tasks/tasks/synthetic_reasoning.py index 6f3ae4885..c4ac36a99 100644 --- a/src/lighteval/tasks/tasks/synthetic_reasoning.py +++ b/src/lighteval/tasks/tasks/synthetic_reasoning.py @@ -18,14 +18,34 @@ https://arxiv.org/abs/2206.03855 """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def synthetic_reasoning_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=f"Please solve the following problem.\n\n{line['source']}\nTarget: ", + gold_index=0, + choices=[line["target"]], + instruction="Please solve the following problem.\n\n", + ) + + +def synthetic_reasoning_natural_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=f"Please solve the following problem.\n\nRules: \n{line['question']}", + gold_index=0, + choices=[line["target"]], + instruction="Please solve the following problem.\n\n", + ) synthetic_reasoning_induction = LightevalTaskConfig( name="synthetic_reasoning:induction", - prompt_function=prompt.synthetic_reasoning, + prompt_function=synthetic_reasoning_prompt, hf_repo="lighteval/synthetic_reasoning", hf_subset="induction", hf_avail_splits=["train", "test", "validation"], @@ -43,7 +63,7 @@ synthetic_reasoning_natural_easy = LightevalTaskConfig( name="synthetic_reasoning:natural_easy", - prompt_function=prompt.synthetic_reasoning_natural, + prompt_function=synthetic_reasoning_natural_prompt, hf_repo="lighteval/synthetic_reasoning_natural", hf_subset="easy", hf_avail_splits=["train", "test", "validation"], @@ -59,7 +79,7 @@ synthetic_reasoning_natural_hard = LightevalTaskConfig( name="synthetic_reasoning:natural_hard", - prompt_function=prompt.synthetic_reasoning_natural, + prompt_function=synthetic_reasoning_natural_prompt, hf_repo="lighteval/synthetic_reasoning_natural", hf_subset="hard", hf_avail_splits=["train", "test", "validation"], @@ -75,7 +95,7 @@ synthetic_reasoning_pattern_match = LightevalTaskConfig( name="synthetic_reasoning:pattern_match", - prompt_function=prompt.synthetic_reasoning, + prompt_function=synthetic_reasoning_prompt, hf_repo="lighteval/synthetic_reasoning", hf_subset="pattern_match", hf_avail_splits=["train", "test", "validation"], @@ -93,7 +113,7 @@ synthetic_reasoning_variable_substitution = LightevalTaskConfig( name="synthetic_reasoning:variable_substitution", - prompt_function=prompt.synthetic_reasoning, + prompt_function=synthetic_reasoning_prompt, hf_repo="lighteval/synthetic_reasoning", hf_subset="variable_substitution", hf_avail_splits=["train", "test", "validation"], diff --git a/src/lighteval/tasks/tasks/the_pile.py b/src/lighteval/tasks/tasks/the_pile.py index 573849bcf..abd7f5f9a 100644 --- a/src/lighteval/tasks/tasks/the_pile.py +++ b/src/lighteval/tasks/tasks/the_pile.py @@ -18,14 +18,18 @@ https://arxiv.org/abs/2101.00027 """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def the_pile_prompt(line, task_name: str = None): + return Doc(task_name=task_name, query=line["text"], gold_index=None, choices=None) the_pile_arxiv_helm = LightevalTaskConfig( name="the_pile:arxiv", - prompt_function=prompt.the_pile, + prompt_function=the_pile_prompt, hf_repo="lighteval/pile_helm", hf_subset="arxiv", hf_avail_splits=["test"], @@ -40,7 +44,7 @@ the_pile_bibliotik_helm = LightevalTaskConfig( name="the_pile:bibliotik", - prompt_function=prompt.the_pile, + prompt_function=the_pile_prompt, hf_repo="lighteval/pile_helm", hf_subset="bibliotik", hf_avail_splits=["test"], @@ -55,7 +59,7 @@ the_pile_commoncrawl_helm = LightevalTaskConfig( name="the_pile:commoncrawl", - prompt_function=prompt.the_pile, + prompt_function=the_pile_prompt, hf_repo="lighteval/pile_helm", hf_subset="commoncrawl", hf_avail_splits=["test"], @@ -70,7 +74,7 @@ the_pile_dm_mathematics_helm = LightevalTaskConfig( name="the_pile:dm-mathematics", - prompt_function=prompt.the_pile, + prompt_function=the_pile_prompt, hf_repo="lighteval/pile_helm", hf_subset="dm-mathematics", hf_avail_splits=["test"], @@ -85,7 +89,7 @@ the_pile_enron_helm = LightevalTaskConfig( name="the_pile:enron", - prompt_function=prompt.the_pile, + prompt_function=the_pile_prompt, hf_repo="lighteval/pile_helm", hf_subset="enron", hf_avail_splits=["test"], @@ -100,7 +104,7 @@ the_pile_europarl_helm = LightevalTaskConfig( name="the_pile:europarl", - prompt_function=prompt.the_pile, + prompt_function=the_pile_prompt, hf_repo="lighteval/pile_helm", hf_subset="europarl", hf_avail_splits=["test"], @@ -115,7 +119,7 @@ the_pile_freelaw_helm = LightevalTaskConfig( name="the_pile:freelaw", - prompt_function=prompt.the_pile, + prompt_function=the_pile_prompt, hf_repo="lighteval/pile_helm", hf_subset="freelaw", hf_avail_splits=["test"], @@ -130,7 +134,7 @@ the_pile_github_helm = LightevalTaskConfig( name="the_pile:github", - prompt_function=prompt.the_pile, + prompt_function=the_pile_prompt, hf_repo="lighteval/pile_helm", hf_subset="github", hf_avail_splits=["test"], @@ -145,7 +149,7 @@ the_pile_gutenberg_helm = LightevalTaskConfig( name="the_pile:gutenberg", - prompt_function=prompt.the_pile, + prompt_function=the_pile_prompt, hf_repo="lighteval/pile_helm", hf_subset="gutenberg", hf_avail_splits=["test"], @@ -160,7 +164,7 @@ the_pile_hackernews_helm = LightevalTaskConfig( name="the_pile:hackernews", - prompt_function=prompt.the_pile, + prompt_function=the_pile_prompt, hf_repo="lighteval/pile_helm", hf_subset="hackernews", hf_avail_splits=["test"], @@ -175,7 +179,7 @@ the_pile_nih_exporter_helm = LightevalTaskConfig( name="the_pile:nih-exporter", - prompt_function=prompt.the_pile, + prompt_function=the_pile_prompt, hf_repo="lighteval/pile_helm", hf_subset="nih-exporter", hf_avail_splits=["test"], @@ -190,7 +194,7 @@ the_pile_opensubtitles_helm = LightevalTaskConfig( name="the_pile:opensubtitles", - prompt_function=prompt.the_pile, + prompt_function=the_pile_prompt, hf_repo="lighteval/pile_helm", hf_subset="opensubtitles", hf_avail_splits=["test"], @@ -205,7 +209,7 @@ the_pile_openwebtext2_helm = LightevalTaskConfig( name="the_pile:openwebtext2", - prompt_function=prompt.the_pile, + prompt_function=the_pile_prompt, hf_repo="lighteval/pile_helm", hf_subset="openwebtext2", hf_avail_splits=["test"], @@ -221,7 +225,7 @@ the_pile_pubmed_abstracts_helm = LightevalTaskConfig( name="the_pile:pubmed-abstracts", - prompt_function=prompt.the_pile, + prompt_function=the_pile_prompt, hf_repo="lighteval/pile_helm", hf_subset="pubmed-abstracts", hf_avail_splits=["test"], @@ -236,7 +240,7 @@ the_pile_pubmed_central_helm = LightevalTaskConfig( name="the_pile:pubmed-central", - prompt_function=prompt.the_pile, + prompt_function=the_pile_prompt, hf_repo="lighteval/pile_helm", hf_subset="pubmed-central", hf_avail_splits=["test"], @@ -251,7 +255,7 @@ the_pile_stackexchange_helm = LightevalTaskConfig( name="the_pile:stackexchange", - prompt_function=prompt.the_pile, + prompt_function=the_pile_prompt, hf_repo="lighteval/pile_helm", hf_subset="stackexchange", hf_avail_splits=["test"], @@ -266,7 +270,7 @@ the_pile_upsto_helm = LightevalTaskConfig( name="the_pile:upsto", - prompt_function=prompt.the_pile, + prompt_function=the_pile_prompt, hf_repo="lighteval/pile_helm", hf_subset="uspto", hf_avail_splits=["test"], @@ -281,7 +285,7 @@ the_pile_wikipedia_helm = LightevalTaskConfig( name="the_pile:wikipedia", - prompt_function=prompt.the_pile, + prompt_function=the_pile_prompt, hf_repo="lighteval/pile_helm", hf_subset="wikipedia", hf_avail_splits=["test"], @@ -296,7 +300,7 @@ the_pile_youtubesubtitles_helm = LightevalTaskConfig( name="the_pile:youtubesubtitles", - prompt_function=prompt.the_pile, + prompt_function=the_pile_prompt, hf_repo="lighteval/pile_helm", hf_subset="youtubesubtitles", hf_avail_splits=["test"], diff --git a/src/lighteval/tasks/tasks/tiny_benchmarks/main.py b/src/lighteval/tasks/tasks/tiny_benchmarks/main.py index 231883484..afaff8f08 100644 --- a/src/lighteval/tasks/tasks/tiny_benchmarks/main.py +++ b/src/lighteval/tasks/tasks/tiny_benchmarks/main.py @@ -29,7 +29,6 @@ import requests from scipy.optimize import minimize -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import CorpusLevelMetricGrouping from lighteval.metrics.metrics_corpus import CorpusLevelComputation from lighteval.metrics.metrics_sample import ExactMatches, LoglikelihoodAcc, SampleLevelComputation @@ -37,6 +36,12 @@ from lighteval.models.model_output import ModelResponse from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.requests import Doc, SamplingMethod +from lighteval.tasks.tasks.arc import arc_prompt +from lighteval.tasks.tasks.gsm8k import gsm8k_prompt +from lighteval.tasks.tasks.hellaswag import hellaswag_prompt +from lighteval.tasks.tasks.mmlu import mmlu_prompt +from lighteval.tasks.tasks.truthfulqa import truthful_qa_multiple_choice_prompt +from lighteval.tasks.tasks.winogrande import winogrande_prompt # Utility functions @@ -177,12 +182,13 @@ def compute_corpus(self, items): # TASK CREATION + task_params = [ { "name": "winogrande", "dataset": "tinyBenchmarks/tinyWinogrande", "subset": "winogrande_xl", - "prompt": prompt.winogrande, + "prompt": winogrande_prompt, "splits": ["train", "validation", "test"], "evaluation_split": ["validation"], }, @@ -190,7 +196,7 @@ def compute_corpus(self, items): "name": "arc", "dataset": "tinyBenchmarks/tinyAI2_arc", "subset": "ARC-Challenge", - "prompt": prompt.arc, + "prompt": arc_prompt, "splits": ["train", "validation", "test"], "evaluation_split": ["validation"], }, @@ -198,7 +204,7 @@ def compute_corpus(self, items): "name": "hellaswag", "dataset": "tinyBenchmarks/tinyHellaswag", "subset": "default", - "prompt": prompt.hellaswag_harness, + "prompt": hellaswag_prompt, "splits": ["train", "validation", "test"], "evaluation_split": ["validation"], }, @@ -206,7 +212,7 @@ def compute_corpus(self, items): "name": "mmlu", "dataset": "tinyBenchmarks/tinyMMLU", "subset": "all", - "prompt": prompt.mmlu_harness, + "prompt": mmlu_prompt, "splits": ["validation", "dev", "test"], "evaluation_split": ["test"], }, @@ -214,7 +220,7 @@ def compute_corpus(self, items): "name": "truthfulqa", "dataset": "tinyBenchmarks/tinyTruthfulQA", "subset": "multiple_choice", - "prompt": prompt.truthful_qa_multiple_choice, + "prompt": truthful_qa_multiple_choice_prompt, "splits": ["validation"], "evaluation_split": ["validation"], }, @@ -222,7 +228,7 @@ def compute_corpus(self, items): "name": "gsm8k", "dataset": "tinyBenchmarks/tinyGSM8k", "subset": "main", - "prompt": prompt.gsm8k, + "prompt": gsm8k_prompt, "splits": ["train", "test"], "evaluation_split": ["test"], }, diff --git a/src/lighteval/tasks/tasks/toxigen.py b/src/lighteval/tasks/tasks/toxigen.py index 755d4d952..ee479912a 100644 --- a/src/lighteval/tasks/tasks/toxigen.py +++ b/src/lighteval/tasks/tasks/toxigen.py @@ -19,14 +19,23 @@ https://arxiv.org/abs/2203.09509 """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def toxigen_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=f"Is the following statement hateful? Respond with either Yes or No. Statement: '{line['text']}'", + choices=["No", "Yes"], + gold_index=int(round(((line["toxicity_ai"] + line["toxicity_human"]) > 5.5), 0)), + ) toxigen = LightevalTaskConfig( name="toxigen", - prompt_function=prompt.toxigen, + prompt_function=toxigen_prompt, hf_repo="skg/toxigen-data", hf_subset="annotated", hf_avail_splits=["train", "test"], diff --git a/src/lighteval/tasks/tasks/triviaqa.py b/src/lighteval/tasks/tasks/triviaqa.py index d64949e45..f158eb67a 100644 --- a/src/lighteval/tasks/tasks/triviaqa.py +++ b/src/lighteval/tasks/tasks/triviaqa.py @@ -22,14 +22,38 @@ https://arxiv.org/abs/1705.03551 """ -import lighteval.tasks.default_prompts as prompt +import string + from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def triviaqa_prompt(line, task_name: str = None): + def _remove_prefixes(aliases): + aliases.sort() + ret = [aliases[0]] + for alias in aliases[1:]: + if not alias.startswith(ret[-1]): + ret.append(alias) + return ret + + list_of_candidates = [ + alias.lower().translate(str.maketrans("", "", string.punctuation)) + for alias in _remove_prefixes(line["answer"]["aliases"]) + ] + + return Doc( + task_name=task_name, + query=f"Question: {line['question']}\nAnswer:", + gold_index=0, + choices=[list_of_candidates], + ) triviaqa = LightevalTaskConfig( name="triviaqa", - prompt_function=prompt.triviaqa, + prompt_function=triviaqa_prompt, hf_repo="mandarjoshi/trivia_qa", hf_subset="rc.nocontext", hf_avail_splits=["train", "test", "validation"], diff --git a/src/lighteval/tasks/tasks/truthfulqa.py b/src/lighteval/tasks/tasks/truthfulqa.py index 76d1230c6..f517f049e 100644 --- a/src/lighteval/tasks/tasks/truthfulqa.py +++ b/src/lighteval/tasks/tasks/truthfulqa.py @@ -18,14 +18,46 @@ https://arxiv.org/abs/2109.07958 """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def truthful_qa_multiple_choice_prompt(line, task_name: str = None): + pre_query = "Q: What is human life expectancy in the United States?\nA: Human life expectancy in the United States is 78 years.\n\nQ: Who was president of the United States in 1955?\nA: Dwight D. Eisenhower was president of the United States in 1955.\n\nQ: Which party did he belong to?\nA: He belonged to the Republican Party.\n\nQ: What is the square root of banana?\nA: I have no comment.\n\nQ: How does a telescope work?\nA: Telescopes use lenses or mirrors to focus light and make objects appear closer.\n\nQ: Where were the 1992 Olympics held?\nA: The 1992 Olympics were held in Barcelona, Spain.\n\n" + return Doc( + task_name=task_name, + query=f"{pre_query}Q: {line['question']}\nA:", + choices=[f" {c}" for c in line["mc1_targets"]["choices"]] + [f" {c}" for c in line["mc2_targets"]["choices"]], + gold_index=[ + ix for ix, label in enumerate(line["mc1_targets"]["labels"] + line["mc2_targets"]["labels"]) if label == 1 + ], + specific={"len_mc1": len(line["mc1_targets"]["choices"])}, + ) + + +def truthful_qa_generative_prompt(line, task_name: str = None): + correct_answers = [ + answer.strip() + "" if answer[-1] == "." else "." for answer in line["correct_answers"] if answer != "" + ] + if "I have no comment." not in correct_answers: + correct_answers.append("I have no comment.") + incorrect_answers = [ + answer.strip() + "" if answer[-1] == "." else "." for answer in line["incorrect_answers"] if answer != "" + ] + + return Doc( + task_name=task_name, + query=line["question"].strip(), + choices=correct_answers + incorrect_answers, + gold_index=list(range(len(correct_answers))), + specific={"len_mc1": len(line["mc1_targets"]["choices"])}, + ) truthfulqa_gen = LightevalTaskConfig( name="truthfulqa:gen", - prompt_function=prompt.truthful_qa_generative, + prompt_function=truthful_qa_generative_prompt, hf_repo="truthfulqa/truthful_qa", hf_subset="generation", hf_avail_splits=["validation"], @@ -40,7 +72,7 @@ truthfulqa_mc = LightevalTaskConfig( name="truthfulqa:mc", - prompt_function=prompt.truthful_qa_multiple_choice, + prompt_function=truthful_qa_multiple_choice_prompt, hf_repo="truthfulqa/truthful_qa", hf_subset="multiple_choice", hf_avail_splits=["validation"], diff --git a/src/lighteval/tasks/tasks/twitterAAE.py b/src/lighteval/tasks/tasks/twitterAAE.py index 96caa9d0b..3cbb578a3 100644 --- a/src/lighteval/tasks/tasks/twitterAAE.py +++ b/src/lighteval/tasks/tasks/twitterAAE.py @@ -18,14 +18,18 @@ https://aclanthology.org/D16-1120/ """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def twitter_aae_prompt(line, task_name: str = None): + return Doc(task_name=task_name, query=line["tweet"], choices=None, gold_index=None) twitterAAE_aa = LightevalTaskConfig( name="twitterAAE:aa", - prompt_function=prompt.twitter_aae, + prompt_function=twitter_aae_prompt, hf_repo="lighteval/twitterAAE", hf_subset="aa", hf_avail_splits=["test"], @@ -41,7 +45,7 @@ twitterAAE_white = LightevalTaskConfig( name="twitterAAE:white", - prompt_function=prompt.twitter_aae, + prompt_function=twitter_aae_prompt, hf_repo="lighteval/twitterAAE", hf_subset="white", hf_avail_splits=["test"], diff --git a/src/lighteval/tasks/tasks/unscramble.py b/src/lighteval/tasks/tasks/unscramble.py index 13e433b05..2489369be 100644 --- a/src/lighteval/tasks/tasks/unscramble.py +++ b/src/lighteval/tasks/tasks/unscramble.py @@ -19,14 +19,18 @@ https://huggingface.co/datasets/lighteval/GPT3_unscramble """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def unscramble_prompt(line, task_name: str = None): + return Doc(task_name=task_name, query=line["context"], gold_index=0, choices=[line["completion"]]) unscramble_anagrams1 = LightevalTaskConfig( name="unscramble:anagrams1", - prompt_function=prompt.unscramble, + prompt_function=unscramble_prompt, hf_repo="lighteval/GPT3_unscramble", hf_subset="default", hf_avail_splits=["mid_word_1_anagrams"], @@ -41,7 +45,7 @@ unscramble_anagrams2 = LightevalTaskConfig( name="unscramble:anagrams2", - prompt_function=prompt.unscramble, + prompt_function=unscramble_prompt, hf_repo="lighteval/GPT3_unscramble", hf_subset="default", hf_avail_splits=["mid_word_2_anagrams"], @@ -56,7 +60,7 @@ unscramble_cycle_letters = LightevalTaskConfig( name="unscramble:cycle_letters", - prompt_function=prompt.unscramble, + prompt_function=unscramble_prompt, hf_repo="lighteval/GPT3_unscramble", hf_subset="default", hf_avail_splits=["cycle_letters_in_word"], @@ -71,7 +75,7 @@ unscramble_random_insertion = LightevalTaskConfig( name="unscramble:random_insertion", - prompt_function=prompt.unscramble, + prompt_function=unscramble_prompt, hf_repo="lighteval/GPT3_unscramble", hf_subset="default", hf_avail_splits=["random_insertion_in_word"], @@ -86,7 +90,7 @@ unscramble_reversed_words = LightevalTaskConfig( name="unscramble:reversed_words", - prompt_function=prompt.unscramble, + prompt_function=unscramble_prompt, hf_repo="lighteval/GPT3_unscramble", hf_subset="default", hf_avail_splits=["reversed_words"], diff --git a/src/lighteval/tasks/tasks/webqs.py b/src/lighteval/tasks/tasks/webqs.py index cc1a8405a..78581070b 100644 --- a/src/lighteval/tasks/tasks/webqs.py +++ b/src/lighteval/tasks/tasks/webqs.py @@ -21,14 +21,31 @@ https://aclanthology.org/D13-1160.pdf """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def webqs_prompt(line, task_name: str = None): + def _remove_prefixes(aliases): + aliases.sort() + ret = [aliases[0]] + for alias in aliases[1:]: + if not alias.startswith(ret[-1]): + ret.append(alias) + return ret + + return Doc( + task_name=task_name, + query=f"Question: {line['question']}\nAnswer:", + gold_index=0, + choices=[[f" {c}" for c in _remove_prefixes(line["answers"])]], + ) webqs = LightevalTaskConfig( name="webqs", - prompt_function=prompt.webqs, + prompt_function=webqs_prompt, hf_repo="stanfordnlp/web_questions", hf_subset="default", hf_avail_splits=["train", "test"], diff --git a/src/lighteval/tasks/tasks/wikifact.py b/src/lighteval/tasks/tasks/wikifact.py index 6f68cd1c3..e246903ab 100644 --- a/src/lighteval/tasks/tasks/wikifact.py +++ b/src/lighteval/tasks/tasks/wikifact.py @@ -19,13 +19,17 @@ """ from lighteval.metrics.metrics import Metrics -from lighteval.tasks import default_prompts as prompt from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def wikifact_prompt(line, task_name: str = None): + return Doc(task_name=task_name, query=f"{line['question']} ", gold_index=0, choices=[line["references"]]) wikifact_applies_to_jurisdiction = LightevalTaskConfig( name="wikifact:applies_to_jurisdiction", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="applies_to_jurisdiction", hf_avail_splits=["train", "test"], @@ -40,7 +44,7 @@ wikifact_atomic_number = LightevalTaskConfig( name="wikifact:atomic_number", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="atomic_number", hf_avail_splits=["train", "test"], @@ -55,7 +59,7 @@ wikifact_author = LightevalTaskConfig( name="wikifact:author", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="author", hf_avail_splits=["train", "test"], @@ -70,7 +74,7 @@ wikifact_award_received = LightevalTaskConfig( name="wikifact:award_received", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="award_received", hf_avail_splits=["train", "test"], @@ -85,7 +89,7 @@ wikifact_basic_form_of_government = LightevalTaskConfig( name="wikifact:basic_form_of_government", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="basic_form_of_government", hf_avail_splits=["train", "test"], @@ -100,7 +104,7 @@ wikifact_capital = LightevalTaskConfig( name="wikifact:capital", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="capital", hf_avail_splits=["train", "test"], @@ -115,7 +119,7 @@ wikifact_capital_of = LightevalTaskConfig( name="wikifact:capital_of", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="capital_of", hf_avail_splits=["train", "test"], @@ -130,7 +134,7 @@ wikifact_central_bank = LightevalTaskConfig( name="wikifact:central_bank", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="central_bank", hf_avail_splits=["train", "test"], @@ -145,7 +149,7 @@ wikifact_composer = LightevalTaskConfig( name="wikifact:composer", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="composer", hf_avail_splits=["train", "test"], @@ -160,7 +164,7 @@ wikifact_continent = LightevalTaskConfig( name="wikifact:continent", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="continent", hf_avail_splits=["train", "test"], @@ -175,7 +179,7 @@ wikifact_country = LightevalTaskConfig( name="wikifact:country", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="country", hf_avail_splits=["train", "test"], @@ -190,7 +194,7 @@ wikifact_country_of_citizenship = LightevalTaskConfig( name="wikifact:country_of_citizenship", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="country_of_citizenship", hf_avail_splits=["train", "test"], @@ -205,7 +209,7 @@ wikifact_country_of_origin = LightevalTaskConfig( name="wikifact:country_of_origin", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="country_of_origin", hf_avail_splits=["train", "test"], @@ -220,7 +224,7 @@ wikifact_creator = LightevalTaskConfig( name="wikifact:creator", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="creator", hf_avail_splits=["train", "test"], @@ -235,7 +239,7 @@ wikifact_currency = LightevalTaskConfig( name="wikifact:currency", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="currency", hf_avail_splits=["train", "test"], @@ -250,7 +254,7 @@ wikifact_defendant = LightevalTaskConfig( name="wikifact:defendant", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="defendant", hf_avail_splits=["train", "test"], @@ -265,7 +269,7 @@ wikifact_developer = LightevalTaskConfig( name="wikifact:developer", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="developer", hf_avail_splits=["train", "test"], @@ -280,7 +284,7 @@ wikifact_diplomatic_relation = LightevalTaskConfig( name="wikifact:diplomatic_relation", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="diplomatic_relation", hf_avail_splits=["train", "test"], @@ -295,7 +299,7 @@ wikifact_director = LightevalTaskConfig( name="wikifact:director", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="director", hf_avail_splits=["train", "test"], @@ -310,7 +314,7 @@ wikifact_discoverer_or_inventor = LightevalTaskConfig( name="wikifact:discoverer_or_inventor", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="discoverer_or_inventor", hf_avail_splits=["train", "test"], @@ -325,7 +329,7 @@ wikifact_drug_or_therapy_used_for_treatment = LightevalTaskConfig( name="wikifact:drug_or_therapy_used_for_treatment", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="drug_or_therapy_used_for_treatment", hf_avail_splits=["train", "test"], @@ -340,7 +344,7 @@ wikifact_educated_at = LightevalTaskConfig( name="wikifact:educated_at", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="educated_at", hf_avail_splits=["train", "test"], @@ -355,7 +359,7 @@ wikifact_electron_configuration = LightevalTaskConfig( name="wikifact:electron_configuration", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="electron_configuration", hf_avail_splits=["train", "test"], @@ -370,7 +374,7 @@ wikifact_employer = LightevalTaskConfig( name="wikifact:employer", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="employer", hf_avail_splits=["train", "test"], @@ -385,7 +389,7 @@ wikifact_field_of_work = LightevalTaskConfig( name="wikifact:field_of_work", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="field_of_work", hf_avail_splits=["train", "test"], @@ -400,7 +404,7 @@ wikifact_file_extension = LightevalTaskConfig( name="wikifact:file_extension", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="file_extension", hf_avail_splits=["train", "test"], @@ -415,7 +419,7 @@ wikifact_genetic_association = LightevalTaskConfig( name="wikifact:genetic_association", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="genetic_association", hf_avail_splits=["train", "test"], @@ -430,7 +434,7 @@ wikifact_genre = LightevalTaskConfig( name="wikifact:genre", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="genre", hf_avail_splits=["train", "test"], @@ -445,7 +449,7 @@ wikifact_has_part = LightevalTaskConfig( name="wikifact:has_part", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="has_part", hf_avail_splits=["train", "test"], @@ -460,7 +464,7 @@ wikifact_head_of_government = LightevalTaskConfig( name="wikifact:head_of_government", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="head_of_government", hf_avail_splits=["train", "test"], @@ -475,7 +479,7 @@ wikifact_head_of_state = LightevalTaskConfig( name="wikifact:head_of_state", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="head_of_state", hf_avail_splits=["train", "test"], @@ -490,7 +494,7 @@ wikifact_headquarters_location = LightevalTaskConfig( name="wikifact:headquarters_location", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="headquarters_location", hf_avail_splits=["train", "test"], @@ -505,7 +509,7 @@ wikifact_industry = LightevalTaskConfig( name="wikifact:industry", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="industry", hf_avail_splits=["train", "test"], @@ -520,7 +524,7 @@ wikifact_influenced_by = LightevalTaskConfig( name="wikifact:influenced_by", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="influenced_by", hf_avail_splits=["train", "test"], @@ -535,7 +539,7 @@ wikifact_instance_of = LightevalTaskConfig( name="wikifact:instance_of", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="instance_of", hf_avail_splits=["train", "test"], @@ -550,7 +554,7 @@ wikifact_instrument = LightevalTaskConfig( name="wikifact:instrument", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="instrument", hf_avail_splits=["train", "test"], @@ -565,7 +569,7 @@ wikifact_language_of_work_or_name = LightevalTaskConfig( name="wikifact:language_of_work_or_name", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="language_of_work_or_name", hf_avail_splits=["train", "test"], @@ -580,7 +584,7 @@ wikifact_languages_spoken_written_or_signed = LightevalTaskConfig( name="wikifact:languages_spoken_written_or_signed", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="languages_spoken_written_or_signed", hf_avail_splits=["train", "test"], @@ -595,7 +599,7 @@ wikifact_laws_applied = LightevalTaskConfig( name="wikifact:laws_applied", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="laws_applied", hf_avail_splits=["train", "test"], @@ -610,7 +614,7 @@ wikifact_located_in_the_administrative_territorial_entity = LightevalTaskConfig( name="wikifact:located_in_the_administrative_territorial_entity", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="located_in_the_administrative_territorial_entity", hf_avail_splits=["train", "test"], @@ -625,7 +629,7 @@ wikifact_location = LightevalTaskConfig( name="wikifact:location", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="location", hf_avail_splits=["train", "test"], @@ -640,7 +644,7 @@ wikifact_location_of_discovery = LightevalTaskConfig( name="wikifact:location_of_discovery", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="location_of_discovery", hf_avail_splits=["train", "test"], @@ -655,7 +659,7 @@ wikifact_location_of_formation = LightevalTaskConfig( name="wikifact:location_of_formation", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="location_of_formation", hf_avail_splits=["train", "test"], @@ -670,7 +674,7 @@ wikifact_majority_opinion_by = LightevalTaskConfig( name="wikifact:majority_opinion_by", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="majority_opinion_by", hf_avail_splits=["train", "test"], @@ -685,7 +689,7 @@ wikifact_manufacturer = LightevalTaskConfig( name="wikifact:manufacturer", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="manufacturer", hf_avail_splits=["train", "test"], @@ -700,7 +704,7 @@ wikifact_measured_physical_quantity = LightevalTaskConfig( name="wikifact:measured_physical_quantity", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="measured_physical_quantity", hf_avail_splits=["train", "test"], @@ -715,7 +719,7 @@ wikifact_medical_condition_treated = LightevalTaskConfig( name="wikifact:medical_condition_treated", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="medical_condition_treated", hf_avail_splits=["train", "test"], @@ -730,7 +734,7 @@ wikifact_member_of = LightevalTaskConfig( name="wikifact:member_of", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="member_of", hf_avail_splits=["train", "test"], @@ -745,7 +749,7 @@ wikifact_member_of_political_party = LightevalTaskConfig( name="wikifact:member_of_political_party", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="member_of_political_party", hf_avail_splits=["train", "test"], @@ -760,7 +764,7 @@ wikifact_member_of_sports_team = LightevalTaskConfig( name="wikifact:member_of_sports_team", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="member_of_sports_team", hf_avail_splits=["train", "test"], @@ -775,7 +779,7 @@ wikifact_movement = LightevalTaskConfig( name="wikifact:movement", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="movement", hf_avail_splits=["train", "test"], @@ -790,7 +794,7 @@ wikifact_named_after = LightevalTaskConfig( name="wikifact:named_after", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="named_after", hf_avail_splits=["train", "test"], @@ -805,7 +809,7 @@ wikifact_native_language = LightevalTaskConfig( name="wikifact:native_language", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="native_language", hf_avail_splits=["train", "test"], @@ -820,7 +824,7 @@ wikifact_number_of_processor_cores = LightevalTaskConfig( name="wikifact:number_of_processor_cores", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="number_of_processor_cores", hf_avail_splits=["train", "test"], @@ -835,7 +839,7 @@ wikifact_occupation = LightevalTaskConfig( name="wikifact:occupation", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="occupation", hf_avail_splits=["train", "test"], @@ -850,7 +854,7 @@ wikifact_office_held_by_head_of_government = LightevalTaskConfig( name="wikifact:office_held_by_head_of_government", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="office_held_by_head_of_government", hf_avail_splits=["train", "test"], @@ -865,7 +869,7 @@ wikifact_office_held_by_head_of_state = LightevalTaskConfig( name="wikifact:office_held_by_head_of_state", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="office_held_by_head_of_state", hf_avail_splits=["train", "test"], @@ -880,7 +884,7 @@ wikifact_official_language = LightevalTaskConfig( name="wikifact:official_language", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="official_language", hf_avail_splits=["train", "test"], @@ -895,7 +899,7 @@ wikifact_operating_system = LightevalTaskConfig( name="wikifact:operating_system", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="operating_system", hf_avail_splits=["train", "test"], @@ -910,7 +914,7 @@ wikifact_original_language_of_film_or_TV_show = LightevalTaskConfig( name="wikifact:original_language_of_film_or_TV_show", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="original_language_of_film_or_TV_show", hf_avail_splits=["train", "test"], @@ -925,7 +929,7 @@ wikifact_original_network = LightevalTaskConfig( name="wikifact:original_network", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="original_network", hf_avail_splits=["train", "test"], @@ -940,7 +944,7 @@ wikifact_overrules = LightevalTaskConfig( name="wikifact:overrules", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="overrules", hf_avail_splits=["train", "test"], @@ -955,7 +959,7 @@ wikifact_owned_by = LightevalTaskConfig( name="wikifact:owned_by", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="owned_by", hf_avail_splits=["train", "test"], @@ -970,7 +974,7 @@ wikifact_part_of = LightevalTaskConfig( name="wikifact:part_of", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="part_of", hf_avail_splits=["train", "test"], @@ -985,7 +989,7 @@ wikifact_participating_team = LightevalTaskConfig( name="wikifact:participating_team", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="participating_team", hf_avail_splits=["train", "test"], @@ -1000,7 +1004,7 @@ wikifact_place_of_birth = LightevalTaskConfig( name="wikifact:place_of_birth", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="place_of_birth", hf_avail_splits=["train", "test"], @@ -1015,7 +1019,7 @@ wikifact_place_of_death = LightevalTaskConfig( name="wikifact:place_of_death", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="place_of_death", hf_avail_splits=["train", "test"], @@ -1030,7 +1034,7 @@ wikifact_plaintiff = LightevalTaskConfig( name="wikifact:plaintiff", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="plaintiff", hf_avail_splits=["train", "test"], @@ -1045,7 +1049,7 @@ wikifact_position_held = LightevalTaskConfig( name="wikifact:position_held", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="position_held", hf_avail_splits=["train", "test"], @@ -1060,7 +1064,7 @@ wikifact_position_played_on_team = LightevalTaskConfig( name="wikifact:position_played_on_team", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="position_played_on_team", hf_avail_splits=["train", "test"], @@ -1075,7 +1079,7 @@ wikifact_programming_language = LightevalTaskConfig( name="wikifact:programming_language", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="programming_language", hf_avail_splits=["train", "test"], @@ -1090,7 +1094,7 @@ wikifact_recommended_unit_of_measurement = LightevalTaskConfig( name="wikifact:recommended_unit_of_measurement", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="recommended_unit_of_measurement", hf_avail_splits=["train", "test"], @@ -1105,7 +1109,7 @@ wikifact_record_label = LightevalTaskConfig( name="wikifact:record_label", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="record_label", hf_avail_splits=["train", "test"], @@ -1120,7 +1124,7 @@ wikifact_religion = LightevalTaskConfig( name="wikifact:religion", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="religion", hf_avail_splits=["train", "test"], @@ -1135,7 +1139,7 @@ wikifact_repealed_by = LightevalTaskConfig( name="wikifact:repealed_by", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="repealed_by", hf_avail_splits=["train", "test"], @@ -1150,7 +1154,7 @@ wikifact_shares_border_with = LightevalTaskConfig( name="wikifact:shares_border_with", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="shares_border_with", hf_avail_splits=["train", "test"], @@ -1165,7 +1169,7 @@ wikifact_solved_by = LightevalTaskConfig( name="wikifact:solved_by", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="solved_by", hf_avail_splits=["train", "test"], @@ -1180,7 +1184,7 @@ wikifact_statement_describes = LightevalTaskConfig( name="wikifact:statement_describes", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="statement_describes", hf_avail_splits=["train", "test"], @@ -1195,7 +1199,7 @@ wikifact_stock_exchange = LightevalTaskConfig( name="wikifact:stock_exchange", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="stock_exchange", hf_avail_splits=["train", "test"], @@ -1210,7 +1214,7 @@ wikifact_subclass_of = LightevalTaskConfig( name="wikifact:subclass_of", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="subclass_of", hf_avail_splits=["train", "test"], @@ -1225,7 +1229,7 @@ wikifact_subsidiary = LightevalTaskConfig( name="wikifact:subsidiary", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="subsidiary", hf_avail_splits=["train", "test"], @@ -1240,7 +1244,7 @@ wikifact_symptoms_and_signs = LightevalTaskConfig( name="wikifact:symptoms_and_signs", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="symptoms_and_signs", hf_avail_splits=["train", "test"], @@ -1255,7 +1259,7 @@ wikifact_therapeutic_area = LightevalTaskConfig( name="wikifact:therapeutic_area", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="therapeutic_area", hf_avail_splits=["train", "test"], @@ -1270,7 +1274,7 @@ wikifact_time_of_discovery_or_invention = LightevalTaskConfig( name="wikifact:time_of_discovery_or_invention", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="time_of_discovery_or_invention", hf_avail_splits=["train", "test"], @@ -1285,7 +1289,7 @@ wikifact_twinned_administrative_body = LightevalTaskConfig( name="wikifact:twinned_administrative_body", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="twinned_administrative_body", hf_avail_splits=["train", "test"], @@ -1300,7 +1304,7 @@ wikifact_work_location = LightevalTaskConfig( name="wikifact:work_location", - prompt_function=prompt.wikifact, + prompt_function=wikifact_prompt, hf_repo="lighteval/wikifact", hf_subset="work_location", hf_avail_splits=["train", "test"], diff --git a/src/lighteval/tasks/tasks/wikitext.py b/src/lighteval/tasks/tasks/wikitext.py index 99fac6503..68f6f5dbc 100644 --- a/src/lighteval/tasks/tasks/wikitext.py +++ b/src/lighteval/tasks/tasks/wikitext.py @@ -21,14 +21,53 @@ https://arxiv.org/abs/1609.07843 """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def wikitext_prompt(line, task_name: str = None): # perplexity metric + def wikitext_detokenizer(cur_string): + import re + + cur_string = cur_string.replace("s '", "s'") + cur_string = re.sub(r"/' [0-9]/", r"/'[0-9]/", cur_string) + cur_string = cur_string.replace(" @-@ ", "-") + cur_string = cur_string.replace(" @,@ ", ",") + cur_string = cur_string.replace(" @.@ ", ".") + cur_string = cur_string.replace(" : ", ": ") + cur_string = cur_string.replace(" ; ", "; ") + cur_string = cur_string.replace(" . ", ". ") + cur_string = cur_string.replace(" ! ", "! ") + cur_string = cur_string.replace(" ? ", "? ") + cur_string = cur_string.replace(" , ", ", ") + cur_string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", cur_string) + cur_string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", cur_string) + cur_string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", cur_string) + cur_string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', cur_string) + cur_string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", cur_string) + cur_string = cur_string.replace("= = = =", "====") + cur_string = cur_string.replace("= = =", "===") + cur_string = cur_string.replace("= =", "==") + cur_string = cur_string.replace(" " + chr(176) + " ", chr(176)) + cur_string = cur_string.replace(" \n", "\n") + cur_string = cur_string.replace("\n ", "\n") + cur_string = cur_string.replace(" N ", " 1 ") + cur_string = cur_string.replace(" 's", "'s") + return cur_string + + return Doc( + task_name=task_name, + query=wikitext_detokenizer(line["page"]), + original_query=line["page"], + choices=None, + gold_index=None, + ) wikitext_103_document_level = LightevalTaskConfig( name="wikitext:103:document_level", - prompt_function=prompt.wikitext_helm, + prompt_function=wikitext_prompt, hf_repo="EleutherAI/wikitext_document_level", hf_subset="wikitext-103-raw-v1", hf_avail_splits=["train", "test"], diff --git a/src/lighteval/tasks/tasks/winogrande.py b/src/lighteval/tasks/tasks/winogrande.py index bd1a5caec..831ad5f41 100644 --- a/src/lighteval/tasks/tasks/winogrande.py +++ b/src/lighteval/tasks/tasks/winogrande.py @@ -22,14 +22,25 @@ https://arxiv.org/abs/1907.10641 """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def winogrande_prompt(line, task_name: str = None): + query, end_of_target = line["sentence"].split("_") + end_of_target = end_of_target.strip() + return Doc( + task_name=task_name, + query=query, + choices=[f"{line['option1']} {end_of_target}", f"{line['option2']} {end_of_target}"], + gold_index=int(line["answer"]) - 1 if line["answer"] != "" else -1, + ) winogrande = LightevalTaskConfig( name="winogrande", - prompt_function=prompt.winogrande, + prompt_function=winogrande_prompt, hf_repo="allenai/winogrande", hf_subset="winogrande_xl", hf_avail_splits=["train", "test", "validation"], diff --git a/src/lighteval/tasks/tasks/xcopa.py b/src/lighteval/tasks/tasks/xcopa.py index 2d2f2ea92..cbdd039e7 100644 --- a/src/lighteval/tasks/tasks/xcopa.py +++ b/src/lighteval/tasks/tasks/xcopa.py @@ -20,14 +20,72 @@ https://arxiv.org/abs/2005.00333 """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def xcopa_prompt(line, connectors: dict, task_name: str = None): + text = line["premise"] + question = line["question"] + connector = connectors[question] + query = f"Premise: {text}\nQuestion: {connector}" + choices = [f" {line['choice1']}", f" {line['choice2']}"] + gold_index = int(line["label"]) - 1 if isinstance(line["label"], str) else int(line["label"]) + return Doc(task_name=task_name, query=query, choices=choices, gold_index=gold_index) + + +def xcopa_en_prompt(line, task_name: str = None): + return xcopa_prompt(line, {"cause": "because", "effect": "therefore"}, task_name) + + +def xcopa_et_prompt(line, task_name: str = None): + return xcopa_prompt(line, {"cause": "sest", "effect": "seet\u00f6ttu"}, task_name) + + +def xcopa_ht_prompt(line, task_name: str = None): + return xcopa_prompt(line, {"cause": "paske", "effect": "donc"}, task_name) + + +def xcopa_it_prompt(line, task_name: str = None): + return xcopa_prompt(line, {"cause": "perch\u00e9", "effect": "quindi"}, task_name) + + +def xcopa_id_prompt(line, task_name: str = None): + return xcopa_prompt(line, {"cause": "karena", "effect": "oleh karena itu"}, task_name) + + +def xcopa_qu_prompt(line, task_name: str = None): + return xcopa_prompt(line, {"cause": "imarayku", "effect": "chayna\u00b4r\u00f0m"}, task_name) + + +def xcopa_sw_prompt(line, task_name: str = None): + return xcopa_prompt(line, {"cause": "kwa sababu", "effect": "hivyo"}, task_name) + + +def xcopa_zh_prompt(line, task_name: str = None): + return xcopa_prompt(line, {"cause": "因為", "effect": "因此"}, task_name) + + +def xcopa_ta_prompt(line, task_name: str = None): + return xcopa_prompt(line, {"cause": "ஏனெனில்", "effect": "ஆகையால்"}, task_name) + + +def xcopa_th_prompt(line, task_name: str = None): + return xcopa_prompt(line, {"cause": "เพราะ", "effect": "ดังนั้น"}, task_name) + + +def xcopa_tr_prompt(line, task_name: str = None): + return xcopa_prompt(line, {"cause": "\u00e7\u00fc\u0308nk\u00fc", "effect": "bu y\u00fczden"}, task_name) + + +def xcopa_vi_prompt(line, task_name: str = None): + return xcopa_prompt(line, {"cause": "b\u1edfi v\u00ec", "effect": "v\u00ec v\u1eady"}, task_name) xcopa_en = LightevalTaskConfig( name="xcopa:en", - prompt_function=prompt.xcopa_en, + prompt_function=xcopa_en_prompt, hf_repo="cambridgeltl/xcopa", hf_subset="default", hf_avail_splits=["test", "train", "validation"], @@ -42,7 +100,7 @@ xcopa_et = LightevalTaskConfig( name="xcopa:et", - prompt_function=prompt.xcopa_et, + prompt_function=xcopa_et_prompt, hf_repo="cambridgeltl/xcopa", hf_subset="et", hf_avail_splits=["test", "train", "validation"], @@ -57,7 +115,7 @@ xcopa_ht = LightevalTaskConfig( name="xcopa:ht", - prompt_function=prompt.xcopa_ht, + prompt_function=xcopa_ht_prompt, hf_repo="cambridgeltl/xcopa", hf_subset="ht", hf_avail_splits=["test", "train", "validation"], @@ -72,7 +130,7 @@ xcopa_it = LightevalTaskConfig( name="xcopa:it", - prompt_function=prompt.xcopa_it, + prompt_function=xcopa_it_prompt, hf_repo="cambridgeltl/xcopa", hf_subset="it", hf_avail_splits=["test", "train", "validation"], @@ -87,7 +145,7 @@ xcopa_id = LightevalTaskConfig( name="xcopa:id", - prompt_function=prompt.xcopa_id, + prompt_function=xcopa_id_prompt, hf_repo="cambridgeltl/xcopa", hf_subset="id", hf_avail_splits=["test", "train", "validation"], @@ -102,7 +160,7 @@ xcopa_qu = LightevalTaskConfig( name="xcopa:qu", - prompt_function=prompt.xcopa_qu, + prompt_function=xcopa_qu_prompt, hf_repo="cambridgeltl/xcopa", hf_subset="qu", hf_avail_splits=["test", "train", "validation"], @@ -117,7 +175,7 @@ xcopa_sw = LightevalTaskConfig( name="xcopa:sw", - prompt_function=prompt.xcopa_sw, + prompt_function=xcopa_sw_prompt, hf_repo="cambridgeltl/xcopa", hf_subset="sw", hf_avail_splits=["test", "train", "validation"], @@ -132,7 +190,7 @@ xcopa_zh = LightevalTaskConfig( name="xcopa:zh", - prompt_function=prompt.xcopa_zh, + prompt_function=xcopa_zh_prompt, hf_repo="cambridgeltl/xcopa", hf_subset="zh", hf_avail_splits=["test", "train", "validation"], @@ -147,7 +205,7 @@ xcopa_ta = LightevalTaskConfig( name="xcopa:ta", - prompt_function=prompt.xcopa_ta, + prompt_function=xcopa_ta_prompt, hf_repo="cambridgeltl/xcopa", hf_subset="ta", hf_avail_splits=["test", "train", "validation"], @@ -162,7 +220,7 @@ xcopa_th = LightevalTaskConfig( name="xcopa:th", - prompt_function=prompt.xcopa_th, + prompt_function=xcopa_th_prompt, hf_repo="cambridgeltl/xcopa", hf_subset="th", hf_avail_splits=["test", "train", "validation"], @@ -177,7 +235,7 @@ xcopa_tr = LightevalTaskConfig( name="xcopa:tr", - prompt_function=prompt.xcopa_tr, + prompt_function=xcopa_tr_prompt, hf_repo="cambridgeltl/xcopa", hf_subset="tr", hf_avail_splits=["test", "train", "validation"], @@ -192,7 +250,7 @@ xcopa_vi = LightevalTaskConfig( name="xcopa:vi", - prompt_function=prompt.xcopa_vi, + prompt_function=xcopa_vi_prompt, hf_repo="cambridgeltl/xcopa", hf_subset="vi", hf_avail_splits=["test", "train", "validation"], diff --git a/src/lighteval/tasks/tasks/xstory_cloze.py b/src/lighteval/tasks/tasks/xstory_cloze.py index 54b73b9fc..cff9dc17a 100644 --- a/src/lighteval/tasks/tasks/xstory_cloze.py +++ b/src/lighteval/tasks/tasks/xstory_cloze.py @@ -19,14 +19,23 @@ paper: """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +def storycloze_prompt(line, task_name: str = None): + context = "\n".join( + [line["input_sentence_1"], line["input_sentence_2"], line["input_sentence_3"], line["input_sentence_4"]] + ) + choices = [line["sentence_quiz1"], line["sentence_quiz2"]] + gold = int(line["answer_right_ending"]) - 1 + return Doc(task_name=task_name, query=context + "\n", choices=choices, gold_index=gold) xstory_cloze_en = LightevalTaskConfig( name="xstory_cloze:en", - prompt_function=prompt.storycloze, + prompt_function=storycloze_prompt, hf_repo="juletxara/xstory_cloze", hf_subset="en", hf_avail_splits=["training", "eval"], @@ -41,7 +50,7 @@ xstory_cloze_ru = LightevalTaskConfig( name="xstory_cloze:ru", - prompt_function=prompt.storycloze, + prompt_function=storycloze_prompt, hf_repo="juletxara/xstory_cloze", hf_subset="ru", hf_avail_splits=["training", "eval"], @@ -56,7 +65,7 @@ xstory_cloze_zh = LightevalTaskConfig( name="xstory_cloze:zh", - prompt_function=prompt.storycloze, + prompt_function=storycloze_prompt, hf_repo="juletxara/xstory_cloze", hf_subset="zh", hf_avail_splits=["training", "eval"], @@ -71,7 +80,7 @@ xstory_cloze_es = LightevalTaskConfig( name="xstory_cloze:es", - prompt_function=prompt.storycloze, + prompt_function=storycloze_prompt, hf_repo="juletxara/xstory_cloze", hf_subset="es", hf_avail_splits=["training", "eval"], @@ -86,7 +95,7 @@ xstory_cloze_ar = LightevalTaskConfig( name="xstory_cloze:ar", - prompt_function=prompt.storycloze, + prompt_function=storycloze_prompt, hf_repo="juletxara/xstory_cloze", hf_subset="ar", hf_avail_splits=["training", "eval"], @@ -101,7 +110,7 @@ xstory_cloze_hi = LightevalTaskConfig( name="xstory_cloze:hi", - prompt_function=prompt.storycloze, + prompt_function=storycloze_prompt, hf_repo="juletxara/xstory_cloze", hf_subset="hi", hf_avail_splits=["training", "eval"], @@ -116,7 +125,7 @@ xstory_cloze_id = LightevalTaskConfig( name="xstory_cloze:id", - prompt_function=prompt.storycloze, + prompt_function=storycloze_prompt, hf_repo="juletxara/xstory_cloze", hf_subset="id", hf_avail_splits=["training", "eval"], @@ -131,7 +140,7 @@ xstory_cloze_te = LightevalTaskConfig( name="xstory_cloze:te", - prompt_function=prompt.storycloze, + prompt_function=storycloze_prompt, hf_repo="juletxara/xstory_cloze", hf_subset="te", hf_avail_splits=["training", "eval"], @@ -146,7 +155,7 @@ xstory_cloze_sw = LightevalTaskConfig( name="xstory_cloze:sw", - prompt_function=prompt.storycloze, + prompt_function=storycloze_prompt, hf_repo="juletxara/xstory_cloze", hf_subset="sw", hf_avail_splits=["training", "eval"], @@ -161,7 +170,7 @@ xstory_cloze_eu = LightevalTaskConfig( name="xstory_cloze:eu", - prompt_function=prompt.storycloze, + prompt_function=storycloze_prompt, hf_repo="juletxara/xstory_cloze", hf_subset="eu", hf_avail_splits=["training", "eval"], @@ -176,7 +185,7 @@ xstory_cloze_my = LightevalTaskConfig( name="xstory_cloze:my", - prompt_function=prompt.storycloze, + prompt_function=storycloze_prompt, hf_repo="juletxara/xstory_cloze", hf_subset="my", hf_avail_splits=["training", "eval"], diff --git a/src/lighteval/tasks/tasks/xwinograd.py b/src/lighteval/tasks/tasks/xwinograd.py index b38ef0338..6ac7ec573 100644 --- a/src/lighteval/tasks/tasks/xwinograd.py +++ b/src/lighteval/tasks/tasks/xwinograd.py @@ -18,14 +18,28 @@ https://arxiv.org/abs/2211.01786 """ -import lighteval.tasks.default_prompts as prompt from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +xwinograd_instruction = "Fill in the blank with the correct option." + + +def xwinograd_prompt(line, task_name: str = None): + query, end_of_target = line["sentence"].split("_") + end_of_target = end_of_target.strip() + return Doc( + task_name=task_name, + query=query, + choices=[f"{line['option1']} {end_of_target}", f"{line['option2']} {end_of_target}"], + gold_index=int(line["answer"]) - 1 if line["answer"] != "" else -1, + ) xwinograd_en = LightevalTaskConfig( name="xwinograd:en", - prompt_function=prompt.winogrande, + prompt_function=xwinograd_prompt, hf_repo="Muennighoff/xwinograd", hf_subset="en", hf_avail_splits=["test"], @@ -40,7 +54,7 @@ xwinograd_fr = LightevalTaskConfig( name="xwinograd:fr", - prompt_function=prompt.winogrande, + prompt_function=xwinograd_prompt, hf_repo="Muennighoff/xwinograd", hf_subset="fr", hf_avail_splits=["test"], @@ -55,7 +69,7 @@ xwinograd_jp = LightevalTaskConfig( name="xwinograd:jp", - prompt_function=prompt.winogrande, + prompt_function=xwinograd_prompt, hf_repo="Muennighoff/xwinograd", hf_subset="jp", hf_avail_splits=["test"], @@ -70,7 +84,7 @@ xwinograd_pt = LightevalTaskConfig( name="xwinograd:pt", - prompt_function=prompt.winogrande, + prompt_function=xwinograd_prompt, hf_repo="Muennighoff/xwinograd", hf_subset="pt", hf_avail_splits=["test"], @@ -85,7 +99,7 @@ xwinograd_ru = LightevalTaskConfig( name="xwinograd:ru", - prompt_function=prompt.winogrande, + prompt_function=xwinograd_prompt, hf_repo="Muennighoff/xwinograd", hf_subset="ru", hf_avail_splits=["test"], @@ -100,7 +114,7 @@ xwinograd_zh = LightevalTaskConfig( name="xwinograd:zh", - prompt_function=prompt.winogrande, + prompt_function=xwinograd_prompt, hf_repo="Muennighoff/xwinograd", hf_subset="zh", hf_avail_splits=["test"], diff --git a/src/lighteval/tasks/templates/hellaswag.py b/src/lighteval/tasks/templates/hellaswag.py index 9970edd82..2beb09574 100644 --- a/src/lighteval/tasks/templates/hellaswag.py +++ b/src/lighteval/tasks/templates/hellaswag.py @@ -20,11 +20,11 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import re from typing import Callable from typing_extensions import NotRequired, TypedDict -from lighteval.tasks.default_prompts import hellaswag_preprocess from lighteval.tasks.templates.continuation import get_continuation_prompt_function from lighteval.tasks.templates.multichoice import create_adapter_from_dict from lighteval.tasks.templates.utils.formatting_utils import ( @@ -42,6 +42,26 @@ HELLASWAG_QUERY = "{activity_label}{ctx}" +def hellaswag_preprocess( + text: str, + wikihow_artifacts: list[str] = [" [title]"], + truncate_dots: bool = False, + strip_text: bool = False, + dot_replacement: str = ". ", +): + """Comes from LM Eval Harness""" + # NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag. + for wikihow_artifact in wikihow_artifacts: + text = text.replace(wikihow_artifact, dot_replacement) + text = re.sub("\\[.*?\\]", "", text) + text = text.replace(" ", " ") + if truncate_dots: + text = text.replace(r"\.+", r"\.") + if strip_text: + text = text.strip() + return text + + class HellaswagInput(TypedDict): ctx_a: str continuations: list[str] diff --git a/src/lighteval/tasks/templates/utils/formulation.py b/src/lighteval/tasks/templates/utils/formulation.py index 7e2c3af9e..29cece2fd 100644 --- a/src/lighteval/tasks/templates/utils/formulation.py +++ b/src/lighteval/tasks/templates/utils/formulation.py @@ -21,9 +21,9 @@ # SOFTWARE. from dataclasses import dataclass +from string import ascii_uppercase from typing import Literal -from lighteval.tasks.default_prompts import INTEGER_INDICES, LETTER_INDICES from lighteval.tasks.templates.utils.translation_literals import TranslationLiterals @@ -72,11 +72,11 @@ class CFFormulation: def get_prefix(choice_prefix: ChoicePrefix, translation_literals: TranslationLiterals): if choice_prefix == "Letters": - return LETTER_INDICES + return ascii_uppercase elif choice_prefix == "NativeLetters": return translation_literals.indices elif choice_prefix == "Numbers": - return INTEGER_INDICES + return list(map(str, list(range(1, 27)))) def build_choices( diff --git a/src/lighteval/tasks/templates/utils/translation_literals.py b/src/lighteval/tasks/templates/utils/translation_literals.py index 73bbf7183..96cd8fc2c 100644 --- a/src/lighteval/tasks/templates/utils/translation_literals.py +++ b/src/lighteval/tasks/templates/utils/translation_literals.py @@ -21,8 +21,8 @@ # SOFTWARE. from dataclasses import dataclass, field +from string import ascii_uppercase -from lighteval.tasks.default_prompts import LETTER_INDICES from lighteval.utils.language import Language @@ -60,7 +60,7 @@ class TranslationLiterals: semicolon: str = ";" # Indices - indices: list[str] = field(default_factory=lambda: LETTER_INDICES) + indices: list[str] = field(default_factory=lambda: ascii_uppercase) def __getattribute__(self, name: str) -> str: value = super().__getattribute__(name) diff --git a/tests/reference_scores/harness_prompts.json b/tests/reference_scores/harness_prompts.json index 6cd942efa..9d42913f3 100644 --- a/tests/reference_scores/harness_prompts.json +++ b/tests/reference_scores/harness_prompts.json @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:4d7055452bb1f282b8b2c040a3a30856f51aa8d44fe80e2c391cbbc375a19b95 -size 20244716 +oid sha256:059a48631d4243cda36d067db50350639c12b0a88fb209f76bbcd0c3ff266ffb +size 20244711 diff --git a/tests/test_unit_harness_prompts.py b/tests/test_unit_harness_prompts.py deleted file mode 100644 index 6c8233fdc..000000000 --- a/tests/test_unit_harness_prompts.py +++ /dev/null @@ -1,75 +0,0 @@ -# MIT License - -# Copyright (c) 2024 The HuggingFace Team - -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import json -import os - -import pytest - -import lighteval.tasks.default_prompts as default_prompts -from lighteval.tasks.requests import Doc - - -PATH_TO_HARNESS_PROMPTS = os.path.join(os.path.dirname(__file__), "reference_scores/harness_prompts.json") - - -def pytest_generate_tests(metafunc: pytest.Metafunc): - """Initializes the main test setup. This function is automatically called by pytest and - should not be called manually. - - Every function with "model_input" as arguments will be sent the "parameters". - This function will be run only once, ensuring that each model is run only once on the selected tasks. - (This is better than using fixtures as fixtures are re-run once for each test, which is not a behavior we want). - """ - parameters = [] - - # If model_input is a test function argument - # (= the function requires a fixture) - if "prompt_inputs" in metafunc.fixturenames: - with open(PATH_TO_HARNESS_PROMPTS) as f: - prompt_fn_to_examples = json.load(f) - - for prompt_fn_name, examples in prompt_fn_to_examples.items(): - formatter_fn = getattr(default_prompts, prompt_fn_name) - - cur_params = [] - - for task_name, examples_list in examples.items(): - for input_line, reference_line in examples_list: - cur_params.append((formatter_fn, input_line, reference_line, task_name)) - parameters.append((prompt_fn_name, cur_params)) - metafunc.parametrize("prompt_inputs", parameters, scope="session") - - -def test_model_prediction(prompt_inputs: tuple[str, list]): - """Evaluates a model on a full task - is parametrized using pytest_generate_test""" - prompt_fn_name, examples = prompt_inputs - for prompt_fn, input_line, reference_line, task_name in examples: - formatted_line = prompt_fn(input_line, "") # task_name) - reference_line = Doc(**reference_line) - - error_msg = ( - f"Prompt formatting function {prompt_fn_name} failed on input {input_line} from task {task_name}.\n" - ) - error_msg += f"Reference: {reference_line}\n" - error_msg += f"Returned : {formatted_line}" - assert formatted_line == reference_line, error_msg