In [3]:
import mlflow
from typing import Any, Dict, Generator, List, MutableMapping, Optional, Union
from mlflow.tracking import MlflowClient
from pathlib import Path
from omegaconf import DictConfig, OmegaConf

In [30]:
class MLFlowLogger():
    def __init__(
        self,
        tracking_uri: str,
        cfg: Union[DictConfig, dict],
        exp_name: Optional[str] = None,
    ):
        super().__init__()
        self.client = MlflowClient(tracking_uri)
        self.cfg = cfg

        if cfg["debug"]:
            exp_name = "Debug"
        else:
            exp_name = cfg["exp_name"]

        self.experiment = self.client.get_experiment_by_name(exp_name)
        if self.experiment is None:
            self.experiment_id = self.client.create_experiment(exp_name)
            self.experiment = self.client.get_experiment(self.experiment_id)
        else:
            self.experiment_id = self.experiment.experiment_id

        # convert hydra config to dict
        self.run = self.client.create_run(self.experiment_id)
        self.run_id = self.run.info.run_id

        self.local_run_dir = (
            Path(".")
            / Path(tracking_uri.lstrip("file:"))
            / self.experiment_id
            / self.run_id
            / "artifacts"
        ).resolve()

    def log_param(self, key, value):
        self.client.log_param(self.run_id, key, value)

    def log_metric(self, key, value, step: Union[str, int], prefix: str = ""):
        print(step)
        self.client.log_metric(self.run_id, prefix + key, value, step)

    def log_metrics(self, metrics: Dict[str, Any], step: Union[str, int], prefix: str = ""):
        for k, v in metrics.items():
            self.log_metric(k, v, step, prefix)

    def log_hparams(self, params: Union[Dict[str, Any], DictConfig]) -> None:

        if isinstance(params, DictConfig):
            params = OmegaConf.to_container(params, resolve=True)

        for k, v in params.items():
            if len(str(v)) > 250:
                
                continue

            self.log_param(k, v)

    def log_artifact(self, local_path, artifact_path=None):
        self.client.log_artifact(self.run_id, local_path, artifact_path)

    def close(self):
        self.client.set_terminated(self.run_id)


In [31]:
logger = MLFlowLogger(tracking_uri="file:./mlruns", cfg={"debug": True})

In [32]:
for i in range(50):
    logger.client.log_metric(logger.client.run.info.run_id, "test", i, i)

AttributeError: 'MlflowClient' object has no attribute 'run'

In [39]:
client = MlflowClient(tracking_uri="file:./mlruns")

In [40]:
run = client.create_run(experiment_id="0")

In [41]:
run_id = run.info.run_id

In [42]:
for i in range(50):
    client.log_metric(run_id, "test", i, i)
client.set_terminated(run_id)

In [43]:
with mlflow.start_run() as run:
    for i in range(50):
        mlflow.log_metric("test", i, i)