Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Experiment tracking support with ClearML Logger v1 #4896

Merged
merged 16 commits into from Mar 22, 2023
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
168 changes: 168 additions & 0 deletions parlai/core/logs.py
Expand Up @@ -276,3 +276,171 @@ def finish(self):

def flush(self):
pass


class ClearMLLogger(object):
"""
Log objects to ClearML.

To log all the necessary details for a ParlAI experiment using MLOps. After logging,
details can be viewed in ClearML Experiment Manager Web UI.
"""

@classmethod
def add_cmdline_args(
cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None
) -> ParlaiParser:
"""
Add ClearML CLI args.
"""
logger = parser.add_argument_group('ClearML Arguments')
logger.add_argument(
'-clearmllog',
'--clearml-log',
type=bool,
default=False,
help="Creates a ClearML Task. Default: False. If True, ClearML logging will be enabled.",
hidden=False,
)

logger.add_argument(
'-clearmlproject',
'--clearml-project-name',
type=str,
default="ParlAI",
help='ClearML Project Name. All the logs will be stored under this project in ClearML WebUI. If not set, default will set to ParlAI.',
hidden=False,
)

logger.add_argument(
'-clearmltask',
'--clearml-task-name',
type=str,
default="Default Task",
help='ClearML Task Name. All the logs will be stored under this task in ClearML WebUI. If not set, default will set to "Default Task".',
hidden=False,
)

return logger

def __init__(self, opt: Opt):
try:
from clearml import Task, Logger
except ImportError:
raise ImportError('Please run `pip install clearml`.')

# Set ClearML Project Name
project_name = opt.get('clearml_project_name')
# Set ClearML Task Name
task_name = opt.get('clearml_task_name')
# Instantiate CleaML Task
if Task.current_task():
self.clearml_task = Task.current_task()
else:
self.clearml_task = Task.init(
project_name=project_name,
task_name=task_name,
auto_connect_arg_parser=False,
auto_connect_frameworks={'tensorboard': False},
output_uri=True,
)

# Report Hyperparameter Configurations
self.clearml_task.connect(opt)

# Initialize ClearML Logger
self.clearml_logger = Logger.current_logger()

def log_metrics(self, setting, step, report):
"""
Log all metrics (iteratively during training) to ClearML WebUI.

:param setting:
One of train/valid/test. Here, it will be "train". Will be used as the title for the graph/table/chart.
:param step:
Number of parleys
:param report:
The report to log
"""
for k, v in report.items():
v = v.value() if isinstance(v, Metric) else v
if not isinstance(v, numbers.Number):
logging.error(f'k {k} v {v} is not a number')
continue
display = get_metric_display_data(metric=k)

try:
self.clearml_logger.report_scalar(
title=f"{display.title} ({k})",
series=f'{setting}',
value=v,
iteration=step,
)

except Exception as exception:
print(exception)

def log_final(self, setting, report):
"""
Log final single value metrics to ClearML WebUI.

:param setting:
One of train/valid/test. Here, it will be either "valid" or "test". Will be used as the title for the graph/table/chart.
:param report:
The report to log
"""
report = dict_report(report)
for k, v in report.items():
if isinstance(v, numbers.Number):
self.clearml_logger.report_single_value(
f'{get_metric_display_data(metric=k).title} - {setting}', v
)

def log_debug_samples(self, series, debug_samples, index=0, title="dialogues"):
"""
Log/Report Test/Validation Samples as debug samples in ClearML WebUI.

:param series:
Name of series to show on WebUI. One of train/valid/test or similar.
:param debug_samples:
The sample to log.
:param index:
Specifies iteration number. Default: 0.
:param title:
Type of metric (For ClearML WebUI). Default set to "dialouges".
"""

# Report Test/Validation Samples as debug samples
self.clearml_logger.report_media(
title=title,
series=series,
iteration=index,
stream=debug_samples,
file_extension=".txt",
)

def upload_artifact(self, artifact_name, artifact_path):
"""
Upload custom artifacts/models to ClearML.

:param artifact_name:
Name of artifact/model to log or display in ClearML WebUI
:param artifact_path:
The disk location of the artifact/model for uploading.
"""

self.clearml_task.update_output_model(
model_path=artifact_path, model_name=artifact_name, auto_delete_file=False
)

def flush(self):
"""
Flush logger manually.
"""
self.clearml_logger.flush()

def close(self):
"""
Close current ClearML Task after completing the experiment.
"""
self.clearml_task.close()
31 changes: 28 additions & 3 deletions parlai/scripts/train_model.py
Expand Up @@ -36,7 +36,7 @@
import parlai.utils.logging as logging
from parlai.core.agents import create_agent, create_agent_from_shared
from parlai.core.exceptions import StopTrainException
from parlai.core.logs import TensorboardLogger, WandbLogger
from parlai.core.logs import TensorboardLogger, WandbLogger, ClearMLLogger
from parlai.core.metrics import Metric
from parlai.core.metrics import (
aggregate_named_reports,
Expand Down Expand Up @@ -281,6 +281,7 @@ def setup_args(parser=None) -> ParlaiParser:
WorldLogger.add_cmdline_args(parser, partial_opt=None)
TensorboardLogger.add_cmdline_args(parser, partial_opt=None)
WandbLogger.add_cmdline_args(parser, partial_opt=None)
ClearMLLogger.add_cmdline_args(parser, partial_opt=None)
parser = setup_dict_args(parser)
return parser

Expand Down Expand Up @@ -472,6 +473,8 @@ def __init__(self, opt):
if opt['wandb_log'] and is_primary_worker():
model = self.agent.model if hasattr(self.agent, 'model') else None
self.wb_logger = WandbLogger(opt, model)
if opt['clearml_log'] and is_primary_worker():
self.clearml_logger = ClearMLLogger(opt)

def save_model(self, suffix=None):
"""
Expand Down Expand Up @@ -567,6 +570,11 @@ def validate(self):
valid_report['total_exs'] = self._total_exs
self.wb_logger.log_metrics('valid', self.parleys, valid_report)

if opt['clearml_log'] and is_primary_worker():
valid_report['total_exs'] = self._total_exs
self.clearml_logger.log_metrics('valid', self.parleys, valid_report)
self.clearml_logger.flush()

# send valid metrics to agent if the agent wants them
if hasattr(self.agent, 'receive_metrics'):
self.agent.receive_metrics(valid_report)
Expand Down Expand Up @@ -629,7 +637,9 @@ def validate(self):
return True
return False

def _run_single_eval(self, opt, valid_world, max_exs, datatype, is_multitask, task):
def _run_single_eval(
self, opt, valid_world, max_exs, datatype, is_multitask, task, index
):

# run evaluation on a single world
valid_world.reset()
Expand All @@ -652,6 +662,10 @@ def _run_single_eval(self, opt, valid_world, max_exs, datatype, is_multitask, ta
if cnt == 0 and opt['display_examples']:
print(valid_world.display() + '\n~~')
print(valid_world.report())
if opt['clearml_log'] and is_primary_worker():
self.clearml_logger.log_debug_samples(
datatype, valid_world.display(), index=index
)
cnt = valid_world.report().get('exs') or 0

if world_logger is not None:
Expand Down Expand Up @@ -707,7 +721,7 @@ def _run_eval(
else:
task = opt['task'].split(',')[index]
task_report = self._run_single_eval(
opt, v_world, max_exs_per_worker, datatype, is_multitask, task
opt, v_world, max_exs_per_worker, datatype, is_multitask, task, index
)
reports.append(task_report)

Expand Down Expand Up @@ -757,6 +771,8 @@ def _run_final_extra_eval(self, opt):
)
if opt['wandb_log'] and is_primary_worker():
self.wb_logger.log_final(final_datatype, final_valid_report)
if opt['clearml_log'] and is_primary_worker():
self.clearml_logger.log_final(final_datatype, final_valid_report)

return final_valid_report

Expand Down Expand Up @@ -898,6 +914,8 @@ def log(self):
self.tb_logger.log_metrics('train', self.parleys, train_report)
if opt['wandb_log'] and is_primary_worker():
self.wb_logger.log_metrics('train', self.parleys, train_report)
if opt['clearml_log'] and is_primary_worker():
self.clearml_logger.log_metrics('train', self.parleys, train_report)

return train_report

Expand Down Expand Up @@ -1026,6 +1044,9 @@ def train(self):
self.wb_logger.log_final('valid', self.final_valid_report)
self.wb_logger.log_final('test', self.final_test_report)
self.wb_logger.finish()
if opt['clearml_log'] and is_primary_worker():
self.clearml_logger.log_final('Validation Report', self.final_valid_report)
self.clearml_logger.log_final('Test Report', self.final_test_report)

if valid_worlds:
for valid_world in valid_worlds:
Expand All @@ -1042,6 +1063,10 @@ def train(self):
if opt['wandb_log'] and is_primary_worker():
self.wb_logger.finish()

if opt['clearml_log'] and is_primary_worker():
self.clearml_logger.upload_artifact('dictionary', opt['dict_file'])
self.clearml_logger.close()

self._save_train_stats()

return self.final_valid_report, self.final_test_report
Expand Down
26 changes: 26 additions & 0 deletions tests/test_clearml_logger.py
@@ -0,0 +1,26 @@
import unittest


def setUpModule():
unittest.defaultTestLoader.parallelism = 1


class TestClearMLLogger(unittest.TestCase):
def test_task_init(self):
from clearml import Task
skinan marked this conversation as resolved.
Show resolved Hide resolved

Task.set_offline(offline_mode=True)
from parlai.core.logs import ClearMLLogger

opt = {}
try:
self.clearml_callback = ClearMLLogger(opt)
except Exception as exc:
self.clearml_callback = None
self.fail(exc)

self.assertEqual(Task.current_task()._project_name[1], "ParlAI")


if __name__ == '__main__':
unittest.main()