In [1]:
import random

from hyperopt import hp
from hyperopt.pyll import scope
from ray import train, tune
from ray.tune.schedulers import ASHAScheduler
from ray.tune.search.hyperopt import HyperOptSearch

In [7]:
def train_model(config, input_data):
    factor = random.uniform(0, 1)
    for i in range(input_data["epochs"],1,-1):
        train.report({input_data["metric"]: i*factor})


In [9]:
num_samples = 100
epochs = 100
metric = "mse"

input_data = {
    "epochs": epochs,
    "metric": metric,
}

example_config = {
    "lr_rnvae": 1e-3,
    "lr_imputer": 1e-4,
    "dropout": .1,
    "sm_emb_size": 64,
    "cell_emb_size": 32,
    "latent_dim": 256,
    "hidden_dim": 512,
    "kld_weight": 1,
    "impute_loss_weight": 2,
}

space = {
    "lr_rnvae": hp.loguniform("lr_rnvae", -10, -1),
    "lr_imputer": hp.loguniform("lr_imputer", -10, -1),
    "dropout": hp.uniform("dropout", 0, 1),
    "sm_emb_size": scope.int(hp.qloguniform("sm_emb_size", 0, 3, 1)),
    "cell_emb_size": scope.int(hp.qloguniform("cell_emb_size", 0, 3, 1)),
    "latent_dim": scope.int(hp.qloguniform("latent_dim", 0, 7, 1)),
    "hidden_dim": scope.int(hp.qloguniform("hidden_dim", 0, 7, 1)),
    "kld_weight": hp.loguniform("kld_weight", -2, 2),
    "impute_loss_weight": hp.loguniform("impute_loss_weight", -2, 2),
}

mode = "min"
hyperopt_search = HyperOptSearch(space, metric=metric, mode=mode)
scheduler = ASHAScheduler(metric=metric, grace_period=5, mode=mode, max_t=epochs)
tuner = tune.Tuner(
    tune.with_parameters(train_model, input_data=input_data),
    tune_config=tune.TuneConfig(
        num_samples=num_samples,
        search_alg=hyperopt_search,
        scheduler=scheduler
    ),
    run_config=train.RunConfig(
        failure_config=train.FailureConfig(fail_fast=False))
)
results = tuner.fit()

best_result = results.get_best_result(metric, mode=mode)
print(best_result.path)
print("CONFIG:", best_result.config)
print("METRICS:", best_result.metrics)

0,1
Current time:,2023-10-21 14:28:52
Running for:,00:00:12.52
Memory:,6.6/8.0 GiB

Trial name,status,loc,cell_emb_size,dropout,hidden_dim,impute_loss_weight,kld_weight,latent_dim,lr_imputer,lr_rnvae,sm_emb_size,iter,total time (s),mse
train_model_e7acab17,TERMINATED,127.0.0.1:26032,2,0.326163,21,1.04033,7.16542,37,0.00166993,0.0178721,5,99,0.00505853,1.90042
train_model_07d2d42f,TERMINATED,127.0.0.1:26032,2,0.591583,783,2.54235,0.4033,5,0.000140948,0.0365279,2,99,0.00607777,1.05982
train_model_3de17014,TERMINATED,127.0.0.1:26032,2,0.741777,12,1.36922,1.23219,30,0.00882943,0.000686908,7,99,0.00999975,0.877159
train_model_54cb8e06,TERMINATED,127.0.0.1:26032,1,0.114099,3,3.93176,1.4754,477,0.000287224,0.34495,4,99,0.0194569,0.266338
train_model_46bf1dd9,TERMINATED,127.0.0.1:26032,13,0.116371,3,2.35041,6.68817,1,0.000246769,0.00436596,7,99,0.0335176,0.35422
train_model_30c68789,TERMINATED,127.0.0.1:26032,5,0.116028,161,1.94583,6.25246,402,0.214411,0.333646,6,5,0.00172997,17.6377
train_model_4de5a241,TERMINATED,127.0.0.1:26032,8,0.283694,642,1.13836,0.76805,44,0.00667491,0.0202679,8,99,0.0275426,0.147578
train_model_4283be7c,TERMINATED,127.0.0.1:26032,2,0.788342,1,0.665047,0.15152,18,0.00339999,0.000683166,4,5,0.00805473,79.0467
train_model_de12339f,TERMINATED,127.0.0.1:26032,8,0.693519,48,0.168206,0.345401,17,0.00220484,0.000357469,2,5,0.00206804,69.4948
train_model_22fab9b2,TERMINATED,127.0.0.1:26032,1,0.517484,43,0.812304,2.01851,97,0.00477447,0.000164582,2,5,0.0045073,69.6128


2023-10-21 14:28:53,403	INFO tune.py:1143 -- Total run time: 13.53 seconds (12.47 seconds for the tuning loop).


/Users/laurasisson/ray_results/train_model_2023-10-21_14-28-39/train_model_94eae705_60_cell_emb_size=3,dropout=0.7424,hidden_dim=13,impute_loss_weight=0.1813,kld_weight=7.2437,latent_dim=1,lr_i_2023-10-21_14-28-48
CONFIG: {'cell_emb_size': 3, 'dropout': 0.7424053974648346, 'hidden_dim': 13, 'impute_loss_weight': 0.1812690579822345, 'kld_weight': 7.243671545471919, 'latent_dim': 1, 'lr_imputer': 0.004563194716937017, 'lr_rnvae': 0.00011361265198956045, 'sm_emb_size': 5}
METRICS: {'mse': 0.04290977545466945, 'timestamp': 1697912928, 'done': True, 'training_iteration': 99, 'trial_id': '94eae705', 'date': '2023-10-21_14-28-48', 'time_this_iter_s': 0.00016307830810546875, 'time_total_s': 0.02109551429748535, 'pid': 26032, 'hostname': 'Lauras-Air', 'node_ip': '127.0.0.1', 'config': {'cell_emb_size': 3, 'dropout': 0.7424053974648346, 'hidden_dim': 13, 'impute_loss_weight': 0.1812690579822345, 'kld_weight': 7.243671545471919, 'latent_dim': 1, 'lr_imputer': 0.004563194716937017, 'lr_rnvae': 0.0