In [1]:
import torch

In [2]:
from batch_index_select import *

### Quick usage test

In [3]:
x = torch.tensor([
    [
        [1, 2, 1, 1],
        [4, 5, 1, 1],
        [7, 8, 1, 1]
    ],
    [
        [10, 11, 0, 0],
        [13, 14, 0, 0],
        [16, 17, 0, 0]
    ]
])
x.shape

torch.Size([2, 3, 4])

In [4]:
idx = torch.tensor([
    [0, 2],
    [1, 2]
])
idx.shape

torch.Size([2, 2])

In [5]:
batch_index_select(
    x=x,
    idx=idx
)

tensor([[[ 1,  2,  1,  1],
         [ 7,  8,  1,  1]],

        [[13, 14,  0,  0],
         [16, 17,  0,  0]]])

### Randomized tests

In [6]:
batch_size = 1
patch_embedding_dim = 256
num_patches_per_layer = 256
num_patches_to_select = 64

In [7]:
x = torch.randn((batch_size, num_patches_per_layer, patch_embedding_dim))
x

tensor([[[-0.0339, -1.2891,  0.3094,  ..., -1.1228,  0.3909,  1.6523],
         [ 0.1933,  1.2342,  0.7176,  ...,  0.0785, -0.4016,  1.0491],
         [-0.4878,  0.5987, -1.5426,  ...,  1.7683, -0.9930,  0.5460],
         ...,
         [-0.4368, -1.9926, -0.7336,  ..., -0.2809,  0.6574,  0.1650],
         [-0.7169,  0.5940, -1.3645,  ...,  0.2376, -0.1195, -1.3433],
         [-0.8502, -1.6204, -0.3106,  ..., -0.0919, -1.2947,  0.2784]]])

In [8]:
iters = 1000

for _ in range(iters):
  select_idx = torch.vstack(
      [
          torch.multinomial(torch.ones(num_patches_per_layer),
                            num_patches_to_select)
          for _ in range(1)
      ]
  )
  batch_index_selection = batch_index_select(x, select_idx)

  for batch_idx in range(batch_size):
    batch_selected_patches = batch_index_selection[batch_idx]
    expected_selected_patches = torch.index_select(
        input=x[batch_idx],
        index=select_idx[batch_idx],
        dim=0
    )
    assert torch.equal(batch_selected_patches, expected_selected_patches)