From 5fe7341f5785193787d9bbf6bc046cc4843d8291 Mon Sep 17 00:00:00 2001 From: Sam Daulton Date: Mon, 30 Mar 2026 14:23:43 -0700 Subject: [PATCH] example of starting ax experiment from dataframe data (#5098) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/5098 This diff adds a new Ax tutorial notebook demonstrating how to initialize an experiment from pre-existing DataFrame data. **Key Features Covered:** 1. **Attaching Historical Trials** — Shows how to use `client.attach_trial()` and `client.complete_trial()` to import historical experiment data (arms with parameters and their observed metric values) from a pandas DataFrame 2. **Warm-Starting Bayesian Optimization** — After attaching historical data, uses Ax's `Client` to generate new candidate trials with model-based optimization (BoTorch) 3. **Analysis & Visualization** — Demonstrates Ax's built-in analysis tools: cross-validation plots, utility progression tracking, and arm effects visualization (both observed and predicted) **Use Cases:** - Migrating experiments run outside of Ax into the Ax framework - Leveraging existing CSV/database data to warm-start Bayesian optimization - Building on historical configurations and outcomes **Technical Details:** - Uses the Branin function as a benchmark optimization problem (2 parameters, known global minimum ≈ 0.398) - Creates 15 quasi-random historical evaluations and attaches them as completed trials - Generates and evaluates a new candidate using `client.get_next_trial()` - Compares predicted vs. observed arm effects Reviewed By: andycylmeta Differential Revision: D93760064 --- .../experiment_from_dataframe.ipynb | 836 ++++++++++++++++++ 1 file changed, 836 insertions(+) create mode 100644 tutorials/experiment_from_dataframe/experiment_from_dataframe.ipynb diff --git a/tutorials/experiment_from_dataframe/experiment_from_dataframe.ipynb b/tutorials/experiment_from_dataframe/experiment_from_dataframe.ipynb new file mode 100644 index 00000000000..b728e222fbd --- /dev/null +++ b/tutorials/experiment_from_dataframe/experiment_from_dataframe.ipynb @@ -0,0 +1,836 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Attaching Pre-Existing Trial Data with Ax's Client\n", + "\n", + "In many real-world scenarios you may already have data from previous experiments \u2014 perhaps logged in a spreadsheet, database, or DataFrame \u2014 and you want to leverage Ax's modeling and analysis capabilities on top of that data before generating new candidates.\n", + "\n", + "This tutorial demonstrates how to:\n", + "1. Create an experiment using Ax's `Client`\n", + "2. Attach historical trials with arms (and their parameters) from a DataFrame\n", + "3. Attach metric observations from a DataFrame\n", + "4. Analyze the results using Ax's built-in analyses (cross-validation, utility progression, and arm effects)\n", + "5. Generate a new candidate trial using Ax's Bayesian optimization model\n", + "6. Visualize the predicted effects of the new candidate\n", + "\n", + "### When is this useful?\n", + "- You ran experiments outside of Ax and want to use Ax to model them and suggest next steps\n", + "- You have a CSV or database table of past configurations and their outcomes\n", + "- You want to warm-start Bayesian optimization with historical data\n", + "\n", + "### Prerequisites\n", + "- Familiarity with Python and pandas DataFrames\n", + "- Basic understanding of [adaptive experimentation](../../intro-to-ae.mdx) and [Bayesian optimization](../../intro-to-bo.mdx)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 1: Import Necessary Modules" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "from ax.api.client import Client\n", + "from ax.api.configs import RangeParameterConfig" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 2: Prepare Historical Data\n", + "\n", + "Imagine we've previously evaluated the [Branin function](https://www.sfu.ca/~ssurjano/branin.html) at several points and recorded the results in a DataFrame. The Branin function is a common benchmark for optimization with two parameters ($x_1$ and $x_2$) and a known global minimum of $\\approx 0.398$.\n", + "\n", + "$$\n", + "f(x_1, x_2) = \\left(x_2 - \\frac{5.1}{4\\pi^2} x_1^2 + \\frac{5}{\\pi} x_1 - 6\\right)^2 + 10 \\left(1 - \\frac{1}{8\\pi}\\right) \\cos(x_1) + 10\n", + "$$\n", + "\n", + "We'll create a DataFrame representing this historical data \u2014 each row is a previously-evaluated configuration with its observed metric value." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
arm_namex1x2branin
0historical_06.6093413.40858124.314643
1historical_11.5831778.31877230.263474
2historical_27.8789690.95725910.033110
3historical_35.46052012.414468143.167098
4historical_4-3.5873409.47496616.522254
5historical_59.63433511.37131676.539605
6historical_66.4170965.31789037.251285
7historical_76.79096514.560470198.245457
8historical_8-3.07829613.3968172.038708
9historical_91.75578911.67575273.389817
10historical_100.5619702.91958123.084072
11historical_118.9014757.00081526.006108
12historical_124.6579770.65705610.014235
13historical_137.3414242.31434215.782220
14historical_141.65121310.24573451.758615
\n", + "
" + ], + "text/plain": [ + " arm_name x1 x2 branin\n", + "0 historical_0 6.609341 3.408581 24.314643\n", + "1 historical_1 1.583177 8.318772 30.263474\n", + "2 historical_2 7.878969 0.957259 10.033110\n", + "3 historical_3 5.460520 12.414468 143.167098\n", + "4 historical_4 -3.587340 9.474966 16.522254\n", + "5 historical_5 9.634335 11.371316 76.539605\n", + "6 historical_6 6.417096 5.317890 37.251285\n", + "7 historical_7 6.790965 14.560470 198.245457\n", + "8 historical_8 -3.078296 13.396817 2.038708\n", + "9 historical_9 1.755789 11.675752 73.389817\n", + "10 historical_10 0.561970 2.919581 23.084072\n", + "11 historical_11 8.901475 7.000815 26.006108\n", + "12 historical_12 4.657977 0.657056 10.014235\n", + "13 historical_13 7.341424 2.314342 15.782220\n", + "14 historical_14 1.651213 10.245734 51.758615" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": {}, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def branin(x1: float, x2: float) -> float:\n", + " \"\"\"The Branin function \u2014 a standard optimization benchmark.\"\"\"\n", + " return (\n", + " (x2 - 5.1 / (4 * np.pi**2) * x1**2 + 5.0 / np.pi * x1 - 6.0) ** 2\n", + " + 10 * (1 - 1.0 / (8 * np.pi)) * np.cos(x1)\n", + " + 10\n", + " )\n", + "\n", + "\n", + "# Generate historical data: 15 quasi-random evaluations\n", + "rng = np.random.default_rng(42)\n", + "n_historical = 15\n", + "historical_data = pd.DataFrame(\n", + " {\n", + " \"arm_name\": [f\"historical_{i}\" for i in range(n_historical)],\n", + " \"x1\": rng.uniform(-5, 10, n_historical),\n", + " \"x2\": rng.uniform(0, 15, n_historical),\n", + " }\n", + ")\n", + "historical_data[\"branin\"] = historical_data.apply(\n", + " lambda row: branin(row[\"x1\"], row[\"x2\"]), axis=1\n", + ")\n", + "\n", + "historical_data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 3: Initialize the Client and Configure the Experiment\n", + "\n", + "We create a `Client` and define the search space to match our historical data. The Branin function is typically evaluated on $x_1 \\in [-5, 10]$ and $x_2 \\in [0, 15]$." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[INFO 02-19 16:20:55] ax.api.client: GenerationStrategy(name='Center+Sobol+MBM:fast', nodes=[CenterGenerationNode(next_node_name='Sobol', use_existing_trials_for_initialization=True), GenerationNode(name='Sobol', generator_specs=[GeneratorSpec(generator_enum=Sobol, generator_key_override=None)], transition_criteria=[MinTrials(transition_to='MBM'), MinTrials(transition_to='MBM')], suggested_experiment_status=ExperimentStatus.INITIALIZATION), GenerationNode(name='MBM', generator_specs=[GeneratorSpec(generator_enum=BoTorch, generator_key_override=None)], transition_criteria=[], suggested_experiment_status=ExperimentStatus.OPTIMIZATION)]) chosen based on user input and problem structure.\n" + ] + } + ], + "source": [ + "client = Client(random_seed=42)\n", + "\n", + "client.configure_experiment(\n", + " name=\"branin_historical\",\n", + " parameters=[\n", + " RangeParameterConfig(name=\"x1\", parameter_type=\"float\", bounds=(-5.0, 10.0)),\n", + " RangeParameterConfig(name=\"x2\", parameter_type=\"float\", bounds=(0.0, 15.0)),\n", + " ],\n", + ")\n", + "\n", + "client.configure_optimization(objective=\"-branin\") # minimize\n", + "client.configure_generation_strategy()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 4: Attach Historical Trials from the DataFrame\n", + "\n", + "Now we iterate over our DataFrame to attach each historical evaluation as a trial. For each row we:\n", + "1. Call `client.attach_trial` with the arm's parameters and name \u2014 this creates a trial in the RUNNING state\n", + "2. Call `client.complete_trial` with the observed metric data \u2014 this marks the trial COMPLETED\n", + "\n", + "The `raw_data` argument to `complete_trial` is a dictionary mapping metric names to their observed values. Each value can be either a `float` (mean only) or a `tuple[float, float]` (mean, SEM) if you have uncertainty estimates." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[INFO 02-19 16:20:55] ax.api.client: Trial 0 marked COMPLETED.\n", + "[INFO 02-19 16:20:55] ax.api.client: Trial 1 marked COMPLETED.\n", + "[INFO 02-19 16:20:55] ax.api.client: Trial 2 marked COMPLETED.\n", + "[INFO 02-19 16:20:55] ax.api.client: Trial 3 marked COMPLETED.\n", + "[INFO 02-19 16:20:55] ax.api.client: Trial 4 marked COMPLETED.\n", + "[INFO 02-19 16:20:55] ax.api.client: Trial 5 marked COMPLETED.\n", + "[INFO 02-19 16:20:55] ax.api.client: Trial 6 marked COMPLETED.\n", + "[INFO 02-19 16:20:55] ax.api.client: Trial 7 marked COMPLETED.\n", + "[INFO 02-19 16:20:55] ax.api.client: Trial 8 marked COMPLETED.\n", + "[INFO 02-19 16:20:55] ax.api.client: Trial 9 marked COMPLETED.\n", + "[INFO 02-19 16:20:55] ax.api.client: Trial 10 marked COMPLETED.\n", + "[INFO 02-19 16:20:55] ax.api.client: Trial 11 marked COMPLETED.\n", + "[INFO 02-19 16:20:55] ax.api.client: Trial 12 marked COMPLETED.\n", + "[INFO 02-19 16:20:55] ax.api.client: Trial 13 marked COMPLETED.\n", + "[INFO 02-19 16:20:55] ax.api.client: Trial 14 marked COMPLETED.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Attached 15 historical trials.\n" + ] + } + ], + "source": [ + "for _, row in historical_data.iterrows():\n", + " # Attach the trial with its parameterization\n", + " trial_index = client.attach_trial(\n", + " parameters={\"x1\": row[\"x1\"], \"x2\": row[\"x2\"]},\n", + " arm_name=row[\"arm_name\"],\n", + " )\n", + "\n", + " # Complete the trial with observed metric data\n", + " client.complete_trial(\n", + " trial_index=trial_index,\n", + " raw_data={\"branin\": row[\"branin\"]},\n", + " )\n", + "\n", + "print(f\"Attached {len(historical_data)} historical trials.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can inspect the state of the experiment using `client.summarize()`, which returns a DataFrame with one row per arm showing its parameters, metric values, and trial status." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
trial_indexarm_nametrial_statusbraninx1x2
00historical_0COMPLETED24.3146436.6093413.408581
11historical_1COMPLETED30.2634741.5831778.318772
22historical_2COMPLETED10.0331107.8789690.957259
33historical_3COMPLETED143.1670985.46052012.414468
44historical_4COMPLETED16.522254-3.5873409.474966
55historical_5COMPLETED76.5396059.63433511.371316
66historical_6COMPLETED37.2512856.4170965.317890
77historical_7COMPLETED198.2454576.79096514.560470
88historical_8COMPLETED2.038708-3.07829613.396817
99historical_9COMPLETED73.3898171.75578911.675752
1010historical_10COMPLETED23.0840720.5619702.919581
1111historical_11COMPLETED26.0061088.9014757.000815
1212historical_12COMPLETED10.0142354.6579770.657056
1313historical_13COMPLETED15.7822207.3414242.314342
1414historical_14COMPLETED51.7586151.65121310.245734
\n", + "
" + ], + "text/plain": [ + " trial_index arm_name trial_status branin x1 x2\n", + "0 0 historical_0 COMPLETED 24.314643 6.609341 3.408581\n", + "1 1 historical_1 COMPLETED 30.263474 1.583177 8.318772\n", + "2 2 historical_2 COMPLETED 10.033110 7.878969 0.957259\n", + "3 3 historical_3 COMPLETED 143.167098 5.460520 12.414468\n", + "4 4 historical_4 COMPLETED 16.522254 -3.587340 9.474966\n", + "5 5 historical_5 COMPLETED 76.539605 9.634335 11.371316\n", + "6 6 historical_6 COMPLETED 37.251285 6.417096 5.317890\n", + "7 7 historical_7 COMPLETED 198.245457 6.790965 14.560470\n", + "8 8 historical_8 COMPLETED 2.038708 -3.078296 13.396817\n", + "9 9 historical_9 COMPLETED 73.389817 1.755789 11.675752\n", + "10 10 historical_10 COMPLETED 23.084072 0.561970 2.919581\n", + "11 11 historical_11 COMPLETED 26.006108 8.901475 7.000815\n", + "12 12 historical_12 COMPLETED 10.014235 4.657977 0.657056\n", + "13 13 historical_13 COMPLETED 15.782220 7.341424 2.314342\n", + "14 14 historical_14 COMPLETED 51.758615 1.651213 10.245734" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "client.summarize()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 5: Analyze the Experiment\n", + "\n", + "With historical data attached, we can use Ax's analysis framework to understand the experiment so far. We'll compute three analyses:\n", + "\n", + "1. **Cross-Validation Plot**: Assesses how well Ax's surrogate model fits the observed data using leave-one-out cross-validation. Points near the diagonal $y = x$ line indicate good model fit.\n", + "\n", + "2. **Utility Progression**: Shows how the best observed objective value improves over completed trials. This helps us understand whether the historical evaluations were trending toward the optimum.\n", + "\n", + "3. **Arm Effects Plot (Observed)**: Displays the observed metric values for each arm with confidence intervals, providing an overview of all evaluations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from ax.analysis.plotly.cross_validation import CrossValidationPlot\n", + "from ax.analysis.plotly.utility_progression import UtilityProgressionAnalysis\n", + "from ax.analysis.plotly.arm_effects import ArmEffectsPlot\n", + "cards = client.compute_analyses(\n", + " analyses=[\n", + " CrossValidationPlot(),\n", + " UtilityProgressionAnalysis(),\n", + " ArmEffectsPlot(metric_name=\"branin\", use_model_predictions=False),\n", + " ],\n", + " display=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 6: Generate a New Candidate Trial\n", + "\n", + "Now that Ax has a trained model based on the historical data, we can ask it to suggest the next best point to evaluate. Ax uses Bayesian optimization to balance exploration (trying new regions) with exploitation (refining promising regions)." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[INFO 02-19 16:21:07] ax.api.client: Generated new trial 15 with parameters {'x1': -5.0, 'x2': 15.0} using GenerationNode MBM.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Trial 15: {'x1': -5.0, 'x2': 15.0}\n" + ] + } + ], + "source": [ + "new_trials = client.get_next_trials(max_trials=1)\n", + "\n", + "for trial_index, parameters in new_trials.items():\n", + " print(f\"Trial {trial_index}: {parameters}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 7: Visualize Predicted Arm Effects for the Candidate\n", + "\n", + "Before actually evaluating the candidate, we can use the `ArmEffectsPlot` with `use_model_predictions=True` to see what the model predicts for all arms \u2014 including the newly generated candidate. This is useful for understanding:\n", + "- How the model expects the new candidate to perform relative to historical arms\n", + "- The model's uncertainty about each arm's true effect\n", + "- Whether the candidate is expected to improve upon the best historical result\n", + "\n", + "The candidate trial will appear highlighted with a distinct color." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cards = client.compute_analyses(\n", + " analyses=[\n", + " ArmEffectsPlot(metric_name=\"branin\", use_model_predictions=True),\n", + " ],\n", + " display=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 8: Evaluate the Candidate and Continue the Loop\n", + "\n", + "In practice you would evaluate the candidate in your system, then report the result back to Ax and continue the optimization loop. Here we evaluate the Branin function directly and complete the trial." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[INFO 02-19 16:21:11] ax.api.client: Trial 15 marked COMPLETED.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Trial 15 completed with branin = 17.5083\n" + ] + } + ], + "source": [ + "for trial_index, parameters in new_trials.items():\n", + " result = branin(parameters[\"x1\"], parameters[\"x2\"])\n", + " client.complete_trial(\n", + " trial_index=trial_index,\n", + " raw_data={\"branin\": result},\n", + " )\n", + " print(f\"Trial {trial_index} completed with branin = {result:.4f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's run a few more optimization rounds to see Ax improve upon the historical data, then examine the final results." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[INFO 02-19 16:21:12] ax.api.client: Generated new trial 16 with parameters {'x1': 10.0, 'x2': 3.051505} using GenerationNode MBM.\n", + "[INFO 02-19 16:21:12] ax.api.client: Trial 16 marked COMPLETED.\n", + "[INFO 02-19 16:21:13] ax.api.client: Generated new trial 17 with parameters {'x1': -1.478571, 'x2': 9.250075} using GenerationNode MBM.\n", + "[INFO 02-19 16:21:13] ax.api.client: Trial 17 marked COMPLETED.\n", + "[INFO 02-19 16:21:13] ax.api.client: Generated new trial 18 with parameters {'x1': 10.0, 'x2': 0.0} using GenerationNode MBM.\n", + "[INFO 02-19 16:21:13] ax.api.client: Trial 18 marked COMPLETED.\n", + "[INFO 02-19 16:21:13] ax.api.client: Generated new trial 19 with parameters {'x1': -2.930578, 'x2': 15.0} using GenerationNode MBM.\n", + "[INFO 02-19 16:21:13] ax.api.client: Trial 19 marked COMPLETED.\n", + "[INFO 02-19 16:21:13] ax.api.client: Generated new trial 20 with parameters {'x1': -3.225827, 'x2': 12.428191} using GenerationNode MBM.\n", + "[INFO 02-19 16:21:13] ax.api.client: Trial 20 marked COMPLETED.\n" + ] + } + ], + "source": [ + "for _ in range(5):\n", + " trials = client.get_next_trials(max_trials=1)\n", + " for trial_index, parameters in trials.items():\n", + " result = branin(parameters[\"x1\"], parameters[\"x2\"])\n", + " client.complete_trial(\n", + " trial_index=trial_index,\n", + " raw_data={\"branin\": result},\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Best parameters: {'x1': 10.0, 'x2': 3.051504770822107}\n", + "Prediction (mean, sem): {'branin': (np.float64(1.9996011891514804), np.float64(3.26333652262047))}\n", + "From trial 16 (arm '16_0')\n", + "\n", + "Known global minimum: ~0.398 at (x1, x2) in {(-pi, 12.275), (pi, 2.275), (9.425, 2.475)}\n" + ] + } + ], + "source": [ + "best_parameters, prediction, trial_index, arm_name = (\n", + " client.get_best_parameterization()\n", + ")\n", + "print(f\"Best parameters: {best_parameters}\")\n", + "print(f\"Prediction (mean, sem): {prediction}\")\n", + "print(f\"From trial {trial_index} (arm '{arm_name}')\")\n", + "print(f\"\\nKnown global minimum: ~0.398 at (x1, x2) in \"\n", + " \"{(-pi, 12.275), (pi, 2.275), (9.425, 2.475)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 9: Final Analysis\n", + "\n", + "Let's compute a final suite of analyses to see how well the optimization performed after warm-starting from historical data. We include:\n", + "- **Cross-Validation**: Updated model fit assessment with all data\n", + "- **Utility Progression**: Shows how the best value improved \u2014 including through the historical warm-start phase and the Bayesian optimization phase\n", + "- **Modeled Arm Effects**: Model-based predictions for all arms with uncertainty estimates" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cards = client.compute_analyses(\n", + " analyses=[\n", + " CrossValidationPlot(),\n", + " UtilityProgressionAnalysis(),\n", + " ArmEffectsPlot(metric_name=\"branin\", use_model_predictions=True),\n", + " ],\n", + " display=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Conclusion\n", + "\n", + "This tutorial demonstrated how to use Ax's `Client` to:\n", + "- **Attach historical trial data** from a DataFrame using `attach_trial` and `complete_trial`\n", + "- **Analyze the experiment** with cross-validation, utility progression, and arm effects plots\n", + "- **Generate new candidates** via Bayesian optimization, warm-started on historical observations\n", + "- **Visualize predicted arm effects** to understand model expectations before evaluating candidates\n", + "\n", + "This workflow is applicable to any scenario where you have pre-existing evaluations and want to continue optimizing using Ax's Bayesian optimization engine." + ] + } + ], + "metadata": { + "fileHeader": "", + "fileUid": "aff55789-d879-4c9a-b6cf-435b78b72edc", + "isAdHoc": false, + "kernelspec": { + "display_name": "python3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.12.0" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}