In [2]:
import torch

In [3]:
bsz = 3
max_len = 9
lens = torch.tensor([1,4,9]).unsqueeze(1) 

m = torch.arange(max_len)
m

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8])

In [4]:
lens.size()

torch.Size([3, 1])

In [5]:
m.repeat(bsz, 1)

tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8],
        [0, 1, 2, 3, 4, 5, 6, 7, 8],
        [0, 1, 2, 3, 4, 5, 6, 7, 8]])

In [6]:
m = m.expand(bsz, max_len)
m

tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8],
        [0, 1, 2, 3, 4, 5, 6, 7, 8],
        [0, 1, 2, 3, 4, 5, 6, 7, 8]])

In [7]:
~(m < lens) # tilde

tensor([[False,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False, False, False, False, False, False]])

In [12]:
def get_src_key_padding_mask(bsz, max_len, lens):
    """return a Boolean mask for a list or tensor of sequence lengths
    True for values in tensor greater than sequence length

    bsz (int)
    max_len (int): max seq len of item in batch
    lens [bsz]: list or tensor of lengths
    """
    if type(lens) == list:
        lens = torch.tensor(lens)
    assert lens.dim() == 1
    
    lens = lens.unsqueeze(1) # [bsz] -> [bsz, seq_len]
    m = torch.arange(max_len)
    m = m.expand(bsz, max_len) # repeat along batch dimension
    m = (m < lens)
    return ~m # tilde inverts a bool tensor

In [13]:
get_src_key_padding_mask(3, 9, torch.tensor([1,4,9]))

tensor([[False,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False, False, False, False, False, False]])

In [14]:
get_src_key_padding_mask(3, 9, [1,4,9])

tensor([[False,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False, False, False, False, False, False]])

# desired text lens

In [8]:
import torch

In [35]:
len_modifier = 2
lens = torch.tensor([1,4,7,8,9,10])
max_len = torch.tensor(10)
# max_len = torch.tensor(10).expand(lens.size(0))

In [36]:
lens

tensor([ 1,  4,  7,  8,  9, 10])

In [37]:
lens2 = torch.min(lens + len_modifier, max_len)

In [38]:
lens2

tensor([ 3,  6,  9, 10, 10, 10])

In [39]:
max_len

tensor(10)

In [42]:
len_modifier = -2
lens = torch.tensor([1,4,7,8,9,10])
min_len = torch.tensor(1)
lens2 = torch.max(lens + len_modifier, min_len)
lens2

tensor([1, 2, 5, 6, 7, 8])

In [44]:
def get_desired_text_lens(lens, len_modifier):
    if len_modifier < 0:
        min_len = torch.tensor(1)
        return torch.max(lens + len_modifier, min_len)
    elif len_modifier > 0:
        max_len = torch.max(lens)
        return torch.min(lens + len_modifier, max_len)
    else:
        return lens

In [45]:
get_desired_text_lens(torch.tensor([1,4,7,8,9,10]), 2)

tensor([ 3,  6,  9, 10, 10, 10])

In [46]:
get_desired_text_lens(torch.tensor([1,4,7,8,9,10]), -2)

tensor([1, 2, 5, 6, 7, 8])

In [47]:
get_desired_text_lens(torch.tensor([1,4,7,8,9,10]), 0)

tensor([ 1,  4,  7,  8,  9, 10])