Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race condition for trial number computation. #1490

Merged
merged 7 commits into from Jul 13, 2020

Conversation

hvy
Copy link
Member

@hvy hvy commented Jul 7, 2020

Motivation

Fixes #1488.

Description of the changes

Allows count_past_trials to circumvent repeatable read isolation level. See the comments in the code for details.

Note

Verified the changes with MySQL, by inspection, by checking for unique trial numbers after running a distributed optimization. The bug is otherwise reproducible, simply running a lot of trials with.

SELECT COUNT(DISTINCT(number)) FROM trials;
import optuna
from optuna.samplers import TPESampler

def objective(trial, n_params):
    return sum(trial.suggest_float(f"x{i}", 0.0, 1.0) for i in range(n_params))

if __name__ == "__main__":
    n_params = 10
    n_trials = 200

    database = "duplicatenumber1488"
    storage = "mysql://root@localhost/duplicatenumber1488"
    study = optuna.create_study(sampler=TPESampler(), study_name=database, storage=storage, load_if_exists=True)

    study.optimize(lambda trial: objective(trial, n_params), n_trials=n_trials)

About isolation levels in sqlalchemy for different dialect
https://docs.sqlalchemy.org/en/13/core/connections.html#sqlalchemy.engine.Connection.execution_options.params.isolation_level.

@hvy hvy added bug Issue/PR about behavior that is broken. Not for typos/examples/CI/test but for Optuna itself. optuna.storages Related to the `optuna.storages` submodule. This is automatically labeled by github-actions. labels Jul 7, 2020
@hvy hvy marked this pull request as ready for review July 7, 2020 09:20
@c-bata
Copy link
Member

c-bata commented Jul 7, 2020

"Read Uncommitted" is lower isolation level than "Repeatable Read". It looks that this PR depends on "Dirty Reads", right?

ScreenShot 2020-07-07 18 17 57
https://en.wikipedia.org/wiki/Isolation_(database_systems)

Like this table said, "Dirty reads" may occure when using "Read Uncommitted". So it looks that this fixes is not safe.

@c-bata c-bata self-requested a review July 7, 2020 09:49
@hvy hvy force-pushed the fix-trial-number-race-condition branch from 2ebd008 to d3060a2 Compare July 7, 2020 13:33
@hvy
Copy link
Member Author

hvy commented Jul 7, 2020

Thanks @c-bata for you quick comment.

I changed the logic to do row level locking on all trials for the given study, and retrying on failures, rather than relaxing the isolation level. Verified the logic with MySQL and PostgreSQL locally. Note that this has the downside of significantly making creation of trials slower. However, this cannot be helped given our "constraint" on the number that it must be unique, etc.

@hvy
Copy link
Member Author

hvy commented Jul 7, 2020

Mini bench my MySQL.

Code

import argparse
import math
import time

import optuna
from optuna.samplers import TPESampler
import sqlalchemy


class Profile:
    def __enter__(self):
        self.start = time.time()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.end = time.time()

    def get(self):
        return self.end - self.start


def build_objective_fun(n_param):
    def objective(trial):
        return sum(
            [
                math.sin(trial.suggest_uniform("param-{}".format(i), 0, math.pi * 2))
                for i in range(n_param)
            ]
        )

    return objective


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("mysql_user", type=str)
    parser.add_argument("mysql_host", type=str)
    args = parser.parse_args()

    storage_str = "mysql+pymysql://{}@{}/".format(args.mysql_user, args.mysql_host)

    optuna.logging.set_verbosity(optuna.logging.CRITICAL)

    print(f"| #params | #trials | time(sec) |")
    print(f"| ------- | ------- | --------- |")

    for n_param in [1, 2, 4, 8, 16, 32]:
        for n_trial in [1, 10, 100, 1000]:
            engine = sqlalchemy.create_engine(storage_str)
            conn = engine.connect()
            conn.execute("commit")
            database_str = "profile_storage_t{}_p{}".format(n_trial, n_param)
            try:
                conn.execute("drop database {}".format(database_str))
            except Exception:
                pass
            conn.execute("create database {}".format(database_str))
            conn.close()

            storage = optuna.storages.get_storage(storage_str + database_str)
            study = optuna.create_study(storage=storage, sampler=TPESampler())

            with Profile() as prof:
                study.optimize(
                    build_objective_fun(n_param), n_trials=n_trial, gc_after_trial=False,
                )

            print(f"| {n_param} | {n_trial} | {prof.get():.2f} |")

#params #trials PR(sec) master(sec)
1 1 0.02 0.01
1 10 0.12 0.11
1 100 1.60 1.18
1 1000 35.97 16.84
2 1 0.02 0.02
2 10 0.13 0.11
2 100 1.53 1.25
2 1000 38.03 20.17
4 1 0.02 0.02
4 10 0.14 0.12
4 100 1.83 1.53
4 1000 42.60 26.23
8 1 0.04 0.03
8 10 0.16 0.15
8 100 2.36 2.04
8 1000 52.53 36.11
16 1 0.08 0.06
16 10 0.20 0.20
16 100 3.21 2.98
16 1000 71.93 54.78
32 1 0.11 0.11
32 10 0.30 0.29
32 100 4.95 4.64
32 1000 110.76 93.52

Comment on lines 467 to 471
# Lock all trials belonging to this study. This might lead to a deadlock
# (`OperationalError`) in which case we will retry.
session.query(models.TrialModel).filter(
models.TrialModel.study_id == study_id
).with_for_update().all()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about locking a entry in the studies table instead of locking the whole trials (though I haven't checked that it improves the latency)?

Copy link
Member Author

@hvy hvy Jul 8, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, that might scale better with the number of trials. Let me just verify it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good.

#params #trials time(sec)
1 1 0.02
1 10 0.11
1 100 1.26
1 1000 17.41
2 1 0.02
2 10 0.12
2 100 1.32
2 1000 20.22
4 1 0.02
4 10 0.13
4 100 1.61
4 1000 25.82
8 1 0.04
8 10 0.16
8 100 2.07
8 1000 35.07
16 1 0.06
16 10 0.20
16 100 3.02
16 1000 55.23
32 1 0.12
32 10 0.30
32 100 4.76
32 1000 90.28

Copy link
Member

@c-bata c-bata left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Note that this has the downside of significantly making creation of trials slower. However, this cannot be helped given our "constraint" on the number that it must be unique, etc.

Totally agree.

@hvy
Copy link
Member Author

hvy commented Jul 8, 2020

Thanks @c-bata as always. Regarding the performance hit, it's actually negligible now with @ytsmiling's suggestion of restricting the lock to a single row (at most), instead of all trials.

Copy link
Member

@c-bata c-bata left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hvy Now I noticed that we should modify the implementation of count_past_trials.

def count_past_trials(self, session):
# type: (orm.Session) -> int
trial_count = session.query(func.count(TrialModel.trial_id)).filter(
TrialModel.study_id == self.study_id, TrialModel.trial_id < self.trial_id
)
return trial_count.scalar()

We need to remove TrialModel.trial_id < self.trial_id from here.

Copy link
Member

@c-bata c-bata left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, It's my misunderstanding. Now we lock study before inserting trial. So current logic has no problem.

Copy link
Member

@ytsmiling ytsmiling left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for addressing this issue. LGTM.

@ytsmiling
Copy link
Member

It's highly unlikely, but if trial_id is not assigned sequentially, the count_past_trial method can fail (duplicated number can be assigned). I'm okay about leaving the method as is, but if you'd like to change the implementation, I'll re-review this PR.

@@ -464,9 +464,46 @@ def _create_new_trial(

session = self.scoped_session()

# Ensure that that study exists.
models.StudyModel.find_or_raise_by_id(study_id, session)
try:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Memo: Try n (maybe 3) times and propagate the OperationalError in case they all fail. It should be easier to debug.

@hvy
Copy link
Member Author

hvy commented Jul 10, 2020

Changed the logic to propagate sqlalchemy errors after 3 retries to reduce the risk of silencing unexpected errors and to aid debugging in case of what would previously have resulted in maxim recursion depth error.

@hvy
Copy link
Member Author

hvy commented Jul 10, 2020

PTAL.

Copy link
Member

@toshihikoyanase toshihikoyanase left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Copy link
Member

@ytsmiling ytsmiling left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Copy link
Member

@HideakiImamura HideakiImamura left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! LGTM!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Issue/PR about behavior that is broken. Not for typos/examples/CI/test but for Optuna itself. optuna.storages Related to the `optuna.storages` submodule. This is automatically labeled by github-actions.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Duplicate trial numbers on distributed optimization
5 participants