# Active Avoidance Task Visualization Notebook

This notebook allows you to train an agent for the active avoidance task and then visualize its performance without having to retrain the agent each time.

## 1. Import Libraries

In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
import pickle
from IPython.display import display, HTML

# Import the environment
from envs.active_avoidance_env_test import ActiveAvoidanceEnv2D

# Import agents
from agents.ppo_agent import PPOAgent
from agents.maxent_agent import MaxEntAgent, FisherMaxEntAgent
from agents.trpo_agent import TRPOAgent

# Import the visualization functions
from utils.avoidance_visualization import (
    plot_avoidance_training_curves,
    plot_avoidance_trajectories_2d,
    plot_avoidance_heatmap_2d,
    plot_avoidance_trajectory_step_by_step,
    plot_multiple_avoidance_trajectories
)

## 2. Configuration

Define your configuration parameters here. You can modify these to experiment with different settings.

In [2]:
# Set random seeds for reproducibility
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

# Ensure plots directory exists
os.makedirs('plots', exist_ok=True)
os.makedirs('saved_models', exist_ok=True)

# --- Configuration ---
config = {
    # Environment params
    'height': 10,
    'width': 20,
    'tone_duration_steps': 30,
    'shock_delay_steps': 30,
    'max_steps_per_episode': 100,
    'initial_task': ActiveAvoidanceEnv2D.AVOID_TONE_1,

    # Agent params
    'agent_class': MaxEntAgent, 
    'agent_name': 'MaxEnt_RNN', 
    'policy_type': 'rnn',       
    'hidden_dim': 128,           # Hidden dimension for RNN
    # 'hidden_dims': [128, 128], # multi-layer MLP - comment out for rnn
    'lr': 0.001,                
    'gamma': 0.993,
    # 'epsilon': 0.2,           # PPO specific
    'temperature': 0.15,         # MaxEnt specific
    # 'kl_delta': 0.01,         # TRPO specific
    'rnn_type': 'gru',          # Specify RNN type ('gru' or 'lstm' or 'rnn')

    # Training params
    'num_episodes': 2000,
    'task_switch_episode': 1000
}

# Display configuration
print("Configuration:")
for key, value in config.items():
    if key == 'agent_class':
        print(f"  {key}: {value.__name__}")
    else:
        print(f"  {key}: {value}")

Configuration:
  height: 10
  width: 20
  tone_duration_steps: 30
  shock_delay_steps: 30
  max_steps_per_episode: 100
  initial_task: 1
  agent_class: MaxEntAgent
  agent_name: MaxEnt_RNN
  policy_type: rnn
  hidden_dim: 128
  lr: 0.001
  gamma: 0.993
  temperature: 0.15
  rnn_type: gru
  num_episodes: 2000
  task_switch_episode: 1000


## 3. Create Environment and Agent

Initialize the environment and agent based on the configuration.

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

env = ActiveAvoidanceEnv2D(
    height=config['height'],
    width=config['width'],
    tone_duration_steps=config['tone_duration_steps'],
    shock_delay_steps=config['shock_delay_steps'],
    max_steps_per_episode=config['max_steps_per_episode'],
    initial_task=config['initial_task']
)

state_dim = env.observation_space.shape[0] # Should be 4
action_dim = env.action_space.n

# --- Create Agent ---
AgentClass = config['agent_class']

# Construct agent parameters
agent_params = {
    'policy_type': config['policy_type'],
    'state_dim': state_dim,
    'action_dim': action_dim,
    'lr': config['lr'],
    'gamma': config['gamma']
}
if config['policy_type'] == 'mlp':
    agent_params['hidden_dims'] = config.get('hidden_dims', [128, 128])
elif config['policy_type'] in ['rnn', 'transformer']:
    agent_params['hidden_dim'] = config['hidden_dim']
    if config['policy_type'] == 'rnn':
         agent_params['rnn_type'] = config.get('rnn_type', 'gru')

# Add algorithm-specific params
if AgentClass == PPOAgent:
    agent_params['epsilon'] = config['epsilon']
elif AgentClass in [MaxEntAgent, FisherMaxEntAgent]:
     agent_params['temperature'] = config['temperature']
     # Add Fisher specific params if needed for FisherMaxEntAgent
elif AgentClass == TRPOAgent:
     agent_params['kl_delta'] = config['kl_delta']
     # Add TRPO specific params if needed

# Instantiate the agent
agent = AgentClass(**agent_params)

print(f"Created {config['agent_name']} with {config['policy_type']} policy.")
print(f"State dim: {state_dim}, Action dim: {action_dim}")

Using device: cpu
Created MaxEnt_RNN with rnn policy.
State dim: 5, Action dim: 5


## 4. Train the Agent

Train the agent and save the training history and model. You only need to run this cell once.

In [None]:
# Set to True if you want to train the agent
TRAIN_AGENT = True

# Set to True if you want to load a previously trained agent
LOAD_AGENT = False

# Path to save/load the agent and training history
agent_save_path = f"saved_models/{config['agent_name']}_agent.pkl"
history_save_path = f"saved_models/{config['agent_name']}_history.pkl"

if LOAD_AGENT and os.path.exists(agent_save_path) and os.path.exists(history_save_path):
    # Load the agent and training history
    print(f"Loading agent from {agent_save_path}...")
    with open(agent_save_path, 'rb') as f:
        agent = pickle.load(f)
    
    print(f"Loading training history from {history_save_path}...")
    with open(history_save_path, 'rb') as f:
        training_history = pickle.load(f)
    
    rewards_history = training_history['rewards']
    losses_history = training_history['losses']
    metric_history = training_history['metrics']
    avoidance_rate_history = training_history['avoidance_rates']
    shock_rate_history = training_history['shock_rates']
    
    print("Agent and training history loaded successfully.")
    
elif TRAIN_AGENT:
    # --- Training ---
    print("Starting training...")
    rewards_history = []
    losses_history = []
    metric_history = []
    avoidance_rate_history = []
    shock_rate_history = []

    for episode in range(config['num_episodes']):
        # Task Switching Logic
        if 'task_switch_episode' in config and episode == config['task_switch_episode']:
            print(f"\n--- Switching Task at Episode {episode} ---")
            env.switch_task()

        # Run Episode
        state = env.reset()
        states, actions, rewards, log_probs_old_list = [], [], [], []
        hidden_states_list = []
        episode_reward = 0
        ep_info = {'avoided': False, 'shocked': False}

        hidden_state = None
        if config['policy_type'] in ["rnn"] and hasattr(agent.policy_net, 'init_hidden'):
             hidden_state = agent.policy_net.init_hidden().to(device)

        for step in range(config['max_steps_per_episode']):
            # Select action
            if config['policy_type'] in ["rnn"]:
                action, log_prob, _, h_new = agent.select_action(state, hidden_state)
                hidden_states_list.append(hidden_state)
                hidden_state = h_new
            else:
                action, log_prob, _ = agent.select_action(state)

            # Step environment
            next_state, reward, done, info = env.step(action)

            # Store transition
            states.append(state)
            actions.append(action)
            rewards.append(reward)
            log_probs_old_list.append(log_prob) # Used by PPO/TRPO, ignored by MaxEnt update

            episode_reward += reward
            state = next_state

            if done:
                ep_info = info
                break

        # Agent Update
        update_args = [states, actions, rewards, log_probs_old_list] # MaxEnt ignores log_probs_old_list
        if config['policy_type'] in ["rnn"]:
             update_args.append(hidden_states_list)

        update_info = agent.update(*update_args)

        # Record History
        rewards_history.append(episode_reward)
        losses_history.append(update_info.get('loss', update_info.get('policy_loss', 0)))
        # MaxEnt returns 'entropy', PPO/TRPO return 'approx_kl' or 'kl'
        metric_history.append(update_info.get('approx_kl', update_info.get('entropy', update_info.get('kl', 0))))
        avoidance_rate_history.append(1 if ep_info['avoided'] else 0)
        shock_rate_history.append(1 if ep_info['shocked'] else 0)

        # Print Progress
        if (episode + 1) % 50 == 0:
            avg_reward = np.mean(rewards_history[-50:])
            avg_loss = np.mean(losses_history[-50:])
            avg_metric = np.mean(metric_history[-50:])
            avg_avoid = np.mean(avoidance_rate_history[-50:]) * 100
            avg_shock = np.mean(shock_rate_history[-50:]) * 100

            metric_name = 'Entropy' if isinstance(agent, (MaxEntAgent, FisherMaxEntAgent)) else 'KL'
            print(f"Ep {episode+1}/{config['num_episodes']} | Task: {'Tone1' if env.current_task_id == 1 else 'Tone2'} | "
                  f"Rwd: {avg_reward:.2f} | Loss: {avg_loss:.4f} | {metric_name}: {avg_metric:.4f} | "
                  f"Avoid%: {avg_avoid:.1f} | Shock%: {avg_shock:.1f}")

    print("Training finished.")
    
    # Save the agent and training history
    print(f"Saving agent to {agent_save_path}...")
    with open(agent_save_path, 'wb') as f:
        pickle.dump(agent, f)
    
    print(f"Saving training history to {history_save_path}...")
    training_history = {
        'rewards': rewards_history,
        'losses': losses_history,
        'metrics': metric_history,
        'avoidance_rates': avoidance_rate_history,
        'shock_rates': shock_rate_history
    }
    with open(history_save_path, 'wb') as f:
        pickle.dump(training_history, f)
    
    print("Agent and training history saved successfully.")
    
else:
    print("Skipping training. Set TRAIN_AGENT=True to train the agent or LOAD_AGENT=True to load a previously trained agent.")

## 5. Plot Training Curves

Visualize the training performance of the agent.

In [None]:
# Check if we have training history
if 'rewards_history' in locals() and 'losses_history' in locals() and 'metric_history' in locals() and \
   'avoidance_rate_history' in locals() and 'shock_rate_history' in locals():
    
    metric_name = 'Entropy' if isinstance(agent, (MaxEntAgent, FisherMaxEntAgent)) else 'KL'
    
    plot_avoidance_training_curves(
        rewards=rewards_history,
        losses=losses_history,
        metrics=metric_history,
        metric_name=metric_name,
        avoidance_rates=avoidance_rate_history,
        shock_rates=shock_rate_history,
        smooth_window=50,
        save_path=f"plots/{config['agent_name']}_training_curves.png",
    )
    
    print(f"Training curves saved to plots/{config['agent_name']}_training_curves.png")
else:
    print("No training history available. Please train the agent or load a previously trained agent.")

## 6. Visualization Options

Choose which visualization you want to generate. You can run these cells multiple times with different parameters without retraining the agent.

### 6.1 Plot 2D Trajectories

In [None]:
# Parameters for 2D trajectory visualization
num_trajectories = 8
max_steps = config['max_steps_per_episode']
save_path = f"plots/{config['agent_name']}_trajectory_task{env.current_task_id}.png"

try:
    plot_avoidance_trajectories_2d(env, agent, num_trajectories=num_trajectories, max_steps=max_steps, save_path=save_path)
    print(f"2D trajectories saved to {save_path}")
except Exception as e:
    print(f"Could not generate trajectory plot: {e}")

### 6.2 Plot Heatmap

In [None]:
# Parameters for heatmap visualization
num_episodes = 100
max_steps = config['max_steps_per_episode']
save_path = f"plots/{config['agent_name']}_heatmap_task{env.current_task_id}.png"

try:
    plot_avoidance_heatmap_2d(env, agent, num_episodes=num_episodes, max_steps=max_steps, save_path=save_path)
    print(f"Heatmap saved to {save_path}")
except Exception as e:
    print(f"Could not generate heatmap plot: {e}")

### 6.3 Plot Step-by-Step Trajectory Animation

In [None]:
# Parameters for step-by-step trajectory animation
max_steps = config['max_steps_per_episode']
save_path = f"plots/{config['agent_name']}_step_by_step_task{env.current_task_id}.gif"
visualization_type = 'animation'  # 'animation' or 'gallery'
fps = 5
gallery_save_dir = None  # Set to a directory path if you want to save individual frames

try:
    plot_avoidance_trajectory_step_by_step(
        env, agent, max_steps=max_steps, save_path=save_path, 
        visualization_type=visualization_type, fps=fps, gallery_save_dir=gallery_save_dir
    )
    print(f"Step-by-step trajectory animation saved to {save_path}")
except Exception as e:
    print(f"Could not generate step-by-step trajectory animation: {e}")

### 6.4 Plot Multiple Trajectories Side by Side

In [None]:
# Parameters for multiple trajectories visualization
num_runs = 4
max_steps = config['max_steps_per_episode']
save_path = f"plots/{config['agent_name']}_multiple_trajectories_task{env.current_task_id}.gif"
visualization_type = 'animation'  # 'animation' or 'gallery'
fps = 5
gallery_save_dir = None  # Set to a directory path if you want to save individual frames
random_seed = 42  # Set to None to use the environment's default randomness

try:
    plot_multiple_avoidance_trajectories(
        env, agent, num_runs=num_runs, max_steps=max_steps, save_path=save_path, 
        visualization_type=visualization_type, fps=fps, gallery_save_dir=gallery_save_dir,
        random_seed=random_seed
    )
    print(f"Multiple trajectories animation saved to {save_path}")
except Exception as e:
    print(f"Could not generate multiple trajectories animation: {e}")

## 7. Experiment with Different Parameters

You can modify the parameters in the visualization cells above to experiment with different settings without retraining the agent.

In [None]:
# Example: Try different numbers of runs for multiple trajectories
num_runs_options = [2, 4, 6, 8]

for num_runs in num_runs_options:
    save_path = f"plots/{config['agent_name']}_multiple_trajectories_{num_runs}_runs_task{env.current_task_id}.gif"
    
    try:
        plot_multiple_avoidance_trajectories(
            env, agent, num_runs=num_runs, max_steps=max_steps, save_path=save_path, 
            visualization_type='animation', fps=5, random_seed=42
        )
        print(f"Multiple trajectories animation with {num_runs} runs saved to {save_path}")
    except Exception as e:
        print(f"Could not generate multiple trajectories animation with {num_runs} runs: {e}")

## 8. Switch Task and Visualize

You can switch the task and visualize the agent's performance on the new task.

In [None]:
# Switch task
env.switch_task()
print(f"Switched to Task {'Tone1' if env.current_task_id == 1 else 'Tone2'}")

# Visualize the agent's performance on the new task
save_path = f"plots/{config['agent_name']}_multiple_trajectories_task{env.current_task_id}.gif"

try:
    plot_multiple_avoidance_trajectories(
        env, agent, num_runs=4, max_steps=max_steps, save_path=save_path, 
        visualization_type='animation', fps=5, random_seed=42
    )
    print(f"Multiple trajectories animation for the new task saved to {save_path}")
except Exception as e:
    print(f"Could not generate multiple trajectories animation for the new task: {e}")