diff --git a/mlflow/store/sqlalchemy_store.py b/mlflow/store/sqlalchemy_store.py index dc30c34276ae5..ec91d2be9294f 100644 --- a/mlflow/store/sqlalchemy_store.py +++ b/mlflow/store/sqlalchemy_store.py @@ -1,6 +1,6 @@ import sqlalchemy import uuid - +from contextlib import contextmanager from six.moves import urllib from mlflow.entities.lifecycle_stage import LifecycleStage @@ -11,11 +11,10 @@ from mlflow.entities import ViewType from mlflow.exceptions import MlflowException from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE, RESOURCE_ALREADY_EXISTS, \ - INVALID_STATE, RESOURCE_DOES_NOT_EXIST + INVALID_STATE, RESOURCE_DOES_NOT_EXIST, INTERNAL_ERROR from mlflow.tracking.utils import _is_local_uri from mlflow.utils.file_utils import build_path, mkdir from mlflow.utils.mlflow_tags import MLFLOW_PARENT_RUN_ID, MLFLOW_RUN_NAME -from mlflow.protos.databricks_pb2 import INTERNAL_ERROR class SqlAlchemyStore(AbstractStore): @@ -55,25 +54,54 @@ def __init__(self, db_uri, default_artifact_root): self.engine = sqlalchemy.create_engine(db_uri) Base.metadata.create_all(self.engine) Base.metadata.bind = self.engine - self.SessionMaker = sqlalchemy.orm.sessionmaker(bind=self.engine) - self.session = self.SessionMaker() + SessionMaker = sqlalchemy.orm.sessionmaker(bind=self.engine) + self.ManagedSessionMaker = self._get_managed_session_maker(SessionMaker) if _is_local_uri(default_artifact_root): mkdir(default_artifact_root) if len(self.list_experiments()) == 0: - self._create_default_experiment() + with self.ManagedSessionMaker() as session: + self._create_default_experiment(session) - def _set_no_auto_for_zero_values(self): + @staticmethod + def _get_managed_session_maker(SessionMaker): + """ + Creates a factory for producing exception-safe SQLAlchemy sessions that are made available + using a context manager. Any session produced by this factory is automatically committed + if no exceptions are encountered within its associated context. If an exception is + encountered, the session is rolled back. Finally, any session produced by this factory is + automatically closed when the session's associated context is exited. + """ + + @contextmanager + def make_managed_session(): + """Provide a transactional scope around a series of operations.""" + session = SessionMaker() + try: + yield session + session.commit() + except MlflowException: + session.rollback() + raise + except Exception as e: + session.rollback() + raise MlflowException(message=e, error_code=INTERNAL_ERROR) + finally: + session.close() + + return make_managed_session + + def _set_no_auto_for_zero_values(self, session): if self.db_type == MYSQL: - self.session.execute("SET @@SESSION.sql_mode='NO_AUTO_VALUE_ON_ZERO';") + session.execute("SET @@SESSION.sql_mode='NO_AUTO_VALUE_ON_ZERO';") # DB helper methods to allow zero values for columns with auto increments - def _unset_no_auto_for_zero_values(self): + def _unset_no_auto_for_zero_values(self, session): if self.db_type == MYSQL: - self.session.execute("SET @@SESSION.sql_mode='';") + session.execute("SET @@SESSION.sql_mode='';") - def _create_default_experiment(self): + def _create_default_experiment(self, session): """ MLflow UI and client code expects a default experiment with ID 0. This method uses SQL insert statement to create the default experiment as a hack, since @@ -101,35 +129,31 @@ def decorate(s): values = ", ".join([decorate(default_experiment.get(c)) for c in columns]) try: - self._set_no_auto_for_zero_values() - self.session.execute("INSERT INTO {} ({}) VALUES ({});".format(table, - ", ".join(columns), - values)) + self._set_no_auto_for_zero_values(session) + session.execute("INSERT INTO {} ({}) VALUES ({});".format( + table, ", ".join(columns), values)) finally: - self._unset_no_auto_for_zero_values() - self.session.commit() + self._unset_no_auto_for_zero_values(session) - def _save_to_db(self, objs): + def _save_to_db(self, session, objs): """ Store in db """ if type(objs) is list: - self.session.add_all(objs) + session.add_all(objs) else: # single object - self.session.add(objs) - - self.session.commit() + session.add(objs) - def _get_or_create(self, model, **kwargs): - instance = self.session.query(model).filter_by(**kwargs).first() + def _get_or_create(self, session, model, **kwargs): + instance = session.query(model).filter_by(**kwargs).first() created = False if instance: return instance, created else: instance = model(**kwargs) - self._save_to_db(instance) + self._save_to_db(objs=instance, session=session) created = True return instance, created @@ -141,26 +165,24 @@ def create_experiment(self, name, artifact_location=None): if name is None or name == '': raise MlflowException('Invalid experiment name', INVALID_PARAMETER_VALUE) - new_session = self.SessionMaker() - try: - experiment = SqlExperiment( - name=name, lifecycle_stage=LifecycleStage.ACTIVE, - artifact_location=artifact_location - ) - new_session.add(experiment) - if not artifact_location: - # this requires a double write. The first one to generate an autoincrement-ed ID - eid = new_session.query(SqlExperiment).filter_by(name=name).first().experiment_id - experiment.artifact_location = self._get_artifact_location(eid) - new_session.commit() - except sqlalchemy.exc.IntegrityError as e: - new_session.rollback() - raise MlflowException('Experiment(name={}) already exists. ' - 'Error: {}'.format(name, str(e)), RESOURCE_ALREADY_EXISTS) - - return experiment.experiment_id - - def _list_experiments(self, ids=None, names=None, view_type=ViewType.ACTIVE_ONLY): + with self.ManagedSessionMaker() as session: + try: + experiment = SqlExperiment( + name=name, lifecycle_stage=LifecycleStage.ACTIVE, + artifact_location=artifact_location + ) + session.add(experiment) + if not artifact_location: + # this requires a double write. The first one to generate an autoincrement-ed ID + eid = session.query(SqlExperiment).filter_by(name=name).first().experiment_id + experiment.artifact_location = self._get_artifact_location(eid) + except sqlalchemy.exc.IntegrityError as e: + raise MlflowException('Experiment(name={}) already exists. ' + 'Error: {}'.format(name, str(e)), RESOURCE_ALREADY_EXISTS) + + return experiment.experiment_id + + def _list_experiments(self, session, ids=None, names=None, view_type=ViewType.ACTIVE_ONLY): stages = LifecycleStage.view_type_to_stages(view_type) conditions = [SqlExperiment.lifecycle_stage.in_(stages)] @@ -170,13 +192,16 @@ def _list_experiments(self, ids=None, names=None, view_type=ViewType.ACTIVE_ONLY if names and len(names) > 0: conditions.append(SqlExperiment.name.in_(names)) - return self.session.query(SqlExperiment).filter(*conditions) + return session.query(SqlExperiment).filter(*conditions) def list_experiments(self, view_type=ViewType.ACTIVE_ONLY): - return [exp.to_mlflow_entity() for exp in self._list_experiments(view_type=view_type)] + with self.ManagedSessionMaker() as session: + return [exp.to_mlflow_entity() for exp in + self._list_experiments(session=session, view_type=view_type)] - def _get_experiment(self, experiment_id, view_type): - experiments = self._list_experiments(ids=[experiment_id], view_type=view_type).all() + def _get_experiment(self, session, experiment_id, view_type): + experiments = self._list_experiments( + session=session, ids=[experiment_id], view_type=view_type).all() if len(experiments) == 0: raise MlflowException('No Experiment with id={} exists'.format(experiment_id), RESOURCE_DOES_NOT_EXIST) @@ -187,71 +212,78 @@ def _get_experiment(self, experiment_id, view_type): return experiments[0] def get_experiment(self, experiment_id): - return self._get_experiment(experiment_id, ViewType.ALL).to_mlflow_entity() + with self.ManagedSessionMaker() as session: + return self._get_experiment(session, experiment_id, ViewType.ALL).to_mlflow_entity() def get_experiment_by_name(self, experiment_name): """ Specialized implementation for SQL backed store. """ - experiments = self._list_experiments(names=[experiment_name], view_type=ViewType.ALL).all() - if len(experiments) == 0: - return None + with self.ManagedSessionMaker() as session: + experiments = self._list_experiments( + names=[experiment_name], view_type=ViewType.ALL, session=session).all() + if len(experiments) == 0: + return None - if len(experiments) > 1: - raise MlflowException('Expected only 1 experiment with name={}. Found {}.'.format( - experiment_name, len(experiments)), INVALID_STATE) + if len(experiments) > 1: + raise MlflowException('Expected only 1 experiment with name={}. Found {}.'.format( + experiment_name, len(experiments)), INVALID_STATE) - return experiments[0] + return experiments[0].to_mlflow_entity() def delete_experiment(self, experiment_id): - experiment = self._get_experiment(experiment_id, ViewType.ACTIVE_ONLY) - experiment.lifecycle_stage = LifecycleStage.DELETED - self._save_to_db(experiment) + with self.ManagedSessionMaker() as session: + experiment = self._get_experiment(session, experiment_id, ViewType.ACTIVE_ONLY) + experiment.lifecycle_stage = LifecycleStage.DELETED + self._save_to_db(objs=experiment, session=session) def restore_experiment(self, experiment_id): - experiment = self._get_experiment(experiment_id, ViewType.DELETED_ONLY) - experiment.lifecycle_stage = LifecycleStage.ACTIVE - self._save_to_db(experiment) + with self.ManagedSessionMaker() as session: + experiment = self._get_experiment(session, experiment_id, ViewType.DELETED_ONLY) + experiment.lifecycle_stage = LifecycleStage.ACTIVE + self._save_to_db(objs=experiment, session=session) def rename_experiment(self, experiment_id, new_name): - experiment = self._get_experiment(experiment_id, ViewType.ALL) - if experiment.lifecycle_stage != LifecycleStage.ACTIVE: - raise MlflowException('Cannot rename a non-active experiment.', INVALID_STATE) + with self.ManagedSessionMaker() as session: + experiment = self._get_experiment(session, experiment_id, ViewType.ALL) + if experiment.lifecycle_stage != LifecycleStage.ACTIVE: + raise MlflowException('Cannot rename a non-active experiment.', INVALID_STATE) - experiment.name = new_name - self._save_to_db(experiment) + experiment.name = new_name + self._save_to_db(objs=experiment, session=session) def create_run(self, experiment_id, user_id, run_name, source_type, source_name, entry_point_name, start_time, source_version, tags, parent_run_id): - experiment = self.get_experiment(experiment_id) + with self.ManagedSessionMaker() as session: + experiment = self.get_experiment(experiment_id) + + if experiment.lifecycle_stage != LifecycleStage.ACTIVE: + raise MlflowException('Experiment id={} must be active'.format(experiment_id), + INVALID_STATE) + + run_uuid = uuid.uuid4().hex + artifact_location = build_path(experiment.artifact_location, run_uuid, + SqlAlchemyStore.ARTIFACTS_FOLDER_NAME) + run = SqlRun(name=run_name or "", artifact_uri=artifact_location, run_uuid=run_uuid, + experiment_id=experiment_id, source_type=SourceType.to_string(source_type), + source_name=source_name, entry_point_name=entry_point_name, + user_id=user_id, status=RunStatus.to_string(RunStatus.RUNNING), + start_time=start_time, end_time=None, + source_version=source_version, lifecycle_stage=LifecycleStage.ACTIVE) - if experiment.lifecycle_stage != LifecycleStage.ACTIVE: - raise MlflowException('Experiment id={} must be active'.format(experiment_id), - INVALID_STATE) - - run_uuid = uuid.uuid4().hex - artifact_location = build_path(experiment.artifact_location, run_uuid, - SqlAlchemyStore.ARTIFACTS_FOLDER_NAME) - run = SqlRun(name=run_name or "", artifact_uri=artifact_location, run_uuid=run_uuid, - experiment_id=experiment_id, source_type=SourceType.to_string(source_type), - source_name=source_name, entry_point_name=entry_point_name, - user_id=user_id, status=RunStatus.to_string(RunStatus.RUNNING), - start_time=start_time, end_time=None, - source_version=source_version, lifecycle_stage=LifecycleStage.ACTIVE) - - for tag in tags: - run.tags.append(SqlTag(key=tag.key, value=tag.value)) - if parent_run_id: - run.tags.append(SqlTag(key=MLFLOW_PARENT_RUN_ID, value=parent_run_id)) - if run_name: - run.tags.append(SqlTag(key=MLFLOW_RUN_NAME, value=run_name)) + for tag in tags: + run.tags.append(SqlTag(key=tag.key, value=tag.value)) + if parent_run_id: + run.tags.append(SqlTag(key=MLFLOW_PARENT_RUN_ID, value=parent_run_id)) + if run_name: + run.tags.append(SqlTag(key=MLFLOW_RUN_NAME, value=run_name)) - self._save_to_db([run]) + self._save_to_db(objs=run, session=session) - return run.to_mlflow_entity() + return run.to_mlflow_entity() - def _get_run(self, run_uuid): - runs = self.session.query(SqlRun).filter(SqlRun.run_uuid == run_uuid).all() + def _get_run(self, session, run_uuid): + runs = session.query(SqlRun).filter(SqlRun.run_uuid == run_uuid).all() if len(runs) == 0: raise MlflowException('Run with id={} not found'.format(run_uuid), @@ -276,109 +308,140 @@ def _check_run_is_deleted(self, run): INVALID_PARAMETER_VALUE) def update_run_info(self, run_uuid, run_status, end_time): - run = self._get_run(run_uuid) - self._check_run_is_active(run) - run.status = RunStatus.to_string(run_status) - run.end_time = end_time + with self.ManagedSessionMaker() as session: + run = self._get_run(run_uuid=run_uuid, session=session) + self._check_run_is_active(run) + run.status = RunStatus.to_string(run_status) + run.end_time = end_time - self._save_to_db(run) - run = run.to_mlflow_entity() + self._save_to_db(objs=run, session=session) + run = run.to_mlflow_entity() - return run.info + return run.info def get_run(self, run_uuid): - run = self._get_run(run_uuid) - return run.to_mlflow_entity() + with self.ManagedSessionMaker() as session: + run = self._get_run(run_uuid=run_uuid, session=session) + return run.to_mlflow_entity() def restore_run(self, run_id): - run = self._get_run(run_id) - self._check_run_is_deleted(run) - run.lifecycle_stage = LifecycleStage.ACTIVE - self._save_to_db(run) + with self.ManagedSessionMaker() as session: + run = self._get_run(run_uuid=run_id, session=session) + self._check_run_is_deleted(run) + run.lifecycle_stage = LifecycleStage.ACTIVE + self._save_to_db(objs=run, session=session) def delete_run(self, run_id): - run = self._get_run(run_id) - self._check_run_is_active(run) - run.lifecycle_stage = LifecycleStage.DELETED - self._save_to_db(run) + with self.ManagedSessionMaker() as session: + run = self._get_run(run_uuid=run_id, session=session) + self._check_run_is_active(run) + run.lifecycle_stage = LifecycleStage.DELETED + self._save_to_db(objs=run, session=session) def log_metric(self, run_uuid, metric): - run = self._get_run(run_uuid) - self._check_run_is_active(run) - try: - # This will check for various integrity checks for metrics table. - # ToDo: Consider prior checks for null, type, metric name validations, ... etc. - self._get_or_create(SqlMetric, run_uuid=run_uuid, key=metric.key, - value=metric.value, timestamp=metric.timestamp) - except sqlalchemy.exc.IntegrityError as ie: - # Querying metrics from run entails pushing the query down to DB layer. - # Hence the rollback. - self.session.rollback() - existing_metric = [m for m in run.metrics - if m.key == metric.key and m.timestamp == metric.timestamp] - if len(existing_metric) == 0: - raise MlflowException("Log metric request failed for run ID={}. Attempted to log" - " metric={}. Error={}".format(run_uuid, - (metric.key, metric.value), - str(ie))) - else: - m = existing_metric[0] - raise MlflowException('Metric={} must be unique. Metric already logged value {} ' - 'at {}'.format(metric, m.value, m.timestamp), - INVALID_PARAMETER_VALUE) + with self.ManagedSessionMaker() as session: + run = self._get_run(run_uuid=run_uuid, session=session) + self._check_run_is_active(run) + + try: + self._get_or_create(model=SqlMetric, run_uuid=run_uuid, key=metric.key, + value=metric.value, timestamp=metric.timestamp, session=session) + # Explicitly commit the session in order to catch potential integrity errors + # while maintaining the current managed session scope ("commit" checks that + # a transaction satisfies uniqueness constraints and throws integrity errors + # when they are violated; "get_or_create()" does not perform these checks). It is + # important that we maintain the same session scope because, in the case of + # an integrity error, we want to examine the uniqueness of metric (timestamp, value) + # tuples using the same database state that the session uses during "commit". + # Creating a new session synchronizes the state with the database. As a result, if + # the conflicting (timestamp, value) tuple were to be removed prior to the creation + # of a new session, we would be unable to determine the cause of failure for the + # first session's "commit" operation. + session.commit() + except sqlalchemy.exc.IntegrityError: + # Roll back the current session to make it usable for further transactions. In the + # event of an error during "commit", a rollback is required in order to continue + # using the session. In this case, we re-use the session because the SqlRun, `run`, + # is lazily evaluated during the invocation of `run.metrics`. + session.rollback() + existing_metric = [m for m in run.metrics + if m.key == metric.key and m.timestamp == metric.timestamp] + if len(existing_metric) > 0: + m = existing_metric[0] + raise MlflowException( + "Metric={} must be unique. Metric already logged value {}" + " at {}".format(metric, m.value, m.timestamp), INVALID_PARAMETER_VALUE) + else: + raise def get_metric_history(self, run_uuid, metric_key): - metrics = self.session.query(SqlMetric).filter_by(run_uuid=run_uuid, key=metric_key).all() - return [metric.to_mlflow_entity() for metric in metrics] + with self.ManagedSessionMaker() as session: + metrics = session.query(SqlMetric).filter_by(run_uuid=run_uuid, key=metric_key).all() + return [metric.to_mlflow_entity() for metric in metrics] def log_param(self, run_uuid, param): - run = self._get_run(run_uuid) - self._check_run_is_active(run) - # if we try to update the value of an existing param this will fail - # because it will try to create it with same run_uuid, param key - try: - # This will check for various integrity checks for params table. - # ToDo: Consider prior checks for null, type, param name validations, ... etc. - self._get_or_create(SqlParam, run_uuid=run_uuid, key=param.key, - value=param.value) - except sqlalchemy.exc.IntegrityError as ie: - # Querying metrics from run entails pushing the query down to DB layer. - # Hence the rollback. - self.session.rollback() - existing_params = [p.value for p in run.params if p.key == param.key] - if len(existing_params) == 0: - raise MlflowException("Log param request failed for run ID={}. Attempted to log" - " param={}. Error={}".format(run_uuid, - (param.key, param.value), - str(ie))) - else: - old_value = existing_params[0] - raise MlflowException("Changing param value is not allowed. Param with key='{}' was" - " already logged with value='{}' for run ID='{}. Attempted " - " logging new value '{}'.".format(param.key, old_value, - run_uuid, param.value), - INVALID_PARAMETER_VALUE) + with self.ManagedSessionMaker() as session: + run = self._get_run(run_uuid=run_uuid, session=session) + self._check_run_is_active(run) + # if we try to update the value of an existing param this will fail + # because it will try to create it with same run_uuid, param key + try: + # This will check for various integrity checks for params table. + # ToDo: Consider prior checks for null, type, param name validations, ... etc. + self._get_or_create(model=SqlParam, session=session, run_uuid=run_uuid, + key=param.key, value=param.value) + # Explicitly commit the session in order to catch potential integrity errors + # while maintaining the current managed session scope ("commit" checks that + # a transaction satisfies uniqueness constraints and throws integrity errors + # when they are violated; "get_or_create()" does not perform these checks). It is + # important that we maintain the same session scope because, in the case of + # an integrity error, we want to examine the uniqueness of parameter values using + # the same database state that the session uses during "commit". Creating a new + # session synchronizes the state with the database. As a result, if the conflicting + # parameter value were to be removed prior to the creation of a new session, + # we would be unable to determine the cause of failure for the first session's + # "commit" operation. + session.commit() + except sqlalchemy.exc.IntegrityError: + # Roll back the current session to make it usable for further transactions. In the + # event of an error during "commit", a rollback is required in order to continue + # using the session. In this case, we re-use the session because the SqlRun, `run`, + # is lazily evaluated during the invocation of `run.params`. + session.rollback() + existing_params = [p.value for p in run.params if p.key == param.key] + if len(existing_params) > 0: + old_value = existing_params[0] + raise MlflowException( + "Changing param value is not allowed. Param with key='{}' was already" + " logged with value='{}' for run ID='{}. Attempted logging new value" + " '{}'.".format( + param.key, old_value, run_uuid, param.value), INVALID_PARAMETER_VALUE) + else: + raise def set_tag(self, run_uuid, tag): - run = self._get_run(run_uuid) - self._check_run_is_active(run) - self.session.merge(SqlTag(run_uuid=run_uuid, key=tag.key, value=tag.value)) - self.session.commit() + with self.ManagedSessionMaker() as session: + run = self._get_run(run_uuid=run_uuid, session=session) + self._check_run_is_active(run) + session.merge(SqlTag(run_uuid=run_uuid, key=tag.key, value=tag.value)) def search_runs(self, experiment_ids, search_filter, run_view_type): - runs = [run.to_mlflow_entity() - for exp in experiment_ids - for run in self._list_runs(exp, run_view_type)] - return [run for run in runs if not search_filter or search_filter.filter(run)] - - def _list_runs(self, experiment_id, run_view_type): - exp = self._list_experiments(ids=[experiment_id], view_type=ViewType.ALL).first() + with self.ManagedSessionMaker() as session: + runs = [run.to_mlflow_entity() + for exp in experiment_ids + for run in self._list_runs(session, exp, run_view_type)] + return [run for run in runs if not search_filter or search_filter.filter(run)] + + def _list_runs(self, session, experiment_id, run_view_type): + exp = self._list_experiments( + ids=[experiment_id], view_type=ViewType.ALL, session=session).first() stages = set(LifecycleStage.view_type_to_stages(run_view_type)) return [run for run in exp.runs if run.lifecycle_stage in stages] def log_batch(self, run_id, metrics, params, tags): - run = self._get_run(run_id) - self._check_run_is_active(run) + with self.ManagedSessionMaker() as session: + run = self._get_run(run_uuid=run_id, session=session) + self._check_run_is_active(run) try: for param in params: self.log_param(run_id, param) diff --git a/tests/store/test_sqlalchemy_store.py b/tests/store/test_sqlalchemy_store.py index 67a1ab1b89542..19cf3cc6733ab 100644 --- a/tests/store/test_sqlalchemy_store.py +++ b/tests/store/test_sqlalchemy_store.py @@ -1,17 +1,17 @@ import shutil +import six import unittest import warnings import mock -import sqlalchemy import time import mlflow import uuid -from mlflow.entities import Metric, Param -from mlflow.entities import ViewType, RunTag, SourceType, RunStatus +from mlflow.entities import ViewType, RunTag, SourceType, RunStatus, Experiment, Metric, Param from mlflow.protos.service_pb2 import SearchRuns, SearchExpression -from mlflow.protos.databricks_pb2 import ErrorCode, RESOURCE_DOES_NOT_EXIST, INVALID_PARAMETER_VALUE +from mlflow.protos.databricks_pb2 import ErrorCode, RESOURCE_DOES_NOT_EXIST,\ + INVALID_PARAMETER_VALUE, INTERNAL_ERROR from mlflow.store.dbmodels import models from mlflow import entities from mlflow.exceptions import MlflowException @@ -25,15 +25,14 @@ class TestSqlAlchemyStoreSqliteInMemory(unittest.TestCase): + def _setup_database(self, filename=''): # use a static file name to initialize sqllite to test retention. self.store = SqlAlchemyStore(DB_URI + filename, ARTIFACT_URI) - self.session = self.store.session def setUp(self): self.maxDiff = None # print all differences on assert failures self.store = None - self.session = None self._setup_database() def tearDown(self): @@ -54,9 +53,8 @@ def _verify_logged(self, run_uuid, metrics, params, tags): assert len(all_metrics) == len(metrics) logged_metrics = [(m.key, m.value, m.timestamp) for m in all_metrics] assert set(logged_metrics) == set([(m.key, m.value, m.timestamp) for m in metrics]) - assert len(run.data.tags) == len(tags) - logged_tags = [(tag.key, tag.value) for tag in run.data.tags] - assert set(logged_tags) == set([(tag.key, tag.value) for tag in tags]) + logged_tags = set([(tag.key, tag.value) for tag in run.data.tags]) + assert set([(tag.key, tag.value) for tag in tags]) <= logged_tags assert len(run.data.params) == len(params) logged_params = [(param.key, param.value) for param in run.data.params] assert set(logged_params) == set([(param.key, param.value) for param in params]) @@ -73,43 +71,42 @@ def test_default_experiment_lifecycle(self): with TempDir(chdr=True) as tmp: tmp_file_name = "sqlite_file_to_lifecycle_test_{}.db".format(int(time.time())) self._setup_database("/" + tmp.path(tmp_file_name)) - default = self.session.query(models.SqlExperiment).filter_by(name='Default').first() - self.assertEqual(default.experiment_id, 0) - self.assertEqual(default.lifecycle_stage, entities.LifecycleStage.ACTIVE) + + default_experiment = self.store.get_experiment(experiment_id=0) + self.assertEqual(default_experiment.name, Experiment.DEFAULT_EXPERIMENT_NAME) + self.assertEqual(default_experiment.lifecycle_stage, entities.LifecycleStage.ACTIVE) self._experiment_factory('aNothEr') all_experiments = [e.name for e in self.store.list_experiments()] - - self.assertSequenceEqual(set(['aNothEr', 'Default']), set(all_experiments)) + six.assertCountEqual(self, set(['aNothEr', 'Default']), set(all_experiments)) self.store.delete_experiment(0) - self.assertSequenceEqual(['aNothEr'], [e.name for e in self.store.list_experiments()]) + six.assertCountEqual(self, ['aNothEr'], [e.name for e in self.store.list_experiments()]) another = self.store.get_experiment(1) self.assertEqual('aNothEr', another.name) - default = self.session.query(models.SqlExperiment).filter_by(name='Default').first() - self.assertEqual(default.experiment_id, 0) - self.assertEqual(default.lifecycle_stage, entities.LifecycleStage.DELETED) + default_experiment = self.store.get_experiment(experiment_id=0) + self.assertEqual(default_experiment.name, Experiment.DEFAULT_EXPERIMENT_NAME) + self.assertEqual(default_experiment.lifecycle_stage, entities.LifecycleStage.DELETED) # destroy SqlStore and make a new one del self.store self._setup_database("/" + tmp.path(tmp_file_name)) # test that default experiment is not reactivated - default = self.session.query(models.SqlExperiment).filter_by(name='Default').first() - self.assertEqual(default.experiment_id, 0) - self.assertEqual(default.lifecycle_stage, entities.LifecycleStage.DELETED) + default_experiment = self.store.get_experiment(experiment_id=0) + self.assertEqual(default_experiment.name, Experiment.DEFAULT_EXPERIMENT_NAME) + self.assertEqual(default_experiment.lifecycle_stage, entities.LifecycleStage.DELETED) - self.assertSequenceEqual(['aNothEr'], [e.name for e in self.store.list_experiments()]) + six.assertCountEqual(self, ['aNothEr'], [e.name for e in self.store.list_experiments()]) all_experiments = [e.name for e in self.store.list_experiments(ViewType.ALL)] - self.assertSequenceEqual(set(['aNothEr', 'Default']), set(all_experiments)) + six.assertCountEqual(self, set(['aNothEr', 'Default']), set(all_experiments)) # ensure that experiment ID dor active experiment is unchanged another = self.store.get_experiment(1) self.assertEqual('aNothEr', another.name) - self.session.close() self.store = None def test_raise_duplicate_experiments(self): @@ -126,13 +123,13 @@ def test_delete_experiment(self): all_experiments = self.store.list_experiments() self.assertEqual(len(all_experiments), len(experiments) + 1) # default - exp = experiments[0] - self.store.delete_experiment(exp) + exp_id = experiments[0] + self.store.delete_experiment(exp_id) - actual = self.session.query(models.SqlExperiment).get(exp) - self.assertEqual(len(self.store.list_experiments()), len(all_experiments) - 1) + updated_exp = self.store.get_experiment(exp_id) + self.assertEqual(updated_exp.lifecycle_stage, entities.LifecycleStage.DELETED) - self.assertEqual(actual.lifecycle_stage, entities.LifecycleStage.DELETED) + self.assertEqual(len(self.store.list_experiments()), len(all_experiments) - 1) def test_get_experiment(self): name = 'goku' @@ -153,97 +150,120 @@ def test_list_experiments(self): self.assertEqual(len(experiments) + 1, len(actual)) # default - for experiment_id in experiments: - res = self.session.query(models.SqlExperiment).filter_by( - experiment_id=experiment_id).first() - self.assertIn(res.name, testnames) - self.assertEqual(res.experiment_id, experiment_id) + with self.store.ManagedSessionMaker() as session: + for experiment_id in experiments: + res = session.query(models.SqlExperiment).filter_by( + experiment_id=experiment_id).first() + self.assertIn(res.name, testnames) + self.assertEqual(res.experiment_id, experiment_id) def test_create_experiments(self): - result = self.session.query(models.SqlExperiment).all() - self.assertEqual(len(result), 1) + with self.store.ManagedSessionMaker() as session: + result = session.query(models.SqlExperiment).all() + self.assertEqual(len(result), 1) experiment_id = self.store.create_experiment(name='test exp') - result = self.session.query(models.SqlExperiment).all() - self.assertEqual(len(result), 2) - test_exp = self.session.query(models.SqlExperiment).filter_by(name='test exp').first() + with self.store.ManagedSessionMaker() as session: + result = session.query(models.SqlExperiment).all() + self.assertEqual(len(result), 2) - self.assertEqual(test_exp.experiment_id, experiment_id) - self.assertEqual(test_exp.name, 'test exp') + test_exp = session.query(models.SqlExperiment).filter_by(name='test exp').first() + self.assertEqual(test_exp.experiment_id, experiment_id) + self.assertEqual(test_exp.name, 'test exp') actual = self.store.get_experiment(experiment_id) self.assertEqual(actual.experiment_id, experiment_id) self.assertEqual(actual.name, 'test exp') def test_run_tag_model(self): - run_data = models.SqlTag(run_uuid='tuuid', key='test', value='val') - self.session.add(run_data) - self.session.commit() - tags = self.session.query(models.SqlTag).all() - self.assertEqual(len(tags), 1) - - actual = tags[0].to_mlflow_entity() - - self.assertEqual(actual.value, run_data.value) - self.assertEqual(actual.key, run_data.key) + # Create a run whose UUID we can reference when creating tag models. + # `run_uuid` is a foreign key in the tags table; therefore, in order + # to insert a tag with a given run UUID, the UUID must be present in + # the runs table + run = self._run_factory() + with self.store.ManagedSessionMaker() as session: + new_tag = models.SqlTag(run_uuid=run.info.run_uuid, key='test', value='val') + session.add(new_tag) + session.commit() + added_tags = [ + tag for tag in session.query(models.SqlTag).all() + if tag.key == new_tag.key + ] + self.assertEqual(len(added_tags), 1) + added_tag = added_tags[0].to_mlflow_entity() + self.assertEqual(added_tag.value, new_tag.value) def test_metric_model(self): - run_data = models.SqlMetric(run_uuid='testuid', key='accuracy', value=0.89) - self.session.add(run_data) - self.session.commit() - metrics = self.session.query(models.SqlMetric).all() - self.assertEqual(len(metrics), 1) - - actual = metrics[0].to_mlflow_entity() + # Create a run whose UUID we can reference when creating metric models. + # `run_uuid` is a foreign key in the tags table; therefore, in order + # to insert a metric with a given run UUID, the UUID must be present in + # the runs table + run = self._run_factory() + with self.store.ManagedSessionMaker() as session: + new_metric = models.SqlMetric(run_uuid=run.info.run_uuid, key='accuracy', value=0.89) + session.add(new_metric) + session.commit() + metrics = session.query(models.SqlMetric).all() + self.assertEqual(len(metrics), 1) - self.assertEqual(actual.value, run_data.value) - self.assertEqual(actual.key, run_data.key) + added_metric = metrics[0].to_mlflow_entity() + self.assertEqual(added_metric.value, new_metric.value) + self.assertEqual(added_metric.key, new_metric.key) def test_param_model(self): - run_data = models.SqlParam(run_uuid='test', key='accuracy', value='test param') - self.session.add(run_data) - self.session.commit() - params = self.session.query(models.SqlParam).all() - self.assertEqual(len(params), 1) - - actual = params[0].to_mlflow_entity() - - self.assertEqual(actual.value, run_data.value) - self.assertEqual(actual.key, run_data.key) + # Create a run whose UUID we can reference when creating parameter models. + # `run_uuid` is a foreign key in the tags table; therefore, in order + # to insert a parameter with a given run UUID, the UUID must be present in + # the runs table + run = self._run_factory() + with self.store.ManagedSessionMaker() as session: + new_param = models.SqlParam( + run_uuid=run.info.run_uuid, key='accuracy', value='test param') + session.add(new_param) + session.commit() + params = session.query(models.SqlParam).all() + self.assertEqual(len(params), 1) + + added_param = params[0].to_mlflow_entity() + self.assertEqual(added_param.value, new_param.value) + self.assertEqual(added_param.key, new_param.key) def test_run_needs_uuid(self): - run = models.SqlRun() - self.session.add(run) - - with self.assertRaises(sqlalchemy.exc.IntegrityError): + # Depending on the implementation, a NULL identity key may result in different + # exceptions, including IntegrityError (sqlite) and FlushError (MysQL). + # Therefore, we check for the more generic 'SQLAlchemyError' + with self.assertRaises(MlflowException) as exception_context: warnings.simplefilter("ignore") - with warnings.catch_warnings(): - self.session.commit() - warnings.resetwarnings() + with self.store.ManagedSessionMaker() as session, warnings.catch_warnings(): + run = models.SqlRun() + session.add(run) + warnings.resetwarnings() + assert exception_context.exception.error_code == ErrorCode.Name(INTERNAL_ERROR) def test_run_data_model(self): - m1 = models.SqlMetric(key='accuracy', value=0.89) - m2 = models.SqlMetric(key='recal', value=0.89) - p1 = models.SqlParam(key='loss', value='test param') - p2 = models.SqlParam(key='blue', value='test param') + with self.store.ManagedSessionMaker() as session: + m1 = models.SqlMetric(key='accuracy', value=0.89) + m2 = models.SqlMetric(key='recal', value=0.89) + p1 = models.SqlParam(key='loss', value='test param') + p2 = models.SqlParam(key='blue', value='test param') - self.session.add_all([m1, m2, p1, p2]) + session.add_all([m1, m2, p1, p2]) - run_data = models.SqlRun(run_uuid=uuid.uuid4().hex) - run_data.params.append(p1) - run_data.params.append(p2) - run_data.metrics.append(m1) - run_data.metrics.append(m2) + run_data = models.SqlRun(run_uuid=uuid.uuid4().hex) + run_data.params.append(p1) + run_data.params.append(p2) + run_data.metrics.append(m1) + run_data.metrics.append(m2) - self.session.add(run_data) - self.session.commit() + session.add(run_data) + session.commit() - run_datums = self.session.query(models.SqlRun).all() - actual = run_datums[0] - self.assertEqual(len(run_datums), 1) - self.assertEqual(len(actual.params), 2) - self.assertEqual(len(actual.metrics), 2) + run_datums = session.query(models.SqlRun).all() + actual = run_datums[0] + self.assertEqual(len(run_datums), 1) + self.assertEqual(len(actual.params), 2) + self.assertEqual(len(actual.metrics), 2) def test_run_info(self): experiment_id = self._experiment_factory('test exp') @@ -273,21 +293,18 @@ def test_run_info(self): else: self.assertEqual(v, v2) - def _get_run_configs(self, name='test', experiment_id=None): + def _get_run_configs(self, name='test', experiment_id=None, tags=(), parent_run_id=None): return { 'experiment_id': experiment_id, - 'name': name, + 'run_name': name, 'user_id': 'Anderson', - 'run_uuid': uuid.uuid4().hex, - 'status': RunStatus.to_string(RunStatus.SCHEDULED), - 'source_type': SourceType.to_string(SourceType.NOTEBOOK), + 'source_type': SourceType.NOTEBOOK, 'source_name': 'Python application', 'entry_point_name': 'main.py', 'start_time': int(time.time()), - 'end_time': int(time.time()), 'source_version': mlflow.__version__, - 'lifecycle_stage': entities.LifecycleStage.ACTIVE, - 'artifact_uri': '//' + 'tags': tags, + 'parent_run_id': parent_run_id, } def _run_factory(self, config=None): @@ -299,66 +316,60 @@ def _run_factory(self, config=None): experiment_id = self._experiment_factory('test exp') config["experiment_id"] = experiment_id - run = models.SqlRun(**config) - self.session.add(run) - - return run + return self.store.create_run(**config) - def test_create_run(self): + def test_create_run_with_tags(self): + run_name = "test-run-1" experiment_id = self._experiment_factory('test_create_run') - expected = self._get_run_configs('booyya', experiment_id=experiment_id) - tags = [RunTag('3', '4'), RunTag('1', '2')] - actual = self.store.create_run(expected["experiment_id"], expected["user_id"], - expected["name"], - SourceType.from_string(expected["source_type"]), - expected["source_name"], expected["entry_point_name"], - expected["start_time"], expected["source_version"], - tags, None) - - self.assertEqual(actual.info.experiment_id, expected["experiment_id"]) + expected = self._get_run_configs(name=run_name, experiment_id=experiment_id, tags=tags) + + actual = self.store.create_run(**expected) + + self.assertEqual(actual.info.experiment_id, experiment_id) self.assertEqual(actual.info.user_id, expected["user_id"]) - self.assertEqual(actual.info.name, 'booyya') - self.assertEqual(actual.info.source_type, SourceType.from_string(expected["source_type"])) + self.assertEqual(actual.info.name, run_name) + self.assertEqual(actual.info.source_type, expected["source_type"]) self.assertEqual(actual.info.source_name, expected["source_name"]) self.assertEqual(actual.info.source_version, expected["source_version"]) self.assertEqual(actual.info.entry_point_name, expected["entry_point_name"]) self.assertEqual(actual.info.start_time, expected["start_time"]) - self.assertEqual(len(actual.data.tags), 3) - name_tag = models.SqlTag(key=MLFLOW_RUN_NAME, value='booyya').to_mlflow_entity() + # Run creation should add an additional tag containing the run name. Check for + # its existence + self.assertEqual(len(actual.data.tags), len(tags) + 1) + name_tag = models.SqlTag(key=MLFLOW_RUN_NAME, value=run_name).to_mlflow_entity() self.assertListEqual(actual.data.tags, tags + [name_tag]) def test_create_run_with_parent_id(self): - exp = self._experiment_factory('test_create_run_with_parent_id') - expected = self._get_run_configs('booyya', experiment_id=exp) + run_name = "test-run-1" + parent_run_id = "parent_uuid_5" + experiment_id = self._experiment_factory('test_create_run') + expected = self._get_run_configs( + name=run_name, experiment_id=experiment_id, parent_run_id=parent_run_id) - tags = [RunTag('3', '4'), RunTag('1', '2')] - actual = self.store.create_run(expected["experiment_id"], expected["user_id"], - expected["name"], - SourceType.from_string(expected["source_type"]), - expected["source_name"], expected["entry_point_name"], - expected["start_time"], expected["source_version"], - tags, "parent_uuid_5") - - self.assertEqual(actual.info.experiment_id, expected["experiment_id"]) + actual = self.store.create_run(**expected) + + self.assertEqual(actual.info.experiment_id, experiment_id) self.assertEqual(actual.info.user_id, expected["user_id"]) - self.assertEqual(actual.info.name, 'booyya') - self.assertEqual(actual.info.source_type, SourceType.from_string(expected["source_type"])) + self.assertEqual(actual.info.name, run_name) + self.assertEqual(actual.info.source_type, expected["source_type"]) self.assertEqual(actual.info.source_name, expected["source_name"]) self.assertEqual(actual.info.source_version, expected["source_version"]) self.assertEqual(actual.info.entry_point_name, expected["entry_point_name"]) self.assertEqual(actual.info.start_time, expected["start_time"]) - self.assertEqual(len(actual.data.tags), 4) - name_tag = models.SqlTag(key=MLFLOW_RUN_NAME, value='booyya').to_mlflow_entity() + # Run creation should add two additional tags containing the run name and parent run id. + # Check for the existence of these two tags + self.assertEqual(len(actual.data.tags), 2) + name_tag = models.SqlTag(key=MLFLOW_RUN_NAME, value=run_name).to_mlflow_entity() parent_id_tag = models.SqlTag(key=MLFLOW_PARENT_RUN_ID, - value='parent_uuid_5').to_mlflow_entity() - self.assertListEqual(actual.data.tags, tags + [parent_id_tag, name_tag]) + value=parent_run_id).to_mlflow_entity() + self.assertListEqual(actual.data.tags, [parent_id_tag, name_tag]) def test_to_mlflow_entity(self): + # Create a run and obtain an MLflow Run entity associated with the new run run = self._run_factory() - run = run.to_mlflow_entity() self.assertIsInstance(run.info, entities.RunInfo) self.assertIsInstance(run.data, entities.RunData) @@ -374,41 +385,27 @@ def test_to_mlflow_entity(self): def test_delete_run(self): run = self._run_factory() - self.session.commit() - run_uuid = run.run_uuid - self.store.delete_run(run_uuid) - actual = self.session.query(models.SqlRun).filter_by(run_uuid=run_uuid).first() - self.assertEqual(actual.lifecycle_stage, entities.LifecycleStage.DELETED) + self.store.delete_run(run.info.run_uuid) - deleted_run = self.store.get_run(run_uuid) - self.assertEqual(actual.run_uuid, deleted_run.info.run_uuid) + with self.store.ManagedSessionMaker() as session: + actual = session.query(models.SqlRun).filter_by(run_uuid=run.info.run_uuid).first() + self.assertEqual(actual.lifecycle_stage, entities.LifecycleStage.DELETED) + + deleted_run = self.store.get_run(run.info.run_uuid) + self.assertEqual(actual.run_uuid, deleted_run.info.run_uuid) def test_log_metric(self): run = self._run_factory() - self.session.commit() - tkey = 'blahmetric' tval = 100.0 metric = entities.Metric(tkey, tval, int(time.time())) metric2 = entities.Metric(tkey, tval, int(time.time()) + 2) - self.store.log_metric(run.run_uuid, metric) - self.store.log_metric(run.run_uuid, metric2) - - actual = self.session.query(models.SqlMetric).filter_by(key=tkey, value=tval) - - self.assertIsNotNone(actual) - - run = self.store.get_run(run.run_uuid) - - # SQL store _get_run method returns full history of recorded metrics. - # Should return duplicates as well - # MLflow RunData contains only the last reported values for metrics. - sql_run_metrics = self.store._get_run(run.info.run_uuid).metrics - self.assertEqual(2, len(sql_run_metrics)) - self.assertEqual(1, len(run.data.metrics)) + self.store.log_metric(run.info.run_uuid, metric) + self.store.log_metric(run.info.run_uuid, metric2) + run = self.store.get_run(run.info.run_uuid) found = False for m in run.data.metrics: if m.key == tkey and m.value == tval: @@ -416,51 +413,49 @@ def test_log_metric(self): self.assertTrue(found) + # SQL store _get_run method returns full history of recorded metrics. + # Should return duplicates as well + # MLflow RunData contains only the last reported values for metrics. + with self.store.ManagedSessionMaker() as session: + sql_run_metrics = self.store._get_run(session, run.info.run_uuid).metrics + self.assertEqual(2, len(sql_run_metrics)) + self.assertEqual(1, len(run.data.metrics)) + def test_log_metric_uniqueness(self): run = self._run_factory() - self.session.commit() - tkey = 'blahmetric' tval = 100.0 metric = entities.Metric(tkey, tval, int(time.time())) metric2 = entities.Metric(tkey, 1.02, int(time.time())) - self.store.log_metric(run.run_uuid, metric) + self.store.log_metric(run.info.run_uuid, metric) with self.assertRaises(MlflowException) as e: - self.store.log_metric(run.run_uuid, metric2) + self.store.log_metric(run.info.run_uuid, metric2) self.assertIn("must be unique. Metric already logged value", e.exception.message) def test_log_null_metric(self): run = self._run_factory() - self.session.commit() - tkey = 'blahmetric' tval = None metric = entities.Metric(tkey, tval, int(time.time())) - with self.assertRaises(MlflowException) as e: - self.store.log_metric(run.run_uuid, metric) - self.assertIn("Log metric request failed for run ID=", e.exception.message) - self.assertIn("IntegrityError", e.exception.message) + with self.assertRaises(MlflowException) as exception_context: + self.store.log_metric(run.info.run_uuid, metric) + assert exception_context.exception.error_code == ErrorCode.Name(INTERNAL_ERROR) def test_log_param(self): run = self._run_factory() - self.session.commit() - tkey = 'blahmetric' tval = '100.0' param = entities.Param(tkey, tval) param2 = entities.Param('new param', 'new key') - self.store.log_param(run.run_uuid, param) - self.store.log_param(run.run_uuid, param2) + self.store.log_param(run.info.run_uuid, param) + self.store.log_param(run.info.run_uuid, param2) - actual = self.session.query(models.SqlParam).filter_by(key=tkey, value=tval) - self.assertIsNotNone(actual) - - run = self.store.get_run(run.run_uuid) + run = self.store.get_run(run.info.run_uuid) self.assertEqual(2, len(run.data.params)) found = False @@ -473,34 +468,27 @@ def test_log_param(self): def test_log_param_uniqueness(self): run = self._run_factory() - self.session.commit() - tkey = 'blahmetric' tval = '100.0' param = entities.Param(tkey, tval) param2 = entities.Param(tkey, 'newval') - self.store.log_param(run.run_uuid, param) + self.store.log_param(run.info.run_uuid, param) with self.assertRaises(MlflowException) as e: - self.store.log_param(run.run_uuid, param2) + self.store.log_param(run.info.run_uuid, param2) self.assertIn("Changing param value is not allowed. Param with key=", e.exception.message) def test_log_empty_str(self): run = self._run_factory() - self.session.commit() - tkey = 'blahmetric' tval = '' param = entities.Param(tkey, tval) param2 = entities.Param('new param', 'new key') - self.store.log_param(run.run_uuid, param) - self.store.log_param(run.run_uuid, param2) - - actual = self.session.query(models.SqlParam).filter_by(key=tkey, value=tval) - self.assertIsNotNone(actual) + self.store.log_param(run.info.run_uuid, param) + self.store.log_param(run.info.run_uuid, param2) - run = self.store.get_run(run.run_uuid) + run = self.store.get_run(run.info.run_uuid) self.assertEqual(2, len(run.data.params)) found = False @@ -513,32 +501,23 @@ def test_log_empty_str(self): def test_log_null_param(self): run = self._run_factory() - self.session.commit() - tkey = 'blahmetric' tval = None param = entities.Param(tkey, tval) - with self.assertRaises(MlflowException) as e: - self.store.log_param(run.run_uuid, param) - self.assertIn("Log param request failed for run ID=", e.exception.message) - self.assertIn("IntegrityError", e.exception.message) + with self.assertRaises(MlflowException) as exception_context: + self.store.log_param(run.info.run_uuid, param) + assert exception_context.exception.error_code == ErrorCode.Name(INTERNAL_ERROR) def test_set_tag(self): run = self._run_factory() - self.session.commit() - tkey = 'test tag' tval = 'a boogie' tag = entities.RunTag(tkey, tval) - self.store.set_tag(run.run_uuid, tag) - - actual = self.session.query(models.SqlTag).filter_by(key=tkey, value=tval) + self.store.set_tag(run.info.run_uuid, tag) - self.assertIsNotNone(actual) - - run = self.store.get_run(run.run_uuid) + run = self.store.get_run(run.info.run_uuid) found = False for m in run.data.tags: @@ -549,7 +528,7 @@ def test_set_tag(self): def test_get_metric_history(self): run = self._run_factory() - self.session.commit() + key = 'test' expected = [ models.SqlMetric(key=key, value=0.6, timestamp=1).to_mlflow_entity(), @@ -557,29 +536,30 @@ def test_get_metric_history(self): ] for metric in expected: - self.store.log_metric(run.run_uuid, metric) + self.store.log_metric(run.info.run_uuid, metric) - actual = self.store.get_metric_history(run.run_uuid, key) + actual = self.store.get_metric_history(run.info.run_uuid, key) - self.assertSequenceEqual([(m.key, m.value, m.timestamp) for m in expected], - [(m.key, m.value, m.timestamp) for m in actual]) + six.assertCountEqual(self, + [(m.key, m.value, m.timestamp) for m in expected], + [(m.key, m.value, m.timestamp) for m in actual]) def test_list_run_infos(self): experiment_id = self._experiment_factory('test_exp') - r1 = self._run_factory(self._get_run_configs('t1', experiment_id)).run_uuid - r2 = self._run_factory(self._get_run_configs('t2', experiment_id)).run_uuid + r1 = self._run_factory(config=self._get_run_configs('t1', experiment_id)).info.run_uuid + r2 = self._run_factory(config=self._get_run_configs('t2', experiment_id)).info.run_uuid def _runs(experiment_id, view_type): return [r.run_uuid for r in self.store.list_run_infos(experiment_id, view_type)] - self.assertSequenceEqual([r1, r2], _runs(experiment_id, ViewType.ALL)) - self.assertSequenceEqual([r1, r2], _runs(experiment_id, ViewType.ACTIVE_ONLY)) + six.assertCountEqual(self, [r1, r2], _runs(experiment_id, ViewType.ALL)) + six.assertCountEqual(self, [r1, r2], _runs(experiment_id, ViewType.ACTIVE_ONLY)) self.assertEqual(0, len(_runs(experiment_id, ViewType.DELETED_ONLY))) self.store.delete_run(r1) - self.assertSequenceEqual([r1, r2], _runs(experiment_id, ViewType.ALL)) - self.assertSequenceEqual([r2], _runs(experiment_id, ViewType.ACTIVE_ONLY)) - self.assertSequenceEqual([r1], _runs(experiment_id, ViewType.DELETED_ONLY)) + six.assertCountEqual(self, [r1, r2], _runs(experiment_id, ViewType.ALL)) + six.assertCountEqual(self, [r2], _runs(experiment_id, ViewType.ACTIVE_ONLY)) + six.assertCountEqual(self, [r1], _runs(experiment_id, ViewType.DELETED_ONLY)) def test_rename_experiment(self): new_name = 'new name' @@ -592,10 +572,11 @@ def test_rename_experiment(self): def test_update_run_info(self): run = self._run_factory() + new_status = entities.RunStatus.FINISHED endtime = int(time.time()) - actual = self.store.update_run_info(run.run_uuid, new_status, endtime) + actual = self.store.update_run_info(run.info.run_uuid, new_status, endtime) self.assertEqual(actual.status, new_status) self.assertEqual(actual.end_time, endtime) @@ -619,34 +600,32 @@ def test_restore_experiment(self): def test_delete_restore_run(self): run = self._run_factory() - self.assertEqual(run.lifecycle_stage, entities.LifecycleStage.ACTIVE) - - run_uuid = run.run_uuid + self.assertEqual(run.info.lifecycle_stage, entities.LifecycleStage.ACTIVE) with self.assertRaises(MlflowException) as e: - self.store.restore_run(run_uuid) + self.store.restore_run(run.info.run_uuid) self.assertIn("must be in 'deleted' state", e.exception.message) - self.store.delete_run(run_uuid) + self.store.delete_run(run.info.run_uuid) with self.assertRaises(MlflowException) as e: - self.store.delete_run(run_uuid) + self.store.delete_run(run.info.run_uuid) self.assertIn("must be in 'active' state", e.exception.message) - deleted = self.store.get_run(run_uuid) - self.assertEqual(deleted.info.run_uuid, run_uuid) + deleted = self.store.get_run(run.info.run_uuid) + self.assertEqual(deleted.info.run_uuid, run.info.run_uuid) self.assertEqual(deleted.info.lifecycle_stage, entities.LifecycleStage.DELETED) - self.store.restore_run(run_uuid) + self.store.restore_run(run.info.run_uuid) with self.assertRaises(MlflowException) as e: - self.store.restore_run(run_uuid) + self.store.restore_run(run.info.run_uuid) self.assertIn("must be in 'deleted' state", e.exception.message) - restored = self.store.get_run(run_uuid) - self.assertEqual(restored.info.run_uuid, run_uuid) + restored = self.store.get_run(run.info.run_uuid) + self.assertEqual(restored.info.run_uuid, run.info.run_uuid) self.assertEqual(restored.info.lifecycle_stage, entities.LifecycleStage.ACTIVE) def test_error_logging_to_deleted_run(self): exp = self._experiment_factory('error_logging') - run_uuid = self._run_factory(self._get_run_configs(experiment_id=exp)).run_uuid + run_uuid = self._run_factory(self._get_run_configs(experiment_id=exp)).info.run_uuid self.store.delete_run(run_uuid) self.assertEqual(self.store.get_run(run_uuid).info.lifecycle_stage, @@ -715,29 +694,29 @@ def _metric_expression(self, key, comparator, val): def test_search_vanilla(self): exp = self._experiment_factory('search_vanilla') - runs = [self._run_factory(self._get_run_configs('r_%d' % r, exp)).run_uuid + runs = [self._run_factory(self._get_run_configs('r_%d' % r, exp)).info.run_uuid for r in range(3)] - self.assertSequenceEqual(runs, self._search(exp, run_view_type=ViewType.ALL)) - self.assertSequenceEqual(runs, self._search(exp, run_view_type=ViewType.ACTIVE_ONLY)) - self.assertSequenceEqual([], self._search(exp, run_view_type=ViewType.DELETED_ONLY)) + six.assertCountEqual(self, runs, self._search(exp, run_view_type=ViewType.ALL)) + six.assertCountEqual(self, runs, self._search(exp, run_view_type=ViewType.ACTIVE_ONLY)) + six.assertCountEqual(self, [], self._search(exp, run_view_type=ViewType.DELETED_ONLY)) first = runs[0] self.store.delete_run(first) - self.assertSequenceEqual(runs, self._search(exp, run_view_type=ViewType.ALL)) - self.assertSequenceEqual(runs[1:], self._search(exp, run_view_type=ViewType.ACTIVE_ONLY)) - self.assertSequenceEqual([first], self._search(exp, run_view_type=ViewType.DELETED_ONLY)) + six.assertCountEqual(self, runs, self._search(exp, run_view_type=ViewType.ALL)) + six.assertCountEqual(self, runs[1:], self._search(exp, run_view_type=ViewType.ACTIVE_ONLY)) + six.assertCountEqual(self, [first], self._search(exp, run_view_type=ViewType.DELETED_ONLY)) self.store.restore_run(first) - self.assertSequenceEqual(runs, self._search(exp, run_view_type=ViewType.ALL)) - self.assertSequenceEqual(runs, self._search(exp, run_view_type=ViewType.ACTIVE_ONLY)) - self.assertSequenceEqual([], self._search(exp, run_view_type=ViewType.DELETED_ONLY)) + six.assertCountEqual(self, runs, self._search(exp, run_view_type=ViewType.ALL)) + six.assertCountEqual(self, runs, self._search(exp, run_view_type=ViewType.ACTIVE_ONLY)) + six.assertCountEqual(self, [], self._search(exp, run_view_type=ViewType.DELETED_ONLY)) def test_search_params(self): experiment_id = self._experiment_factory('search_params') - r1 = self._run_factory(self._get_run_configs('r1', experiment_id)).run_uuid - r2 = self._run_factory(self._get_run_configs('r2', experiment_id)).run_uuid + r1 = self._run_factory(self._get_run_configs('r1', experiment_id)).info.run_uuid + r2 = self._run_factory(self._get_run_configs('r2', experiment_id)).info.run_uuid self.store.log_param(r1, entities.Param('generic_param', 'p_val')) self.store.log_param(r2, entities.Param('generic_param', 'p_val')) @@ -750,35 +729,35 @@ def test_search_params(self): # test search returns both runs expr = self._param_expression("generic_param", "=", "p_val") - self.assertSequenceEqual([r1, r2], self._search(experiment_id, param_expressions=[expr])) + six.assertCountEqual(self, [r1, r2], self._search(experiment_id, param_expressions=[expr])) # test search returns appropriate run (same key different values per run) expr = self._param_expression("generic_2", "=", "some value") - self.assertSequenceEqual([r1], self._search(experiment_id, param_expressions=[expr])) + six.assertCountEqual(self, [r1], self._search(experiment_id, param_expressions=[expr])) expr = self._param_expression("generic_2", "=", "another value") - self.assertSequenceEqual([r2], self._search(experiment_id, param_expressions=[expr])) + six.assertCountEqual(self, [r2], self._search(experiment_id, param_expressions=[expr])) expr = self._param_expression("generic_param", "=", "wrong_val") - self.assertSequenceEqual([], self._search(experiment_id, param_expressions=[expr])) + six.assertCountEqual(self, [], self._search(experiment_id, param_expressions=[expr])) expr = self._param_expression("generic_param", "!=", "p_val") - self.assertSequenceEqual([], self._search(experiment_id, param_expressions=[expr])) + six.assertCountEqual(self, [], self._search(experiment_id, param_expressions=[expr])) expr = self._param_expression("generic_param", "!=", "wrong_val") - self.assertSequenceEqual([r1, r2], self._search(experiment_id, param_expressions=[expr])) + six.assertCountEqual(self, [r1, r2], self._search(experiment_id, param_expressions=[expr])) expr = self._param_expression("generic_2", "!=", "wrong_val") - self.assertSequenceEqual([r1, r2], self._search(experiment_id, param_expressions=[expr])) + six.assertCountEqual(self, [r1, r2], self._search(experiment_id, param_expressions=[expr])) expr = self._param_expression("p_a", "=", "abc") - self.assertSequenceEqual([r1], self._search(experiment_id, param_expressions=[expr])) + six.assertCountEqual(self, [r1], self._search(experiment_id, param_expressions=[expr])) expr = self._param_expression("p_b", "=", "ABC") - self.assertSequenceEqual([r2], self._search(experiment_id, param_expressions=[expr])) + six.assertCountEqual(self, [r2], self._search(experiment_id, param_expressions=[expr])) def test_search_metrics(self): experiment_id = self._experiment_factory('search_params') - r1 = self._run_factory(self._get_run_configs('r1', experiment_id)).run_uuid - r2 = self._run_factory(self._get_run_configs('r2', experiment_id)).run_uuid + r1 = self._run_factory(self._get_run_configs('r1', experiment_id)).info.run_uuid + r2 = self._run_factory(self._get_run_configs('r2', experiment_id)).info.run_uuid self.store.log_metric(r1, entities.Metric("common", 1.0, 1)) self.store.log_metric(r2, entities.Metric("common", 1.0, 1)) @@ -793,70 +772,70 @@ def test_search_metrics(self): self.store.log_metric(r2, entities.Metric("m_b", 8.0, 3)) expr = self._metric_expression("common", "=", 1.0) - self.assertSequenceEqual([r1, r2], self._search(experiment_id, param_expressions=[expr])) + six.assertCountEqual(self, [r1, r2], self._search(experiment_id, param_expressions=[expr])) expr = self._metric_expression("common", ">", 0.0) - self.assertSequenceEqual([r1, r2], self._search(experiment_id, param_expressions=[expr])) + six.assertCountEqual(self, [r1, r2], self._search(experiment_id, param_expressions=[expr])) expr = self._metric_expression("common", ">=", 0.0) - self.assertSequenceEqual([r1, r2], self._search(experiment_id, param_expressions=[expr])) + six.assertCountEqual(self, [r1, r2], self._search(experiment_id, param_expressions=[expr])) expr = self._metric_expression("common", "<", 4.0) - self.assertSequenceEqual([r1, r2], self._search(experiment_id, param_expressions=[expr])) + six.assertCountEqual(self, [r1, r2], self._search(experiment_id, param_expressions=[expr])) expr = self._metric_expression("common", "<=", 4.0) - self.assertSequenceEqual([r1, r2], self._search(experiment_id, param_expressions=[expr])) + six.assertCountEqual(self, [r1, r2], self._search(experiment_id, param_expressions=[expr])) expr = self._metric_expression("common", "!=", 1.0) - self.assertSequenceEqual([], self._search(experiment_id, param_expressions=[expr])) + six.assertCountEqual(self, [], self._search(experiment_id, param_expressions=[expr])) expr = self._metric_expression("common", ">=", 3.0) - self.assertSequenceEqual([], self._search(experiment_id, param_expressions=[expr])) + six.assertCountEqual(self, [], self._search(experiment_id, param_expressions=[expr])) expr = self._metric_expression("common", "<=", 0.75) - self.assertSequenceEqual([], self._search(experiment_id, param_expressions=[expr])) + six.assertCountEqual(self, [], self._search(experiment_id, param_expressions=[expr])) # tests for same metric name across runs with different values and timestamps expr = self._metric_expression("measure_a", ">", 0.0) - self.assertSequenceEqual([r1, r2], self._search(experiment_id, param_expressions=[expr])) + six.assertCountEqual(self, [r1, r2], self._search(experiment_id, param_expressions=[expr])) expr = self._metric_expression("measure_a", "<", 50.0) - self.assertSequenceEqual([r1], self._search(experiment_id, param_expressions=[expr])) + six.assertCountEqual(self, [r1], self._search(experiment_id, param_expressions=[expr])) expr = self._metric_expression("measure_a", "<", 1000.0) - self.assertSequenceEqual([r1, r2], self._search(experiment_id, param_expressions=[expr])) + six.assertCountEqual(self, [r1, r2], self._search(experiment_id, param_expressions=[expr])) expr = self._metric_expression("measure_a", "!=", -12.0) - self.assertSequenceEqual([r1, r2], self._search(experiment_id, param_expressions=[expr])) + six.assertCountEqual(self, [r1, r2], self._search(experiment_id, param_expressions=[expr])) expr = self._metric_expression("measure_a", ">", 50.0) - self.assertSequenceEqual([r2], self._search(experiment_id, param_expressions=[expr])) + six.assertCountEqual(self, [r2], self._search(experiment_id, param_expressions=[expr])) expr = self._metric_expression("measure_a", "=", 1.0) - self.assertSequenceEqual([r1], self._search(experiment_id, param_expressions=[expr])) + six.assertCountEqual(self, [r1], self._search(experiment_id, param_expressions=[expr])) expr = self._metric_expression("measure_a", "=", 400.0) - self.assertSequenceEqual([r2], self._search(experiment_id, param_expressions=[expr])) + six.assertCountEqual(self, [r2], self._search(experiment_id, param_expressions=[expr])) # test search with unique metric keys expr = self._metric_expression("m_a", ">", 1.0) - self.assertSequenceEqual([r1], self._search(experiment_id, param_expressions=[expr])) + six.assertCountEqual(self, [r1], self._search(experiment_id, param_expressions=[expr])) expr = self._metric_expression("m_b", ">", 1.0) - self.assertSequenceEqual([r2], self._search(experiment_id, param_expressions=[expr])) + six.assertCountEqual(self, [r2], self._search(experiment_id, param_expressions=[expr])) # there is a recorded metric this threshold but not last timestamp expr = self._metric_expression("m_b", ">", 5.0) - self.assertSequenceEqual([], self._search(experiment_id, param_expressions=[expr])) + six.assertCountEqual(self, [], self._search(experiment_id, param_expressions=[expr])) # metrics matches last reported timestamp for 'm_b' expr = self._metric_expression("m_b", "=", 4.0) - self.assertSequenceEqual([r2], self._search(experiment_id, param_expressions=[expr])) + six.assertCountEqual(self, [r2], self._search(experiment_id, param_expressions=[expr])) def test_search_full(self): experiment_id = self._experiment_factory('search_params') - r1 = self._run_factory(self._get_run_configs('r1', experiment_id)).run_uuid - r2 = self._run_factory(self._get_run_configs('r2', experiment_id)).run_uuid + r1 = self._run_factory(self._get_run_configs('r1', experiment_id)).info.run_uuid + r2 = self._run_factory(self._get_run_configs('r2', experiment_id)).info.run_uuid self.store.log_param(r1, entities.Param('generic_param', 'p_val')) self.store.log_param(r2, entities.Param('generic_param', 'p_val')) @@ -874,37 +853,37 @@ def test_search_full(self): p_expr = self._param_expression("generic_param", "=", "p_val") m_expr = self._metric_expression("common", "=", 1.0) - self.assertSequenceEqual([r1, r2], self._search(experiment_id, - param_expressions=[p_expr], - metrics_expressions=[m_expr])) + six.assertCountEqual(self, [r1, r2], self._search(experiment_id, + param_expressions=[p_expr], + metrics_expressions=[m_expr])) # all params and metrics match p_expr = self._param_expression("generic_param", "=", "p_val") m1_expr = self._metric_expression("common", "=", 1.0) m2_expr = self._metric_expression("m_a", ">", 1.0) - self.assertSequenceEqual([r1], self._search(experiment_id, - param_expressions=[p_expr], - metrics_expressions=[m1_expr, m2_expr])) + six.assertCountEqual(self, [r1], self._search(experiment_id, + param_expressions=[p_expr], + metrics_expressions=[m1_expr, m2_expr])) # test with mismatch param p_expr = self._param_expression("random_bad_name", "=", "p_val") m1_expr = self._metric_expression("common", "=", 1.0) m2_expr = self._metric_expression("m_a", ">", 1.0) - self.assertSequenceEqual([], self._search(experiment_id, - param_expressions=[p_expr], - metrics_expressions=[m1_expr, m2_expr])) + six.assertCountEqual(self, [], self._search(experiment_id, + param_expressions=[p_expr], + metrics_expressions=[m1_expr, m2_expr])) # test with mismatch metric p_expr = self._param_expression("generic_param", "=", "p_val") m1_expr = self._metric_expression("common", "=", 1.0) m2_expr = self._metric_expression("m_a", ">", 100.0) - self.assertSequenceEqual([], self._search(experiment_id, - param_expressions=[p_expr], - metrics_expressions=[m1_expr, m2_expr])) + six.assertCountEqual(self, [], self._search(experiment_id, + param_expressions=[p_expr], + metrics_expressions=[m1_expr, m2_expr])) def test_log_batch(self): experiment_id = self._experiment_factory('log_batch') - run_uuid = self._run_factory(self._get_run_configs('r1', experiment_id)).run_uuid + run_uuid = self._run_factory(self._get_run_configs('r1', experiment_id)).info.run_uuid metric_entities = [Metric("m1", 0.87, 12345), Metric("m2", 0.49, 12345)] param_entities = [Param("p1", "p1val"), Param("p2", "p2val")] tag_entities = [RunTag("t1", "t1val"), RunTag("t2", "t2val")] @@ -914,7 +893,7 @@ def test_log_batch(self): tags = [(t.key, t.value) for t in run.data.tags] metrics = [(m.key, m.value, m.timestamp) for m in run.data.metrics] params = [(p.key, p.value) for p in run.data.params] - assert set(tags) == set([("t1", "t1val"), ("t2", "t2val")]) + assert set([("t1", "t1val"), ("t2", "t2val")]) <= set(tags) assert set(metrics) == set([("m1", 0.87, 12345), ("m2", 0.49, 12345)]) assert set(params) == set([("p1", "p1val"), ("p2", "p2val")]) @@ -922,8 +901,8 @@ def test_log_batch_limits(self): # Test that log batch at the maximum allowed request size succeeds (i.e doesn't hit # SQL limitations, etc) experiment_id = self._experiment_factory('log_batch_limits') - run_uuid = self._run_factory(self._get_run_configs('r1', experiment_id)).run_uuid - metric_tuples = [("m%s" % i, i * 0.1, 12345) for i in range(1000)] + run_uuid = self._run_factory(self._get_run_configs('r1', experiment_id)).info.run_uuid + metric_tuples = [("m%s" % i, i, 12345) for i in range(1000)] metric_entities = [Metric(*metric_tuple) for metric_tuple in metric_tuples] self.store.log_batch(run_id=run_uuid, metrics=metric_entities, params=[], tags=[]) run = self.store.get_run(run_uuid) @@ -934,48 +913,44 @@ def test_log_batch_param_overwrite_disallowed(self): # Test that attempting to overwrite a param via log_batch results in an exception and that # no partial data is logged run = self._run_factory() - self.session.commit() tkey = 'my-param' param = entities.Param(tkey, 'orig-val') - self.store.log_param(run.run_uuid, param) + self.store.log_param(run.info.run_uuid, param) overwrite_param = entities.Param(tkey, 'newval') tag = entities.RunTag("tag-key", "tag-val") metric = entities.Metric("metric-key", 3.0, 12345) with self.assertRaises(MlflowException) as e: - self.store.log_batch(run.run_uuid, metrics=[metric], params=[overwrite_param], + self.store.log_batch(run.info.run_uuid, metrics=[metric], params=[overwrite_param], tags=[tag]) self.assertIn("Changing param value is not allowed. Param with key=", e.exception.message) assert e.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) - self._verify_logged(run.run_uuid, metrics=[], params=[param], tags=[]) + self._verify_logged(run.info.run_uuid, metrics=[], params=[param], tags=[]) def test_log_batch_param_overwrite_disallowed_single_req(self): # Test that attempting to overwrite a param via log_batch results in an exception run = self._run_factory() - self.session.commit() pkey = "common-key" param0 = entities.Param(pkey, "orig-val") param1 = entities.Param(pkey, 'newval') tag = entities.RunTag("tag-key", "tag-val") metric = entities.Metric("metric-key", 3.0, 12345) with self.assertRaises(MlflowException) as e: - self.store.log_batch(run.run_uuid, metrics=[metric], params=[param0, param1], + self.store.log_batch(run.info.run_uuid, metrics=[metric], params=[param0, param1], tags=[tag]) self.assertIn("Changing param value is not allowed. Param with key=", e.exception.message) assert e.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) - self._verify_logged(run.run_uuid, metrics=[], params=[param0], tags=[]) + self._verify_logged(run.info.run_uuid, metrics=[], params=[param0], tags=[]) def test_log_batch_accepts_empty_payload(self): run = self._run_factory() - self.session.commit() - self.store.log_batch(run.run_uuid, metrics=[], params=[], tags=[]) - self._verify_logged(run.run_uuid, metrics=[], params=[], tags=[]) + self.store.log_batch(run.info.run_uuid, metrics=[], params=[], tags=[]) + self._verify_logged(run.info.run_uuid, metrics=[], params=[], tags=[]) def test_log_batch_internal_error(self): # Verify that internal errors during the DB save step for log_batch result in # MlflowExceptions run = self._run_factory() - self.session.commit() def _raise_exception_fn(*args, **kwargs): # pylint: disable=unused-argument raise Exception("Some internal error") @@ -991,7 +966,7 @@ def _raise_exception_fn(*args, **kwargs): # pylint: disable=unused-argument log_batch_kwargs = {"metrics": [], "params": [], "tags": []} log_batch_kwargs.update(kwargs) with self.assertRaises(MlflowException) as e: - self.store.log_batch(run.run_uuid, **log_batch_kwargs) + self.store.log_batch(run.info.run_uuid, **log_batch_kwargs) self.assertIn(str(e.exception.message), "Some internal error") def test_log_batch_nonexistent_run(self): @@ -1002,47 +977,47 @@ def test_log_batch_nonexistent_run(self): def test_log_batch_params_idempotency(self): run = self._run_factory() - self.session.commit() params = [Param("p-key", "p-val")] - self.store.log_batch(run.run_uuid, metrics=[], params=params, tags=[]) - self.store.log_batch(run.run_uuid, metrics=[], params=params, tags=[]) - self._verify_logged(run.run_uuid, metrics=[], params=params, tags=[]) + self.store.log_batch(run.info.run_uuid, metrics=[], params=params, tags=[]) + self.store.log_batch(run.info.run_uuid, metrics=[], params=params, tags=[]) + self._verify_logged(run.info.run_uuid, metrics=[], params=params, tags=[]) def test_log_batch_tags_idempotency(self): run = self._run_factory() - self.session.commit() - self.store.log_batch(run.run_uuid, metrics=[], params=[], tags=[RunTag("t-key", "t-val")]) - self.store.log_batch(run.run_uuid, metrics=[], params=[], tags=[RunTag("t-key", "t-val")]) - self._verify_logged(run.run_uuid, metrics=[], params=[], tags=[RunTag("t-key", "t-val")]) + self.store.log_batch( + run.info.run_uuid, metrics=[], params=[], tags=[RunTag("t-key", "t-val")]) + self.store.log_batch( + run.info.run_uuid, metrics=[], params=[], tags=[RunTag("t-key", "t-val")]) + self._verify_logged( + run.info.run_uuid, metrics=[], params=[], tags=[RunTag("t-key", "t-val")]) def test_log_batch_allows_tag_overwrite(self): run = self._run_factory() - self.session.commit() - self.store.log_batch(run.run_uuid, metrics=[], params=[], tags=[RunTag("t-key", "val")]) - self.store.log_batch(run.run_uuid, metrics=[], params=[], tags=[RunTag("t-key", "newval")]) - self._verify_logged(run.run_uuid, metrics=[], params=[], tags=[RunTag("t-key", "newval")]) + self.store.log_batch( + run.info.run_uuid, metrics=[], params=[], tags=[RunTag("t-key", "val")]) + self.store.log_batch( + run.info.run_uuid, metrics=[], params=[], tags=[RunTag("t-key", "newval")]) + self._verify_logged( + run.info.run_uuid, metrics=[], params=[], tags=[RunTag("t-key", "newval")]) def test_log_batch_allows_tag_overwrite_single_req(self): run = self._run_factory() - self.session.commit() tags = [RunTag("t-key", "val"), RunTag("t-key", "newval")] - self.store.log_batch(run.run_uuid, metrics=[], params=[], tags=tags) - self._verify_logged(run.run_uuid, metrics=[], params=[], tags=[tags[-1]]) + self.store.log_batch(run.info.run_uuid, metrics=[], params=[], tags=tags) + self._verify_logged(run.info.run_uuid, metrics=[], params=[], tags=[tags[-1]]) def test_log_batch_same_metric_repeated_single_req(self): run = self._run_factory() - self.session.commit() metric0 = Metric(key="metric-key", value=1, timestamp=2) metric1 = Metric(key="metric-key", value=2, timestamp=3) - self.store.log_batch(run.run_uuid, params=[], metrics=[metric0, metric1], tags=[]) - self._verify_logged(run.run_uuid, params=[], metrics=[metric0, metric1], tags=[]) + self.store.log_batch(run.info.run_uuid, params=[], metrics=[metric0, metric1], tags=[]) + self._verify_logged(run.info.run_uuid, params=[], metrics=[metric0, metric1], tags=[]) def test_log_batch_same_metric_repeated_multiple_reqs(self): run = self._run_factory() - self.session.commit() metric0 = Metric(key="metric-key", value=1, timestamp=2) metric1 = Metric(key="metric-key", value=2, timestamp=3) - self.store.log_batch(run.run_uuid, params=[], metrics=[metric0], tags=[]) - self._verify_logged(run.run_uuid, params=[], metrics=[metric0], tags=[]) - self.store.log_batch(run.run_uuid, params=[], metrics=[metric1], tags=[]) - self._verify_logged(run.run_uuid, params=[], metrics=[metric0, metric1], tags=[]) + self.store.log_batch(run.info.run_uuid, params=[], metrics=[metric0], tags=[]) + self._verify_logged(run.info.run_uuid, params=[], metrics=[metric0], tags=[]) + self.store.log_batch(run.info.run_uuid, params=[], metrics=[metric1], tags=[]) + self._verify_logged(run.info.run_uuid, params=[], metrics=[metric0, metric1], tags=[])