# BinAX Quick Start Guide

This notebook provides a quick introduction to BinAX for getting started with bin packing reinforcement learning.

## Installation

```bash
pip install -e .
```

In [None]:
import jax
import jax.numpy as jnp
from jax import random

# Import BinAX
import sys

sys.path.append("..")
from binax import BinPackingEnv, PolicyValueNetwork, PPOAgent

## 1. Create Environment

In [None]:
# Create bin packing environment
env = BinPackingEnv(
    bin_capacity=1.0, max_bins=10, max_items=20, item_size_range=(0.1, 0.5)
)

# Reset environment
key = random.PRNGKey(42)
state = env.reset(key, num_items=8)

print(env.render_state(state))

## 2. Create Network and Agent

In [None]:
# Create network
network = PolicyValueNetwork(hidden_dim=128, num_layers=2, max_bins=10)

# Create PPO agent
agent = PPOAgent(network, action_dim=11)

# Initialize parameters
key, param_key = random.split(key)
params = agent.init_params(param_key, state)

print(f"Network parameters: {sum(x.size for x in jax.tree_leaves(params)):,}")

## 3. Test Agent

In [None]:
# Test agent action selection
valid_actions = env.get_valid_actions(state)
key, action_key = random.split(key)

action, log_prob, value = agent.select_action(params, state, action_key, valid_actions)

print(f"Selected action: {action.bin_idx}")
print(f"Log probability: {log_prob:.4f}")
print(f"Value estimate: {value:.4f}")
print(f"Valid actions: {valid_actions}")

## 4. Run Episode

In [None]:
# Run a complete episode
current_state = env.reset(random.PRNGKey(123), num_items=6)
episode_reward = 0
steps = 0

print("Running episode...")
while not current_state.done and steps < 20:
    valid_actions = env.get_valid_actions(current_state)
    key, action_key = random.split(key)

    action, _, _ = agent.select_action(params, current_state, action_key, valid_actions)

    next_state, reward, done = env.step(current_state, action, action_key)

    print(
        f"Step {steps}: Item {current_state.item_queue[current_state.current_item_idx]:.3f} -> Bin {action.bin_idx}, Reward: {reward:.2f}"
    )

    episode_reward += reward
    current_state = next_state
    steps += 1

print("\nEpisode completed!")
print(f"Total reward: {episode_reward:.2f}")
print(f"Bins used: {jnp.sum(current_state.bin_utilization > 0)}")
print("\nFinal state:")
print(env.render_state(current_state))

## 5. Training

For full training, use the trainer:

In [None]:
from binax.trainer import Trainer, TrainingConfig

# Create training configuration
config = TrainingConfig(
    total_timesteps=50_000,  # Small for demo
    num_envs=16,
    network_type="simple",
    use_wandb=False,
)

# Create trainer
trainer = Trainer(config, seed=42)

# Start training
print("Starting training...")
# trainer.train()  # Uncomment to run training

## Next Steps

1. **Full Training**: Run `python -m binax.trainer` for complete training
2. **Experiments**: Try different network architectures and hyperparameters
3. **Evaluation**: Compare with classical heuristics
4. **Scaling**: Use more environments and longer training

See the full demo notebook for more advanced examples!