In [1]:
import gymnasium as gym
import numpy as np
from IPython.display import display, Video

from sklearn.neural_network import MLPRegressor
from sklearn.preprocessing import StandardScaler
import pickle

import os
os.environ['MUJOCO_GL']='egl'

In [6]:
def run_model(regressors, scaler, env_name, seed):
    env = gym.make(env_name, render_mode='rgb_array')
    env = gym.wrappers.RecordVideo(env, "videos", episode_trigger=lambda x: True)

    average_reward = 0.0
    for i in range(10):
        #obs = env.reset(seed=seed+i*100)[0]
        # setting fixed seed
        obs = env.reset(seed=seed)[0]
        
        action = np.zeros(env.action_space.shape[0])
     
        terminated, truncated = False, False
    
        while not (terminated or truncated):
            predicted_action = np.ones(shape=env.action_space.shape[0])
            for action in range(3):
                regressor, _ = regressors[f"action{action}"]
                scaled_obs = scaler.transform(obs.reshape(1,-1))
                predicted_action[action] = regressor.predict(scaled_obs).item()
            
            predicted_action = np.clip(predicted_action, -1, 1)
            print(predicted_action)
            obs, reward, terminated, truncated, _ = env.step(predicted_action)
            average_reward += reward

        #added for debugging
        break
    average_reward /= 10
    print(f"Average reward over the 10 episodes: {average_reward:.3f}")
    env.close()

In [14]:
model_path = 'Hopper-nn.pkl'

with open(model_path, 'rb') as f:
    regressors, scaler = pickle.load(f)
    run_model(regressors, scaler, "Hopper-v5", seed=24012000)

  logger.warn(


[-0.52649476 -0.64370866  0.79777801]
[-0.46175129 -0.82826961  0.96724233]
[-0.44824957 -0.98248908  0.98764176]
[0.11861941 0.01631967 0.96761764]
[0.50101383 0.6413175  1.        ]
[0.57505322 0.39024551 1.        ]
[0.39918131 0.69366266 0.88208748]
[-0.06441767  0.0165815  -0.03956992]
[ 0.10261173 -0.12674334 -0.16027298]
[ 0.18249423 -0.13695611 -0.31779314]
[ 0.13512995 -0.1201026  -0.4168805 ]
[ 0.08951256 -0.06743977 -0.37862422]
[ 0.02946282 -0.00225615 -0.270189  ]
[-0.0352912   0.09450354 -0.16258565]
[-0.08943338  0.21704459 -0.05861215]
[-0.11859188  0.3528368   0.00370588]
[-0.10909224  0.48500773  0.01179472]
[-0.10342255  0.54739221 -0.01793614]
[-0.13926432  0.55278893 -0.04386773]
[-0.18733551  0.57205596 -0.00084239]
[-0.22927366  0.54151448  0.0413753 ]
[-0.26078024  0.48957612  0.05745178]
[-0.5738273   0.67549602  0.21442675]
[-0.50062838  0.76614196  0.20264868]
[0.26568351 0.5763565  0.07391363]
[-0.16517423  0.41949224  0.04874044]
[ 0.2543509  -0.00815018  0

In [13]:
# Display recorded video
display(Video(url='videos/rl-video-episode-0.mp4'))