## 1. Setup

Clone the repository and install dependencies. **Internet must be enabled.**

In [None]:
# Clone the EMPO repository
!git clone --depth 1 https://github.com/mensch72/empo.git
%cd empo

In [None]:
# Run the automated setup script
# This installs dependencies and configures paths
%run scripts/kaggle_setup.py

## 2. Verify GPU Setup

Kaggle provides T4 or P100 GPUs. Let's verify GPU availability.

In [None]:
import torch
import numpy as np

print("System Information:")
print(f"  PyTorch version: {torch.__version__}")
print(f"  NumPy version: {np.__version__}")
print(f"  CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"  GPU: {torch.cuda.get_device_name(0)}")
    print(f"  GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    device = 'cuda'
else:
    print("  ⚠️ No GPU available - training will be slower")
    print("  Enable GPU: Settings (right sidebar) → Accelerator → GPU")
    device = 'cpu'

print(f"\n✓ Using device: {device}")

## 3. Create and Explore an Environment

Let's create a simple multi-agent gridworld environment.

In [None]:
# Create a small environment for demonstration
from envs.one_or_three_chambers import SmallOneOrThreeChambersMapEnv

env = SmallOneOrThreeChambersMapEnv()
env.reset()

print("Environment Information:")
print(f"  Grid size: {env.width} x {env.height}")
print(f"  Number of agents: {len(env.agents)}")
print(f"  Max steps: {env.max_steps}")

In [None]:
# Display agent information
for i, agent in enumerate(env.agents):
    agent_type = "Human" if agent.color == 'yellow' else "Robot" if agent.color == 'grey' else "Other"
    can_push = getattr(agent, 'can_push_rocks', False)
    print(f"Agent {i} ({agent_type}): pos={tuple(agent.pos)}, color={agent.color}, can_push_rocks={can_push}")

In [None]:
# Render the environment
import matplotlib.pyplot as plt

img = env.render(mode='rgb_array', highlight=False)
plt.figure(figsize=(6, 6))
plt.imshow(img)
plt.title('Initial Environment State')
plt.axis('off')
plt.show()

## 4. State Management

EMPO extends MultiGrid with explicit state management for model-based planning.

In [None]:
# Get the current state (hashable, can be stored and restored)
state = env.get_state()

print("State format: (step_count, agent_states, mobile_objects, mutable_objects)")
print(f"\nCurrent state:")
print(f"  Step count: {state[0]}")
print(f"  Agent states: {state[1]}")
print(f"  Mobile objects: {state[2]}")

In [None]:
# Take steps and observe state change
actions = [env.actions.forward] * len(env.agents)
obs, rewards, done, info = env.step(actions)

new_state = env.get_state()
print(f"After step: step_count={new_state[0]}")

# Restore original state
env.set_state(state)
restored_state = env.get_state()
print(f"State restored: {state == restored_state}")

## 5. Compute the State DAG

For finite environments, EMPO computes the complete state-space DAG.

In [None]:
import time

# Create environment with reduced max_steps for speed
small_env = SmallOneOrThreeChambersMapEnv()
small_env.max_steps = 2  # Small for quick demo
small_env.reset()

print(f"Environment: {small_env.width}x{small_env.height} grid, {small_env.max_steps} max steps")

# Compute the DAG
print("Computing state DAG...")
t0 = time.time()
states, state_to_idx, successors = small_env.get_dag()

print(f"✓ DAG computed in {time.time() - t0:.2f}s")
print(f"  Total states: {len(states)}")
print(f"  Terminal states: {sum(1 for s in successors if len(s) == 0)}")
print(f"  Total transitions: {sum(len(s) for s in successors)}")

## 6. Compute Human Policy Prior

The core EMPO computation: computing human policy priors via backward induction.

In [None]:
from empo import PossibleGoal, PossibleGoalGenerator, compute_human_policy_prior
from typing import Iterator, Tuple

# Define a simple goal: reaching a specific cell
class ReachCellGoal(PossibleGoal):
    """A goal where a specific human agent tries to reach a specific cell."""
    
    def __init__(self, world_model, human_agent_index: int, target_pos: tuple):
        super().__init__(world_model)
        self.human_agent_index = human_agent_index
        self.target_pos = np.array(target_pos)
    
    def is_achieved(self, state) -> int:
        step_count, agent_states, mobile_objects, mutable_objects = state
        if self.human_agent_index < len(agent_states):
            agent_state = agent_states[self.human_agent_index]
            if agent_state[0] == self.target_pos[0] and agent_state[1] == self.target_pos[1]:
                return 1
        return 0
    
    def __str__(self):
        return f"ReachCell({self.human_agent_index}->({self.target_pos[0]},{self.target_pos[1]}))"
    
    def __hash__(self):
        return hash((self.human_agent_index, self.target_pos[0], self.target_pos[1]))
    
    def __eq__(self, other):
        if not isinstance(other, ReachCellGoal):
            return False
        return (self.human_agent_index == other.human_agent_index and 
                np.array_equal(self.target_pos, other.target_pos))

In [None]:
# Define a goal generator
class SimpleCellGoalGenerator(PossibleGoalGenerator):
    """Generates goals for reachable cells."""
    
    def __init__(self, world_model):
        super().__init__(world_model)
        
        # Find empty cells
        self.empty_cells = []
        for x in range(world_model.width):
            for y in range(world_model.height):
                cell = world_model.grid.get(x, y)
                if cell is None or (hasattr(cell, 'type') and cell.type != 'wall'):
                    self.empty_cells.append((x, y))
        
        self.empty_cells = self.empty_cells[:5]  # Limit for demo
        print(f"Goal generator: {len(self.empty_cells)} target cells")
    
    def generate(self, state, human_agent_index: int) -> Iterator[Tuple[PossibleGoal, float]]:
        total_goals = len(self.empty_cells)
        if total_goals == 0:
            return
        weight = 1.0 / total_goals
        for pos in self.empty_cells:
            yield ReachCellGoal(self.world_model, human_agent_index, pos), weight

In [None]:
# Identify human agents and compute policy prior
human_agent_indices = [i for i, agent in enumerate(small_env.agents) if agent.color == 'yellow']
print(f"Human agents: {human_agent_indices}")

small_env.reset()
goal_generator = SimpleCellGoalGenerator(small_env)

print("\nComputing human policy prior...")
t0 = time.time()

human_policy_prior = compute_human_policy_prior(
    world_model=small_env,
    human_agent_indices=human_agent_indices,
    possible_goal_generator=goal_generator,
    parallel=False,  # Single-threaded for Kaggle
    level_fct=lambda state: state[0]
)

print(f"✓ Human policy prior computed in {time.time() - t0:.2f}s")

In [None]:
# Test the policy prior
if human_agent_indices:
    small_env.reset()
    initial_state = small_env.get_state()
    first_human_idx = human_agent_indices[0]
    
    # Get first goal
    first_goal = next(iter(goal_generator.generate(initial_state, first_human_idx)))[0]
    
    action_probs = human_policy_prior(initial_state, first_human_idx, first_goal)
    print(f"Action probabilities for {first_goal}:")
    for action_idx, prob in enumerate(action_probs):
        action_name = small_env.actions.available[action_idx] if action_idx < len(small_env.actions.available) else f"action_{action_idx}"
        print(f"  {action_name}: {prob:.4f}")

## 7. Neural Network Training with GPU

Train a neural policy prior using Kaggle's GPU.

In [None]:
import os
import random

# Import neural network training components
from empo.multigrid import MultiGridGoalSampler
from empo.nn_based.multigrid import train_multigrid_neural_policy_prior

# Create training environment
train_env = SmallOneOrThreeChambersMapEnv()
train_env.max_steps = 10
train_env.reset()

train_human_indices = [i for i, agent in enumerate(train_env.agents) if agent.color == 'yellow']
goal_sampler = MultiGridGoalSampler(train_env)

# Set seeds
random.seed(42)
torch.manual_seed(42)

# TensorBoard logging
from torch.utils.tensorboard import SummaryWriter
log_dir = '/kaggle/working/tensorboard_logs'
os.makedirs(log_dir, exist_ok=True)
writer = SummaryWriter(log_dir=log_dir)

print(f"Training environment: {train_env.width}x{train_env.height}, {len(train_human_indices)} human(s)")
print(f"Device: {device}")
print(f"TensorBoard logs: {log_dir}")

In [None]:
# Train neural policy prior
print("\nTraining neural policy prior (100 episodes)...")
t0 = time.time()

neural_prior = train_multigrid_neural_policy_prior(
    world_model=train_env,
    human_agent_indices=train_human_indices,
    goal_sampler=goal_sampler,
    num_episodes=100,          # Use 500+ for real training
    steps_per_episode=10,
    batch_size=32,
    learning_rate=1e-3,
    gamma=0.99,
    beta=100.0,
    replay_buffer_size=5000,
    updates_per_episode=2,
    epsilon=0.3,
    reward_shaping=True,
    device=device,
    verbose=True
)

elapsed = time.time() - t0
print(f"\n✓ Training completed in {elapsed:.2f}s")
print(f"  Parameters: {sum(p.numel() for p in neural_prior.q_network.parameters()):,}")

In [None]:
# Save model checkpoint to Kaggle output
checkpoint_path = '/kaggle/working/neural_policy_prior.pt'
torch.save({
    'q_network_state_dict': neural_prior.q_network.state_dict(),
    'training_episodes': 100,
    'device': device,
}, checkpoint_path)
print(f"✓ Model saved to: {checkpoint_path}")

## 8. Visualize Episode

Generate and visualize a sample episode.

In [None]:
# Generate episode frames
env.reset()
frames = [env.render(mode='rgb_array', highlight=False)]

done = False
step = 0
max_demo_steps = 10

while not done and step < max_demo_steps:
    actions = [env.action_space.sample() for _ in env.agents]
    obs, rewards, done, info = env.step(actions)
    frames.append(env.render(mode='rgb_array', highlight=False))
    step += 1

print(f"Generated {len(frames)} frames over {step} steps")

In [None]:
# Display frames as grid
n_frames = min(6, len(frames))
fig, axes = plt.subplots(1, n_frames, figsize=(3*n_frames, 3))

for i, ax in enumerate(axes):
    idx = i * (len(frames) - 1) // (n_frames - 1) if n_frames > 1 else 0
    ax.imshow(frames[idx])
    ax.set_title(f'Step {idx}')
    ax.axis('off')

plt.tight_layout()
plt.suptitle('Sample Episode', y=1.02)
plt.savefig('/kaggle/working/episode_frames.png', dpi=150, bbox_inches='tight')
plt.show()
print("✓ Saved to /kaggle/working/episode_frames.png")

## 9. Kaggle-Specific Tips

### Resource Limits:
- **GPU quota**: 30 hours/week (T4 or P100)
- **Session length**: Max 12 hours
- **Disk space**: 20GB in `/kaggle/working/`

### Best Practices:
- Save checkpoints to `/kaggle/working/` (persisted as output)
- Use `torch.cuda.empty_cache()` to free GPU memory
- Enable "Save & Run All" to produce downloadable outputs
- Use datasets for large input files

### Not Supported:
- MPI distributed training (use `parallel=False`)
- Docker containers
- Long background processes

In [None]:
# Clear GPU memory if needed
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
    print(f"GPU memory cached: {torch.cuda.memory_reserved() / 1e9:.2f} GB")

In [None]:
# List output files
import os
print("Output files in /kaggle/working/:")
for f in os.listdir('/kaggle/working/'):
    size = os.path.getsize(f'/kaggle/working/{f}') / 1024
    print(f"  {f}: {size:.1f} KB")

## Summary

This notebook demonstrated:

1. **Setup**: Automated setup with `kaggle_setup.py`
2. **Environment**: Creating MultiGrid environments
3. **State Management**: `get_state()` / `set_state()` for planning
4. **DAG Computation**: State-space exploration
5. **Policy Priors**: Backward induction computation
6. **Neural Training**: GPU-accelerated training with checkpoints
7. **Visualization**: Episode rendering

For more information:
- [GitHub Repository](https://github.com/mensch72/empo)
- [API Documentation](https://github.com/mensch72/empo/blob/main/docs/API.md)

In [None]:
print("\n" + "="*60)
print("EMPO Kaggle Demo Complete!")
print("="*60)