In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

## Cluttered Env

In [2]:
import os
import sys

import jax
import jax.numpy as jnp

from cluttered_env import ClutteredEnv, EnvParams

parent_dir = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.insert(0, parent_dir)

from purejaxrl.purejaxrl.wrappers import (
    NormalizeVecObsEnvState,
    ClipAction,
)

config = {
    "LR": 3e-4,
    "NUM_ENVS": 1024,
    "NUM_STEPS": 256,
    "TOTAL_TIMESTEPS": 5e7,
    "UPDATE_EPOCHS": 4,
    "NUM_MINIBATCHES": 32,
    "GAMMA": 0.99,
    "GAE_LAMBDA": 0.95,
    "CLIP_EPS": 0.2,
    "ENT_COEF": 0.0,
    "VF_COEF": 0.5,
    "MAX_GRAD_NORM": 0.5,
    "ACTIVATION": "relu",
    "ANNEAL_LR": False,
    "NORMALIZE_ENV": True,
    "DEBUG": True,
}

rng = jax.random.PRNGKey(0)
rng, key_reset, key_act, key_step = jax.random.split(rng, 4)

env_params = EnvParams()
env = ClutteredEnv()
env = ClipAction(env)

# Reset the environment.
obs, state = env.reset(key_reset, env_params)

# Sample a random action.
action = env.action_space(env_params).sample(key_act)

# Perform the step transition.
n_obs, n_state, reward, done, _ = env.step(key_step, state, action, env_params)

In [3]:
from ppo_continuous_action import ActorCritic, make_train

In [4]:
rng = jax.random.PRNGKey(30)
train_jit = jax.jit(make_train(ClutteredEnv(), EnvParams(), config))

In [None]:
import time
import matplotlib.pyplot as plt
rng = jax.random.PRNGKey(42)
t0 = time.time()
out = jax.block_until_ready(train_jit(rng))
print(f"time: {time.time() - t0:.2f} s")
plt.plot(out["metrics"]["returned_episode_returns"].mean(-1).reshape(-1))
plt.xlabel("Update Step")
plt.ylabel("Return")
plt.show()

In [6]:
train_state = out['runner_state'][0]
env_state = out['runner_state'][1].env_state
env_state = NormalizeVecObsEnvState(
    mean=jnp.mean(env_state.mean, axis=0),
    var=jnp.mean(env_state.var, axis=0),
    count=env_state.count,
    env_state=env_state.env_state,
)

env_params = EnvParams(max_steps_in_episode=1000)

network = ActorCritic(
    action_dim=env.action_space(env_params).shape[0], 
    activation=config["ACTIVATION"]
)

def policy_inference(network, params, obs):
    pi, _ = network.apply(params, obs)
    return pi.mode()

In [7]:
key_reset, key_step = jax.random.split(rng)
state_seq = []

obs, state = env.reset(key_reset, env_params)
state_seq.append(state)

for i in range(1000):
    obs = (obs - env_state.mean) / jnp.sqrt(env_state.var + 1e-8)
    
    action = policy_inference(network, train_state.params, obs)
    
    rng, key_reset, key_step = jax.random.split(rng, 3)
    obs, state, reward, done, _ = env.step(key_step, state, action, env_params)
    
    if done:
        obs, state = env.reset(key_reset, env_params)
        
    state_seq.append(state)

In [None]:
import pygame
import numpy as np
from IPython.display import display, clear_output
from PIL import Image

x = np.array([state.x for state in state_seq])
y = np.array([state.y for state in state_seq])
theta = np.array([state.theta for state in state_seq])

target_pos = np.array([state.target_state for state in state_seq])
obs_pos = np.array([state.obs_state[:,0:2] for state in state_seq])

# Initialize pygame
pygame.init()

# Set up display
width, height = 600, 600
screen = pygame.Surface((width, height))  # Create an off-screen surface
clock = pygame.time.Clock()

# Basic Colors
BG = (15, 16, 31)
ROBOT = (255, 255, 255)
TARGET = (0, 255, 0)
OBS = (0, 101, 252)

radius = 10

# Render loop
for i in range(1000):
    screen.fill(BG)  # Clear screen
    coord = (
        int((x[i] + 10) * 600 / 20),
        int((y[i] + 10) * 600 / 20),
    )
    
    pygame.draw.circle(screen, ROBOT, coord, radius)
    for j in range(obs_pos[i].shape[0]):
        coord = (
            int((obs_pos[i][j, 0] + 10) * 600 / 20),
            int((obs_pos[i][j, 1] + 10) * 600 / 20),
        )
        pygame.draw.circle(screen, OBS, coord, radius)

    pygame.draw.circle(
        screen, TARGET,
        (int((target_pos[i, 0] + 10) * 600 / 20),
         int((target_pos[i, 1] + 10) * 600 / 20)),
        radius
    )

    # Convert surface to image and display in Jupyter
    arr = pygame.surfarray.array3d(screen)
    img = Image.fromarray(np.rot90(arr))
    clear_output(wait=True)
    display(img)
    # pygame.image.save(screen, f"frames/frame_{i:04d}.png")
    clock.tick(50)

pygame.quit()