Skip to content

Commit

Permalink
added_outline_for_table
Browse files Browse the repository at this point in the history
  • Loading branch information
KennethEnevoldsen committed Apr 23, 2024
1 parent d6bd6f4 commit 8140f5f
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 0 deletions.
4 changes: 4 additions & 0 deletions docs/creat_tasks_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from __future__ import annotations

import mteb

2 changes: 2 additions & 0 deletions mteb/abstasks/AbsTask.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def languages(self) -> set[str]:
"""Returns the languages of the task"""
return self.metadata.languages



def __repr__(self) -> str:
"""Format the representation of the task such that it appears as:
Expand Down
117 changes: 117 additions & 0 deletions mteb/overview_fns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""This script contains functions that are used to get an overview of the MTEB benchmark."""

from __future__ import annotations

import logging

from mteb.abstasks import AbsTask
from mteb.abstasks.languages import (
ISO_TO_LANGUAGE,
ISO_TO_SCRIPT,
path_to_lang_codes,
path_to_lang_scripts,
)
from mteb.abstasks.TaskMetadata import TASK_DOMAIN, TASK_TYPE
from mteb.tasks import * # import all tasks

logger = logging.getLogger(__name__)


def check_is_valid_script(script: str) -> None:
if script not in ISO_TO_SCRIPT:
raise ValueError(
f"Invalid script code: {script}, you can find valid ISO 15924 codes in {path_to_lang_scripts}"
)


def check_is_valid_language(lang: str) -> None:
if lang not in ISO_TO_LANGUAGE:
raise ValueError(
f"Invalid language code: {lang}, you can find valid ISO 639-3 codes in {path_to_lang_codes}"
)


def filter_superseeded_datasets(tasks: list[AbsTask]) -> list[AbsTask]:
return [t for t in tasks if t.is_superseeded is None]


def filter_tasks_by_languages(
tasks: list[AbsTask], languages: list[str]
) -> list[AbsTask]:
[check_is_valid_language(lang) for lang in languages]
langs_to_keep = set(languages)
return [t for t in tasks if langs_to_keep.intersection(t.metadata.languages)]


def filter_tasks_by_script(tasks: list[AbsTask], script: list[str]) -> list[AbsTask]:
[check_is_valid_script(s) for s in script]
script_to_keep = set(script)
return [t for t in tasks if script_to_keep.intersection(t.metadata.scripts)]


def filter_tasks_by_domains(
tasks: list[AbsTask], domains: list[TASK_DOMAIN]
) -> list[AbsTask]:
domains_to_keep = set(domains)

def _convert_to_set(domain: list[TASK_DOMAIN] | None) -> set:
return set(domain) if domain is not None else set()

return [
t
for t in tasks
if domains_to_keep.intersection(_convert_to_set(t.metadata.domains))
]


def filter_tasks_by_task_type(
tasks: list[AbsTask], task_type: TASK_TYPE
) -> list[AbsTask]:
return [t for t in tasks if t.metadata.type == task_type]


def get_tasks(
languages: list[str] | None = None,
script: list[str] | None = None,
domains: list[TASK_DOMAIN] | None = None,
task_type: TASK_TYPE | None = None,
exclude_superseeded_datasets: bool = True,
) -> list[AbsTask]:
"""Get a list of tasks based on the specified filters.
Args:
languages: A list of languages either specified as 3 letter languages codes (ISO 639-3, e.g. "eng") or as script languages codes e.g.
"eng-Latn".
script: A list of script codes (ISO 15924 codes). If None, all scripts are included.
domains: A list of task domains.
task_type: A string specifying the type of task. If None, all tasks are included.
exclude_superseeded_datasets: A boolean flag to exclude datasets which are superseeded by another.
Returns:
A list of all initialized tasks objects which pass all of the filters (AND operation).
Examples:
>>> get_tasks(languages=["eng", "deu"], script=["Latn"], domains=["Legal"])
>>> get_tasks(languages=["eng"], script=["Latn"], task_type="Classification")
>>> get_tasks(languages=["eng"], script=["Latn"], task_type="Clustering", exclude_superseeded_datasets=False)
"""
tasks_categories_cls = [cls for cls in AbsTask.__subclasses__()]
tasks = [
cls()
for cat_cls in tasks_categories_cls
for cls in cat_cls.__subclasses__()
if cat_cls.__name__.startswith("AbsTask")
]

if languages:
tasks = filter_tasks_by_languages(tasks, languages)
if script:
tasks = filter_tasks_by_script(tasks, script)
if domains:
tasks = filter_tasks_by_domains(tasks, domains)
if exclude_superseeded_datasets:
tasks = filter_superseeded_datasets(tasks)
if task_type:
tasks = filter_tasks_by_task_type(tasks, task_type)

return tasks

0 comments on commit 8140f5f

Please sign in to comment.