In [None]:
import gym
import numpy as np
from joblib import load
import csv
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from gym.wrappers import GrayScaleObservation
from stable_baselines3.common.vec_env import VecFrameStack, DummyVecEnv, VecEnvWrapper
from stable_baselines3.common.env_util import make_vec_env

# Create environment
env = gym.make('MontezumaRevenge', render_mode="rgb_array")
env.reward_range = (-2, 2) 
env = GrayScaleObservation(env, keep_dim=True)
env = DummyVecEnv([lambda: env]) # Create a vectorized environment for parallelized training using multiple envs
env = VecFrameStack(env, 4, channels_order='last') # Stack consecutive frames

# Load the trained models
ppo_model = PPO.load("./no_exploration_ppo_models/no_expl_ppo_montezuma_2246656_steps.zip")

# Set the number of test episodes
num_test_episodes = 1000

# Function to run a test episode and collect data
def test_agent(model, env, num_episodes, csv_filename):
    cumulative_score = 0
    with open(csv_filename, mode='a', newline='') as file:
        writer = csv.writer(file)
        # Write the header if it's the first time writing to the file
        if file.tell() == 0:
            writer.writerow(["Episode", "Total Reward", "Cumulative Score"])
        
        for episode in range(num_episodes):
            obs = env.reset()  # Reset environment
            if isinstance(obs, tuple):  # Check if reset returns a tuple
                obs = obs[0]  # Extract the observation from the tuple
            done = False
            total_reward = 0
            while not done:
                action, _states = model.predict(obs)
                obs, reward, done, info = env.step(action)
                total_reward += reward
                cumulative_score += reward
            
            # Write episode data to the CSV
            writer.writerow([episode + 1, total_reward, cumulative_score])

    return cumulative_score

# Test PPO Model
ppo_cumulative_score = test_agent(ppo_model, env, num_test_episodes, 'test_results.csv') 