Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add option to log selected config only #1159

Merged
merged 4 commits into from
May 23, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 14 additions & 2 deletions mmengine/visualization/vis_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,10 @@ class MLflowVisBackend(BaseVisBackend):
tracking_uri (str, optional): The tracking uri. Default to None.
artifact_suffix (Tuple[str] or str, optional): The artifact suffix.
Default to ('.json', '.log', '.py', 'yaml').
tracked_config_keys (dict, optional): The top level keys of config that
will be added to the experiment. Default to None, which means all
the config will be added.
`New in version 0.7.4.`
"""

def __init__(self,
Expand All @@ -659,14 +663,16 @@ def __init__(self,
params: Optional[dict] = None,
tracking_uri: Optional[str] = None,
artifact_suffix: SUFFIX_TYPE = ('.json', '.log', '.py',
'yaml')):
'yaml'),
tracked_config_keys: Optional[dict] = None):
super().__init__(save_dir)
self._exp_name = exp_name
self._run_name = run_name
self._tags = tags
self._params = params
self._tracking_uri = tracking_uri
self._artifact_suffix = artifact_suffix
self._tracked_config_keys = tracked_config_keys

def _init_env(self):
"""Setup env for MLflow."""
Expand Down Expand Up @@ -729,7 +735,13 @@ def add_config(self, config: Config, **kwargs) -> None:
config (Config): The Config object
"""
self.cfg = config
self._mlflow.log_params(self._flatten(self.cfg))
if self._tracked_config_keys is None:
self._mlflow.log_params(self._flatten(self.cfg))
else:
tracked_cfg = dict()
for k in self._tracked_config_keys:
tracked_cfg[k] = self.cfg[k]
self._mlflow.log_params(self._flatten(tracked_cfg))
self._mlflow.log_text(self.cfg.pretty_text, 'config.py')

@force_init_env
Expand Down