In [None]:
import optuna
import joblib
import optuna.visualization as vis
import time

import sys
import os

project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.append(project_root)

from generators.Optimizers.ParticleSwarm import ParticleSwarm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# ----- Fixed experiment settings -----
base_kwargs = {
    "cnn_model_path": 'Models/CNN_6_1_2.keras',
    "masked_sequence": "AATACTAGAGGTCTTCCGACNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNGTGTGGGCGGGAAGACAACTAGGGG",
    "target_expression": 1,
    "max_iter": 100,
    "seed": 42
}

# ----- Objective Function -----
def objective(trial):

    c1 = trial.suggest_float("c1", 0, 1.0)
    c2 = trial.suggest_float("c2", 0, 1.0)
    w = trial.suggest_float("w", 0, 1.0)
    n_particles = trial.suggest_int("n_particles", 1, 100)

    # --- GA kwargs ---
    kwargs = base_kwargs.copy()
    kwargs.update({
        "c1": c1,
        "c2": c2,
        "w": w,
        "n_particles": n_particles,
    })

    # --- Run GA ---
    try:
        ps = ParticleSwarm(**kwargs)
        start_time = time.time()
        _, _, best_error = ps.run()
        end_time = time.time()
        return best_error, end_time - start_time
    except Exception as e:
        print(f"Trial failed: {e}")
        return 0.0

# ----- Run Optimization -----
def run_optimization(n_trials=50):
    study = optuna.create_study(directions=["minimize", "minimize"], sampler=optuna.samplers.TPESampler())
    study.optimize(objective, n_trials=n_trials)
    return study

In [None]:
# study = run_optimization(n_trials=200)
# joblib.dump(study, "Data/Optimizer Hyperparameters/PS_hp_200.pkl")

In [5]:
study = joblib.load("../data/optimizer_hp/PS_hp_200.pkl")

In [6]:
vis.plot_param_importances(study).show()

In [7]:
fig = vis.plot_pareto_front(
    study,
    targets=lambda t: (t.values[0], t.values[1]),
    target_names=["Error", "Run time"],
).show()

In [8]:
print(study.trials[172].values)
for k, v in study.trials[172].params.items():
    print(f"  {k}: {v}")

[0.3782031536102295, 78.66221451759338]
  c1: 0.43255869497860583
  c2: 0.6195857637320343
  w: 0.5260564952340614
  n_particles: 85


In [9]:
import gc 
gc.collect()

302