Skip to content

Commit

Permalink
mixture_or_task_with_new_vocab should respect `add_to_seqio_registr…
Browse files Browse the repository at this point in the history
…y` when creating subtasks.

PiperOrigin-RevId: 512466234
  • Loading branch information
dhgarrette authored and SeqIO committed Feb 26, 2023
1 parent b5aadfe commit 95476d1
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 10 deletions.
6 changes: 3 additions & 3 deletions seqio/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,14 +141,14 @@ def _validate_output_features(og_output_features, new_output_features):
new_tasks_and_rates = []
for task_name, rate in og_mix._task_to_rate.items():
new_task_name = f"{new_mixture_or_task_name}.{task_name}"
_ = mixture_or_task_with_new_vocab(
new_task = mixture_or_task_with_new_vocab(
task_name,
new_task_name,
new_vocab=new_vocab,
new_output_features=new_output_features,
add_to_seqio_registry=True,
add_to_seqio_registry=add_to_seqio_registry,
)
new_tasks_and_rates.append((new_task_name, rate))
new_tasks_and_rates.append((new_task, rate))

new_mix = dp.Mixture(
new_mixture_or_task_name,
Expand Down
33 changes: 26 additions & 7 deletions seqio/helpers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,16 +237,35 @@ def test_mixture_new_output_features(self):
add_to_seqio_registry=False,
)

# Step 4: Get new Tasks and Mixtures from the Registry.
# Step 4: Get new Tasks and Mixtures.
self.assertNotIn("my_new_test_mix2", dp.MixtureRegistry.names())
new_submix = dp.get_mixture_or_task("my_new_test_mix2.my_test_mix1")
new_submix_subtask1 = dp.get_mixture_or_task(
"my_new_test_mix2.my_test_mix1.my_test_task1"

self.assertNotIn(
"my_new_test_mix2.my_test_mix1", dp.MixtureRegistry.names()
)
new_submix_subtask2 = dp.get_mixture_or_task(
"my_new_test_mix2.my_test_mix1.my_test_task2"
self.assertLen(new_mix._sub_mixtures, 1)
new_submix = new_mix._sub_mixtures[0]
self.assertEqual(new_submix.name, "my_new_test_mix2.my_test_mix1")

self.assertNotIn(
"my_new_test_mix2.my_test_mix1.my_test_task1", dp.TaskRegistry.names()
)
new_subtask = dp.get_mixture_or_task("my_new_test_mix2.my_test_task1")
self.assertNotIn(
"my_new_test_mix2.my_test_mix1.my_test_task2", dp.TaskRegistry.names()
)
self.assertLen(new_submix._tasks, 2)
new_submix_subtask1, new_submix_subtask2 = new_submix._tasks
self.assertEqual(
new_submix_subtask1.name, "my_new_test_mix2.my_test_mix1.my_test_task1"
)
self.assertEqual(
new_submix_subtask2.name, "my_new_test_mix2.my_test_mix1.my_test_task2"
)

self.assertNotIn("my_new_test_mix2.my_test_task1", dp.TaskRegistry.names())
self.assertLen(new_mix._tasks, 1)
new_subtask = new_mix._tasks[0]
self.assertEqual(new_subtask.name, "my_new_test_mix2.my_test_task1")

# Step 5: Verify mixing rates for new mixtures.
self.assertDictEqual(
Expand Down

0 comments on commit 95476d1

Please sign in to comment.