# BinAX Demo: JAX-Based Reinforcement Learning for Bin Packing

This notebook demonstrates the capabilities of BinAX, a high-performance reinforcement learning framework for solving bin packing problems using JAX.

## Overview

- **Environment**: 1D bin packing with configurable parameters
- **Algorithm**: Proximal Policy Optimization (PPO)
- **Networks**: Attention-based and simple architectures
- **Framework**: JAX with JIT compilation and vectorization

Let's start by importing the necessary libraries and setting up our environment.


In [None]:
# Install dependencies if needed
# !pip install jax jaxlib flax optax chex matplotlib seaborn tqdm

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from jax import random
from tqdm import tqdm
import time

# Set up plotting
plt.style.use("seaborn-v0_8")
sns.set_palette("husl")
%matplotlib inline

print(f"JAX version: {jax.__version__}")
print(f"JAX devices: {jax.devices()}")
print(f"JAX backend: {jax.default_backend()}")

In [None]:
# Import BinAX components
import sys

sys.path.append("..")

from binax import BinPackingEnv, PPOAgent
from binax.algorithms import PPOConfig
from binax.networks import create_network
from binax.trainer import TrainingConfig, Trainer
from binax.types import BinPackingAction, BinPackingState

## 1. Environment Exploration

Let's start by exploring the bin packing environment and understanding its components.


In [None]:
# Create a simple bin packing environment
env = BinPackingEnv(
    bin_capacity=1.0, max_bins=10, max_items=20, item_size_range=(0.1, 0.5)
)

# Initialize environment
key = random.PRNGKey(42)
state = env.reset(key, num_items=8)

print("Initial Environment State:")
print(env.render_state(state))
print(f"\nItem queue: {state.item_queue[:8]}")

In [None]:
# Visualize the initial state
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Plot bin capacities
ax1.bar(range(env.max_bins), state.bin_capacities, alpha=0.7, color="lightblue")
ax1.set_xlabel("Bin Index")
ax1.set_ylabel("Remaining Capacity")
ax1.set_title("Bin Capacities")
ax1.set_ylim(0, env.bin_capacity * 1.1)

# Plot items to pack
items_to_pack = state.item_queue[state.item_queue > 0]
ax2.bar(range(len(items_to_pack)), items_to_pack, alpha=0.7, color="lightcoral")
ax2.set_xlabel("Item Index")
ax2.set_ylabel("Item Size")
ax2.set_title("Items to Pack")
ax2.set_ylim(0, max(items_to_pack) * 1.1)

plt.tight_layout()
plt.show()

## 2. Manual Environment Interaction

Let's manually interact with the environment to understand the dynamics.


In [None]:
# Let's manually pack items using a simple First Fit strategy
def first_fit_strategy(state: BinPackingState) -> int:
    """Simple First Fit strategy for comparison."""
    current_item_size = state.item_queue[state.current_item_idx]

    # Find first bin that can fit the item
    for i, capacity in enumerate(state.bin_capacities):
        if capacity >= current_item_size:
            return i

    # If no bin can fit, return the first empty bin
    return 0


# Run First Fit on our environment
current_state = state
episode_rewards = []
episode_states = [current_state]

print("Running First Fit Strategy:")
print("=" * 50)

while not current_state.done:
    # Get valid actions
    valid_actions = env.get_valid_actions(current_state)

    # Use First Fit strategy
    action_idx = first_fit_strategy(current_state)
    action = BinPackingAction(bin_idx=action_idx)

    # Step environment
    key, step_key = random.split(key)
    next_state, reward, done = env.step(current_state, action, step_key)

    print(
        f"Step {current_state.step_count}: Item {current_state.item_queue[current_state.current_item_idx]:.3f} -> Bin {action_idx}, Reward: {reward:.2f}"
    )

    episode_rewards.append(reward)
    episode_states.append(next_state)
    current_state = next_state

    if current_state.step_count > 20:  # Safety break
        break

print("\nFinal State:")
print(env.render_state(current_state))
print(f"\nTotal Reward: {sum(episode_rewards):.2f}")
print(f"Bins Used: {jnp.sum(current_state.bin_utilization > 0)}")

In [None]:
# Visualize the packing result
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Plot final bin utilization
used_bins = current_state.bin_utilization > 0
bin_indices = np.arange(len(current_state.bin_utilization))
used_capacity = env.bin_capacity - current_state.bin_capacities

colors = ["lightgreen" if used else "lightgray" for used in used_bins]
ax1.bar(bin_indices, used_capacity, alpha=0.7, color=colors)
ax1.axhline(
    y=env.bin_capacity, color="red", linestyle="--", alpha=0.7, label="Capacity"
)
ax1.set_xlabel("Bin Index")
ax1.set_ylabel("Used Capacity")
ax1.set_title("Final Bin Utilization")
ax1.legend()

# Plot rewards over time
ax2.plot(episode_rewards, marker="o", alpha=0.7)
ax2.set_xlabel("Step")
ax2.set_ylabel("Reward")
ax2.set_title("Rewards Over Time")
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Calculate efficiency metrics
total_item_volume = jnp.sum(state.item_queue[state.item_queue > 0])
bins_used = jnp.sum(current_state.bin_utilization > 0)
total_bin_volume = bins_used * env.bin_capacity
efficiency = total_item_volume / total_bin_volume if total_bin_volume > 0 else 0

print("\nEfficiency Metrics:")
print(f"Total item volume: {total_item_volume:.3f}")
print(f"Bins used: {bins_used}")
print(f"Total bin volume: {total_bin_volume:.3f}")
print(f"Packing efficiency: {efficiency:.1%}")

## 3. Neural Network Architecture

Let's explore the neural network architectures available in BinAX.


In [None]:
# Create different network architectures
attention_network = create_network(
    network_type="attention",
    hidden_dim=128,
    num_layers=2,
    num_heads=4,
    max_bins=10,
    dropout_rate=0.1,
)

simple_network = create_network(
    network_type="simple", hidden_dim=128, max_bins=10, dropout_rate=0.1
)

print("Network Architectures Created:")
print(f"- Attention Network: {attention_network}")
print(f"- Simple Network: {simple_network}")

In [None]:
# Initialize network parameters and test forward pass
key, init_key = random.split(key)
dummy_state = env.reset(init_key, num_items=5)

# Initialize both networks
key, param_key1, param_key2 = random.split(key, 3)
attention_params = attention_network.init(param_key1, dummy_state, training=False)
simple_params = simple_network.init(param_key2, dummy_state, training=False)

# Test forward pass
attention_output = attention_network.apply(
    attention_params, dummy_state, training=False
)
simple_output = simple_network.apply(simple_params, dummy_state, training=False)

print("Network Outputs:")
print(
    f"Attention Network - Action Logits Shape: {attention_output.action_logits.shape}"
)
print(f"Attention Network - Value: {attention_output.value:.4f}")
print(f"Simple Network - Action Logits Shape: {simple_output.action_logits.shape}")
print(f"Simple Network - Value: {simple_output.value:.4f}")

In [None]:
# Visualize network predictions
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))

# Plot action logits
ax1.bar(
    range(len(attention_output.action_logits)),
    attention_output.action_logits,
    alpha=0.7,
    color="blue",
)
ax1.set_xlabel("Action (Bin Index)")
ax1.set_ylabel("Logit Value")
ax1.set_title("Attention Network - Action Logits")

ax2.bar(
    range(len(simple_output.action_logits)),
    simple_output.action_logits,
    alpha=0.7,
    color="red",
)
ax2.set_xlabel("Action (Bin Index)")
ax2.set_ylabel("Logit Value")
ax2.set_title("Simple Network - Action Logits")

# Plot action probabilities
attention_probs = jax.nn.softmax(attention_output.action_logits)
simple_probs = jax.nn.softmax(simple_output.action_logits)

ax3.bar(range(len(attention_probs)), attention_probs, alpha=0.7, color="blue")
ax3.set_xlabel("Action (Bin Index)")
ax3.set_ylabel("Probability")
ax3.set_title("Attention Network - Action Probabilities")

ax4.bar(range(len(simple_probs)), simple_probs, alpha=0.7, color="red")
ax4.set_xlabel("Action (Bin Index)")
ax4.set_ylabel("Probability")
ax4.set_title("Simple Network - Action Probabilities")

plt.tight_layout()
plt.show()

print(
    f"\nCurrent item size: {dummy_state.item_queue[dummy_state.current_item_idx]:.3f}"
)
print(f"Bin capacities: {dummy_state.bin_capacities[:5]}")
print(f"Valid actions: {env.get_valid_actions(dummy_state)[:5]}")

## 4. PPO Agent Demo

Let's create a PPO agent and demonstrate its action selection process.


In [None]:
# Create PPO agent
ppo_config = PPOConfig(
    learning_rate=3e-4,
    num_epochs=4,
    num_minibatches=4,
    gamma=0.99,
    gae_lambda=0.95,
    clip_eps=0.2,
    entropy_coeff=0.01,
    value_loss_coeff=0.5,
)

agent = PPOAgent(attention_network, ppo_config, action_dim=11)  # max_bins + 1

# Initialize agent parameters
key, agent_key = random.split(key)
agent_params = agent.init_params(agent_key, dummy_state)

print("PPO Agent initialized successfully!")
print(f"Number of parameters: {sum(x.size for x in jax.tree_leaves(agent_params)):,}")

In [None]:
# Demonstrate agent action selection
test_state = env.reset(random.PRNGKey(123), num_items=6)
valid_actions = env.get_valid_actions(test_state)

print("Agent Action Selection Demo:")
print("=" * 40)

for i in range(5):
    key, action_key = random.split(key)
    action, log_prob, value = agent.select_action(
        agent_params, test_state, action_key, valid_actions
    )

    print(
        f"Sample {i + 1}: Action={action.bin_idx}, Log Prob={log_prob:.4f}, Value={value:.4f}"
    )

print(f"\nCurrent item size: {test_state.item_queue[test_state.current_item_idx]:.3f}")
print(f"Valid actions: {valid_actions}")

## 5. Mini Training Demo

Let's run a small-scale training demo to see the agent learn.


In [None]:
# Create a simple training configuration for demo
demo_config = TrainingConfig(
    # Small scale for demo
    total_timesteps=10_000,
    num_envs=8,
    rollout_length=64,
    # Environment
    bin_capacity=1.0,
    max_bins=10,
    max_items=15,
    item_size_range=(0.1, 0.4),
    # Network
    network_type="simple",  # Faster for demo
    hidden_dim=64,
    num_layers=2,
    # Training
    learning_rate=1e-3,
    num_epochs=2,
    num_minibatches=2,
    # Logging
    log_interval=2,
    eval_interval=20,
    use_wandb=False,  # Disable for demo
)

print("Demo Training Configuration:")
print(f"- Total timesteps: {demo_config.total_timesteps:,}")
print(f"- Parallel environments: {demo_config.num_envs}")
print(f"- Rollout length: {demo_config.rollout_length}")
print(f"- Network type: {demo_config.network_type}")

In [None]:
# Run mini training
print("Starting mini training demo...")
print("This may take a few minutes depending on your hardware.")

# Create trainer
trainer = Trainer(demo_config, seed=42)

# Manual training loop for better control and visualization
key, reset_key = random.split(random.PRNGKey(42))
states = trainer.reset_fn(reset_key)

training_metrics = []
eval_metrics = []

timestep = 0
update_count = 0

print("\nTraining Progress:")
progress_bar = tqdm(total=demo_config.total_timesteps, desc="Training")

while timestep < demo_config.total_timesteps:
    # Collect rollout
    rollout_batch, states = trainer.collect_rollout(states)

    # Update policy
    key, update_key = random.split(key)
    trainer.params, trainer.opt_state, metrics = trainer.agent.update(
        trainer.params, trainer.opt_state, rollout_batch, update_key
    )

    # Update counters
    timestep += demo_config.rollout_length * demo_config.num_envs
    update_count += 1

    # Store metrics
    training_metrics.append(
        {
            "timestep": timestep,
            "policy_loss": float(metrics.policy_loss),
            "value_loss": float(metrics.value_loss),
            "entropy_loss": float(metrics.entropy_loss),
            "total_loss": float(metrics.total_loss),
            "mean_reward": float(jnp.mean(rollout_batch.rewards)),
            "mean_value": float(jnp.mean(rollout_batch.values)),
        }
    )

    # Evaluation
    if update_count % 5 == 0:
        eval_result = trainer.evaluate(num_episodes=3)
        eval_result["timestep"] = timestep
        eval_metrics.append(eval_result)

    # Update progress
    progress_bar.update(demo_config.rollout_length * demo_config.num_envs)
    progress_bar.set_postfix(
        {
            "Reward": f"{training_metrics[-1]['mean_reward']:.2f}",
            "Loss": f"{training_metrics[-1]['total_loss']:.4f}",
        }
    )

progress_bar.close()
print("\nTraining completed!")

In [None]:
# Visualize training results
import pandas as pd

# Convert metrics to DataFrame for easier plotting
training_df = pd.DataFrame(training_metrics)
eval_df = pd.DataFrame(eval_metrics)

# Create comprehensive training visualization
fig, axes = plt.subplots(2, 3, figsize=(18, 10))

# Training losses
axes[0, 0].plot(
    training_df["timestep"], training_df["policy_loss"], label="Policy Loss", alpha=0.7
)
axes[0, 0].plot(
    training_df["timestep"], training_df["value_loss"], label="Value Loss", alpha=0.7
)
axes[0, 0].set_xlabel("Timestep")
axes[0, 0].set_ylabel("Loss")
axes[0, 0].set_title("Training Losses")
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Rewards
axes[0, 1].plot(
    training_df["timestep"], training_df["mean_reward"], alpha=0.7, color="green"
)
axes[0, 1].set_xlabel("Timestep")
axes[0, 1].set_ylabel("Mean Reward")
axes[0, 1].set_title("Training Rewards")
axes[0, 1].grid(True, alpha=0.3)

# Values
axes[0, 2].plot(
    training_df["timestep"], training_df["mean_value"], alpha=0.7, color="purple"
)
axes[0, 2].set_xlabel("Timestep")
axes[0, 2].set_ylabel("Mean Value")
axes[0, 2].set_title("Value Function")
axes[0, 2].grid(True, alpha=0.3)

# Evaluation metrics
if len(eval_df) > 0:
    axes[1, 0].plot(
        eval_df["timestep"],
        eval_df["eval/episode_reward"],
        "o-",
        alpha=0.7,
        color="red",
    )
    axes[1, 0].set_xlabel("Timestep")
    axes[1, 0].set_ylabel("Episode Reward")
    axes[1, 0].set_title("Evaluation - Episode Reward")
    axes[1, 0].grid(True, alpha=0.3)

    axes[1, 1].plot(
        eval_df["timestep"],
        eval_df["eval/episode_length"],
        "o-",
        alpha=0.7,
        color="orange",
    )
    axes[1, 1].set_xlabel("Timestep")
    axes[1, 1].set_ylabel("Episode Length")
    axes[1, 1].set_title("Evaluation - Episode Length")
    axes[1, 1].grid(True, alpha=0.3)

    axes[1, 2].plot(
        eval_df["timestep"], eval_df["eval/bins_used"], "o-", alpha=0.7, color="brown"
    )
    axes[1, 2].set_xlabel("Timestep")
    axes[1, 2].set_ylabel("Bins Used")
    axes[1, 2].set_title("Evaluation - Bins Used")
    axes[1, 2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print final metrics
print("\nFinal Training Metrics:")
print(f"- Final reward: {training_df['mean_reward'].iloc[-1]:.3f}")
print(f"- Final policy loss: {training_df['policy_loss'].iloc[-1]:.4f}")
print(f"- Final value loss: {training_df['value_loss'].iloc[-1]:.4f}")

if len(eval_df) > 0:
    print("\nFinal Evaluation Metrics:")
    print(f"- Episode reward: {eval_df['eval/episode_reward'].iloc[-1]:.3f}")
    print(f"- Episode length: {eval_df['eval/episode_length'].iloc[-1]:.1f}")
    print(f"- Bins used: {eval_df['eval/bins_used'].iloc[-1]:.1f}")

## 6. Trained Agent Demonstration

Let's test our trained agent and compare it with the First Fit baseline.


In [None]:
# Test trained agent vs First Fit on multiple episodes
num_test_episodes = 10

agent_results = []
first_fit_results = []

print("Comparing Trained Agent vs First Fit:")
print("=" * 50)

for episode in range(num_test_episodes):
    # Test environment
    test_env = BinPackingEnv(
        bin_capacity=1.0, max_bins=10, max_items=15, item_size_range=(0.1, 0.4)
    )

    # Initialize same episode for both
    episode_key = random.PRNGKey(episode + 100)
    initial_state = test_env.reset(episode_key, num_items=8)

    # Test trained agent
    agent_state = initial_state
    agent_reward = 0
    agent_steps = 0

    while not agent_state.done and agent_steps < 20:
        valid_actions = test_env.get_valid_actions(agent_state)
        action_key = random.PRNGKey(agent_steps + episode * 100)
        action, _, _ = trainer.agent.select_action(
            trainer.params, agent_state, action_key, valid_actions
        )
        agent_state, reward, _ = test_env.step(agent_state, action, action_key)
        agent_reward += reward
        agent_steps += 1

    agent_bins_used = jnp.sum(agent_state.bin_utilization > 0)

    # Test First Fit
    ff_state = initial_state
    ff_reward = 0
    ff_steps = 0

    while not ff_state.done and ff_steps < 20:
        action_idx = first_fit_strategy(ff_state)
        action = BinPackingAction(bin_idx=action_idx)
        step_key = random.PRNGKey(ff_steps + episode * 100)
        ff_state, reward, _ = test_env.step(ff_state, action, step_key)
        ff_reward += reward
        ff_steps += 1

    ff_bins_used = jnp.sum(ff_state.bin_utilization > 0)

    # Store results
    agent_results.append(
        {
            "episode": episode,
            "reward": float(agent_reward),
            "bins_used": int(agent_bins_used),
            "steps": agent_steps,
            "success": bool(agent_state.done and agent_steps < 20),
        }
    )

    first_fit_results.append(
        {
            "episode": episode,
            "reward": float(ff_reward),
            "bins_used": int(ff_bins_used),
            "steps": ff_steps,
            "success": bool(ff_state.done and ff_steps < 20),
        }
    )

    print(
        f"Episode {episode + 1}: Agent={agent_bins_used} bins, FF={ff_bins_used} bins"
    )

# Convert to DataFrames
agent_df = pd.DataFrame(agent_results)
ff_df = pd.DataFrame(first_fit_results)

print("\nComparison Results:")
print(
    f"Trained Agent - Avg Bins: {agent_df['bins_used'].mean():.1f}, Avg Reward: {agent_df['reward'].mean():.2f}"
)
print(
    f"First Fit - Avg Bins: {ff_df['bins_used'].mean():.1f}, Avg Reward: {ff_df['reward'].mean():.2f}"
)
print(
    f"Success Rate - Agent: {agent_df['success'].mean():.1%}, First Fit: {ff_df['success'].mean():.1%}"
)

In [None]:
# Visualize comparison
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5))

# Bins used comparison
x = np.arange(len(agent_df))
width = 0.35

ax1.bar(
    x - width / 2,
    agent_df["bins_used"],
    width,
    label="Trained Agent",
    alpha=0.7,
    color="blue",
)
ax1.bar(
    x + width / 2, ff_df["bins_used"], width, label="First Fit", alpha=0.7, color="red"
)
ax1.set_xlabel("Episode")
ax1.set_ylabel("Bins Used")
ax1.set_title("Bins Used Comparison")
ax1.legend()
ax1.grid(True, alpha=0.3)

# Reward comparison
ax2.bar(
    x - width / 2,
    agent_df["reward"],
    width,
    label="Trained Agent",
    alpha=0.7,
    color="blue",
)
ax2.bar(
    x + width / 2, ff_df["reward"], width, label="First Fit", alpha=0.7, color="red"
)
ax2.set_xlabel("Episode")
ax2.set_ylabel("Total Reward")
ax2.set_title("Reward Comparison")
ax2.legend()
ax2.grid(True, alpha=0.3)

# Box plot comparison
bins_data = [agent_df["bins_used"], ff_df["bins_used"]]
ax3.boxplot(bins_data, labels=["Trained Agent", "First Fit"])
ax3.set_ylabel("Bins Used")
ax3.set_title("Bins Used Distribution")
ax3.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Statistical comparison
from scipy import stats

# Perform t-test
t_stat, p_value = stats.ttest_ind(agent_df["bins_used"], ff_df["bins_used"])
print("\nStatistical Test (t-test):")
print(f"T-statistic: {t_stat:.4f}")
print(f"P-value: {p_value:.4f}")
print(f"Significant difference: {'Yes' if p_value < 0.05 else 'No'}")

# Improvement calculation
improvement = (
    (ff_df["bins_used"].mean() - agent_df["bins_used"].mean())
    / ff_df["bins_used"].mean()
    * 100
)
print(f"\nAgent improvement over First Fit: {improvement:.1f}%")

## 7. Performance Analysis

Let's analyze the computational performance of our JAX implementation.


In [None]:
# Performance benchmarking
print("Performance Benchmarking:")
print("=" * 30)

# Environment step performance
test_env = BinPackingEnv()
test_state = test_env.reset(random.PRNGKey(0))
test_action = BinPackingAction(bin_idx=0)

# Warmup
for _ in range(10):
    test_env.step(test_state, test_action, random.PRNGKey(0))

# Benchmark environment steps
num_steps = 1000
start_time = time.time()
for i in range(num_steps):
    test_env.step(test_state, test_action, random.PRNGKey(i))
env_time = time.time() - start_time

print(f"Environment steps: {num_steps / env_time:.0f} steps/sec")

# Benchmark network forward pass
network = create_network("simple", hidden_dim=64, max_bins=10)
params = network.init(random.PRNGKey(0), test_state, training=False)

# Warmup
for _ in range(10):
    network.apply(params, test_state, training=False)

# Benchmark network
start_time = time.time()
for _ in range(num_steps):
    network.apply(params, test_state, training=False)
net_time = time.time() - start_time

print(f"Network forward pass: {num_steps / net_time:.0f} forward/sec")

# Benchmark vectorized operations
from binax.environment import make_vectorized_env

env_params = {
    "bin_capacity": 1.0,
    "max_bins": 10,
    "max_items": 20,
    "item_size_range": (0.1, 0.5),
}

num_envs = 32
reset_fn, step_fn, _ = make_vectorized_env(env_params, num_envs)

# Benchmark vectorized steps
vec_states = reset_fn(random.PRNGKey(0))
vec_actions = BinPackingAction(bin_idx=jnp.zeros(num_envs, dtype=jnp.int32))

# Warmup
for _ in range(10):
    step_fn(vec_states, vec_actions, random.PRNGKey(0))

num_vec_steps = 100
start_time = time.time()
for i in range(num_vec_steps):
    step_fn(vec_states, vec_actions, random.PRNGKey(i))
vec_time = time.time() - start_time

effective_steps = num_vec_steps * num_envs
print(f"Vectorized steps: {effective_steps / vec_time:.0f} steps/sec ({num_envs} envs)")

# Memory usage estimation
param_size = sum(x.size for x in jax.tree_leaves(params)) * 4  # 4 bytes per float32
state_size = sum(x.size for x in jax.tree_leaves(test_state)) * 4

print("\nMemory Usage:")
print(f"- Network parameters: {param_size / 1024:.1f} KB")
print(f"- Single state: {state_size} bytes")
print(f"- {num_envs} states: {state_size * num_envs / 1024:.1f} KB")

## 8. Conclusion

This demo showcased the key features of BinAX:

### Key Highlights:

1. **Environment**: Flexible bin packing environment with configurable parameters
2. **Networks**: Both attention-based and simple architectures
3. **Algorithm**: Complete PPO implementation with GAE
4. **Performance**: JAX-powered high-performance computing
5. **Evaluation**: Comprehensive metrics and comparisons

### Next Steps:

- Scale up training with more environments and longer training
- Experiment with different network architectures
- Test on more complex bin packing variants
- Compare with other RL algorithms
- Benchmark against more sophisticated heuristics

### Usage:

```python
# For full-scale training
config = TrainingConfig(
    total_timesteps=1_000_000,
    num_envs=64,
    network_type="attention",
    use_wandb=True
)

trainer = Trainer(config)
trainer.train()
```

BinAX provides a solid foundation for reinforcement learning research in combinatorial optimization!
