In [3]:
'''
ref : https://medium.com/@mbednarski/understanding-indexing-with-pytorch-gather-33717a84ebc4
'''

import torch

indices = torch.LongTensor([3,7,4,1])
indices = indices.unsqueeze(-1)
print(indices.shape)
print(indices)

torch.Size([4, 1])
tensor([[3],
        [7],
        [4],
        [1]])


In [10]:
target = torch.arange(0,40).unsqueeze(0)
target_ = target.reshape(4,-1)
print(target_.shape)
print(target_)

torch.Size([4, 10])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
        [20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34, 35, 36, 37, 38, 39]])


In [11]:
'''
torch.gather(input, dim, index, out=None, sparse_grad=False) → Tensor
Gathers values along an axis specified by dim.
'''

'\ntorch.gather(input, dim, index, out=None, sparse_grad=False) → Tensor\nGathers values along an axis specified by dim.\n\n'

In [12]:
torch.gather(input=target_, dim=1, index=indices )

tensor([[ 3],
        [17],
        [24],
        [31]])

In [13]:
'''
3D example

In three dimensions, things become more tricky.

Imagine we have a following scenario: RNN network with sequences padded to maximum length. 
We would like to collect last element in each sequence, with all features from rnn hidden state.

BATCH_SIZE x MAX_SEQ_LEN x HIDDEN_STAT


'''

'\n3D example\n\nIn three dimensions, things become more tricky.\n\nImagine we have a following scenario: RNN network with sequences padded to maximum length. \nWe would like to collect last element in each sequence, with all features from rnn hidden state.\n'

In [16]:
batch_size = 8
max_seq_len = 9
hidden_size = 6
x = torch.empty(batch_size, max_seq_len, hidden_size)
for i in range(batch_size):
  for j in range(max_seq_len):
    for k in range(hidden_size):
      x[i,j,k] = i + j*10 + k*100

print(x)

# value “123” means “1st batch, 2nd sequence element, 3rd hidden state”. If we do:

tensor([[[  0., 100., 200., 300., 400., 500.],
         [ 10., 110., 210., 310., 410., 510.],
         [ 20., 120., 220., 320., 420., 520.],
         [ 30., 130., 230., 330., 430., 530.],
         [ 40., 140., 240., 340., 440., 540.],
         [ 50., 150., 250., 350., 450., 550.],
         [ 60., 160., 260., 360., 460., 560.],
         [ 70., 170., 270., 370., 470., 570.],
         [ 80., 180., 280., 380., 480., 580.]],

        [[  1., 101., 201., 301., 401., 501.],
         [ 11., 111., 211., 311., 411., 511.],
         [ 21., 121., 221., 321., 421., 521.],
         [ 31., 131., 231., 331., 431., 531.],
         [ 41., 141., 241., 341., 441., 541.],
         [ 51., 151., 251., 351., 451., 551.],
         [ 61., 161., 261., 361., 461., 561.],
         [ 71., 171., 271., 371., 471., 571.],
         [ 81., 181., 281., 381., 481., 581.]],

        [[  2., 102., 202., 302., 402., 502.],
         [ 12., 112., 212., 312., 412., 512.],
         [ 22., 122., 222., 322., 422., 522.],
         

In [17]:
lens = torch.LongTensor([5,6,1,8,3,7,3,4])
# add one trailing dimension
lens = lens.unsqueeze(-1)
print(lens.shape)

# repeat 6 times
indices = lens.repeat(1,6)
print(indices.shape)

print(indices)


torch.Size([8, 1])
torch.Size([8, 6])
tensor([[5, 5, 5, 5, 5, 5],
        [6, 6, 6, 6, 6, 6],
        [1, 1, 1, 1, 1, 1],
        [8, 8, 8, 8, 8, 8],
        [3, 3, 3, 3, 3, 3],
        [7, 7, 7, 7, 7, 7],
        [3, 3, 3, 3, 3, 3],
        [4, 4, 4, 4, 4, 4]])


In [18]:
indices = indices.unsqueeze(1)
print(indices.shape)

torch.Size([8, 1, 6])


In [20]:
results = torch.gather(x, 1, indices)
print(results.shape)

torch.Size([8, 1, 6])


In [21]:
results

tensor([[[ 50., 150., 250., 350., 450., 550.]],

        [[ 61., 161., 261., 361., 461., 561.]],

        [[ 12., 112., 212., 312., 412., 512.]],

        [[ 83., 183., 283., 383., 483., 583.]],

        [[ 34., 134., 234., 334., 434., 534.]],

        [[ 75., 175., 275., 375., 475., 575.]],

        [[ 36., 136., 236., 336., 436., 536.]],

        [[ 47., 147., 247., 347., 447., 547.]]])