Skip to content

Commit

Permalink
Implement test
Browse files Browse the repository at this point in the history
  • Loading branch information
nzw0301 committed Jun 9, 2022
1 parent e1abcee commit 8e88df1
Showing 1 changed file with 36 additions and 0 deletions.
36 changes: 36 additions & 0 deletions tests/integration_tests/test_mlflow.py
Expand Up @@ -536,3 +536,39 @@ def test_multiobjective_raises_on_type_mismatch(tmpdir: py.path.local, metrics:
tracking_uri = f"file:{tmpdir}"
with pytest.raises(TypeError):
MLflowCallback(tracking_uri=tracking_uri, metric_name=metrics)


def test_chunk_info(tmpdir: py.path.local) -> None:

num_objective = mlflow.utils.validation.MAX_METRICS_PER_BATCH + 1
num_params = mlflow.utils.validation.MAX_PARAMS_TAGS_PER_BATCH + 1

def objective(trial: optuna.trial.Trial) -> Tuple[float, ...]:
for i in range(num_params):
trial.suggest_float(f"x_{i}", 1, 2)

return tuple([1.] * num_objective)

tracking_uri = f"file:{tmpdir}"
study_name = "my_study"

n_trials = 1

mlflc = MLflowCallback(tracking_uri=tracking_uri)
study = optuna.create_study(study_name=study_name, directions=["maximize"] * num_objective)
study.optimize(objective, n_trials=n_trials, callbacks=[mlflc])

mlfl_client = MlflowClient(tracking_uri)

experiment = mlfl_client.list_experiments()[0]
run_infos = mlfl_client.list_run_infos(experiment.experiment_id)
assert len(run_infos) == n_trials

run = mlfl_client.get_run(run_infos[0].run_id)
run_dict = run.to_dictionary()

# The `tags` contains param's distributions and other information too, such as trial number.
print(run_dict["data"]["tags"])
assert len(run_dict["data"]["tags"]) == num_params
assert len(run_dict["data"]["params"]) == num_params
assert len(run_dict["data"]["metrics"]) == num_objective

0 comments on commit 8e88df1

Please sign in to comment.