In [None]:
# test_cartpole.py
import cartpolev1
import random
import time


def run_cartpole_c(max_total_steps: int = 10_000):
    env = cartpolev1.CartPoleEnv()
    state = env.reset()
    # print("Initial State:", state)

    done = False
    time_steps = 0
    total_steps = 0
    max_time_steps = 500

    start_time = time.time()

    while total_steps < max_total_steps:

        while not done and time_steps < max_time_steps:
            # Agent selects a random action
            action = random.randint(0, 1)
            state, reward, done = env.step(action)
            # print(f"Time Step: {time_steps}, Action: {action}, State: {state}, Reward: {reward}, Done: {done}")
            time_steps += 1
            total_steps += 1

            if total_steps >= max_total_steps:
                break

        if done:
            time_steps = 0
            done = False
            # print(f"Episode terminated after {time_steps} time steps.")
        # else:
            # print(f"Episode reached max time steps ({max_time_steps}).")
        env.reset()

    sps = total_steps / (time.time() - start_time)
    print(f"CartPole C:\nTotal Steps: {total_steps}, Steps per Second: {sps:.2f}")
    return sps

In [None]:
run_cartpole_c(10_000)

In [109]:
import gymnasium as gym

def run_cartpole_py(max_total_steps: int = 10_000):
    env = gym.make("CartPole-v1")
    state, info = env.reset()
    # print("Initial State:", state)

    done = False
    time_steps = 0
    total_steps = 0
    max_time_steps = 500

    start_time = time.time()

    while total_steps < max_total_steps:

        while not done and time_steps < max_time_steps:
            # Agent selects a random action
            action = env.action_space.sample()
            _, _, done, truncated, _ = env.step(action)
            time_steps += 1
            total_steps += 1

            if total_steps >= max_total_steps:
                break

        if done or truncated:
            time_steps = 0
            done = False

        env.reset()

    sps = total_steps / (time.time() - start_time)
    print(f"CartPole Python:\nTotal Steps: {total_steps}, Steps per Second: {sps:.2f}")
    return sps


In [None]:
run_cartpole_py(10_000)

In [111]:
import time
import jax
jax.config.update("jax_enable_x64", True)
import gymnax

def run_cartpole_jax(max_total_steps: int = 10_000):
    key = jax.random.PRNGKey(0)
    env, env_params = gymnax.make("CartPole-v1")

    # JIT the environment functions
    jit_step = jax.jit(env.step)
    jit_reset = jax.jit(env.reset)

    total_steps = 0
    max_time_steps = 1000

    start_time = time.time()

    while total_steps < max_total_steps:
        # Reset the environment and keys at the start of each episode
        key, subkey = jax.random.split(key)
        obs, state = jit_reset(subkey, env_params)
        done = False
        time_steps = 0

        while not done and time_steps < max_time_steps:
            # Split keys for action and step
            key, subkey = jax.random.split(key)
            action = env.action_space(env_params).sample(subkey)

            key, subkey = jax.random.split(key)
            obs, state, reward, done, _ = jit_step(subkey, state, action, env_params)

            time_steps += 1
            total_steps += 1

            if total_steps >= max_total_steps:
                break

    sps = total_steps / (time.time() - start_time)
    print(f"CartPole JAX:\nTotal Steps: {total_steps}, Steps per Second: {sps:.2f}")
    return sps


In [None]:
run_cartpole_jax(10_000)

In [None]:
c_sps = run_cartpole_c(10_000)
py_sps = run_cartpole_py(10_000)
jax_sps = run_cartpole_jax(10_000)

print(f"Speedup: {c_sps / jax_sps:.2f}")
print(f"Speedup: {py_sps / jax_sps:.2f}")


In [None]:
# Plot a histogram of the results
import matplotlib.pyplot as plt


plt.bar(["CartPole C", "CartPole Python", "CartPole JAX"], [c_sps, py_sps, jax_sps], color=["blue", "orange", "green"])
plt.ylabel("Steps per Second")
plt.show()

In [None]:
# Vectorise Environment
import cartpole
import random

# Number of parallel environments
num_envs = 5

# Create a batch of environments
env_batch = cartpole.CartPoleBatch(num_envs=num_envs)

# Reset all environments
states = env_batch.reset()
# print("Initial States:")
# for i, state in enumerate(states):
    # print(f"Environment {i}: {state}")

done = [False] * num_envs
time_steps = 0
max_time_steps = 500

while not all(done) and time_steps < max_time_steps:
    # Agent selects actions for each environment
    actions = [random.randint(0, 1) for _ in range(num_envs)]
    
    # Step all environments
    states, rewards, dones = env_batch.step(actions)
    
    for i in range(num_envs):
        if not done[i]:  # Only print if the environment is still running
            # print(f"Env {i}, Time Step: {time_steps}, Action: {actions[i]}, State: {states[i]}, Reward: {rewards[i]}, Done: {dones[i]}")
            done[i] = dones[i]
    
    time_steps += 1

# print(f"Simulation ended after {time_steps} time steps.")

In [10]:
import time


def run_cartpole_batch_c(max_total_steps: int = 10_000):
    # Number of parallel environments
    num_envs = 5

    # Create a batch of environments
    env_batch = cartpole.CartPoleBatch(num_envs=num_envs)

    # Reset all environments
    states = env_batch.reset()
    # print("Initial State:", state)

    done = [False] * num_envs
    time_steps = 0
    total_steps = 0
    max_time_steps = 500

    start_time = time.time()

    while total_steps < max_total_steps:

        # Agent selects actions for each environment
        actions = [random.randint(0, 1) for _ in range(num_envs)]
        
        # Step all environments
        states, rewards, dones = env_batch.step(actions)
        # print(f"Time Step: {time_steps}, Action: {action}, State: {state}, Reward: {reward}, Done: {done}")
        time_steps += num_envs
        total_steps += num_envs

        if total_steps >= max_total_steps:
            break
        
        env_batch.reset()

    sps = total_steps / (time.time() - start_time)
    print(f"CartPole C:\nTotal Steps: {total_steps}, Steps per Second: {sps:.2f}")
    return sps

In [None]:
run_cartpole_batch_c(1000)

In [1]:
import cartpole
import random
import time

def run_cartpole_batch_c(max_total_steps: int = 10_000, num_envs: int = 10_000):
    # Number of parallel environments
    num_envs = 5

    # Create a batch of environments
    env_batch = cartpole.CartPoleBatch(num_envs=num_envs)

    # Initialize done flags
    done = [False] * num_envs

    # Reset all environments initially
    states = env_batch.reset()
    # print("Initial States:")
    # for i, state in enumerate(states):
    #     print(f"Environment {i}: {state}")

    time_steps = 0
    total_steps = 0
    max_time_steps = 500

    start_time = time.time()

    while total_steps < max_total_steps:
        # Agent selects actions for each environment
        actions = []
        for i in range(num_envs):
            if not done[i]:
                actions.append(random.randint(0, 1))
            else:
                actions.append(0)  # Dummy action for completed environments

        # Step all environments
        states, rewards, dones = env_batch.step(actions)

        # for i in range(num_envs):
        #     if not done[i]:  # Only process environments that are not done
        #         # print(f"Env {i}, Time Step: {time_steps}, Action: {actions[i]}, State: {states[i]}, Reward: {rewards[i]}, Done: {dones[i]}")
        #         done[i] = dones[i]
        #         if done[i]:
        #             # Reset the environment that is done
        #             reset_states = env_batch.reset([i])
        #             # print(f"Env {i} reset to State: {reset_states[0]}")
        #             done[i] = False  # Reset done flag after resetting

        time_steps += num_envs
        total_steps += num_envs

        if total_steps >= max_total_steps:
            break
    
    sps = total_steps / (time.time() - start_time)
    print(f"CartPole Batch C:\nTotal Steps: {total_steps}, Steps per Second: {sps:.2f}")

# print(f"Simulation ended after {time_steps} time steps.")

In [6]:
run_cartpole_batch_c(10_000, 10)

CartPole Batch C:
Total Steps: 10000, Steps per Second: 838006.03
