## 1. Using the getiaction Python API

The most direct way to use Groot is through the `getiaction` API.

### 1.1 Basic Instantiation

In [None]:
from getiaction.policies.groot import Groot

Groot??

In [None]:
policy = Groot()

print(f"Policy created: {type(policy).__name__}")
print(f"Trainable parameters: {sum(p.numel() for p in policy.parameters() if p.requires_grad):,}")

### 1.2 Training with getiaction Trainer

The `getiaction.train.Trainer` is a Lightning subclass with conveniences for robotics training.

In [None]:
from getiaction.data.lerobot import LeRobotDataModule
from getiaction.policies.groot import Groot
from getiaction.train import Trainer

# Create policy (24GB GPU settings)
policy = Groot(
    base_model_path="nvidia/GR00T-N1.5-3B",
    embodiment_tag="new_embodiment",
    chunk_size=50,
    n_action_steps=50,
    tune_projector=True,
    tune_diffusion_model=False,  # Set True for 48GB+ GPUs
)

# Create data module
datamodule = LeRobotDataModule(
    repo_id="lerobot/aloha_sim_transfer_cube_human",
    data_format="lerobot",
    train_batch_size=1,  # Use 1 for 24GB GPUs
)

# Create trainer
trainer = Trainer(
    max_epochs=10,
    accelerator="gpu",  # Use "xpu" for Intel GPUs
    devices=1,
    precision="bf16-mixed",
    accumulate_grad_batches=16,  # Effective batch size = 1 × 16 = 16
    gradient_clip_val=1.0,
    log_every_n_steps=10,
    fast_dev_run=True,  # Set to False to run full training
)

# Train!
# trainer.fit(policy, datamodule)

### 1.3 Inference with Groot

After training, use the policy for inference.

In [None]:
import torch

# Prepare observation (example with dummy data)
# In practice, this comes from your robot's sensors
batch = {
    # Images: (B, C, H, W) - batch, channels, height, width
    "observation.images.cam_high": torch.randn(1, 3, 224, 224),
    "observation.images.cam_left_wrist": torch.randn(1, 3, 224, 224),
    "observation.images.cam_right_wrist": torch.randn(1, 3, 224, 224),
    # Robot state: (B, state_dim)
    "observation.state": torch.randn(1, 14),
    # Task instruction
    "task": "Pick up the red cube and place it on the blue plate",
}

# Run inference (policy handles device placement automatically)
# policy.eval()
# with torch.no_grad():
#     actions = policy.select_action(batch)

# print(f"Predicted actions shape: {actions.shape}")  # (B, action_dim)

## 2. Using HuggingFace Configuration

You can also load Groot using HuggingFace's configuration system. This is useful when you want to use the same configuration format as the original NVIDIA release.

### 2.2 Creating GrootModel from HuggingFace Pretrained

The underlying `GrootModel` uses the same architecture as NVIDIA's implementation but with PyTorch native attention (SDPA) for broader hardware support.

In [None]:
from getiaction.policies.groot.model import GrootModel

# Load GrootModel from HuggingFace pretrained weights
# This automatically downloads and loads the NVIDIA weights
model = GrootModel.from_pretrained(
    "nvidia/GR00T-N1.5-3B",
    attn_implementation="sdpa",  # Use PyTorch native attention
)

print(f"Model loaded: {type(model).__name__}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

### 2.3 Wrapping GrootModel in Lightning Module

For training, wrap the model in the `Groot` Lightning module.

In [None]:
from getiaction.policies.groot import Groot

# Create Groot policy - it handles model loading internally
policy = Groot(
    base_model_path="nvidia/GR00T-N1.5-3B",
    embodiment_tag="new_embodiment",
    attn_implementation="sdpa",  # Use PyTorch native attention
)

print("Policy ready for training")

### 2.4 Using HuggingFace Datasets

The `LeRobotDataModule` can load datasets directly from HuggingFace Hub.

In [None]:
from getiaction.data.lerobot import LeRobotDataModule

# Available datasets on HuggingFace Hub (examples)
datasets = [
    "lerobot/aloha_sim_transfer_cube_human",
    "lerobot/aloha_sim_insertion_human",
    "lerobot/pusht",
    "lerobot/xarm_lift_medium",
]

# Load a dataset
datamodule = LeRobotDataModule(
    repo_id="lerobot/aloha_sim_transfer_cube_human",
    data_format="lerobot",
    train_batch_size=1,
)

# Setup and inspect
datamodule.setup("fit")
print(f"Number of training samples: {len(datamodule.train_dataset)}")

## 3. Using the CLI

The `getiaction` CLI provides a convenient way to train models using YAML configuration files.

### Basic Usage (24GB GPUs)

```bash
getiaction fit --config configs/getiaction/groot.yaml
```

### Intel XPU Support

```bash
getiaction fit --config configs/getiaction/groot.yaml --trainer.accelerator xpu
```

### GPU-Specific Overrides

**48GB GPU** (A6000, L40):
```bash
getiaction fit --config configs/getiaction/groot.yaml \
    --model.tune_diffusion_model true \
    --data.train_batch_size 2 \
    --trainer.accumulate_grad_batches 8
```

**80GB GPU** (A100, H100):
```bash
getiaction fit --config configs/getiaction/groot.yaml \
    --model.tune_diffusion_model true \
    --data.train_batch_size 4 \
    --trainer.accumulate_grad_batches 4
```

### Using a Different Dataset

```bash
getiaction fit --config configs/getiaction/groot.yaml --data.repo_id lerobot/pusht
```

## 4. Memory Requirements Summary

| Configuration | Trainable Params | VRAM Required |
|--------------|------------------|---------------|
| `tune_projector` only (default) | ~518M | 24GB |
| `tune_projector` + `tune_diffusion_model` | ~1.1B | 48GB |
| + `tune_llm` or `tune_visual` | ~2.7B+ | 80GB+ |

## 5. Tips and Best Practices

### Hardware Selection
- **24GB GPUs** (3090, 4090, Intel B580): Use default settings (`tune_projector=True`, `tune_diffusion_model=False`)
- **48GB+ GPUs**: Enable `tune_diffusion_model=True` for better results
- **Intel XPU**: Use `attn_implementation="sdpa"` (default) - works natively

### Training Tips
- Start with the default configuration and adjust based on loss curves
- Use gradient accumulation to achieve larger effective batch sizes

### Attention Implementation
- `sdpa`: PyTorch native, works on CUDA and XPU (default)
- `flash_attention_2`: Requires CUDA and flash-attn package
- `eager`: Fallback, slower but most compatible