# Interactive Episode Visualization Demo

This notebook demonstrates the enhanced visualization capabilities for BinAX episodes, showing how the agent makes decisions step by step.

In [None]:
import jax
import numpy as np
import matplotlib.pyplot as plt

from binax.environment import BinPackingEnv
from binax.algorithms import PPO
from binax.networks import SimpleNetwork
from binax.visualizer import EpisodeVisualizer
from binax.interactive_viz import InteractiveEpisodeExplorer, EpisodeComparator

# Enable interactive plots
%matplotlib widget

## Setup Environment and Agent

In [None]:
# Initialize environment
env_params = {
    "num_items": 15,
    "item_dist": "uniform",
    "item_sizes": (0.1, 0.6),
    "bin_capacity": 1.0,
    "reward_type": "utilization",
}
env = BinPackingEnv(**env_params)

# Initialize agent
network = SimpleNetwork(
    hidden_dim=64,
    num_layers=2,
    max_bins=env.max_bins,
)

agent = PPO(
    network=network,
    learning_rate=3e-4,
    clip_epsilon=0.2,
    value_coef=0.5,
    entropy_coef=0.01,
)

# Initialize network parameters
key = jax.random.PRNGKey(42)
dummy_state = env.reset(key)
network_params = network.init(key, dummy_state)
agent.network_params = network_params

print("Environment and agent initialized!")

## Run Episode with Visualization Recording

In [None]:
def run_visualized_episode(env, agent, visualizer, seed=0, verbose=True):
    """Run a single episode while recording visualization data."""
    key = jax.random.PRNGKey(seed)
    key, reset_key = jax.random.split(key)
    
    # Reset environment
    state = env.reset(reset_key)
    visualizer.clear_history()
    
    total_reward = 0
    step_count = 0
    
    if verbose:
        print("Starting episode...\n")
    
    while not state.done:
        key, action_key = jax.random.split(key)
        
        # Get action from agent with probabilities
        action_logits = agent.network.apply(
            agent.network_params, state, method=agent.network.policy
        )
        action_probs = jax.nn.softmax(action_logits)
        
        # Sample action
        action_idx = jax.random.categorical(action_key, action_logits)
        action = env.idx_to_action(action_idx, state)
        
        # Get value estimate
        value = agent.network.apply(
            agent.network_params, state, method=agent.network.value
        )
        
        # Step environment
        next_state, reward = env.step(state, action)
        
        # Record step for visualization
        visualizer.record_step(
            state=state,
            action=action,
            reward=float(reward),
            action_probs=np.array(action_probs),
            value_estimate=float(value)
        )
        
        total_reward += reward
        step_count += 1
        
        if verbose:
            item_size = float(state.item_queue[state.current_item_idx])
            prob = float(action_probs[action_idx])
            print(f"Step {step_count}: Item {item_size:.3f} → Bin {action.bin_index} "
                  f"(prob: {prob:.2f}), Reward: {reward:.3f}")
        
        state = next_state
    
    if verbose:
        print(f"\nEpisode completed in {step_count} steps with total reward: {total_reward:.3f}")
    
    return total_reward

# Create visualizer and run episode
visualizer = EpisodeVisualizer(
    bin_capacity=env.bin_capacity,
    max_bins=env.max_bins
)

total_reward = run_visualized_episode(env, agent, visualizer, seed=42)

## Episode Summary Visualization

In [None]:
# Create comprehensive episode summary
fig = visualizer.plot_episode_summary(figsize=(16, 10))
plt.show()

## Interactive Episode Explorer

Use the controls below to step through the episode and see how the agent made each decision.

In [None]:
# Create interactive explorer
explorer = InteractiveEpisodeExplorer(visualizer)
explorer.display()

## Animated Episode Playback

In [None]:
# Create animation
anim = visualizer.create_episode_animation(interval=800, figsize=(12, 8))
if anim:
    from IPython.display import HTML
    HTML(anim.to_jshtml())
else:
    print("No animation data available")

## Compare Multiple Episodes

Let's run several episodes and compare their performance.

In [None]:
# Run multiple episodes for comparison
visualizers = []
rewards = []

for i in range(3):
    viz = EpisodeVisualizer(
        bin_capacity=env.bin_capacity,
        max_bins=env.max_bins
    )
    
    reward = run_visualized_episode(env, agent, viz, seed=i*123, verbose=False)
    
    visualizers.append(viz)
    rewards.append(reward)
    
    print(f"Episode {i+1}: {len(viz.episode_history)} steps, reward: {reward:.3f}")

print(f"\nAverage reward: {np.mean(rewards):.3f} ± {np.std(rewards):.3f}")

In [None]:
# Compare episodes
comparator = EpisodeComparator(
    visualizers=visualizers,
    labels=[f"Episode {i+1}" for i in range(len(visualizers))]
)

# Show comparison plot
fig = comparator.compare_episodes(figsize=(18, 12))
plt.show()

# Show performance summary
print("\nPerformance Summary:")
summary = comparator.create_performance_summary()

## Analyze Agent Decision Patterns

In [None]:
# Analyze decision patterns from the first episode
first_viz = visualizers[0]

# Plot distribution of action probabilities
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# 1. Action probability distribution
ax = axes[0, 0]
all_probs = []
for step in first_viz.episode_history:
    if step.action_probs is not None:
        # Get probability of selected action
        selected_prob = step.action_probs[step.bin_selected]
        all_probs.append(selected_prob)

ax.hist(all_probs, bins=20, alpha=0.7, color='skyblue')
ax.set_xlabel('Selected Action Probability')
ax.set_ylabel('Frequency')
ax.set_title('Distribution of Selected Action Probabilities')
ax.axvline(np.mean(all_probs), color='red', linestyle='--', label=f'Mean: {np.mean(all_probs):.2f}')
ax.legend()

# 2. Value estimates over time
ax = axes[0, 1]
values = [step.value_estimate for step in first_viz.episode_history if step.value_estimate is not None]
ax.plot(values, color='green', linewidth=2)
ax.set_xlabel('Step')
ax.set_ylabel('Value Estimate')
ax.set_title('Agent Value Estimates')
ax.grid(True, alpha=0.3)

# 3. Item size vs action confidence
ax = axes[1, 0]
item_sizes = [step.item_size for step in first_viz.episode_history]
action_confidence = [max(step.action_probs) if step.action_probs is not None else 0 
                    for step in first_viz.episode_history]

ax.scatter(item_sizes, action_confidence, alpha=0.7, color='purple')
ax.set_xlabel('Item Size')
ax.set_ylabel('Max Action Probability')
ax.set_title('Item Size vs Action Confidence')
ax.grid(True, alpha=0.3)

# 4. Reward distribution
ax = axes[1, 1]
rewards = [step.reward for step in first_viz.episode_history]
ax.hist(rewards, bins=15, alpha=0.7, color='orange')
ax.set_xlabel('Reward')
ax.set_ylabel('Frequency')
ax.set_title('Reward Distribution')
ax.axvline(np.mean(rewards), color='red', linestyle='--', label=f'Mean: {np.mean(rewards):.3f}')
ax.legend()

plt.tight_layout()
plt.show()

print(f"\nDecision Analysis:")
print(f"Average action confidence: {np.mean(action_confidence):.3f}")
print(f"Most confident decision: {max(action_confidence):.3f}")
print(f"Least confident decision: {min(action_confidence):.3f}")
print(f"Average value estimate: {np.mean(values):.3f}")

## Save Visualizations

In [None]:
# Save static visualizations
fig_summary = visualizers[0].plot_episode_summary(figsize=(16, 10))
fig_summary.savefig('episode_summary.png', dpi=300, bbox_inches='tight')
print("Saved episode_summary.png")

# Save comparison
fig_comparison = comparator.compare_episodes(figsize=(18, 12))
fig_comparison.savefig('episode_comparison.png', dpi=300, bbox_inches='tight')
print("Saved episode_comparison.png")

# Try to save animation
anim = visualizers[0].create_episode_animation(interval=1000, figsize=(10, 8))
if anim:
    try:
        anim.save('episode_animation.gif', writer='pillow', fps=1)
        print("Saved episode_animation.gif")
    except Exception as e:
        print(f"Could not save animation: {e}")
        print("Install pillow for GIF support: pip install pillow")
        
plt.show()