22MLflow
33------
44"""
5- import os
65from argparse import Namespace
76from time import time
87from typing import Optional , Dict , Any , Union
1110 import mlflow
1211 from mlflow .tracking import MlflowClient
1312 _MLFLOW_AVAILABLE = True
14- except ImportError : # pragma: no-cover
13+ except ModuleNotFoundError : # pragma: no-cover
1514 mlflow = None
1615 MlflowClient = None
1716 _MLFLOW_AVAILABLE = False
1817
18+
1919from pytorch_lightning import _logger as log
2020from pytorch_lightning .loggers .base import LightningLoggerBase , rank_zero_experiment
2121from pytorch_lightning .utilities import rank_zero_only
2222
2323
24+ LOCAL_FILE_URI_PREFIX = "file:"
25+
26+
2427class MLFlowLogger (LightningLoggerBase ):
2528 """
2629 Log using `MLflow <https://mlflow.org>`_. Install it with pip:
@@ -52,59 +55,71 @@ class MLFlowLogger(LightningLoggerBase):
5255 Args:
5356 experiment_name: The name of the experiment
5457 tracking_uri: Address of local or remote tracking server.
55- If not provided, defaults to the service set by ``mlflow.tracking.set_tracking_uri` `.
58+ If not provided, defaults to `file:<save_dir> `.
5659 tags: A dictionary tags for the experiment.
60+ save_dir: A path to a local directory where the MLflow runs get saved.
61+ Defaults to `./mlflow` if `tracking_uri` is not provided.
62+ Has no effect if `tracking_uri` is provided.
5763
5864 """
5965
6066 def __init__ (self ,
6167 experiment_name : str = 'default' ,
6268 tracking_uri : Optional [str ] = None ,
6369 tags : Optional [Dict [str , Any ]] = None ,
64- save_dir : Optional [str ] = None ):
70+ save_dir : Optional [str ] = './mlruns' ):
6571
6672 if not _MLFLOW_AVAILABLE :
6773 raise ImportError ('You want to use `mlflow` logger which is not installed yet,'
6874 ' install it with `pip install mlflow`.' )
6975 super ().__init__ ()
70- if not tracking_uri and save_dir :
71- tracking_uri = f'file:{ os .sep * 2 } { save_dir } '
72- self ._mlflow_client = MlflowClient (tracking_uri )
73- self .experiment_name = experiment_name
76+ if not tracking_uri :
77+ tracking_uri = f'{ LOCAL_FILE_URI_PREFIX } { save_dir } '
78+
79+ self ._experiment_name = experiment_name
80+ self ._experiment_id = None
81+ self ._tracking_uri = tracking_uri
7482 self ._run_id = None
7583 self .tags = tags
84+ self ._mlflow_client = MlflowClient (tracking_uri )
7685
7786 @property
7887 @rank_zero_experiment
7988 def experiment (self ) -> MlflowClient :
8089 r"""
81- Actual MLflow object. To use mlflow features in your
90+ Actual MLflow object. To use MLflow features in your
8291 :class:`~pytorch_lightning.core.lightning.LightningModule` do the following.
8392
8493 Example::
8594
8695 self.logger.experiment.some_mlflow_function()
8796
8897 """
89- return self ._mlflow_client
90-
91- @property
92- def run_id (self ):
93- if self ._run_id is not None :
94- return self ._run_id
95-
96- expt = self ._mlflow_client .get_experiment_by_name (self .experiment_name )
98+ expt = self ._mlflow_client .get_experiment_by_name (self ._experiment_name )
9799
98100 if expt :
99- self ._expt_id = expt .experiment_id
101+ self ._experiment_id = expt .experiment_id
100102 else :
101- log .warning (f'Experiment with name { self .experiment_name } not found. Creating it.' )
102- self ._expt_id = self ._mlflow_client .create_experiment (name = self .experiment_name )
103+ log .warning (f'Experiment with name { self ._experiment_name } not found. Creating it.' )
104+ self ._experiment_id = self ._mlflow_client .create_experiment (name = self ._experiment_name )
103105
104- run = self ._mlflow_client .create_run (experiment_id = self ._expt_id , tags = self .tags )
105- self ._run_id = run .info .run_id
106+ if not self ._run_id :
107+ run = self ._mlflow_client .create_run (experiment_id = self ._experiment_id , tags = self .tags )
108+ self ._run_id = run .info .run_id
109+ return self ._mlflow_client
110+
111+ @property
112+ def run_id (self ):
113+ # create the experiment if it does not exist to get the run id
114+ _ = self .experiment
106115 return self ._run_id
107116
117+ @property
118+ def experiment_id (self ):
119+ # create the experiment if it does not exist to get the experiment id
120+ _ = self .experiment
121+ return self ._experiment_id
122+
108123 @rank_zero_only
109124 def log_hyperparams (self , params : Union [Dict [str , Any ], Namespace ]) -> None :
110125 params = self ._convert_params (params )
@@ -126,14 +141,26 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
126141 @rank_zero_only
127142 def finalize (self , status : str = 'FINISHED' ) -> None :
128143 super ().finalize (status )
129- if status == 'success' :
130- status = 'FINISHED'
131- self .experiment .set_terminated (self .run_id , status )
144+ status = 'FINISHED' if status == 'success' else status
145+ if self .experiment .get_run (self .run_id ):
146+ self .experiment .set_terminated (self .run_id , status )
147+
148+ @property
149+ def save_dir (self ) -> Optional [str ]:
150+ """
151+ The root file directory in which MLflow experiments are saved.
152+
153+ Return:
154+ Local path to the root experiment directory if the tracking uri is local.
155+ Otherwhise returns `None`.
156+ """
157+ if self ._tracking_uri .startswith (LOCAL_FILE_URI_PREFIX ):
158+ return self ._tracking_uri .lstrip (LOCAL_FILE_URI_PREFIX )
132159
133160 @property
134161 def name (self ) -> str :
135- return self .experiment_name
162+ return self .experiment_id
136163
137164 @property
138165 def version (self ) -> str :
139- return self ._run_id
166+ return self .run_id
0 commit comments