In [2]:
import time
from typing import Any
from tqdm import tqdm

import gymnasium as gym
import numpy as np
from gymnasium.envs.registration import register
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd

from swarm.bas import Swarm, Agent, BASEnv, Blueprint, wrappers, RenderWrapper

In [7]:
def benchmark_env(
    env: gym.Env, num_resets: int, num_steps: int, action: Any
) -> np.ndarray:
    """Benchmark environment.

    Return array, where array[r, s+1] = time for step s in after reset r.
    array[r, 0] is the time for the reset.
    """
    with tqdm(total=num_resets * (num_steps + 1)) as pbar:
        times = []
        for _ in range(num_resets):
            ts = []

            ts.append(time.perf_counter())
            env.reset()
            ts.append(time.perf_counter())

            pbar.update(1)

            for _ in range(num_steps):
                env.step(action)
                ts.append(time.perf_counter())
                pbar.update(1)

            times.append(ts)

        times = np.diff(np.array(times), axis=1)

        return times


In [8]:
env = BASEnv(
    blueprint=Blueprint(
        world_size=np.array([100, 100]),
    ),
    swarm=Swarm(
        num_boids=100,
        radius=1,
        max_velocity=1,
        max_acceleration=0.1,
        separation_range=5,
        cohesion_range=10,
        alignment_range=10,
        steering_weights=(1.1, 1, 1),
        obstacle_margin=3,
    ),
    agent=Agent(radius=1, max_velocity=1, reset_position=np.array([50, 50])),
)

env = wrappers.SectionObservationWrapper(env, num_sections=8, max_range=20)
env = wrappers.DistanceToTargetRewardWrapper(env, target_position=np.array([90,90]))
env = wrappers.DiscreteActionWrapper(env, num_actions=5)

In [12]:
times = benchmark_env(env, num_resets=20, num_steps=10000, action=1)


100%|██████████| 200020/200020 [06:34<00:00, 507.58it/s]


In [14]:
qs = [0, 0.25, 0.5, 0.75, 1]
values = np.quantile(1 / times[:, 1:], qs, axis=0)

fig = go.Figure(layout_title="SPS")
for q, v in zip(qs, values):
    fig.add_scatter(y=pd.Series(v).rolling(100).mean(), mode="lines", name=q)

fig.add_scatter(
    y=pd.Series(1 / times[:, 1:].mean(axis=0)).rolling(100).mean(),
    mode="lines",
    name="mean",
)
fig
