# Interactive Episode Visualization Demo

This notebook demonstrates the enhanced visualization capabilities for BinAX episodes.

In [None]:
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt

from binax.environment import BinPackingEnv
from binax.algorithms import PPOAgent, PPOConfig
from binax.networks import SimplePolicyValueNetwork
from binax.types import BinPackingAction
from binax.visualizer import EpisodeVisualizer
from binax.interactive_viz import InteractiveEpisodeExplorer, EpisodeComparator

print("Imports successful!")

## Setup Environment and Agent

In [None]:
# Initialize environment
env = BinPackingEnv(
    bin_capacity=1.0,
    max_bins=50,
    max_items=15,
    item_size_range=(0.1, 0.6)
)

# Initialize agent
network = SimplePolicyValueNetwork(
    hidden_dims=[64, 64],
    max_bins=env.max_bins
)

config = PPOConfig(
    learning_rate=3e-4,
    clip_eps=0.2,
    value_loss_coeff=0.5,
    entropy_coeff=0.01
)

agent = PPOAgent(network=network, config=config)

# Initialize network parameters
key = jax.random.PRNGKey(42)
dummy_state = env.reset(key)
network_params = agent.init_params(key, dummy_state)

print("Environment and agent initialized!")

## Run Episode with Visualization

In [None]:
def run_visualized_episode(env, agent, network_params, visualizer, seed=0):
    key = jax.random.PRNGKey(seed)
    key, reset_key = jax.random.split(key)
    
    state = env.reset(reset_key)
    visualizer.clear_history()
    
    total_reward = 0
    step_count = 0
    
    while not state.done:
        key, action_key, step_key = jax.random.split(key, 3)
        
        # Get valid actions and network output
        valid_actions = env.get_valid_actions(state)
        network_output = agent.network.apply(network_params, state, training=False)
        
        # Mask invalid actions and get probabilities
        masked_logits = jnp.where(valid_actions, network_output.action_logits, -1e9)
        action_probs = jax.nn.softmax(masked_logits)
        
        # Sample action
        action_idx = jax.random.categorical(action_key, masked_logits)
        action = BinPackingAction(bin_idx=action_idx)
        
        # Step environment
        next_state, reward, _ = env.step(state, action, step_key)
        
        # Record step for visualization
        visualizer.record_step(
            state=state,
            action=action,
            reward=float(reward),
            action_probs=np.array(action_probs),
            value_estimate=float(network_output.value)
        )
        
        total_reward += reward
        step_count += 1
        state = next_state
        
        current_item = float(state.item_queue[state.current_item_idx-1])
        prob = float(action_probs[action_idx])
        print(f"Step {step_count}: Item {current_item:.3f} → Bin {action.bin_idx} (prob: {prob:.2f}), Reward: {reward:.3f}")
    
    print(f"\nEpisode completed in {step_count} steps with total reward: {total_reward:.3f}")
    return total_reward

# Create visualizer and run episode
visualizer = EpisodeVisualizer(bin_capacity=env.bin_capacity, max_bins=env.max_bins)
total_reward = run_visualized_episode(env, agent, network_params, visualizer, seed=42)

## Episode Summary Visualization

In [None]:
fig = visualizer.plot_episode_summary(figsize=(16, 10))
plt.show()

## Interactive Explorer (if ipywidgets available)

In [None]:
try:
    explorer = InteractiveEpisodeExplorer(visualizer)
    explorer.display()
except ImportError:
    print("ipywidgets not available. Install with: uv add ipywidgets")
    print("Showing static key steps instead.")
    
    num_steps = len(visualizer.episode_history)
    key_steps = [0, num_steps//3, 2*num_steps//3, num_steps-1]
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    axes = axes.flatten()
    
    for i, step_idx in enumerate(key_steps):
        ax = axes[i]
        step = visualizer.episode_history[step_idx]
        
        if step.action_probs is not None:
            x = np.arange(min(10, len(step.action_probs)))
            probs = step.action_probs[:len(x)]
            colors = ['red' if j == step.bin_selected and j < len(x) else 'skyblue' for j in range(len(x))]
            ax.bar(x, probs, color=colors)
            ax.set_title(f'Step {step_idx+1}: Item {step.item_size:.3f}')
            ax.set_xlabel('Bin Index')
            ax.set_ylabel('Probability')
            ax.set_ylim(0, 1)
    
    plt.tight_layout()
    plt.show()

## Animation

In [None]:
anim = visualizer.create_episode_animation(interval=800, figsize=(12, 8))
if anim:
    print("Animation created successfully!")
    # Note: Animation display may require additional setup in some environments
else:
    print("Animation creation failed")