In [None]:
# MARL with IQL (CNN) - Pursuit Environment
# Independent Q-Learning adapted for the PettingZoo SISL Pursuit environment.
# This version uses a CNN for processing image-like observations from Pursuit.
# Basic imports
import torch
import torch.nn as nn  # Though not directly used here, agent.py and dqn.py use it
import numpy as np
import random
from collections import OrderedDict  # For CqlAgentConfig if used, not primary here
from itertools import count
from pprint import pprint
import gymnasium as gym  # PettingZoo uses Gymnasium spaces
# Environment and Agent imports
from pettingzoo.sisl import pursuit_v4  # MODIFIED: Import Pursuit environment
from agent import DqnAgent, DqnAgentConfig  # MODIFIED: Assuming agent.py is updated
# DQN_CNN and DQN_MLP are imported within agent.py from dqn.py

# Plotting imports
import matplotlib
import matplotlib.pyplot as plt

is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display
plt.ion()

# Determine device
DEVICE = torch.device(
    'cuda' if torch.cuda.is_available() else
    'mps' if torch.backends.mps.is_available() else
    'cpu'
)
print(f"Using device: {DEVICE}")

# Environment Setup
# MODIFIED: Environment parameters for Pursuit
# These are example parameters, tune as needed.
MAX_CYCLES_ENV = 500
N_EVADERS = 5
N_PURSUERS = 4  # Number of agents we will control
OBS_RANGE = 7  # Affects observation HxW; H=W=OBS_RANGE
N_CATCH = 2  # Number of pursuers needed to catch an evader

# MODIFIED: Instantiate Pursuit environment
env = pursuit_v4.parallel_env(
    max_cycles=MAX_CYCLES_ENV,
    n_evaders=N_EVADERS,
    n_pursuers=N_PURSUERS,
    obs_range=OBS_RANGE,
    n_catch=N_CATCH,
    render_mode='rgb_array'  # For potential rendering/observation consistency
)
env.reset()  # Call reset once to initialize agents list

# MODIFIED: Identify agents to be controlled (pursuers)
# In pursuit_v4, agents are named e.g., "pursuer_0", "evader_0"
ALL_AGENT_IDS_FROM_ENV = env.agents
CONTROLLED_AGENT_KEYS = sorted([agent_id for agent_id in ALL_AGENT_IDS_FROM_ENV if "pursuer" in agent_id])

if not CONTROLLED_AGENT_KEYS:
    raise ValueError("No pursuers found to control. Check N_PURSUERS or agent naming in the environment.")
num_controlled_agents = len(CONTROLLED_AGENT_KEYS)
print(f"Controlling {num_controlled_agents} pursuer(s): {CONTROLLED_AGENT_KEYS}")

# MODIFIED: Extract observation shapes and action dimensions for controlled agents
# Pursuit observations are typically image-like: Box(0, 255, (H, W, C), np.uint8)
# We need to convert to PyTorch's NCHW format. Config will store CHW.
obs_shapes = {}  # Stores (C, H, W)
act_dims = {}

for agent_key in CONTROLLED_AGENT_KEYS:
    obs_space = env.observation_space(agent_key)
    act_space = env.action_space(agent_key)

    if not isinstance(obs_space, gym.spaces.Box) or len(obs_space.shape) != 3:
        raise ValueError(f"Expected Box observation space with 3 dims (H,W,C) for agent {agent_key}, got {obs_space}")

    h, w, c = obs_space.shape  # Original HWC
    obs_shapes[agent_key] = (c, h, w)  # Store as CHW for PyTorch
    act_dims[agent_key] = int(act_space.n)

print(f"Observation Shapes (C, H, W) for controlled agents: {obs_shapes}")
print(f"Action Dimensions for controlled agents: {act_dims}")

# Example observation check
# temp_states_dict, _ = env.reset()
# if CONTROLLED_AGENT_KEYS:
#     example_obs_raw = temp_states_dict[CONTROLLED_AGENT_KEYS[0]]
#     print(f"Example raw observation for {CONTROLLED_AGENT_KEYS[0]} - shape: {example_obs_raw.shape}, dtype: {example_obs_raw.dtype}")
# del temp_states_dict, example_obs_raw

# Agent Initialization
USE_CNN = True  # MODIFIED: Set to True to use DQN_CNN

# MODIFIED: Agent Configuration (tune these hyperparameters)
AGENT_CONFIGS = {}
for agent_key in CONTROLLED_AGENT_KEYS:
    AGENT_CONFIGS[agent_key] = DqnAgentConfig(
        obs_shape=obs_shapes[agent_key] if USE_CNN else None,  # Pass CHW shape
        obs_dim=int(np.prod(obs_shapes[agent_key])) if not USE_CNN else None,  # Pass flattened size if MLP
        act_dim=act_dims[agent_key],
        hidden_dim=256,  # Num units in FC layer after CNN, or hidden_dim for MLP
        batch_size=32,  # Might need adjustment based on memory
        lr=1e-4,  # CNNs often benefit from smaller LRs
        grad_clip_value=1.0,  # Common for DQN
        gamma=0.99,
        eps_start=1.0,
        eps_decay=0.9995,  # Slower decay for more complex tasks
        eps_min=0.05,
        mem_size=50000,  # Replay memory size
        use_cnn=USE_CNN
    )

# Create agent instances
cur_agents = {
    agent_key: DqnAgent(
        sid=agent_key,
        config=AGENT_CONFIGS[agent_key],
        act_sampler=env.action_space(agent_key).sample,  # Pass the sampler
        device=DEVICE
    )
    for agent_key in CONTROLLED_AGENT_KEYS
}

if cur_agents:
    print(f"Created {len(cur_agents)} DqnAgent(s). First agent's policy network:")
    print(list(cur_agents.values())[0].policy_net)
else:
    print("No agents were created.")


# Helper Function for State Preprocessing (for CNN)
def preprocess_observation(obs_hwc: np.ndarray, device: torch.device) -> torch.Tensor:
    """Converts HWC observation to NCHW PyTorch tensor and normalizes."""
    if obs_hwc is None:  # Should not happen if agent is active
        raise ValueError("Received None observation for an active agent.")
    # Normalize to [0, 1]
    obs_normalized = np.array(obs_hwc, dtype=np.float32) / 255.0
    # Permute HWC to CHW and add batch dimension N=1
    obs_tensor_chw = torch.tensor(obs_normalized, device=device).permute(2, 0, 1)
    return obs_tensor_chw.unsqueeze(0)  # NCHW


# Evaluation Function
# Assuming gym is imported if needed by PettingZoo, or import directly



def eval_agent(
        eval_env_lambda: Callable,  # Lambda function to create a new eval environment
        dqn_agents_dict: dict[str, DqnAgent],
        networks_to_use: dict[str, nn.Module],  # agent_key -> network (e.g., policy_net or target_net)
        n_episodes: int = 1,
        render_mode_eval: str | None = None  # 'human' or 'rgb_array'
) -> dict[str, list[float]]:
    cumulative_rewards_per_agent = {agent_key: [] for agent_key in dqn_agents_dict.keys()}
    eval_device = list(dqn_agents_dict.values())[0].device  # Get device from first agent

    for i_episode in range(n_episodes):
        eval_env = eval_env_lambda()  # Create a fresh environment
        if render_mode_eval:  # Set render mode for this specific eval env instance
            # This might not work for all PettingZoo envs if render_mode is set at init
            # For pursuit_v4, render_mode is set at init. We'll rely on that.
            # If you want to change it dynamically, env might need a set_render_mode method.
            pass

        raw_states_dict, _ = eval_env.reset()

        # MODIFIED: Preprocess initial states for controlled agents
        current_states_processed = {
            agent_key: preprocess_observation(raw_states_dict[agent_key], eval_device)
            for agent_key in dqn_agents_dict.keys() if agent_key in raw_states_dict
        }

        # Dones only for controlled agents
        episode_dones = {agent_key: False for agent_key in dqn_agents_dict.keys()}
        episode_cumulative_rewards = {agent_key: 0.0 for agent_key in dqn_agents_dict.keys()}

        for t in count():
            if render_mode_eval == 'human':
                eval_env.render()  # Render if in human mode

            actions_to_env = {}
            for agent_key, agent_instance in dqn_agents_dict.items():
                if agent_key in current_states_processed and not episode_dones[agent_key]:
                    # Use greedy action selection with the specified network
                    action_tensor = agent_instance.select_action_greedy(
                        current_states_processed[agent_key],
                        networks_to_use[agent_key]
                    )
                    actions_to_env[agent_key] = action_tensor.item()

            if not actions_to_env:  # All controlled agents are done
                break

            raw_next_obs_dict, raw_rewards_dict, raw_terminations_dict, raw_truncations_dict, _ = eval_env.step(
                actions_to_env)

            next_states_processed = {}
            for agent_key in dqn_agents_dict.keys():
                if episode_dones[agent_key]:  # If agent was already done, skip
                    if agent_key in current_states_processed:  # Keep its last state if needed, or None
                        next_states_processed[agent_key] = current_states_processed[agent_key]
                    continue

                # Update rewards and dones for controlled agents
                episode_cumulative_rewards[agent_key] += raw_rewards_dict.get(agent_key, 0)

                terminated = raw_terminations_dict.get(agent_key, False)
                truncated = raw_truncations_dict.get(agent_key, False)
                episode_dones[agent_key] = terminated or truncated

                if not episode_dones[agent_key] and agent_key in raw_next_obs_dict:
                    next_states_processed[agent_key] = preprocess_observation(raw_next_obs_dict[agent_key], eval_device)
                elif agent_key in current_states_processed:  # If done, can reuse last state or set to None
                    next_states_processed[agent_key] = current_states_processed[
                        agent_key]  # or None, if network handles it

            current_states_processed = next_states_processed

            if all(episode_dones.values()):
                break

        if render_mode_eval == 'human':
            eval_env.close()  # Close env if rendered explicitly, not strictly needed for parallel_env auto-reset

        for agent_key in dqn_agents_dict.keys():
            cumulative_rewards_per_agent[agent_key].append(episode_cumulative_rewards[agent_key])

    return cumulative_rewards_per_agent


def get_agent_wise_cumulative_rewards(cumulative_rewards_dict: dict[str, list[float]]) -> dict[str, float]:
    return {
        agent_key: sum(rewards_list) / len(rewards_list) if rewards_list else 0.0
        for agent_key, rewards_list in cumulative_rewards_dict.items()
    }


# Baseline Evaluation (Before Training)
if cur_agents:
    print("Running baseline evaluation...")
    # Lambda to create pursuit environment for evaluation
    eval_env_lambda_pursuit = lambda: pursuit_v4.parallel_env(
        max_cycles=MAX_CYCLES_ENV, n_evaders=N_EVADERS, n_pursuers=num_controlled_agents,
        obs_range=OBS_RANGE, n_catch=N_CATCH, render_mode='rgb_array'  # Eval does not render human by default
    )

    baseline_eval_res = eval_agent(
        eval_env_lambda_pursuit,
        cur_agents,
        dqns={agent.sid: agent.target_net for agent in cur_agents.values()},  # Use target_net for stable eval
        n_episodes=10  # Number of episodes for baseline evaluation
    )
    avg_baseline_res = get_agent_wise_cumulative_rewards(baseline_eval_res)
    pprint(f"Average baseline rewards per agent: {avg_baseline_res}")
    all_avg_baseline_res = sum(avg_baseline_res.values()) / len(avg_baseline_res) if avg_baseline_res else 0.0
    print(f"Overall average baseline reward: {all_avg_baseline_res}")
else:
    print("No agents to evaluate for baseline.")
    all_avg_baseline_res = float('-inf')


# Plotting Module
def plot_episodes(avg_returns: list[float], title_suffix: str = ''):
    plt.figure(1)
    plt.clf()
    returns_t = torch.tensor(avg_returns, dtype=torch.float)
    plt.title(f'Training... {title_suffix}')
    plt.xlabel('Episode')
    plt.ylabel('Avg. Return (across all controlled agents)')
    plt.plot(returns_t.numpy())

    if len(returns_t) >= 10:  # Plot 10-episode rolling mean
        means = returns_t.unfold(0, 10, 1).mean(1).reshape(-1)
        means = torch.cat((torch.full((9,), float('nan')), means))  # Pad with NaNs for alignment
        plt.plot(means.numpy(), label='10-ep Avg.')
    plt.legend()
    plt.pause(0.001)
    if is_ipython:
        display.display(plt.gcf())
        display.clear_output(wait=True)


# Training Setup
N_TRAIN_EPISODES = 500  # MODIFIED: Number of training episodes
MAX_TRAIN_EPISODE_STEPS = MAX_CYCLES_ENV  # Max steps per episode, from env
DQN_UPDATE_TARGET_FREQ_STEPS = 200  # MODIFIED: How often (in total steps) to consider updating target net

# Lambda for creating training environment
train_env_lambda = lambda: pursuit_v4.parallel_env(
    max_cycles=MAX_CYCLES_ENV, n_evaders=N_EVADERS, n_pursuers=num_controlled_agents,
    obs_range=OBS_RANGE, n_catch=N_CATCH, render_mode='rgb_array'
)
# Re-initialize main training env using the lambda for consistency
env = train_env_lambda()


# Target Network Update Logic
def update_all_agents_target_dqns(
        current_agents: dict[str, DqnAgent],
        eval_env_fn: Callable,
        current_best_mean_reward: float,
        n_eval_episodes: int = 5
) -> float:
    """Evaluates policy nets and updates target nets if performance improved."""
    print("Evaluating policy networks for potential target network update...")
    policy_eval_res = eval_agent(
        eval_env_fn,
        current_agents,
        dqns={agent.sid: agent.policy_net for agent in current_agents.values()},
        n_episodes=n_eval_episodes
    )
    avg_policy_rewards = get_agent_wise_cumulative_rewards(policy_eval_res)
    overall_avg_policy_reward = sum(avg_policy_rewards.values()) / len(
        avg_policy_rewards) if avg_policy_rewards else float('-inf')

    print(
        f"Policy net eval: current avg reward {overall_avg_policy_reward:.2f} vs best mean {current_best_mean_reward:.2f}")
    if overall_avg_policy_reward > current_best_mean_reward:
        print(f"Improvement found! Updating target networks. New best mean: {overall_avg_policy_reward:.2f}")
        current_best_mean_reward = overall_avg_policy_reward
        for agent in current_agents.values():
            agent.update_target_network()
    else:
        print("No improvement, target networks not updated.")
    return current_best_mean_reward


# Training Loop
if not cur_agents:
    print("No agents defined. Skipping training.")
else:
    total_steps_done = 0
    best_overall_mean_reward = all_avg_baseline_res  # Initialize with baseline
    episode_mean_returns_log = [all_avg_baseline_res] if all_avg_baseline_res != float('-inf') else []

    print(f"Starting training for {N_TRAIN_EPISODES} episodes...")
    for i_episode in range(N_TRAIN_EPISODES):
        raw_states_dict, _ = env.reset()

        current_states_processed = {
            agent_key: preprocess_observation(raw_states_dict[agent_key], DEVICE)
            for agent_key in CONTROLLED_AGENT_KEYS if agent_key in raw_states_dict
        }

        # Dones for controlled agents for the current episode
        episode_dones = {agent_key: False for agent_key in CONTROLLED_AGENT_KEYS}
        episode_total_reward_for_plot = 0
        num_active_agents_for_plot = len(CONTROLLED_AGENT_KEYS)

        for t_step in range(MAX_TRAIN_EPISODE_STEPS):
            actions_for_env_step = {}
            current_actions_this_step = {}  # To store action tensors for memory

            for agent_key in CONTROLLED_AGENT_KEYS:
                if not episode_dones[agent_key] and agent_key in current_states_processed:
                    agent = cur_agents[agent_key]
                    action_tensor = agent.select_action(current_states_processed[agent_key])
                    actions_for_env_step[agent_key] = action_tensor.item()
                    current_actions_this_step[agent_key] = action_tensor

            if not actions_for_env_step:  # All controlled agents are done
                break

            raw_next_obs_dict, raw_rewards_dict, raw_terminations_dict, raw_truncations_dict, _ = env.step(
                actions_for_env_step)

            next_states_processed_this_step = {}

            for agent_key in CONTROLLED_AGENT_KEYS:
                agent = cur_agents[agent_key]

                # If agent was already done, its state for memory is its last known state, next_state is None
                if episode_dones[agent_key]:
                    if agent_key in current_actions_this_step:  # Only memorize if it took an action
                        agent.memorize(
                            current_states_processed[agent_key],
                            current_actions_this_step[agent_key],
                            None,  # Terminal state
                            torch.tensor([[raw_rewards_dict.get(agent_key, 0)]], device=DEVICE, dtype=torch.float32)
                        )
                    continue  # Skip further processing for this already-done agent

                # Process current step's outcome for active agents
                reward_val = raw_rewards_dict.get(agent_key, 0)
                episode_total_reward_for_plot += reward_val  # Summing rewards for plotting average later

                terminated = raw_terminations_dict.get(agent_key, False)
                truncated = raw_truncations_dict.get(agent_key, False)
                is_done_this_step = terminated or truncated

                next_state_for_memory = None
                if not is_done_this_step and agent_key in raw_next_obs_dict:
                    processed_next_obs = preprocess_observation(raw_next_obs_dict[agent_key], DEVICE)
                    next_states_processed_this_step[agent_key] = processed_next_obs
                    next_state_for_memory = processed_next_obs

                # Memorize, only if an action was taken by this agent this step
                if agent_key in current_actions_this_step:
                    agent.memorize(
                        current_states_processed[agent_key],
                        current_actions_this_step[agent_key],
                        next_state_for_memory,
                        torch.tensor([[reward_val]], device=DEVICE, dtype=torch.float32)
                    )

                agent.train()  # Try to train agent
                episode_dones[agent_key] = is_done_this_step  # Update done status

            current_states_processed = next_states_processed_this_step  # Move to next state

            # Update target network periodically
            if total_steps_done % DQN_UPDATE_TARGET_FREQ_STEPS == 0 and total_steps_done > 0:
                best_overall_mean_reward = update_all_agents_target_dqns(
                    cur_agents, train_env_lambda, best_overall_mean_reward
                )

            for agent in cur_agents.values():  # Epsilon decay for all controlled agents
                agent.update_eps()

            total_steps_done += 1

            if all(episode_dones.values()):  # If all controlled agents are done
                break

        # End of episode actions
        # Always try to update target networks at the end of an episode based on policy net performance
        best_overall_mean_reward = update_all_agents_target_dqns(
            cur_agents, train_env_lambda, best_overall_mean_reward
        )

        # Log episode return for plotting (average across controlled agents)
        avg_episode_return = episode_total_reward_for_plot / num_controlled_agents if num_controlled_agents > 0 else 0
        episode_mean_returns_log.append(avg_episode_return)
        plot_episodes(episode_mean_returns_log)

        print(
            f"Episode {i_episode + 1}/{N_TRAIN_EPISODES} finished. Avg Return: {avg_episode_return:.2f}. Epsilon (agent 0): {cur_agents[CONTROLLED_AGENT_KEYS[0]].eps:.3f}. Total steps: {total_steps_done}")

    print("Training finished.")
    plot_episodes(episode_mean_returns_log, title_suffix=" - Final")  # Final plot
    plt.ioff()  # Turn off interactive plotting
    plt.show()

# Post-Training Evaluation
if cur_agents:
    print("Running post-training evaluation...")
    post_train_eval_res = eval_agent(
        train_env_lambda,  # Use the same lambda as training for consistency
        cur_agents,
        dqns={agent.sid: agent.target_net for agent in cur_agents.values()},  # Evaluate final target_net
        n_episodes=20
    )
    avg_post_train_res = get_agent_wise_cumulative_rewards(post_train_eval_res)
    pprint(f"Average post-training rewards per agent: {avg_post_train_res}")
    all_avg_post_train_res = sum(avg_post_train_res.values()) / len(avg_post_train_res) if avg_post_train_res else 0.0
    print(f"Overall average post-training reward: {all_avg_post_train_res}")
else:
    print("No agents to evaluate post-training.")

# Rendered Evaluation (Optional)
# if cur_agents:
#     print("Running rendered evaluation...")
#     # For rendered eval, create an env with 'human' mode if supported and desired
#     rendered_eval_env_lambda = lambda: pursuit_v4.parallel_env(
#         max_cycles=MAX_CYCLES_ENV, n_evaders=N_EVADERS, n_pursuers=num_controlled_agents,
#         obs_range=OBS_RANGE, n_catch=N_CATCH, render_mode='human' # Set to human for visual
#     )
#     eval_agent(
#         rendered_eval_env_lambda,
#         cur_agents,
#         dqns={agent.sid: agent.target_net for agent in cur_agents.values()},
#         n_episodes=1,
#         render_mode_eval='human' # Pass render mode to eval_agent
#     )
#     print("Rendered evaluation finished. (Close the window if it opened)")