In [2]:
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
torch.arange(-3, 3 + 1) + torch.arange(10)[:, None]

tensor([[-3, -2, -1,  0,  1,  2,  3],
        [-2, -1,  0,  1,  2,  3,  4],
        [-1,  0,  1,  2,  3,  4,  5],
        [ 0,  1,  2,  3,  4,  5,  6],
        [ 1,  2,  3,  4,  5,  6,  7],
        [ 2,  3,  4,  5,  6,  7,  8],
        [ 3,  4,  5,  6,  7,  8,  9],
        [ 4,  5,  6,  7,  8,  9, 10],
        [ 5,  6,  7,  8,  9, 10, 11],
        [ 6,  7,  8,  9, 10, 11, 12]])

In [None]:
You're right, the padded zeros can influence the results. I'll update the implementation to include a mask for the transformer_decoder_layer. Here's the revised implementation:

python
Copy code
import torch
import torch.nn as nn

class TemporalLocalWindowAttentionGather(nn.Module):
    def __init__(self, d_model, nhead, window_size=3):
        super(TemporalLocalWindowAttentionGather, self).__init__()
        self.window_size = window_size
        self.transformer_decoder_layer = nn.TransformerDecoderLayer(d_model, nhead)

    def gather_key_value(self, x, window_size):
        b, t, n, c = x.shape

        # Pad the input tensor along the time dimension
        pad_size = window_size // 2
        x_padded = torch.cat([x.new_zeros(b, pad_size, n, c), x, x.new_zeros(b, pad_size, n, c)], dim=1)    # [b,t+2p,n,c]

        # Create indices for gathering key and value tensors
        indices = torch.arange(-pad_size, pad_size + 1, dtype=torch.long, device=x.device) + torch.arange(t)[:, None]   # [2p+1,t]
        indices = indices[None, :, None, None].expand(b, -1, n, c)

        key_value = torch.gather(x_padded, 1, indices)
        key_value = key_value.view(b, t, window_size * n, c).transpose(2, 3)

        return key_value

    def generate_mask(self, b, t, n, c, window_size):
        pad_size = window_size // 2
        mask = torch.zeros(b, t, window_size * n, dtype=torch.bool, device=x.device)
        mask[:, :, pad_size * n:(pad_size + t) * n] = True
        return mask

    def forward(self, x):
        b, t, n, c = x.shape
        query = x.view(b, t * n, c)

        key_value = self.gather_key_value(x, self.window_size)
        mask = self.generate_mask(b, t, n, c, self.window_size)
        output = self.transformer_decoder_layer(query, key_value, key_padding_mask=mask)
        output = output.view(b, t, n, c)

        return output

# Example usage:
d_model = 64
nhead = 8
batch_size = 16
num_frames = 10
num_tokens = 50

video_features = torch.randn(batch_size, num_frames, num_tokens, d_model)
temporal_local_window_attention_gather = TemporalLocalWindowAttentionGather(d_model, nhead)
output = temporal_local_window_attention_gather(video_features)

print("Output shape: ", output.shape)