In [None]:
import math
import time
from functools import partial
from pathlib import Path
from typing import NamedTuple

import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from jax import block_until_ready, grad, jit, lax, random, vmap

# jax_pso.py

In [None]:
class JaxSwarmState(NamedTuple):
    positions: jnp.ndarray
    velocities: jnp.ndarray
    p_best_pos: jnp.ndarray
    p_best_fit: jnp.ndarray
    g_best_pos: jnp.ndarray
    g_best_fit: jnp.ndarray
    rng: random.PRNGKey

class GradientState(NamedTuple):
    new_g_best_pos: jnp.ndarray
    lower: jnp.ndarray
    upper: jnp.ndarray
    eta: float

@partial(
    jit,
    static_argnames=(
        "objective_fn",
        "bounds",
        "num_dims",
        "num_particles",
        "max_iters",
        "c1",
        "c2",
        "w",
        "eta",
        "steps",
    ),
)
def jax_gd_pso(
    objective_fn: callable,
    bounds: tuple,
    num_dims: int,
    num_particles: int,
    max_iters: int,
    c1: float,
    c2: float,
    w: float,
    seed: random.PRNGKey,
    eta: float,
    steps: int,
) -> tuple:
    key = seed
    lower, upper = bounds
    k_pos, k_vel, k_state = random.split(key, 3)

    search_range = upper - lower
    velocity_scale = 0.1
    limit = search_range * velocity_scale

    init_positions = random.uniform(
        k_pos, (num_particles, num_dims), minval=lower, maxval=upper,
    )
    init_velocities = random.uniform(
        k_vel, (num_particles, num_dims), minval=-limit, maxval=limit,
    )
    init_fitness = vmap(objective_fn)(init_positions)

    best_idx = jnp.argmin(init_fitness)
    g_best_pos = init_positions[best_idx]
    g_best_fit = init_fitness[best_idx]

    initial_state = JaxSwarmState(
        positions=init_positions,
        velocities=init_velocities,
        p_best_pos=init_positions,
        p_best_fit=init_fitness,
        g_best_pos=g_best_pos,
        g_best_fit=g_best_fit,
        rng=k_state,
    )

    gradient_fn = grad(objective_fn)

    def update_step(swarm_state: JaxSwarmState, i: int) -> tuple:
        k1, k2, k_next = random.split(swarm_state.rng, 3)
        r1 = random.uniform(k1, (num_particles, num_dims))
        r2 = random.uniform(k2, (num_particles, num_dims))

        inertia = w * swarm_state.velocities
        cognitive = c1 * r1 * (swarm_state.p_best_pos - swarm_state.positions)
        social = c2 * r2 * (swarm_state.g_best_pos - swarm_state.positions)

        new_velocities = inertia + cognitive + social
        new_positions = swarm_state.positions + new_velocities
        new_positions = jnp.clip(new_positions, lower, upper)

        new_fitness = vmap(objective_fn)(new_positions)

        improved = new_fitness < swarm_state.p_best_fit
        mask = improved[:, None]
        new_p_best_pos = jnp.where(mask, new_positions, swarm_state.p_best_pos)
        new_p_best_fit = jnp.where(improved, new_fitness, swarm_state.p_best_fit)

        current_g_best_idx = jnp.argmin(new_p_best_fit)
        current_g_best_fit = new_p_best_fit[current_g_best_idx]

        global_improved = current_g_best_fit < swarm_state.g_best_fit
        candidate_g_pos = jnp.where(
            global_improved, new_p_best_pos[current_g_best_idx], swarm_state.g_best_pos,
        )
        candidate_g_fit = jnp.where(
            global_improved, current_g_best_fit, swarm_state.g_best_fit,
        )

        def gradient_descent_step(g_state: GradientState, _: None) -> tuple:
            grads = gradient_fn(g_state.new_g_best_pos)
            step = g_state.eta * grads
            updated_pos = g_state.new_g_best_pos - step
            updated_pos = jnp.clip(updated_pos, g_state.lower, g_state.upper)
            next_state = GradientState(
                new_g_best_pos=updated_pos,
                lower=g_state.lower,
                upper=g_state.upper,
                eta=g_state.eta,
            )
            return next_state, None

        def apply_gradient(_: None) -> tuple:
            init_grad_state = GradientState(candidate_g_pos, lower, upper, eta)
            final_grad_state, _ = lax.scan(
                gradient_descent_step,
                init_grad_state,
                None,
                steps,
            )
            final_pos = final_grad_state.new_g_best_pos
            return final_pos, objective_fn(final_pos)

        def skip_gradient(_: None) -> tuple:
            return candidate_g_pos, candidate_g_fit

        gradient_g_pos, gradient_g_fit = lax.cond(
            i % 10 == 0,
            apply_gradient,
            skip_gradient,
            None,
        )

        global_improved = gradient_g_fit < candidate_g_fit
        final_g_pos = jnp.where(global_improved, gradient_g_pos, candidate_g_pos)
        final_g_fit = jnp.where(global_improved, gradient_g_fit, candidate_g_fit)

        next_state = JaxSwarmState(
            positions=new_positions,
            velocities=new_velocities,
            p_best_pos=new_p_best_pos,
            p_best_fit=new_p_best_fit,
            g_best_pos=final_g_pos,
            g_best_fit=final_g_fit,
            rng=k_next,
        )

        return next_state, final_g_fit

    final_state, history = lax.scan(update_step, initial_state, jnp.arange(max_iters))
    full_history = jnp.concatenate([jnp.array([initial_state.g_best_fit]), history])

    return final_state.g_best_pos, final_state.g_best_fit, full_history

# numpy_pso.py

In [None]:
class SwarmState(NamedTuple):
    positions: np.ndarray
    velocities: np.ndarray
    p_best_pos: np.ndarray
    p_best_fit: np.ndarray
    g_best_pos: np.ndarray
    g_best_fit: np.ndarray
    rng: np.random.Generator
    history: np.ndarray

def numpy_pso(
    objective_fn: callable,
    bounds: tuple,
    num_dims: int,
    num_particles: int,
    max_iters: int,
    c1: float,
    c2: float,
    w: float,
    seed: int,
    **_: any,
) -> tuple:
    lower, upper = bounds
    rng = np.random.default_rng(seed)

    init_positions = rng.uniform(lower, upper, (num_particles, num_dims))
    init_velocities = np.zeros((num_particles, num_dims))
    init_fitness = np.array([objective_fn(position) for position in init_positions])

    best_idx = np.argmin(init_fitness)
    g_best_pos = init_positions[best_idx]
    g_best_fit = init_fitness[best_idx]

    history = np.zeros(max_iters)
    history[0] = g_best_fit

    swarm_state = SwarmState(
        positions=init_positions,
        velocities=init_velocities,
        p_best_pos=init_positions,
        p_best_fit=init_fitness,
        g_best_pos=g_best_pos,
        g_best_fit=g_best_fit,
        rng=rng,
        history=history,
    )

    for i in range(max_iters):
        r1 = swarm_state.rng.random((num_particles, num_dims))
        r2 = swarm_state.rng.random((num_particles, num_dims))

        inertia = w * swarm_state.velocities
        cognitive = c1 * r1 * (swarm_state.p_best_pos - swarm_state.positions)
        social = c2 * r2 * (swarm_state.g_best_pos - swarm_state.positions)

        new_velocities = inertia + cognitive + social
        new_positions = swarm_state.positions + new_velocities
        new_positions = np.clip(new_positions, lower, upper)

        new_fitness = np.array([objective_fn(pos) for pos in new_positions])

        improved = new_fitness < swarm_state.p_best_fit
        mask = improved[:, None]
        new_p_best_pos = np.where(mask, new_positions, swarm_state.p_best_pos)
        new_p_best_fit = np.where(improved, new_fitness, swarm_state.p_best_fit)

        current_g_best_idx = np.argmin(new_p_best_fit)
        current_g_best_fit = new_p_best_fit[current_g_best_idx]
        global_improved = current_g_best_fit < swarm_state.g_best_fit
        new_g_best_pos = np.where(
            global_improved, new_p_best_pos[current_g_best_idx], swarm_state.g_best_pos,
        )
        new_g_best_fit = np.where(
            global_improved, current_g_best_fit, swarm_state.g_best_fit,
        )

        new_history = swarm_state.history
        new_history[i] = new_g_best_fit

        swarm_state = SwarmState(
            positions=new_positions,
            velocities=new_velocities,
            p_best_pos=new_p_best_pos,
            p_best_fit=new_p_best_fit,
            g_best_pos=new_g_best_pos,
            g_best_fit=new_g_best_fit,
            rng=swarm_state.rng,
            history=new_history,
        )

    return swarm_state.g_best_pos, swarm_state.g_best_fit, swarm_state.history

# plot_benchmarks.py

In [None]:
def _save_figure(fig: plt.Figure, filename: str, config: dict) -> None:
    save_path = config["output_path"] / filename
    fig.savefig(save_path, bbox_inches="tight")
    plt.close(fig)

def plot_execution_time_dims(df: pd.DataFrame, config: dict) -> None:
    benchmarks = df["Benchmark"].unique()
    fig, axes = plt.subplots(
        4,
        4,
        figsize=(7, 7),
        constrained_layout=True,
        sharex=True,
        sharey=True,
    )
    axes_flat = axes.flatten()

    for ax, benchmark in zip(axes_flat, benchmarks, strict=True):
        df_benchmark = df[df["Benchmark"] == benchmark]
        algorithms = df_benchmark["Algorithm"].unique()
        colors = sns.color_palette(n_colors=len(algorithms))

        sns.lineplot(
            data=df_benchmark,
            x="Dimension",
            y="Mean of Execution Times (s)",
            hue="Algorithm",
            marker="o",
            palette=config["palette"],
            ax=ax,
        )

        for algorithm, color in zip(algorithms, colors, strict=True):
            data_algorithm = df_benchmark[df_benchmark["Algorithm"] == algorithm]
            mean = data_algorithm["Mean of Execution Times (s)"]
            std = data_algorithm["Standard Deviation of Execution Times (s)"]

            ax.fill_between(
                data_algorithm["Dimension"],
                mean - std,
                mean + std,
                color=color,
                alpha=0.2,
            )

        ax.set_title(f"{benchmark}", fontweight="bold")
        ax.set(xlabel="", ylabel="", yscale="log")
        ax.get_legend().remove()

    fig.supxlabel("Dimension", fontsize=8)
    fig.supylabel("Mean of Execution Times (s)", fontsize=8)

    handles, labels = axes_flat[0].get_legend_handles_labels()
    fig.legend(
        handles,
        labels,
        loc="upper center",
        bbox_to_anchor=(0.5, 1.05),
        ncol=3,
        frameon=False,
    )

    _save_figure(fig, "execution_time_dims.pdf", config)

def plot_convergence(df: pd.DataFrame, config: dict) -> None:
    benchmarks = df["Benchmark"].unique()
    dimensions = df["Dimension"].unique()

    for benchmark in benchmarks:
        for dimension in dimensions:
            df_subset = df[
                (df["Dimension"] == dimension) & (df["Benchmark"] == benchmark)
            ]

            fig, ax = plt.subplots()

            for _, row in df_subset.iterrows():
                sns.lineplot(
                    data=row["Fitness History"],
                    ax=ax,
                    label=row["Algorithm"],
                )

            ax.set(xlabel="Iteration", ylabel="Fitness")
            ax.set_title("Convergence History", fontweight="bold")
            ax.legend(title="Algorithms")
            _save_figure(fig, f"convergence_{benchmark}_{dimension}d.pdf", config)

def generate_summary_tables(df: pd.DataFrame, config: dict) -> None:
    for dim, df_dim in df.groupby("Dimension"):
        pivot_table = df_dim.pivot_table(
            index="Benchmark",
            columns="Algorithm",
            values ="Mean of Execution Times (s)",
        )
        save_path = config["output_path"] / f"execution_time_table_{dim}d.csv"
        pivot_table.to_csv(save_path)

def generate_visualizations(df: pd.DataFrame) -> None:
    config = {
        "output_path": Path("./results/"),
        "palette": "viridis",
        "font": {
            "font.size": 7,
            "axes.titlesize": 9,
            "legend.fontsize": 8,
            "axes.labelsize": 10,
            "xtick.labelsize": 7.5,
            "ytick.labelsize": 7.5,
        },
    }

    print("Plotting execution time by dimensions...")
    plot_execution_time_dims(df, config)

    print("Generating summary tables...")
    generate_summary_tables(df, config)

    print("Plotting convergence histories...")
    plot_convergence(df, config)

# benchmarks.py

In [None]:
def ackley_py(x: list) -> float:
    n = len(x)
    sum1 = sum(xi**2 for xi in x)
    sum2 = sum(math.cos(2 * math.pi * xi) for xi in x)
    return -20 * math.exp(-0.2 * math.sqrt(sum1 / n)) - math.exp(sum2 / n) + 20 + math.e

def rastrigin_py(x: list) -> float:
    n = len(x)
    return 10 * n + sum(xi**2 - 10 * math.cos(2 * math.pi * xi) for xi in x)

def sphere_py(x: list) -> float:
    return sum(xi**2 for xi in x)

def rosenbrock_py(x: list) -> float:
    return sum(100 * (x[i+1] - x[i]**2)**2 + (x[i] - 1)**2 for i in range(len(x) - 1))

def ackley_np(x: np.ndarray) -> float:
    n = x.shape[0]
    sum1 = np.sum(x**2)
    sum2 = np.sum(np.cos(2 * np.pi * x))
    return -20 * np.exp(-0.2 * np.sqrt(sum1 / n)) - np.exp(sum2 / n) + 20 + np.e

def rastrigin_np(x: np.ndarray) -> float:
    n = x.shape[0]
    return 10 * n + np.sum(x**2 - 10 * np.cos(2 * np.pi * x))

def sphere_np(x: np.ndarray) -> float:
    return np.sum(x**2)

def rosenbrock_np(x: np.ndarray) -> float:
    return np.sum(100 * (x[1:] - x[:-1]**2)**2 + (1 - x[:-1])**2)

@jit
def ackley_jax(x: jnp.ndarray) -> jnp.ndarray:
    n = x.shape[0]
    sum1 = jnp.sum(x**2)
    sum2 = jnp.sum(jnp.cos(2 * jnp.pi * x))
    return -20 * jnp.exp(-0.2 * jnp.sqrt(sum1 / n)) - jnp.exp(sum2 / n) + 20 + jnp.e

@jit
def rastrigin_jax(x: jnp.ndarray) -> jnp.ndarray:
    n = x.shape[0]
    return 10 * n + jnp.sum(x**2 - 10 * jnp.cos(2 * jnp.pi * x))

@jit
def sphere_jax(x: jnp.ndarray) -> jnp.ndarray:
    return jnp.sum(x**2)

@jit
def rosenbrock_jax(x: jnp.ndarray) -> jnp.ndarray:
    return jnp.sum(100 * (x[1:] - x[:-1]**2)**2 + (1 - x[:-1])**2)

BENCHMARKS = {
    "Ackley": {
        "bounds": (-32.768, 32.768),
        "Python PSO": ackley_py,
        "NumPy PSO": ackley_np,
        "JAX PSO": ackley_jax,
    },
    "Rastrigin": {
        "bounds": (-5.12, 5.12),
        "Python PSO": rastrigin_py,
        "NumPy PSO": rastrigin_np,
        "JAX PSO": rastrigin_jax,
    },
    "Rosenbrock": {
        "bounds": (-5.0, 10.0),
        "Python PSO": rosenbrock_py,
        "NumPy PSO": rosenbrock_np,
        "JAX PSO": rosenbrock_jax,
    },
    "Sphere": {
        "bounds": (-5.12, 5.12),
        "Python PSO": sphere_py,
        "NumPy PSO": sphere_np,
        "JAX PSO": sphere_jax,
    },
}

ALGORITHMS = {
    "NumPy PSO": numpy_pso,
    "JAX PSO": jax_gd_pso,
}

DIMS = [10, 30, 50, 100]

HYPERPARAMETERS = {
    "num_dims": None,
    "num_particles": 30,
    "max_iters": 100,
    "c1": 1.5,
    "c2": 1.5,
    "w": 0.7,
    "seed": None,
    "eta": 0.001,
    "steps": 5,
}

NUM_RUNS = 10

# main.py

In [None]:
def run_experiment() -> list[dict]:
    results = []

    total_configs = len(DIMS) * len(ALGORITHMS) * len(BENCHMARKS)
    current_config = 0

    for dim in DIMS:
        print(f"Dimension: {dim}")
        for algorithm_name, algorithm_fn in ALGORITHMS.items():
            for benchmark_name, benchmark_config in BENCHMARKS.items():
                current_config += 1

                objective_fn = benchmark_config[algorithm_name]
                bounds = benchmark_config["bounds"]
                hyperparameters = HYPERPARAMETERS.copy()
                hyperparameters["num_dims"] = dim

                print(
                    f"[{current_config}/{total_configs}] Running {algorithm_name} "
                    f"on {benchmark_name}",
                )

                execution_times = []
                fitness_history = []
                for i in range(NUM_RUNS):
                    hyperparameters["seed"] = i

                    if algorithm_name == "JAX PSO":
                        hyperparameters["seed"] = random.PRNGKey(i)

                    start = time.perf_counter()
                    result = algorithm_fn(objective_fn, bounds, **hyperparameters)

                    if algorithm_name == "JAX PSO":
                        block_until_ready(result)

                    end = time.perf_counter()
                    execution_times.append(end - start)

                    _, fitness, _ = result

                    print(f"Iteration {i + 1} | Fitness: {fitness}")
                    fitness_history.append(fitness)

                mean_time = float(jnp.mean(jnp.array(execution_times)))
                std_time = float(jnp.std(jnp.array(execution_times)))

                mean_fitness = float(jnp.mean(jnp.array(fitness_history)))
                std_fitness = float(jnp.std(jnp.array(fitness_history)))


                results.extend(
                        [
                            {
                                "Dimension": dim,
                                "Benchmark": benchmark_name,
                                "Algorithm": algorithm_name,
                                "Execution Time History": execution_times,
                                "Mean of Execution Times (s)": mean_time,
                                "Standard Deviation of Execution Times (s)": std_time,
                                "Fitness History": fitness_history,
                                "Mean of Fitness": mean_fitness,
                                "Standard Deviation of Fitness": std_fitness,
                            },
                        ],
                    )

    return results

In [None]:
results = run_experiment()
df = pd.DataFrame(results)

In [None]:
df.to_csv("./experiment_results.csv")
generate_visualizations(df)