# Batch

The `Batch` class serves as a fundamental data structure within Tianshou, designed to efficiently store and manipulate hierarchical named tensors. This tutorial provides comprehensive guidance on understanding the conceptual foundations and operational behavior of `Batch`, enabling users to fully leverage Tianshou's capabilities.

The tutorial is organized into three sections: first, we establish the concept of hierarchical named tensors; second, we introduce basic `Batch` operations; and third, we explore advanced topics.

## Hierarchical Named Tensors

Hierarchical named tensors refer to a collection of tensors whose identifiers form a structured hierarchy. Consider a set of four tensors `[t1, t2, t3, t4]` with corresponding names `[name1, name2, name3, name4]`, where `name1` and `name2` reside within namespace `name0`. In this configuration, the fully qualified name of tensor `t1` becomes `name0.name1`, demonstrating how hierarchy manifests through tensor naming conventions.

The structure of hierarchical named tensors can be represented using a tree data structure. This representation includes a virtual root node representing the entire object, with internal nodes serving as keys (names) and leaf nodes containing values (scalars or tensors).

<div align=center>
<img src="../_static/images/batch_tree.png" style="width:50%" title="data flow">
</div>

The necessity for hierarchical named tensors arises from the inherent heterogeneity of reinforcement learning problems. While the RL abstraction is elegantly simple:

```python
state, reward, done = env.step(action)
```

The `reward` and `done` components are typically scalar values. However, both `state` and `action` exhibit significant variation across different environments. For instance, a `state` may be represented as a simple vector, a tensor, or a combination of camera and sensory inputs. In the latter case, hierarchical named tensors provide a natural storage mechanism. This hierarchical structure extends beyond `state` and `action` to encompass all transition components (`state`, `action`, `reward`, `done`) within a unified hierarchical framework.

While storing hierarchical named tensors is straightforward using nested dictionary structures:

```python
{
    'done': done,
    'reward': reward,
    'state': {
        'camera': camera,
        'sensory': sensory
    },
    'action': {
        'direct': direct,
        'point_3d': point_3d,
        'force': force,
    }
}
```

The challenge lies in **manipulating** these structures efficiently—for example, when adding new transition tuples to a replay buffer while handling their heterogeneity. The `Batch` class addresses this challenge by providing streamlined methods to create, store, and manipulate hierarchical named tensors.

`Batch` can be conceptualized as a NumPy-enhanced Python dictionary. It shares similarities with PyTorch's `tensordict`, though with distinct type structure characteristics.

<div align=center>
<img src="../_static/images/concepts_arch.png", title="data flow">
Data flow
</div>

In [None]:
import pickle

import numpy as np
import torch

from tianshou.data import Batch

## Basic Usage

This section covers fundamental `Batch` operations, including the contents of `Batch` objects, construction methods, and manipulation techniques.

### Content Specification

The content of `Batch` objects is defined by the following rules:

1. A `Batch` object may be empty (`Batch()`) or contain at least one key-value pair. Empty `Batch` objects can be utilized for key reservation (detailed in the Advanced Topics section).

2. Keys must be strings, serving as identifiers for their corresponding values.

3. Values may be scalars, tensors, or `Batch` objects. This recursive definition enables the construction of hierarchical batch structures.

4. Tensors constitute the primary value type. Tensors are n-dimensional arrays of uniform data type. Two tensor types are supported: [PyTorch](https://pytorch.org/) tensor type `torch.Tensor` and [NumPy](https://numpy.org/) tensor type `np.ndarray`.

5. Scalars represent valid values, comprising single boolean values, numbers, or objects. These include Python scalars (`False`, `1`, `2.3`, `None`, `'hello'`) and NumPy scalars (`np.bool_(True)`, `np.int32(1)`, `np.float64(2.3)`). Scalars must not be conflated with `Batch`/dict/tensor types.

**Note:** `Batch` objects cannot directly store `dict` objects due to internal implementation using dictionaries for data storage. During construction, `dict` objects are automatically converted to `Batch` objects.

Supported tensor data types include boolean and numeric types (any integer or floating-point precision supported by NumPy or PyTorch). NumPy's support for object arrays enables storage of non-numeric data types within `Batch`. For data that are neither boolean nor numeric (e.g., strings, sets), storage within `np.ndarray` with `np.object` data type is supported, allowing `Batch` to accommodate arbitrary Python objects.

In [None]:
data = Batch(a=4, b=[5, 5], c="2312312", d=("a", -2, -3))
print(data)
print(data.b)

A `Batch` object stores all input data as key-value pairs and automatically converts values to NumPy arrays when applicable.

### Construction Methods

Two primary construction methods are available for `Batch` objects: construction from a dictionary, or using keyword arguments. The following examples demonstrate these approaches.

#### Dictionary-Based Construction

In [None]:
# Direct dictionary passing (potentially nested) is supported
data = Batch({"a": 4, "b": [5, 5], "c": "2312312"})
# Lists are automatically converted to NumPy arrays
print(data.b)
data.b = np.array([3, 4, 5])
print(data)

In [None]:
# Lists of dictionary objects (potentially nested) are automatically stacked
data = Batch([{"a": 0.0, "b": "hello"}, {"a": 1.0, "b": "world"}])
print(data)

#### Keyword Argument Construction

In [None]:
# Construction using keyword arguments
data = Batch(a=[4, 4], b=[5, 5], c=[None, None])
print(data)

In [None]:
# Combining dictionary and keyword arguments
data = Batch(
    {"a": [4, 4], "b": [5, 5]}, c=[None, None]
)  # First argument is a dictionary; 'c' is a keyword argument
print(data)

In [None]:
arr = np.zeros((3, 4))
# By default, Batch maintains references to data; explicit copying is supported via the copy parameter
data = Batch(arr=arr, copy=True)  # data.arr is now a copy of 'arr'

#### Nested Batch Construction

In [None]:
# Nested dictionaries are converted to nested Batch objects
data = {
    "action": np.array([1.0, 2.0, 3.0]),
    "reward": 3.66,
    "obs": {
        "rgb_obs": np.zeros((3, 3)),
        "flatten_obs": np.ones(5),
    },
}

batch = Batch(data, extra="extra_string")
print(batch)
# batch.obs is also a Batch instance
print(type(batch.obs))
print(batch.obs.rgb_obs)

In [None]:
# Lists of dictionaries/Batches are automatically concatenated/stacked
# This feature facilitates data collection from parallelized environments
batch = Batch([data] * 3)
print(batch)
print(batch.obs.rgb_obs.shape)

### Data Manipulation

Internal data can be accessed using either `b.key` or `b[key]` notation, where `b.key` retrieves the subtree rooted at `key`. When the result is a non-empty subtree, key references can be chained (e.g., `b.key.key1.key2.key3`). Upon reaching a leaf node, the stored data (scalars or tensors) is returned.

In [None]:
data = Batch(a=4, b=[5, 5])
print(data.b)
# Attribute access (obj.key) is equivalent to dictionary access (obj["key"])
print(data["a"])

In [None]:
# Dictionary-style iteration over items is supported
for key, value in data.items():
    print(f"{key}: {value}")

In [None]:
# Methods keys() and values() behave analogously to their dict counterparts
for key in data.keys():
    print(f"{key}")

In [None]:
# The update() method operates analogously to dict.update()
# Equivalent to: data.c = 1; data.d = 2; data.e = 3;
data.update(c=1, d=2, e=3)
print(data)

In [None]:
# Adding and deleting key-value pairs
batch1 = Batch({"a": [4, 4], "b": (5, 5)})
print(batch1)

batch1.c = Batch(c1=np.arange(3), c2=False)
del batch1.a
print(batch1)

# Accessing values by key
assert batch1["c"] is batch1.c
print("c" in batch1)

**Important Note:** While `for x in data` iterates over keys when `data` is a `dict` object, for `Batch` objects this syntax iterates over `data[0], data[1], ..., data[-1]`.

### Length, Shape, Indexing, and Slicing

`Batch` implements a subset of NumPy ndarray APIs, supporting advanced slicing operations (e.g., `batch[:, i]`) provided the slice is valid. NumPy's broadcasting mechanism is also supported.

In [None]:
# Initializing Batch with tensors
data = Batch(a=np.array([[0.0, 2.0], [1.0, 3.0]]), b=[[5.0, -5.0], [1.0, -2.0]])
# When all values share the same length/shape, the Batch adopts that length/shape
print(len(data))
print(data.shape)

In [None]:
# Accessing the first element of all stored tensors while preserving Batch structure
print(data[0])

In [None]:
# Iteration over data[0], data[1], ..., data[-1]
for sample in data:
    print(sample.a)

In [None]:
# Advanced slicing with arithmetic operations and broadcasting
data[:, 1] += 1
print(data)

In [None]:
# Direct application of NumPy functions to Batch objects
print(np.mean(data))

In [None]:
# Conversion to list is supported
list(data)

#### Environment Stepping Example

In [None]:
# Example: Data collected from four parallel environments
step_outputs = [
    {
        "act": np.random.randint(10),
        "rew": 0.0,
        "obs": np.ones((3, 3)),
        "info": {"done": np.random.choice(2), "failed": False},
        "terminated": False,
        "truncated": False,
    }
    for _ in range(4)
]
batch = Batch(step_outputs)
print(batch)
print(batch.shape)

In [None]:
# Advanced indexing for selecting data from specific environments
print(batch[0])
print(batch[[0, 3]])

In [None]:
# Slicing operations are supported
print(batch[-2:])

### Stack, Concatenate, and Split Operations

Tianshou provides intuitive methods for stacking and concatenating multiple `Batch` instances, as well as splitting instances into multiple batches. Currently, we focus on aggregation (stack/concatenate) of homogeneous (structurally identical) batches.

In [None]:
data_1 = Batch(a=np.array([0.0, 2.0]), b=5)
data_2 = Batch(a=np.array([1.0, 3.0]), b=-5)
data = Batch.stack((data_1, data_2))
print(data)

In [None]:
# Split operation with optional shuffling
data_split = list(data.split(1, shuffle=False))
print(data_split)

In [None]:
data_cat = Batch.cat(data_split)
print(data_cat)

#### Additional Concatenation and Stacking Examples

In [None]:
# Concatenating batches with compatible keys
b1 = Batch(a=[{"b": np.float64(1.0), "d": Batch(e=np.array(3.0))}])
b2 = Batch(a=[{"b": np.float64(4.0), "d": {"e": np.array(6.0)}}])
b12_cat_out = Batch.cat([b1, b2])
print(b1)
print(b2)
print(b12_cat_out)

In [None]:
# Stacking batches with compatible keys along specified axis
b3 = Batch(a=np.zeros((3, 2)), b=np.ones((2, 3)), c=Batch(d=[[1], [2]]))
b4 = Batch(a=np.ones((3, 2)), b=np.ones((2, 3)), c=Batch(d=[[0], [3]]))
b34_stack = Batch.stack((b3, b4), axis=1)
print(b3)
print(b4)
print(b34_stack)

In [None]:
# Splitting batch into unit-sized batches with optional shuffling
print(type(b34_stack.split(1)))
print(list(b34_stack.split(1, shuffle=True)))

### Data Type Conversion

While `Batch` supports both NumPy arrays and PyTorch Tensors with identical usage patterns, seamless conversion between these types is provided.

In [None]:
batch1 = Batch(a=np.arange(2), b=torch.zeros((2, 2)))
batch2 = Batch(a=np.arange(2), b=torch.ones((2, 2)))
batch_cat = Batch.cat([batch1, batch2, batch1])
print(batch_cat)

Data type conversion is straightforward when uniform data types are desired.

In [None]:
data = Batch(a=np.zeros((3, 4)))
data.to_torch_(dtype=torch.float32, device="cpu")
print(data.a)
# Conversion to NumPy is also supported via to_numpy_()
data.to_numpy_()
print(data.a)

In [None]:
batch_cat.to_numpy_()
print(batch_cat)
batch_cat.to_torch_()
print(batch_cat)

### Serialization

`Batch` objects are serializable and compatible with Python's `pickle` module, enabling persistent storage and restoration. This capability is particularly important for distributed environment sampling.

In [None]:
batch = Batch(obs=Batch(a=0.0, c=torch.Tensor([1.0, 2.0])), np=np.zeros([3, 4]))
batch_pk = pickle.loads(pickle.dumps(batch))
print(batch_pk)

## Advanced Topics

This section addresses advanced `Batch` concepts, including key reservation mechanisms, detailed length and shape semantics, and aggregation of heterogeneous batches.

### Key Reservation

In many scenarios, the key structure is known in advance while value shapes remain undetermined until runtime (e.g., after environment execution). Tianshou supports key reservation through placeholder values.

<div style="text-align: center; padding: 1rem;">
<img src="../_static/images/batch_reserve.png" style="width: 50%; padding-bottom: 1rem;"><br>
Structure of a batch with reserved keys
</div>

Key reservation is implemented using empty `Batch()` objects as placeholder values.

In [None]:
a = Batch(b=Batch())  # 'b' is a reserved key
print(a)

# Hierarchical key reservation
a = Batch(b=Batch(c=Batch()), d=Batch())  # 'c' and 'd' are reserved keys
print(a)

a = Batch(key1=np.array([1, 2]), key2=np.array([3, 4]), key3=Batch(key4=Batch(), key5=Batch()))
print(a)

The structure of `Batch` objects with reserved keys can be visualized using tree notation, where reserved keys represent internal nodes lacking attached leaf nodes.

**Important:** Reserved keys indicate that values will eventually be assigned. These values may be scalars, tensors, or `Batch` objects. Understanding this concept is essential for working with heterogeneous batches.

The introduction of reserved keys necessitates verification methods.

In [None]:
# Examples of checking whether a Batch is empty
print(len(Batch().get_keys()) == 0)
print(len(Batch(a=Batch(), b=Batch(c=Batch())).get_keys()) == 0)
print(len(Batch(a=Batch(), b=Batch(c=Batch()))) == 0)
print(len(Batch(d=1).get_keys()) == 0)
print(len(Batch(a=np.float64(1.0)).get_keys()) == 0)

To verify emptiness, use `len(Batch.get_keys()) == 0` for direct emptiness (a simple `Batch()`) or `len(Batch) == 0` for recursive emptiness (a `Batch` without scalar or tensor leaf nodes).

**Note:** The `Batch.empty` attribute differs from emptiness checking. `Batch.empty` and its in-place variant `Batch.empty_` are used to reset values to zeros or `None`. Consult the API documentation for additional details.

### Length and Shape Semantics

The primary use case for `Batch` is storing batched data collections. The term "Batch" originates from deep learning terminology, denoting mini-batches sampled from datasets. Typically, a "Batch" represents a collection of tensors sharing a common first dimension, with batch size corresponding to the `Batch` object's length.

When all leaf nodes in a `Batch` object are tensors but possess different lengths, storage within `Batch` remains possible. However, the semantics of `len(obj)` become ambiguous. Currently, Tianshou returns the minimum tensor length, though we strongly recommend avoiding `len(obj)` operations on `Batch` objects containing tensors of varying lengths.

In [None]:
# Length and shape examples for Batch objects
data = Batch(a=[5.0, 4.0], b=np.zeros((2, 3, 4)))
print(data.shape)
print(len(data))
print(data[0].shape)
try:
    len(data[0])
except TypeError as e:
    print(f"TypeError: {e}")

**Important:** Following scientific computing conventions, scalars possess no length. If any scalar leaf node exists in a `Batch` object, invoking `len(obj)` raises an exception.

Similarly, reserved keys have undetermined values and therefore no defined length (or equivalently, **arbitrary** length). When tensors and reserved keys coexist, the latter are ignored in `len(obj)` calculations, returning the minimum tensor length. When no tensors exist in the `Batch` object, Tianshou raises an exception.

The `obj.shape` attribute exhibits similar behavior to `len(obj)`:

1. When all leaf nodes are tensors with identical shapes, that shape is returned.

2. When all leaf nodes are tensors with differing shapes, the minimum length per dimension is returned.

3. When any scalar value exists, `obj.shape` returns `[]`.

4. Reserved keys have undetermined shape, treated as `[]`.

### Aggregation of Heterogeneous Batches

This section examines aggregation operations (stack/concatenate) on heterogeneous `Batch` objects, focusing on structural heterogeneity. Aggregation operations ultimately invoke NumPy/PyTorch operators (`np.stack`, `np.concatenate`, `torch.stack`, `torch.cat`). Value heterogeneity that violates these operators' requirements (e.g., stacking `np.ndarray` with `torch.Tensor`, or stacking tensors with incompatible shapes) results in exceptions.

<div style="text-align: center; padding: 1rem;">
<img src="../_static/images/aggregation.png" style="width: 100%; padding-bottom: 0rem;"><br>
</div>

The behavior is intuitive: keys not shared across all batches are padded with zeros (or `None` for `np.object` data type) in batches lacking these keys.

In [None]:
# Stack example: batch a lacks key 'b', batch b lacks key 'a'
a = Batch(a=np.zeros([4, 4]), common=Batch(c=np.zeros([4, 5])))
b = Batch(b=np.zeros([4, 6]), common=Batch(c=np.zeros([4, 5])))
c = Batch.stack([a, b])
print(c.a.shape)
print(c.b.shape)
print(c.common.c.shape)

In [None]:
# Automatic padding with None or 0 using appropriate shapes
data_1 = Batch(a=np.array([0.0, 2.0]))
data_2 = Batch(a=np.array([1.0, 3.0]), b="done")
data = Batch.stack((data_1, data_2))
print(data)

In [None]:
# Concatenation example: batch a lacks key 'b', batch b lacks key 'a'
a = Batch(a=np.zeros([3, 4]), common=Batch(c=np.zeros([3, 5])))
b = Batch(b=np.zeros([4, 3]), common=Batch(c=np.zeros([4, 5])))
# Note: Recent changes have modified concatenation behavior for heterogeneous batches
# The following operation is no longer supported:
# c = Batch.cat([a, b])
# print(c.a.shape)
# print(c.b.shape)
# print(c.common.c.shape)

However, certain cases of extreme heterogeneity prevent aggregation:

In [None]:
# Example of incompatible batches that cannot be aggregated
try:
    a = Batch(a=np.zeros([4, 4]))
    b = Batch(a=Batch(b=Batch()))
    c = Batch.stack([a, b])
except Exception as e:
    print(f"Exception: {e}")

How can we determine if batches can be aggregated? Reconsider the purpose of reserved keys. The distinction between `a1=Batch(b=Batch())` and `a2=Batch()` is that `a1.b` returns `Batch()` while `a2.b` raises an exception. **Reserved keys enable attribute reference for future value assignment.**

A key chain `k=[key1, key2, ..., keyn]` applies to `b` if the expression `b.key1.key2.{...}.keyn` is valid, with the result being `b[k]`.

For a set of `Batch` objects S, aggregation is possible if there exists a `Batch` object `b` satisfying:

1. **Key chain applicability:** For any object `bi` in S and any key chain `k`, if `bi[k]` is valid, then `b[k]` must be valid.

2. **Type consistency:** If `bi[k]` is not `Batch()` (the final key in the chain is not reserved), then the type of `b[k]` must match `bi[k]` (both must be scalar/tensor/non-empty Batch values).

The `Batch` object `b` satisfying these rules with minimal keys determines the aggregation structure. Values are defined as follows: for any applicable key chain `k`, `b[k]` represents the stack/concatenation of `[bi[k] for bi in S]` (with appropriate zero or `None` padding when `k` does not apply to `bi`). When all `bi[k]` are `Batch()`, the aggregation result is also an empty `Batch()`.

### Additional Considerations

1. Environment observations typically utilize NumPy ndarrays, while policies require `torch.Tensor` for prediction and learning. Tianshou provides helper functions for in-place conversion between NumPy arrays and Torch tensors.

2. `obj.stack_([a, b])` is equivalent to `Batch.stack([obj, a, b])`, and `obj.cat_([a, b])` is equivalent to `Batch.cat([obj, a, b])`. For frequently required two-batch concatenation, `obj.cat_(a)` serves as an alias for `obj.cat_([a])`.

3. `Batch.cat` and `Batch.cat_` currently do not support the `axis` argument available in `np.concatenate` and `torch.cat`.

4. `Batch.stack` and `Batch.stack_` support the `axis` argument, enabling stacking along dimensions beyond the first. However, when keys are not shared across all batches, `stack` with `axis != 0` is undefined and currently raises an exception.