Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

internal #483

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions seqio/dataset_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2020,6 +2020,21 @@ def get_mixture_or_task(task_or_mixture_name: str):
)


def maybe_get_mixture_or_task(
task: Union[str, Task, Mixture]
) -> Union[Task, Mixture]:
"""Given a task name, Task, or Mixture object, return an object."""
if isinstance(task, str):
return get_mixture_or_task(task)

if not isinstance(task, (Task, Mixture)):
raise ValueError(
"User passed in a task that was not a string, Task, or Mixture."
f"Got type: {type(task)}"
)
return task


def get_subtasks(task_or_mixture):
"""Returns all the Tasks in a Mixture as a list or the Task itself."""
if isinstance(task_or_mixture, Task):
Expand Down
40 changes: 19 additions & 21 deletions seqio/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def num_input_examples(self, split: str) -> Optional[int]:


def mixture_or_task_with_truncated_data(
mixture_or_task_name: str,
mixture_or_task: Union[dp.Task, dp.Mixture, str],
new_mixture_or_task_name: str,
*,
split_sizes: Mapping[str, int],
Expand All @@ -263,7 +263,8 @@ def mixture_or_task_with_truncated_data(
and few-shot fine-tuning datasets.

Args:
mixture_or_task_name: The name of the original Task or Mixture.
mixture_or_task: The original Task or Mixture, or the name of a registered
Task or Mixture.
new_mixture_or_task_name: The name of the new Task or Mixture. For Mixtures,
this is also used as a prefix for subtasks, e.g. "subtask_1" is registered
with the new vocabulary as "new_mixture_or_task_name.subtask_1".
Expand All @@ -277,48 +278,45 @@ def mixture_or_task_with_truncated_data(
Returns:
The new `Task` or `Mixture` object.
"""
if isinstance(mixture_or_task, str):
mixture_or_task = dp.get_mixture_or_task(mixture_or_task)

if mixture_or_task_name in dp.TaskRegistry.names():
if isinstance(mixture_or_task, dp.Task):
# This is a `Task`.
og_task: dp.Task = dp.get_mixture_or_task(mixture_or_task_name)

new_task = dp.Task(
new_mixture_or_task_name,
source=TruncatedDatasetProvider(
og_task.source,
mixture_or_task.source,
split_sizes=split_sizes,
shuffle_buffer_size=og_task._shuffle_buffer_size,
shuffle_buffer_size=mixture_or_task._shuffle_buffer_size,
),
output_features=og_task.output_features,
preprocessors=og_task.preprocessors,
postprocess_fn=og_task.postprocessor,
metric_fns=og_task.metric_fns,
shuffle_buffer_size=og_task._shuffle_buffer_size,
output_features=mixture_or_task.output_features,
preprocessors=mixture_or_task.preprocessors,
postprocess_fn=mixture_or_task.postprocessor,
metric_fns=mixture_or_task.metric_fns,
shuffle_buffer_size=mixture_or_task._shuffle_buffer_size,
)
if add_to_seqio_registry:
dp.TaskRegistry.add_provider(new_mixture_or_task_name, new_task)
return new_task
else:
# This is a Mixture. Create and register new sub-Tasks/Mixtures with the
# provided vocab/output_features, then create a new Mixture.
og_mix: dp.Mixture = dp.get_mixture_or_task(mixture_or_task_name)

new_tasks_and_rates = []
for task_name, rate in og_mix._task_to_rate.items():
new_task_name = f"{new_mixture_or_task_name}.{task_name}"
_ = mixture_or_task_with_truncated_data(
for task_name, rate in mixture_or_task._task_to_rate.items():
new_task = mixture_or_task_with_truncated_data(
task_name,
new_task_name,
f"{new_mixture_or_task_name}.{task_name}",
split_sizes=split_sizes,
add_to_seqio_registry=True,
add_to_seqio_registry=add_to_seqio_registry,
)
new_tasks_and_rates.append((new_task_name, rate))
new_tasks_and_rates.append((new_task, rate))

new_mix = dp.Mixture(
new_mixture_or_task_name,
new_tasks_and_rates,
default_rate=None,
sample_fn=og_mix._sample_fn,
sample_fn=mixture_or_task._sample_fn,
)
if add_to_seqio_registry:
dp.MixtureRegistry.add_provider(new_mixture_or_task_name, new_mix)
Expand Down