In [None]:
import time
from functools import partial
from pathlib import Path
from typing import NamedTuple

import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from jax import block_until_ready, grad, jit, lax, random, vmap

# jax_gd_pso.py

In [None]:
class JaxGdSwarmState(NamedTuple):
    positions: jnp.ndarray
    velocities: jnp.ndarray
    p_best_pos: jnp.ndarray
    p_best_fit: jnp.ndarray
    g_best_pos: jnp.ndarray
    g_best_fit: jnp.ndarray
    rng: random.PRNGKey


class GradientState(NamedTuple):
    current_pos: jnp.ndarray


@partial(
    jit,
    static_argnames=(
        "objective_fn",
        "num_dims",
        "num_particles",
        "max_iters",
        "steps",
    ),
)
def jax_gd_pso(
    objective_fn: callable,
    bounds: tuple,
    num_dims: int,
    num_particles: int,
    max_iters: int,
    c1: float,
    c2: float,
    w: float,
    seed: random.PRNGKey,
    eta: float,
    steps: int,
) -> tuple:
    key = seed
    lower, upper = jnp.array(bounds[0]), jnp.array(bounds[1])
    k_pos, k_vel, k_state = random.split(key, 3)

    search_range = upper - lower
    velocity_scale = 0.1
    limit = search_range * velocity_scale

    init_positions = random.uniform(k_pos, (num_particles, num_dims), minval=lower, maxval=upper)
    init_velocities = random.uniform(k_vel, (num_particles, num_dims), minval=-limit, maxval=limit)
    init_fitness = vmap(objective_fn)(init_positions)

    best_idx = jnp.argmin(init_fitness)
    g_best_pos = init_positions[best_idx]
    g_best_fit = init_fitness[best_idx]

    initial_state = JaxGdSwarmState(
        positions=init_positions,
        velocities=init_velocities,
        p_best_pos=init_positions,
        p_best_fit=init_fitness,
        g_best_pos=g_best_pos,
        g_best_fit=g_best_fit,
        rng=k_state,
    )

    gradient_fn = grad(objective_fn)

    def update_step(swarm_state: JaxGdSwarmState, i: int) -> tuple:
        k1, k2, k_next = random.split(swarm_state.rng, 3)
        r1 = random.uniform(k1, (num_particles, num_dims))
        r2 = random.uniform(k2, (num_particles, num_dims))

        inertia = w * swarm_state.velocities
        cognitive = c1 * r1 * (swarm_state.p_best_pos - swarm_state.positions)
        social = c2 * r2 * (swarm_state.g_best_pos - swarm_state.positions)

        new_velocities = inertia + cognitive + social
        new_positions = swarm_state.positions + new_velocities
        new_positions = jnp.clip(new_positions, lower, upper)

        new_fitness = vmap(objective_fn)(new_positions)

        improved = new_fitness < swarm_state.p_best_fit

        new_p_best_pos = jnp.where(improved[:, None], new_positions, swarm_state.p_best_pos)
        new_p_best_fit = jnp.where(improved, new_fitness, swarm_state.p_best_fit)

        current_g_best_idx = jnp.argmin(new_p_best_fit)
        current_g_best_fit = new_p_best_fit[current_g_best_idx]

        global_improved = current_g_best_fit < swarm_state.g_best_fit

        candidate_g_pos = jnp.where(
            global_improved,
            new_p_best_pos[current_g_best_idx],
            swarm_state.g_best_pos,
        )

        candidate_g_fit = jnp.where(global_improved, current_g_best_fit, swarm_state.g_best_fit)

        def gradient_descent_step(g_state: GradientState, _: None) -> tuple:
            grads = gradient_fn(g_state.current_pos)
            updated_pos = g_state.current_pos - eta * grads
            updated_pos = jnp.clip(updated_pos, lower, upper)
            return GradientState(updated_pos), None

        def apply_gradient(_: None) -> tuple:
            init_grad_state = GradientState(candidate_g_pos)
            final_grad_state, _ = lax.scan(
                gradient_descent_step,
                init_grad_state,
                None,
                steps,
            )
            final_pos = final_grad_state.current_pos
            final_fit = objective_fn(final_pos)
            return final_pos, final_fit

        def skip_gradient(_: None) -> tuple:
            return candidate_g_pos, candidate_g_fit

        gradient_g_pos, gradient_g_fit = lax.cond(
            i % 10 == 0,
            apply_gradient,
            skip_gradient,
            None,
        )

        gd_improved = gradient_g_fit < candidate_g_fit
        final_g_pos = jnp.where(gd_improved, gradient_g_pos, candidate_g_pos)
        final_g_fit = jnp.where(gd_improved, gradient_g_fit, candidate_g_fit)

        any_improvement = final_g_fit < swarm_state.g_best_fit

        target_idx = current_g_best_idx

        mask_winner = (jnp.arange(num_particles) == target_idx)[:, None]
        should_update_mask = (gd_improved & any_improvement) & mask_winner

        final_p_best_pos = jnp.where(
            should_update_mask,
            final_g_pos,
            new_p_best_pos,
        )

        final_p_best_fit = jnp.where(
            (gd_improved & any_improvement) & (jnp.arange(num_particles) == target_idx),
            final_g_fit,
            new_p_best_fit,
        )

        next_state = JaxGdSwarmState(
            positions=new_positions,
            velocities=new_velocities,
            p_best_pos=final_p_best_pos,
            p_best_fit=final_p_best_fit,
            g_best_pos=final_g_pos,
            g_best_fit=final_g_fit,
            rng=k_next,
        )

        return next_state, final_g_fit

    final_state, history = lax.scan(update_step, initial_state, jnp.arange(max_iters))
    full_history = jnp.concatenate([jnp.array([initial_state.g_best_fit]), history])

    return final_state.g_best_pos, final_state.g_best_fit, full_history

# pso.py

In [None]:
class SwarmState(NamedTuple):
    positions: np.ndarray
    velocities: np.ndarray
    p_best_pos: np.ndarray
    p_best_fit: np.ndarray
    g_best_pos: np.ndarray
    g_best_fit: np.ndarray
    rng: np.random.Generator
    history: np.ndarray


def pso(
    objective_fn: callable,
    bounds: tuple,
    num_dims: int,
    num_particles: int,
    max_iters: int,
    c1: float,
    c2: float,
    w: float,
    seed: int,
    **_: any,
) -> tuple:
    lower, upper = bounds
    rng = np.random.default_rng(seed)

    init_positions = rng.uniform(lower, upper, (num_particles, num_dims))
    init_velocities = np.zeros((num_particles, num_dims))
    init_fitness = np.array([objective_fn(position) for position in init_positions])

    best_idx = np.argmin(init_fitness)
    g_best_pos = init_positions[best_idx]
    g_best_fit = init_fitness[best_idx]

    history = np.zeros(max_iters)
    history[0] = g_best_fit

    swarm_state = SwarmState(
        positions=init_positions,
        velocities=init_velocities,
        p_best_pos=init_positions,
        p_best_fit=init_fitness,
        g_best_pos=g_best_pos,
        g_best_fit=g_best_fit,
        rng=rng,
        history=history,
    )

    for i in range(max_iters):
        r1 = swarm_state.rng.random((num_particles, num_dims))
        r2 = swarm_state.rng.random((num_particles, num_dims))

        inertia = w * swarm_state.velocities
        cognitive = c1 * r1 * (swarm_state.p_best_pos - swarm_state.positions)
        social = c2 * r2 * (swarm_state.g_best_pos - swarm_state.positions)

        new_velocities = inertia + cognitive + social
        new_positions = swarm_state.positions + new_velocities
        new_positions = np.clip(new_positions, lower, upper)

        new_fitness = np.array([objective_fn(pos) for pos in new_positions])

        improved = new_fitness < swarm_state.p_best_fit
        mask = improved[:, None]
        new_p_best_pos = np.where(mask, new_positions, swarm_state.p_best_pos)
        new_p_best_fit = np.where(improved, new_fitness, swarm_state.p_best_fit)

        current_g_best_idx = np.argmin(new_p_best_fit)
        current_g_best_fit = new_p_best_fit[current_g_best_idx]
        global_improved = current_g_best_fit < swarm_state.g_best_fit
        new_g_best_pos = np.where(
            global_improved,
            new_p_best_pos[current_g_best_idx],
            swarm_state.g_best_pos,
        )
        new_g_best_fit = np.where(
            global_improved,
            current_g_best_fit,
            swarm_state.g_best_fit,
        )

        new_history = swarm_state.history
        new_history[i] = new_g_best_fit

        swarm_state = SwarmState(
            positions=new_positions,
            velocities=new_velocities,
            p_best_pos=new_p_best_pos,
            p_best_fit=new_p_best_fit,
            g_best_pos=new_g_best_pos,
            g_best_fit=new_g_best_fit,
            rng=swarm_state.rng,
            history=new_history,
        )

    return swarm_state.g_best_pos, swarm_state.g_best_fit, swarm_state.history


# plot_benchmarks.py

In [None]:
def _save_figure(fig: plt.Figure, filename: str, config: dict) -> None:
    save_path = config["output_path"] / filename
    fig.savefig(save_path, bbox_inches="tight", format="pdf", dpi=300)
    plt.close(fig)


def plot_execution_time(df: pd.DataFrame, config: dict) -> None:
    benchmarks = df["Benchmark"].unique()
    algorithms = df["Algorithm"].unique()
    dimensions = df["Dimension"].unique()
    colors = sns.color_palette(config["palette"], n_colors=len(algorithms))

    fig, axes = plt.subplots(2, 2, figsize=(10, 8))
    axes_flattened = axes.flatten()

    for ax, benchmark in zip(axes_flattened, benchmarks):
        df_benchmark = df[df["Benchmark"] == benchmark]

        for idx, algorithm in enumerate(algorithms):
            df_subset = df_benchmark[df_benchmark["Algorithm"] == algorithm]
            mean = df_subset["Mean of Execution Times (s)"]
            std = df_subset["Standard Deviation of Execution Times (s)"]

            ax.plot(dimensions, mean, marker="o", label=algorithm, color=colors[idx])
            ax.fill_between(dimensions, mean - std, mean + std, color=colors[idx], alpha=0.2)

        ax.set_xlabel("Dimension")
        ax.set_ylabel("Time (s)")
        ax.set_title(f"{benchmark}", fontweight="bold")

        if ax == axes_flattened[0]:
            ax.legend(title="Algorithm")

    plt.tight_layout()
    _save_figure(fig, "execution_time_plot.pdf", config)


def plot_convergence(df: pd.DataFrame, config: dict) -> None:
    benchmarks = df["Benchmark"].unique()
    dimensions = df["Dimension"].unique()
    algorithms = df["Algorithm"].unique()
    colors = sns.color_palette(config["palette"], n_colors=len(algorithms))

    fig, axes = plt.subplots(4, 4, figsize=(16, 16))

    for i, benchmark in enumerate(benchmarks):
        for j, dimension in enumerate(dimensions):
            ax = axes[i, j]
            df_subset = df[(df["Dimension"] == dimension) & (df["Benchmark"] == benchmark)]

            for k, (_, row) in enumerate(df_subset.iterrows()):
                mean_history = jnp.array(row["Mean Fitness History"])
                std_history = jnp.array(row["Std Fitness History"])
                iterations = range(len(mean_history))

                ax.plot(iterations, mean_history, label=row["Algorithm"], color=colors[k])
                ax.fill_between(
                    iterations,
                    mean_history - std_history,
                    mean_history + std_history,
                    alpha=0.2,
                    color=colors[k],
                )

            ax.set_xlabel("Iteration")
            ax.set_ylabel("Fitness")
            ax.set_title(f"{benchmark} - {dimension}D", fontsize=10)

    plt.tight_layout()
    _save_figure(fig, "convergence_plot.pdf", config)


def _generate_comparison_table(
    df: pd.DataFrame,
    config: dict,
    mean_col: str,
    std_col: str,
    output_filename: str,
    caption: str,
    label: str,
) -> None:
    df_proc = df.copy()

    df_proc["formatted"] = (
        df_proc[mean_col].map("{:.2e}".format) + r" $\pm$ " + df_proc[std_col].map("{:.2e}".format)
    )

    df_pivot = df_proc.pivot_table(
        index=["Benchmark", "Dimension"],
        columns="Algorithm",
        values=["formatted", mean_col],
        aggfunc="first",
    )

    means = df_pivot[mean_col].fillna(float("inf"))

    display_df = df_pivot["formatted"].copy()

    for index, row in means.iterrows():
        min_val = row.min()
        is_min = row == min_val
        for col in display_df.columns:
            if is_min[col]:
                display_df.loc[index, col] = f"\\textbf{{{display_df.loc[index, col]}}}"

    display_df = display_df.reset_index()

    output_path = config["output_path"] / output_filename

    latex_code = display_df.style.hide(axis="index").to_latex(
        column_format="llcc",
        hrules=True,
        caption=caption,
        label=label,
        position="h",
    )

    with output_path.open("w") as f:
        f.write(latex_code)


def create_convergence_table(df: pd.DataFrame, config: dict) -> None:
    _generate_comparison_table(
        df=df,
        config=config,
        mean_col="Mean of Fitness",
        std_col="Standard Deviation of Fitness",
        output_filename="convergence_table.tex",
        caption=r"Convergence comparison (Mean Fitness $\pm$ Std Dev). Best results in bold.",
        label="tab:convergence",
    )


def create_execution_time_table(df: pd.DataFrame, config: dict) -> None:
    _generate_comparison_table(
        df=df,
        config=config,
        mean_col="Mean of Execution Times (s)",
        std_col="Standard Deviation of Execution Times (s)",
        output_filename="execution_time_table.tex",
        caption="Execution time comparison in seconds.",
        label="tab:execution_time",
    )


def generate_visualizations(df: pd.DataFrame, config: dict) -> None:
    print("Plotting convergence...")
    plot_convergence(df, config)

    print("Plotting execution time...")
    plot_execution_time(df, config)

    print("Creating convergence table (LaTeX)...")
    create_convergence_table(df, config)

    print("Creating execution time table (LaTeX)...")
    create_execution_time_table(df, config)


config = {
    "output_path": Path("./results/"),
    "palette": "viridis",
}

# benchmarks.py

In [None]:
def ackley_np(x: np.ndarray) -> float:
    n = x.shape[0]
    sum1 = np.sum(x**2)
    sum2 = np.sum(np.cos(2 * np.pi * x))
    return -20 * np.exp(-0.2 * np.sqrt(sum1 / n)) - np.exp(sum2 / n) + 20 + np.e


def rastrigin_np(x: np.ndarray) -> float:
    n = x.shape[0]
    return 10 * n + np.sum(x**2 - 10 * np.cos(2 * np.pi * x))


def sphere_np(x: np.ndarray) -> float:
    return np.sum(x**2)


def rosenbrock_np(x: np.ndarray) -> float:
    return np.sum(100 * (x[1:] - x[:-1] ** 2) ** 2 + (1 - x[:-1]) ** 2)


@jit
def ackley_jax(x: jnp.ndarray) -> jnp.ndarray:
    n = x.shape[0]
    sum1 = jnp.sum(x**2)
    sum2 = jnp.sum(jnp.cos(2 * jnp.pi * x))
    return -20 * jnp.exp(-0.2 * jnp.sqrt(sum1 / n)) - jnp.exp(sum2 / n) + 20 + jnp.e


@jit
def rastrigin_jax(x: jnp.ndarray) -> jnp.ndarray:
    n = x.shape[0]
    return 10 * n + jnp.sum(x**2 - 10 * jnp.cos(2 * jnp.pi * x))


@jit
def sphere_jax(x: jnp.ndarray) -> jnp.ndarray:
    return jnp.sum(x**2)


@jit
def rosenbrock_jax(x: jnp.ndarray) -> jnp.ndarray:
    return jnp.sum(100 * (x[1:] - x[:-1] ** 2) ** 2 + (1 - x[:-1]) ** 2)


BENCHMARKS = {
    "Ackley": {
        "bounds": (-32.768, 32.768),
        "PSO": ackley_np,
        "JAX-GD-PSO": ackley_jax,
    },
    "Rastrigin": {
        "bounds": (-5.12, 5.12),
        "PSO": rastrigin_np,
        "JAX-GD-PSO": rastrigin_jax,
    },
    "Rosenbrock": {
        "bounds": (-5.0, 10.0),
        "PSO": rosenbrock_np,
        "JAX-GD-PSO": rosenbrock_jax,
    },
    "Sphere": {
        "bounds": (-5.12, 5.12),
        "PSO": sphere_np,
        "JAX-GD-PSO": sphere_jax,
    },
}

ALGORITHMS = {
    "PSO": pso,
    "JAX-GD-PSO": jax_gd_pso,
}

DIMS = [30, 100, 500, 1000]

HYPERPARAMETERS = {
    "num_dims": None,
    "num_particles": 30,
    "max_iters": 1000,
    "c1": 1.5,
    "c2": 1.5,
    "w": 0.7,
    "seed": None,
    "eta": 0.01,
    "steps": 5,
}

NUM_RUNS = 50

# main.py

In [None]:
def run_experiment() -> list[dict]:
    results = []

    total_configs = len(DIMS) * len(ALGORITHMS) * len(BENCHMARKS)
    current_config = 0

    for dim in DIMS:
        print(f"Dimension: {dim}")
        for algorithm_name, algorithm_fn in ALGORITHMS.items():
            for benchmark_name, benchmark_config in BENCHMARKS.items():
                current_config += 1

                objective_fn = benchmark_config[algorithm_name]
                bounds = benchmark_config["bounds"]
                hyperparameters = HYPERPARAMETERS.copy()
                hyperparameters["num_dims"] = dim

                print(
                    f"[{current_config}/{total_configs}] Running {algorithm_name} "
                    f"on {benchmark_name}",
                )

                execution_times = []
                fitness_history = []
                for i in range(NUM_RUNS):
                    hyperparameters["seed"] = i

                    if algorithm_name == "JAX-GD-PSO":
                        hyperparameters["seed"] = random.PRNGKey(i)
                        algorithm_fn(objective_fn, bounds, **hyperparameters)

                    start = time.perf_counter()
                    result = algorithm_fn(objective_fn, bounds, **hyperparameters)

                    if algorithm_name == "JAX-GD-PSO":
                        block_until_ready(result)

                    end = time.perf_counter()
                    execution_times.append(end - start)

                    _, fitness, history = result

                    print(f"Iteration {i + 1} | Fitness: {fitness}")
                    fitness_history.append(history)

                mean_fitness_history = jnp.mean(jnp.array(fitness_history), axis=0)
                std_fitness_history = jnp.std(jnp.array(fitness_history), axis=0)

                mean_time = float(jnp.mean(jnp.array(execution_times)))
                std_time = float(jnp.std(jnp.array(execution_times)))

                mean_fitness = float(jnp.mean(mean_fitness_history))
                std_fitness = float(jnp.std(std_fitness_history))

                results.extend(
                    [
                        {
                            "Dimension": dim,
                            "Benchmark": benchmark_name,
                            "Algorithm": algorithm_name,
                            "Execution Time History": execution_times,
                            "Mean of Execution Times (s)": mean_time,
                            "Standard Deviation of Execution Times (s)": std_time,
                            "Mean Fitness History": mean_fitness_history.tolist(),
                            "Std Fitness History": std_fitness_history.tolist(),
                            "Mean of Fitness": mean_fitness,
                            "Standard Deviation of Fitness": std_fitness,
                        },
                    ],
                )

    return results

In [None]:
results = run_experiment()
df = pd.DataFrame(results)

In [None]:
df.to_csv("./experiment_results.csv")
generate_visualizations(df)