In [3]:
import math
from collections import deque
from itertools import chain
from typing import Callable

import numpy as np
import tinygrad
from tinygrad import Tensor, nn, dtypes

In [4]:
def get_activation_fn(activation: str) -> Callable:
    """Return an activation function given a string."""
    if activation == "relu":
        return Tensor.relu
    if activation == "gelu":
        return Tensor.gelu
    raise RuntimeError(f"activation should be relu/gelu/glu, not {activation}.")

In [5]:
get_activation_fn('relu')

<function tinygrad.tensor.Tensor.relu(self)>

In [40]:
class ACTSinusoidalPositionEmbedding2d:
    """2D sinusoidal positional embeddings similar to what's presented in Attention Is All You Need.

    The variation is that the position indices are normalized in [0, 2π] (not quite: the lower bound is 1/H
    for the vertical direction, and 1/W for the horizontal direction.
    """

    def __init__(self, dimension: int):
        """
        Args:
            dimension: The desired dimension of the embeddings.
        """
        super().__init__()
        self.dimension = dimension
        self._two_pi = 2 * math.pi
        self._eps = 1e-6
        # Inverse "common ratio" for the geometric progression in sinusoid frequencies.
        self._temperature = 10000

    def __call__(self, x: Tensor) -> Tensor:
        """
        Args:
            x: A (B, C, H, W) batch of 2D feature map to generate the embeddings for.
        Returns:
            A (1, C, H, W) batch of corresponding sinusoidal positional embeddings.
        """
        not_mask = Tensor.ones_like(x[0, :1])  # (1, H, W)
        # Note: These are like range(1, H+1) and range(1, W+1) respectively, but in most implementations
        # they would be range(0, H) and range(0, W). Keeping it at as is to match the original code.
        y_range = not_mask.cumsum(1).cast(dtype=dtypes.float32)
        x_range = not_mask.cumsum(2).cast(dtype=dtypes.float32)

        # "Normalize" the position index such that it ranges in [0, 2π].
        # Note: Adding epsilon on the denominator should not be needed as all values of y_embed and x_range
        # are non-zero by construction. This is an artifact of the original code.
        y_range = y_range / (y_range[:, -1:, :] + self._eps) * self._two_pi
        x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi

        inverse_frequency = Tensor(self._temperature ** (
            2 * (np.arange(self.dimension, dtype='f') // 2) / self.dimension
        ))

        x_range = x_range.unsqueeze(-1) / inverse_frequency  # (1, H, W, 1)
        y_range = y_range.unsqueeze(-1) / inverse_frequency  # (1, H, W, 1)

        print(x_range)
        print(y_range)

        # Note: this stack then flatten operation results in interleaved sine and cosine terms.
        # pos_embed_x and pos_embed_y are (1, H, W, C // 2).
        x_range_sin = x_range[..., 0::2].sin()
        x_range_cos = x_range[..., 1::2].cos()
        y_range_sin = y_range[..., 0::2].sin()
        y_range_cos = y_range[..., 1::2].cos()
        print(f'x_range[..., 0::2].sin(): {x_range_sin}')
        print(f'x_range[..., 1::2].cos(): {x_range_cos}')
        pos_embed_x = x_range_sin.stack(x_range_cos, dim=-1).flatten(3)
        pos_embed_y = y_range_sin.stack(y_range_cos, dim=-1).flatten(3)
        pos_embed = pos_embed_y.cat(pos_embed_x, dim=3).permute(0, 3, 1, 2)  # (1, C, H, W)

        return pos_embed

In [41]:
actSin = ACTSinusoidalPositionEmbedding2d(10)
actSin(Tensor.zeros(4,4,4,4))

<Tensor <LB METAL (1, 4, 4, 10) float (<BinaryOps.MUL: 2>, None)> on METAL with grad None>
<Tensor <LB METAL (1, 4, 4, 10) float (<BinaryOps.MUL: 2>, None)> on METAL with grad None>
x_range[..., 0::2].sin(): <Tensor <LB METAL (1, 4, 4, 5) float (<UnaryOps.SIN: 5>, None)> on METAL with grad None>
x_range[..., 1::2].cos(): <Tensor <LB METAL (1, 4, 4, 5) float (<UnaryOps.SIN: 5>, None)> on METAL with grad None>


<Tensor <LB METAL (1, 20, 4, 4) float ShapeTracker(views=(View(shape=(1, 20, 4, 4), strides=(0, 1, 80, 20), offset=0, mask=None, contiguous=False),))> on METAL with grad None>

In [42]:
def create_sinusoidal_pos_embedding(num_positions: int, dimension: int) -> Tensor:
    """1D sinusoidal positional embeddings as in Attention is All You Need.

    Args:
        num_positions: Number of token positions required.
    Returns: (num_positions, dimension) position embeddings (the first dimension is the batch dimension).

    """

    def get_position_angle_vec(position):
        return [position / np.power(10000, 2 * (hid_j // 2) / dimension) for hid_j in range(dimension)]

    sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(num_positions)], dtype='f')
    sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # dim 2i
    sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # dim 2i+1
    return Tensor(sinusoid_table).float()

In [43]:
create_sinusoidal_pos_embedding(3, 3).numpy()

array([[ 0.        ,  1.        ,  0.        ],
       [ 0.841471  ,  0.5403023 ,  0.00215443],
       [ 0.9092974 , -0.4161468 ,  0.00430886]], dtype=float32)

In [273]:
from tinygrad import Tensor, nn
from typing import Optional, Union, Literal
from tinygrad.ops import Variable

class MultiheadAttention:
    def __init__(self, embed_dim, num_heads, dropout=0.0):
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == embed_dim, "n_state must be divisible by n_head"

        self.query = nn.Linear(embed_dim, embed_dim)
        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)
        self.out = nn.Linear(embed_dim, embed_dim)

        self.scaling = self.head_dim ** -0.5        
        self.dropout = dropout

    def __call__(self, q: Tensor, k: Tensor, v: Tensor, key_padding_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None, training: bool = True):
        batch_size, tgt_len, _ = q.shape
        src_len = k.shape[1]

        # Apply linear transformations
        q = self.query(q)
        k = self.key(k)
        v = self.value(v)

        # Reshape and transpose for multi-head attention
        q = q.reshape(batch_size, tgt_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        k = k.reshape(batch_size, src_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        v = v.reshape(batch_size, src_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        
        # Calculate attention scores
        attn_scores = (q @ k.transpose(-2, -1)) * self.scaling

        # Apply key padding mask if provided
        if key_padding_mask is not None:
            print(f'(q,k,v): {q.shape}, {k.shape}, {v.shape}')
            print(f'(key_padding_mask): {key_padding_mask.shape}')
            # Reshape and expand key_padding_mask to match attn_scores dimensions
            key_padding_mask = key_padding_mask.squeeze(1).squeeze(1)  # Remove extra dimensions
            key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(1)  # Add dimensions for heads and query length
            key_padding_mask = key_padding_mask.expand(batch_size, self.num_heads, tgt_len, src_len)
            attn_scores = attn_scores.masked_fill(key_padding_mask, float('-inf'))   
    
        # Apply softmax to get attention weights
        attn_weights = attn_scores.softmax(axis=-1)

        # Apply dropout
        if self.dropout > 0:
            attn_weights = attn_weights.dropout(p=self.dropout)

        # Apply attention to values
        attn_output = attn_weights @ v

        # Reshape and combine heads
        attn_output = attn_output.permute(0, 2, 1, 3).reshape(batch_size, tgt_len, self.embed_dim)

        # Final projection
        attn_output = self.out(attn_output)

        return attn_output


In [274]:
mha = MultiheadAttention(9, 9)
mha(Tensor.zeros(9, 9, 9), Tensor.zeros(9, 9, 9), Tensor.zeros(9, 9, 9))

<Tensor <LB METAL (9, 9, 9) float (<BinaryOps.ADD: 1>, None)> on METAL with grad None>

In [275]:
from dataclasses import dataclass, field


@dataclass
class ACTConfig:
    """Configuration class for the Action Chunking Transformers policy.

    Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".

    The parameters you will most likely need to change are the ones which depend on the environment / sensors.
    Those are: `input_shapes` and 'output_shapes`.

    Notes on the inputs and outputs:
        - Either:
            - At least one key starting with "observation.image is required as an input.
              AND/OR
            - The key "observation.environment_state" is required as input.
        - If there are multiple keys beginning with "observation.images." they are treated as multiple camera
          views. Right now we only support all images having the same shape.
        - May optionally work without an "observation.state" key for the proprioceptive robot state.
        - "action" is required as an output key.

    Args:
        n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
            current step and additional steps going back).
        chunk_size: The size of the action prediction "chunks" in units of environment steps.
        n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
            This should be no greater than the chunk size. For example, if the chunk size size 100, you may
            set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
            environment, and throws the other 50 out.
        input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
            the input data name, and the value is a list indicating the dimensions of the corresponding data.
            For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
            indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
            include batch dimension or temporal dimension.
        output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
            the output data name, and the value is a list indicating the dimensions of the corresponding data.
            For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
            Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
        input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
            and the value specifies the normalization mode to apply. The two available modes are "mean_std"
            which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
            [-1, 1] range.
        output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
            original scale. Note that this is also used for normalizing the training targets.
        vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
        pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone.
            `None` means no pretrained weights.
        replace_final_stride_with_dilation: Whether to replace the ResNet's final 2x2 stride with a dilated
            convolution.
        pre_norm: Whether to use "pre-norm" in the transformer blocks.
        dim_model: The transformer blocks' main hidden dimension.
        n_heads: The number of heads to use in the transformer blocks' multi-head attention.
        dim_feedforward: The dimension to expand the transformer's hidden dimension to in the feed-forward
            layers.
        feedforward_activation: The activation to use in the transformer block's feed-forward layers.
        n_encoder_layers: The number of transformer layers to use for the transformer encoder.
        n_decoder_layers: The number of transformer layers to use for the transformer decoder.
        use_vae: Whether to use a variational objective during training. This introduces another transformer
            which is used as the VAE's encoder (not to be confused with the transformer encoder - see
            documentation in the policy class).
        latent_dim: The VAE's latent dimension.
        n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder.
        temporal_ensemble_coeff: Coefficient for the exponential weighting scheme to apply for temporal
            ensembling. Defaults to None which means temporal ensembling is not used. `n_action_steps` must be
            1 when using this feature, as inference needs to happen at every step to form an ensemble. For
            more information on how ensembling works, please see `ACTTemporalEnsembler`.
        dropout: Dropout to use in the transformer layers (see code for details).
        kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective
            is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
    """

    # Input / output structure.
    n_obs_steps: int = 1
    chunk_size: int = 100
    n_action_steps: int = 100

    input_shapes: dict[str, list[int]] = field(
        default_factory=lambda: {
            "observation.images.top": [3, 480, 640],
            "observation.state": [14],
        }
    )
    output_shapes: dict[str, list[int]] = field(
        default_factory=lambda: {
            "action": [14],
        }
    )

    # Normalization / Unnormalization
    input_normalization_modes: dict[str, str] = field(
        default_factory=lambda: {
            "observation.images.top": "mean_std",
            "observation.state": "mean_std",
        }
    )
    output_normalization_modes: dict[str, str] = field(
        default_factory=lambda: {
            "action": "mean_std",
        }
    )

    # Overrides.
    override_dataset_stats: dict[str, dict[str, list[[float]]]] = field(
        default_factory=lambda: {
            "observation.images.top": {
                "mean": [[[0.485]], [[0.456]], [[0.406]]],  # (c,1,1)
                "std": [[[0.229]], [[0.224]], [[0.225]]]  # (c,1,1)
            }
        }
    )

    # Architecture.
    # Vision backbone.
    vision_backbone: str = "resnet18"
    pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1"
    replace_final_stride_with_dilation: int = False
    # Transformer layers.
    pre_norm: bool = False
    dim_model: int = 512
    n_heads: int = 8
    dim_feedforward: int = 3200
    feedforward_activation: str = "relu"
    n_encoder_layers: int = 4
    # Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
    # that means only the first layer is used. Here we match the original implementation by setting this to 1.
    # See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
    n_decoder_layers: int = 1
    # VAE.
    use_vae: bool = True
    latent_dim: int = 32
    n_vae_encoder_layers: int = 4

    # Inference.
    # Note: the value used in ACT when temporal ensembling is enabled is 0.01.
    temporal_ensemble_coeff: float | None = None

    # Training and loss computation.
    dropout: float = 0.1
    kl_weight: float = 10.0

    def __post_init__(self):
        """Input validation (not exhaustive)."""
        if not self.vision_backbone.startswith("resnet"):
            raise ValueError(
                f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
            )
        if self.temporal_ensemble_coeff is not None and self.n_action_steps > 1:
            raise NotImplementedError(
                "`n_action_steps` must be 1 when using temporal ensembling. This is "
                "because the policy needs to be queried every step to compute the ensembled action."
            )
        if self.n_action_steps > self.chunk_size:
            raise ValueError(
                f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
                f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
            )
        if self.n_obs_steps != 1:
            raise ValueError(
                f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
            )
        if (
            not any(k.startswith("observation.image") for k in self.input_shapes)
            and "observation.environment_state" not in self.input_shapes
        ):
            raise ValueError("You must provide at least one image or the environment state among the inputs.")

In [276]:
class ACTDecoderLayer:
    def __init__(self, config: ACTConfig):
        super().__init__()
        self.self_attn = MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
        self.multihead_attn = MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)

        # Feed forward layers.
        self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward)
        self.linear2 = nn.Linear(config.dim_feedforward, config.dim_model)

        self.norm1 = nn.LayerNorm(config.dim_model)
        self.norm2 = nn.LayerNorm(config.dim_model)
        self.norm3 = nn.LayerNorm(config.dim_model)
        self.dropout_rate = config.dropout

        self.activation = get_activation_fn(config.feedforward_activation)
        self.pre_norm = config.pre_norm

    def maybe_add_pos_embed(self, tensor: Tensor, pos_embed: Tensor | None) -> Tensor:
        return tensor if pos_embed is None else tensor + pos_embed

    def __call__(
        self,
        x: Tensor,
        encoder_out: Tensor,
        decoder_pos_embed: Tensor | None = None,
        encoder_pos_embed: Tensor | None = None,
    ) -> Tensor:
        """
        Args:
            x: (Decoder Sequence, Batch, Channel) tensor of input tokens.
            encoder_out: (Encoder Sequence, B, C) output features from the last layer of the encoder we are
                cross-attending with.
            decoder_pos_embed: (ES, 1, C) positional embedding for keys (from the encoder).
            encoder_pos_embed: (DS, 1, C) Positional_embedding for the queries (from the decoder).
        Returns:
            (DS, B, C) tensor of decoder output features.
        """
        skip = x
        if self.pre_norm:
            x = self.norm1(x)
        q = k = self.maybe_add_pos_embed(x, decoder_pos_embed)
        x = self.self_attn(q, k, x)  
        #x = x[0] # select just the output, not the attention weights
        x = skip + x.dropout(p=self.dropout_rate)
        if self.pre_norm:
            skip = x
            x = self.norm2(x)
        else:
            x = self.norm1(x)
            skip = x
        x = self.multihead_attn(
            self.maybe_add_pos_embed(x, decoder_pos_embed),
            self.maybe_add_pos_embed(encoder_out, encoder_pos_embed),
            encoder_out,
        )
        #x = x[0]  # select just the output, not the attention weights
        x = skip + x.dropout(p=self.dropout_rate)
        if self.pre_norm:
            skip = x
            x = self.norm3(x)
        else:
            x = self.norm2(x)
            skip = x
        
        x = x.sequential([self.linear1, self.activation]).dropout(p=self.dropout_rate).sequential([self.linear2])
        x = skip + x.dropout(p=self.dropout_rate)
        if not self.pre_norm:
            x = self.norm3(x)
        return x


In [277]:
actDecoder = ACTDecoderLayer(ACTConfig())
actDecoder(Tensor.zeros(3,512, 512), Tensor.zeros(3,512, 512))

<Tensor <LB METAL (3, 512, 512) float (<BinaryOps.MUL: 2>, None)> on METAL with grad None>

In [278]:
class ACTDecoder:
    def __init__(self, config: ACTConfig):
        """Convenience module for running multiple decoder layers followed by normalization."""
        super().__init__()
        self.layers = [ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)]
        self.norm = nn.LayerNorm(config.dim_model)

    def __call__(
        self,
        x: Tensor,
        encoder_out: Tensor,
        decoder_pos_embed: Tensor | None = None,
        encoder_pos_embed: Tensor | None = None,
    ) -> Tensor:
        for layer in self.layers:
            x = layer(
                x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed
            )
        if self.norm is not None:
            x = self.norm(x)
        return x

In [279]:
actDecode = ACTDecoder(ACTConfig())
actDecode(Tensor.zeros(3,512,512), Tensor.zeros(3,512,512))

<Tensor <LB METAL (3, 512, 512) float (<BinaryOps.MUL: 2>, None)> on METAL with grad None>

In [280]:
class ACTEncoderLayer:
    def __init__(self, config: ACTConfig):
        super().__init__()
        self.self_attn = MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)

        # Feed forward layers.
        self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward)
        self.dropout = config.dropout
        self.linear2 = nn.Linear(config.dim_feedforward, config.dim_model)

        self.norm1 = nn.LayerNorm(config.dim_model)
        self.norm2 = nn.LayerNorm(config.dim_model)

        self.activation = get_activation_fn(config.feedforward_activation)
        self.pre_norm = config.pre_norm

    def __call__(self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None) -> Tensor:
        skip = x
        if self.pre_norm:
            x = self.norm1(x)
        q = k = x if pos_embed is None else x + pos_embed
        x = self.self_attn(q, k, x, key_padding_mask=key_padding_mask)
        # x = x[0]  # note: [0] to select just the output, not the attention weights
        x = skip + x.dropout(p=self.dropout)
        if self.pre_norm:
            skip = x
            x = self.norm2(x)
        else:
            x = self.norm1(x)
            skip = x
        x = x.sequential([self.linear1, self.activation]).dropout(p=self.dropout).sequential([self.linear2])
        x = skip + x.dropout(p=self.dropout)
        if not self.pre_norm:
            x = self.norm2(x)
        return x

In [281]:
actEncode = ACTEncoderLayer(ACTConfig())
actEncode(Tensor.zeros(3, 512, 512))

<Tensor <LB METAL (3, 512, 512) float (<BinaryOps.MUL: 2>, None)> on METAL with grad None>

In [282]:
class ACTEncoder:
    """Convenience module for running multiple encoder layers, maybe followed by normalization."""

    def __init__(self, config: ACTConfig, is_vae_encoder: bool = False):
        super().__init__()
        self.is_vae_encoder = is_vae_encoder
        num_layers = config.n_vae_encoder_layers if self.is_vae_encoder else config.n_encoder_layers
        self.layers = [ACTEncoderLayer(config) for _ in range(num_layers)]
        self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else lambda x: x

    def __call__(
        self, x: Tensor, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None
    ) -> Tensor:
        for layer in self.layers:
            print(f'ACTEncoder x.shape per layer: {x.shape}')
            x = layer(x, pos_embed=pos_embed, key_padding_mask=key_padding_mask)
        x = self.norm(x)
        return x


In [283]:
actEncoder = ACTEncoder(ACTConfig())
actEncode(Tensor.zeros(3, 512, 512))

<Tensor <LB METAL (3, 512, 512) float (<BinaryOps.MUL: 2>, None)> on METAL with grad None>

In [302]:
import tinygrad.nn as nn
from tinygrad import Tensor, dtypes
from tinygrad.helpers import fetch, get_child

# allow monkeypatching in layer implementations
BatchNorm = nn.BatchNorm2d
Conv2d = nn.Conv2d
Linear = nn.Linear

class FrozenBatchNorm2d:
    def __init__(self, num_features, eps=1e-5):
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        # Register buffers instead of parameters
        self.weight = Tensor.ones(num_features, requires_grad=False)
        self.bias = Tensor.zeros(num_features, requires_grad=False)
        self.running_mean = Tensor.zeros(num_features, requires_grad=False)
        self.running_var = Tensor.ones(num_features, requires_grad=False)
    def __call__(self, x:Tensor) -> Tensor:
        # Reshape for 2D input
        scale = (self.weight / (self.running_var + self.eps).sqrt()).reshape(1, -1, 1, 1)
        bias = (self.bias - self.running_mean * scale.flatten()).reshape(1, -1, 1, 1)
        return x * scale + bias

class Block:
  def __init__(self, in_dims, dims, stride=1):
    super().__init__()
    self.conv1 = nn.Conv2d(
      in_dims, dims, kernel_size=3, stride=stride, padding=1, bias=False
    )
    self.bn1 = FrozenBatchNorm2d(dims)
    self.conv2 = nn.Conv2d(
      dims, dims, kernel_size=3, stride=1, padding=1, bias=False
    )
    self.bn2 = FrozenBatchNorm2d(dims)
    self.downsample = []
    if stride != 1:
      self.downsample = [
        nn.Conv2d(in_dims, dims, kernel_size=1, stride=stride, bias=False),
        FrozenBatchNorm2d(dims)
      ]
  def __call__(self, x):
    base_operations = [
        self.conv1,
        self.bn1,
        Tensor.relu,
        self.conv2,
        self.bn2
    ]
    out = x.sequential(base_operations)
    
    if self.downsample != []:
      return (x.sequential(base_operations) + x.sequential(self.downsample)).relu()
    else:
      return x.sequential(base_operations).relu()

class ResNet:
  def __init__(self, block, num_blocks, num_classes=10):
    super().__init__()
    self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
    self.bn1 = FrozenBatchNorm2d(64)
    self.layer1 = self._make_layer(block, 64, 64, num_blocks[0], stride=1)
    self.layer2 = self._make_layer(block, 64, 128, num_blocks[1], stride=2)
    self.layer3 = self._make_layer(block, 128, 256, num_blocks[2], stride=2)
    self.layer4 = self._make_layer(block, 256, 512, num_blocks[3], stride=2)
    #self.fc = nn.Linear(512, num_classes, requires_grad=False) # if we decide to use this someday, remove the grad
  def _make_layer(self, block, in_dims, dims, num_blocks, stride):
    strides = [stride] + [1] * (num_blocks - 1)
    layers = []
    for stride in strides:
      layers.append(block(in_dims, dims, stride))
      in_dims = dims
    return layers
  def __call__(self, x:Tensor):
    x = self.bn1(self.conv1(x)).relu().max_pool2d()
    x = x.sequential(self.layer1)
    x = x.sequential(self.layer2 + self.layer3 + self.layer4)
    """
    Commented out for now, because we're just using the output from layer4
    """
    #x = x.mean([2, 3])
    #x = self.fc(x)
    return x

resnet18_IMAGENET1K_V1 = ResNet(Block, [2, 2, 2, 2], num_classes=1000)
state_dict = nn.state.safe_load("resnet18-f37072fd.safetensors")
nn.state.load_state_dict(resnet18_IMAGENET1K_V1, state_dict)

ram used:  2.26 GB, layer4.1.bn2.running_var                          : 100%|█| 


loaded weights in  18.32 ms, 0.04 GB loaded at 2.44 GB/s


In [303]:
from itertools import chain

class ACT:
    """Action Chunking Transformer: The underlying neural network for ACTPolicy.

    Note: In this code we use the terms `vae_encoder`, 'encoder', `decoder`. The meanings are as follows.
        - The `vae_encoder` is, as per the literature around variational auto-encoders (VAE), the part of the
          model that encodes the target data (a sequence of actions), and the condition (the robot
          joint-space).
        - A transformer with an `encoder` (not the VAE encoder) and `decoder` (not the VAE decoder) with
          cross-attention is used as the VAE decoder. For these terms, we drop the `vae_` prefix because we
          have an option to train this model without the variational objective (in which case we drop the
          `vae_encoder` altogether, and nothing about this model has anything to do with a VAE).

                                 Transformer
                                 Used alone for inference
                                 (acts as VAE decoder
                                  during training)
                                ┌───────────────────────┐
                                │             Outputs   │
                                │                ▲      │
                                │     ┌─────►┌───────┐  │
                   ┌──────┐     │     │      │Transf.│  │
                   │      │     │     ├─────►│decoder│  │
              ┌────┴────┐ │     │     │      │       │  │
              │         │ │     │ ┌───┴───┬─►│       │  │
              │ VAE     │ │     │ │       │  └───────┘  │
              │ encoder │ │     │ │Transf.│             │
              │         │ │     │ │encoder│             │
              └───▲─────┘ │     │ │       │             │
                  │       │     │ └▲──▲─▲─┘             │
                  │       │     │  │  │ │               │
                inputs    └─────┼──┘  │ image emb.      │
                                │    state emb.         │
                                └───────────────────────┘
    """

    def __init__(self, config: ACTConfig):
        super().__init__()
        self.config = config
        # BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
        # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
        self.use_robot_state = "observation.state" in config.input_shapes
        self.use_images = any(k.startswith("observation.image") for k in config.input_shapes)
        self.use_env_state = "observation.environment_state" in config.input_shapes
        if self.config.use_vae:
            self.vae_encoder = ACTEncoder(config, is_vae_encoder=True)
            self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
            # Projection layer for joint-space configuration to hidden dimension.
            if self.use_robot_state:
                self.vae_encoder_robot_state_input_proj = nn.Linear(
                    config.input_shapes["observation.state"][0], config.dim_model
                )
            # Projection layer for action (joint-space target) to hidden dimension.
            self.vae_encoder_action_input_proj = nn.Linear(
                config.output_shapes["action"][0], config.dim_model
            )
            # Projection layer from the VAE encoder's output to the latent distribution's parameter space.
            self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2)
            # Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
            # dimension.
            num_input_token_encoder = 1 + config.chunk_size
            if self.use_robot_state:
                num_input_token_encoder += 1
            self.vae_encoder_pos_enc = create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0)
            self.vae_encoder_pos_enc.requires_grad = False

        # Backbone for image feature extraction.
        if self.use_images:
            resnet18_IMAGENET1K_V1 = ResNet(Block, [2, 2, 2, 2], num_classes=1000)
            state_dict = nn.state.safe_load("resnet18-f37072fd.safetensors")
            nn.state.load_state_dict(resnet18_IMAGENET1K_V1, state_dict)
            backbone_model = resnet18_IMAGENET1K_V1
            # Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final
            # feature map).
            # Note: The forward method of this returns a dict: {"feature_map": output}.
            self.backbone = backbone_model #IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})

        # Transformer (acts as VAE decoder when training with the variational objective).
        self.encoder = ACTEncoder(config)
        self.decoder = ACTDecoder(config)

        # Transformer encoder input projections. The tokens will be structured like
        # [latent, (robot_state), (env_state), (image_feature_map_pixels)].
        if self.use_robot_state:
            self.encoder_robot_state_input_proj = nn.Linear(
                config.input_shapes["observation.state"][0], config.dim_model
            )
        if self.use_env_state:
            self.encoder_env_state_input_proj = nn.Linear(
                config.input_shapes["observation.environment_state"][0], config.dim_model
            )
        self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model)
        if self.use_images:
            self.encoder_img_feat_input_proj = nn.Conv2d(
                512, config.dim_model, kernel_size=1
            )
        # Transformer encoder positional embeddings.
        n_1d_tokens = 1  # for the latent
        if self.use_robot_state:
            n_1d_tokens += 1
        if self.use_env_state:
            n_1d_tokens += 1
        self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model)
        if self.use_images:
            self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)

        # Transformer decoder.
        # Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
        self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model)

        # Final action regression head on the output of the transformer's decoder.
        self.action_head = nn.Linear(config.dim_model, config.output_shapes["action"][0])

        self._reset_parameters()

        # CHANGE THIS WHEN RUNNING.
        self.training=True

    def _reset_parameters(self):
        """Xavier-uniform initialization of the transformer parameters as in the original code."""
        for p in chain(nn.state.get_parameters(self.encoder), nn.state.get_parameters(self.decoder)):
            if p.ndim > 1:
                def xavier_uniform_(tensor: Tensor) -> Tensor:
                    fan_in, fan_out = tensor.shape[:2]
                    
                    # Calculate the range for the uniform distribution
                    # This is the glorot/xavier uniform initialization formula
                    a = math.sqrt(6.0 / (fan_in + fan_out))
                    
                    # Use uniform distribution to initialize the tensor
                    return Tensor.uniform(*tensor.shape, low=-a, high=a)
                p = xavier_uniform_(p)

    def __call__(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
        """A forward pass through the Action Chunking Transformer (with optional VAE encoder).

        `batch` should have the following structure:
        {
            "observation.state" (optional): (B, state_dim) batch of robot states.

            "observation.images": (B, n_cameras, C, H, W) batch of images.
                AND/OR
            "observation.environment_state": (B, env_dim) batch of environment states.

            "action" (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions.
        }

        Returns:
            (B, chunk_size, action_dim) batch of action sequences
            Tuple containing the latent PDF's parameters (mean, log(σ²)) both as (B, L) tensors where L is the
            latent dimension.
        """
        if self.config.use_vae and self.training:
            assert (
                "action" in batch
            ), "actions must be provided when using the variational objective in training mode."

        batch_size = (
            batch["observation.images"]
            if "observation.images" in batch
            else batch["observation.environment_state"]
        ).shape[0]

        print(f'batch_size: {batch_size}')

        # Prepare the latent for input to the transformer encoder.
        if self.config.use_vae and "action" in batch:
            # Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence].
            cls_embed = self.vae_encoder_cls_embed.weight.repeat(batch_size, 1, 1) # (B, 1, D)
            if self.use_robot_state:
                robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
                robot_state_embed = robot_state_embed.unsqueeze(1)  # (B, 1, D)
            action_embed = self.vae_encoder_action_input_proj(batch["action"])  # (B, S, D)

            if self.use_robot_state:
                vae_encoder_input = [cls_embed, robot_state_embed, action_embed]  # (B, S+2, D)
            else:
                vae_encoder_input = [cls_embed, action_embed]
            vae_encoder_input = Tensor.cat(*vae_encoder_input, dim=1)

            # Prepare fixed positional embedding.
            # Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
            pos_embed = self.vae_encoder_pos_enc.contiguous().detach()  # (1, S+2, D)

            # Prepare key padding mask for the transformer encoder. We have 1 or 2 extra tokens at the start of the
            # sequence depending whether we use the input states or not (cls and robot state)
            # False means not a padding token.
            cls_joint_is_pad = Tensor.full(
                shape=(batch_size, 2 if self.use_robot_state else 1),
                fill_value=False
            )
            key_padding_mask = Tensor.cat(
                cls_joint_is_pad, batch["action_is_pad"], dim=1
            )  # (bs, seq+1 or 2)

            print(f'vae_encoder_input.shape: {vae_encoder_input.shape}')
            print(f'pos_embed.shape: {pos_embed.shape}')
            print(f'key_padding_mask.shape: {key_padding_mask.shape}')

            # Forward pass through VAE encoder to get the latent PDF parameters.
            cls_token_out = self.vae_encoder(
                vae_encoder_input.permute(1, 0, 2),
                pos_embed=pos_embed.permute(1, 0, 2),
                key_padding_mask=key_padding_mask.permute(1,0),
            )
            print(f'cls_token_out.shape: {cls_token_out.shape}')
            cls_token_out = cls_token_out[0]  # select the class token, with shape (B, D)
            print(f'cls_token_out[0].shape: {cls_token_out.shape}')
            latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
            mu = latent_pdf_params[:, : self.config.latent_dim]
            # This is 2log(sigma). Done this way to match the original implementation.
            log_sigma_x2 = latent_pdf_params[:, self.config.latent_dim :]

            # Sample the latent with the reparameterization trick.
            latent_sample = mu + log_sigma_x2.div(2).exp() * Tensor.randn(*(mu.shape))
        else:
            # When not using the VAE encoder, we set the latent to be all zeros.
            mu = log_sigma_x2 = None
            # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
            latent_sample = Tensor.zeros(batch_size, self.config.latent_dim, dtype=dtypes.float32)

        # Prepare transformer encoder inputs.
        encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)]
        encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1))
        # Robot state token.
        if self.use_robot_state:
            encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"]))
        # Environment state token.
        if self.use_env_state:
            encoder_in_tokens.append(
                self.encoder_env_state_input_proj(batch["observation.environment_state"])
            )

        # Camera observation features and positional embeddings.
        if self.use_images:
            all_cam_features = []
            all_cam_pos_embeds = []

            for cam_index in range(batch["observation.images"].shape[-4]):
                cam_features = self.backbone(batch["observation.images"][:, cam_index])  #["feature_map"]
                print(f'backbone output: {cam_features.shape}')
                # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use
                # buffer
                cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).cast(dtype=cam_features.dtype)
                cam_features = self.encoder_img_feat_input_proj(cam_features)  # (B, C, h, w)
                print(f'cam_features: {cam_features.shape}')
                all_cam_features.append(cam_features)
                print(f'len all_cam_features: {len(all_cam_features)}')
                all_cam_pos_embeds.append(cam_pos_embed)
            # Concatenate camera observation feature maps and positional embeddings along the width dimension,
            # and move to (sequence, batch, dim).
            all_cam_features = Tensor.cat(*all_cam_features, dim=-1)
            print(f'len all_cam_features after cat: {len(all_cam_features)}')
            print(f'Before encoder_in_tokens.extend, encoder_in token len: {len(encoder_in_tokens)}')
            encoder_in_tokens.extend(all_cam_features.permute(2, 3, 0, 1).reshape(-1, all_cam_features.shape[0], all_cam_features.shape[1]))
            print(f'encoder_in_tokens: {len(encoder_in_tokens)}')
            all_cam_pos_embeds = Tensor.cat(*all_cam_pos_embeds, dim=-1)
            print(f'all_cam_pos_embeds: {all_cam_pos_embeds}')
            encoder_in_pos_embed.extend(all_cam_pos_embeds.permute(2, 3, 0, 1).reshape(-1, all_cam_pos_embeds.shape[0], all_cam_pos_embeds.shape[1]))

        print(f'Before tensor.stack, encoder_in token len: {len(encoder_in_tokens)}')
        print(f'Before tensor.stack, encoder_in_pos_embed token len: {len(encoder_in_pos_embed)}')
        # Stack all tokens along the sequence dimension.
        encoder_in_tokens = Tensor.stack(*encoder_in_tokens, dim=0)
        encoder_in_pos_embed = Tensor.stack(*encoder_in_pos_embed, dim=0)

        print(f'encoder_in_tokens: {len(encoder_in_tokens)}')
        print(f'encoder_in_pos_embed.shape: {encoder_in_pos_embed.shape}')

        # Forward pass through the transformer modules.
        encoder_out = self.encoder(encoder_in_tokens, pos_embed=encoder_in_pos_embed)
        # TODO(rcadene, alexander-soare): remove call to `device` ; precompute and use buffer
        decoder_in = Tensor.zeros(
            *(self.config.chunk_size, batch_size, self.config.dim_model),
            dtype=encoder_in_pos_embed.dtype
        )
        print(f'encoder_out.shape: {encoder_out.shape}')
        print(f'decoder_in.shape: {decoder_in.shape}')
        print(f'encoder_in_pos_embed.shape: {encoder_in_pos_embed.shape}')
        print(f'decoder_pos_embed.shape: {self.decoder_pos_embed.weight.shape}')
        print(f'decoder_pos_embed.shape unsqueezed: {self.decoder_pos_embed.weight.unsqueeze(1).shape}')
        decoder_out = self.decoder(
            decoder_in.permute(1,0,2),
            encoder_out.permute(1,0,2),
            encoder_pos_embed=encoder_in_pos_embed.permute(1,0,2),
            decoder_pos_embed=self.decoder_pos_embed.weight.unsqueeze(1).permute(1,0,2),
        )

        # Move back to (B, S, C).
        # decoder_out = decoder_out.transpose(0, 1)
        print(f'decoder_out: {decoder_out.shape}')

        actions = self.action_head(decoder_out)

        return actions, (mu, log_sigma_x2)

In [304]:
act = ACT(ACTConfig())

ram used:  2.30 GB, layer4.1.bn2.running_var                          : 100%|█| 


loaded weights in  15.17 ms, 0.04 GB loaded at 2.95 GB/s


In [305]:
class ACTTemporalEnsembler:
    def __init__(self, temporal_ensemble_coeff: float, chunk_size: int) -> None:
        """Temporal ensembling as described in Algorithm 2 of https://arxiv.org/abs/2304.13705.

        The weights are calculated as wᵢ = exp(-temporal_ensemble_coeff * i) where w₀ is the oldest action.
        They are then normalized to sum to 1 by dividing by Σwᵢ. Here's some intuition around how the
        coefficient works:
            - Setting it to 0 uniformly weighs all actions.
            - Setting it positive gives more weight to older actions.
            - Setting it negative gives more weight to newer actions.
        NOTE: The default value for `temporal_ensemble_coeff` used by the original ACT work is 0.01. This
        results in older actions being weighed more highly than newer actions (the experiments documented in
        https://github.com/huggingface/lerobot/pull/319 hint at why highly weighing new actions might be
        detrimental: doing so aggressively may diminish the benefits of action chunking).

        Here we use an online method for computing the average rather than caching a history of actions in
        order to compute the average offline. For a simple 1D sequence it looks something like:

        ```
        import torch

        seq = torch.linspace(8, 8.5, 100)
        print(seq)

        m = 0.01
        exp_weights = torch.exp(-m * torch.arange(len(seq)))
        print(exp_weights)

        # Calculate offline
        avg = (exp_weights * seq).sum() / exp_weights.sum()
        print("offline", avg)

        # Calculate online
        for i, item in enumerate(seq):
            if i == 0:
                avg = item
                continue
            avg *= exp_weights[:i].sum()
            avg += item * exp_weights[i]
            avg /= exp_weights[:i+1].sum()
        print("online", avg)
        ```
        """
        self.chunk_size = chunk_size
        self.ensemble_weights = (-temporal_ensemble_coeff * Tensor.arange(chunk_size)).exp()
        self.ensemble_weights_cumsum = self.ensemble_weights.cumsum(axis=0)
        self.reset()

    def reset(self):
        """Resets the online computation variables."""
        self.ensembled_actions = None
        # (chunk_size,) count of how many actions are in the ensemble for each time step in the sequence.
        self.ensembled_actions_count = None

    def update(self, actions: Tensor) -> Tensor:
        """
        Takes a (batch, chunk_size, action_dim) sequence of actions, update the temporal ensemble for all
        time steps, and pop/return the next batch of actions in the sequence.
        """
        if self.ensembled_actions is None:
            # Initializes `self._ensembled_action` to the sequence of actions predicted during the first
            # time step of the episode.
            self.ensembled_actions = actions.contiguous()
            # Note: The last dimension is unsqueeze to make sure we can broadcast properly for tensor
            # operations later.
            self.ensembled_actions_count = Tensor.ones(
                *(self.chunk_size, 1), dtype=dtypes.long
            )
        else:
            # self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
            # the online update for those entries.
            self.ensembled_actions *= self.ensemble_weights_cumsum[self.ensembled_actions_count - 1]
            self.ensembled_actions += actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count]
            self.ensembled_actions /= self.ensemble_weights_cumsum[self.ensembled_actions_count]
            self.ensembled_actions_count = (self.ensembled_actions_count + 1).clamp(max_=self.chunk_size)
            # The last action, which has no prior online average, needs to get concatenated onto the end.
            self.ensembled_actions = Tensor.cat(*[self.ensembled_actions, actions[:, -1:]], dim=1)
            self.ensembled_actions_count = Tensor.cat(
                *[self.ensembled_actions_count, Tensor.ones_like(self.ensembled_actions_count[-1:])]
            )
        # "Consume" the first action.
        action, self.ensembled_actions, self.ensembled_actions_count = (
            self.ensembled_actions[:, 0],
            self.ensembled_actions[:, 1:],
            self.ensembled_actions_count[1:],
        )
        return action

In [306]:
from normalize import *

class ACTPolicy:
    """
    Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
    Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
    """

    name = "act"

    def __init__(
        self,
        config: ACTConfig | None = None,
        dataset_stats: dict[str, dict[str, Tensor]] | None = None,
    ):
        """
        Args:
            config: Policy configuration class instance or None, in which case the default instantiation of
                    the configuration class is used.
            dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
                that they will be passed with a call to `load_state_dict` before the policy is used.
        """
        super().__init__()
        if config is None:
            config = ACTConfig()
        self.config: ACTConfig = config

        self.normalize_inputs = Normalize(
            config.input_shapes, config.input_normalization_modes, dataset_stats
        )
        self.normalize_targets = Normalize(
            config.output_shapes, config.output_normalization_modes, dataset_stats
        )
        self.unnormalize_outputs = Unnormalize(
            config.output_shapes, config.output_normalization_modes, dataset_stats
        )

        self.model = ACT(config)

        self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]

        if config.temporal_ensemble_coeff is not None:
            self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)

        self.reset()

    def reset(self):
        """This should be called whenever the environment is reset."""
        if self.config.temporal_ensemble_coeff is not None:
            self.temporal_ensembler.reset()
        else:
            self._action_queue = deque([], maxlen=self.config.n_action_steps)

    def select_action(self, batch: dict[str, Tensor]) -> Tensor:
        """Select a single action given environment observations.

        This method wraps `select_actions` in order to return one action at a time for execution in the
        environment. It works by managing the actions in a queue and only calling `select_actions` when the
        queue is empty.
        """
        Tensor.no_grad = True
        self.eval()

        batch = self.normalize_inputs(batch)
        if len(self.expected_image_keys) > 0:
            batch = dict(batch)  # shallow copy so that adding a key doesn't modify the original
            batch["observation.images"] = Tensor.stack(*[batch[k] for k in self.expected_image_keys], dim=-4)

        # If we are doing temporal ensembling, do online updates where we keep track of the number of actions
        # we are ensembling over.
        if self.config.temporal_ensemble_coeff is not None:
            actions = self.model(batch)[0]  # (batch_size, chunk_size, action_dim)
            actions = self.unnormalize_outputs({"action": actions})["action"]
            action = self.temporal_ensembler.update(actions)
            return action

        # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
        # querying the policy.
        if len(self._action_queue) == 0:
            actions = self.model(batch)[0][:, : self.config.n_action_steps]

            # TODO(rcadene): make _forward return output dictionary?
            actions = self.unnormalize_outputs({"action": actions})["action"]

            # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
            # effectively has shape (n_action_steps, batch_size, *), hence the transpose.
            self._action_queue.extend(actions.transpose(0, 1))
        item_to_return = self._action_queue.popleft()
        Tensor.no_grad = False
        return item_to_return

    def __call__(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
        """Run the batch through the model and compute the loss for training or validation."""
        batch = self.normalize_inputs(batch)
        if len(self.expected_image_keys) > 0:
            batch = dict(batch)  # shallow copy so that adding a key doesn't modify the original
            batch["observation.images"] = Tensor.stack(*[batch[k] for k in self.expected_image_keys], dim=-4)
        batch = self.normalize_targets(batch)
        actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)

        l1_loss = (
            (batch["action"] - actions_hat).abs() * batch["action_is_pad"].logical_not().int().unsqueeze(-1)
        ).mean()

        loss_dict = {"l1_loss": l1_loss.item()}
        if self.config.use_vae:
            # Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
            # each dimension independently, we sum over the latent dimension to get the total
            # KL-divergence per batch element, then take the mean over the batch.
            # (See App. B of https://arxiv.org/abs/1312.6114 for more details).
            mean_kld = (
                (-0.5 * (1 + log_sigma_x2_hat - mu_hat.square() - (log_sigma_x2_hat).exp())).sum(axis=-1).mean()
            )
            loss_dict["kld_loss"] = mean_kld.item()
            loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
        else:
            loss_dict["loss"] = l1_loss

        return loss_dict

In [308]:
from pathlib import Path

from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from torch.utils.data import DataLoader
import torch

import tinygrad
from tinygrad import Tensor, nn, TinyJit

from omegaconf import ListConfig, OmegaConf

from tinygrad.nn.state import safe_save, safe_load, get_state_dict, load_state_dict

# Start of training code

# Create a directory to store the training checkpoint.
output_directory = Path("outputs/train/example_pusht")
output_directory.mkdir(parents=True, exist_ok=True)

# Number of offline training steps (we'll only do offline training for this example.)
# Adjust as you prefer. 5000 steps are needed to get something worth evaluating.
training_steps = 100000
log_freq = 1

# Set up the dataset.
delta_timestamps = {
    "action": [i / 50.0 for i in range(100)],
}
dataset = LeRobotDataset('lerobot/aloha_sim_insertion_human', delta_timestamps=delta_timestamps)
print(dataset.stats)

cfg = ACTConfig()
policy = ACTPolicy(cfg, dataset_stats=dataset.stats)

params_not_backbone = [p for n, p in nn.state.get_state_dict(policy).items() if p.requires_grad != False and not n.startswith("model.backbone")]
params_backbone = [p for n, p in nn.state.get_state_dict(policy).items() if p.requires_grad != False and n.startswith("model.backbone")]

Tensor.manual_seed(1000)

if hasattr(cfg, 'override_dataset_stats'):
    for key, stats_dict in cfg.override_dataset_stats.items():
        for stats_type, listconfig in stats_dict.items():
            # example of stats_type: min, max, mean, std
            print(f'listconfig: {listconfig}')
            dataset.stats[key][stats_type] = torch.tensor(listconfig, dtype=torch.float32)

opt = nn.optim.AdamW(params_not_backbone, lr=1e-5, weight_decay=1e-4)
opt_backbone = nn.optim.AdamW(params_backbone, lr=1e-5, weight_decay=1e-4)

#@TinyJit
@Tensor.train()
def train_step(batch) -> Tensor:
    Tensor.training = True
    output_dict = policy(batch)
    loss = output_dict["loss"]
    opt.zero_grad()
    opt_backbone.zero_grad()
    loss.backward()
    opt.step()
    opt_backbone.step()
    return loss

print(f'Starting training loop')
# Create dataloader for offline training.
dataloader = DataLoader(
    dataset,
    num_workers=0,
    batch_size=8,
    shuffle=True,
    pin_memory=False,
    drop_last=True,
)

step = 0
done = False
with Tensor.train():
    while not done:
        for batch in dataloader:
            batch = {k: Tensor(v.numpy(), requires_grad=False) for k, v in batch.items()}
            loss = train_step(batch)
        
            if step % log_freq == 0:
                print(f"step: {step} loss: {loss.numpy():.3f}")
            step += 1

            if step % 10000 == 0:
                try:
                    state_dict = get_state_dict(policy)
                    safe_save(state_dict, f'{output_directory}/model_{step}.safetensors')
                except:
                    print(f'Exception with safe save occured')
            if step >= training_steps:
                done = True
                break

# Save a policy checkpoint.
state_dict = get_state_dict(policy)
safe_save(state_dict, f'{output_directory}/model_final.safetensors')

Fetching 56 files:   0%|          | 0/56 [00:00<?, ?it/s]

ram used:  1.71 GB, layer4.1.bn2.running_var                          : 100%|█| 


{'action': {'max': tensor([0.3175, 0.0844, 1.2226, 0.2807, 0.9986, 0.4418, 1.1625, 0.3206, 0.2056,
        1.2118, 0.7056, 1.1459, 0.4801, 0.9541]), 'mean': tensor([ 0.0075, -0.1817,  0.7322, -0.0069,  0.4357, -0.0031,  0.2792, -0.1002,
        -0.2062,  0.6435,  0.2017,  0.6110, -0.1440,  0.2546]), 'min': tensor([-0.2884, -0.9557,  0.3022, -0.2654, -0.5446, -0.4142,  0.0827, -0.4725,
        -0.9940,  0.0890, -0.2209, -0.4449, -0.8452, -0.0650]), 'std': tensor([0.1098, 0.2156, 0.2022, 0.1066, 0.2204, 0.1394, 0.2707, 0.1425, 0.3127,
        0.2861, 0.1915, 0.2999, 0.2880, 0.3504])}, 'episode_index': {'max': tensor([49.]), 'mean': tensor([24.5000]), 'min': tensor([0.]), 'std': tensor([14.4309])}, 'frame_index': {'max': tensor([499.]), 'mean': tensor([249.5000]), 'min': tensor([0.]), 'std': tensor([144.3372])}, 'index': {'max': tensor([24999.]), 'mean': tensor([12499.4971]), 'min': tensor([0.]), 'std': tensor([7216.8779])}, 'next.done': {'max': tensor([1.]), 'mean': tensor([0.0020]), 'mi

Exception ignored in: <function WeakValueDictionary.__init__.<locals>.remove at 0x110890a40>
Traceback (most recent call last):
  File "/opt/homebrew/Cellar/python@3.12/3.12.6/Frameworks/Python.framework/Versions/3.12/lib/python3.12/weakref.py", line 105, in remove
    def remove(wr, selfref=ref(self), _atomic_removal=_remove_dead_weakref):

KeyboardInterrupt: 


step: 3 loss: 52.718
batch_size: 8
vae_encoder_input.shape: (8, 102, 512)
pos_embed.shape: (1, 102, 512)
key_padding_mask.shape: (8, 102)
ACTEncoder x.shape per layer: (102, 8, 512)
(q,k,v): (102, 8, 8, 64), (102, 8, 8, 64), (102, 8, 8, 64)
(key_padding_mask): (102, 8)
ACTEncoder x.shape per layer: (102, 8, 512)
(q,k,v): (102, 8, 8, 64), (102, 8, 8, 64), (102, 8, 8, 64)
(key_padding_mask): (102, 8)
ACTEncoder x.shape per layer: (102, 8, 512)
(q,k,v): (102, 8, 8, 64), (102, 8, 8, 64), (102, 8, 8, 64)
(key_padding_mask): (102, 8)
ACTEncoder x.shape per layer: (102, 8, 512)
(q,k,v): (102, 8, 8, 64), (102, 8, 8, 64), (102, 8, 8, 64)
(key_padding_mask): (102, 8)
cls_token_out.shape: (102, 8, 512)
cls_token_out[0].shape: (8, 512)
backbone output: (8, 512, 15, 20)
<Tensor <LB METAL (1, 15, 20, 256) float (<BinaryOps.MUL: 2>, None)> on METAL with grad None>
<Tensor <LB METAL (1, 15, 20, 256) float (<BinaryOps.MUL: 2>, None)> on METAL with grad None>
x_range[..., 0::2].sin(): <Tensor <LB METAL 

KeyboardInterrupt: 