# Configurable closed-loop optimization with Ax `Scheduler`

*We recommend reading through the ["Developer API" tutorial](https://ax.dev/tutorials/gpei_hartmann_developer.html) before getting started with the `Scheduler`, as using it in this tutorial will require an Ax `Experiment` and an understanding of the experiment's subcomponents like the search space and the runner.*

### Contents:
1. **Scheduler and external systems for trial evalution** –– overview of how scheduler works with an external system to run a closed-loop optimization.
2. **Set up a mock external system** –– creating a dummy external system client, which will be used to illustrate a scheduler setup in this tutorial.
3. **Set up an experiment according to the mock external system** –– set up a runner that deploys trials to the dummy external system from part 2 and a metric that fetches trial results from that system, then leverage those runner and metric and set up an experiment.
4. **Set up a scheduler**, given an experiment.
   1. Create a scheduler subclass to poll trial status.
   2. Set up a generation strategy using an auto-selection utility.
5. **Running the optimization** via `Scheduler.run_n_trials`.
6. **Leveraging SQL storage and experiment resumption** –– resuming an experiment in one line of code.
7. **Configuring the scheduler** –– overview of the many options scheduler provides to configure the closed-loop down to granular detail.
8. **Advanced functionality**:
   1. Reporting results to an external system during the optimization.
   2. Using `Scheduler.run_trials_and_yield_results` to run the optimization via a generator method.

# 1. Define SearchSpace

In [4]:
from ax import *

from typing import Any, Dict, NamedTuple, Union

from ax.core.base_trial import TrialStatus

Search Space

In [5]:
search_space = SearchSpace(
    parameters = [
        RangeParameter(
            name="antisolvent_volume", 
            parameter_type=ParameterType.FLOAT, 
            lower=30, 
            upper=150,
        ),
        RangeParameter(
            name="antisolvent_rate", 
            parameter_type=ParameterType.FLOAT, 
            lower=30, 
            upper=150,
        ),
        RangeParameter(
            name="antisolvent_timing", 
            parameter_type=ParameterType.FLOAT, 
            lower=-30, 
            upper=-5,
        ),
        RangeParameter(
            name="anneal_duration", 
            parameter_type=ParameterType.FLOAT, 
            lower=15*60, 
            upper=60*60,
        ),
    ]
)

Job Queue

In [6]:
from frgpascal.experimentaldesign.tasks import *
from frgpascal.hardware.liquidlabware import TipRack, LiquidLabware, AVAILABLE_VERSIONS as liquid_labware_versions
from frgpascal.hardware.sampletray import SampleTray, AVAILABLE_VERSIONS as sampletray_versions

from frgpascal.analysis import photoluminescence as PL 

from frgpascal.bridge import PASCALAxQueue


In [7]:
class PASCALJob:
    """Dummy class to represent a job scheduled on `MockJobQueue`."""

    # id: int
    # parameters: Dict[str, Union[str, float, int, bool]]
    
    def __init__(self, job_id, parameters):
        self.job_id = job_id
        self.parameters = parameters

In [10]:
class JobQueue(PASCALAxQueue):
    ### PASCAL methods
    def __init__(self):
        super().__init__()
    def initialize_labware(self):
        self.tipracks = [
            TipRack(
                version='sartorius_safetyspace_tiprack_200ul', 
                deck_slot=7,
                starting_tip="D7"
            ),
            TipRack(
                version='sartorius_safetyspace_tiprack_200ul', 
                deck_slot=10,
                starting_tip="A2"
            ),
        ]
        self.sampletray = SampleTray(
            name='Tray1',
            version='storage_v1',
            gantry=None,
            gripper=None,
            p0=[0,0,0]
        )
        self.solutions = {
            'methylacetate': Solution(
                solvent='MethylAcetate',
                labware='4mL_b_AntisolventTray',
                well='D1',
            ),
            'absorber': Solution(
                solutes= 'FA0.78_MA0.1_Cs0.12_(Pb_(I0.8_Br0.1_I0.1)3)1.09',
                solvent= 'DMF3_DMSO1',
                molarity= 1.2,
            )
        }

    def build_sample(self, parameters: Dict[str, Union[str, float, int, bool]]) -> Sample:
        spincoat_absorber = Spincoat(
            steps=[
                [3000,2000,50], #speed (rpm), acceleration (rpm/s), duration (s)
            ],
            drops = [
                Drop(
                    solution=self.solutions['absorber'],  #this will be filled later using the list of psk solutions
                    volume=20,
                    time=-1,
                    blow_out=True,
                    # pre_mix = (5,50),
                ),
                Drop(
                    solution=self.solutions['methylacetate'],
                    volume=parameters['antisolvent_volume'],
                    time=50+parameters['antisolvent_timing'],
                    reuse_tip=True,
                    touch_tip=False,
                    rate=parameters['antisolvent_rate'],
                    pre_mix = (3,100),
                    slow_travel=True
                )
            ],
        )
        anneal_absorber = Anneal(
            temperature=100,
            duration=parameters['anneal_duration']
        )

        samplename = f'sample{self.sample_counter}'
        sample = Sample(
            name = samplename,
            substrate='1mm glass',
            worklist = [
                spincoat_absorber,
                anneal_absorber,
                Rest(180),
                Characterize()
            ],
            storage_slot = {
                "tray": self.sampletray.name, 
                "slot": self.sampletray.load(samplename)
                },
        )
        return sample

    ### Ax methods
    def schedule_job_with_parameters(
        self, parameters: Dict[str, Union[str, float, int, bool]]
    ) -> int:
        """Schedules an evaluation job with given parameters and returns job ID."""
        # Code to actually schedule the job and produce an ID would go here;
        # using timestamp as dummy ID for this example.
        sample = self.build_sample(parameters)
        self.add_sample(sample=sample)
        job_id = sample.name
        self.jobs[job_id] = PASCALJob(job_id, parameters)
        self.protocols_in_progress.append(job_id)
        return job_id

    def get_job_status(self, job_id: str) -> TrialStatus:
        """ "Get status of the job by a given ID. For simplicity of the example,
        return an Ax `TrialStatus`.
        """
        # sample_name = self.jobs[job_id]
        # Instead of randomizing trial status, code to check actual job status
        # would go here.
        # time.sleep(1)
        if job_id in self.completed_protocols:
            return TrialStatus.COMPLETED
        return TrialStatus.RUNNING

    def get_outcome_value_for_completed_job(self, job_id: str) -> Dict[str, float]:
        """Get evaluation results for a given completed job."""
        job = self.jobs[job_id]

        fid = os.path.join(self.experiment_folder, 'PL_635', '{job_id}_pl.csv')
        wl, cps = PL.load_spectrum(fid)
        fit = PL.fit_spectrum(wl=wl, cts=cps, wlmin=650, wlmax=1100, wlguess=730, plot=False)
        # In a real external system, this would retrieve real relevant outcomes and
        # not a synthetic function value.
        return {"redplspec_intensity": (fit['intensity'], 0.0)} 


MOCK_JOB_QUEUE_CLIENT = JobQueue()


def get_mock_job_queue_client() -> JobQueue:
    """Obtain the singleton job queue instance."""
    return MOCK_JOB_QUEUE_CLIENT

Metric

In [11]:
import pandas as pd

from ax.core.metric import Metric
from ax.core.base_trial import BaseTrial
from ax.core.data import Data

class PLBrightnessMetric(Metric):  # Pulls data for trial from external system.
    
    def fetch_trial_data(self, trial: BaseTrial) -> Data:
        """Obtains data via fetching it from ` for a given trial."""
        if not isinstance(trial, Trial):
            raise ValueError("This metric only handles `Trial`.")
        
        mock_job_queue = get_mock_job_queue_client()
        
        # Here we leverage the "job_id" metadata created by `MockJobRunner.run`.
        sample_data = mock_job_queue.get_outcome_value_for_completed_job(
            job_id=trial.run_metadata.get("job_id")
        )
        df_dict = {
            "trial_index": trial.index,
            "metric_name": "redplspec_intensity",
            "arm_name": trial.arm.name,
            "mean": sample_data.get("redplspec_intensity"),
            # Can be set to 0.0 if function is known to be noiseless
            # or to an actual value when SEM is known. Setting SEM to
            # `None` results in Ax assuming unknown noise and inferring
            # noise level from data.
            "sem": None,
        }
        return Data(df=pd.DataFrame.from_records([df_dict]))

Experiment

In [None]:
from ax import *

def make_experiment_with_runner_and_metric() -> Experiment:

    objective=Objective(
        metric=PLBrightnessMetric(
            name="redplspec_intensity"
            ), 
        minimize=False
        )

    return Experiment(
        name="branin_test_experiment",
        search_space=SearchSpace(parameters=parameters),
        optimization_config=OptimizationConfig(objective=objective),
        runner=MockJobRunner(),
        is_test=True,  # Marking this experiment as a test experiment.
    )

experiment = make_branin_experiment_with_runner_and_metric()

## 2. Set up a mock external execution system 

An example of an 'external system' running trial evaluations could be a remote server executing scheduled jobs, a subprocess conducting ML training runs, an engine running physics simulations, etc. For the sake of example here, let us assume a dummy external system with the following client:

In [6]:
import numpy as np
import os
import time
from queue import Queue
import json
import asyncio
from abc import ABC, abstractmethod
import ntplib
from frgpascal.experimentaldesign.tasks import Sample
from frgpascal.workers import (
    Worker_Hotplate,
    Worker_Storage,
    Worker_GantryGripper,
    Worker_Characterization,
    Worker_SpincoaterLiquidHandler,
)
from frgpascal.bridge import PASCALAxQueue
from frgpascal import system

In [None]:
class MockJobQueue(PASCALAxQueue):
    ### PASCAL methods
    def build_sample(self, parameters: Dict[str, Union[str, float, int, bool]]) -> Sample:

    ### Ax methods
    def schedule_job_with_parameters(
        self, parameters: Dict[str, Union[str, float, int, bool]]
    ) -> int:
        """Schedules an evaluation job with given parameters and returns job ID."""
        # Code to actually schedule the job and produce an ID would go here;
        # using timestamp as dummy ID for this example.
        job_id = int(time.time())
        self.jobs[job_id] = MockJob(job_id, parameters)
        return job_id

    def get_job_status(self, job_id: int) -> TrialStatus:
        """ "Get status of the job by a given ID. For simplicity of the example,
        return an Ax `TrialStatus`.
        """
        job = self.jobs[job_id]
        # Instead of randomizing trial status, code to check actual job status
        # would go here.
        # time.sleep(1)
        if (time.time() - job.t0) >= job.duration:
            return TrialStatus.COMPLETED
        return TrialStatus.RUNNING

    def get_outcome_value_for_completed_job(self, job_id: int) -> Dict[str, float]:
        """Get evaluation results for a given completed job."""
        job = self.jobs[job_id]
        # In a real external system, this would retrieve real relevant outcomes and
        # not a synthetic function value.
        return {"branin": branin(job.parameters.get("x1"), job.parameters.get("x2"))} 


MOCK_JOB_QUEUE_CLIENT = MockJobQueueClient()


def get_mock_job_queue_client() -> MockJobQueueClient:
    """Obtain the singleton job queue instance."""
    return MOCK_JOB_QUEUE_CLIENT

## 3. Set up an experiment according to the mock external system

As mentioned above, using a `Scheduler` requires a fully set up experiment with metrics and a runner. Refer to the "Building Blocks of Ax" tutorial to learn more about those components, as here we assume familiarity with them. 

The following runner and metric set up intractions between the `Scheduler` and the mock external system we assume:

In [74]:
from collections import defaultdict
from typing import Iterable, Set

from ax.core.base_trial import BaseTrial
from ax.core.runner import Runner
from ax.core.trial import Trial


class MockJobRunner(Runner):  # Deploys trials to external system.
    def run(self, trial: BaseTrial) -> Dict[str, Any]:
        """Deploys a trial based on custom runner subclass implementation.

        Args:
            trial: The trial to deploy.

        Returns:
            Dict of run metadata from the deployment process.
        """
        if not isinstance(trial, Trial):
            raise ValueError("This runner only handles `Trial`.")

        mock_job_queue = get_mock_job_queue_client()
        job_id = mock_job_queue.schedule_job_with_parameters(
            parameters=trial.arm.parameters
        )
        # This run metadata will be attached to trial as `trial.run_metadata`
        # by the base `Scheduler`.
        return {"job_id": job_id}

    def poll_trial_status(
        self, trials: Iterable[BaseTrial]
    ) -> Dict[TrialStatus, Set[int]]:
        """Checks the status of any non-terminal trials and returns their
        indices as a mapping from TrialStatus to a list of indices. Required
        for runners used with Ax ``Scheduler``.

        NOTE: Does not need to handle waiting between polling calls while trials
        are running; this function should just perform a single poll.

        Args:
            trials: Trials to poll.

        Returns:
            A dictionary mapping TrialStatus to a list of trial indices that have
            the respective status at the time of the polling. This does not need to
            include trials that at the time of polling already have a terminal
            (ABANDONED, FAILED, COMPLETED) status (but it may).
        """
        status_dict = defaultdict(set)
        for trial in trials:
            mock_job_queue = get_mock_job_queue_client()
            status = mock_job_queue.get_job_status(
                job_id=trial.run_metadata.get("job_id")
            )
            status_dict[status].add(trial.index)

        return status_dict

In [75]:
import pandas as pd

from ax.core.metric import Metric
from ax.core.base_trial import BaseTrial
from ax.core.data import Data

class BraninForMockJobMetric(Metric):  # Pulls data for trial from external system.
    
    def fetch_trial_data(self, trial: BaseTrial) -> Data:
        """Obtains data via fetching it from ` for a given trial."""
        if not isinstance(trial, Trial):
            raise ValueError("This metric only handles `Trial`.")
        
        mock_job_queue = get_mock_job_queue_client()
        
        # Here we leverage the "job_id" metadata created by `MockJobRunner.run`.
        branin_data = mock_job_queue.get_outcome_value_for_completed_job(
            job_id=trial.run_metadata.get("job_id")
        )
        df_dict = {
            "trial_index": trial.index,
            "metric_name": "branin",
            "arm_name": trial.arm.name,
            "mean": branin_data.get("branin"),
            # Can be set to 0.0 if function is known to be noiseless
            # or to an actual value when SEM is known. Setting SEM to
            # `None` results in Ax assuming unknown noise and inferring
            # noise level from data.
            "sem": None,
        }
        return Data(df=pd.DataFrame.from_records([df_dict]))

Now we can set up the experiment using the runner and metric we defined. This experiment will have a single-objective optimization config, minimizing the Branin function, and the search space that corresponds to that function.

## 4. Setting up a `Scheduler`

### 4a. Subclassing `Scheduler`

The base Ax `Scheduler` is abstract and must be subclassed, but only one method must be implemented on the subclass: `poll_trial_status`. 

### 4B. Auto-selecting a generation strategy

A `Scheduler` also requires an Ax `GenerationStrategy` specifying the algorithm to use for the optimization. Here we use the `choose_generation_strategy` utility that auto-picks a generation strategy based on the search space properties. To construct a custom generation strategy instead, refer to the ["Generation Strategy" tutorial](https://ax.dev/tutorials/generation_strategy.html).

Importantly, a generation strategy in Ax limits allowed parallelism levels for each generation step it contains. If you would like the `Scheduler` to ensure parallelism limitations, set `max_examples` on each generation step in your generation strategy.

In [40]:
from ax.modelbridge.dispatch_utils import choose_generation_strategy

generation_strategy = choose_generation_strategy(
    search_space=experiment.search_space, 
    max_parallelism_cap=3,
)

[INFO 01-20 13:37:44] ax.modelbridge.dispatch_utils: Using Bayesian optimization since there are more ordered parameters than there are categories for the unordered categorical parameters.
[INFO 01-20 13:37:44] ax.modelbridge.dispatch_utils: Using Bayesian Optimization generation strategy: GenerationStrategy(name='Sobol+GPEI', steps=[Sobol for 5 trials, GPEI for subsequent trials]). Iterations after 5 will take longer to generate due to  model-fitting.


In [86]:
from ax.modelbridge.generation_strategy import GenerationStrategy, GenerationStep

generation_strategy = GenerationStrategy(
    steps=[
        # 1. Initialization step (does not require pre-existing data and is well-suited for
        # initial sampling of the search space)
        GenerationStep(
            model=Models.SOBOL,
            num_trials=5,  # How many trials should be produced from this generation step
            min_trials_observed=3,  # How many trials need to be completed to move to next model
            max_parallelism=5,  # Max parallelism for this step
            model_kwargs={"seed": 999},  # Any kwargs you want passed into the model
            model_gen_kwargs={},  # Any kwargs you want passed to `modelbridge.gen`
        ),
        # 2. Bayesian optimization step (requires data obtained from previous phase and learns
        # from all data available at the time of each new candidate generation call)
        GenerationStep(
            model=Models.GPEI,
            num_trials=20,  # No limitation on how many trials should be produced from this step
            max_parallelism=5,  # Parallelism limit for this step, often lower than for Sobol
            # More on parallelism vs. required samples in BayesOpt:
            # https://ax.dev/docs/bayesopt.html#tradeoff-between-parallelism-and-total-number-of-trials
        ),
    ]
)

Now we have all the components needed to start the scheduler:

In [87]:
from ax.service.scheduler import Scheduler, SchedulerOptions


scheduler = Scheduler(
    experiment=experiment,
    generation_strategy=generation_strategy,
    options=SchedulerOptions(),
)

[INFO 01-20 14:06:32] Scheduler: `Scheduler` requires experiment to have immutable search space and optimization config. Setting property immutable_search_space_and_opt_config to `True` on experiment.


In [88]:
# scheduler.poll_trial_status = scheduler.experiment.runner.poll_trial_status

## 5. Running the optimization

Once the `Scheduler` instance is set up, user can execute `run_n_trials` as many times as needed, and each execution will add up to the specified `max_trials` trials to the experiment. The number of trials actually run might be less than `max_trials` if the optimization was concluded (e.g. there are no more points in the search space).

In [89]:
scheduler.run_n_trials(max_trials=30)

[INFO 01-20 14:06:36] Scheduler: Running trials [0]...
[INFO 01-20 14:06:37] Scheduler: Running trials [1]...
[INFO 01-20 14:06:38] Scheduler: Running trials [2]...
[INFO 01-20 14:06:39] Scheduler: Running trials [3]...
[INFO 01-20 14:06:40] Scheduler: Running trials [4]...
[INFO 01-20 14:06:41] Scheduler: Generated all trials that can be generated currently. Model requires more data to generate more trials.
[INFO 01-20 14:06:41] Scheduler: Waiting for completed trials (for 1 sec, currently running trials: 5).
[INFO 01-20 14:06:42] Scheduler: Waiting for completed trials (for 1.5 sec, currently running trials: 5).
[INFO 01-20 14:06:43] Scheduler: Waiting for completed trials (for 2 sec, currently running trials: 5).
[INFO 01-20 14:06:46] Scheduler: Retrieved COMPLETED trials: [1].
[INFO 01-20 14:06:46] Scheduler: Fetching data for trials: [1].
[INFO 01-20 14:06:46] Scheduler: Waiting for completed trials (for 1 sec, currently running trials: 4).
[INFO 01-20 14:06:47] Scheduler: Waiting

OptimizationResult()

We can examine `experiment` to see that it now has three trials:

In [91]:
from ax.service.utils.report_utils import exp_to_df

df = exp_to_df(experiment)

In [92]:
import matplotlib.pyplot as plt

In [None]:
from ax.service.utils.report_utils import exp_to_df

exp_to_df(experiment)

Unnamed: 0,branin,trial_index,arm_name,x1,x2,trial_status,generation_method
0,46.244599,0,0_0,4.430947,7.722154,COMPLETED,Sobol
3,14.735405,1,1_0,1.523239,5.943987,COMPLETED,Sobol
15,2.808085,2,2_0,-3.865318,14.02389,COMPLETED,Sobol
23,9.95684,3,3_0,7.948669,0.581467,COMPLETED,Sobol
24,95.420824,4,4_0,9.308774,12.123547,COMPLETED,Sobol
25,50.700722,5,5_0,-1.509119,2.362892,COMPLETED,Sobol
26,21.301719,6,6_0,-0.804579,9.519027,COMPLETED,Sobol
27,3.173198,7,7_0,2.982035,4.031629,COMPLETED,Sobol
28,157.877833,8,8_0,4.037862,14.084431,COMPLETED,Sobol
29,41.053458,9,9_0,0.016886,1.341454,COMPLETED,Sobol


Now we can run `run_n_trials` again to add three more trials to the experiment.

In [32]:
scheduler.run_n_trials(max_trials=3)

[INFO 01-20 13:35:42] Scheduler: Running trials [6]...
[INFO 01-20 13:35:43] Scheduler: Running trials [7]...
[INFO 01-20 13:35:45] Scheduler: Running trials [8]...
[INFO 01-20 13:35:48] Scheduler: Retrieved COMPLETED trials: 6 - 8.
[INFO 01-20 13:35:48] Scheduler: Fetching data for trials: 6 - 8.


OptimizationResult()

Examiniming the experiment, we now see 6 trials, one of which is produced by Bayesian optimization (GPEI):

In [10]:
exp_to_df(experiment)

Unnamed: 0,branin,trial_index,arm_name,x1,x2,trial_status,generation_method
0,27.402094,0,0_0,-4.022825,9.644323,COMPLETED,Sobol
1,4.438386,1,1_0,5.277253,8.187497,COMPLETED,Sobol
2,4.438386,2,2_0,4.094302,1.70019,COMPLETED,Sobol
3,20.659917,3,3_0,1.132628,6.929395,COMPLETED,Sobol
4,65.515889,4,4_0,6.307456,7.877028,COMPLETED,Sobol
5,0.517103,5,5_0,9.42251,2.127847,COMPLETED,GPEI


For each call to `run_n_trials`, one can specify a timeout; if `run_n_trials` has been running for too long without finishing its `max_trials`, the operation will exit gracefully:

In [13]:
scheduler.run_n_trials(max_trials=3)#, timeout_hours=0.00001)


A not p.d., added jitter of 1.0e-08 to the diagonal

[INFO 01-20 13:31:29] Scheduler: Running trials [9]...
[INFO 01-20 13:31:31] Scheduler: Running trials [10]...

A not p.d., added jitter of 1.0e-08 to the diagonal

[INFO 01-20 13:31:32] Scheduler: Running trials [11]...
[INFO 01-20 13:31:33] Scheduler: Retrieved COMPLETED trials: [11].
[INFO 01-20 13:31:33] Scheduler: Fetching data for trials: [11].
[INFO 01-20 13:31:33] Scheduler: Done submitting trials, waiting for remaining 2 running trials...
[INFO 01-20 13:31:33] Scheduler: Retrieved COMPLETED trials: 9 - 10.
[INFO 01-20 13:31:33] Scheduler: Fetching data for trials: 9 - 10.


OptimizationResult()

In [14]:
exp_to_df(experiment)

Unnamed: 0,branin,trial_index,arm_name,x1,x2,trial_status,generation_method
0,27.402094,0,0_0,-4.022825,9.644323,COMPLETED,Sobol
3,4.438386,1,1_0,5.277253,8.187497,COMPLETED,Sobol
4,4.438386,2,2_0,4.094302,1.70019,COMPLETED,Sobol
5,20.659917,3,3_0,1.132628,6.929395,COMPLETED,Sobol
6,65.515889,4,4_0,6.307456,7.877028,COMPLETED,Sobol
7,0.517103,5,5_0,9.42251,2.127847,COMPLETED,GPEI
8,308.129096,6,6_0,-5.0,0.0,COMPLETED,GPEI
9,145.872191,7,7_0,10.0,15.0,COMPLETED,GPEI
10,10.960889,8,8_0,10.0,0.0,COMPLETED,GPEI
11,66.329507,9,9_0,-0.031439,12.886256,COMPLETED,GPEI


## 7. Configuring the scheduler

`Scheduler` exposes many options to configure the exact settings of the closed-loop optimization to perform. A few notable ones are:
- `trial_type` –– currently only `Trial` and not `BatchTrial` is supported, but support for `BatchTrial`-s will follow,
- `tolerated_trial_failure_rate` and `min_failed_trials_for_failure_rate_check` –– together these two settings control how the scheduler monitors the failure rate among trial runs it deploys. Once `min_failed_trials_for_failure_rate_check` is deployed, the scheduler will start checking whether the ratio of failed to total trials is greater than `tolerated_trial_failure_rate`, and if it is, scheduler will exit the optimization with a `FailureRateExceededError`,
- `ttl_seconds_for_trials` –– sometimes a failure in a trial run means that it will be difficult to query its status (e.g. due to a crash). If this setting is specified, the Ax `Experiment` will automatically mark trials that have been running for too long (more than their 'time-to-live' (TTL) seconds) as failed,
- `run_trials_in_batches` –– if `True`, the scheduler will attempt to run trials not by calling `Scheduler.run_trial` in a loop, but by calling `Scheduler.run_trials` on all ready-to-deploy trials at once. This could allow for saving compute in cases where the deployment operation has large overhead and deploying many trials at once saves compute. Note that using this option successfully will require your scheduler subclass to implement `MySchedulerSubclass.run_trials` and `MySchedulerSubclass.poll_available_capacity`.

The rest of the options is described in the docstring below:

In [16]:
print(SchedulerOptions.__doc__)

Settings for a scheduler instance.

    Attributes:
        max_pending_trials: Maximum number of pending trials the scheduler
            can have ``STAGED`` or ``RUNNING`` at once, required. If looking
            to use ``Runner.poll_available_capacity`` as a primary guide for
            how many trials should be pending at a given time, set this limit
            to a high number, as an upper bound on number of trials that
            should not be exceeded.
        trial_type: Type of trials (1-arm ``Trial`` or multi-arm ``Batch
            Trial``) that will be deployed using the scheduler. Defaults
            to 1-arm `Trial`. NOTE: use ``BatchTrial`` only if need to
            evaluate multiple arms *together*, e.g. in an A/B-test
            influenced by data nonstationarity. For cases where just
            deploying multiple arms at once is beneficial but the trials
            are evaluated *independently*, implement ``run_trials`` method
            in scheduler subcla

### 8b. Using `run_trials_and_yield_results` generator method

In some systems it's beneficial to have greater control over `Scheduler.run_n_trials` instead of just starting it and needing to wait for it to run all the way to completion before having access to its output. For this purpose, the `Scheduler` implements a generator method `run_trials_and_yield_results`, which yields the output of `Scheduler.report_results` each time there are new completed trials and can be used like so:

In [78]:
class ResultReportingScheduler(Scheduler):
    def report_results(self):
        return True, {
            "trials so far": len(self.experiment.trials),
            "currently producing trials from generation step": self.generation_strategy._curr.model_name,
            "running trials": [t.index for t in self.running_trials],
        }

In [83]:
experiment = make_branin_experiment_with_runner_and_metric()

generation_strategy = GenerationStrategy(
    steps=[
        # 1. Initialization step (does not require pre-existing data and is well-suited for
        # initial sampling of the search space)
        GenerationStep(
            model=Models.SOBOL,
            num_trials=5,  # How many trials should be produced from this generation step
            min_trials_observed=3,  # How many trials need to be completed to move to next model
            max_parallelism=5,  # Max parallelism for this step
            model_kwargs={"seed": 999},  # Any kwargs you want passed into the model
            model_gen_kwargs={},  # Any kwargs you want passed to `modelbridge.gen`
        ),
        # 2. Bayesian optimization step (requires data obtained from previous phase and learns
        # from all data available at the time of each new candidate generation call)
        GenerationStep(
            model=Models.GPEI,
            num_trials=20,  # No limitation on how many trials should be produced from this step
            max_parallelism=3,  # Parallelism limit for this step, often lower than for Sobol
            # More on parallelism vs. required samples in BayesOpt:
            # https://ax.dev/docs/bayesopt.html#tradeoff-between-parallelism-and-total-number-of-trials
        ),
    ]
)

scheduler = ResultReportingScheduler(
    experiment=experiment,
    generation_strategy=generation_strategy,
    options=SchedulerOptions(),
)

for reported_result in scheduler.run_trials_and_yield_results(max_trials=20):
    print("Reported result: ", reported_result)

[INFO 01-20 14:03:35] ax.core.experiment: The is_test flag has been set to True. This flag is meant purely for development and integration testing purposes. If you are running a live experiment, please set this flag to False
[INFO 01-20 14:03:35] ResultReportingScheduler: `Scheduler` requires experiment to have immutable search space and optimization config. Setting property immutable_search_space_and_opt_config to `True` on experiment.
[INFO 01-20 14:03:35] ResultReportingScheduler: Running trials [0]...
[INFO 01-20 14:03:36] ResultReportingScheduler: Running trials [1]...
[INFO 01-20 14:03:37] ResultReportingScheduler: Running trials [2]...
[INFO 01-20 14:03:38] ResultReportingScheduler: Running trials [3]...
[INFO 01-20 14:03:39] ResultReportingScheduler: Running trials [4]...
[INFO 01-20 14:03:40] ResultReportingScheduler: Generated all trials that can be generated currently. Model requires more data to generate more trials.
[INFO 01-20 14:03:40] ResultReportingScheduler: Retrieved

Reported result:  (True, {'trials so far': 5, 'currently producing trials from generation step': 'Sobol', 'running trials': [0, 1, 3, 4]})


[INFO 01-20 14:03:41] ResultReportingScheduler: Waiting for completed trials (for 1.5 sec, currently running trials: 4).
[INFO 01-20 14:03:43] ResultReportingScheduler: Waiting for completed trials (for 2 sec, currently running trials: 4).
[INFO 01-20 14:03:45] ResultReportingScheduler: Waiting for completed trials (for 3 sec, currently running trials: 4).
[INFO 01-20 14:03:49] ResultReportingScheduler: Waiting for completed trials (for 5 sec, currently running trials: 4).
[INFO 01-20 14:03:54] ResultReportingScheduler: Retrieved COMPLETED trials: [0].
[INFO 01-20 14:03:54] ResultReportingScheduler: Fetching data for trials: [0].
[INFO 01-20 14:03:54] ResultReportingScheduler: Waiting for completed trials (for 1 sec, currently running trials: 3).


Reported result:  (True, {'trials so far': 5, 'currently producing trials from generation step': 'Sobol', 'running trials': [1, 3, 4]})


[INFO 01-20 14:03:55] ResultReportingScheduler: Waiting for completed trials (for 1.5 sec, currently running trials: 3).
[INFO 01-20 14:03:56] ResultReportingScheduler: Retrieved COMPLETED trials: 3 - 4.
[INFO 01-20 14:03:56] ResultReportingScheduler: Fetching data for trials: 3 - 4.
[INFO 01-20 14:03:57] ResultReportingScheduler: Running trials [5]...


Reported result:  (True, {'trials so far': 5, 'currently producing trials from generation step': 'Sobol', 'running trials': [1]})


[INFO 01-20 14:03:58] ResultReportingScheduler: Running trials [6]...
[INFO 01-20 14:03:59] ResultReportingScheduler: Running trials [7]...
[INFO 01-20 14:04:00] ResultReportingScheduler: Generated all trials that can be generated currently. Max parallelism currently reached.
[INFO 01-20 14:04:00] ResultReportingScheduler: Retrieved COMPLETED trials: [6].
[INFO 01-20 14:04:00] ResultReportingScheduler: Fetching data for trials: [6].


Reported result:  (True, {'trials so far': 8, 'currently producing trials from generation step': 'GPEI', 'running trials': [1, 5, 7]})


[INFO 01-20 14:04:00] ResultReportingScheduler: Running trials [8]...
[INFO 01-20 14:04:02] ResultReportingScheduler: Generated all trials that can be generated currently. Max parallelism currently reached.
[INFO 01-20 14:04:02] ResultReportingScheduler: Retrieved COMPLETED trials: [1, 7].
[INFO 01-20 14:04:02] ResultReportingScheduler: Fetching data for trials: [1, 7].


Reported result:  (True, {'trials so far': 9, 'currently producing trials from generation step': 'GPEI', 'running trials': [5, 8]})



A not p.d., added jitter of 1.0e-08 to the diagonal

[INFO 01-20 14:04:02] ResultReportingScheduler: Running trials [9]...
[INFO 01-20 14:04:03] ResultReportingScheduler: Generated all trials that can be generated currently. Max parallelism currently reached.
[INFO 01-20 14:04:03] ResultReportingScheduler: Waiting for completed trials (for 1 sec, currently running trials: 3).
[INFO 01-20 14:04:04] ResultReportingScheduler: Waiting for completed trials (for 1.5 sec, currently running trials: 3).
[INFO 01-20 14:04:06] ResultReportingScheduler: Waiting for completed trials (for 2 sec, currently running trials: 3).
[INFO 01-20 14:04:08] ResultReportingScheduler: Waiting for completed trials (for 3 sec, currently running trials: 3).
[INFO 01-20 14:04:11] ResultReportingScheduler: Waiting for completed trials (for 5 sec, currently running trials: 3).
[INFO 01-20 14:04:16] ResultReportingScheduler: Retrieved COMPLETED trials: [5, 8].
[INFO 01-20 14:04:16] ResultReportingScheduler: Fetching d

Reported result:  (True, {'trials so far': 10, 'currently producing trials from generation step': 'GPEI', 'running trials': [9]})


[INFO 01-20 14:04:17] ResultReportingScheduler: Running trials [10]...
[INFO 01-20 14:04:18] ResultReportingScheduler: Running trials [11]...
[INFO 01-20 14:04:19] ResultReportingScheduler: Generated all trials that can be generated currently. Max parallelism currently reached.
[INFO 01-20 14:04:19] ResultReportingScheduler: Waiting for completed trials (for 1 sec, currently running trials: 3).
[INFO 01-20 14:04:20] ResultReportingScheduler: Waiting for completed trials (for 1.5 sec, currently running trials: 3).
[INFO 01-20 14:04:22] ResultReportingScheduler: Waiting for completed trials (for 2 sec, currently running trials: 3).
[INFO 01-20 14:04:24] ResultReportingScheduler: Retrieved COMPLETED trials: [9].
[INFO 01-20 14:04:24] ResultReportingScheduler: Fetching data for trials: [9].


Reported result:  (True, {'trials so far': 12, 'currently producing trials from generation step': 'GPEI', 'running trials': [10, 11]})


[INFO 01-20 14:04:24] ResultReportingScheduler: Running trials [12]...
[INFO 01-20 14:04:25] ResultReportingScheduler: Generated all trials that can be generated currently. Max parallelism currently reached.
[INFO 01-20 14:04:25] ResultReportingScheduler: Waiting for completed trials (for 1 sec, currently running trials: 3).
[INFO 01-20 14:04:26] ResultReportingScheduler: Waiting for completed trials (for 1.5 sec, currently running trials: 3).
[INFO 01-20 14:04:28] ResultReportingScheduler: Waiting for completed trials (for 2 sec, currently running trials: 3).
[INFO 01-20 14:04:30] ResultReportingScheduler: Waiting for completed trials (for 3 sec, currently running trials: 3).
[INFO 01-20 14:04:34] ResultReportingScheduler: Retrieved COMPLETED trials: [10].
[INFO 01-20 14:04:34] ResultReportingScheduler: Fetching data for trials: [10].


Reported result:  (True, {'trials so far': 13, 'currently producing trials from generation step': 'GPEI', 'running trials': [11, 12]})


[INFO 01-20 14:04:34] ResultReportingScheduler: Running trials [13]...
[INFO 01-20 14:04:35] ResultReportingScheduler: Generated all trials that can be generated currently. Max parallelism currently reached.
[INFO 01-20 14:04:35] ResultReportingScheduler: Waiting for completed trials (for 1 sec, currently running trials: 3).
[INFO 01-20 14:04:36] ResultReportingScheduler: Waiting for completed trials (for 1.5 sec, currently running trials: 3).
[INFO 01-20 14:04:38] ResultReportingScheduler: Waiting for completed trials (for 2 sec, currently running trials: 3).
[INFO 01-20 14:04:40] ResultReportingScheduler: Retrieved COMPLETED trials: [11].
[INFO 01-20 14:04:40] ResultReportingScheduler: Fetching data for trials: [11].


Reported result:  (True, {'trials so far': 14, 'currently producing trials from generation step': 'GPEI', 'running trials': [12, 13]})


[INFO 01-20 14:04:40] ResultReportingScheduler: Running trials [14]...
[INFO 01-20 14:04:41] ResultReportingScheduler: Generated all trials that can be generated currently. Max parallelism currently reached.
[INFO 01-20 14:04:41] ResultReportingScheduler: Waiting for completed trials (for 1 sec, currently running trials: 3).
[INFO 01-20 14:04:42] ResultReportingScheduler: Waiting for completed trials (for 1.5 sec, currently running trials: 3).
[INFO 01-20 14:04:44] ResultReportingScheduler: Waiting for completed trials (for 2 sec, currently running trials: 3).
[INFO 01-20 14:04:46] ResultReportingScheduler: Waiting for completed trials (for 3 sec, currently running trials: 3).
[INFO 01-20 14:04:50] ResultReportingScheduler: Waiting for completed trials (for 5 sec, currently running trials: 3).
[INFO 01-20 14:04:55] ResultReportingScheduler: Retrieved COMPLETED trials: [12].
[INFO 01-20 14:04:55] ResultReportingScheduler: Fetching data for trials: [12].


Reported result:  (True, {'trials so far': 15, 'currently producing trials from generation step': 'GPEI', 'running trials': [13, 14]})


[INFO 01-20 14:04:55] ResultReportingScheduler: Running trials [15]...
[INFO 01-20 14:04:56] ResultReportingScheduler: Generated all trials that can be generated currently. Max parallelism currently reached.
[INFO 01-20 14:04:56] ResultReportingScheduler: Waiting for completed trials (for 1 sec, currently running trials: 3).
[INFO 01-20 14:04:57] ResultReportingScheduler: Waiting for completed trials (for 1.5 sec, currently running trials: 3).
[INFO 01-20 14:04:59] ResultReportingScheduler: Waiting for completed trials (for 2 sec, currently running trials: 3).
[INFO 01-20 14:05:01] ResultReportingScheduler: Retrieved COMPLETED trials: 13 - 14.
[INFO 01-20 14:05:01] ResultReportingScheduler: Fetching data for trials: 13 - 14.


Reported result:  (True, {'trials so far': 16, 'currently producing trials from generation step': 'GPEI', 'running trials': [15]})



A not p.d., added jitter of 1.0e-08 to the diagonal

[INFO 01-20 14:05:03] ResultReportingScheduler: Running trials [16]...
[INFO 01-20 14:05:04] ResultReportingScheduler: Running trials [17]...
[INFO 01-20 14:05:06] ResultReportingScheduler: Generated all trials that can be generated currently. Max parallelism currently reached.
[INFO 01-20 14:05:06] ResultReportingScheduler: Waiting for completed trials (for 1 sec, currently running trials: 3).
[INFO 01-20 14:05:07] ResultReportingScheduler: Waiting for completed trials (for 1.5 sec, currently running trials: 3).
[INFO 01-20 14:05:08] ResultReportingScheduler: Waiting for completed trials (for 2 sec, currently running trials: 3).
[INFO 01-20 14:05:10] ResultReportingScheduler: Retrieved COMPLETED trials: [17].
[INFO 01-20 14:05:10] ResultReportingScheduler: Fetching data for trials: [17].


Reported result:  (True, {'trials so far': 18, 'currently producing trials from generation step': 'GPEI', 'running trials': [15, 16]})


[INFO 01-20 14:05:11] ResultReportingScheduler: Running trials [18]...
[INFO 01-20 14:05:12] ResultReportingScheduler: Generated all trials that can be generated currently. Max parallelism currently reached.
[INFO 01-20 14:05:12] ResultReportingScheduler: Waiting for completed trials (for 1 sec, currently running trials: 3).
[INFO 01-20 14:05:13] ResultReportingScheduler: Waiting for completed trials (for 1.5 sec, currently running trials: 3).
[INFO 01-20 14:05:15] ResultReportingScheduler: Waiting for completed trials (for 2 sec, currently running trials: 3).
[INFO 01-20 14:05:17] ResultReportingScheduler: Retrieved COMPLETED trials: [16].
[INFO 01-20 14:05:17] ResultReportingScheduler: Fetching data for trials: [16].


Reported result:  (True, {'trials so far': 19, 'currently producing trials from generation step': 'GPEI', 'running trials': [18, 15]})



A not p.d., added jitter of 1.0e-08 to the diagonal

[INFO 01-20 14:05:18] ResultReportingScheduler: Running trials [19]...
[INFO 01-20 14:05:19] ResultReportingScheduler: Waiting for completed trials (for 1 sec, currently running trials: 3).
[INFO 01-20 14:05:20] ResultReportingScheduler: Waiting for completed trials (for 1.5 sec, currently running trials: 3).
[INFO 01-20 14:05:21] ResultReportingScheduler: Waiting for completed trials (for 2 sec, currently running trials: 3).
[INFO 01-20 14:05:24] ResultReportingScheduler: Retrieved COMPLETED trials: [15].
[INFO 01-20 14:05:24] ResultReportingScheduler: Fetching data for trials: [15].
[INFO 01-20 14:05:24] ResultReportingScheduler: Done submitting trials, waiting for remaining 2 running trials...
[INFO 01-20 14:05:24] ResultReportingScheduler: Waiting for completed trials (for 1 sec, currently running trials: 2).


Reported result:  (True, {'trials so far': 20, 'currently producing trials from generation step': 'GPEI', 'running trials': [18, 19]})


[INFO 01-20 14:05:25] ResultReportingScheduler: Waiting for completed trials (for 1.5 sec, currently running trials: 2).
[INFO 01-20 14:05:26] ResultReportingScheduler: Waiting for completed trials (for 2 sec, currently running trials: 2).
[INFO 01-20 14:05:29] ResultReportingScheduler: Waiting for completed trials (for 3 sec, currently running trials: 2).
[INFO 01-20 14:05:32] ResultReportingScheduler: Retrieved COMPLETED trials: [19].
[INFO 01-20 14:05:32] ResultReportingScheduler: Fetching data for trials: [19].
[INFO 01-20 14:05:32] ResultReportingScheduler: Waiting for completed trials (for 1 sec, currently running trials: 1).


Reported result:  (True, {'trials so far': 20, 'currently producing trials from generation step': 'GPEI', 'running trials': [18]})


[INFO 01-20 14:05:33] ResultReportingScheduler: Waiting for completed trials (for 1.5 sec, currently running trials: 1).
[INFO 01-20 14:05:34] ResultReportingScheduler: Waiting for completed trials (for 2 sec, currently running trials: 1).
[INFO 01-20 14:05:37] ResultReportingScheduler: Retrieved COMPLETED trials: [18].
[INFO 01-20 14:05:37] ResultReportingScheduler: Fetching data for trials: [18].


Reported result:  (True, {'trials so far': 20, 'currently producing trials from generation step': 'GPEI', 'running trials': []})
Reported result:  (True, {'trials so far': 20, 'currently producing trials from generation step': 'GPEI', 'running trials': []})


In [84]:
exp_to_df(experiment)

Unnamed: 0,branin,trial_index,arm_name,x1,x2,trial_status,generation_method
0,46.244599,0,0_0,4.430947,7.722154,COMPLETED,Sobol
3,14.735405,1,1_0,1.523239,5.943987,COMPLETED,Sobol
12,2.808085,2,2_0,-3.865318,14.02389,COMPLETED,Sobol
13,9.95684,3,3_0,7.948669,0.581467,COMPLETED,Sobol
14,95.420824,4,4_0,9.308774,12.123547,COMPLETED,Sobol
15,79.888519,5,5_0,-5.0,8.991949,COMPLETED,GPEI
16,20.041068,6,6_0,1.82353,0.0,COMPLETED,GPEI
17,90.12128,7,7_0,-0.347321,15.0,COMPLETED,GPEI
18,308.129096,8,8_0,-5.0,0.0,COMPLETED,GPEI
19,48.629291,9,9_0,-2.735696,4.432332,COMPLETED,GPEI
