In [None]:
# default_exp optuna

# Optuna: A hyperparameter optimization framework

> Optuna is an automatic hyperparameter optimization software framework, particularly designed for machine learning. It features an imperative, define-by-run style user API. Thanks to our define-by-run API, the code written with Optuna enjoys high modularity, and the user of Optuna can dynamically construct the search spaces for the hyperparameters.

In [None]:
# export
from pathlib import Path
from fastcore.script import *
import joblib
from tsai.imports import *
from importlib import import_module
import warnings
warnings.filterwarnings("ignore")


@call_parse
def optuna_study(
    config:             Param('Path to the study config file', str), 
    study_type:         Param('Type of study', str)=None,
    multivariate:       Param('Flag to show progress bars or not.', store_false)=True,
    study_name:         Param("Study's name. If this argument is set to None, a unique name is generated automatically.", str)=None, 
    seed:               Param('Seed for random number generator.', int)=None, 
    search_space:       Param('Path to dictionary whose keys and values are a parameter name and the corresponding candidates of values', str)=None, 
    direction:          Param('Direction of optimization.', str)='maximize',
    n_trials:           Param('The number of trials.', int)=None, 
    timeout:            Param('Stop study after the given number of second(s).', int)=None, 
    gc_after_trial:     Param('Flag to determine whether to automatically run garbage collection after each trial.', store_true)=False,
    show_progress_bar:  Param('Flag to show progress bars or not.', store_false)=True,
    show_plots:         Param('Flag to show plots or not.', store_false)=True,
    save:               Param('Flag to save study to disk or not.', store_false)=True,
    path:               Param('Path where the study will be saved', str)='optuna', 
    ):

    try: import optuna
    except ImportError: raise ImportError('You need to install optuna!') 

    while True: 
        if config[0] in "/ .": config = config.split(config[0], 1)[1]
        else: break
    if '/' in config and config.rsplit('/', 1)[0] not in sys.path: sys.path.append(config.rsplit('/', 1)[0])
    m = import_file_as_module(config)
    assert hasattr(m, 'objective'), f"there's no objective function in {config}"
    objective = getattr(m, "objective")

    if study_type is None or study_type.lower() == "bayesian": sampler = optuna.samplers.TPESampler(seed=seed, multivariate=multivariate)
    elif study_type.lower() in ["gridsearch", "gridsampler"]: 
        assert hasattr(m, 'search_space'), f"there's no search_space function in {search_space}"
        search_space = getattr(m, 'search_space')
        sampler = optuna.samplers.GridSampler(search_space=search_space)
    elif study_type.lower() in ["randomsearch", "randomsampler"]: sampler = optuna.samplers.RandomSampler(seed=seed)
    
    try: 
        study = optuna.create_study(sampler=sampler, study_name=study_name, direction=direction)
        study.optimize(objective, n_trials=n_trials, timeout=timeout, gc_after_trial=gc_after_trial, show_progress_bar=show_progress_bar)
        
    except KeyboardInterrupt:
        pass

    if save: 
        full_path = Path(path)/f'{study.study_name}.pkl'
        full_path.parent.mkdir(parents=True, exist_ok=True)
        joblib.dump(study, full_path)
        print(f'\nOptuna study saved to {full_path}')
        print(f"To reload the study run: study = joblib.load('{full_path}')")

    if show_plots and len(study.trials) > 1: 
        try: display(optuna.visualization.plot_optimization_history(study))
        except: pass
        try: display(optuna.visualization.plot_param_importances(study))
        except: pass
        try: display(optuna.visualization.plot_slice(study))
        except: pass
        try: display(optuna.visualization.plot_parallel_coordinate(study))
        except: pass
    
    try: 
        print(f"\nStudy stats   : ")
        print(f"===============")
        print(f"Study name    : {study.study_name}")
        print(f"  n_trials    : {len(study.trials)}")
        print(f"Best trial    :")
        trial = study.best_trial
        print(f"  value       : {trial.value}")
        print(f"  best_params = {trial.params}\n")
    except:
        print('No trials are completed yet.')
    return study

In [None]:
#hide
out = create_scripts()
beep(out)

<IPython.core.display.Javascript object>

Converted 000_utils.ipynb.
Converted 000b_data.validation.ipynb.
Converted 000c_data.preparation.ipynb.
Converted 001_data.external.ipynb.
Converted 002_data.core.ipynb.
Converted 002b_data.unwindowed.ipynb.
Converted 002c_data.metadatasets.ipynb.
Converted 003_data.preprocessing.ipynb.
Converted 003b_data.transforms.ipynb.
Converted 003c_data.mixed_augmentation.ipynb.
Converted 003d_data.image.ipynb.
Converted 003e_data.features.ipynb.
Converted 005_data.tabular.ipynb.
Converted 006_data.mixed.ipynb.
Converted 050_losses.ipynb.
Converted 051_metrics.ipynb.
Converted 052_learner.ipynb.
Converted 052b_tslearner.ipynb.
Converted 053_optimizer.ipynb.
Converted 060_callback.core.ipynb.
Converted 061_callback.noisy_student.ipynb.
Converted 063_callback.MVP.ipynb.
Converted 064_callback.PredictionDynamics.ipynb.
Converted 100_models.layers.ipynb.
Converted 100b_models.utils.ipynb.
Converted 100c_models.explainability.ipynb.
Converted 101_models.ResNet.ipynb.
Converted 101b_models.ResNetPlus.