Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#4299 log_batch actually writes in batches #5460

Merged
merged 35 commits into from
Apr 14, 2022

Conversation

erensahin
Copy link
Contributor

@erensahin erensahin commented Mar 6, 2022

What changes are proposed in this pull request?

In this PR, SqlAlchemyStore.log_batch is implemented in a way that it actually writes in batches. Three new public methods are added:

  • log_params: log the given params at once. In case of any IntegrityError, rollback the session and call log_param for
    each param and let it handle IntegrityError
  • log_metrics: log the given metrics at once, and call _update_latest_metric_if_necessary if the metric instance is recently created.
  • set_tags: set the given tags at once.

In all of them, logging logic of single param, metric, or tag is preserved.

The main intuition is to share the same session and not to open new session for each atomic operation and commit the session. To maintain the current behavior, log_params actually commits its own session to catch the integrity error and let it handled by log_param method, by logging all params one by one by calling log_param.

SqlAlchemyStore lacks docstrings and typing, so I did not add them. But I can add if it is desired. I only added it to protected _get_or_create_many method to clarify intention of it.

How is this patch tested?

It passes current unit tests cases of tests/store/tracking/test_sqlalchemy_store.py and new test cases are added for log_params, log_metrics, and set_tags.

I also did a small load test:

import os
from timeit import default_timer as timer

import mlflow
from mlflow.tracking import MlflowClient
from mlflow.entities import Param, RunTag, Metric


os.environ["MLFLOW_TRACKING_URI"] = "sqlite:///mock.db"

params = []
metrics = []
tags = []
N = 100 * 100
MAX_BATCH_SIZE = 100

for i in range(N):
    params.append(Param(str(i), "val"))
    metrics.append(Metric(str(i), i, 0, 0))
    tags.append(RunTag(str(i), "val"))


with mlflow.start_run() as run:
    start = timer()
    for i in range(0, N, MAX_BATCH_SIZE):
        MlflowClient().log_batch(
            run.info.run_id,
            metrics=metrics[i:i+MAX_BATCH_SIZE],
            params=params[i:i+MAX_BATCH_SIZE],
            tags=tags[i:i+MAX_BATCH_SIZE]
        )
    runtime = round(timer() - start, 3)
    print(f"Runtime: {runtime} seconds...")

Results of master branch: Runtime: 246.888 seconds...
Results of this branch: Runtime: 4.769 seconds...

Does this PR change the documentation?

  • No. You can skip the rest of this section.
  • Yes. Make sure the changed pages / sections render correctly by following the steps below.
  1. Check the status of the ci/circleci: build_doc check. If it's successful, proceed to the
    next step, otherwise fix it.
  2. Click Details on the right to open the job page of CircleCI.
  3. Click the Artifacts tab.
  4. Click docs/build/html/index.html.
  5. Find the changed pages / sections and make sure they render correctly.

Release Notes

Is this a user-facing change?

  • No. You can skip the rest of this section.
  • Yes. Give a description of this change to be included in the release notes for MLflow users.

In this change, SqlAlchemyStore.log_batch method is modified and it actually writes in batches now. Also, new public methods are added to SqlAlchemyStore, log_params, log_metrics, and set_tags.

What component(s), interfaces, languages, and integrations does this PR affect?

Components

  • area/artifacts: Artifact stores and artifact logging
  • area/build: Build and test infrastructure for MLflow
  • area/docs: MLflow documentation pages
  • area/examples: Example code
  • area/model-registry: Model Registry service, APIs, and the fluent client calls for Model Registry
  • area/models: MLmodel format, model serialization/deserialization, flavors
  • area/projects: MLproject format, project running backends
  • area/scoring: MLflow Model server, model deployment tools, Spark UDFs
  • area/server-infra: MLflow Tracking server backend
  • area/tracking: Tracking Service, tracking client APIs, autologging

Interface

  • area/uiux: Front-end, user experience, plotting, JavaScript, JavaScript dev server
  • area/docker: Docker use across MLflow's components, such as MLflow Projects and MLflow Models
  • area/sqlalchemy: Use of SQLAlchemy in the Tracking Service or Model Registry
  • area/windows: Windows support

Language

  • language/r: R APIs and clients
  • language/java: Java APIs and clients
  • language/new: Proposals for new client languages

Integrations

  • integrations/azure: Azure and Azure ML integrations
  • integrations/sagemaker: SageMaker integrations
  • integrations/databricks: Databricks integrations

How should the PR be classified in the release notes? Choose one:

  • rn/breaking-change - The PR will be mentioned in the "Breaking Changes" section
  • rn/none - No description will be included. The PR will be mentioned only by the PR number in the "Small Bugfixes and Documentation Updates" section
  • rn/feature - A new user-facing feature worth mentioning in the release notes
  • rn/bug-fix - A user-facing bug fix worth mentioning in the release notes
  • rn/documentation - A user-facing documentation change worth mentioning in the release notes

So that metrics can be compared by their composite properties

Signed-off-by: Eren Sahin <sahineren.09@gmail.com>
Add the following new public methods:

* log_params: log the given params at once. In case of
any IntegrityError, rollback the session and call log_param for
each param and let it handle IntegrityError
* log_metrics: log the given metrics at once, and call
_update_latest_metric_if_necessary if the metric instance is
recently created.
* set_tags: set the given tags at once.

All these methods receive a session and can share the same session
which is created in log_batch. since log_params commits the session,
at first it is called wihout shared session and log_params commits
its own session independently.

log_metrics and set_tags share the same session.

Signed-off-by: Eren Sahin <sahineren.09@gmail.com>
…low#4299)

Additionally, modify test_log_batch_internal_error test case to use
correct mock method names.

Signed-off-by: Eren Sahin <sahineren.09@gmail.com>
@github-actions github-actions bot added area/tracking Tracking service, tracking client APIs, autologging rn/bug-fix Mention under Bug Fixes in Changelogs. labels Mar 6, 2022
@erensahin
Copy link
Contributor Author

@WeichenXu123 hi, you were reviewing #3024 therefore I am tagging you to notify 🙂

@WeichenXu123
Copy link
Collaborator

thank you for your contribution! I will review this soon.

Copy link
Collaborator

@dbczumar dbczumar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@erensahin Thanks so much for picking this up! Awesome PR! I left a comment that should hopefully help us simplify things and, possibly, improve performance further. Let me know what you think!

@@ -222,6 +223,29 @@ def _get_or_create(self, session, model, **kwargs):

return instance, created

def _get_or_create_many(self, session, model, list_of_args):
Copy link
Collaborator

@dbczumar dbczumar Mar 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re-examining the _get_or_create code after a few years, I'm wondering if it's truly necessary to have this logic for point writes or for batch writes. The only place where the created boolean output of _get_or_create is used is

logged_metric, just_created = self._get_or_create(
, where we check it to determine whether or not to update latest_metrics.

In all other cases, _get_or_create is duplicating the function of an integrity constraint.

Given that, can we remove the _get_or_create_many semantics from this change and instead in favor of implementing the the following logic for the following procedures?

log_params

  1. Create a session
  2. Within the session, perform a _get_run() operation to fetch all of the params associated with the run
  3. Attempt to insert the batch of parameters into the params table via session.add_all() & session.commit()
  4. In case of an integrity error, check to see if there are any value differences between the new batch (from 3) and the existing run parameters (from 2). If there are, throw a specific MlflowException. If there aren't, return successfully in order to provide idempotency (Params: once written, param values cannot be changed (attempting to overwrite a param value will result in an error). However, logging the same param (key, value) is permitted - that is, logging a param is idempotent. from https://mlflow.org/docs/latest/rest-api.html#log-batch).

This is very similar to

existing_params = [p.value for p in run.params if p.key == param.key]
if len(existing_params) > 0:
old_value = existing_params[0]
raise MlflowException(
"Changing param values is not allowed. Param with key='{}' was already"
" logged with value='{}' for run ID='{}'. Attempted logging new value"
" '{}'.".format(param.key, old_value, run_id, param.value),
INVALID_PARAMETER_VALUE,
)
else:
raise
, except we don't re-raise in the case where there isn't any mismatch between the new parameter values and the existing values in the DB.

log_metrics

  1. Create a session
  2. Attempt to insert the batch of metrics into the metrics table via session.add_all() & session.commit()
  3. Unconditionally attempt to update the latest_metrics table
  4. If any integrity errors are encountered, raise them directly to the caller

set_tags

  1. Create a session
  2. Attempt to upsert the batch of tags into the tags table via session.merge()
  3. If any integrity errors are encountered (unlikely), raise them directly to the caller

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After implementing these changes, I'd be curious to see whether the performance is even better, given that we'd no longer perform entity lookups prior to insertion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @dbczumar , thank you so much for such detailed comment.

I didn't want to make radical changes and keep things behave as they do in logging one by one, but I myself think these parts should be re-designed as well. I totally agree with your comment on the necessity of _get_or_create.

log_params

Yeah, it makes sense. You expect to ignore the integrity error if the parameters are same. And even if one parameter does not match, we should rollback the changes. Did I understand correctly?

log_metrics

Makes sense. And I should update latest_metrics table in batch afterwards.

set_tags

Makes sense. I also would not expect an integrity error since we merge.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to implement your requests but there is some problems:

sqlalchemy does not support batch upsert operations via session.merge(). so I didn't implement it.
since we change the behavior in _log_params, now some test cases fail. The problem is that we handle the exception (and conditionally raise) for the whole batch, and during handling it, the whole session is rolled back. But the test cases (test_log_batch_param_overwrite_disallowed_single_req) still expect that one of the params is logged at all..

@@ -604,6 +632,52 @@ def log_metric(self, run_id, metric):
if just_created:
self._update_latest_metric_if_necessary(logged_metric, session)

def log_metrics(self, run_id, metrics, session=None, check_run=True):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Can we make the new functions introduced by this PR "private" by prepending an underscore, since these methods aren't part of the AbstractStore API and don't need to be called by external consumers?

Suggested change
def log_metrics(self, run_id, metrics, session=None, check_run=True):
def _log_metrics(self, run_id, metrics, session=None, check_run=True):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, thanks for the suggestion

Comment on lines 675 to 678
for (sql_metric, just_created) in instances:
if just_created:
self._update_latest_metric_if_necessary(
sql_metric, session)
Copy link
Collaborator

@dbczumar dbczumar Mar 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this a batch operation as well? Given a list of metrics, we can query the latest_metrics table for the corresponding metric names to determine which metrics to update, then, we can do a batch session.merge().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's a great suggestion, thanks.

By renaming them as _log_metrics, _log_params, and _set_tag respectively

Signed-off-by: Eren Sahin <sahineren.09@gmail.com>
@erensahin
Copy link
Contributor Author

erensahin commented Mar 15, 2022

Hi @dbczumar , I think it is ready for another review round, I tried to implement your requests but there is some problems:

  1. sqlalchemy does not support batch upsert operations via session.merge(). so I didn't implement it.
  2. since we change the behavior in _log_params, now some test cases fail. The problem is that we handle the exception (and conditionally raise) for the whole batch, and during handling it, the whole session is rolled back. But the test cases (test_log_batch_param_overwrite_disallowed_single_req) still expect that one of the params is logged at all..

Copy link
Collaborator

@dbczumar dbczumar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @dbczumar , I think it is ready for another review round, I tried to implement your requests but there is some problems:

  1. sqlalchemy does not support batch upsert operations via session.merge(). so I didn't implement it.
  2. since we change the behavior in _log_params, now some test cases fail. The problem is that we handle the exception (and conditionally raise) for the whole batch, and during handling it, the whole session is rolled back. But the test cases (test_log_batch_param_overwrite_disallowed_single_req) still expect that one of the params is logged at all..

Thanks so much for trying this out and addressing feedback, @erensahin!

  1. My mistake! Apologies for missing this. I've left some suggestions for other approaches to improve performance without using merge().

  2. This is okay. We can remove the old conflicting test coverage. Partial writes aren't part of the documented API contract; IMO, removing undocumented partial writes is an improvement.

We should be all good to go once remaining comments are addressed. Let me know if I can provide any additional clarifications / assistance here. Thank you for all your hard work!

# ignore the exception since idempotency is reached.
non_matching_params = []
for param in param_instances:
existing_value = [p.value for p in run.params if p.key == param.key]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Performance nit: Can we transform run.params into a dictionary representation mapping from param.key -> param.value so that a linear scan isn't required for each element of param_instances? Then, we can do existing_value = run_params_dict.get(param.key)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's a smart idea. thanks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

# keep the behavior of updating latest metric table if necessary
# for each created or retrieved metric
for logged_metric in metric_instances:
self._update_latest_metric_if_necessary(logged_metric, session)
Copy link
Collaborator

@dbczumar dbczumar Mar 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies for the oversight regarding the lack of batch support for session.merge().

Even without session.merge(), I think we may be able to write a more performant batch version of _update_latest_metric_if_necessary by making the following adjustments:

  1. Start a serializable transaction via:
with session.connection(execution_options={'isolation_level': 'SERIALIZABLE'}) as connection:
    ...

(Reference: https://www.bookstack.cn/read/sqlalchemy-1.3/d36f81a90089ab55.md - see "Setting Isolation For Individual Transactions)

  1. Perform the metrics table insert operations defined above.

  2. Issue a single select query to fetch all of the latest_metrics records corresponding to the metric keys that were supplied to _log_metrics.

  3. For each latest Metric SqlLatestMetric object returned by the query in (3), update its attributes if necessary (as determined by the logic in _compare_metrics). For reference, sqlalchemy allows records to be updated by mutating the attributes of their corresponding objects (e.g. https://stackoverflow.com/a/2334917/11952869).

  4. For new metrics that do not have corresponding latest_metrics entries, create new Metric SqlLatestMetric objects and add them to the session via session.add_all() (_save_to_db())

  5. Commit the transaction

A serializable transaction provides even better isolation guarantees for concurrent batch logging than the approach taken prior to this PR. Taking row locks on the latest_metrics table is insufficient because two concurrent operations may attempt to insert a brand new metric with differing values; in such a case, there would be no corresponding row in latest_metrics to lock.

Finally, after implementing this logic, can we update log_metric to call _log_metrics as well? This way, our serializable transaction is the only transaction that ever writes to metrics and latest_metrics, ensuring isolation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me, and it is quite promising improvement as far as i can foresee right now, thanks for the suggestion :) 4-5 covers an upsert operation here. I am not sure session.merge handles the WHEN NOT MATCHED BY SOURCE THEN DELETE logic (I dont think that it does). So there's nothing requires a delete operation here, right?

  1. I think 3 would return SqlLatestMetric objects, and I should not convert them to mlflow entity (which is Metric) over here to benefit from updating the attributes.
  2. I think I should create SqlLatestMetric objects here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fantastic! I agree, we never need to worry about record deletion for metrics / latest_metrics.

Good catch - apologies for the oversight. 4 and 5 should refer to SqlLatestMetric; I'll update the original message.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, I guess :)

@@ -731,6 +731,95 @@ def test_log_null_metric(self):
warnings.resetwarnings()
assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)

def test__log_metrics(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we provide a similar level of test coverage by testing the log_batch() method with different inputs, rather than testing internal methods? If this isn't easy to accomplish, let me know :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't modify any test cases regarding log_batch, but I'll try to add more cases for it :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I think it would be great if we could write test cases for log_batch() and remove the test cases for the private methods. I.e. replace private method testing with public method testing, as long as we can get the same level of coverage.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I revisited all unit test cases and converted them to call log_batch instead of private methods

with context as session:
for tag in tags:
_validate_tag(tag.key, tag.value)
session.merge(SqlTag(run_uuid=run_id, key=tag.key, value=tag.value))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because session.merge() doesn't support batch upserts, I think the logic you were using previously is more performant. Can we do the following?

  1. Query the database for existing tags with the specified keys. Use with_for_update() to lock the existing tag rows, ensuring that they aren't deleted by a concurrent DeleteTag operation before we commit the update.

  2. For each latest RunTag object returned by the query in (1), update its value. For reference, sqlalchemy allows records to be updated by mutating the attributes of their corresponding objects (e.g. https://stackoverflow.com/a/2334917/11952869).

  3. For new tags that were not present in the query results from (1), add them to the session via session.add_all() (_save_to_db()).

With this proposed implementation, it's possible in rare cases that _set_tags will fail because two concurrent operations attempt to insert the same new tag key into the tags table. Accordingly, can we retry 3 times with randomized exponential backoff in case of an IntegrityError?

It's probably best to use a new session for _set_tags to allow for this retry logic to occur without rolling back metric inserts. While this hurts atomicity, it improves performance, and the existing API wasn't atomic anyway :). Accordingly, _log_params, _set_tags, and _log_metrics can all use their own sessions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it is quite clear. I got your point on using their own sessions in all these methods.

For the retry case, do you have any suggestions on implementing it on sqlalchemy? Seen some implementations over Query interface, are they ok to use? (https://stackoverflow.com/questions/53287215/retry-failed-sqlalchemy-queries)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question! For retries, I think we can just write some retry logic around the entire _set_tags body, without worrying about SqlAlchemy-specific integrations. I noticed the following response on the thread you linked:

hi there - SQLAlchemy author here. I've just been pointed to this recipe. I would highly recommend against using a pattern such as the above. "retry" for connectivity should be performed at the top level of a transaction only, and that's what the pool pre_ping feature is for. if you lose a connection in the middle of a transaction, you need to re-run your whole operation. explicit is better than implicit. –

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to implement a retry mechanism, but I was unable to write a comprehensive unit test for it. I did not want to get into multiprocessing/threading. do you have a suggestion?

)

with context as session:
self._save_to_db(session=session, objs=metric_instances)
Copy link
Collaborator

@dbczumar dbczumar Mar 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It just occurred to me that we may run into trouble with the idempotence of LogBatch due to this use of session.add_all() (_save_to_db()) to insert a collection of new Metric SqlMetric records into the metrics table. If two LogBatch requests with the same metric content are issued sequentially, the second one will fail with a primary key constraint violation.

Can we implement logic similar to what we're using in _log_params by catching IntegrityError and checking the run's metric history? If the metric history contains all of the metrics being inserted, we can simply return, since we know that the metrics table contains the desired values and that they were reflected correctly by latest_metrics when they were originally inserted.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it is a good idea. If the metric history does not contain all of the metrics being inserted, I should insert the ones whom not present, and update latest_metrics table, is it correct?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep! Exactly :).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this one is done

Comment on lines 819 to 820
session = self.SessionMaker()
session.connection(execution_options={"isolation_level": "SERIALIZABLE"})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need a serializable transaction for log_params and set_tags because we're only writing to a single table and we aren't performing any gap locking (e.g. select for update). Can we use ManagedSessionMaker instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

# isolation level for individual transaction.
# Here, we open a connection with SERIALIZABLE isolation level.
# it will be released when session.commit is called.
session = self.SessionMaker()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need a serializable transaction for log_params and set_tags because we're only writing to a single table and we aren't performing any gap locking (e.g. select for update). Can we use ManagedSessionMaker instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

current_tags = (
session.query(SqlTag)
.filter(SqlTag.run_uuid == run_id, SqlTag.key.in_([t.key for t in tags]))
.with_for_update()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can loosen the isolation guarantees for set_tags because, unlike log_metrics, we're only modifying a single table. For concurrent requests setting multiple tags, there's no well-defined behavior. If we want to use with_for_update, we'd need to use a serializable transaction to avoid deadlocks due to gap locking (see #4353 (comment)).

Can we remove with_for_update()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to make sure that the isolation level is guaranteed but if you say that we don't need it for SqlTag entity, it makes sense. I removed it.


new_tag_dict = {}
for tag in tags:
_validate_tag(tag.key, tag.value)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Can we move tag validation to the top of the method so that we don't make any DB queries if tags are invalid?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's a good idea, done

[t.key for t in tags], max_retries
)
)
sleep_duration = 0.1 * ((2 ** attempt) - 1)
Copy link
Collaborator

@dbczumar dbczumar Apr 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should increase the sleep duration slightly.

Suggested change
sleep_duration = 0.1 * ((2 ** attempt) - 1)
sleep_duration = (2 ** attempt) - 1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, done

attempt = (attempt or 0) + 1
if attempt > max_retries:
raise MlflowException(
"Failed to set tags with given keys={} within {} retries".format(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Can we place the list of keys at the end of the message?

Suggested change
"Failed to set tags with given keys={} within {} retries".format(
"Failed to set tags within {} retries. Keys: {}".format(

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines 807 to 809
param_instances = {}
for param in params:
if param.key not in param_instances:
param_instances[param.key] = SqlParam(
run_uuid=run_id, key=param.key, value=param.value
)
param_instances = list(param_instances.values())
Copy link
Collaborator

@dbczumar dbczumar Apr 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can replace this with

params = set(params)
param_instances = [SqlParam(run_uuid=run_id, key=param.key, value=param.value) for param in params]

Copy link
Contributor Author

@erensahin erensahin Apr 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure why I choose this way of doing this, but I think I wanted to make sure that we keep the order of params while making them distinct, since set does not preserve the order. But I can replace with that if you want :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Params are unordered, so we don't need to preserve insertion order :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, the order matters in tags, so it made me think so :) applied it over here

Comment on lines 930 to 931
# if the SqlTag instance is recently created,
# update with incoming
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Thanks for handling this case. Small comment suggestion:

Suggested change
# if the SqlTag instance is recently created,
# update with incoming
# if a SqlTag instance is already present in `new_tag_dict`, this means that multiple
# tags with the same key were passed to `set_tags`. In this case, we resolve potential
# conflicts by updating the value of the existing instance to the value of `tag`

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I applied it

@@ -621,29 +674,80 @@ def _compare_metrics(metric_a, metric_b):
metric_b.value,
)

# Fetch the latest metric value corresponding to the specified run_id and metric key and
# lock its associated row for the remainder of the transaction in order to ensure
def _merge_metric(new_metric, old_metric):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I think _overwrite_metric is a more descriptive name here

Suggested change
def _merge_metric(new_metric, old_metric):
def _overwrite_metric(new_metric, old_metric):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Applied your suggestion, that was my other alternative :)

Copy link
Collaborator

@dbczumar dbczumar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@erensahin This is looking great! I left a few small comments, but I think we should be ready to merge once they're addressed! Thank you for all your hard work on this!

It looks like there are a few small test failures in https://github.com/mlflow/mlflow/runs/5825511155?check_suite_focus=true. Please let me know if you'd like me to help debug them.

@erensahin erensahin force-pushed the 4299-log_batch branch 2 times, most recently from 641bcb1 to ea2ca5d Compare April 11, 2022 20:06
So that params can be compared by their composite properties

Signed-off-by: Eren Sahin <sahineren.09@gmail.com>
@dbczumar
Copy link
Collaborator

My guess regarding Windows is that the sqlite version used for our Windows CI is older than the one used for Linux:

SQLITE_MAX_VARIABLE_NUMBER, which defaults to 999 for SQLite versions prior to 3.32.0 (2020-05-22) or 32766 for SQLite versions after 3.32.0.

In practice, 2 parts should be equivalent to batches of 500, right? Since the maximum number of keys is 1000. If that max limit was ever raised, we'd want to keep using a smaller batch size.

@erensahin
Copy link
Contributor Author

My guess regarding Windows is that the sqlite version used for our Windows CI is older than the one used for Linux:

SQLITE_MAX_VARIABLE_NUMBER, which defaults to 999 for SQLite versions prior to 3.32.0 (2020-05-22) or 32766 for SQLite versions after 3.32.0.

In practice, 2 parts should be equivalent to batches of 500, right? Since the maximum number of keys is 1000. If that max limit was ever raised, we'd want to keep using a smaller batch size.

Yes, it is the same in practice

# and try to handle them.
try:
self._save_to_db(session=session, objs=metric_instances)
self._update_latest_metrics_if_necessary(metric_instances, session)
Copy link
Contributor Author

@erensahin erensahin Apr 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that this might be the causing "deadlock", since we operate on metric_instances that are not committed yet? these are another tests that do fail on windows (test_log_batch_same_metrics_repeated_multiple_reqs). it might be less performant but I would suggest something like below

try:
    self._save_to_db(session=session, objs=metric_instances)
    session.commit()
except:
    session.rollback()
    ...

self._update_latest_metrics_if_necessary(metric_instances, session)
session.commit()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I discovered an issue when running locally against sqlite where self._update_latest_metrics_if_necessary(metric_instances, session) is failing with a duplicate element and the error handling logic in https://github.com/mlflow/mlflow/pull/5460/files#diff-b2b10e69c9c9134afc376963051160539426069d6ed9ebd6af8252ce2b51de09R645-R664 gets stuck. Investigating further...

Signed-off-by: dbczumar <corey.zumar@databricks.com>
Signed-off-by: dbczumar <corey.zumar@databricks.com>
@dbczumar
Copy link
Collaborator

dbczumar commented Apr 13, 2022

@erensahin I just pushed a commit for the following:

  1. Use a chunk size of 500 when querying latest metrics.
  2. Use ManagedSessionMaker for creating sessions that use isolated transaction levels, ensuring that sessions are committed and closed properly.

This seems to have fixed the database tests.

Comment on lines 644 to 661
metric_history = (
session.query(SqlMetric)
.filter(
SqlMetric.run_uuid == run_id,
SqlMetric.key.in_([m.key for m in metric_instances]),
)
.all()
)
# convert to a set of Metric instance to take advantage of its hashable
# and then obtain the metrics that were not logged earlier within this run_id
metric_history = {m.to_mlflow_entity() for m in metric_history}
non_existing_metrics = [
m for m in metric_instances if m.to_mlflow_entity() not in metric_history
]
# if there exist metrics that were tried to be logged & rolled back even though
# they were not violating the PK, log them.
if non_existing_metrics:
self._log_metrics(run_id, non_existing_metrics)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thought, reading up to 1000 metric histories into memory for conflict resolution could be fairly expensive from a latency / memory pressure perspective.

I'm wondering if it might be better to throw in this case instead. For certain sql dialects, we can use a "on duplicate key ignore" construct and otherwise fail open. I'll update this with some more details shortly.

Thoughts @erensahin ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The more I think about it, the more I favor the current approach. The data volume concern can be addressed via further chunking

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much for handling the problem @dbczumar :)

In my opinion, throwing an error when duplicate metric entry (key-timestamp-step-value) is given in a batch is perfectly fine without even trying to insert. It is the way how log_params treat. And I agree on "on duplicate key ignore" approach on certain dialects, but for the sake of stability raising an early error makes more sense for me.

Copy link
Collaborator

@dbczumar dbczumar Apr 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood! I was tempted to go that direction, but I don't want to break any existing user workloads that might rely on idempotent logging of duplicate metric keys. I prefer your existing logic :D. I've just made a minor adjustment here to chunk the metrics into batches of 100 to avoid loading too much data into memory at once; I think the data volume is acceptable since this case is unlikely to occur frequently. I also refactored slightly so that we call an internal method when handling the IntegrityError, rather than retrying the whole _log_metric method, which would create a nested session. I've gone ahead and pushed the changes; let me know how they look! If CI passes, I think we're all set to merge! Thanks for all your hard work @erensahin !

Signed-off-by: dbczumar <corey.zumar@databricks.com>
Signed-off-by: dbczumar <corey.zumar@databricks.com>
run = self._get_run(run_uuid=run_id, session=session)
self._check_run_is_active(run)

def _try_insert_tags(attempt_number, max_retries=3):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@erensahin I also made a slight tweak here to call an internal method when retrying tag setting; this helps us avoid creating a new session nested within the previous session. Let me know if you have any concerns!

Signed-off-by: dbczumar <corey.zumar@databricks.com>
Signed-off-by: dbczumar <corey.zumar@databricks.com>
Signed-off-by: dbczumar <corey.zumar@databricks.com>
Copy link
Collaborator

@dbczumar dbczumar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@erensahin I believe that this also addresses #4353 by serializing metric write operations. I've added a test case to confirm. Awesome work! I'll plan to merge tomorrow once tests pass! Thank you for your contribution!

Signed-off-by: dbczumar <corey.zumar@databricks.com>
@erensahin
Copy link
Contributor Author

@erensahin I believe that this also addresses #4353 by serializing metric write operations. I've added a test case to confirm. Awesome work! I'll plan to merge tomorrow once tests pass! Thank you for your contribution!

thats great news! thank you for your support. Do you have any planned release time? when we can get these new changes?
I would also appreciate if you can apply these changes to Databricks Managed Mlflow as well. We are using it and we hope to see a major speed-up :)

@erensahin
Copy link
Contributor Author

I also updated the MR description, we started with 240 secs --> 25 secs improvement and we end up 240 secs --> 5 secs

Signed-off-by: dbczumar <corey.zumar@databricks.com>
Signed-off-by: dbczumar <corey.zumar@databricks.com>
Signed-off-by: dbczumar <corey.zumar@databricks.com>
Signed-off-by: dbczumar <corey.zumar@databricks.com>
Signed-off-by: dbczumar <corey.zumar@databricks.com>
Signed-off-by: dbczumar <corey.zumar@databricks.com>
# Order by the metric run ID and key to ensure a consistent locking order
# across transactions, reducing deadlock likelihood
.order_by(SqlLatestMetric.run_uuid, SqlLatestMetric.key)
.with_for_update()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@erensahin After adding a concurrency test, I found that using a serializable transaction leads to deadlock due to range locking. I've instead reverted back to the previous with_for_update() row locking procedure and derisked it by ensuring that we only lock preexisting rows in the database, resolving the issue described in #4353 (comment). While it's possible that two separate transactions may simultaneously insert a brand new metric, this case is highly unlikely to occur in practice, and it wasn't handled before. The slightly more common case where two threads are logging values for a preexisting metric simultaneously is safe from a concurrency perspective with row locking.

Copy link
Collaborator

@dbczumar dbczumar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@erensahin Thanks again for all your hard work! This will be released with MLflow 1.26.0, which should roll out some time in May. The Databricks-managed MLflow Tracking service implements fast, scalable, and threadsafe batch metadata logging; feel free to check it out!

Merging!

@dbczumar dbczumar merged commit 314d60d into mlflow:master Apr 14, 2022
@harupy harupy mentioned this pull request Jun 2, 2022
24 tasks
@szczesniak-piotr szczesniak-piotr mentioned this pull request Jun 12, 2023
33 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
area/tracking Tracking service, tracking client APIs, autologging rn/bug-fix Mention under Bug Fixes in Changelogs.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[FR] SqlAlchemyStore.log_batch should really write in batches
4 participants