In [1]:
import gymnasium as gym
import numpy as np
from IPython.display import display, Video

from sklearn.neural_network import MLPRegressor
from sklearn.preprocessing import StandardScaler
import pickle

import os
os.environ['MUJOCO_GL']='egl'

In [2]:
def run_model(regressors, scaler, env_name, seed):
    env = gym.make(env_name, render_mode='rgb_array')
    env = gym.wrappers.RecordVideo(env, "videos", episode_trigger=lambda x: True)

    average_reward = 0.0
    for i in range(10):
        #obs = env.reset(seed=seed+i*100)[0]
        # setting fixed seed
        obs = env.reset(seed=seed)[0]
        
        action = np.zeros(env.action_space.shape[0])
     
        terminated, truncated = False, False
    
        while not (terminated or truncated):
            predicted_action = np.ones(shape=env.action_space.shape[0])
            for action in range(3):
                regressor, _ = regressors[f"action{action}"]
                scaled_obs = scaler.transform(obs.reshape(1,-1))
                predicted_action[action] = regressor.predict(scaled_obs).item()
            
            predicted_action = np.clip(predicted_action, -1, 1)
            print(predicted_action)
            obs, reward, terminated, truncated, _ = env.step(predicted_action)
            average_reward += reward

        #added for debugging
        break
    average_reward /= 10
    print(f"Average reward over the 10 episodes: {average_reward:.3f}")
    env.close()

In [3]:
model_path = 'Hopper-nn.pkl'

with open(model_path, 'rb') as f:
    regressors, scaler = pickle.load(f)
    run_model(regressors, scaler, "Hopper-v5", seed=3655591450)

  logger.warn(


[-0.5176572  -0.70567551  0.78241   ]
[-0.46465046 -0.88104536  0.95071343]
[-0.45579049 -1.          0.96491657]
[0.1361995  0.06157271 0.93857448]
[0.48665532 0.64668991 1.        ]
[0.56373953 0.40101423 1.        ]
[0.41299781 0.69946064 0.90208666]
[-0.01252179 -0.01708126 -0.0956989 ]
[ 0.12956679 -0.11903547 -0.21741082]
[ 0.16838857 -0.13021313 -0.36668975]
[ 0.11241111 -0.08944691 -0.41466396]
[ 0.0718824  -0.02538016 -0.33736196]
[ 0.00255368  0.04725299 -0.22628854]
[-0.06362484  0.14876955 -0.12817006]
[-0.11384929  0.26546951 -0.03861151]
[-0.13439038  0.3972249   0.01261008]
[-0.11539443  0.50830564  0.00824478]
[-0.11494823  0.54786027 -0.01692595]
[-0.14807356  0.56350661 -0.04332664]
[-0.19617057  0.57696302 -0.00224748]
[-0.23772535  0.53638359  0.02839433]
[-0.27486148  0.50133697  0.02612156]
[-0.28403638  0.45789051 -0.00258923]
[-0.4004231   0.57575915  0.05484132]
[-0.12275235  0.59130976  0.03709209]
[0.11087409 0.3716459  0.01442127]
[ 0.05537042  0.02986749 -0

In [5]:
# Display recorded video
display(Video(url='videos/rl-video-episode-0.mp4'))