In [None]:
import gymnasium as gym
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import time
from pathlib import Path
import base64
from IPython import display as ipythondisplay

from minigrid.wrappers import ImgObsWrapper
try:
    from minigrid.core.constants import COLOR_NAMES
except ImportError:
    from minigrid.minigrid_env import COLOR_NAMES


from stable_baselines3 import DQN, PPO
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.vec_env import DummyVecEnv, VecTransposeImage
from stable_baselines3.common.env_util import make_vec_env

base_path = "GITHUB PATH"
DQN_MODEL_PATH = f"{base_path}dqn_goto_redball_logic_custom_cnn.zip"
PPO_MODEL_PATH = f"{base_path}ppo_goto_redball_logic_custom_cnn.zip"

print(f"DQN Path: {Path(DQN_MODEL_PATH)}")
print(f"PPO Path: {Path(PPO_MODEL_PATH)}")


class GoToRedBallLogicWrapper(gym.Wrapper):
    """Applies penalties based on symbolic rules for the GoToRedBall task."""
    def __init__(self, env: gym.Env, violation_penalty: float = -0.5):
        super().__init__(env)
        self.penalty = violation_penalty
        try:
            self.red_color_index = COLOR_NAMES.index('red')
        except ValueError:
             print(f"Error: 'red' not found in COLOR_NAMES list: {COLOR_NAMES}")
             self.red_color_index = 0
        self.red_ball_pos = None
        assert hasattr(self.env.unwrapped, 'actions'), "Environment must have an 'actions' attribute"

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        grid = self.env.unwrapped.grid
        self.red_ball_pos = None
        ball_found_debug = False
        for i in range(grid.width):
            for j in range(grid.height):
                cell = grid.get(i, j)
                if cell:
                    raw_color_attr = getattr(cell, 'color', None)
                    cell_color_idx = -1
                    if isinstance(raw_color_attr, str):
                        try: cell_color_idx = COLOR_NAMES.index(raw_color_attr)
                        except ValueError: pass
                    elif isinstance(raw_color_attr, int):
                         if 0 <= raw_color_attr < len(COLOR_NAMES):
                             cell_color_idx = raw_color_attr

                    if hasattr(cell, 'type') and cell.type == 'ball':
                         ball_found_debug = True
                         if cell_color_idx == self.red_color_index:
                             self.red_ball_pos = (i, j)
                             break
            if self.red_ball_pos:
                 break
        info['logic_violation_count'] = 0
        return obs, info

    def step(self, action):
        obs, reward, terminated, truncated, info = self.env.step(action)
        logic_violation_type = None
        unwrapped_env = self.env.unwrapped

        if action == unwrapped_env.actions.pickup:
            fx, fy = unwrapped_env.front_pos
            cell_in_front = unwrapped_env.grid.get(fx, fy)
            if cell_in_front and hasattr(cell_in_front, 'type') and cell_in_front.type == 'ball':
                 raw_color_attr = getattr(cell_in_front, 'color', None)
                 cell_color_idx = -1
                 if isinstance(raw_color_attr, str):
                     try: cell_color_idx = COLOR_NAMES.index(raw_color_attr)
                     except ValueError: pass
                 elif isinstance(raw_color_attr, int):
                     if 0 <= raw_color_attr < len(COLOR_NAMES):
                        cell_color_idx = raw_color_attr

                 if cell_color_idx != self.red_color_index:
                    reward += self.penalty
                    logic_violation_type = 'pickup_wrong_ball'

        if action == unwrapped_env.actions.done:
            agent_pos = tuple(unwrapped_env.agent_pos)
            if self.red_ball_pos is None:
                 reward += self.penalty
                 logic_violation_type = 'premature_done_no_target'
            elif agent_pos != self.red_ball_pos:
                reward += self.penalty
                logic_violation_type = 'premature_done'

        info['logic_violation'] = logic_violation_type is not None
        info['logic_violation_type'] = logic_violation_type

        return obs, reward, terminated, truncated, info


class MiniGridCNN(BaseFeaturesExtractor):
    """Custom CNN Feature Extractor for MiniGrid-like environments (e.g., 7x7 input)."""
    def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 64):
        super().__init__(observation_space, features_dim)
        n_input_channels = observation_space.shape[0]
        input_height = observation_space.shape[1]
        input_width = observation_space.shape[2]

        self.cnn = nn.Sequential(
            nn.Conv2d(n_input_channels, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Flatten(),
        )

        with torch.no_grad():
            dummy_input = torch.as_tensor(observation_space.sample()[None]).float()
            n_flatten = self.cnn(dummy_input).shape[1]

        self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())

    def forward(self, observations: torch.Tensor) -> torch.Tensor:
        cnn_output = self.cnn(observations.float())
        return self.linear(cnn_output)


def make_env(env_id, violation_penalty=-0.5, seed=0, render_mode=None):
    """Utility function for creating and wrapping the environment."""
    def _init():
        env_kwargs = {'render_mode': render_mode} if render_mode else {}
        env = gym.make(env_id, **env_kwargs)
        env = ImgObsWrapper(env)
        env = GoToRedBallLogicWrapper(env, violation_penalty=violation_penalty)
        return env
    return _init

ENV_ID = 'BabyAI-GoToRedBallGrey-v0'
FEATURES_DIM = 64
VIOLATION_PENALTY = -0.5
N_EVAL_EPISODES = 50
MAX_STEPS_PER_EPISODE = 200
EVAL_SEED = 42

def evaluate_model_detailed(model, env_id, n_eval_episodes, max_steps, violation_penalty, seed, features_dim):
    """
    Evaluates an agent, collecting rewards, steps, and logic violations.
    Handles both PPO and DQN style observation processing if needed.
    """
    eval_env = make_env(env_id, violation_penalty=violation_penalty, seed=seed)() # Instantiate the callable

    is_image_space = isinstance(eval_env.observation_space, gym.spaces.Box) and len(eval_env.observation_space.shape) == 3
    transpose_needed = is_image_space

    episode_rewards = []
    episode_lengths = []
    episode_violations = []
    total_violations_map = {}

    for episode in range(n_eval_episodes):
        obs, info = eval_env.reset(seed=seed + episode)
        terminated = False
        truncated = False
        step = 0
        episode_reward = 0
        violations_this_episode = 0

        while not terminated and not truncated and step < max_steps:
            if transpose_needed:
                processed_obs = np.transpose(obs, (2, 0, 1))[None]

            else:
                processed_obs = obs[None]

            action, _states = model.predict(processed_obs, deterministic=True)
            action_to_step = action.item() if isinstance(action, (np.ndarray, np.number)) else action

            obs, reward, terminated, truncated, info = eval_env.step(action_to_step)

            episode_reward += reward
            step += 1

            if info.get('logic_violation', False):
                violations_this_episode += 1
                violation_type = info.get('logic_violation_type', 'unknown')
                total_violations_map[violation_type] = total_violations_map.get(violation_type, 0) + 1


        episode_rewards.append(episode_reward)
        episode_lengths.append(step)
        episode_violations.append(violations_this_episode)

    eval_env.close()

    mean_reward = np.mean(episode_rewards)
    std_reward = np.std(episode_rewards)
    mean_length = np.mean(episode_lengths)
    std_length = np.std(episode_lengths)
    mean_violations = np.mean(episode_violations)
    std_violations = np.std(episode_violations)

    results = {
        "mean_reward": mean_reward,
        "std_reward": std_reward,
        "mean_length": mean_length,
        "std_length": std_length,
        "mean_violations": mean_violations,
        "std_violations": std_violations,
        "total_violations_by_type": total_violations_map,
        "all_rewards": episode_rewards
    }
    return results


device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

custom_objects_dqn = {
    "policy_kwargs": dict(
        features_extractor_class=MiniGridCNN,
        features_extractor_kwargs=dict(features_dim=FEATURES_DIM),
    ),
}
custom_objects_ppo = {
     "policy_kwargs": dict(
        features_extractor_class=MiniGridCNN,
        features_extractor_kwargs=dict(features_dim=FEATURES_DIM),
    )
}

model_dqn = None
model_ppo = None

print("\nLoading DQN model...")
try:
    if Path(DQN_MODEL_PATH).exists():
         temp_env_dqn = make_vec_env(lambda: make_env(ENV_ID, violation_penalty=VIOLATION_PENALTY)(), n_envs=1, vec_env_cls=DummyVecEnv)
         temp_env_dqn = VecTransposeImage(temp_env_dqn)

         model_dqn = DQN.load(DQN_MODEL_PATH, env=temp_env_dqn, custom_objects=custom_objects_dqn, device=device)
         print("DQN model loaded successfully.")
    else:
        print(f"DQN model file not found at {DQN_MODEL_PATH}")
except Exception as e:
    print(f"Error loading DQN model: {e}")
    if 'temp_env_dqn' in locals() and temp_env_dqn is not None:
        temp_env_dqn.close()


print("\nLoading PPO model...")
try:
    if Path(PPO_MODEL_PATH).exists():
         temp_env_ppo = make_vec_env(lambda: make_env(ENV_ID, violation_penalty=VIOLATION_PENALTY)(), n_envs=1, vec_env_cls=DummyVecEnv)
         temp_env_ppo = VecTransposeImage(temp_env_ppo)

         model_ppo = PPO.load(PPO_MODEL_PATH, env=temp_env_ppo, custom_objects=custom_objects_ppo, device=device)
         print("PPO model loaded successfully.")
    else:
        print(f"PPO model file not found at {PPO_MODEL_PATH}")
except Exception as e:
    print(f"Error loading PPO model: {e}")
    if 'temp_env_ppo' in locals() and temp_env_ppo is not None:
         temp_env_ppo.close()


results_dqn = None
results_ppo = None

if model_dqn:
    print(f"\nEvaluating DQN model for {N_EVAL_EPISODES} episodes...")
    start_time = time.time()
    results_dqn = evaluate_model_detailed(
        model_dqn, ENV_ID, N_EVAL_EPISODES, MAX_STEPS_PER_EPISODE, VIOLATION_PENALTY, EVAL_SEED, FEATURES_DIM
    )
    print(f"DQN Evaluation finished in {time.time() - start_time:.2f} seconds.")
else:
    print("\nSkipping DQN evaluation as model failed to load.")

if model_ppo:
    print(f"\nEvaluating PPO model for {N_EVAL_EPISODES} episodes...")
    start_time = time.time()
    results_ppo = evaluate_model_detailed(
        model_ppo, ENV_ID, N_EVAL_EPISODES, MAX_STEPS_PER_EPISODE, VIOLATION_PENALTY, EVAL_SEED + N_EVAL_EPISODES, FEATURES_DIM # Use different seed offset
    )
    print(f"PPO Evaluation finished in {time.time() - start_time:.2f} seconds.")
else:
    print("\nSkipping PPO evaluation as model failed to load.")


comparison_data = []
if results_dqn:
    comparison_data.append({
        "Model": "DQN",
        "Mean Reward": results_dqn['mean_reward'],
        "Std Reward": results_dqn['std_reward'],
        "Mean Length": results_dqn['mean_length'],
        "Std Length": results_dqn['std_length'],
        "Mean Violations": results_dqn['mean_violations'],
        "Violation Details": results_dqn['total_violations_by_type']
    })
if results_ppo:
     comparison_data.append({
        "Model": "PPO",
        "Mean Reward": results_ppo['mean_reward'],
        "Std Reward": results_ppo['std_reward'],
        "Mean Length": results_ppo['mean_length'],
        "Std Length": results_ppo['std_length'],
        "Mean Violations": results_ppo['mean_violations'],
        "Violation Details": results_ppo['total_violations_by_type']
    })

if comparison_data:
    df_comparison = pd.DataFrame(comparison_data)
    df_comparison.set_index("Model", inplace=True)

    print("\n\n--- Performance Comparison ---")
    float_cols = ["Mean Reward", "Std Reward", "Mean Length", "Std Length", "Mean Violations"]
    for col in float_cols:
        if col in df_comparison.columns:
             df_comparison[col] = df_comparison[col].map('{:.2f}'.format)

    print(df_comparison)

    plot_data = []
    if results_dqn:
        for r in results_dqn['all_rewards']:
            plot_data.append({"Model": "DQN", "Episode Reward": r})
    if results_ppo:
        for r in results_ppo['all_rewards']:
             plot_data.append({"Model": "PPO", "Episode Reward": r})

    if plot_data:
        df_plot = pd.DataFrame(plot_data)
        plt.figure(figsize=(10, 6))
        sns.boxplot(x="Model", y="Episode Reward", data=df_plot)
        plt.title(f'Reward Distribution per Episode ({N_EVAL_EPISODES} Episodes)')
        plt.ylabel("Total Reward (including penalties)")
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        plt.show()
    else:
        print("\nNo data available for plotting reward distribution.")

else:
    print("\nNo evaluation results to display.")


print("\nComparison script finished.")