Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Fixes train_model worldlogging for multitask with mutators. #4414

Merged
merged 2 commits into from Mar 17, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 5 additions & 4 deletions parlai/scripts/train_model.py
Expand Up @@ -619,7 +619,7 @@ def validate(self):
return True
return False

def _run_single_eval(self, opt, valid_world, max_exs, datatype, is_multitask):
def _run_single_eval(self, opt, valid_world, max_exs, datatype, is_multitask, task):

# run evaluation on a single world
valid_world.reset()
Expand All @@ -629,7 +629,7 @@ def _run_single_eval(self, opt, valid_world, max_exs, datatype, is_multitask):
# set up world logger for the "test" fold
if opt['world_logs'] and datatype == 'test':
task_opt['world_logs'] = get_task_world_logs(
valid_world.getID(), opt['world_logs'], is_multitask
task, opt['world_logs'], is_multitask
)
world_logger = WorldLogger(task_opt)

Expand Down Expand Up @@ -691,9 +691,10 @@ def _run_eval(

max_exs_per_worker = max_exs / (len(valid_worlds) * num_workers())
is_multitask = len(valid_worlds) > 1
for v_world in valid_worlds:
for index, v_world in enumerate(valid_worlds):
task = opt['task'].split(',')[index]
task_report = self._run_single_eval(
opt, v_world, max_exs_per_worker, datatype, is_multitask
opt, v_world, max_exs_per_worker, datatype, is_multitask, task
)
reports.append(task_report)

Expand Down
27 changes: 27 additions & 0 deletions tests/test_train_model.py
Expand Up @@ -271,6 +271,33 @@ def test_save_multiple_world_logs(self):
json_lines = f.readlines()
assert len(json_lines) == 5

def test_save_multiple_world_logs_mutator(self):
"""
Test that we can save multiple world_logs from train model on multiple tasks
with mutators present.
"""
with testing_utils.tempdir() as tmpdir:
log_report = os.path.join(tmpdir, 'world_logs.jsonl')
multitask = 'integration_tests:mutators=flatten,integration_tests:ReverseTeacher:mutator=reverse'
Copy link
Contributor

@jxmsML jxmsML Mar 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we can test this with integration_tests:mutators=flatten,integration_tests:mutator=reverse to highlight that you wanted to test if world logs works for exact same teacher except for mutators

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was just trying to convey that the two tasks can be anything, it's the number of tasks that matter. So it works for any two tasks.

valid, test = testing_utils.train_model(
{
'task': multitask,
'validation_max_exs': 10,
'model': 'repeat_label',
'short_final_eval': True,
'num_epochs': 1.0,
'world_logs': log_report,
}
)

for task in multitask.split(','):
task_log_report = get_task_world_logs(
task, log_report, is_multitask=True
)
with PathManager.open(task_log_report) as f:
json_lines = f.readlines()
assert len(json_lines) == 5


@register_agent("fake_report")
class FakeReportAgent(Agent):
Expand Down