We'll create the local and global masks for the enitre sequence assuming that input tokens are concatenated in the following order: [cnn scene tokens, object tokens, language tokens].

Since we can only create the mask after the padding, we use scene, object and text temporal vectors with values:

- -1 for history tokens
- 0 for padding tokens
- 1 .. N the number of the corresponding future frame 

Text input tokens will always be history tokens. We make text input tokens global.

In [1]:
import torch

# Example with 3 history frames, 2 future frames and 1 padding token.
scene_temporal_ids1 = torch.Tensor([-1, -1, -1, 1, 2, 0])
# The history frames have a total of 6 objects, future frame 1 has 3 objects and future frame 2 has 2 objects.
object_temporal_ids1 = torch.Tensor([-1, -1, -1, -1, -1, -1, 1, 1, 1, 2, 2, 0, 0])
# There is no text for future frames.
text_temporal_ids1 = torch.Tensor([-1, -1, -1, -1, -1, -1, -1, 0, 0])

In [2]:
# Create a second sample with 2 history and 3 future frames
scene_temporal_ids2 = torch.Tensor([-1, -1, 1, 2, 3, 0])
object_temporal_ids2 = torch.Tensor([-1, -1, -1, -1, 1, 1, 1, 2, 2, 3, 0, 0, 0])
text_temporal_ids2 = torch.Tensor([-1, -1, -1, -1, -1, 0, 0, 0, 0])

# Concatenate them in a batch
scene_temporal_ids = torch.stack([scene_temporal_ids1, scene_temporal_ids2])
object_temporal_ids = torch.stack([object_temporal_ids1, object_temporal_ids2])
text_temporal_ids = torch.stack([text_temporal_ids1, text_temporal_ids2])

In [3]:
from emma_policy.datamodules.collate import (
    make_text_history_global_pattern,
    make_encoder_causal_mask_batch,
)

attention2d = make_encoder_causal_mask_batch(
    scene_temporal_ids,
    object_temporal_ids,
    text_temporal_ids,
    dtype=scene_temporal_ids.dtype,
)
global_attenion = make_text_history_global_pattern(
    scene_temporal_ids,
    object_temporal_ids,
    text_temporal_ids,
    dtype=scene_temporal_ids.dtype,
)

In [4]:
print(f"2D ttention mask shape: batch size x total tokens x total tokens = {attention2d.shape}")
assert (
    attention2d.shape[2]
    == scene_temporal_ids.shape[1] + object_temporal_ids.shape[1] + text_temporal_ids.shape[1]
)
print(f"Global ttention mask shapebatch size x total tokens =  {global_attenion.shape}")
assert (
    global_attenion.shape[1]
    == scene_temporal_ids.shape[1] + object_temporal_ids.shape[1] + text_temporal_ids.shape[1]
)
global_attenion

2D ttention mask shape: batch size x total tokens x total tokens = torch.Size([2, 28, 28])
Global ttention mask shapebatch size x total tokens =  torch.Size([2, 28])


tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 1., 1., 1., 1., 1., 1., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 1., 1., 1., 1., 1., 0., 0., 0., 0.]])

Now we can check any possible combination. The element (i, j) of the 2D attention mask is 1 if element i is allowed to attend to element j.

In [5]:
print("First item in the batch.")
print("Scene-to-scene attention: 3 history frames, 2 future frames, 1 padding")
scene_len = scene_temporal_ids.shape[-1]
attention2d[0, :scene_len, :scene_len]

First item in the batch.
Scene-to-scene attention: 3 history frames, 2 future frames, 1 padding


tensor([[1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [0., 0., 0., 0., 0., 0.]])

In [6]:
print(
    "Scene-to-objects attention: 6 history objects, 3 objects in frame 1, 2 objects in frame 2, 2 paddings"
)
object_len = object_temporal_ids.shape[-1]
attention2d[0, :scene_len, scene_len : scene_len + object_len]

Scene-to-objects attention: 6 history objects, 3 objects in frame 1, 2 objects in frame 2, 2 paddings


tensor([[1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [7]:
print("Scene-to-text attention: 7 history text tokens")
text_len = text_temporal_ids.shape[-1]
attention2d[0, :scene_len, scene_len + object_len :]

Scene-to-text attention: 7 history text tokens


tensor([[1., 1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.]])