In [1]:
import numpy as np
import tinygrad
from tinygrad import Tensor, nn, dtypes
from tinygrad.dtype import ConstType
from tinygrad.nn.optim import Optimizer
import einops

In [2]:
import math
import warnings
from collections import deque
from typing import Callable, List, Union, Literal

# Vector-Quantized Behavior Transformers (VQ-BeTs)

Vector-Quantized Behavior Transformers (VQ-BeTs) represent a significant advancement in behavior modeling and generation. They address key limitations of previous approaches like Behavior Transformers (BeT) by introducing a hierarchical vector quantization module for tokenizing continuous actions.

Why VQ-BeTs?

1. Handling Multimodal Action Distributions: Real-world behaviors often have multiple valid actions for a given state. VQ-BeTs can capture this multimodality more effectively than methods using simple discretization like k-means clustering.

2. Scalability: Unlike k-means, the hierarchical vector quantization in VQ-BeTs scales well to high-dimensional action spaces and long sequences, making it suitable for complex, long-horizon tasks.

3. Gradient Information: The vector quantization process allows for gradient flow, enabling more effective training compared to non-differentiable discretization methods.

4. Versatility: VQ-BeTs can handle both conditional and unconditional behavior generation, as well as partial observations, making them applicable to a wide range of tasks.

5. Improved Performance: Across various environments, including simulated manipulation, autonomous driving, and robotics, VQ-BeTs have shown superior performance compared to state-of-the-art models like BeT and Diffusion Policies.

6. Faster Inference: VQ-BeTs demonstrate significantly faster inference times compared to some alternatives, such as Diffusion Policies.

# Implementing a Multi-Layer Perceptron (MLP) in tinygrad

In this cell, we'll define a simple Multi-Layer Perceptron (MLP) class using tinygrad. An MLP is a type of feedforward neural network consisting of multiple layers of neurons. 

Our MLP class will:
1. Accept an input dimension and a list of hidden layer dimensions
2. Construct a sequence of linear layers with ReLU activations between them
3. Use the last dimension in the list as the output dimension

Key points:
- We use `nn.Linear` for the fully connected layers
- ReLU activation is applied after each hidden layer
- The final layer doesn't have an activation function, allowing for flexible use in various tasks

Let's implement the MLP class:

In [3]:
class MLP:
    def __init__(
        self,
        in_channels: int,
        hidden_channels: List[int],
    ):
        layers = []
        in_dim = in_channels
        for hidden_dim in hidden_channels[:-1]:
            layers.append(nn.Linear(in_dim, hidden_dim))
            layers.append(Tensor.relu)
            in_dim = hidden_dim
        self.fc = nn.Linear(in_dim, hidden_channels[-1])
        self.layers = layers
    def __call__(self, x:Tensor) -> Tensor:
        return self.fc(x.sequential(self.layers))

# Implementing Focal Loss in tinygrad

Next, we'll implement the Focal Loss function, which is particularly useful for dealing with class imbalance in classification tasks. This implementation is adapted from the miniBET project.

Focal Loss modifies standard cross-entropy loss to down-weight easy examples and focus training on hard negatives. It introduces a focusing parameter γ (gamma) that adjusts the rate at which easy examples are down-weighted.

Key features of this implementation:
1. Supports both 2D and 3D input tensors
2. Uses log_softmax for numerical stability
3. Allows for mean or sum reduction of the loss

Let's examine the FocalLoss class:

In [4]:
class FocalLoss:
    """
    From https://github.com/notmahi/miniBET/blob/main/behavior_transformer/bet.py
    """

    def __init__(self, gamma: float = 0, size_average: bool = True):
        super().__init__()
        self.gamma = gamma
        self.size_average = size_average

    def __call__(self, x: Tensor, target: Tensor) -> Tensor:
        if len(x.shape) == 3:
            N, T, _ = x.shape
            logpt = x.log_softmax(axis=-1)
            logpt = logpt.gather(-1, target.view(N, T, 1)).view(N, T)
        elif len(x.shape) == 2:
            logpt = x.log_softmax(axis=-1)
            logpt = logpt.gather(-1, target.view(-1, 1)).view(-1)
        pt = logpt.exp()

        loss = -1 * (1 - pt).pow(self.gamma) * logpt
        return loss.mean() if self.size_average else loss.sum()


# Implementing the cdist Function for VQ-BeTs

In the context of Vector-Quantized Behavior Transformers (VQ-BeTs), calculating distances between vectors is a crucial operation, particularly in the vector quantization process. The `cdist` function computes the pairwise distance between two sets of vectors, which is essential for finding the nearest codebook vectors during quantization.

Why is this important for VQ-BeTs?

1. Vector Quantization: VQ-BeTs use vector quantization to discretize continuous action spaces. This process involves finding the nearest codebook vector for each input vector, which requires computing distances.

2. Codebook Updates: During training, the codebook vectors are updated based on their proximity to input vectors. The `cdist` function helps in determining these proximities efficiently.

3. Efficiency: By implementing `cdist` using matrix operations, we can compute all pairwise distances in parallel, which is much faster than iterating over vectors individually. This is crucial for the performance of VQ-BeTs, especially when dealing with large action spaces or long sequences.

4. Flexibility: While we're primarily using Euclidean distance (p=2), the function is designed to potentially accommodate other distance metrics, which could be useful for experimenting with different quantization schemes.

Let's examine the implementation of the `cdist` function:

In [5]:
def cdist(x1, x2, p=2.0):
    # Ensure inputs are Tensors
    x1 = x1 if isinstance(x1, Tensor) else Tensor(x1)
    x2 = x2 if isinstance(x2, Tensor) else Tensor(x2)
    
    # Compute squared Euclidean distance
    x1_norm = (x1**2).sum(axis=-1, keepdim=True)
    x2_norm = (x2**2).sum(axis=-1, keepdim=True)
    
    cross_term = x1 @ x2.transpose()
    
    dist = x1_norm + x2_norm.transpose() - 2 * cross_term
    
    # Handle floating point errors (negative distances)
    dist = dist.maximum(Tensor.zeros_like(dist))
    
    # For p=2 (Euclidean distance), take the square root
    if p == 2:
        return dist.sqrt()
    else:
        # For other p-norms, we'd need to implement a different calculation
        raise NotImplementedError("Only p=2 (Euclidean distance) is currently implemented")

Tensor.cdist = cdist


In [6]:
x1 = Tensor([[0.9041, 0.0196], [-0.3108, -2.4423], [-0.4821, 1.059]])
x2 = Tensor([[-2.1763, -0.4713], [-0.6986, 1.3702]])

result = cdist(x1, x2)
print(result.numpy())

[[3.1192703  2.0958931 ]
 [2.713841   3.8321726 ]
 [2.2830093  0.37910157]]


# Utility Functions for VQ-BeTs Implementation

In implementing Vector-Quantized Behavior Transformers (VQ-BeTs), we rely on a set of utility functions that handle various operations crucial to the model's functionality. These functions cover a range of tasks from basic tensor operations to more complex procedures like k-means clustering and Gumbel-Softmax sampling.

Why are these utility functions important for VQ-BeTs?

1. Tensor Operations: Functions like `unbind`, `cdist`, and `batched_embedding` provide efficient ways to manipulate tensors, which is essential for handling the high-dimensional data in VQ-BeTs.

2. Vector Quantization: The `kmeans` function is central to the vector quantization process, allowing us to create and update codebooks.

3. Sampling and Noise: Functions like `gumbel_noise` and `gumbel_sample` are crucial for the stochastic aspects of VQ-BeTs, enabling the model to handle uncertainty and explore the action space effectively.

4. Initialization and Normalization: Functions like `normal_`, `uniform_init`, and `laplace_smoothing` help in properly initializing and normalizing model parameters, which is critical for stable training of VQ-BeTs.

5. Loss Computation: The `orthogonal_loss_fn` is used to encourage diversity in the codebook, which is important for capturing a wide range of behaviors.

6. Efficient Computation: Many of these functions (e.g., `batched_bincount`, `batched_sample_vectors`) are designed to operate efficiently on batches of data, which is crucial for training VQ-BeTs on large datasets.

These utility functions form the backbone of our VQ-BeTs implementation, enabling efficient computation and providing the necessary tools for the model's core operations. Let's examine each of these functions:

In [7]:
# General utils

def noop(*args, **kwargs):
    pass


def identity(t):
    return t

def unbind(x: Tensor):
    return tuple(x[i] for i in range(x.shape[0]))

def cdist(x, y):
    x2 = reduce(x.square(), "b n d -> b n", "sum")
    y2 = reduce(y.square(), "b n d -> b n", "sum")
    xy = Tensor.einsum("b i d, b j d -> b i j", x, y) * -2
    return (rearrange(x2, "b i -> b i 1") + rearrange(y2, "b j -> b 1 j") + xy).sqrt()


def log(t, eps=1e-20):
    return t.clamp(min_=eps).log()


def ema_inplace(old, new, decay):
    is_mps = str(old.device).startswith('METAL')

    if not is_mps:
        old.assign(old.detach().lerp(new.detach(), 1 - decay))
    else:
        old.assign(old.detach().mul(decay).add(new.detach() * (1 - decay)))

from einops import pack, rearrange, reduce, repeat, unpack

def pack_one(t, pattern):
    return pack([t], pattern)


def unpack_one(t, ps, pattern):
    return unpack(t, ps, pattern)[0]

def normal_(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> Tensor:
    """
    Fills the input Tensor with values drawn from the normal distribution N(mean, std^2).
    
    Args:
    tensor (Tensor): an n-dimensional `Tensor`
    mean (float): the mean of the normal distribution
    std (float): the standard deviation of the normal distribution
    
    Returns:
    Tensor: the modified input tensor
    """
    # Create a new tensor with values drawn from a normal distribution
    normal_tensor = Tensor.randn(*tensor.shape) * std + mean
    
    # In-place update of the input tensor
    tensor.assign(normal_tensor)
    
    return tensor

def uniform_init(*shape):
    return Tensor.kaiming_uniform(shape)


def gumbel_noise(t):
    noise = Tensor.uniform(t.shape, low=0, high=1)
    return -(-noise.log()).log()


def gumbel_sample(
    logits,
    temperature=1.0,
    stochastic=False,
    straight_through=False,
    reinmax=False,
    dim=-1,
    training=True,
):
    dtype, size = logits.dtype, logits.shape[dim]

    if training and stochastic and temperature > 0:
        sampling_logits = (logits / temperature) + gumbel_noise(logits)
    else:
        sampling_logits = logits

    ind = sampling_logits.argmax(axis=dim)
    one_hot = ind.one_hot(size).cast(dtype=dtype)

    assert not (
        reinmax and not straight_through
    ), "reinmax can only be turned on if using straight through gumbel softmax"

    if not straight_through or temperature <= 0.0 or not training:
        return ind, one_hot

    # use reinmax for better second-order accuracy - https://arxiv.org/abs/2304.08612
    # algorithm 2

    if reinmax:
        π0 = logits.softmax(axis=dim)
        π1 = (one_hot + (logits / temperature).softmax(axis=dim)) / 2
        π1 = ((log(π1) - logits).detach() + logits).softmax(axis=1)
        π2 = 2 * π1 - 0.5 * π0
        one_hot = π2 - π2.detach() + one_hot
    else:
        π1 = (logits / temperature).softmax(axis=dim)
        one_hot = one_hot + π1 - π1.detach()

    return ind, one_hot


def laplace_smoothing(x, n_categories, eps=1e-5, dim=-1):
    denom = x.sum(axis=dim, keepdim=True)
    return (x + eps) / (denom + n_categories * eps)


def sample_vectors(samples, num):
    num_samples = samples.shape[0]
    if num_samples >= num:
        # we're gonna do an ad-hoc of torch's randperm here
        indices = Tensor(np.randperm(num_samples)[:num], dtype=dtypes.int)
    else:
        indices = Tensor.randint((num,), low=0, high=num_samples)

    return samples[indices]


def batched_sample_vectors(samples, num):
    return Tensor.stack(*[sample_vectors(sample, num) for sample in unbind(samples)], dim=0)

def pad_shape(shape, size, dim=0):
    return [size if i == dim else s for i, s in enumerate(shape)]


def sample_multinomial(total_count, probs):
    remainder = Tensor(1.0)
    sample = Tensor.empty_like(probs, dtype=dtypes.long)

    for i, p in enumerate(probs):
        # my hacky way of getting binomial up in tinygrad
        u = Tensor.rand(total_count)
        s = (u <= (p / remainder)).sum(axis=-1)
        sample[i] = s
        total_count -= s
        remainder -= p

    return sample

def batched_bincount(x, *, minlength):
    batch = x.shape[0]
    
    # Create a zero tensor for the target
    target = Tensor.zeros(batch, minlength)
    
    # Iterate over each batch
    for i in range(batch):
        # Get the current batch
        x_batch = x[i]
        
        # Iterate over each element in the batch
        for j in range(x_batch.shape[0]):
            # Get the index (bin) for the current element
            index = x_batch[j]
            
            # Increment the count for this index in the target tensor
            target[i, index] += 1
    
    return target

def scatter_tensor(self, dim:int, index:Tensor, src:Union[Tensor, ConstType], reduce:Union[None, Literal['multiply'], Literal['add']] = None):
    """
    Scatters `src` values along an axis specified by `dim`.
    apply `add` or `multiply` reduction operation with `reduce`.
    """
    index, dim  = index.to(self.device), self._resolve_dim(dim)
    if not isinstance(src, Tensor): src = Tensor(src, device=self.device, dtype=self.dtype)._broadcast_to(index.shape)
    assert index.ndim == self.ndim == src.ndim, f"self.ndim must equal index.ndim, {self.ndim=}, {index.ndim=}"
    assert all((s >= i if d != dim else True) and srcs >= i for d,(s,i,srcs) in enumerate(zip(self.shape, index.shape, src.shape))), \
      f"Expected {index.shape=} to be <= {self.shape=} apart from dimension {dim} and to be <= {src.shape=}"
    mask = (index.unsqueeze(-1) == Tensor.arange(self.shape[dim], requires_grad=False, device=self.device)).transpose(-1, dim)
    src = src.shrink(tuple((0, sh) for sh in index.shape)).unsqueeze(-1).transpose(-1, dim)
    src = src.expand(tuple(self.size(i) if i == dim else None for i in range(src.ndim)))
    src = src.pad(tuple((0, max(xs-ss, 0)) for ss, xs in zip(src.shape[:-1], self.shape)) + (None,))
    mask = mask.pad(tuple((0, max(xs-ms, 0)) for ms, xs in zip(mask.shape[:-1], self.shape)) + (None,))
    if reduce is None:
      nan_masked = mask.where(src*mask, float("nan"))
      masked_src = functools.reduce(lambda x,y: y.isnan().where(x, y), nan_masked.split(1, -1))
      return (mask.any(-1).where(masked_src.squeeze(), self)).cast(self.dtype)
    if reduce == "add": return ((mask*src).sum(-1) + self).cast(self.dtype)
    if reduce == "multiply": return (mask.where(mask*src, 1).prod(-1) * self).cast(self.dtype)

Tensor.scatter = scatter_tensor

def kmeans(
    samples,
    num_clusters,
    num_iters=10,
    sample_fn=batched_sample_vectors,
    all_reduce_fn=noop,
):
    num_codebooks, dim, dtype = (
        samples.shape[0],
        samples.shape[-1],
        samples.dtype,
    )

    means = sample_fn(samples, num_clusters)

    for _ in range(num_iters):
        dists = -Tensor.cdist(samples, means, p=2)
        
        # Replace argmax
        buckets = dists.argmax(axis=-1)

        bins = batched_bincount(buckets, minlength=num_clusters)
        all_reduce_fn(bins)

        zero_mask = (bins == 0)
        bins_min_clamped = bins.masked_fill(zero_mask, 1)

        new_means = Tensor.zeros((num_codebooks, num_clusters, dim), dtype=dtype)
        new_means.scatter(1, repeat(buckets, "h n -> h n d", d=dim), samples, 'add')
        new_means = new_means / rearrange(bins_min_clamped, "... -> ... 1")
        all_reduce_fn(new_means)

        means = rearrange(zero_mask, "... -> ... 1").where(means, new_means)

    return means, bins


def batched_embedding(indices, embeds):
    batch, dim = indices.shape[1], embeds.shape[-1]
    indices = repeat(indices, "h b n -> h b n d", d=dim)
    embeds = repeat(embeds, "h c d -> h b c d", b=batch)
    return embeds.gather(2, indices)


def orthogonal_loss_fn(t):
    # eq (2) from https://arxiv.org/abs/2112.00384
    h, n = t.shape[:2]
    normed_codes = F.normalize(t, p=2, dim=-1)
    cosine_sim = Tensor.einsum("h i d, h j d -> h i j", normed_codes, normed_codes)
    return cosine_sim.square().sum() / (h * n.square()) - (1 / n)


# EuclideanCodebook: Vector Quantization for VQ-BeTs

The `EuclideanCodebook` class is a sophisticated implementation of vector quantization, a key component in Vector-Quantized Behavior Transformers (VQ-BeTs). This class provides a flexible and powerful mechanism for discretizing continuous action spaces.

Key Features of the EuclideanCodebook:

1. **Multi-Codebook Support**: Can handle multiple codebooks, allowing for more expressive representation of complex behaviors.

2. **Initialization Strategies**:
   - Supports both uniform and k-means initialization of codebooks
   - Flexible initialization options for different use cases

3. **Adaptive Codebook Management**:
   - Exponential Moving Average (EMA) updates for codebook vectors
   - Mechanism to detect and replace "dead" (unused) codebook entries
   - Optional learnable codebook

4. **Advanced Sampling Techniques**:
   - Gumbel-Softmax sampling for stochastic vector selection
   - Temperature-controlled sampling to balance exploration and exploitation

5. **Affine Transformation Support**:
   - Optional affine parameter tracking
   - Batch and codebook-level mean and variance normalization

Why is this important for VQ-BeTs?

- **Representation Learning**: Enables learning discrete representations of continuous action spaces
- **Computational Efficiency**: Reduces the dimensionality of action representations
- **Behavior Modeling**: Captures complex, multi-modal action distributions

Let's dive into the implementation of the EuclideanCodebook:

In [8]:
from functools import partial
class EuclideanCodebook:
    def __init__(
        self,
        dim,
        codebook_size,
        num_codebooks=1,
        kmeans_init=False,
        kmeans_iters=10,
        sync_kmeans=True,
        decay=0.8,
        eps=1e-5,
        threshold_ema_dead_code=2,
        reset_cluster_size=None,
        use_ddp=False,
        learnable_codebook=False,
        gumbel_sample=gumbel_sample,
        sample_codebook_temp=1.0,
        ema_update=True,
        affine_param=False,
        sync_affine_param=False,
        affine_param_batch_decay=0.99,
        affine_param_codebook_decay=0.9,
    ):
        self.transform_input = identity

        self.decay = decay
        self.ema_update = ema_update

        init_fn = Tensor.kaiming_uniform if not kmeans_init else Tensor.zeros
        embed = init_fn(num_codebooks, codebook_size, dim, requires_grad=learnable_codebook)

        self.codebook_size = codebook_size
        self.num_codebooks = num_codebooks

        self.kmeans_iters = kmeans_iters
        self.eps = eps
        self.threshold_ema_dead_code = threshold_ema_dead_code
        self.reset_cluster_size = (
            reset_cluster_size if (reset_cluster_size is not None) else threshold_ema_dead_code
        )

        assert callable(gumbel_sample)
        self.gumbel_sample = gumbel_sample
        self.sample_codebook_temp = sample_codebook_temp

        assert not (
            use_ddp and num_codebooks > 1 and kmeans_init
        ), "kmeans init is not compatible with multiple codebooks in distributed environment for now"

        # we aren't dealing with distributed work for now. Dear god no no no no
        self.sample_fn = batched_sample_vectors
        self.kmeans_all_reduce_fn = noop
        self.all_reduce_fn = noop

        self.initted = not kmeans_init
        self.cluster_size = Tensor.zeros((num_codebooks, codebook_size), requires_grad=False)
        self.embed_avg = embed.detach()

        self.learnable_codebook = learnable_codebook
        self.embed = embed
        self.embed.requires_grad = learnable_codebook

        # affine related params

        self.affine_param = affine_param
        self.sync_affine_param = sync_affine_param

        if not affine_param:
            return

        self.affine_param_batch_decay = affine_param_batch_decay
        self.affine_param_codebook_decay = affine_param_codebook_decay

        self.batch_mean = Tensor(None, requires_grad=False)
        self.batch_variance = Tensor(None, requires_grad=False)

        self.codebook_mean_needs_init = Tensor([True], requires_grad=False)
        self.codebook_mean = Tensor.empty(num_codebooks, 1, dim, requires_grad=False)
        self.codebook_variance_needs_init = Tensor([True], requires_grad=False)
        self.codebook_variance = Tensor.empty(num_codebooks, 1, dim, requires_grad=False)

    # Don't jit this
    def init_embed_(self, data, mask=None):
        if self.initted:
            return

        if mask is not None:
            c = data.shape[0]
            data = rearrange(data[mask], "(c n) d -> c n d", c=c)

        embed, cluster_size = kmeans(
            data,
            self.codebook_size,
            self.kmeans_iters,
            sample_fn=self.sample_fn,
            all_reduce_fn=self.kmeans_all_reduce_fn,
        )

        embed_sum = embed * rearrange(cluster_size, "... -> ... 1")

        self.embed = self.embed.replace(embed)
        self.embed_avg = self.embed_avg.replace(embed_sum)
        self.cluster_size = self.cluster_size.replace(cluster_size)
        self.initted = True

    # don't jit
    def update_with_decay(self, buffer_name, new_value, decay):
        old_value = getattr(self, buffer_name)

        needs_init = getattr(self, buffer_name + "_needs_init", False)

        if needs_init:
            setattr(self, buffer_name + "_needs_init", Tensor([False], requires_grad=False))

        if not (old_value is not None) or needs_init:
            setattr(self, buffer_name, new_value.detach())

            return

        value = old_value * decay + new_value.detach() * (1 - decay)
        setattr(self, buffer_name, value)

    # don't jit
    def update_affine(self, data, embed, mask=None):
        assert self.affine_param

        # don't use bessel correction
        var_fn = partial(Tensor.var, correction=0)

        # calculate codebook mean and variance

        embed = rearrange(embed, "h ... d -> h (...) d")

        if Tensor.training:
            self.update_with_decay(
                "codebook_mean",
                reduce(embed, "h n d -> h 1 d", "mean"),
                self.affine_param_codebook_decay,
            )
            self.update_with_decay(
                "codebook_variance",
                reduce(embed, "h n d -> h 1 d", var_fn),
                self.affine_param_codebook_decay,
            )

        # prepare batch data, which depends on whether it has masking

        data = rearrange(data, "h ... d -> h (...) d")

        if mask is not None:
            c = data.shape[0]
            data = rearrange(data[mask], "(c n) d -> c n d", c=c)

        # calculate batch mean and variance

        if not self.sync_affine_param:
            self.update_with_decay(
                "batch_mean",
                reduce(data, "h n d -> h 1 d", "mean"),
                self.affine_param_batch_decay,
            )
            self.update_with_decay(
                "batch_variance",
                reduce(data, "h n d -> h 1 d", var_fn),
                self.affine_param_batch_decay,
            )
            return

        num_vectors = data.shape[-2]

        # number of vectors, for denominator

        num_vectors = Tensor([num_vectors], dtype=data.dtype)
        # calculate distributed mean

        batch_sum = reduce(data, "h n d -> h 1 d", "sum")
        batch_mean = batch_sum / num_vectors

        self.update_with_decay("batch_mean", batch_mean, self.affine_param_batch_decay)

        # calculate distributed variance

        variance_numer = reduce((data - batch_mean).square(), "h n d -> h 1 d", "sum")
        batch_variance = variance_numer / num_vectors

        self.update_with_decay("batch_variance", batch_variance, self.affine_param_batch_decay)

    def replace(self, batch_samples, batch_mask):
        for ind, (samples, mask) in enumerate(
            zip(unbind(batch_samples), unbind(batch_mask), strict=False)
        ):
            if not mask.any().item():
                continue

            sampled = self.sample_fn(rearrange(samples, "... -> 1 ..."), mask.sum().item())
            sampled = rearrange(sampled, "1 ... -> ...")

            self.embed[ind][mask] = sampled

            self.cluster_size[ind][mask] = self.reset_cluster_size
            self.embed_avg[ind][mask] = sampled * self.reset_cluster_size

    def expire_codes_(self, batch_samples):
        if self.threshold_ema_dead_code == 0:
            return

        expired_codes = self.cluster_size < self.threshold_ema_dead_code

        if not expired_codes.any().item():
            return

        batch_samples = rearrange(batch_samples, "h ... d -> h (...) d")
        self.replace(batch_samples, batch_mask=expired_codes)
        
    def __call__(self, x, sample_codebook_temp=None, mask=None, freeze_codebook=False):
        needs_codebook_dim = x.ndim < 4
        sample_codebook_temp = (
            sample_codebook_temp if (sample_codebook_temp is not None) else self.sample_codebook_temp
        )

        x = x.float()

        if needs_codebook_dim:
            x = rearrange(x, "... -> 1 ...")

        flatten, ps = pack_one(x, "h * d")

        if mask is not None:
            mask = repeat(
                mask,
                "b n -> c (b h n)",
                c=flatten.shape[0],
                h=flatten.shape[-2] // (mask.shape[0] * mask.shape[1]),
            )

        self.init_embed_(flatten, mask=mask)

        if self.affine_param:
            self.update_affine(flatten, self.embed, mask=mask)

        embed = self.embed if self.learnable_codebook else self.embed.detach()

        if self.affine_param:
            codebook_std = self.codebook_variance.clamp(min_=1e-5).sqrt()
            batch_std = self.batch_variance.clamp(min_=1e-5).sqrt()
            embed = (embed - self.codebook_mean) * (batch_std / codebook_std) + self.batch_mean

        dist = -cdist(flatten, embed)

        embed_ind, embed_onehot = self.gumbel_sample(
            dist, dim=-1, temperature=sample_codebook_temp, training=Tensor.training
        )

        embed_ind = unpack_one(embed_ind, ps, "h *")

        if Tensor.training:
            unpacked_onehot = unpack_one(embed_onehot, ps, "h * c")
            quantize = Tensor.einsum("h b n c, h c d -> h b n d", unpacked_onehot, embed)
        else:
            quantize = batched_embedding(embed_ind, embed)

        if Tensor.training and self.ema_update and not freeze_codebook:
            if self.affine_param:
                flatten = (flatten - self.batch_mean) * (codebook_std / batch_std) + self.codebook_mean

            if mask is not None:
                embed_onehot[mask.logical_not().int()] = 0.0

            cluster_size = embed_onehot.sum(axis=1)

            self.all_reduce_fn(cluster_size)
            ema_inplace(self.cluster_size, cluster_size, self.decay)

            embed_sum = Tensor.einsum("h n d, h n c -> h c d", flatten, embed_onehot)
            self.all_reduce_fn(embed_sum.contiguous())
            ema_inplace(self.embed_avg, embed_sum, self.decay)

            cluster_size = laplace_smoothing(
                self.cluster_size, self.codebook_size, self.eps
            ) * self.cluster_size.sum(axis=-1, keepdim=True)

            embed_normalized = self.embed_avg / rearrange(cluster_size, "... -> ... 1")
            self.embed.replace(embed_normalized)
            self.expire_codes_(x)

        if needs_codebook_dim:
            quantize, embed_ind = tuple(rearrange(t, "1 ... -> ...") for t in (quantize, embed_ind))

        dist = unpack_one(dist, ps, "h * d")

        return quantize, embed_ind, dist

# VectorQuantize: The Core of VQ-BeTs

The `VectorQuantize` class is the cornerstone of Vector-Quantized Behavior Transformers (VQ-BeTs). This class implements the vector quantization process, which is crucial for transforming continuous action spaces into discrete representations that can be effectively processed by transformer models.

Key Features and Their Importance for VQ-BeTs:

1. **Multi-Head Support**:
   - Allows for multiple independent codebooks, enabling more nuanced representation of complex behaviors.
   - Essential for capturing different aspects of behavior in parallel.

2. **Flexible Codebook Initialization**:
   - Supports both uniform and k-means initialization.
   - K-means initialization can lead to better initial representations of the action space.

3. **Commitment Loss**:
   - Encourages the encoder to produce embeddings close to the codebook vectors.
   - Crucial for learning meaningful discrete representations of continuous actions.

4. **Straight-Through Estimator**:
   - Allows gradient flow through the discretization process.
   - Essential for end-to-end training of the VQ-BeT model.

5. **Gumbel-Softmax Sampling**:
   - Enables stochastic selection of codebook vectors.
   - Facilitates exploration in the action space during training.

6. **EMA Codebook Updates**:
   - Provides stable updates to the codebook vectors.
   - Important for consistent learning and adaptation to evolving behavior patterns.

7. **Orthogonal Regularization**:
   - Encourages diversity in the codebook vectors.
   - Helps in capturing a wide range of distinct behaviors.

8. **Affine Transformation Support**:
   - Allows for normalization of inputs and codebook vectors.
   - Aids in handling varying scales and distributions in the action space.

Why is this crucial for VQ-BeTs?

- **Discrete Representation**: Enables the transformation of continuous actions into discrete tokens, allowing the use of transformer architectures for behavior modeling.
- **Efficient Learning**: By clustering similar actions, it helps in identifying and learning common behavior patterns.
- **Scalability**: Allows handling of high-dimensional and complex action spaces typical in robotics and other advanced control tasks.
- **Exploration-Exploitation Balance**: The stochastic nature of the quantization process aids in balancing exploration and exploitation during learning.

Let's examine the implementation of the VectorQuantize class:

In [9]:
from functools import partial
class VectorQuantize:
    def __init__(
        self,
        dim,
        codebook_size,
        codebook_dim=None,
        heads=1,
        separate_codebook_per_head=False,
        decay=0.8,
        eps=1e-5,
        kmeans_init=False,
        kmeans_iters=10,
        sync_kmeans=True,
        threshold_ema_dead_code=0,
        channel_last=True,
        accept_image_fmap=False,
        commitment_weight=1.0,
        commitment_use_cross_entropy_loss=False,
        orthogonal_reg_weight=0.0,
        orthogonal_reg_active_codes_only=False,
        orthogonal_reg_max_codes=None,
        stochastic_sample_codes=False,
        sample_codebook_temp=1.0,
        straight_through=False,
        reinmax=False,  # using reinmax for improved straight-through, assuming straight through helps at all
        sync_codebook=None,
        sync_affine_param=False,
        ema_update=True,
        learnable_codebook=False,
        in_place_codebook_optimizer: Callable[
            ..., Optimizer
        ] = None,  # Optimizer used to update the codebook embedding if using learnable_codebook
        affine_param=False,
        affine_param_batch_decay=0.99,
        affine_param_codebook_decay=0.9,
        sync_update_v=0.0,  # the v that controls optimistic vs pessimistic update for synchronous update rule (21) https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf
    ):
        super().__init__()
        self.dim = dim
        self.heads = heads
        self.separate_codebook_per_head = separate_codebook_per_head

        codebook_dim = codebook_dim if (codebook_dim is not None) else dim
        codebook_input_dim = codebook_dim * heads

        requires_projection = codebook_input_dim != dim
        self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else None
        self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else None

        self.eps = eps
        self.commitment_weight = commitment_weight
        self.commitment_use_cross_entropy_loss = commitment_use_cross_entropy_loss  # whether to use cross entropy loss to codebook as commitment loss

        self.learnable_codebook = learnable_codebook

        has_codebook_orthogonal_loss = orthogonal_reg_weight > 0
        self.has_codebook_orthogonal_loss = has_codebook_orthogonal_loss
        self.orthogonal_reg_weight = orthogonal_reg_weight
        self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
        self.orthogonal_reg_max_codes = orthogonal_reg_max_codes

        assert not (ema_update and learnable_codebook), "learnable codebook not compatible with EMA update"

        assert 0 <= sync_update_v <= 1.0
        assert not (sync_update_v > 0.0 and not learnable_codebook), "learnable codebook must be turned on"

        self.sync_update_v = sync_update_v

        gumbel_sample_fn = partial(
            gumbel_sample,
            stochastic=stochastic_sample_codes,
            reinmax=reinmax,
            straight_through=straight_through,
        )

        if sync_codebook is None:
            sync_codebook = False

        codebook_kwargs = {
            "dim": codebook_dim,
            "num_codebooks": heads if separate_codebook_per_head else 1,
            "codebook_size": codebook_size,
            "kmeans_init": kmeans_init,
            "kmeans_iters": kmeans_iters,
            "sync_kmeans": sync_kmeans,
            "decay": decay,
            "eps": eps,
            "threshold_ema_dead_code": threshold_ema_dead_code,
            "use_ddp": sync_codebook,
            "learnable_codebook": has_codebook_orthogonal_loss or learnable_codebook,
            "sample_codebook_temp": sample_codebook_temp,
            "gumbel_sample": gumbel_sample_fn,
            "ema_update": ema_update,
        }

        if affine_param:
            codebook_kwargs = dict(
                **codebook_kwargs,
                affine_param=True,
                sync_affine_param=sync_affine_param,
                affine_param_batch_decay=affine_param_batch_decay,
                affine_param_codebook_decay=affine_param_codebook_decay,
            )

        self._codebook = EuclideanCodebook(**codebook_kwargs)

        self.in_place_codebook_optimizer = (
            in_place_codebook_optimizer(nn.state.get_parameters(self._codebook))
            if (in_place_codebook_optimizer is not None)
            else None
        )

        self.codebook_size = codebook_size

        self.accept_image_fmap = accept_image_fmap
        self.channel_last = channel_last

    @property
    def codebook(self):
        codebook = self._codebook.embed

        if self.separate_codebook_per_head:
            return codebook

        return rearrange(codebook, "1 ... -> ...")

    @codebook.setter
    def codebook(self, codes):
        if not self.separate_codebook_per_head:
            codes = rearrange(codes, "... -> 1 ...")

        self._codebook.embed.copy_(codes)

    def get_codebook_vector_from_indices(self, indices):
        codebook = self.codebook
        is_multiheaded = codebook.ndim > 2

        if not is_multiheaded:
            codes = codebook[indices]
            return rearrange(codes, "... h d -> ... (h d)")

        indices, ps = pack_one(indices, "b * h")
        indices = rearrange(indices, "b n h -> b h n")

        indices = repeat(indices, "b h n -> b h n d", d=codebook.shape[-1])
        codebook = repeat(codebook, "h n d -> b h n d", b=indices.shape[0])

        codes = codebook.gather(2, indices)
        codes = rearrange(codes, "b h n d -> b n (h d)")
        codes = unpack_one(codes, ps, "b * d")
        return codes

    def __call__(
        self,
        x: Tensor,
        indices=None,
        mask=None,
        sample_codebook_temp=None,
        freeze_codebook=False,
    ):
        orig_input = x

        only_one = x.ndim == 2

        if only_one:
            assert mask is None
            x = rearrange(x, "b d -> b 1 d")

        shape, device, heads, is_multiheaded, _codebook_size, return_loss = (
            x.shape,
            x.device,
            self.heads,
            self.heads > 1,
            self.codebook_size,
            (indices is not None),
        )

        need_transpose = not self.channel_last and not self.accept_image_fmap
        should_inplace_optimize = self.in_place_codebook_optimizer is not None

        # rearrange inputs

        if self.accept_image_fmap:
            height, width = x.shape[-2:]
            x = rearrange(x, "b c h w -> b (h w) c")

        if need_transpose:
            x = rearrange(x, "b d n -> b n d")

        # project input

        x = self.project_in(x) if self.project_in is not None else x

        # handle multi-headed separate codebooks

        if is_multiheaded:
            ein_rhs_eq = "h b n d" if self.separate_codebook_per_head else "1 (b h) n d"
            x = rearrange(x, f"b n (h d) -> {ein_rhs_eq}", h=heads)

        # l2norm for cosine sim, otherwise identity

        x = self._codebook.transform_input(x)

        # codebook forward kwargs

        codebook_forward_kwargs = {
            "sample_codebook_temp": sample_codebook_temp,
            "mask": mask,
            "freeze_codebook": freeze_codebook,
        }

        # quantize

        quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs)

        # one step in-place update

        if should_inplace_optimize and Tensor.training and not freeze_codebook:
            if mask is not None:
                loss = (quantize - x.detach()).square().sum()

                loss_mask = mask
                if is_multiheaded:
                    loss_mask = repeat(
                        mask,
                        "b n -> c (b h) n",
                        c=loss.shape[0],
                        h=loss.shape[1] // mask.shape[0],
                    )

                loss = loss[loss_mask].mean()

            else:
                loss = (quantize - x.detach()).square().sum().mean()

            loss.backward()
            self.in_place_codebook_optimizer.step()
            self.in_place_codebook_optimizer.zero_grad()

            # quantize again

            quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs)

        if Tensor.training:
            # determine code to use for commitment loss
            commit_quantize = quantize.detach() if not self.learnable_codebook or freeze_codebook else quantize

            # straight through

            quantize = x + (quantize - x).detach()

            if self.sync_update_v > 0.0:
                # (21) in https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf
                quantize = quantize + self.sync_update_v * (quantize - quantize.detach())

        # function for calculating cross entropy loss to distance matrix
        # used for (1) naturalspeech2 training residual vq latents to be close to the correct codes and (2) cross-entropy based commitment loss

        def calculate_ce_loss(codes):
            if not is_multiheaded:
                dist_einops_eq = "1 b n l -> b l n"
            elif self.separate_codebook_per_head:
                dist_einops_eq = "c b n l -> b l n c"
            else:
                dist_einops_eq = "1 (b h) n l -> b l n h"

            ce_loss = rearrange(distances, dist_einops_eq, b=shape[0]).cross_entropy(
                codes #, ignore_index=-1
            )

            return ce_loss

        # if returning cross entropy loss on codes that were passed in

        if return_loss:
            return quantize, calculate_ce_loss(indices)

        # transform embedding indices

        if is_multiheaded:
            if self.separate_codebook_per_head:
                embed_ind = rearrange(embed_ind, "h b n -> b n h", h=heads)
            else:
                embed_ind = rearrange(embed_ind, "1 (b h) n -> b n h", h=heads)

        if self.accept_image_fmap:
            embed_ind = rearrange(embed_ind, "b (h w) ... -> b h w ...", h=height, w=width)

        if only_one:
            embed_ind = rearrange(embed_ind, "b 1 -> b")

        # aggregate loss

        loss = Tensor([0.0], requires_grad=Tensor.training)

        if Tensor.training:
            if self.commitment_weight > 0:
                if self.commitment_use_cross_entropy_loss:
                    if mask is not None:
                        ce_loss_mask = mask
                        if is_multiheaded:
                            ce_loss_mask = repeat(ce_loss_mask, "b n -> b n h", h=heads)
                        
                        embed_ind = embed_ind.masked_fill(ce_loss_mask.logical_not(), -1)
                        print(f'embed_ind: {embed_ind}')

                    commit_loss = calculate_ce_loss(embed_ind)
                else:
                    if mask is not None:
                        # with variable lengthed sequences
                        commit_loss = (commit_quantize - x).square().sum()

                        loss_mask = mask
                        if is_multiheaded:
                            loss_mask = repeat(
                                loss_mask,
                                "b n -> c (b h) n",
                                c=commit_loss.shape[0],
                                h=commit_loss.shape[1] // mask.shape[0],
                            )

                        commit_loss = commit_loss[loss_mask].mean()
                    else:
                        commit_loss = (commit_quantize - x).square().sum().mean()

                loss = loss + commit_loss * self.commitment_weight

            if self.has_codebook_orthogonal_loss:
                codebook = self._codebook.embed

                # only calculate orthogonal loss for the activated codes for this batch

                if self.orthogonal_reg_active_codes_only:
                    assert not (
                        is_multiheaded and self.separate_codebook_per_head
                    ), "orthogonal regularization for only active codes not compatible with multi-headed with separate codebooks yet"
                    unique_code_ids = Tensor(embed_ind.numpy().unique(), requires_grad=False)
                    codebook = codebook[:, unique_code_ids]

                num_codes = codebook.shape[-2]

                if (self.orthogonal_reg_max_codes is not None) and num_codes > self.orthogonal_reg_max_codes:
                    rand_ids = list(np.randperm(num_codes)[: self.orthogonal_reg_max_codes])
                    codebook = codebook[:, rand_ids]

                orthogonal_reg_loss = orthogonal_loss_fn(codebook)
                loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight

        # handle multi-headed quantized embeddings

        if is_multiheaded:
            if self.separate_codebook_per_head:
                quantize = rearrange(quantize, "h b n d -> b n (h d)", h=heads)
            else:
                quantize = rearrange(quantize, "1 (b h) n d -> b n (h d)", h=heads)

        # project out

        quantize = self.project_out(quantize) if self.project_out is not None else quantize

        # rearrange quantized embeddings

        if need_transpose:
            quantize = rearrange(quantize, "b n d -> b d n")

        if self.accept_image_fmap:
            quantize = rearrange(quantize, "b (h w) c -> b c h w", h=height, w=width)

        if only_one:
            quantize = rearrange(quantize, "b 1 d -> b d")

        # if masking, only return quantized for where mask has True

        if mask is not None:
            quantize = rearrange(mask, "... -> ... 1").where(quantize, orig_input)

        return quantize, embed_ind, loss

# ResidualVQ: Multi-Stage Vector Quantization for VQ-BeTs

The `ResidualVQ` class implements a crucial component of Vector-Quantized Behavior Transformers (VQ-BeTs): multi-stage vector quantization. This approach, also known as residual vector quantization, enhances the model's ability to capture fine-grained details in the action space.

Key Features and Their Importance for VQ-BeTs:

1. **Multiple Quantization Layers**:
   - Uses a cascade of VectorQuantize layers to iteratively quantize the input and its residuals.
   - Allows for more precise representation of complex behaviors by capturing details at different scales.

2. **Shared or Independent Codebooks**:
   - Option to use shared codebooks across layers for parameter efficiency or independent codebooks for more expressive power.
   - Crucial for balancing model complexity and representation capacity.

3. **Quantize Dropout**:
   - Enables stochastic dropping of fine-grained quantizations during training.
   - Enhances model robustness and generalization by forcing it to work with partial information.

4. **Flexible Dimensionality**:
   - Supports different input and codebook dimensions with optional projection layers.
   - Allows adaptation to various action space configurations in different environments.

5. **Image Feature Map Support**:
   - Can handle inputs structured as image feature maps.
   - Useful for tasks involving visual inputs, such as robotic manipulation or autonomous driving.

6. **Codebook Freezing**:
   - Allows freezing of codebooks during certain phases of training.
   - Important for fine-tuning or transfer learning scenarios in behavior modeling.

Why is this crucial for VQ-BeTs?

- **Hierarchical Representation**: Enables capturing behavior patterns at multiple levels of abstraction.
- **Improved Reconstruction**: Residual quantization often leads to better reconstruction of the original input, crucial for accurate behavior modeling.
- **Scalability**: Allows handling of complex, high-dimensional action spaces common in advanced robotics and control tasks.
- **Adaptability**: The flexible architecture allows the model to adapt to various types of behavior data and task requirements.

Let's examine the implementation of the ResidualVQ class:

In [10]:
from functools import partial

class ResidualVQ:
    """
    Residual VQ is composed of multiple VectorQuantize layers.

    Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
        "Residual Vector Quantizer (a.k.a. multi-stage vector quantizer [36]) cascades Nq layers of VQ as follows. The unquantized input vector is
        passed through a first VQ and quantization residuals are computed. The residuals are then iteratively quantized by a sequence of additional
        Nq -1 vector quantizers, as described in Algorithm 1."


    self.project_in: function for projecting input to codebook dimension
    self.project_out: function for projecting codebook dimension to output dimension
    self.layers: nn.ModuleList of VectorQuantize layers that contains Nq layers of VQ as described in the paper.
    self.freeze_codebook: buffer to save an indicator whether the codebook is frozen or not. VQ-BeT will check this to determine whether to update the codebook or not.
    """

    def __init__(
        self,
        *,
        dim,
        num_quantizers,
        codebook_dim=None,
        shared_codebook=False,
        heads=1,
        quantize_dropout=False,
        quantize_dropout_cutoff_index=0,
        quantize_dropout_multiple_of=1,
        accept_image_fmap=False,
        **kwargs,
    ):
        assert heads == 1, "residual vq is not compatible with multi-headed codes"
        codebook_dim = codebook_dim if (codebook_dim is not None) else dim
        codebook_input_dim = codebook_dim * heads
        
        requires_projection = codebook_input_dim != dim
        self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else lambda x: x
        self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else lambda x: x

        self.num_quantizers = num_quantizers

        self.accept_image_fmap = accept_image_fmap
        self.layers = [
                VectorQuantize(
                    dim=codebook_dim, codebook_dim=codebook_dim, accept_image_fmap=accept_image_fmap, **kwargs
                )
                for _ in range(num_quantizers)
            ]

        self.quantize_dropout = quantize_dropout and num_quantizers > 1

        assert quantize_dropout_cutoff_index >= 0

        # freeze_cookbook is like a buffer. Don't strict_load or strict_save
        self.freeze_codebook = False
        self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
        self.quantize_dropout_multiple_of = quantize_dropout_multiple_of  # encodec paper proposes structured dropout, believe this was set to 4

        if not shared_codebook:
            return

        first_vq, *rest_vq = self.layers
        codebook = first_vq._codebook

        for vq in rest_vq:
            vq._codebook = codebook

    @property
    def codebooks(self):
        codebooks = [layer._codebook.embed for layer in self.layers]
        codebooks = Tensor.stack(*codebooks, dim=0)
        codebooks = rearrange(codebooks, "q 1 c d -> q c d")
        return codebooks

    def get_codebook_vector_from_indices(self, indices):
        # this function will return the codes from all codebooks across layers corresponding to the indices
        batch, quantize_dim = indices.shape[0], indices.shape[-1]

        # may also receive indices in the shape of 'b h w q' (accept_image_fmap)

        indices, ps = pack([indices], "b * q")

        # because of quantize dropout, one can pass in indices that are coarse
        # and the network should be able to reconstruct

        if quantize_dim < self.num_quantizers:
            assert (
                self.quantize_dropout > 0.0
            ), "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations"
            indices = indices.pad((0, self.num_quantizers - quantize_dim), value=-1)

        # get ready for gathering

        codebooks = repeat(self.codebooks, "q c d -> q b c d", b=batch)
        gather_indices = repeat(indices, "b n q -> q b n d", d=codebooks.shape[-1])

        # take care of quantizer dropout

        # have it fetch a dummy code to be masked out later
        mask = gather_indices == -1.0
        print(f'mask: {mask}')
        gather_indices = gather_indices.masked_fill(
            mask, 0
        )  # have it fetch a dummy code to be masked out later

        all_codes = codebooks.gather(2, gather_indices)  # gather all codes

        # mask out any codes that were dropout-ed
        all_codes = all_codes.masked_fill(mask, 0.0)

        # if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension)

        (all_codes,) = unpack(all_codes, ps, "q b * d")

        return all_codes

    def __call__(self, x:Tensor, indices=None, return_all_codes=False, sample_codebook_temp=None):
        """
        For given input tensor x, this function will return the quantized output, the indices of the quantized output, and the loss.
        First, the input tensor x is projected to the codebook dimension. Then, the input tensor x is passed through Nq layers of VectorQuantize.
        The residual value of each layer is fed to the next layer.
        """
        num_quant, quant_dropout_multiple_of, return_loss = (
            self.num_quantizers,
            self.quantize_dropout_multiple_of,
            (indices is not None)
        )

        x = self.project_in(x)

        assert not (self.accept_image_fmap and (indices is not None))

        quantized_out = 0.0
        residual = x

        all_losses = []
        all_indices = []

        if return_loss:
            assert not (indices == -1).any().item(), "some of the residual vq indices were dropped out. please use indices derived when the module is in eval mode to derive cross entropy loss"
            ce_losses = []

        should_quantize_dropout = Tensor.training and self.quantize_dropout and not return_loss

        # sample a layer index at which to dropout further residual quantization
        # also prepare null indices and loss

        if should_quantize_dropout:
            rand_quantize_dropout_index = randrange(self.quantize_dropout_cutoff_index, num_quant)

            if quant_dropout_multiple_of != 1:
                rand_quantize_dropout_index = (
                    ceil((rand_quantize_dropout_index + 1) / quant_dropout_multiple_of)
                    * quant_dropout_multiple_of
                    - 1
                )

            null_indices_shape = (x.shape[0], *x.shape[-2:]) if self.accept_image_fmap else tuple(x.shape[:2])
            null_indices = Tensor.full(null_indices_shape, -1.0, dtype=dtypes.long, requires_grad=False)
            null_loss = Tensor.full((1,), 0.0, dtype=x.dtype, requires_grad=False)

        # go through the layers

        for quantizer_index, layer in enumerate(self.layers):
            if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index:
                all_indices.append(null_indices)
                all_losses.append(null_loss)
                continue

            layer_indices = None
            if return_loss:
                layer_indices = indices[..., quantizer_index]

            quantized, *rest = layer(
                residual,
                indices=layer_indices,
                sample_codebook_temp=sample_codebook_temp,
                freeze_codebook=self.freeze_codebook,
            )

            residual = residual - quantized.detach()
            quantized_out = quantized_out + quantized

            if return_loss:
                ce_loss = rest[0]
                ce_losses.append(ce_loss)
                continue

            embed_indices, loss = rest

            all_indices.append(embed_indices)
            all_losses.append(loss)

        # project out, if needed

        quantized_out = self.project_out(quantized_out)

        # whether to early return the cross entropy loss

        if return_loss:
            return quantized_out, sum(ce_losses)

        # stack all losses and indices

        # be cautious here: might need to decouple the all_losses through a point into individual objects
        all_losses = Tensor.stack(*all_losses, dim=-1)
        all_indices = Tensor.stack(*all_indices, dim=-1)

        ret = (quantized_out, all_indices, all_losses)

        if return_all_codes:
            # whether to return all codes from all codebooks across layers
            all_codes = self.get_codebook_vector_from_indices(all_indices)

            # will return all codes in shape (quantizer, batch, sequence length, codebook dimension)
            ret = (*ret, all_codes)

        return ret



# Creating a Config

By now, we've gotten all the basic building blocks assembled for making the critical parts of VQ-BeT. Let's make a config that helps us understand and tweak parameters in our system

In [11]:
from dataclasses import dataclass, field


@dataclass
class VQBeTConfig:
    """Configuration class for VQ-BeT.

    Defaults are configured for training with PushT providing proprioceptive and single camera observations.

    The parameters you will most likely need to change are the ones which depend on the environment / sensors.
    Those are: `input_shapes` and `output_shapes`.

    Notes on the inputs and outputs:
        - "observation.state" is required as an input key.
        - At least one key starting with "observation.image is required as an input.
        - If there are multiple keys beginning with "observation.image" they are treated as multiple camera
          views. Right now we only support all images having the same shape.
        - "action" is required as an output key.

    Args:
        n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
            current step and additional steps going back).
        n_action_pred_token: Total number of current token and future tokens that VQ-BeT predicts.
        action_chunk_size: Action chunk size of each action prediction token.
        input_shapes: A dictionary defining the shapes of the input data for the policy.
            The key represents the input data name, and the value is a list indicating the dimensions
            of the corresponding data. For example, "observation.image" refers to an input from
            a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
            Importantly, shapes doesnt include batch dimension or temporal dimension.
        output_shapes: A dictionary defining the shapes of the output data for the policy.
            The key represents the output data name, and the value is a list indicating the dimensions
            of the corresponding data. For example, "action" refers to an output shape of [14], indicating
            14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
        input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
            and the value specifies the normalization mode to apply. The two available modes are "mean_std"
            which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
            [-1, 1] range.
        output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
            original scale. Note that this is also used for normalizing the training targets.
        vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
        crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
            within the image size. If None, no cropping is done.
        crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
            mode).
        pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone.
            `None` means no pretrained weights.
        use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
            The group sizes are set to be about 16 (to be precise, feature_dim // 16).
        spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
        n_vqvae_training_steps: Number of optimization steps for training Residual VQ.
        vqvae_n_embed: Number of embedding vectors in the RVQ dictionary (each layer).
        vqvae_embedding_dim: Dimension of each embedding vector in the RVQ dictionary.
        vqvae_enc_hidden_dim: Size of hidden dimensions of Encoder / Decoder part of Residaul VQ-VAE
        gpt_block_size: Max block size of minGPT (should be larger than the number of input tokens)
        gpt_input_dim: Size of output input of GPT. This is also used as the dimension of observation features.
        gpt_output_dim: Size of output dimension of GPT. This is also used as a input dimension of offset / bin prediction headers.
        gpt_n_layer: Number of layers of GPT
        gpt_n_head: Number of headers of GPT
        gpt_hidden_dim: Size of hidden dimensions of GPT
        dropout: Dropout rate for GPT
        mlp_hidden_dim: Size of hidden dimensions of offset header / bin prediction headers parts of VQ-BeT
        offset_loss_weight:  A constant that is multiplied to the offset loss
        primary_code_loss_weight: A constant that is multiplied to the primary code prediction loss
        secondary_code_loss_weight: A constant that is multiplied to the secondary code prediction loss
        bet_softmax_temperature: Sampling temperature of code for rollout with VQ-BeT
        sequentially_select: Whether select code of primary / secondary as sequentially (pick primary code,
            and then select secodnary code), or at the same time.
    """

    # Inputs / output structure.
    n_obs_steps: int = 5
    n_action_pred_token: int = 7
    action_chunk_size: int = 5

    input_shapes: dict[str, list[int]] = field(
        default_factory=lambda: {
            "observation.image": [3, 96, 96],
            "observation.state": [2],
        }
    )
    output_shapes: dict[str, list[int]] = field(
        default_factory=lambda: {
            "action": [2],
        }
    )

    # Normalization / Unnormalization
    input_normalization_modes: dict[str, str] = field(
        default_factory=lambda: {
            "observation.image": "mean_std",
            "observation.state": "min_max",
        }
    )
    output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})

    # Architecture / modeling.
    # Vision backbone.
    vision_backbone: str = "resnet18"
    crop_shape: tuple[int, int] | None = (84, 84)
    crop_is_random: bool = True
    pretrained_backbone_weights: str | None = None
    use_group_norm: bool = True
    spatial_softmax_num_keypoints: int = 32
    # VQ-VAE
    n_vqvae_training_steps: int = 10
    vqvae_n_embed: int = 16
    vqvae_embedding_dim: int = 256
    vqvae_enc_hidden_dim: int = 128
    # VQ-BeT
    gpt_block_size: int = 500
    gpt_input_dim: int = 512
    gpt_output_dim: int = 512
    gpt_n_layer: int = 8
    gpt_n_head: int = 8
    gpt_hidden_dim: int = 512
    dropout: float = 0.1
    mlp_hidden_dim: int = 1024
    offset_loss_weight: float = 10000.0
    primary_code_loss_weight: float = 5.0
    secondary_code_loss_weight: float = 0.5
    bet_softmax_temperature: float = 0.1
    sequentially_select: bool = False

    def __post_init__(self):
        """Input validation (not exhaustive)."""
        if not self.vision_backbone.startswith("resnet"):
            raise ValueError(
                f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
            )
        image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
        if self.crop_shape is not None:
            for image_key in image_keys:
                if (
                    self.crop_shape[0] > self.input_shapes[image_key][1]
                    or self.crop_shape[1] > self.input_shapes[image_key][2]
                ):
                    raise ValueError(
                        f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} "
                        f"for `crop_shape` and {self.input_shapes[image_key]} for "
                        "`input_shapes[{image_key}]`."
                    )
        # Check that all input images have the same shape.
        first_image_key = next(iter(image_keys))
        for image_key in image_keys:
            if self.input_shapes[image_key] != self.input_shapes[first_image_key]:
                raise ValueError(
                    f"`input_shapes[{image_key}]` does not match `input_shapes[{first_image_key}]`, but we "
                    "expect all image shapes to match."
                )

# VqVae: The Core of Vector-Quantized Behavior Transformers (VQ-BeTs)

The `VqVae` class implements the Vector Quantized Variational Autoencoder (VQ-VAE) component of VQ-BeTs. This is a crucial part of the architecture that enables the discretization of continuous action spaces, which is fundamental to the VQ-BeT approach.

Key Components and Their Importance:

1. **Encoder and Decoder**:
   - Implemented as MLPs (Multi-Layer Perceptrons)
   - Encoder: Transforms continuous actions into latent representations
   - Decoder: Reconstructs actions from quantized latent representations
   - Critical for learning meaningful representations of behaviors

2. **Residual VQ Layer**:
   - Uses multiple layers of vector quantization
   - Allows for more expressive and fine-grained discretization of the action space
   - Key to capturing complex behavior patterns

3. **Two-Phase Training**:
   - Phase 1: Trains the encoder, decoder, and VQ layer
   - Phase 2: Uses the trained VQ-VAE to provide discrete codes for training the transformer

4. **Code Generation**:
   - Ability to generate discrete codes from continuous actions
   - Essential for creating the "vocabulary" of behaviors that the transformer will learn to predict

5. **Reconstruction and Commitment Losses**:
   - Ensures that the discretized representations can accurately reconstruct the original actions
   - Balances the trade-off between reconstruction accuracy and discretization

Why is this crucial for VQ-BeTs?

- **Discretization**: Enables the transformation of continuous actions into discrete tokens, allowing the use of transformer architectures for behavior modeling
- **Hierarchical Representation**: The residual VQ approach allows for capturing behavior at multiple levels of abstraction
- **Efficient Learning**: By clustering similar actions into discrete codes, it helps in identifying and learning common behavior patterns
- **Bridging Continuous and Discrete**: Allows the model to work with discrete tokens while still being able to generate continuous actions

Let's examine the implementation of the VqVae class:

In [12]:
class VqVae:
    def __init__(
        self,
        config: VQBeTConfig,
    ):
        """
        VQ-VAE is composed of three parts: encoder, vq_layer, and decoder.
        Encoder and decoder are MLPs consisting of an input, output layer, and hidden layer, respectively.
        The vq_layer uses residual VQs.

        This class contains functions for training the encoder and decoder along with the residual VQ layer (for trainign phase 1),
        as well as functions to help BeT training part in training phase 2.
        """

        super().__init__()
        self.config = config
        # 'discretized' indicates whether the Residual VQ part is trained or not. (After finishing the training, we set discretized=True)
        self.discretized = False
        self.optimized_steps = 0
        # we use the fixed number of layers for Residual VQ across all environments.
        self.vqvae_num_layers = 2

        self.vq_layer = ResidualVQ(
            dim=config.vqvae_embedding_dim,
            num_quantizers=self.vqvae_num_layers,
            codebook_size=config.vqvae_n_embed,
        )

        self.encoder = MLP(
            in_channels=self.config.output_shapes["action"][0] * self.config.action_chunk_size,
            hidden_channels=[
                config.vqvae_enc_hidden_dim,
                config.vqvae_enc_hidden_dim,
                config.vqvae_embedding_dim,
            ],
        )
        self.decoder = MLP(
            in_channels=config.vqvae_embedding_dim,
            hidden_channels=[
                config.vqvae_enc_hidden_dim,
                config.vqvae_enc_hidden_dim,
                self.config.output_shapes["action"][0] * self.config.action_chunk_size,
            ],
        )

    def get_embeddings_from_code(self, encoding_indices):
            # This function gets code indices as inputs, and outputs embedding vectors corresponding to the code indices.
        Tensor.no_grad = True
        training_prev, Tensor.training = Tensor.training, False
        z_embed = self.vq_layer.get_codebook_vector_from_indices(encoding_indices)
        # since the RVQ has multiple layers, it adds the vectors in the axis of layers to provide a vector for that code combination.
        z_embed = z_embed.sum(axis=0)
        Tensor.no_grad = False
        Tensor.training = training_prev
        return z_embed

    def get_action_from_latent(self, latent):
        # given latent vector, this function outputs the decoded action.
        output = self.decoder(latent)
        if self.config.action_chunk_size == 1:
            return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0])
        else:
            return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0])

    def get_code(self, state):
        # in phase 2 of VQ-BeT training, we need a `ground truth labels of action data` to calculate the Focal loss for code prediction head. (please refer to section 3.3 in the paper https://arxiv.org/pdf/2403.03181)
        # this function outputs the `GT code` of given action using frozen encoder and quantization layers. (please refer to Figure 2. in the paper https://arxiv.org/pdf/2403.03181)
        Tensor.no_grad = True
        training_prev, Tensor.training = Tensor.training, False
        state = einops.rearrange(state, "N T A -> N (T A)")
        state_rep = self.encoder(state)
        state_rep_shape = state_rep.shape[:-1]
        state_rep_flat = state_rep.view(state_rep.size(0), -1, state_rep.size(1))
        state_rep_flat, vq_code, vq_loss_state = self.vq_layer(state_rep_flat)
        state_vq = state_rep_flat.view(*state_rep_shape, -1)
        vq_code = vq_code.view(*state_rep_shape, -1)
        vq_loss_state = vq_loss_state.sum()
        Tensor.no_grad = False
        Tensor.training = training_prev
        return state_vq, vq_code

    
    def vqvae_forward(self, state):
        # This function passes the given data through Residual VQ with Encoder and Decoder. Please refer to section 3.2 in the paper https://arxiv.org/pdf/2403.03181).
        state = einops.rearrange(state, "N T A -> N (T A)")
        # We start with passing action (or action chunk) at:t+n through the encoder ϕ.
        state_rep = self.encoder(state)
        state_rep_shape = state_rep.shape[:-1]
        state_rep_flat = state_rep.view(state_rep.size(0), -1, state_rep.size(1))
        # The resulting latent embedding vector x = ϕ(at:t+n) is then mapped to an embedding vector in the codebook of the RVQ layers by the nearest neighbor look-up.
        state_rep_flat, vq_code, vq_loss_state = self.vq_layer(state_rep_flat)
        #vq_loss_state.requires_grad = False
        state_vq = state_rep_flat.view(*state_rep_shape, -1)
        vq_code = vq_code.view(*state_rep_shape, -1)
        # since the RVQ has multiple layers, it adds the vectors in the axis of layers to provide a vector for that code combination.
        vq_loss_state = vq_loss_state.sum()
        # Then, the discretized vector zq(x) is reconstructed as ψ(zq(x)) by passing through the decoder ψ.
        dec_out = self.decoder(state_vq)
        # Calculate L1 reconstruction loss
        encoder_loss = (state - dec_out).abs().mean()
        # add encoder reconstruction loss and commitment loss
        rep_loss = encoder_loss + vq_loss_state * 5
        
        metric = (
            encoder_loss.detach(),
            vq_loss_state.detach(),
            vq_code,
            rep_loss.item(),
        )
        return rep_loss, metric

# GPT: The Transformer Component of VQ-BeTs

The `GPT` class implements the Transformer component of Vector-Quantized Behavior Transformers (VQ-BeTs). This is a crucial part of the architecture that learns to predict future behavior tokens based on past observations and actions.

Key Features and Their Importance for VQ-BeTs:

1. **Transformer Architecture**:
   - Based on the GPT (Generative Pre-trained Transformer) model
   - Crucial for learning long-range dependencies in behavior sequences

2. **Flexible Configuration**:
   - Uses a `VQBeTConfig` object for hyperparameters
   - Allows easy experimentation with different model sizes and architectures

3. **Token and Position Embeddings**:
   - Separate embeddings for tokens (actions/observations) and positions
   - Enables the model to understand both the content and sequence of behaviors

4. **Multi-Layer Structure**:
   - Uses multiple transformer blocks for deep representation learning
   - Essential for capturing complex patterns in behavior data

5. **Output Projection**:
   - Projects the final hidden states to the output dimension
   - Crucial for predicting the next action token

6. **Parameter Configuration**:
   - Separates parameters into those with and without weight decay
   - Important for effective regularization during training

Why is this crucial for VQ-BeTs?

- **Sequence Modeling**: Enables the model to learn and predict sequences of behavior tokens
- **Context Understanding**: Allows the model to consider long-term context when predicting future actions
- **Scalability**: The transformer architecture can handle variable-length input sequences
- **Transfer Learning**: The GPT-style architecture allows for potential pre-training on large behavior datasets

Let's examine the implementation of the GPT class:

In [13]:
# GPT section, here we gooooo

from tinygrad import nn
from tinygrad.nn.state import get_state_dict

class CausalSelfAttention:
  def __init__(self, config):
    assert config.gpt_hidden_dim % config.gpt_n_head == 0
    # key, query, value projections for all heads, but in a batch
    self.c_attn = nn.Linear(config.gpt_hidden_dim, 3 * config.gpt_hidden_dim)
    # output projection
    self.c_proj = nn.Linear(config.gpt_hidden_dim, config.gpt_hidden_dim)
    # regularization
    self.gpt_n_head = config.gpt_n_head
    self.gpt_hidden_dim = config.gpt_hidden_dim
    self.dropout_rate = config.dropout
    # not really a 'bias', more of a mask, but following the OpenAI/HF naming though
    self.bias = Tensor.ones(config.gpt_block_size, config.gpt_block_size, requires_grad=False).view(
                1, 1, config.gpt_block_size, config.gpt_block_size
            ).tril().contiguous().detach().realize()
  def __call__(self, x: Tensor):
    (
        B,
        T,
        C,
    ) = x.size()  # batch size, sequence length, embedding dimensionality (gpt_hidden_dim)
    # calculate query, key, values for all heads in batch and move head forward to be the batch dim
    q, k, v = self.c_attn(x).split(self.gpt_hidden_dim, dim=2)
    k = k.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2)  # (B, nh, T, hs)
    q = q.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2)  # (B, nh, T, hs)
    v = v.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2)  # (B, nh, T, hs)
    # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
    att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
    att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
    att = att.softmax(axis=-1)
    att = att.dropout(self.dropout_rate)
    y = att @ v  # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
    y = y.transpose(1, 2).contiguous().view(B, T, C)  # re-assemble all head outputs side by side
    # output projection
    y = self.c_proj(y).dropout(self.dropout_rate)
    return y

class Block:
  # causual self-attention block for GPT
  def __init__(self, config):
    self.ln_1 = nn.LayerNorm(config.gpt_hidden_dim)
    self.attn = CausalSelfAttention(config)
    self.ln_2 = nn.LayerNorm(config.gpt_hidden_dim)
    self.mlp = [
        nn.Linear(config.gpt_hidden_dim, 4 * config.gpt_hidden_dim),
        Tensor.gelu,
        nn.Linear(4 * config.gpt_hidden_dim, config.gpt_hidden_dim),
    ]
    self.dropout_rate = config.dropout

  def __call__(self, x: Tensor) -> Tensor:
    x = x + self.attn(self.ln_1(x))
    x = x + self.ln_2(x).sequential(self.mlp).dropout(self.dropout_rate)
    return x

class GPT:
    """
    Original comments:
    Full definition of a GPT Language Model, all of it in this single file.
    References: the official GPT-2 TensorFlow implementation released by OpenAI:
    https://github.com/openai/gpt-2/blob/master/src/model.py
    """
    def __init__(self, config: VQBeTConfig):
        """
        GPT model gets hyperparameters from a config object. Please refer configuration_vqbet.py for more details.
        """
        assert config.gpt_output_dim is not None
        assert config.gpt_block_size is not None
        self.config = config
        self.wte = nn.Linear(config.gpt_input_dim, config.gpt_hidden_dim)
        self.wpe = nn.Embedding(config.gpt_block_size, config.gpt_hidden_dim)
        self.drop = config.dropout
        self.h = [Block(config) for _ in range(config.gpt_n_layer)]
        self.ln_f = nn.LayerNorm(config.gpt_hidden_dim)
        self.lm_head = nn.Linear(config.gpt_hidden_dim, config.gpt_output_dim, bias=False)
        for pn, p in get_state_dict(self).items():
            if pn.endswith("c_proj.weight"):
                print(f'replacing GPT parameter: {pn} with a normal')
                p.assign(Tensor.normal(p.shape, mean=0.0, std=0.02 / math.sqrt(2 * config.gpt_n_layer)))
    def __call__(self, input_to_wte:Tensor, targets=None):
        print(f'Calling GPT!')
        b, t, d = input_to_wte.size()
        assert (
            t <= self.config.gpt_block_size
        ), f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}"
        # positional encodings that are added to the input embeddings
        pos = Tensor.arange(0, t, dtype=dtypes.long).unsqueeze(0)  # shape (1, t)
        # forward the GPT model itself
        print(f'input to GPT: {input_to_wte}, b: {b}, t: {t}, d: {d}')
        tok_emb = self.wte(input_to_wte)  # token embeddings of shape (b, t, gpt_hidden_dim)
        pos_emb = self.wpe(pos)  # position embeddings of shape (1, t, gpt_hidden_dim)
        x = (tok_emb + pos_emb).dropout(self.drop)
        for block in self.h:
            x = block(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)
        return logits
    def configure_parameters(self):
        """
        This long function is unfortunately doing something very simple and is being very defensive:
        We are separating out all parameters of the model into two buckets: those that will experience
        weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
        """

        # separate out all parameters to those that will and won't experience regularizing weight decay
        decay = set()
        no_decay = set()
        whitelist_weight_modules = (nn.Linear,)
        blacklist_weight_modules = (nn.LayerNorm, nn.Embedding)
        for pn, _p in nn.state.get_state_dict(self).items():
            fpn = pn  # full param name
            # Kind of strange, but - attn's bias is a *not* 
            if pn.endswith("bias"):
                # all biases will not be decayed
                no_decay.add(fpn)
                continue
            if pn.endswith("mlp.0.weight") or pn.endswith("mlp.2.weight"):
                # all biases will not be decayed
                decay.add(fpn)
                continue
            if pn.endswith("c_attn.weight") or pn.endswith("c_proj.weight"):
                # all biases will not be decayed
                decay.add(fpn)
                continue
            if pn.endswith("ln_1.weight") or pn.endswith("ln_2.weight"):
                # all biases will not be decayed
                no_decay.add(fpn)
                continue
        decay.add('wte.weight')
        decay.add('lm_head.weight')
        no_decay.add('wpe.weight')
        no_decay.add('ln_f.weight')

        # validate that we considered every parameter
        param_dict = nn.state.get_state_dict(self)
        inter_params = decay & no_decay
        union_params = decay | no_decay
        assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format(
            str(inter_params)
        )
        assert (
            len(param_dict.keys() - union_params) == 0
        ), "parameters {} were not separated into either decay/no_decay set!".format(
            str(param_dict.keys() - union_params),
        )
        print(f'decay list: {sorted(decay)}')
        print(f'no_decay list: {sorted(no_decay)}')

        decay = [param_dict[pn] for pn in sorted(decay)]
        no_decay = [param_dict[pn] for pn in sorted(no_decay)]
        # return the parameters that require weight decay, and the parameters that don't separately.
        return decay, no_decay

# VQBeTRgbEncoder: Image Encoding for VQ-BeTs

The `VQBeTRgbEncoder` class is a crucial component of Vector-Quantized Behavior Transformers (VQ-BeTs) designed to process RGB images. This encoder transforms raw image inputs into compact feature vectors, which can then be used by the VQ-BeT model for behavior prediction and generation.

Key Features and Their Importance for VQ-BeTs:

1. **Flexible Preprocessing**:
   - Optional cropping (random or center) based on configuration
   - Allows for consistent input sizes and potential data augmentation

2. **ResNet Backbone**:
   - Uses a ResNet18 architecture for feature extraction
   - Can be configured with GroupNorm for improved performance in some scenarios

3. **Spatial Softmax Pooling**:
   - Converts spatial feature maps into a fixed-size vector
   - Helps in focusing on the most important spatial features

4. **Configurable Architecture**:
   - Uses a `VQBeTConfig` object for flexible setup
   - Allows easy experimentation with different encoder configurations

5. **Feature Dimensionality Reduction**:
   - Final linear layer to adjust the feature dimension
   - Crucial for matching the input requirements of the VQ-BeT model

Why is this crucial for VQ-BeTs?

- **Visual Input Processing**: Enables VQ-BeTs to work with raw image inputs, essential for visual-based tasks
- **Feature Extraction**: Transforms high-dimensional image data into compact, informative feature vectors
- **Spatial Understanding**: The combination of CNN and spatial softmax allows the model to capture spatial relationships in the input
- **Adaptability**: The configurable nature allows the encoder to be tailored for different visual tasks and environments

Let's examine the implementation of the VQBeTRgbEncoder class:

In [14]:

###
### And now our real hill-climb. Here's where the main part of VQ-BeT lies
### This is a knockoff of the same image -> audio encoder that DiffusionPolicy's RGB encoder uses.
### I'm gonna assume this is the best way to get ResNet18 to get an image to an encoded latent that the audio encoder understands
###

import tinygrad.nn as nn
from tinygrad import Tensor, dtypes
from tinygrad.nn.state import torch_load
from tinygrad.helpers import fetch, get_child

# allow monkeypatching in layer implementations
BatchNorm = nn.BatchNorm2d
Conv2d = nn.Conv2d
Linear = nn.Linear

class ResNet18Backbone:
    def __init__(self, use_group_norm=False):
        self.conv1 = Conv2d(3, 64, kernel_size=7, stride=2, bias=False, padding=3)
        self.bn1 = nn.GroupNorm(num_groups=4, num_channels=64) if use_group_norm else BatchNorm(64)

        self.layer1_0 = [
            # 0
            Conv2d(64, 64, kernel_size=3, stride=1, bias=False, padding=1),
            nn.GroupNorm(num_groups=4, num_channels=64) if use_group_norm else BatchNorm(64),
            Tensor.relu,
            Conv2d(64, 64, kernel_size=3, stride=1, bias=False, padding=1),
            nn.GroupNorm(num_groups=4, num_channels=64) if use_group_norm else BatchNorm(64),
        ]
        self.layer1_1 = [
            # 1
            Conv2d(64, 64, kernel_size=3, stride=1, bias=False, padding=1),
            nn.GroupNorm(num_groups=4, num_channels=64) if use_group_norm else BatchNorm(64),
            Tensor.relu,
            Conv2d(64, 64, kernel_size=3, stride=1, bias=False, padding=1),
            nn.GroupNorm(num_groups=4, num_channels=64) if use_group_norm else BatchNorm(64),
        ]
        self.layer2_0 = [
            # 0
            Conv2d(64, 128, kernel_size=3, stride=2, bias=False, padding=1),
            nn.GroupNorm(num_groups=8, num_channels=128) if use_group_norm else BatchNorm(128),
            Tensor.relu,
            Conv2d(128, 128, kernel_size=3, stride=1, bias=False, padding=1),
            nn.GroupNorm(num_groups=8, num_channels=128) if use_group_norm else BatchNorm(128)
        ]
        self.layer2_d = [
            # 0 downsample
            Conv2d(64, 128, kernel_size=1, stride=2, bias=False),
            nn.GroupNorm(num_groups=8, num_channels=128) if use_group_norm else BatchNorm(128)
        ]
        self.layer2_1 = [
            # 1
            Conv2d(128, 128, kernel_size=3, stride=1, bias=False, padding=1),
            nn.GroupNorm(num_groups=8, num_channels=128) if use_group_norm else BatchNorm(64),
            Tensor.relu,
            Conv2d(128, 128, kernel_size=3, stride=1, bias=False, padding=1),
            nn.GroupNorm(num_groups=8, num_channels=128) if use_group_norm else BatchNorm(128),
        ]
        self.layer3_0 = [
            # 0
            Conv2d(128, 256, kernel_size=3, stride=2, bias=False, padding=1),
            nn.GroupNorm(num_groups=16, num_channels=256) if use_group_norm else BatchNorm(256),
            Tensor.relu,
            Conv2d(256, 256, kernel_size=3, stride=1, bias=False, padding=1),
            nn.GroupNorm(num_groups=16, num_channels=256) if use_group_norm else BatchNorm(256)
        ]
        self.layer3_d = [
            # 0 downsample
            Conv2d(128, 256, kernel_size=1, stride=2, bias=False),
            nn.GroupNorm(num_groups=16, num_channels=256) if use_group_norm else BatchNorm(256)
        ]
        self.layer3_1 = [
            # 1
            Conv2d(256, 256, kernel_size=3, stride=1, bias=False, padding=1),
            nn.GroupNorm(num_groups=16, num_channels=256) if use_group_norm else BatchNorm(256),
            Tensor.relu,
            Conv2d(256, 256, kernel_size=3, stride=1, bias=False, padding=1),
            nn.GroupNorm(num_groups=16, num_channels=256) if use_group_norm else BatchNorm(256),
        ]
        self.layer4_0 = [
            # 0
            Conv2d(256, 512, kernel_size=3, stride=2, bias=False, padding=1),
            nn.GroupNorm(num_groups=32, num_channels=512) if use_group_norm else BatchNorm(512),
            Tensor.relu,
            Conv2d(512, 512, kernel_size=3, stride=1, bias=False, padding=1),
            nn.GroupNorm(num_groups=32, num_channels=512) if use_group_norm else BatchNorm(512)
        ]
        self.layer4_d = [
            # 0 downsample
            Conv2d(256, 512, kernel_size=1, stride=2, bias=False),
            nn.GroupNorm(num_groups=32, num_channels=512) if use_group_norm else BatchNorm(512)
        ]
        self.layer4_1 = [
            # 1
            Conv2d(512, 512, kernel_size=3, stride=1, bias=False, padding=1),
            nn.GroupNorm(num_groups=32, num_channels=512) if use_group_norm else BatchNorm(512),
            Tensor.relu,
            Conv2d(512, 512, kernel_size=3, stride=1, bias=False, padding=1),
            nn.GroupNorm(num_groups=32, num_channels=512) if use_group_norm else BatchNorm(512),
        ]

    
    def __call__(self, x:Tensor) -> Tensor:
        out = self.bn1(self.conv1(x)).relu()
        out = out.max_pool2d(kernel_size=(3,3), stride=2, padding=1, dilation=1)
        out = out.sequential(self.layer1_0)
        out = out.sequential(self.layer1_1)
        out = out.sequential(self.layer2_0) + out.sequential(self.layer2_d)
        out = out.sequential(self.layer2_1)
        out = out.sequential(self.layer3_0) + out.sequential(self.layer3_d)
        out = out.sequential(self.layer3_1)
        out = out.sequential(self.layer4_0) + out.sequential(self.layer4_d)
        out = out.sequential(self.layer4_1)
        return out
   
def center_crop_f(size: tuple, img: Tensor) -> Tensor:
    BS, C, H, W = img.shape
    crop_h, crop_w = size

    # Calculate the starting points for the crop
    start_h = (H - crop_h) // 2
    start_w = (W - crop_w) // 2

    # Create a boolean mask for the center crop
    mask = Tensor.zeros((BS, C, H, W), dtype=dtypes.bool)
    mask[:, :, start_h:start_h+crop_h, start_w:start_w+crop_w] = True

    # Apply the mask and reshape
    X_cropped = Tensor(img.numpy()[mask.numpy()])
    return X_cropped.reshape((BS, C, crop_h, crop_w))

def make_square_mask_2d(shape, mask_size) -> Tensor:
    H, W = shape
    low_x = Tensor.randint(1, low=0, high=W-mask_size).item()
    low_y = Tensor.randint(1, low=0, high=H-mask_size).item()
    return low_y, low_x

def random_crop_f_2d(size: tuple, img: Tensor) -> Tensor:
    H, W = img.shape
    y1, x1 = make_square_mask_2d(img.shape, size[0])
    
    # Perform the crop
    X_cropped = img[y1:y1+size[0], x1:x1+size[0]]
    
    return X_cropped

def make_square_mask_3d(shape, mask_size) -> tuple:
    C, H, W = shape
    low_x = Tensor.randint(1, low=0, high=W - mask_size).item()
    low_y = Tensor.randint(1, low=0, high=H - mask_size).item()
    return low_y, low_x

def random_crop_f_3d(size: tuple, img: Tensor) -> Tensor:
    C, H, W = img.shape
    y1, x1 = make_square_mask_3d(img.shape, size[0])
    
    # Perform the crop
    X_cropped = img[:, y1:y1 + size[0], x1:x1 + size[0]]
    
    return X_cropped

def make_square_mask(shape, mask_size) -> Tensor:
    BS, _, H, W = shape
    low_x = Tensor.randint(BS, low=0, high=W-mask_size, requires_grad=False).reshape(BS,1,1,1)
    low_y = Tensor.randint(BS, low=0, high=H-mask_size, requires_grad=False).reshape(BS,1,1,1)
    idx_x = Tensor.arange(W, dtype=dtypes.int32, requires_grad=False).reshape((1,1,1,W))
    idx_y = Tensor.arange(H, dtype=dtypes.int32, requires_grad=False).reshape((1,1,H,1))
    return (idx_x >= low_x) * (idx_x < (low_x + mask_size)) * (idx_y >= low_y) * (idx_y < (low_y + mask_size))

def random_crop_f(size:tuple, img: Tensor) -> Tensor:
    mask = make_square_mask(img.shape, size[0])
    mask = mask.expand((-1,3,-1,-1))
    X_cropped = Tensor(img.numpy()[mask.numpy()])
    return X_cropped.reshape((-1, 3, size[0], size[0]))


class VQBeTRgbEncoder:
    """Encoder an RGB image into a 1D feature vector.

    Includes the ability to normalize and crop the image first.
    """

    def __init__(self, config: VQBeTConfig):
        # Set up optional preprocessing.
        if config.crop_shape is not None:
            self.do_crop = True
            # Always use center crop for eval
            self.center_crop = lambda x: center_crop_f(config.crop_shape, x)
            if config.crop_is_random:
                self.maybe_random_crop = lambda x: random_crop_f(config.crop_shape, x)
            else:
                self.maybe_random_crop = self.center_crop
        else:
            self.do_crop = False

        # Set up backbone.
        # Note: This assumes that the layer4 feature map is children()[-3]
        # TODO(alexander-soare): Use a safer alternative.
        self.backbone = ResNet18Backbone(use_group_norm=config.use_group_norm) 
        #self.backbone = generate_resnet_model(config.vision_backbone, use_group_norm=config.use_group_norm)
        if config.use_group_norm and config.pretrained_backbone_weights:
            raise ValueError(
                "You can't replace BatchNorm in a pretrained model without ruining the weights!"
            )

        # Set up pooling and final layers.
        # Use a dry run to get the feature map shape.
        # The dummy input should take the number of image channels from `config.input_shapes` and it should
        # use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
        # height and width from `config.input_shapes`.
        image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
        # Note: we have a check in the config class to make sure all images have the same shape.
        image_key = image_keys[0]
        dummy_input_h_w = (
            config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:]
        )
        dummy_input = Tensor.zeros(1, config.input_shapes[image_key][0], *dummy_input_h_w, requires_grad=False)
        Tensor.no_grad = True
        training_prev, Tensor.training = Tensor.training, False
        dummy_feature_map = self.backbone(dummy_input)
        print(f'dummy_feature_map: {dummy_feature_map}')
        Tensor.no_grad = False
        Tensor.training = training_prev
        feature_map_shape = tuple(dummy_feature_map.shape[1:])
        self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
        self.feature_dim = config.spatial_softmax_num_keypoints * 2
        self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)

    def __call__(self, x: Tensor) -> Tensor:
        """
        Args:
            x: (B, C, H, W) image tensor with pixel values in [0, 1].
        Returns:
            (B, D) image feature.
        """
        print(f'Input to VQBeTRgbEncoder: {x}')
        # Preprocess: maybe crop (if it was set up in the __init__).
        if self.do_crop:
            if Tensor.training:  # noqa: SIM108
                x = self.maybe_random_crop(x)
            else:
                # Always use center crop for eval.
                x = self.center_crop(x)
        print(f'Input to VQBeTRgbEncoder after crop: {x}')
        # Extract backbone feature.
        x = self.pool(self.backbone(x)).flatten(start_dim=1)
        # Final linear layer with non-linearity.
        x = self.out(x).relu()
        return x

# VQBeTScheduler: Learning Rate Scheduler for VQ-BeTs

The `VQBeTScheduler` class implements a custom learning rate scheduler specifically designed for training Vector-Quantized Behavior Transformers (VQ-BeTs). This scheduler manages the learning rate throughout the training process, adapting it based on the current training phase and progress.

Key Features and Their Importance for VQ-BeTs:

1. **Two-Phase Training Support**:
   - Maintains a constant learning rate during VQ-VAE training
   - Implements a more complex schedule for the main VQ-BeT training phase
   - Essential for handling the distinct requirements of each training phase

2. **Warmup Period**:
   - Gradually increases the learning rate at the start of the main training phase
   - Helps stabilize training in the early stages, particularly important for transformer models

3. **Cosine Decay with Restarts**:
   - Implements a cyclical learning rate schedule after the warmup period
   - Allows for better exploration of the loss landscape, potentially leading to improved convergence

4. **Configurable Parameters**:
   - Uses configuration object for flexible setup (number of steps, warmup period, etc.)
   - Enables easy experimentation with different scheduling strategies

Why is this crucial for VQ-BeTs?

- **Phase-Specific Optimization**: Addresses the different learning dynamics of the VQ-VAE and transformer components
- **Training Stability**: The warmup period helps prevent early training instability, especially important for the transformer part
- **Improved Convergence**: Cyclical learning rates can lead to better final performance and faster convergence
- **Flexibility**: Allows for easy adjustment of the learning rate schedule to suit different datasets and task complexities

Let's examine the implementation of the VQBeTScheduler class:

In [15]:
from tinygrad.nn.optim import Optimizer
from tinygrad.tensor import Tensor

class LR_Scheduler:
  def __init__(self, optimizer: Optimizer):
    self.optimizer = optimizer
    self.epoch_counter = Tensor([0], requires_grad=False, device=self.optimizer.device)

  def get_lr(self): pass

  def step(self) -> None:
    self.epoch_counter.assign(self.epoch_counter + 1).realize()
    self.optimizer.lr.assign(self.get_lr()).realize()

class VQBeTScheduler(LR_Scheduler):
    def __init__(self, optimizer, cfg):
        self.optimizer = optimizer
        self.n_vqvae_training_steps = cfg.n_vqvae_training_steps
        self.num_warmup_steps = 500
        self.num_training_steps = 250000
        self.num_cycles = 0.5
        self.current_step = 0

    # Breakdown of this:
    # If the current step is less than the VQ-VAE training step, just don't decay
    # If the current step is greater, then you see:
    # 1. how much greater
    # 2. if it's warmed up yet. If not, return a percentage of the warmup
    # 3. If it's warm, do a cos based on the progress and the number of cycles. Add, never let it get above 1. Should be cyclical
    def get_lr(self):
        if self.current_step < self.n_vqvae_training_steps:
            return 1.0
        else:
            current_step = self.current_step - self.n_vqvae_training_steps
            if current_step < self.num_warmup_steps:
                return float(current_step) / float(max(1, self.num_warmup_steps))
            progress = float(current_step - self.num_warmup_steps) / float(
                max(1, self.num_training_steps - self.num_warmup_steps)
            )
            return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)))

# VQBeTOptimizer: Customized Optimization for VQ-BeTs

The `VQBeTOptimizer` class implements a specialized optimization strategy for Vector-Quantized Behavior Transformers (VQ-BeTs). This optimizer is crucial for effectively training the various components of the VQ-BeT model, each with its own learning requirements.

Key Features and Their Importance for VQ-BeTs:

1. **Parameter Grouping**:
   - Separates parameters into different groups (VQ-VAE, decay, no decay)
   - Allows for different optimization strategies for different parts of the model

2. **AdamW Optimizer**:
   - Uses AdamW, which combines the benefits of Adam with weight decay regularization
   - Crucial for training large transformer models effectively

3. **Different Learning Rates**:
   - Applies different learning rates to VQ-VAE and other components
   - Enables fine-tuned optimization for each part of the model

4. **Weight Decay Configuration**:
   - Applies weight decay selectively to certain parameter groups
   - Helps in preventing overfitting while maintaining important features

5. **Flexible Architecture Support**:
   - Adapts to different VQ-BeT configurations (e.g., sequential selection)
   - Ensures the optimizer can handle various model architectures

Why is this crucial for VQ-BeTs?

- **Component-Specific Optimization**: Different parts of VQ-BeTs (e.g., VQ-VAE, transformer) may require different optimization strategies
- **Regularization Control**: Selective application of weight decay helps in managing the complexity of the model
- **Training Stability**: Carefully tuned optimization parameters contribute to stable and effective training
- **Adaptability**: The flexible design allows for easy adaptation to different VQ-BeT variants and configurations

Let's examine the implementation of the VQBeTOptimizer class:

In [16]:
# I...don't think this merits its own optimizer.
# Instead, segregate it into the VQBeT Optimization step.

from tinygrad.nn.state import get_parameters
from tinygrad.nn.optim import *
class VQBeTOptimizer:
    @staticmethod
    def create(policy, cfg):
        vqvae_params = (
            get_parameters(policy.vqbet.action_head.vqvae_model.encoder)
            + get_parameters(policy.vqbet.action_head.vqvae_model.decoder)
            + get_parameters(policy.vqbet.action_head.vqvae_model.vq_layer)
        )
        decay_params, no_decay_params = policy.vqbet.policy.configure_parameters()
        decay_params = (
            decay_params
            + get_parameters(policy.vqbet.rgb_encoder)
            + get_parameters(policy.vqbet.state_projector)
            + get_parameters(policy.vqbet.rgb_feature_projector)
            + [policy.vqbet.action_token]
            + get_parameters(policy.vqbet.action_head.map_to_cbet_preds_offset)
        )

        if cfg.sequentially_select:
            decay_params = (
                decay_params
                + get_parameters(policy.vqbet.action_head.map_to_cbet_preds_primary_bin)
                + get_parameters(policy.vqbet.action_head.map_to_cbet_preds_secondary_bin)
            )
        else:
            decay_params = decay_params + get_parameters(policy.vqbet.action_head.map_to_cbet_preds_bin)

        optim_and_param_groups = [
            AdamW(
                decay_params, 
                lr=1.0e-4, 
                b1=0.95,
                b2=0.999,
                eps=1.0e-8,
                weight_decay=1.0e-6
            ),
            AdamW(
                vqvae_params, 
                lr=1.0e-3, 
                b1=0.9,
                b2=0.999,
                eps=1.0e-8,
                weight_decay=0.0001
            ),
            AdamW(
                no_decay_params, 
                lr=1.0e-4, 
                b1=0.95,
                b2=0.999,
                eps=1.0e-8,
                weight_decay=0.0
            ),
            decay_params, vqvae_params, no_decay_params
        ]
        return optim_and_param_groups

# VQBeTHead: Action Prediction Component for VQ-BeTs

The `VQBeTHead` class is a crucial component of Vector-Quantized Behavior Transformers (VQ-BeTs), responsible for predicting actions based on the output of the transformer layers. This class integrates the VQ-VAE model with prediction heads for both discrete code selection and continuous offset prediction.

Key Features and Their Importance for VQ-BeTs:

1. **VQ-VAE Integration**:
   - Incorporates a pre-trained VQ-VAE model for action discretization and reconstruction
   - Essential for bridging the gap between discrete tokens and continuous actions

2. **Bin Prediction**:
   - Predicts probabilities for each discrete code in the VQ-VAE codebook
   - Supports both simultaneous and sequential code selection strategies

3. **Offset Prediction**:
   - Predicts continuous offsets for fine-tuning the selected discrete actions
   - Allows for more precise action generation within the quantized action space

4. **Flexible Architecture**:
   - Adapts to different VQ-VAE configurations (e.g., number of layers, codebook size)
   - Supports both simultaneous and sequential code selection strategies

5. **Focal Loss**:
   - Uses Focal Loss for training the bin prediction head
   - Helps in addressing class imbalance issues in code prediction

Why is this crucial for VQ-BeTs?

- **Action Generation**: Enables the model to generate both discrete and continuous components of actions
- **Hierarchical Prediction**: The two-stage prediction (bin and offset) allows for more nuanced action generation
- **Adaptability**: Can work with different VQ-VAE configurations and selection strategies
- **Fine-grained Control**: The offset prediction allows for precise adjustments to the quantized actions

Let's examine the implementation of the VQBeTHead class:

In [17]:
## Main method (finally)

class VQBeTHead:
    def __init__(self, config: VQBeTConfig):
        """
        VQBeTHead takes output of GPT layers, and pass the feature through bin prediction head (`self.map_to_cbet_preds_bin`), and offset prediction head (`self.map_to_cbet_preds_offset`)

        self.map_to_cbet_preds_bin: outputs probability of each code (for each layer).
            The input dimension of `self.map_to_cbet_preds_bin` is same with the output of GPT,
            and the output dimension of `self.map_to_cbet_preds_bin` is `self.vqvae_model.vqvae_num_layers (=fixed as 2) * self.config.vqvae_n_embed`.
            if the agent select the code sequentially, we use self.map_to_cbet_preds_primary_bin and self.map_to_cbet_preds_secondary_bin instead of self._map_to_cbet_preds_bin.

        self.map_to_cbet_preds_offset: output the predicted offsets for all the codes in all the layers.
            The input dimension of ` self.map_to_cbet_preds_offset` is same with the output of GPT,
            and the output dimension of ` self.map_to_cbet_preds_offset` is `self.vqvae_model.vqvae_num_layers (=fixed as 2) * self.config.vqvae_n_embed * config.action_chunk_size * config.output_shapes["action"][0]`.
        """

        self.config = config
        # init vqvae
        self.vqvae_model = VqVae(config)
        if config.sequentially_select:
            self.map_to_cbet_preds_primary_bin = MLP(
                in_channels=config.gpt_output_dim,
                hidden_channels=[self.config.vqvae_n_embed],
            )
            self.map_to_cbet_preds_secondary_bin = MLP(
                in_channels=config.gpt_output_dim + self.config.vqvae_n_embed,
                hidden_channels=[self.config.vqvae_n_embed],
            )
        else:
            self.map_to_cbet_preds_bin = MLP(
                in_channels=config.gpt_output_dim,
                hidden_channels=[self.vqvae_model.vqvae_num_layers * self.config.vqvae_n_embed],
            )
        self.map_to_cbet_preds_offset = MLP(
            in_channels=config.gpt_output_dim,
            hidden_channels=[
                self.vqvae_model.vqvae_num_layers
                * self.config.vqvae_n_embed
                * config.action_chunk_size
                * config.output_shapes["action"][0],
            ],
        )
        # loss
        self._focal_loss_fn = FocalLoss(gamma=2.0)

    def discretize(self, n_vqvae_training_steps, actions):
        # Resize the action sequence data to fit the action chunk size using a sliding window approach.        
        actions = Tensor.cat(
            *[
                actions[:, j : j + self.config.action_chunk_size, :]
                for j in range(actions.shape[1] + 1 - self.config.action_chunk_size)
            ],
            dim=0,
        )
        # `actions` is a tensor of shape (new_batch, action_chunk_size, action_dim) where new_batch is the number of possible chunks created from the original sequences using the sliding window.

        loss, metric = self.vqvae_model.vqvae_forward(actions)
        n_different_codes = Tensor(sum(
            [len(np.unique(metric[2][:, i].numpy())) for i in range(self.vqvae_model.vqvae_num_layers)]
        ), requires_grad=False)
        n_different_combinations = len(np.unique(metric[2].numpy(), axis=0))
        recon_l1_error = metric[0].detach().item()
        self.vqvae_model.optimized_steps += 1
        # if we updated RVQ more than `n_vqvae_training_steps` steps, we freeze the RVQ part.
        if self.vqvae_model.optimized_steps >= n_vqvae_training_steps:
            self.vqvae_model.discretized = True
            self.vqvae_model.vq_layer.freeze_codebook = True
            print("Finished discretizing action data!")
            for param in nn.state.get_parameters(self.vqvae_model.vq_layer):
                param.requires_grad = False
        return loss, n_different_codes, n_different_combinations, recon_l1_error

    def __call__(self, x:Tensor, **kwargs) -> Tensor:
        # N is the batch size, and T is number of action query tokens, which are process through same GPT
        N, T, _ = x.shape
        print(f'x before einops first: {x}')
        # we calculate N and T side parallely. Thus, the dimensions would be
        # (batch size * number of action query tokens, action chunk size, action dimension)
        x = einops.rearrange(x, "N T WA -> (N T) WA")
        print(f'x after einops first: {x}')

        # sample offsets
        cbet_offsets = self.map_to_cbet_preds_offset(x)
        print(f'cbet_offsets before einops.rearrange: {cbet_offsets}')
        cbet_offsets = einops.rearrange(
            cbet_offsets,
            "(NT) (G C WA) -> (NT) G C WA",
            G=self.vqvae_model.vqvae_num_layers,
            C=self.config.vqvae_n_embed,
        )
        print(f'cbet_offsets after einops.rearrange: {cbet_offsets}')
        # if self.config.sequentially_select is True, bin prediction head first sample the primary code, and then sample secondary code
        if self.config.sequentially_select:
            cbet_primary_logits = self.map_to_cbet_preds_primary_bin(x)

            # select primary bin first
            cbet_primary_probs = (cbet_primary_logits / self.config.bet_softmax_temperature).softmax()
            NT, choices = cbet_primary_probs.shape
            sampled_primary_centers = einops.rearrange(
                cbet_primary_probs.view(-1, choices).multinomial(num_samples=1),
                "(NT) 1 -> NT",
                NT=NT,
            )

            cbet_secondary_logits = self.map_to_cbet_preds_secondary_bin(
                Tensor.cat(
                    x, sampled_primary_centers.one_hot(num_classes=self.config.vqvae_n_embed),
                    axis=1,
                )
            )
            cbet_secondary_probs = (cbet_secondary_logits / self.config.bet_softmax_temperature).softmax(
                dim=-1
            )
            sampled_secondary_centers = einops.rearrange(
                cbet_secondary_probs.view(-1, choices).multinomial(num_samples=1),
                "(NT) 1 -> NT",
                NT=NT,
            )
            sampled_centers = Tensor.stack(sampled_primary_centers, sampled_secondary_centers, dim=1)
            cbet_logits = Tensor.stack(cbet_primary_logits, cbet_secondary_logits, dim=1)
        # if self.config.sequentially_select is False, bin prediction head samples primary and secondary code at once.
        else:
            cbet_logits = self.map_to_cbet_preds_bin(x)
            print(f'cbet_logits before: {cbet_logits}')
            cbet_logits = einops.rearrange(
                cbet_logits, "(NT) (G C) -> (NT) G C", G=self.vqvae_model.vqvae_num_layers
            )
            print(f'cbet_logits after: {cbet_logits}')
            cbet_probs = (cbet_logits / self.config.bet_softmax_temperature).softmax(axis=-1)
            print(f'cbet_probs: {cbet_probs}')
            NT, G, choices = cbet_probs.shape
            sampled_centers = einops.rearrange(
                cbet_probs.view(-1, choices).multinomial(num_samples=1),
                "(NT G) 1 -> NT G",
                NT=NT,
            ).contiguous().detach()
            print(f'sampled_centers after einops first: {sampled_centers}')

        indices = (
            Tensor.arange(NT, requires_grad=False).unsqueeze(1),
            Tensor.arange(self.vqvae_model.vqvae_num_layers, requires_grad=False).unsqueeze(0),
            sampled_centers,
        )
        print(f'indices: {indices}')
        # Use advanced indexing to sample the values (Extract the only offsets corresponding to the sampled codes.)
        sampled_offsets = cbet_offsets[indices]
        print(f'sampled_offsets: {sampled_offsets}')
        # Then, sum the offsets over the RVQ layers to get a net offset for the bin prediction
        sampled_offsets = sampled_offsets.sum(axis=1)
        Tensor.no_grad = True
        training_prev, Tensor.training = Tensor.training, False
        # Get the centroids (= vectors corresponding to the codes) of each layer to pass it through RVQ decoder
        return_decoder_input = self.vqvae_model.get_embeddings_from_code(sampled_centers)
        print(f'return_decoder_input: {return_decoder_input}')
        # pass the centroids through decoder to get actions.
        decoded_action = self.vqvae_model.get_action_from_latent(return_decoder_input)
        print(f'decoded_action: {decoded_action}')
        Tensor.no_grad = False
        Tensor.training = training_prev
        # reshaped extracted offset to match with decoded centroids
        sampled_offsets = einops.rearrange(
            sampled_offsets, "NT (W A) -> NT W A", W=self.config.action_chunk_size
        )
        # add offset and decoded centroids
        predicted_action = decoded_action + sampled_offsets
        predicted_action = einops.rearrange(
            predicted_action,
            "(N T) W A -> N T (W A)",
            N=N,
            T=T,
            W=self.config.action_chunk_size,
        )

        return {
            "cbet_logits": cbet_logits,
            "predicted_action": predicted_action,
            "sampled_centers": sampled_centers,
            "decoded_action": decoded_action,
        }

    def loss_fn(self, pred, target, **kwargs):
        """
        for given ground truth action values (target), and prediction (pred) this function calculates the overall loss.

        predicted_action: predicted action chunk (offset + decoded centroids)
        sampled_centers: sampled centroids (code of RVQ)
        decoded_action: decoded action, which is produced by passing sampled_centers through RVQ decoder
        NT: batch size * T
        T: number of action query tokens, which are process through same GPT
        cbet_logits: probability of all codes in each layer
        """
        action_seq = target
        predicted_action = pred["predicted_action"]
        sampled_centers = pred["sampled_centers"]
        decoded_action = pred["decoded_action"]
        NT = predicted_action.shape[0] * predicted_action.shape[1]

        cbet_logits = pred["cbet_logits"]

        predicted_action = einops.rearrange(
            predicted_action, "N T (W A) -> (N T) W A", W=self.config.action_chunk_size
        )

        action_seq = einops.rearrange(action_seq, "N T W A -> (N T) W A")
        print(f'action_seq: {action_seq}')

        # offset loss is L1 distance between the predicted action and ground truth action
        offset_loss = (action_seq - predicted_action).abs().mean()
        
        # Figure out the loss for the actions.
        # First, we need to find the closest cluster center for each ground truth action.
        Tensor.no_grad = True
        training_prev, Tensor.training = Tensor.training, False
        state_vq, action_bins = self.vqvae_model.get_code(action_seq)  # action_bins: NT, G
        action_bins = action_bins.detach()
        Tensor.no_grad = False
        Tensor.training = training_prev

        print(f'state_vq, action_bins: {state_vq}, {action_bins}')

        # Now we can compute the loss.

        # calculate primary code prediction loss
        cbet_loss1 = self._focal_loss_fn(
            cbet_logits[:, 0, :],
            action_bins[:, 0],
        )
        # calculate secondary code prediction loss
        cbet_loss2 = self._focal_loss_fn(
            cbet_logits[:, 1, :],
            action_bins[:, 1],
        )
        # add all the prediction loss
        cbet_loss = cbet_loss1 * self.config.primary_code_loss_weight + cbet_loss2 * self.config.secondary_code_loss_weight

        print(f'cbet_loss1: {cbet_loss1}')
        print(f'cbet_loss2: {cbet_loss2}')
        print(f'cbet_loss: {cbet_loss}')

        equal_primary_code_rate = (action_bins[:, 0] == sampled_centers[:, 0]).int().sum() / (NT)
        equal_secondary_code_rate = (action_bins[:, 1] == sampled_centers[:, 1]).int().sum() / (NT)

        print(f'equal_primary_code_rate: {(action_bins[:, 0] == sampled_centers[:, 0])}')
        print(f'equal_secondary_code_rate: {(action_bins[:, 1] == sampled_centers[:, 1])}')

        action_mse_error = (action_seq - predicted_action).square().mean()
        vq_action_error = (action_seq - decoded_action).abs().mean()
        offset_action_error = (action_seq - predicted_action).abs().mean()
        action_error_max = (action_seq - predicted_action).abs().max()

        print(f'action_mse_error: {action_mse_error}')
        print(f'offset_loss: {offset_loss}')

        loss = cbet_loss + self.config.offset_loss_weight * offset_loss

        loss_dict = {
            "loss": loss,
            "classification_loss": cbet_loss.detach().item(),
            "offset_loss": offset_loss.detach().item(),
            "equal_primary_code_rate": equal_primary_code_rate.detach().item(),
            "equal_secondary_code_rate": equal_secondary_code_rate.detach().item(),
            "vq_action_error": vq_action_error.detach().item(),
            "offset_action_error": offset_action_error.detach().item(),
            "action_error_max": action_error_max.detach().item(),
            "action_mse_error": action_mse_error.detach().item(),
        }
        return loss_dict

# VQBeTModel: The Core Architecture of Vector-Quantized Behavior Transformers

The `VQBeTModel` class represents the complete architecture of Vector-Quantized Behavior Transformers (VQ-BeTs). This model integrates various components to process visual and state observations, generate action predictions, and handle the vector quantization of the action space.

Key Components and Their Roles:

1. **RGB Encoder**:
   - Processes image observations into compact feature vectors
   - Essential for handling visual input in tasks like robotic manipulation or autonomous driving

2. **State and RGB Feature Projectors**:
   - Project state observations and RGB features to the appropriate dimension for the GPT model
   - Ensure all inputs are compatible with the transformer architecture

3. **Action Query Token**:
   - Serves as a prompt for querying action chunks
   - Crucial for the model to distinguish between observation and action prediction tasks

4. **GPT (Generative Pre-trained Transformer)**:
   - Core component for sequence modeling
   - Processes the sequence of observations and action queries

5. **VQBeTHead**:
   - Handles the prediction of discrete codes and continuous offsets
   - Integrates the Vector Quantized Variational Autoencoder (VQ-VAE) for action discretization

6. **Training Phases**:
   - Phase 1: Discretize actions using Residual VQ
   - Phase 2: Train the full model to predict future actions

Why is this architecture crucial for VQ-BeTs?

- **End-to-End Learning**: Integrates perception, sequence modeling, and action generation in a single architecture
- **Hierarchical Representation**: Combines continuous and discrete representations of actions
- **Long-term Dependencies**: The transformer architecture allows for capturing long-range dependencies in behavior sequences
- **Flexibility**: Can handle various types of inputs (images, states) and output complex, multi-step action predictions

Let's examine the implementation of the VQBeTModel class:

In [18]:

class VQBeTModel:
    """VQ-BeT: The underlying neural network for VQ-BeT

    Note: In this code we use the terms `rgb_encoder`, 'policy', `action_head`. The meanings are as follows.
        - The `rgb_encoder` process rgb-style image observations to one-dimensional embedding vectors
        - A `policy` is a minGPT architecture, that takes observation sequences and action query tokens to generate `features`.
        - These `features` pass through the action head, which passes through the code prediction, offset prediction head,
        and finally generates a prediction for the action chunks.

        -------------------------------** legend **-------------------------------
        │   n = n_obs_steps, p = n_action_pred_token, c = action_chunk_size)   │
        │   o_{t} : visual observation at timestep {t}                           │
        │   s_{t} : state observation at timestep {t}                            │
        │   a_{t} : action at timestep {t}                                       │
        │   A_Q : action_query_token                                             │
        --------------------------------------------------------------------------


        Training Phase 1. Discretize action using Residual VQ (for config.n_vqvae_training_steps steps)


        ┌─────────────────┐            ┌─────────────────┐            ┌─────────────────┐
        │                 │            │                 │            │                 │
        │   RVQ encoder   │    ─►      │     Residual    │    ─►      │   RVQ Decoder   │
        │ (a_{t}~a_{t+p}) │            │  Code Quantizer │            │                 │
        │                 │            │                 │            │                 │
        └─────────────────┘            └─────────────────┘            └─────────────────┘

        Training Phase 2.

          timestep {t-n+1}   timestep {t-n+2}                timestep {t}
            ┌─────┴─────┐     ┌─────┴─────┐                 ┌─────┴─────┐

        o_{t-n+1}         o_{t-n+2}           ...         o_{t}
            │                 │                             │
            │ s_{t-n+1}       │ s_{t-n+2}         ...       │   s_{t}           p
            │     │           │     │                       │     │     ┌───────┴───────┐
            │     │    A_Q    │     │    A_Q          ...   │     │    A_Q     ...     A_Q
            │     │     │     │     │     │                 │     │     │               │
        ┌───▼─────▼─────▼─────▼─────▼─────▼─────────────────▼─────▼─────▼───────────────▼───┐
        │                                                                                   │
        │                                       GPT                                         │       =>    policy
        │                                                                                   │
        └───────────────▼─────────────────▼─────────────────────────────▼───────────────▼───┘
                        │                 │                             │               │
                    ┌───┴───┐         ┌───┴───┐                     ┌───┴───┐       ┌───┴───┐
                  code    offset    code    offset                code    offset  code    offset
                    ▼       │         ▼       │                     ▼       │       ▼       │       =>    action_head
               RVQ Decoder  │    RVQ Decoder  │                RVQ Decoder  │  RVQ Decoder  │
                    └── + ──┘         └── + ──┘                     └── + ──┘       └── + ──┘
                        ▼                 ▼                             ▼               ▼
                   action chunk      action chunk                  action chunk     action chunk
                    a_{t-n+1} ~       a_{t-n+2} ~                   a_{t} ~     ...  a_{t+p-1} ~
                     a_{t-n+c}         a_{t-n+c+1}                   a_{t+c-1}        a_{t+p+c-1}

                                                                        ▼
                                                      ONLY this chunk is used in rollout!
    """

    def __init__(self, config: VQBeTConfig):
        self.config = config

        self.rgb_encoder = VQBeTRgbEncoder(config)
        self.num_images = len([k for k in config.input_shapes if k.startswith("observation.image")])
        # This action query token is used as a prompt for querying action chunks. Please refer to "A_Q" in the image above.
        # Note: During the forward pass, this token is repeated as many times as needed. The authors also experimented with initializing the necessary number of tokens independently and observed inferior results.
        self.action_token = Tensor.randn(1, 1, self.config.gpt_input_dim)

        # To input state and observation features into GPT layers, we first project the features to fit the shape of input size of GPT.
        self.state_projector = MLP(
            config.input_shapes["observation.state"][0], hidden_channels=[self.config.gpt_input_dim]
        )
        self.rgb_feature_projector = MLP(
            self.rgb_encoder.feature_dim, hidden_channels=[self.config.gpt_input_dim]
        )

        # GPT part of VQ-BeT
        self.policy = GPT(config)
        # bin prediction head / offset prediction head part of VQ-BeT
        self.action_head = VQBeTHead(config)

        # Action tokens for: each observation step, the current action token, and all future action tokens.
        num_tokens = self.config.n_action_pred_token + self.config.n_obs_steps - 1
        ranges = [np.arange(i, i + self.config.action_chunk_size) for i in range(num_tokens)]
        self.select_target_actions_indices = Tensor(np.vstack(ranges), requires_grad=False)

    def __call__(self, batch: dict[str, Tensor], rollout: bool) -> Tensor:
        print(f'Calling VQBeTModel!')
        # Input validation.
        assert set(batch).issuperset({"observation.state", "observation.images"})
        batch_size, n_obs_steps = batch["observation.state"].shape[:2]
        assert n_obs_steps == self.config.n_obs_steps

        # Extract image feature (first combine batch and sequence dims).
        img_features = self.rgb_encoder(
            einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
        )
        # Separate batch and sequence dims.
        img_features = einops.rearrange(
            img_features, "(b s n) ... -> b s n ...", b=batch_size, s=n_obs_steps, n=self.num_images
        )

        # Arrange prior and current observation step tokens as shown in the class docstring.
        # First project features to token dimension.
        rgb_tokens = self.rgb_feature_projector(
            img_features
        )  # (batch, obs_step, number of different cameras, projection dims)
        input_tokens = [rgb_tokens[:, :, i] for i in range(rgb_tokens.size(2))]
        input_tokens.append(
            self.state_projector(batch["observation.state"])
        )  # (batch, obs_step, projection dims)
        input_tokens.append(einops.repeat(self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps))
        print(f'input_tokens before stack: {input_tokens}')
        # Interleave tokens by stacking and rearranging.
        input_tokens = Tensor.stack(*input_tokens, dim=2)
        input_tokens = einops.rearrange(input_tokens, "b n t d -> b (n t) d")

        len_additional_action_token = self.config.n_action_pred_token - 1
        future_action_tokens = self.action_token.repeat(batch_size, len_additional_action_token, 1)
        print(f'future_action_tokens: {future_action_tokens}')

        # add additional action query tokens for predicting future action chunks
        input_tokens = input_tokens.cat(future_action_tokens, dim=1)

        print(f'input_tokens: {input_tokens}')

        # get action features (pass through GPT)
        features = self.policy(input_tokens)
        print(f'features: {features}')
        # len(self.config.input_shapes) is the number of different observation modes.
        # this line gets the index of action prompt tokens.
        historical_act_pred_index = Tensor(np.arange(0, n_obs_steps) * (len(self.config.input_shapes) + 1) + len(
            self.config.input_shapes
        ), requires_grad=False)

        # only extract the output tokens at the position of action query:
        # Behavior Transformer (BeT), and VQ-BeT are both sequence-to-sequence prediction models,
        # mapping sequential observation to sequential action (please refer to section 2.2 in BeT paper https://arxiv.org/pdf/2206.11251).
        # Thus, it predicts a historical action sequence, in addition to current and future actions (predicting future actions : optional).
        if len_additional_action_token > 0:
            features = Tensor.cat(
                *[features[:, historical_act_pred_index], features[:, -len_additional_action_token:]], dim=1
            )
        else:
            features = features[:, historical_act_pred_index]
        # pass through action head
        action_head_output = self.action_head(features)
        # if rollout, VQ-BeT don't calculate loss
        if rollout:
            return action_head_output["predicted_action"][:, n_obs_steps - 1, :].reshape(
                batch_size, self.config.action_chunk_size, -1
            )
        # else, it calculate overall loss (bin prediction loss, and offset loss)
        else:
            output = batch["action"][:, self.select_target_actions_indices].detach()
            print(f'features: {features}')
            loss = self.action_head.loss_fn(action_head_output, output, reduction="mean")
            return action_head_output, loss

In [19]:
class SpatialSoftmax():
    """
    Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al.
    (https://arxiv.org/pdf/1509.06113). A minimal port of the robomimic implementation.

    At a high level, this takes 2D feature maps (from a convnet/ViT) and returns the "center of mass"
    of activations of each channel, i.e., keypoints in the image space for the policy to focus on.

    Example: take feature maps of size (512x10x12). We generate a grid of normalized coordinates (10x12x2):
    -----------------------------------------------------
    | (-1., -1.)   | (-0.82, -1.)   | ... | (1., -1.)   |
    | (-1., -0.78) | (-0.82, -0.78) | ... | (1., -0.78) |
    | ...          | ...            | ... | ...         |
    | (-1., 1.)    | (-0.82, 1.)    | ... | (1., 1.)    |
    -----------------------------------------------------
    This is achieved by applying channel-wise softmax over the activations (512x120) and computing the dot
    product with the coordinates (120x2) to get expected points of maximal activation (512x2).

    The example above results in 512 keypoints (corresponding to the 512 input channels). We can optionally
    provide num_kp != None to control the number of keypoints. This is achieved by a first applying a learnable
    linear mapping (in_channels, H, W) -> (num_kp, H, W).
    """

    def __init__(self, input_shape, num_kp=None):
        """
        Args:
            input_shape (list): (C, H, W) input feature map shape.
            num_kp (int): number of keypoints in output. If None, output will have the same number of channels as input.
        """

        assert len(input_shape) == 3
        self._in_c, self._in_h, self._in_w = input_shape

        if num_kp is not None:
            self.nets = nn.Conv2d(self._in_c, num_kp, kernel_size=1)
            self._out_c = num_kp
        else:
            self.nets = None
            self._out_c = self._in_c

        pos_x, pos_y = np.meshgrid(
            np.linspace(-1.0, 1.0, self._in_w), 
            np.linspace(-1.0, 1.0, self._in_h)
        )
        pos_x = Tensor(pos_x.reshape(self._in_h * self._in_w, 1), dtype=dtypes.float, requires_grad=False)
        pos_y = Tensor(pos_y.reshape(self._in_h * self._in_w, 1), dtype=dtypes.float, requires_grad=False)
        # register as buffer so it's moved to the correct device.
        self.pos_grid = Tensor.cat(pos_x, pos_y, dim=1).detach()
        self.pos_grid.requires_grad = False

    def __call__(self, features: Tensor) -> Tensor:
        """
        Args:
            features: (B, C, H, W) input feature maps.
        Returns:
            (B, K, 2) image-space coordinates of keypoints.
        """
        if self.nets is not None:
            features = self.nets(features)

        # [B, K, H, W] -> [B * K, H * W] where K is number of keypoints
        features = features.reshape(-1, self._in_h * self._in_w)
        # 2d softmax normalization
        attention = features.softmax(axis=-1)
        # [B * K, H * W] x [H * W, 2] -> [B * K, 2] for spatial coordinate mean in x and y dimensions
        expected_xy = attention @ self.pos_grid
        # reshape to [B, K, 2]
        feature_keypoints = expected_xy.view(-1, self._out_c, 2)

        return feature_keypoints

In [20]:
from tinygrad import Tensor, nn, dtypes

def create_stats_buffers(
    shapes: dict[str, list[int]],
    modes: dict[str, str],
    stats: dict[str, dict[str, Tensor]] | None = None,
) -> dict[str, dict[str, dict]]:
    """
    Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max
    statistics.

    Args: (see Normalize and Unnormalize)

    Returns:
        dict: A dictionary where keys are modalities and values are `nn.ParameterDict` containing
            `nn.Parameters` set to `requires_grad=False`, suitable to not be updated during backpropagation.
    """
    stats_buffers = {}

    for key, mode in modes.items():
        assert mode in ["mean_std", "min_max"]

        shape = tuple(shapes[key])

        if "image" in key:
            # sanity checks
            assert len(shape) == 3, f"number of dimensions of {key} != 3 ({shape=}"
            c, h, w = shape
            assert c < h and c < w, f"{key} is not channel first ({shape=})"
            # override image shape to be invariant to height and width
            shape = (c, 1, 1)

        # Note: we initialize mean, std, min, max to infinity. They should be overwritten
        # downstream by `stats` or `policy.load_state_dict`, as expected. During forward,
        # we assert they are not infinity anymore.

        buffer = {}
        if mode == "mean_std":
            buffer = {
                "mean": Tensor.ones(shape, dtype=dtypes.float, requires_grad=False) * float('inf'),
                "std": Tensor.ones(shape, dtype=dtypes.float, requires_grad=False) * float('inf'),
            }
        elif mode == "min_max":
            buffer = {
                "min": Tensor.ones(shape, dtype=dtypes.float, requires_grad=False) * float('inf'),
                "max": Tensor.ones(shape, dtype=dtypes.float, requires_grad=False) * float('inf'),
            }

        if stats is not None:
            # Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
            # tensors anywhere (for example, when we use the same stats for normalization and
            # unnormalization). See the logic here
            # https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
            if mode == "mean_std":
                buffer["mean"].assign(stats[key]["mean"].numpy())
                buffer["mean"].requires_grad = False
                buffer["std"].assign(stats[key]["std"].numpy())
                buffer["std"].requires_grad = False
            elif mode == "min_max":
                buffer["min"].assign(stats[key]["min"].numpy())
                buffer["min"].requires_grad = False
                buffer["max"].assign(stats[key]["max"].numpy())
                buffer["max"].requires_grad = False

        stats_buffers[key] = buffer
    return stats_buffers


def _no_stats_error_str(name: str) -> str:
    return (
        f"`{name}` is infinity. You should either initialize with `stats` as an argument, or use a "
        "pretrained model."
    )


class Normalize():
    """Normalizes data (e.g. "observation.image") for more stable and faster convergence during training."""

    def __init__(
        self,
        shapes: dict[str, list[int]],
        modes: dict[str, str],
        stats: dict[str, dict[str, Tensor]] | None = None,
    ):
        """
        Args:
            shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
            are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
            mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
            is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
            modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
                are their normalization modes among:
                    - "mean_std": subtract the mean and divide by standard deviation.
                    - "min_max": map to [-1, 1] range.
            stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
                and values are dictionaries of statistic types and their values (e.g.
                `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
                training the model for the first time, these statistics will overwrite the default buffers. If
                not provided, as expected for finetuning or evaluation, the default buffers should to be
                overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
                dataset is not needed to get the stats, since they are already in the policy state_dict.
        """
        self.shapes = shapes
        self.modes = modes
        self.stats = stats
        stats_buffers = create_stats_buffers(shapes, modes, stats)
        for key, buffer in stats_buffers.items():
            setattr(self, "buffer_" + key.replace(".", "_"), buffer)

    # TODO(rcadene): should we remove Tensor.no_grad?
    # @Tensor.no_grad
    def __call__(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
        Tensor.no_grad = True
        training_prev, Tensor.training = Tensor.training, False
        batch = dict(batch)  # shallow copy avoids mutating the input batch
        for key, mode in self.modes.items():
            buffer = getattr(self, "buffer_" + key.replace(".", "_"))

            if mode == "mean_std":
                mean = buffer["mean"]
                std = buffer["std"]
                assert not (mean == float('inf')).any().numpy(), _no_stats_error_str("mean")
                assert not (std == float('inf')).any().numpy(), _no_stats_error_str("std")
                #print(f'mean: {mean.numpy()}, std: {std.numpy()}')
                #print(f'batch[{key}] before normalization: {batch[key].numpy()}')
                batch[key] = (batch[key] - mean) / (std + 1e-8)
                #print(f'batch[{key}] after normalization: {batch[key].numpy()}')
            elif mode == "min_max":
                min = buffer["min"]
                max = buffer["max"]
                assert not (min == float('inf')).any().numpy(), _no_stats_error_str("min")
                assert not (max == float('inf')).any().numpy(), _no_stats_error_str("max")
                # normalize to [0,1]
                #print(f'max: {max.numpy()}, min: {min.numpy()}')
                #print(f'batch[{key}] before normalization: {batch[key].numpy()}')
                batch[key] = (batch[key] - min) / (max - min + 1e-8)
                # normalize to [-1, 1]
                batch[key] = batch[key] * 2 - 1
                #print(f'batch[{key}] after normalization: {batch[key].numpy()}')
            else:
                raise ValueError(mode)
        Tensor.no_grad = False
        Tensor.training = training_prev
        return batch


class Unnormalize():
    """
    Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their
    original range used by the environment.
    """

    def __init__(
        self,
        shapes: dict[str, list[int]],
        modes: dict[str, str],
        stats: dict[str, dict[str, Tensor]] | None = None,
    ):
        """
        Args:
            shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
            are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
            mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
            is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
            modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
                are their normalization modes among:
                    - "mean_std": subtract the mean and divide by standard deviation.
                    - "min_max": map to [-1, 1] range.
            stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
                and values are dictionaries of statistic types and their values (e.g.
                `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
                training the model for the first time, these statistics will overwrite the default buffers. If
                not provided, as expected for finetuning or evaluation, the default buffers should to be
                overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
                dataset is not needed to get the stats, since they are already in the policy state_dict.
        """
        self.shapes = shapes
        self.modes = modes
        self.stats = stats
        # `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
        stats_buffers = create_stats_buffers(shapes, modes, stats)
        for key, buffer in stats_buffers.items():
            setattr(self, "buffer_" + key.replace(".", "_"), buffer)

    # TODO(rcadene): should we remove torch.no_grad?
    def __call__(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
        Tensor.no_grad = True
        training_prev, Tensor.training = Tensor.training, False
        batch = dict(batch)  # shallow copy avoids mutating the input batch
        for key, mode in self.modes.items():
            buffer = getattr(self, "buffer_" + key.replace(".", "_"))

            if mode == "mean_std":
                mean = buffer["mean"]
                std = buffer["std"]
                assert not (mean == float('inf')).any().numpy(), _no_stats_error_str("mean")
                assert not (std == float('inf')).any().numpy(), _no_stats_error_str("std")
                batch[key] = batch[key] * std + mean
            elif mode == "min_max":
                min = buffer["min"]
                max = buffer["max"]
                assert not (min == float('inf')).any().numpy(), _no_stats_error_str("min")
                assert not (max == float('inf')).any().numpy(), _no_stats_error_str("max")
                batch[key] = (batch[key] + 1) / 2
                batch[key] = batch[key] * (max - min) + min
            else:
                raise ValueError(mode)
        Tensor.no_grad = False
        Tensor.training = training_prev
        return batch

def populate_queues(queues, batch):
    for key in batch:
        # Ignore keys not in the queues already (leaving the responsibility to the caller to make sure the
        # queues have the keys they want).
        if key not in queues:
            continue
        if len(queues[key]) != queues[key].maxlen:
            # initialize by copying the first observation several times until the queue is full
            while len(queues[key]) != queues[key].maxlen:
                queues[key].append(batch[key])
        else:
            # add latest observation to the queue
            queues[key].append(batch[key])
    return queues

In [21]:
class VQBeTPolicy:
    """
    VQ-BeT Policy as per "Behavior Generation with Latent Actions"
    """

    name = "vqbet"

    def __init__(
        self,
        config: VQBeTConfig | None = None,
        dataset_stats: dict[str, dict[str, Tensor]] | None = None,
    ):
        """
        Args:
            config: Policy configuration class instance or None, in which case the default instantiation of
                the configuration class is used.
            dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
                that they will be passed with a call to `load_state_dict` before the policy is used.
        """
        if config is None:
            config = VQBeTConfig()
        self.config = config
        self.normalize_inputs = Normalize(
            config.input_shapes, config.input_normalization_modes, dataset_stats
        )
        self.normalize_targets = Normalize(
            config.output_shapes, config.output_normalization_modes, dataset_stats
        )
        self.unnormalize_outputs = Unnormalize(
            config.output_shapes, config.output_normalization_modes, dataset_stats
        )

        self.vqbet = VQBeTModel(config)

        self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]

        self.reset()

    def reset(self):
        """
        Clear observation and action queues. Should be called on `env.reset()`
        queues are populated during rollout of the policy, they contain the n latest observations and actions
        """
        self._queues = {
            "observation.images": deque(maxlen=self.config.n_obs_steps),
            "observation.state": deque(maxlen=self.config.n_obs_steps),
            "action": deque(maxlen=self.config.action_chunk_size),
        }

    def select_action(self, batch: dict[str, Tensor]) -> Tensor:
        Tensor.no_grad = True
        training_prev, Tensor.training = Tensor.training, False
        """Select a single action given environment observations.

        This method wraps `select_actions` in order to return one action at a time for execution in the
        environment. It works by managing the actions in a queue and only calling `select_actions` when the
        queue is empty.
        """

        batch = self.normalize_inputs(batch)
        batch = dict(batch)  # shallow copy so that adding a key doesn't modify the original
        batch["observation.images"] = Tensor.stack(*[batch[k] for k in self.expected_image_keys], dim=-4)
        # Note: It's important that this happens after stacking the images into a single key.
        self._queues = populate_queues(self._queues, batch)

        if not self.vqbet.action_head.vqvae_model.discretized:
            warnings.warn(
                "To evaluate in the environment, your VQ-BeT model should contain a pretrained Residual VQ.",
                stacklevel=1,
            )

        if len(self._queues["action"]) == 0:
            batch = {k: Tensor.stack(self._queues[k], dim=1) for k in batch if k in self._queues}
            actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size]

            # the dimension of returned action is (batch_size, action_chunk_size, action_dim)
            actions = self.unnormalize_outputs({"action": actions})["action"]
            # since the data in the action queue's dimension is (action_chunk_size, batch_size, action_dim), we transpose the action and fill the queue
            self._queues["action"].extend(actions.transpose(0, 1))

        action = self._queues["action"].popleft()
        Tensor.no_grad = False
        Tensor.training = training_prev
        return action

    def __call__(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
        """Run the batch through the model and compute the loss for training or validation."""
        batch = self.normalize_inputs(batch)
        batch = dict(batch)  # shallow copy so that adding a key doesn't modify the original
        batch["observation.images"] = Tensor.stack(*[batch[k] for k in self.expected_image_keys], dim=-4)
        batch = self.normalize_targets(batch)
        # VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181)
        if not self.vqbet.action_head.vqvae_model.discretized:
            print(f'Not calling policy vqbet!')
            # loss: total loss of training RVQ
            # n_different_codes: how many of the total possible VQ codes are being used in single batch (how many of them have at least one encoder embedding as a nearest neighbor). This can be at most `vqvae_n_embed * number of layers of RVQ (=2)`.
            # n_different_combinations: how many different code combinations are being used out of all possible combinations in single batch. This can be at most `vqvae_n_embed ^ number of layers of RVQ (=2)` (hint consider the RVQ as a decision tree).
            loss, n_different_codes, n_different_combinations, recon_l1_error = (
                self.vqbet.action_head.discretize(self.config.n_vqvae_training_steps, batch["action"])
            )
            return {
                "loss": loss,
                "n_different_codes": n_different_codes,
                "n_different_combinations": n_different_combinations,
                "recon_l1_error": recon_l1_error,
            }
        print(f'Calling Policy VQBeT!')
        # if Residual VQ is already trained, VQ-BeT trains its GPT and bin prediction head / offset prediction head parts.
        _, loss_dict = self.vqbet(batch, rollout=False)

        return loss_dict

In [22]:
from pathlib import Path

from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from torch.utils.data import DataLoader

import tinygrad
from tinygrad import Tensor, nn, TinyJit, dtypes

from tinygrad.nn.state import safe_save, safe_load, get_state_dict, load_state_dict

# Start of training code

# Create a directory to store the training checkpoint.
output_directory = Path("outputs/train/example_pusht_vqbet")
output_directory.mkdir(parents=True, exist_ok=True)

# Number of offline training steps (we'll only do offline training for this example.)
# Adjust as you prefer. 5000 steps are needed to get something worth evaluating.
training_steps = 40000
log_freq = 100

# Set up the dataset.
delta_timestamps = {
    # Load the previous image and state at -0.1 seconds before current frame,
    # then load current image and state corresponding to 0.0 second.
    "observation.image": [-0.2, -0.15, -0.1, -0.05, 0.0],
    "observation.state": [-0.2, -0.15, -0.1, -0.05, 0.0],
    # Load the previous action (-0.1), the next action to be executed (0.0),
    # and 14 future actions with a 0.1 seconds spacing. All these actions will be
    # used to supervise the policy.
    "action": [-0.2, -0.15, -0.1, -0.05, 0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5],
}
dataset = LeRobotDataset("lerobot/pusht", delta_timestamps=delta_timestamps)

# Set up the the policy.
# Policies are initialized with a configuration class, in this case `VqBeTConfig`.
# For this example, no arguments need to be passed because the defaults are set up for PushT.
# If you're doing something different, you will likely need to change at least some of the defaults.
cfg = VQBeTConfig()
policy = VQBeTPolicy(cfg, dataset_stats=dataset.stats)

opt_decay, opt_vq_vae, opt_no_decay, decay_params, vqvae_params, no_decay_params = VQBeTOptimizer.create(policy, cfg)
opts = OptimizerGroup(opt_decay, opt_no_decay)
lr_scheduler = VQBeTScheduler(opts, cfg)

#@TinyJit
@Tensor.train()
def train_step_vqvae(
    observation_image: Tensor,
    observation_state: Tensor,
    action: Tensor,
    episode_index: Tensor,
    frame_index: Tensor,
    timestamp: Tensor,
    next_reward: Tensor,
    next_done: Tensor,
    next_success: Tensor,
    index: Tensor,
    observation_image_is_pad: Tensor,
    observation_state_is_pad: Tensor,
    action_is_pad: Tensor
):
    batch = {
        'observation.image': observation_image,
        'observation.state': observation_state,
        'action': action,
        'episode_index': episode_index,
        'frame_index': frame_index,
        'timestamp': timestamp,
        'next.reward': next_reward,
        'index': index,
        'observation.image_is_pad': observation_image_is_pad,
        'observation.state_is_pad': observation_state_is_pad,
        'action_is_pad': action_is_pad
    }
    output_dict = policy(batch)
    loss = output_dict["loss"]
    opt_vq_vae.zero_grad()
    loss.backward()
    # grad_norm = clip_grad_norm_(vqvae_params, 10.0) if not policy.vqbet.action_head.vqvae_model.discretized else None
    opt_vq_vae.step()
    return loss

@Tensor.train()
def train_step_main(
    observation_image: Tensor,
    observation_state: Tensor,
    action: Tensor,
    episode_index: Tensor,
    frame_index: Tensor,
    timestamp: Tensor,
    next_reward: Tensor,
    next_done: Tensor,
    next_success: Tensor,
    index: Tensor,
    observation_image_is_pad: Tensor,
    observation_state_is_pad: Tensor,
    action_is_pad: Tensor
):
    batch = {
        'observation.image': observation_image,
        'observation.state': observation_state,
        'action': action,
        'episode_index': episode_index,
        'frame_index': frame_index,
        'timestamp': timestamp,
        'next.reward': next_reward,
        'index': index,
        'observation.image_is_pad': observation_image_is_pad,
        'observation.state_is_pad': observation_state_is_pad,
        'action_is_pad': action_is_pad
    }
    output_dict = policy(batch)
    loss = output_dict["loss"]
    opts.zero_grad()
    loss.backward()
    # grad_norm = clip_grad_norm_(vqvae_params, 10.0) if not policy.vqbet.action_head.vqvae_model.discretized else None
    opts.step()
    if is_discretized:
        lr_scheduler.step()
    return loss

if __name__ == "__main__":
    # Run training loop.
    print(f'Starting training loop')
    # Create dataloader for offline training.
    dataloader = DataLoader(
        dataset,
        num_workers=0,
        batch_size=64,
        shuffle=True,
        pin_memory=False,
        drop_last=True,
    )

    step = 0
    done = False
    with Tensor.train():
        while not done:
            for batch in dataloader:
                batch = {k: Tensor(v.numpy(), requires_grad=False) for k, v in batch.items()}
                if policy.vqbet.action_head.vqvae_model.discretized:
                    loss = train_step_main(
                        batch['observation.image'].realize(),
                        batch['observation.state'].realize(),
                        batch['action'].realize(),
                        batch['episode_index'].realize(),
                        batch['frame_index'].realize(),
                        batch['timestamp'].realize(),
                        batch['next.reward'].realize(),
                        batch['next.done'].realize(),
                        batch['next.success'].realize(),
                        batch['index'].realize(),
                        batch['observation.image_is_pad'].realize(),
                        batch['observation.state_is_pad'].realize(),
                        batch['action_is_pad'].realize()
                    )
                else:
                    loss = train_step_vqvae(
                        batch['observation.image'].realize(),
                        batch['observation.state'].realize(),
                        batch['action'].realize(),
                        batch['episode_index'].realize(),
                        batch['frame_index'].realize(),
                        batch['timestamp'].realize(),
                        batch['next.reward'].realize(),
                        batch['next.done'].realize(),
                        batch['next.success'].realize(),
                        batch['index'].realize(),
                        batch['observation.image_is_pad'].realize(),
                        batch['observation.state_is_pad'].realize(),
                        batch['action_is_pad'].realize()
                    )

                #loss = loss_gradnorm_tuple[0]
                #gradnorm = loss_gradnorm_tuple[1]
            
                if step % log_freq == 0:
                    print(f"step: {step} loss: {loss.numpy():.3f}")
                step += 1

                if step % 5000 == 0 or step == 19999:
                    try:
                        if step < cfg.n_vqvae_training_steps:
                            state_dict_vae_encoder = get_state_dict(policy.vqbet.action_head.vqvae_model.encoder)
                            safe_save(state_dict_vae_encoder, f'{output_directory}/model_vae_encoder_{step}.safetensors')
                            state_dict_vae_decoder = get_state_dict(policy.vqbet.action_head.vqvae_model.decoder)
                            safe_save(state_dict_vae_decoder, f'{output_directory}/model_vae_decoder_{step}.safetensors')
                            state_dict_vae_vq_layer = get_state_dict(policy.vqbet.action_head.vqvae_model.vq_layer)
                            safe_save(state_dict_vae_vq_layer, f'{output_directory}/model_vae_vq_layer_{step}.safetensors')
                        else:
                            state_dict = get_state_dict(policy)
                            safe_save(state_dict, f'{output_directory}/model_{step}.safetensors')
                    except:
                        print(f'Exception with safe save occured')
                if step >= training_steps:
                    done = True
                    break

    # Save a policy checkpoint.
    state_dict = get_state_dict(policy)
    safe_save(state_dict, f'{output_directory}/model_final.safetensors')

Fetching 212 files:   0%|          | 0/212 [00:00<?, ?it/s]

dummy_feature_map: <Tensor <LB METAL (1, 512, 3, 3) float ShapeTracker(views=(View(shape=(1, 512, 3, 3), strides=(0, 9, 3, 1), offset=0, mask=None, contiguous=True),))> on METAL with grad None>
replacing GPT parameter: h.0.attn.c_proj.weight with a normal
replacing GPT parameter: h.1.attn.c_proj.weight with a normal
replacing GPT parameter: h.2.attn.c_proj.weight with a normal
replacing GPT parameter: h.3.attn.c_proj.weight with a normal
replacing GPT parameter: h.4.attn.c_proj.weight with a normal
replacing GPT parameter: h.5.attn.c_proj.weight with a normal
replacing GPT parameter: h.6.attn.c_proj.weight with a normal
replacing GPT parameter: h.7.attn.c_proj.weight with a normal
decay list: ['h.0.attn.c_attn.weight', 'h.0.attn.c_proj.weight', 'h.0.mlp.0.weight', 'h.0.mlp.2.weight', 'h.1.attn.c_attn.weight', 'h.1.attn.c_proj.weight', 'h.1.mlp.0.weight', 'h.1.mlp.2.weight', 'h.2.attn.c_attn.weight', 'h.2.attn.c_proj.weight', 'h.2.mlp.0.weight', 'h.2.mlp.2.weight', 'h.3.attn.c_attn.weig

CompileError: Error Domain=MTLLibraryErrorDomain Code=3 "program_source:3:902: error: no 'buffer' resource location available for 'data31'
kernel void E_168_8_8_16_4n5(device float* data0, const device float* data1, const device float* data2, const device float* data3, const device float* data4, const device float* data5, const device float* data6, const device float* data7, const device float* data8, const device float* data9, const device float* data10, const device float* data11, const device float* data12, const device float* data13, const device float* data14, const device float* data15, const device float* data16, const device float* data17, const device float* data18, const device float* data19, const device float* data20, const device float* data21, const device float* data22, const device float* data23, const device float* data24, const device float* data25, const device float* data26, const device float* data27, const device float* data28, const device float* data29, const device float* data30, const device float* data31, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     ^
" UserInfo={NSLocalizedDescription=program_source:3:902: error: no 'buffer' resource location available for 'data31'
kernel void E_168_8_8_16_4n5(device float* data0, const device float* data1, const device float* data2, const device float* data3, const device float* data4, const device float* data5, const device float* data6, const device float* data7, const device float* data8, const device float* data9, const device float* data10, const device float* data11, const device float* data12, const device float* data13, const device float* data14, const device float* data15, const device float* data16, const device float* data17, const device float* data18, const device float* data19, const device float* data20, const device float* data21, const device float* data22, const device float* data23, const device float* data24, const device float* data25, const device float* data26, const device float* data27, const device float* data28, const device float* data29, const device float* data30, const device float* data31, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     ^
}

In [23]:
import gymnasium as gym
import gymnasium_robotics

gym.register_envs(gymnasium_robotics)

env = gym.make(
    'FrankaKitchen-v1',
    max_episode_steps=500,
    render_mode="rgb_array",
    tasks_to_complete=["microwave", "kettle", "bottom burner", "light switch"]
)
numpy_observation, info = env.reset()
rewards = []
frames = []
# Render frame of the initial state
current_frame = env.render()
frames.append(current_frame)
