In [None]:
import gymnasium as gym
import torch
import torch.nn as nn
import time
import numpy as np

from minigrid.wrappers import ImgObsWrapper
from minigrid.core.constants import COLOR_NAMES

from stable_baselines3 import PPO
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.evaluation import evaluate_policy

# Copied from our training script will be used for evaluation

class GoToRedBallLogicWrapper(gym.Wrapper):
    def __init__(self, env: gym.Env, violation_penalty: float = -0.5):
        super().__init__(env)
        self.penalty = violation_penalty
        try:
            self.red_color_index = COLOR_NAMES.index('red')
        except ValueError:
             print(f"Error: 'red' not found in COLOR_NAMES list: {COLOR_NAMES}")
             self.red_color_index = 0
        self.red_ball_pos = None
        assert hasattr(self.env.unwrapped, 'actions'), "Environment must have an 'actions' attribute"

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        grid = self.env.unwrapped.grid
        self.red_ball_pos = None
        ball_found_debug = False

        for i in range(grid.width):
            for j in range(grid.height):
                cell = grid.get(i, j)
                if cell:
                    raw_color_attr = getattr(cell, 'color', None)
                    cell_type = getattr(cell, 'type', 'N/A')
                    cell_color_idx = -1
                    if isinstance(raw_color_attr, str):
                        try: cell_color_idx = COLOR_NAMES.index(raw_color_attr)
                        except ValueError: pass
                    elif isinstance(raw_color_attr, int):
                         if 0 <= raw_color_attr < len(COLOR_NAMES):
                             cell_color_idx = raw_color_attr

                    if hasattr(cell, 'type') and cell.type == 'ball':
                         ball_found_debug = True
                         if cell_color_idx == self.red_color_index:
                             self.red_ball_pos = (i, j)
                             break
            if self.red_ball_pos:
                 break

        return obs, info

    def step(self, action):
        obs, reward, terminated, truncated, info = self.env.step(action)
        info['logic_violation'] = None
        unwrapped_env = self.env.unwrapped

        if action == unwrapped_env.actions.pickup:
            fx, fy = unwrapped_env.front_pos
            cell_in_front = unwrapped_env.grid.get(fx, fy)
            if cell_in_front and hasattr(cell_in_front, 'type') and cell_in_front.type == 'ball':
                 raw_color_attr = getattr(cell_in_front, 'color', None)
                 cell_color_idx = -1
                 if isinstance(raw_color_attr, str):
                     try: cell_color_idx = COLOR_NAMES.index(raw_color_attr)
                     except ValueError: pass
                 elif isinstance(raw_color_attr, int):
                     if 0 <= raw_color_attr < len(COLOR_NAMES):
                        cell_color_idx = raw_color_attr

                 if cell_color_idx != self.red_color_index:
                    reward += self.penalty
                    info['logic_violation'] = 'pickup_wrong_ball'

        if action == unwrapped_env.actions.done:
            agent_pos = tuple(unwrapped_env.agent_pos)
            if self.red_ball_pos is None:
                 reward += self.penalty
                 info['logic_violation'] = 'premature_done_no_target'
            elif agent_pos != self.red_ball_pos:
                reward += self.penalty
                info['logic_violation'] = 'premature_done'

        return obs, reward, terminated, truncated, info

class MiniGridCNN(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 64):
        super().__init__(observation_space, features_dim)
        n_input_channels = observation_space.shape[0]
        self.cnn = nn.Sequential(
            nn.Conv2d(n_input_channels, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Flatten(),
        )
        with torch.no_grad():
            dummy_input = torch.as_tensor(observation_space.sample()[None]).float()
            n_flatten = self.cnn(dummy_input).shape[1]
        self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())

    def forward(self, observations: torch.Tensor) -> torch.Tensor:
        return self.linear(self.cnn(observations.float()))

def make_env(env_id, violation_penalty=-0.5, seed=0, render_mode=None):
    """ Note: seed is mostly handled by VecEnv wrappers now """
    def _init():
        env_kwargs = {'render_mode': render_mode} if render_mode else {}
        env = gym.make(env_id, **env_kwargs)
        env = ImgObsWrapper(env)
        env = GoToRedBallLogicWrapper(env, violation_penalty=violation_penalty)
        return env
    return _init

ENV_ID = 'BabyAI-GoToRedBallGrey-v0'
MODEL_SAVE_PATH = "ppo_goto_redball_logic_custom_cnn.zip"
VIOLATION_PENALTY = -0.5
FEATURES_DIM = 64

# evaluation script
if __name__ == "__main__":
    N_EVAL_EPISODES = 50
    EVAL_NUM_ENVS = 1

    print(f"\n--- Starting Evaluation ---")
    print(f"Loading model: {MODEL_SAVE_PATH}")
    print(f"Evaluating on env: {ENV_ID} for {N_EVAL_EPISODES} episodes.")

    eval_env = DummyVecEnv([make_env(ENV_ID, violation_penalty=VIOLATION_PENALTY, seed=1000)]) # Use a fixed seed offset for eval

    try:
        model = PPO.load(MODEL_SAVE_PATH, env=eval_env)
    except Exception as e:
        print(f"Error loading model: {e}")
        print("Ensure the model file exists and was saved correctly.")
        eval_env.close()
        exit()

    mean_reward, std_reward = evaluate_policy(model, eval_env, n_eval_episodes=N_EVAL_EPISODES, deterministic=True, render=False)

    print(f"\n--- Evaluation Results ---")
    print(f"Mean reward over {N_EVAL_EPISODES} episodes: {mean_reward:.2f} +/- {std_reward:.2f}")
    print("Note: 'Reward' here includes penalties from the wrapper.")
    print("Higher reward generally means better performance (fewer penalties, reaching goal faster).")

    eval_env.close()
    print("\nEvaluation finished.")

    # Visualise and record the game into a video format
from gymnasium.wrappers import RecordVideo
from IPython import display as ipythondisplay
from pathlib import Path
import base64

def show_video(video_path, video_width = 600):
    video_file = open(video_path, "r+b").read()
    video_url = f"data:video/mp4;base64,{base64.b64encode(video_file).decode()}"
    return ipythondisplay.HTML(f"""<video width={video_width} controls><source src="{video_url}"></video>""")

if __name__ == "__main__":
    N_VISUAL_EPISODES = 3
    MAX_STEPS_PER_EPISODE = 200
    VIDEO_FOLDER = './videos'

    print(f"\n--- Starting Visualization (Colab Video Recording) ---")
    print(f"Loading model: {MODEL_SAVE_PATH}")

    base_env = make_env(ENV_ID, violation_penalty=VIOLATION_PENALTY, render_mode='rgb_array')()

    viz_env = RecordVideo(base_env, VIDEO_FOLDER, episode_trigger=lambda _: True, name_prefix=f"ppo-{ENV_ID}")

    device = "cuda" if torch.cuda.is_available() else "cpu"
    try:
        model = PPO.load(MODEL_SAVE_PATH, device=device)
    except Exception as e:
        print(f"Error loading model: {e}")
        viz_env.close()
        exit()

    for episode in range(N_VISUAL_EPISODES):
        print(f"\nStarting visualization episode {episode + 1}/{N_VISUAL_EPISODES}")
        obs, info = viz_env.reset()
        terminated = False
        truncated = False
        step = 0
        total_reward = 0
        while not terminated and not truncated and step < MAX_STEPS_PER_EPISODE:
            action, _states = model.predict(obs, deterministic=True)
            obs, reward, terminated, truncated, info = viz_env.step(action)
            total_reward += reward
            step += 1
        print(f"Episode {episode+1} finished after {step} steps. Total Reward: {total_reward:.2f}")

    viz_env.close()
    print(f"\nVisualization finished. Videos saved in {VIDEO_FOLDER}")

    video_files = sorted(Path(VIDEO_FOLDER).glob("*.mp4"))
    if video_files:
        latest_video = video_files[-1]
        print(f"Displaying video: {latest_video}")
        display(show_video(latest_video))
    else:
        print(f"No video files found in {VIDEO_FOLDER}")