# Final hyperparameter tuning results

In [None]:
import numpy as np
import pandas as pd

In [None]:
results_dict = {
    "MLP Baseline": {
        "val AUC": 0.6686054843599283,
        "best_params": {
            "dim_embedding": 128,
            "predictor_hidden": 64,
            "lr": 0.00032143048707089674,
            "random_reverse": True,
            "hidden_size": 512,
            "dropout": 0.4
        },
        "runtime": 1.55,
        "nr_trials": 150
    },
    "MLP Freq": {
        "val AUC": 0.7248128243994304,
        "best_params": {
            "dim_embedding": 64,
            "predictor_hidden": 128,
            "lr": 0.0006018234365492453,
            "random_reverse": True,
            "hidden_size": 512,
            "dropout": 0.0
        },
        "runtime": 16.53,
        "nr_trials": 150
    },
    "CNN": {
        "val AUC": 0.6908984428827338,
        "best_params": {
            "dim_embedding": 128,
            "predictor_hidden": 128,
            "lr": 0.0009882333318255191,
            "random_reverse": True,
            "num_kernels_conv1": 64,
            "num_kernels_conv2": 32,
            "kernel_size_conv1": 6,
            "kernel_size_conv2": 9,
            "max_pool1": 20,
            "max_pool2": 5
        },
        "runtime": 2.05,
        "nr_trials": 150
    },
    "LegNet": {
        "val AUC": 0.710789582472096,
        "best_params": {
            "dim_embedding": 128,
            "predictor_hidden": 32,
            "lr": 0.002304424223674682,
            "random_reverse": False,
            "kernel_size": 11,
            "resize_factor": 6,
            "se_reduction": 8,
            "filter_per_group": 2
        },
        "runtime": 24.0,
        "nr_trials": 80
    },
    "RiboNN": {
        "val AUC": 0.6990592990675669,
        "best_params": {
            "dim_embedding": 128,
            "predictor_hidden": 32,
            "lr": 0.0018021153781358553,
            "random_reverse": True,
            "num_layers": 4,
            "dropout": 0.1,
            "grad_clip_norm": 0.2745393751090204
        },
        "runtime": 6.4,
        "nr_trials": 150
    },
    "LSTM": {
        "val AUC": 0.6774309861742686,
        "best_params": {
            "dim_embedding": 128,
            "predictor_hidden": 32,
            "lr": 0.00015599959855408808,
            "random_reverse": False,
            "bidirectional": True,
            "dropout": 0.1,
            "rnn_hidden_size": 512,
            "num_layers": 2
        },
        "runtime": 24.0,
        "nr_trials": 48
    },
    "GRU": {
        "val AUC": 0.6836840292131734,
        "best_params": {
            "dim_embedding": 128,
            "predictor_hidden": 64,
            "lr": 1.3196675488432542e-05,
            "random_reverse": True,
            "bidirectional": True,
            "dropout": 0.1,
            "rnn_hidden_size": 256,
            "num_layers": 2
        },
        "runtime": 24.0,
        "nr_trials": 18
    },
    "xLSTM": {
        "val AUC": 0.6890615957007028,
        "best_params": {
            "dim_embedding": 64,
            "predictor_hidden": 64,
            "lr": 0.00029630029586396017,
            "random_reverse": False,
            "conv1d_kernel_size": 9,
            "m_qkv_proj_blocksize": 8,
            "num_heads": 8,
            "proj_factor": 1,
            "num_blocks": 5,
            "dropout": 0.4
        },
        "runtime": 24.0,
        "nr_trials": 9
    },
    "Transformer": {
        "val AUC": 0.680868127325341,
        "best_params": {
            "dim_embedding": 32,
            "predictor_hidden": 128,
            "lr": 0.0010919470279424945,
            "random_reverse": True,
            "num_layers": 12,
            "dim_feedforward": 128,
            "dropout": 0.2,
            "num_heads": 4
        },
        "runtime": 24.0,
        "nr_trials": 11
    },
    "Mamba": {
        "val AUC": 0.686000643057278,
        "best_params": {
            "dim_embedding": 128,
            "predictor_hidden": 128,
            "lr": 0.00021643583369462338,
            "random_reverse": True,
            "d_state": 16,
            "d_conv": 2,
            "num_layers": 2
        },
        "runtime": 5.28,  # in hours
        "nr_trials": 150
    },
    "PTRnet": {
        "val AUC": 0.6954712816604476,
        "best_params": {
            "predictor_hidden": 32,
            "lr": 0.0004589348543322555,
            "predictor_dropout": 0.0,
            "weight_decay": 0.0006792444466769486,
            "reset_epochs": 20,
            "T_mult": 1,
            "num_layers": 6,
            "dropout": 0.0,
            "align_aug": True,
            "concat_tissue_feature": False,
            "frequency_features": True,
            "seq_only": False,
            "seq_encoding": "embedding",
            "grad_clip_norm": 0.15166663984361584,
            "dim_embedding": 32
        },
        "runtime": 42.36,
        "nr_trials": 754
    }
}

In [None]:
records = [
    {
        "model": model,
        "val AUC": info["val AUC"],
        "runtime (h)": info["runtime"],
        "nr_trials": info["nr_trials"],
        "best_params": info["best_params"]
    }
    for model, info in results_dict.items()
]

df = pd.DataFrame(records).sort_values("val AUC", ascending=False).reset_index(drop=True)
df

In [None]:
df.to_latex()

In [None]:
# PTRnet only
df.iloc[3,:].to_latex()

## Export and store PNG images of Optuna training

In [None]:
import sqlite3
import pandas as pd

# Connect to the database
conn = sqlite3.connect("/export/share/krausef99dm/tuning_dbs/ptrnet.db")

# Read all trials
df = pd.read_sql_query("SELECT * FROM trials WHERE study_id==3 AND state=='COMPLETE'", conn) # WHERE study_id==2 AND number==150
# PRAGMA table_info(trials)
df

In [None]:
# DANGEROUS
#conn.execute("DELETE FROM trials WHERE study_id==2 AND number==0")
#conn.commit()

In [None]:
import os
import optuna
import time
from tqdm import tqdm 
from optuna.visualization import (
    plot_param_importances,
    plot_intermediate_values,
    plot_optimization_history,
    plot_timeline,
)
import plotly.io as pio

models = {
    #"baseline": "codon_final_2",
    #"baseline_freq": "freq_final_2",
    #"cnn": "codon_final_2",
    #"gru": "codon_final_2",
    #"lstm": "codon_final_2",
    #"xlstm": "codon_final_2",
    #"mamba": "codon_final_2",
    #"transformer": "codon_final_2",
    #"LegNet": "codon_final_2",
    #"RiboNN": "codon_final_2",
    "ptrnet": "final_1",
}

plots = {
    "param_importances": plot_param_importances,
    "intermediate_values": plot_intermediate_values,
    "optimization_history": plot_optimization_history,
    "timeline": plot_timeline,
}

base_storage_path = "/export/share/krausef99dm/tuning_dbs"

for model, study_name in tqdm(models.items()):
    if model != "baseline_freq":
        storage_url = f"sqlite:///{base_storage_path}/{model}.db"
    else:
        storage_url = f"sqlite:///{base_storage_path}/baseline.db"
    try:
        study = optuna.load_study(study_name=study_name, storage=storage_url)
        time.sleep(3)
    except Exception as e:
        print(f"[{model}] Failed to load study: {e}")
        continue

    os.makedirs(os.path.join("/export/share/krausef99dm/tuning_images", model), exist_ok=True)

    for plot_name, plot_func in plots.items():
        try:
            fig = plot_func(study)
            time.sleep(1)
            fig.update_layout(
                title=None,
                margin=dict(l=10, r=10, t=10, b=10),
                paper_bgcolor='rgba(0,0,0,0)',
                font=dict(size=14)
            )
            #fig.show()
            output_path = os.path.join("/export/share/krausef99dm/tuning_images", model, f"{plot_name}.png")
            pio.write_image(fig, output_path)
            print(f"[{model}] Saved {plot_name}.png")
        except Exception as e:
            print(f"[{model}] Failed to generate {plot_name}: {e}")

# NOTE: optimization history for xlstm and transformer might fail due to indexing bug

In [None]:
# Note: it seems like optuna dashboard uses all (not only completed trials) and a different importance evaluator than the functions here. 
# Hence, there can be big variations among importance values
from optuna.importance import get_param_importances

study = optuna.load_study(study_name="final_1", storage=f"sqlite:////export/share/krausef99dm/tuning_dbs/ptrnet.db")

#evaluator = optuna.importance.MeanDecreaseImpurityImportanceEvaluator()
#evaluator = optuna.importance.FanovaImportanceEvaluator()
evaluator = None

importances = get_param_importances(study, evaluator=evaluator)
for param, importance in importances.items():
    print(f"{param}: {importance:.3f}")

fig = plot_param_importances(study, evaluator=evaluator)
fig.show()