In [6]:
%load_ext autoreload
%autoreload 2
import pygame
import os
import numpy as np
from datetime import datetime
from stable_baselines3 import PPO
from gymnasium.wrappers import FlattenObservation,TimeLimit
from deform_rl.sim.Rectangle_env.environment import Rectangle1D

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
from gymnasium import ObservationWrapper
from gymnasium.spaces import Box,Dict


sim_cfg = {
    'width': 800,
    'height': 600,
    'FPS': 60,
    'gravity': 0,
    'damping': .15,
    'collision_slope': 0.01,
}

class CustomNormalizeObsrvation(ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        self.width = env.width
        self.height = env.height
        self.observation_space = Dict({
            'position': Box(low=-1, high=1, shape=(2,), dtype=np.float64),
            'velocity': Box(low=-1, high=1, shape=(2,), dtype=np.float64),
            'goal': Box(low=-1, high=1, shape=(2,), dtype=np.float64),
        })
    def observation(self, observation):
        mean = np.array([self.width, self.height]) / 2
        position = (observation['position'] - mean)/ [self.width, self.height]
        velocity = np.tanh(observation['velocity'])
        target = (observation['target'] - mean)/ [self.width, self.height]
        return {'position': position, 'velocity': velocity, 'goal': target}

In [8]:
def _init():
    # Base env
    env = Rectangle1D(sim_config=sim_cfg, threshold=30, oneD=False, render_mode='human')
    env = CustomNormalizeObsrvation(env)
    # Apply wrappers
    env = FlattenObservation(env)
    env = TimeLimit(env, max_episode_steps=1000)

    return env

In [10]:
# random pick actions and visualize
tenv = _init()
obs, _ = tenv.reset()
# tenv = eval_env
save_dir = "./saved_models"
t_model = PPO.load(os.path.join(save_dir, "best_model.zip"),force_reset=True)
cnt = 0
for i in range(10000):
    if cnt >= 1000:
        print("Killed by timeout")
        obs,_ = tenv.reset()
        cnt = 0
    action,_ = t_model.predict(obs, deterministic=True)
    obs, reward, done,truncated, info = tenv.step(action)
    tenv.render()
    if done:
        obs,_ = tenv.reset()
        print("Episode done: ", cnt)
        cnt=0
    if pygame.event.get(pygame.QUIT):
        break
    cnt +=1
tenv.close()