Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bias-shades #37

Merged
merged 18 commits into from
Nov 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ optional arguments:
be comma-separated keyword args, e.g. `key1=value1,key2=value2`, with no spaces
--task_name TASK_NAME
Name of the task to use as found in the lm_eval registry. See: `lm_eval.list_tasks()`
--task_args TASK_ARGS
Optional task constructor args that you'd pass into a task class of kind " `--task_name`.
These must be comma-separated keyword args, e.g. `key1=value1,key2=value2`, with no spaces.
WARNING: To avoid parsing errors, ensure your strings are quoted. For example,
`example_separator='\n+++\n'`
WARNING: Values must NOT contain commas.
--template_names TEMPLATE_NAMES
Comma-separated list of template names for the specified task. Example:
`> python main.py ... --task_name rte --template_names imply,mean`
Expand Down
5 changes: 5 additions & 0 deletions lm_eval/api/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,14 @@ def __init__(
Example:
Q: Where is the Eiffel Tower located? A:{text_target_separator}Paris
"""
assert isinstance(save_examples, bool), "`save_examples` must be a bool."
assert isinstance(example_separator, str) and isinstance(
text_target_separator, str
), "Separator args must be strings."
assert (
text_target_separator.isspace()
), f"`text_target_separator` must be whitespace only. Got: `{text_target_separator}`"

super().__init__(data_dir, cache_dir, download_mode)
self.prompt_template = prompt_template
self.save_examples = save_examples
Expand Down
6 changes: 4 additions & 2 deletions lm_eval/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,13 +287,15 @@ def parse_cli_args_string(args: str) -> dict:
"""Parses a string in the following format to a kwargs dictionary.
"args1=val1,arg2=val2"
"""
args = args.strip()
# Remove leading whitespace but not trailing in case a `val` contains necessary whitespace.
args = args.lstrip()
if not args:
return {}
arg_list = args.split(",")
args_dict = {}
for arg in arg_list:
k, v = arg.split("=")
# Split on the first `=` to allow for `=`s in `val`.
k, v = arg.split("=", 1)
args_dict[k] = str_to_builtin_type(v)
return args_dict

Expand Down
10 changes: 9 additions & 1 deletion lm_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def cli_evaluate(
model_api_name: str,
model_args: str,
task_name: str,
task_args: str,
template_names: List[str],
num_fewshot: Optional[int] = 0,
batch_size: Optional[int] = None,
Expand All @@ -47,6 +48,10 @@ def cli_evaluate(
`lm_eval.api.model.get_model_from_args_string`
task_name (str):
The task name of the task to evaluate the model on.
task_args (str):
String arguments for the task. See:
`lm_eval.api.task.get_task_list_from_args_string`
WARNING: To avoid parse errors, separators must not contain commas.
template_names (List[str]):
List of template names for the specified `task_name` to evaluate
under.
Expand All @@ -69,7 +74,9 @@ def cli_evaluate(
Returns:
Dictionary of results.
"""
tasks = lm_eval.tasks.get_task_list(task_name, template_names)
tasks = lm_eval.tasks.get_task_list_from_args_string(
task_name, template_names, task_args
)
model = lm_eval.models.get_model_from_args_string(
model_api_name, model_args, {"batch_size": batch_size, "device": device}
)
Expand All @@ -93,6 +100,7 @@ def cli_evaluate(
results["config"] = {
"model": model_api_name,
"model_args": model_args,
"task_args": task_args,
"num_fewshot": num_fewshot,
"batch_size": batch_size,
"device": device,
Expand Down
2 changes: 1 addition & 1 deletion lm_eval/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def get_model_from_args_string(
model_args: A string of comma-separated key=value pairs that will be passed
to the model constructor. E.g. "pretrained=gpt2,batch_size=32".
additional_config: An additional dictionary of key=value pairs that will be
passed to the model constructor
passed to the model constructor.

Returns:
A language model instance.
Expand Down
39 changes: 38 additions & 1 deletion lm_eval/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import logging
from typing import List, Tuple, Type, Union
from typing import List, Mapping, Tuple, Type, Optional, Union
from promptsource.templates import DatasetTemplates

import lm_eval.api.utils
from lm_eval.api.task import Task

from . import anli
from . import bias_shades
from . import blimp
from . import diabla
from . import cnn_dailymail
Expand Down Expand Up @@ -134,6 +136,9 @@
# WMT
# Format: `wmt{year}_{lang1}_{lang2}`
**wmt.construct_tasks(),
# Bias-Shades
# Format: `bias_shades_{lang}`
**bias_shades.construct_tasks(),
# BLiMP
"blimp_adjunct_island": blimp.BlimpAdjunctIsland,
"blimp_anaphor_gender_agreement": blimp.BlimpAnaphorGenderAgreement,
Expand Down Expand Up @@ -280,6 +285,38 @@ def get_templates(task_name: str) -> DatasetTemplates:
return _get_templates_from_task(task_class)


def get_task_list_from_args_string(
task_name: str,
template_names: List[str],
task_args: str,
additional_config: Optional[Mapping[str, str]] = None,
) -> List[Task]:
"""Returns a list of the same task but with multiple prompt templates, each
task instantiated with the given kwargs.

Args:
task_name: Name of the task to use as found in the task registry.
template_names: Name of the prompt template from `promptsource` to use
for this task.
task_args: A string of comma-separated key=value pairs that will be passed
to the task constructor. E.g. "data_dir=./datasets,example_separator=\n\n"
additional_config: An additional dictionary of key=value pairs that will
be passed to the task constructor.

Returns:
A list of `Task` instances.
"""
kwargs = lm_eval.api.utils.parse_cli_args_string(task_args)
assert "prompt_template" not in kwargs, (
"Cannot specify a `prompt_template` object in the `task_args` string. "
"Only primitive type arguments are allowed."
)
additional_config = {} if additional_config is None else additional_config
additional_args = {k: v for k, v in additional_config.items() if v is not None}
kwargs.update(additional_args)
return get_task_list(task_name, template_names, **kwargs)


# Helper functions


Expand Down
71 changes: 71 additions & 0 deletions lm_eval/tasks/bias_shades.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""
Multilingual dataset for measuring social biases in language models.
https://huggingface.co/datasets/BigScienceBiasEval/bias-shades/viewer/spanish/test

TODO: Add `arabic`, `german`, `russian`, and `tamil` subsets when `promptsource`
templates become available.
"""
from lm_eval.api.task import PromptSourceTask


_CITATION = """"""


class BiasShadesBase(PromptSourceTask):
def has_training_docs(self):
return False

def has_validation_docs(self):
return False

def has_test_docs(self):
return True

def training_docs(self):
pass

def validation_docs(self):
pass

def test_docs(self):
if self.has_test_docs():
return self.dataset["test"]


class BiasShadesEnglish(BiasShadesBase):
VERSION = 0
DATASET_PATH = "BigScienceBiasEval/bias-shades"
DATASET_NAME = "english"


class BiasShadesFrench(BiasShadesBase):
VERSION = 0
DATASET_PATH = "BigScienceBiasEval/bias-shades"
DATASET_NAME = "french"


class BiasShadesHindi(BiasShadesBase):
VERSION = 0
DATASET_PATH = "BigScienceBiasEval/bias-shades"
DATASET_NAME = "hindi"


class BiasShadesSpanish(BiasShadesBase):
VERSION = 0
DATASET_PATH = "BigScienceBiasEval/bias-shades"
DATASET_NAME = "spanish"


BIAS_SHADES_CLASSES = [
BiasShadesEnglish,
BiasShadesFrench,
BiasShadesHindi,
BiasShadesSpanish,
]


def construct_tasks():
tasks = {}
for bias_shades_class in BIAS_SHADES_CLASSES:
tasks[f"bias_shades_{bias_shades_class.DATASET_NAME}"] = bias_shades_class
return tasks
11 changes: 11 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,16 @@ def parse_args():
help="Name of the task to use as found "
"in the lm_eval registry. See: `lm_eval.list_tasks()`",
)
parser.add_argument(
"--task_args",
default="",
help="""Optional task constructor args that you'd pass into a task class of kind "
`--task_name`. These must be comma-separated keyword args, e.g.
`key1=value1,key2=value2`, with no spaces.
WARNING: To avoid parsing errors, ensure your strings are quoted. For example,
`example_separator='\\n+++\\n'`
WARNING: Values must NOT contain commas.""",
)
parser.add_argument(
"--template_names",
default="all_templates",
Expand Down Expand Up @@ -167,6 +177,7 @@ def main():
model_api_name=args.model_api_name,
model_args=args.model_args,
task_name=args.task_name,
task_args=args.task_args,
template_names=template_names,
num_fewshot=args.num_fewshot,
batch_size=args.batch_size,
Expand Down
99 changes: 98 additions & 1 deletion tests/test_tasks.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import logging
import pytest
import numpy as np
from typing import Optional, Tuple
from itertools import islice
from promptsource.templates import Template

import lm_eval.tasks as tasks
from lm_eval.api.task import Task
from lm_eval.api.request import Request
from lm_eval.api.utils import set_seed
from lm_eval.api.utils import set_seed, DEFAULT_SEED


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -132,3 +133,99 @@ def test_documents_and_requests(task_name: str, task_class: Task):
# TODO: Mock lm after refactoring evaluator.py to not be a mess
for req in requests:
assert isinstance(req, Request)


def test_arg_string_task_creation():
import itertools

TEST_EXAMPLE_SEPS = [
# Test `=` symbol in value string
"\n===TEST_SEPARATOR===\n",
# Test whitespace only separators
" ",
" \t\t ",
"\n\n\n\n",
# Test empty string separator
"",
# Test misc. symbols in separator
"[[[[]]]]",
"<<___>>",
"(())",
]
TEST_TEXT_TARGET_SEPS = [
# Test whitespace separators
" ",
" \t ",
"\n\n\n",
]

# Ensure parsing properly handles args.
for example_sep, text_target_sep in itertools.product(
TEST_EXAMPLE_SEPS, TEST_TEXT_TARGET_SEPS
):
test_arg_string = f" save_examples=False,example_separator={example_sep},text_target_separator={text_target_sep}"
task = tasks.get_task_list_from_args_string(
"wnli",
template_names=["confident"],
task_args=test_arg_string,
)[0]

assert task.save_examples is False
assert task.example_separator == example_sep
assert task.text_target_separator == text_target_sep

# Ensure fewshot context is formatted as expected.
TEST_EXAMPLE_SEP = "\n===TEST_SEPARATOR===\n"
TEST_TEXT_TARGET_SEP = " "
test_arg_string = f" save_examples=False,example_separator={TEST_EXAMPLE_SEP},text_target_separator={TEST_TEXT_TARGET_SEP}"
task = tasks.get_task_list_from_args_string(
"wnli",
template_names=["confident"],
task_args=test_arg_string,
)[0]
context = task.fewshot_context(
task.validation_docs()[0],
num_fewshot=2,
rng=np.random.default_rng(DEFAULT_SEED),
)[0]
expected = f"""If it's true that
The man couldn't lift his son because he was so heavy.
how confident should I be that
The man was so heavy.
very confident or not confident? not confident
===TEST_SEPARATOR===
If it's true that
As Ollie carried Tommy up the long winding steps, his legs ached.
how confident should I be that
Ollie's legs ached.
very confident or not confident? very confident
===TEST_SEPARATOR===
If it's true that
The drain is clogged with hair. It has to be cleaned.
how confident should I be that
The hair has to be cleaned.
very confident or not confident?"""
assert context == expected

# Ensure tasks don't instantiate with invalid args.
with pytest.raises(AssertionError):
bad_save_examples_arg_string = "example_separator=\t,save_examples=yes"
task = tasks.get_task_list_from_args_string(
"wnli",
template_names=["confident"],
task_args=bad_save_examples_arg_string,
)[0]

bad_example_sep_arg_string = "example_separator=False,save_examples=False"
task = tasks.get_task_list_from_args_string(
"wnli",
template_names=["confident"],
task_args=bad_example_sep_arg_string,
)[0]

bad_text_sep_arg_string = "text_target_separator=___"
task = tasks.get_task_list_from_args_string(
"wnli",
template_names=["confident"],
task_args=bad_text_sep_arg_string,
)[0]