In [33]:
import itertools
import os
import optuna
from optuna.exceptions import ExperimentalWarning
from optuna.visualization.matplotlib import plot_contour
from optuna.visualization.matplotlib import plot_edf
from optuna.visualization.matplotlib import plot_intermediate_values
from optuna.visualization.matplotlib import plot_optimization_history
from optuna.visualization.matplotlib import plot_parallel_coordinate
from optuna.visualization.matplotlib import plot_param_importances
from optuna.visualization.matplotlib import plot_rank
from optuna.visualization.matplotlib import plot_slice
from optuna.visualization.matplotlib import plot_timeline
import matplotlib.pyplot as plt
import itertools
import warnings

from transformers.style_interpolation import step_size


In [62]:
def plot_study(study):

    params = list(study.best_params.keys())

    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=ExperimentalWarning)
        plot_optimization_history(study)
        plt.show()

    if study.best_trial.intermediate_values.values():
        plot_intermediate_values(study)
        plt.show()

    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=ExperimentalWarning)
        plot_parallel_coordinate(study)
        plt.show()

    with (warnings.catch_warnings()):
        warnings.simplefilter("ignore", category=ExperimentalWarning)
        combinations = [list(c) for c in itertools.combinations(params, 2)]
        print(len(combinations))
        step_size = 2
        for i in range(0, len(params), step_size):
            plot_contour(study, params=params[i:i+step_size])
            plt.show()

    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=ExperimentalWarning)
        # for param in params:
        #     plot_slice(study, params=[param])
        #     plt.show()


In [68]:
study_dir = "trials"
study_files = [os.path.join(study_dir, file) for file in os.listdir(study_dir) if file.endswith(".db")]

best_values = []
for study_file in study_files:
    storage_uri = f"sqlite:///{study_file}"
    storage = optuna.storages.RDBStorage(storage_uri)
    study_name = storage.get_all_studies()[0].study_name
    study = optuna.load_study(study_name=study_name, storage=storage)
    best_values.append((study_name, study.best_value))
    # plot_study(study)
print(sorted([pair for pair in best_values], key=lambda x: x[1]))


[('STUDY_ResNet18_0_pretrained_False_transformations_False_2025-07-07_20:30:16', 0.07806632213760167), ('STUDY_ResNet18_0_pretrained_True_transformations_False_2025-07-06_15:23:55', 0.09095924987923354), ('STUDY_ResNet18_1_pretrained_False_transformations_False_2025-07-07_23:06:31', 0.1452832967042923), ('STUDY_ResNet18_1_pretrained_True_transformations_False_2025-07-07_08:22:07', 0.15408412204124033), ('STUDY_ResNet18_3_pretrained_False_transformations_False_2025-07-08_03:41:50', 0.253402944188565), ('STUDY_ResNet18_3_pretrained_True_transformations_False_2025-07-07_18:09:42', 0.27541077276691794), ('STUDY_ResNet18_2_pretrained_True_transformations_False_2025-07-07_12:42:01', 0.5258762333542109), ('STUDY_ResNet18_2_pretrained_False_transformations_False_2025-07-08_01:33:03', 0.5330962724983692)]
