In [None]:
# type: ignore
import math
import sys
from collections import OrderedDict

import numpy as np
from loguru import logger

from carbs import CARBS
from carbs import CARBSParams
from carbs import LogSpace
from carbs import LogitSpace
from carbs import ObservationInParam
from carbs import ParamDictType
from carbs import Param

logger.remove()
logger.add(sys.stdout, level="DEBUG", format="{message}")



def run_test_fn(input_in_param: ParamDictType):
    # A noisy function minimized at lr=1e-3, max hidden_dim
    result = (math.log10(input_in_param["learning_rate"]) + 3) ** 2 * 512 / input_in_param[
        "epochs"
    ] + np.random.uniform() * 0.1
    return result

param_spaces = [
    Param(name="learning_rate", space=LogSpace(scale=0.5), search_center=1e-4),
    Param(name="momentum", space=LogitSpace(), search_center=0.9),
    Param(name="epochs", space=LogSpace(is_integer=True, min=2, max=512), search_center=10),
]

carbs_params = CARBSParams(
    better_direction_sign=-1,
    is_wandb_logging_enabled=False,
    resample_frequency=0,
)
carbs = CARBS(carbs_params, param_spaces)
for i in range(10):
    suggestion = carbs.suggest().suggestion
    observed_value = run_test_fn(suggestion)
    obs_out = carbs.observe(ObservationInParam(input=suggestion, output=observed_value, cost=suggestion["epochs"]))
    logger.info(f"Observation {obs_out.logs['observation_count']}")
    logger.info(
        f"Observed lr={obs_out.logs['observation/learning_rate']:.2e}, "
        f"epochs={obs_out.logs['observation/epochs']}, "
        f"output {obs_out.logs['observation/output']:.3f}"
    )
    logger.info(
        f"Best lr={obs_out.logs['best_observation/learning_rate']:.2e}, "
        f"epochs={obs_out.logs['best_observation/epochs']}, "
        f"output {obs_out.logs['best_observation/output']:.3f}"
    )

Running CARBS with params CARBSParams(better_direction_sign=-1, seed=0, num_random_samples=4, is_wandb_logging_enabled=False, wandb_params=WandbLoggingParams(project_name=None, group_name=None, run_name=None, run_id=None, is_suggestion_logged=True, is_observation_logged=True, is_search_space_logged=True, root_dir='/mnt/private'), is_saved_on_every_observation=True, initial_search_radius=0.3, exploration_bias=1.0, num_candidates_for_suggestion_per_dim=100, resample_frequency=-1, max_cost=None, min_pareto_cost_fraction=0.2, is_pareto_group_selection_conservative=True, is_expected_improvement_pareto_value_clamped=True, is_expected_improvement_value_always_max=False, outstanding_suggestion_estimator=<OutstandingSuggestionEstimatorEnum.THOMPSON: 'THOMPSON'>)
Observation 1
Observed lr=3.93e-05, epochs=15, output 67.511
Best lr=3.93e-05, epochs=15, output 67.511
Observation 2
Observed lr=2.84e-04, epochs=33, output 4.718
Best lr=2.84e-04, epochs=33, output 4.718
Observation 3
Observed lr=5.43

L = torch.cholesky(A)
should be replaced with
L = torch.linalg.cholesky(A)
and
U = torch.cholesky(A, upper=True)
should be replaced with
U = torch.linalg.cholesky(A).mH().
This transform will produce equivalent results for all valid (symmetric positive definite) inputs. (Triggered internally at  ../aten/src/ATen/native/BatchLinearAlgebra.cpp:1744.)
  Lff = Kff.cholesky()


Observation 5
Observed lr=2.76e-04, epochs=82, output 1.993
Best lr=2.76e-04, epochs=82, output 1.993


torch.linalg.solve_triangular has its arguments reversed and does not return a copy of one of the inputs.
X = torch.triangular_solve(B, A).solution
should be replaced with
X = torch.linalg.solve_triangular(A, B). (Triggered internally at  ../aten/src/ATen/native/BatchLinearAlgebra.cpp:2183.)
  Lffinv_pack = pack.triangular_solve(Lff, upper=False)[0]


Observation 6
Observed lr=2.64e-04, epochs=163, output 1.118
Best lr=2.64e-04, epochs=163, output 1.118
Observation 7
Observed lr=3.01e-04, epochs=207, output 0.718
Best lr=3.01e-04, epochs=207, output 0.718
Observation 8
Observed lr=2.45e-04, epochs=207, output 1.012
Best lr=3.01e-04, epochs=207, output 0.718
Observation 9
Observed lr=4.58e-04, epochs=173, output 0.437
Best lr=4.58e-04, epochs=173, output 0.437
Observation 10
Observed lr=7.14e-04, epochs=111, output 0.137
Best lr=7.14e-04, epochs=111, output 0.137
