From 8fd2291aee6ab2f69d2f0e7707c902b877f9626b Mon Sep 17 00:00:00 2001 From: Chester Chen <512707+chesterxgchen@users.noreply.github.com> Date: Fri, 10 May 2024 19:13:15 -0700 Subject: [PATCH] fix MLFLOW example (#2575) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Yuan-Ting Hsieh (謝沅廷) --- .../app/config/config_fed_server.conf | 1 + .../cifar10/code/fl/train_with_mlflow.py | 2 +- .../step-by-step/cifar10/sag/sag.ipynb | 46 ++++++++++++++++++- .../sag_pt_in_proc/config_fed_server.conf | 2 +- 4 files changed, 47 insertions(+), 4 deletions(-) diff --git a/examples/advanced/experiment-tracking/mlflow/jobs/hello-pt-mlflow/app/config/config_fed_server.conf b/examples/advanced/experiment-tracking/mlflow/jobs/hello-pt-mlflow/app/config/config_fed_server.conf index 9bc187c8ab..5b0f694a6a 100644 --- a/examples/advanced/experiment-tracking/mlflow/jobs/hello-pt-mlflow/app/config/config_fed_server.conf +++ b/examples/advanced/experiment-tracking/mlflow/jobs/hello-pt-mlflow/app/config/config_fed_server.conf @@ -48,6 +48,7 @@ "id": "mlflow_receiver_with_tracking_uri", "path": "nvflare.app_opt.tracking.mlflow.mlflow_receiver.MLflowReceiver", "args": { + tracking_uri = "file:///{WORKSPACE}/{JOB_ID}/mlruns" "kwargs": { "experiment_name": "hello-pt-experiment", "run_name": "hello-pt-with-mlflow", diff --git a/examples/hello-world/step-by-step/cifar10/code/fl/train_with_mlflow.py b/examples/hello-world/step-by-step/cifar10/code/fl/train_with_mlflow.py index 1d43b88d4d..3e05f715b2 100644 --- a/examples/hello-world/step-by-step/cifar10/code/fl/train_with_mlflow.py +++ b/examples/hello-world/step-by-step/cifar10/code/fl/train_with_mlflow.py @@ -139,7 +139,7 @@ def evaluate(input_weights): running_loss += loss.item() if i % 2000 == 1999: # print every 2000 mini-batches print(f"({client_id}) [{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}") - global_step = input_model.current_round * local_epochs * batch_size + epoch * batch_size + i + global_step = input_model.current_round * steps + epoch * len(trainloader) + i mlflow.log_metric("loss", running_loss / 2000, global_step) running_loss = 0.0 diff --git a/examples/hello-world/step-by-step/cifar10/sag/sag.ipynb b/examples/hello-world/step-by-step/cifar10/sag/sag.ipynb index dbd515b0f0..6d8b8ca76e 100644 --- a/examples/hello-world/step-by-step/cifar10/sag/sag.ipynb +++ b/examples/hello-world/step-by-step/cifar10/sag/sag.ipynb @@ -232,8 +232,8 @@ "source": [ "! nvflare job create -j /tmp/nvflare/jobs/cifar10_sag_pt -w sag_pt_in_proc \\\n", "-f meta.conf min_clients=2 \\\n", - "-f config_fed_client.conf app_script=train.py app_config=\"--batch_size 4 --dataset_path {CIFAR10_ROOT} --num_workers 2\" \\\n", - "-f config_fed_server.conf num_rounds=5 \\\n", + "-f config_fed_client.conf app_script=train_with_mlflow.py app_config=\"--batch_size 4 --dataset_path {CIFAR10_ROOT} --num_workers 2\" \\\n", + "-f config_fed_server.conf num_rounds=2 \\\n", "-sd ../code/fl \\\n", "-force" ] @@ -289,6 +289,48 @@ "The next 5 examples will use the same ScatterAndGather workflow, but will demonstrate different execution APIs and feature.\n", "In the next example [sag_deploy_map](../sag_deploy_map/sag_deploy_map.ipynb), we will learn about the deploy_map configuration for deployment of apps to different sites." ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a49b430b-a65b-4b1e-8793-9b3befcfcfd9", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "!tree /tmp/nvflare/jobs/cifar10_sag_pt_workspace/" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "50594df7-b4c9-4e5e-944a-403b5a105c27", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "!mlflow ui --port 5000 --backend-store-uri /tmp/nvflare/jobs/cifar10_sag_pt_workspace/server/simulate_job/mlruns" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "af2b6628-61af-4bc8-84d4-a9876a27c7c2", + "metadata": {}, + "outputs": [], + "source": [ + "!tensorboard --logdir=/tmp/nvflare/jobs/cifar10_sag_pt_workspace/server/simulate_job/tb_events" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d3ad11c3-6ef7-46cd-8778-0090505b14e1", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/job_templates/sag_pt_in_proc/config_fed_server.conf b/job_templates/sag_pt_in_proc/config_fed_server.conf index ab5691c4b7..deb678189f 100644 --- a/job_templates/sag_pt_in_proc/config_fed_server.conf +++ b/job_templates/sag_pt_in_proc/config_fed_server.conf @@ -107,7 +107,7 @@ path = "nvflare.app_opt.tracking.mlflow.mlflow_receiver.MLflowReceiver" args { # tracking_uri = "http://0.0.0.0:5000" - tracking_uri = "" + tracking_uri = "file:///{WORKSPACE}/{JOB_ID}/mlruns" kwargs { experiment_name = "nvflare-sag-pt-experiment" run_name = "nvflare-sag-pt-with-mlflow"