In [1]:
import torch

In [2]:
from batch_index_select import *

### Quick usage test

In [3]:
x = torch.tensor([
    [
        [1, 2, 1, 1],
        [4, 5, 1, 1],
        [7, 8, 1, 1]
    ],
    [
        [10, 11, 0, 0],
        [13, 14, 0, 0],
        [16, 17, 0, 0]
    ]
])
x.shape

torch.Size([2, 3, 4])

In [4]:
idx = torch.tensor([
    [0, 2],
    [1, 2]
])
idx.shape

torch.Size([2, 2])

In [5]:
batch_index_select(
    x=x,
    idx=idx
)

tensor([[[ 1,  2,  1,  1],
         [ 7,  8,  1,  1]],

        [[13, 14,  0,  0],
         [16, 17,  0,  0]]])

### Randomized tests

In [6]:
batch_size = 1
patch_embedding_dim = 256
num_patches_per_layer = 256
num_patches_to_select = 64

In [7]:
x = torch.randn((batch_size, num_patches_per_layer, patch_embedding_dim))
x

tensor([[[-0.4127,  0.0604,  0.3482,  ...,  0.9609, -0.5695,  1.8515],
         [-1.3135, -1.1062,  0.8621,  ..., -0.7530,  0.0483, -0.1770],
         [ 0.3445,  1.1182, -0.2598,  ...,  0.4021,  0.4755,  1.3950],
         ...,
         [ 0.3185, -0.9757, -0.8706,  ...,  1.4769,  1.5047, -0.3960],
         [ 0.3208, -1.3161, -1.3988,  ...,  0.4501, -0.8291,  0.7720],
         [-0.4082, -0.7709,  0.4343,  ..., -0.3795, -0.2892,  1.4858]]])

In [8]:
iters = 1000

for _ in range(iters):
  select_idx = torch.vstack(
      [
          torch.multinomial(torch.ones(num_patches_per_layer),
                            num_patches_to_select)
          for _ in range(1)
      ]
  )
  batch_index_selection = batch_index_select(x, select_idx)

  for batch_idx in range(batch_size):
    batch_selected_patches = batch_index_selection[batch_idx]
    expected_selected_patches = torch.index_select(
        input=x[batch_idx],
        index=select_idx[batch_idx],
        dim=0
    )
    assert torch.equal(batch_selected_patches, expected_selected_patches)