In some case, we have to pick up from a tensor some columns that consist a specific value in corresponding another tensor or matrix. Here, I have implemented a function that does this process effectively.

This study is coded using Torch but similar functions to the required functions; $bincount$, $nonzero$, $gather$ and $repeat$ are also avaible in other NN freamworks.

In [3]:
import torch
import time

def operation_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    elapsed_ss = '{0:.4f}'.format(elapsed_time - (elapsed_mins * 60)-elapsed_secs)
    return elapsed_mins, elapsed_secs,elapsed_ss

Let's create dummy embedding tensor and it's corresponding another tensor (let' call it as '$hp$'. Do not worry about the name, it's just a name)  that keeps info bout each embedding vector. Dimension of embedding tensor is $batch\_size$ x $max\_seq\_length$ x $hidden\_dim$. Dimension of $hp$ is $batch\_size$ x $max\_seq\_length$.

In [70]:
batch_size = 8
max_seq_len = 10
hidden_size = 9
embedding_tensor = torch.empty(batch_size, max_seq_len, hidden_size)
for i in range(batch_size):
    for j in range(max_seq_len):
        for k in range(hidden_size):
            embedding_tensor[i,j,k] = i + j*10 + k*100

In [71]:
embedding_tensor

tensor([[[  0., 100., 200., 300., 400., 500., 600., 700., 800.],
         [ 10., 110., 210., 310., 410., 510., 610., 710., 810.],
         [ 20., 120., 220., 320., 420., 520., 620., 720., 820.],
         [ 30., 130., 230., 330., 430., 530., 630., 730., 830.],
         [ 40., 140., 240., 340., 440., 540., 640., 740., 840.],
         [ 50., 150., 250., 350., 450., 550., 650., 750., 850.],
         [ 60., 160., 260., 360., 460., 560., 660., 760., 860.],
         [ 70., 170., 270., 370., 470., 570., 670., 770., 870.],
         [ 80., 180., 280., 380., 480., 580., 680., 780., 880.],
         [ 90., 190., 290., 390., 490., 590., 690., 790., 890.]],

        [[  1., 101., 201., 301., 401., 501., 601., 701., 801.],
         [ 11., 111., 211., 311., 411., 511., 611., 711., 811.],
         [ 21., 121., 221., 321., 421., 521., 621., 721., 821.],
         [ 31., 131., 231., 331., 431., 531., 631., 731., 831.],
         [ 41., 141., 241., 341., 441., 541., 641., 741., 841.],
         [ 51., 151., 2

In [79]:
hp = torch.randint(10, (batch_size, max_seq_len))

In [80]:
hp

tensor([[6, 0, 5, 3, 8, 7, 6, 1, 9, 8],
        [5, 1, 2, 1, 7, 0, 2, 1, 8, 9],
        [8, 2, 8, 7, 0, 0, 7, 7, 6, 8],
        [6, 3, 1, 4, 4, 3, 8, 5, 4, 6],
        [1, 3, 6, 8, 8, 0, 0, 9, 8, 8],
        [7, 5, 1, 4, 0, 6, 2, 9, 8, 8],
        [1, 5, 2, 2, 9, 5, 4, 1, 4, 9],
        [0, 5, 7, 6, 4, 7, 0, 3, 3, 3]])

For instance, we want to collect embedding vectors that is correspoding value of hp is equal to 2. 

In [81]:
wanted = 2

Firstly, we have to find which indices of $hp$ consist the value of $wanted$. Here, $wanted$ is equal to 2.

In [82]:
batch_is,token_is = (hp==wanted).nonzero(as_tuple=True)

Some of the rows in $hp$ has no $wanted$ value, some of them consist one or more than one. Therefore, the number of $wanted$ value in a row may vary. Because of this, we have to find max number that $wanted$ occurs in a row.  We have to store the indices in a tensor, so I initialize all this tensor by -1. Here -1 correponds the last vector of sequence and I am sure it is padding vector.

In [83]:
indices = torch.zeros((batch_size,torch.bincount(batch_is).max().item()))

Now, fill the $indices$ with corresponding values.

In [84]:
tic=time.time()
for batch_i, token_i in zip(batch_is,token_is):
    m =(batch_is==batch_i).nonzero().flatten().tolist()
    k = len(m)
    indices[batch_i,:k]= token_is[m]
toc=time.time()
mins, secs, sses = operation_time(tic, toc)
print(f'Finished in {mins}m {secs}s {sses}ss with using gather function')

Finished in 0m 0s 0.0015ss with using gather function


We got the indices that can be used with $gather$ function to pick up corresponding vectors.

In [85]:
indices

tensor([[0., 0.],
        [2., 6.],
        [1., 0.],
        [0., 0.],
        [0., 0.],
        [6., 0.],
        [2., 3.],
        [0., 0.]])

To use gather we have to aware three parameter, 
- input : input tensor
- dim     : dimension along to collect values
- indices : tensor with indices of values to collect (it needs to be same dimension with input tensor except dim axis)

We will collect our vector on dim 1

In [86]:
#repeat hidden_size times
batch_size, max_seq = indices.shape 
indices_ = indices.repeat(1,hidden_size).view(batch_size,-1,max_seq).transpose(2,1).type(torch.int64)

Aware that, $indices$ must be type of integer

In [87]:
indices_.shape

torch.Size([8, 2, 9])

In [88]:
indices_

tensor([[[0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0]],

        [[2, 2, 2, 2, 2, 2, 2, 2, 2],
         [6, 6, 6, 6, 6, 6, 6, 6, 6]],

        [[1, 1, 1, 1, 1, 1, 1, 1, 1],
         [0, 0, 0, 0, 0, 0, 0, 0, 0]],

        [[0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0]],

        [[0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0]],

        [[6, 6, 6, 6, 6, 6, 6, 6, 6],
         [0, 0, 0, 0, 0, 0, 0, 0, 0]],

        [[2, 2, 2, 2, 2, 2, 2, 2, 2],
         [3, 3, 3, 3, 3, 3, 3, 3, 3]],

        [[0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0]]])

In [89]:
reduced_embedding_tensor = torch.gather(embedding_tensor,1,indices_)

In [90]:
reduced_embedding_tensor

tensor([[[  0., 100., 200., 300., 400., 500., 600., 700., 800.],
         [  0., 100., 200., 300., 400., 500., 600., 700., 800.]],

        [[ 21., 121., 221., 321., 421., 521., 621., 721., 821.],
         [ 61., 161., 261., 361., 461., 561., 661., 761., 861.]],

        [[ 12., 112., 212., 312., 412., 512., 612., 712., 812.],
         [  2., 102., 202., 302., 402., 502., 602., 702., 802.]],

        [[  3., 103., 203., 303., 403., 503., 603., 703., 803.],
         [  3., 103., 203., 303., 403., 503., 603., 703., 803.]],

        [[  4., 104., 204., 304., 404., 504., 604., 704., 804.],
         [  4., 104., 204., 304., 404., 504., 604., 704., 804.]],

        [[ 65., 165., 265., 365., 465., 565., 665., 765., 865.],
         [  5., 105., 205., 305., 405., 505., 605., 705., 805.]],

        [[ 26., 126., 226., 326., 426., 526., 626., 726., 826.],
         [ 36., 136., 236., 336., 436., 536., 636., 736., 836.]],

        [[  7., 107., 207., 307., 407., 507., 607., 707., 807.],
         [ 

Yes, we have done it succesfully. Let's define as all this procedure in a function.

In [97]:
def pick_up_corresponding_vectors(embedding_tensor,hp,wanted):
    batch_size, sequence_length, hidden_size = embedding_tensor.shape
    batch_is,token_is = (hp==wanted).nonzero(as_tuple=True)
    indices = torch.zeros((batch_size,torch.bincount(batch_is).max().item()))
    for batch_i, token_i in zip(batch_is,token_is):
        m =(batch_is==batch_i).nonzero().flatten().tolist()
        k = len(m)
        indices[batch_i,:k]= token_is[m]
    batch_size, max_seq = indices.shape 
    indices_ = indices.repeat(1,hidden_size).view(batch_size,-1,max_seq).transpose(2,1).type(torch.int64)
    return torch.gather(embedding_tensor,1,indices_)

Let's call this function with a huge tensor to examine it's speed.

In [91]:
# dummy embedding tensor
batch_size = 800
max_seq_len = 58
hidden_size = 500
embedding_tensor = torch.empty(batch_size, max_seq_len, hidden_size)
for i in range(batch_size):
    for j in range(max_seq_len):
        for k in range(hidden_size):
            embedding_tensor[i,j,k] = i + j*10 + k*100

In [95]:
hp = torch.randint(10, (batch_size, max_seq_len))
wanted = 2

In [99]:
tic=time.time()
pick_up_corresponding_vectors(embedding_tensor,hp,wanted)
toc=time.time()
mins, secs, sses = operation_time(tic, toc)
print(f'Finished in {mins}m {secs}s {sses}ss')

Finished in 0m 0s 0.2922ss


If you know faster method does the same process, please inform me. Thanks for your attention. 