# Final hyperparameter tuning results

In [16]:
import numpy as np
import pandas as pd

In [19]:
results_dict = {
    "MLP Baseline": {
        "val AUC": 0.6686054843599283,
        "best_params": {
            "dim_embedding": 128,
            "predictor_hidden": 64,
            "lr": 0.00032143048707089674,
            "random_reverse": True,
            "hidden_size": 512,
            "dropout": 0.4
        },
        "runtime": 1.55,
        "nr_trials": 150
    },
    "MLP Freq": {
        "val AUC": 0.7248128243994304,
        "best_params": {
            "dim_embedding": 64,
            "predictor_hidden": 128,
            "lr": 0.0006018234365492453,
            "random_reverse": True,
            "hidden_size": 512,
            "dropout": 0.0
        },
        "runtime": 16.53,
        "nr_trials": 150
    },
    "CNN": {
        "val AUC": 0.6908984428827338,
        "best_params": {
            "dim_embedding": 128,
            "predictor_hidden": 128,
            "lr": 0.0009882333318255191,
            "random_reverse": True,
            "num_kernels_conv1": 64,
            "num_kernels_conv2": 32,
            "kernel_size_conv1": 6,
            "kernel_size_conv2": 9,
            "max_pool1": 20,
            "max_pool2": 5
        },
        "runtime": 2.05,
        "nr_trials": 150
    },
    "LegNet": {
        "val AUC": 0.710789582472096,
        "best_params": {
            "dim_embedding": 128,
            "predictor_hidden": 32,
            "lr": 0.002304424223674682,
            "random_reverse": False,
            "kernel_size": 11,
            "resize_factor": 6,
            "se_reduction": 8,
            "filter_per_group": 2
        },
        "runtime": 24.0,
        "nr_trials": 80
    },
    "RiboNN": {
        "val AUC": 0.6990592990675669,
        "best_params": {
            "dim_embedding": 128,
            "predictor_hidden": 32,
            "lr": 0.0018021153781358553,
            "random_reverse": True,
            "num_layers": 4,
            "dropout": 0.1,
            "grad_clip_norm": 0.2745393751090204
        },
        "runtime": 6.4,
        "nr_trials": 150
    },
    "LSTM": {
        "val AUC": 0.6774309861742686,
        "best_params": {
            "dim_embedding": 128,
            "predictor_hidden": 32,
            "lr": 0.00015599959855408808,
            "random_reverse": False,
            "bidirectional": True,
            "dropout": 0.1,
            "rnn_hidden_size": 512,
            "num_layers": 2
        },
        "runtime": 24.0,
        "nr_trials": 48
    },
    "GRU": {
        "val AUC": 0.6836840292131734,
        "best_params": {
            "dim_embedding": 128,
            "predictor_hidden": 64,
            "lr": 1.3196675488432542e-05,
            "random_reverse": True,
            "bidirectional": True,
            "dropout": 0.1,
            "rnn_hidden_size": 256,
            "num_layers": 2
        },
        "runtime": 24.0,
        "nr_trials": 18
    },
    "xLSTM": {
        "val AUC": 0.6890615957007028,
        "best_params": {
            "dim_embedding": 64,
            "predictor_hidden": 64,
            "lr": 0.00029630029586396017,
            "random_reverse": False,
            "conv1d_kernel_size": 9,
            "m_qkv_proj_blocksize": 8,
            "num_heads": 8,
            "proj_factor": 1,
            "num_blocks": 5,
            "dropout": 0.4
        },
        "runtime": 24.0,
        "nr_trials": 9
    },
    "Transformer": {
        "val AUC": 0.680868127325341,
        "best_params": {
            "dim_embedding": 32,
            "predictor_hidden": 128,
            "lr": 0.0010919470279424945,
            "random_reverse": True,
            "num_layers": 12,
            "dim_feedforward": 128,
            "dropout": 0.2,
            "num_heads": 4
        },
        "runtime": 24.0,
        "nr_trials": 11
    },
    "Mamba": {
        "val AUC": 0.686000643057278,
        "best_params": {
            "dim_embedding": 128,
            "predictor_hidden": 128,
            "lr": 0.00021643583369462338,
            "random_reverse": True,
            "d_state": 16,
            "d_conv": 2,
            "num_layers": 2
        },
        "runtime": 5.28,  # in hours
        "nr_trials": 150
    }
}

In [20]:
records = [
    {
        "model": model,
        "val AUC": info["val AUC"],
        "runtime (h)": info["runtime"],
        "nr_trials": info["nr_trials"],
        "best_params": info["best_params"]
    }
    for model, info in results_dict.items()
]

df = pd.DataFrame(records).sort_values("val AUC", ascending=False).reset_index(drop=True)
df

Unnamed: 0,model,val AUC,runtime (h),nr_trials,best_params
0,MLP Freq,0.724813,16.53,150,"{'dim_embedding': 64, 'predictor_hidden': 128,..."
1,LegNet,0.71079,24.0,80,"{'dim_embedding': 128, 'predictor_hidden': 32,..."
2,RiboNN,0.699059,6.4,150,"{'dim_embedding': 128, 'predictor_hidden': 32,..."
3,CNN,0.690898,2.05,150,"{'dim_embedding': 128, 'predictor_hidden': 128..."
4,xLSTM,0.689062,24.0,9,"{'dim_embedding': 64, 'predictor_hidden': 64, ..."
5,Mamba,0.686001,5.28,150,"{'dim_embedding': 128, 'predictor_hidden': 128..."
6,GRU,0.683684,24.0,18,"{'dim_embedding': 128, 'predictor_hidden': 64,..."
7,Transformer,0.680868,24.0,11,"{'dim_embedding': 32, 'predictor_hidden': 128,..."
8,LSTM,0.677431,24.0,48,"{'dim_embedding': 128, 'predictor_hidden': 32,..."
9,MLP Baseline,0.668605,1.55,150,"{'dim_embedding': 128, 'predictor_hidden': 64,..."


In [14]:
df.to_latex()

"\\begin{tabular}{llrrrl}\n\\toprule\n & model & val AUC & runtime (h) & nr_trials & best_params \\\\\n\\midrule\n0 & MLP Freq & 0.724813 & 16.530000 & 150 & {'dim_embedding': 64, 'predictor_hidden': 128, 'lr': 0.0006018234365492453, 'random_reverse': True, 'hidden_size': 512, 'dropout': 0.0} \\\\\n1 & LegNet & 0.710790 & 24.000000 & 80 & {'dim_embedding': 128, 'predictor_hidden': 32, 'lr': 0.002304424223674682, 'random_reverse': False, 'kernel_size': 11, 'resize_factor': 6, 'se_reduction': 8, 'filter_per_group': 2} \\\\\n2 & CNN & 0.690898 & 2.050000 & 150 & {'dim_embedding': 128, 'predictor_hidden': 128, 'lr': 0.0009882333318255191, 'random_reverse': True, 'num_kernels_conv1': 64, 'num_kernels_conv2': 32, 'kernel_size_conv1': 6, 'kernel_size_conv2': 9, 'max_pool1': 20, 'max_pool2': 5} \\\\\n3 & xLSTM & 0.689062 & 24.000000 & 9 & {'dim_embedding': 64, 'predictor_hidden': 64, 'lr': 0.00029630029586396017, 'random_reverse': False, 'conv1d_kernel_size': 9, 'm_qkv_proj_blocksize': 8, 'num

## Export and store PNG images of Optuna training

In [22]:
import sqlite3
import pandas as pd

# Connect to the database
conn = sqlite3.connect("/export/share/krausef99dm/tuning_dbs/baseline.db")

# Read all trials
df = pd.read_sql_query("SELECT * FROM trials", conn) # WHERE study_id==2 AND number==150
# PRAGMA table_info(trials)
df

Unnamed: 0,trial_id,number,study_id,state,datetime_start,datetime_complete
0,1,0,1,COMPLETE,2024-11-27 18:17:43.098059,2024-11-27 18:22:25.765928
1,2,0,2,COMPLETE,2024-11-27 18:21:53.146553,2024-11-27 18:23:12.560393
2,3,1,1,COMPLETE,2024-11-27 18:22:25.860969,2024-11-27 18:28:58.165485
3,4,1,2,COMPLETE,2024-11-27 18:23:12.644701,2024-11-27 18:24:16.711223
4,5,2,2,COMPLETE,2024-11-27 18:24:16.789116,2024-11-27 18:25:12.622144
...,...,...,...,...,...,...
1046,1199,145,8,COMPLETE,2025-05-02 08:29:58.235831,2025-05-02 08:40:15.359812
1047,1200,146,8,PRUNED,2025-05-02 08:40:15.448294,2025-05-02 08:40:43.139412
1048,1201,147,8,COMPLETE,2025-05-02 08:40:43.246198,2025-05-02 08:51:03.745128
1049,1202,148,8,PRUNED,2025-05-02 08:51:03.833239,2025-05-02 08:51:27.183700


In [18]:
# DANGEROUS
#conn.execute("DELETE FROM trials WHERE study_id==2 AND number==0")
#conn.commit()

In [1]:
import os
import optuna
from tqdm import tqdm 
from optuna.visualization import (
    plot_param_importances,
    plot_intermediate_values,
    plot_optimization_history,
    plot_timeline,
)
import plotly.io as pio

models = {
    "baseline": "codon_final_2",
    "baseline_freq": "freq_final_2",
    "cnn": "codon_final_2",
    "gru": "codon_final_2",
    "lstm": "codon_final_2",
    "xlstm": "codon_final_2",
    "mamba": "codon_final_2",
    "transformer": "codon_final_2",
    "LegNet": "codon_final_2",
    "RiboNN": "codon_final_2",
}

plots = {
    "param_importances": plot_param_importances,
    "intermediate_values": plot_intermediate_values,
    "optimization_history": plot_optimization_history,
    "timeline": plot_timeline,
}

base_storage_path = "/export/share/krausef99dm/tuning_dbs"

for model, study_name in tqdm(models.items()):
    if model != "baseline_freq":
        storage_url = f"sqlite:///{base_storage_path}/{model}.db"
    else:
        storage_url = f"sqlite:///{base_storage_path}/baseline.db"
    try:
        study = optuna.load_study(study_name=study_name, storage=storage_url)
    except Exception as e:
        print(f"[{model}] Failed to load study: {e}")
        continue

    os.makedirs(os.path.join("/export/share/krausef99dm/tuning_images", model), exist_ok=True)

    for plot_name, plot_func in plots.items():
        try:
            fig = plot_func(study)
            output_path = os.path.join("/export/share/krausef99dm/tuning_images", model, f"{plot_name}.png")
            pio.write_image(fig, output_path)
            print(f"[{model}] Saved {plot_name}.png")
        except Exception as e:
            print(f"[{model}] Failed to generate {plot_name}: {e}")

  0%|          | 0/10 [00:00<?, ?it/s]

[baseline] Saved param_importances.png
[baseline] Saved intermediate_values.png
[baseline] Saved optimization_history.png



plot_timeline is experimental (supported from v3.2.0). The interface can change in the future.

 10%|█         | 1/10 [00:04<00:39,  4.37s/it]

[baseline] Saved timeline.png
[baseline_freq] Saved param_importances.png
[baseline_freq] Saved intermediate_values.png
[baseline_freq] Saved optimization_history.png



plot_timeline is experimental (supported from v3.2.0). The interface can change in the future.

 20%|██        | 2/10 [00:07<00:28,  3.60s/it]

[baseline_freq] Saved timeline.png
[cnn] Saved param_importances.png
[cnn] Saved intermediate_values.png
[cnn] Saved optimization_history.png



plot_timeline is experimental (supported from v3.2.0). The interface can change in the future.

 30%|███       | 3/10 [00:09<00:20,  2.90s/it]

[cnn] Saved timeline.png
[gru] Saved param_importances.png
[gru] Saved intermediate_values.png
[gru] Saved optimization_history.png



plot_timeline is experimental (supported from v3.2.0). The interface can change in the future.

 40%|████      | 4/10 [00:10<00:13,  2.19s/it]

[gru] Saved timeline.png
[lstm] Saved param_importances.png
[lstm] Saved intermediate_values.png
[lstm] Saved optimization_history.png



plot_timeline is experimental (supported from v3.2.0). The interface can change in the future.

 50%|█████     | 5/10 [00:12<00:09,  1.96s/it]

[lstm] Saved timeline.png
[xlstm] Saved param_importances.png
[xlstm] Saved intermediate_values.png
[xlstm] Failed to generate optimization_history: list index out of range



plot_timeline is experimental (supported from v3.2.0). The interface can change in the future.

 60%|██████    | 6/10 [00:13<00:06,  1.63s/it]

[xlstm] Saved timeline.png
[mamba] Saved param_importances.png
[mamba] Saved intermediate_values.png
[mamba] Saved optimization_history.png



plot_timeline is experimental (supported from v3.2.0). The interface can change in the future.

 70%|███████   | 7/10 [00:15<00:05,  1.71s/it]

[mamba] Saved timeline.png
[transformer] Saved param_importances.png
[transformer] Saved intermediate_values.png
[transformer] Failed to generate optimization_history: list index out of range



plot_timeline is experimental (supported from v3.2.0). The interface can change in the future.

 80%|████████  | 8/10 [00:16<00:03,  1.53s/it]

[transformer] Saved timeline.png
[LegNet] Saved param_importances.png
[LegNet] Saved intermediate_values.png
[LegNet] Saved optimization_history.png



plot_timeline is experimental (supported from v3.2.0). The interface can change in the future.

 90%|█████████ | 9/10 [00:17<00:01,  1.49s/it]

[LegNet] Saved timeline.png
[RiboNN] Saved param_importances.png
[RiboNN] Saved intermediate_values.png
[RiboNN] Saved optimization_history.png



plot_timeline is experimental (supported from v3.2.0). The interface can change in the future.

100%|██████████| 10/10 [00:19<00:00,  1.97s/it]

[RiboNN] Saved timeline.png



