In [14]:
import torch
import math
from functools import reduce


def prob_mask_like(t, prob):
    return torch.zeros_like(t).float().uniform_(0, 1) < prob

def mask_with_tokens(t, token_ids):
    init_no_mask = torch.full_like(t, False, dtype=torch.bool)
    mask = reduce(lambda acc, el: acc | (t == el), token_ids, init_no_mask)
    return mask

def get_mask_subset_with_prob(mask, prob):
    batch, seq_len, device = *mask.shape, mask.device
    max_masked = math.ceil(prob * seq_len)

    num_tokens = mask.sum(dim=-1, keepdim=True)
    mask_excess = (mask.cumsum(dim=-1) > (num_tokens * prob).ceil())
    mask_excess = mask_excess[:, :max_masked]

    rand = torch.rand((batch, seq_len), device=device).masked_fill(~mask, -1e9)
    _, sampled_indices = rand.topk(max_masked, dim=-1)
    sampled_indices = (sampled_indices + 1).masked_fill_(mask_excess, 0)

    new_mask = torch.zeros((batch, seq_len + 1), device=device)
    new_mask.scatter_(-1, sampled_indices, 1)
    return new_mask[:, 1:].bool()


replace_prob = 0.3
mask_prob = 0.3

input = torch.tensor([[1, 2, 3, 4, 5, 9, 11, 13, 15]])

replace_prob = prob_mask_like(input, replace_prob)

print(replace_prob)

mask_ignore_token_ids = set([1, 2, 3])

no_mask = mask_with_tokens(input, mask_ignore_token_ids)
mask = get_mask_subset_with_prob(~no_mask, mask_prob)

print(no_mask)
print(mask)

tensor([[ True, False,  True,  True, False, False, False,  True, False]])
tensor([[ True,  True,  True, False, False, False, False, False, False]])
tensor([[False, False, False,  True,  True, False, False,  True, False]])


: 

In [6]:
import torch


def get_var_seq(var_list: torch.Tensor, indicate: torch.Tensor, device):
        # indicate.shape = (batch, max_len)
        result = []
        mask_ind = []
        var_len = var_list.size(0)

        for batch in indicate:
            seq = []
            mask = []
            for i, item in enumerate(batch):
                if item == 3:
                        seq.append(torch.full_like(var_list, 3, device=device))
                        mask.extend(range(i*var_len, i*var_len + var_len, 1))
                else:
                    seq.append(var_list)

            seq = torch.cat(seq, dim=0)
            result.append(seq)
            mask_ind.append(mask)
            
        result = torch.stack(result, dim=0)
        mask_ind = torch.tensor(mask_ind)
        return result, mask_ind

var_list = torch.tensor([4, 5, 6, 7, 8, 9])
indicate = torch.tensor([4, 5, 3, 3, 8, 9, 3, 11]).unsqueeze(0)

result, mask_ind = get_var_seq(var_list, indicate, None)

In [7]:
print(result)

tensor([[4, 5, 6, 7, 8, 9, 4, 5, 6, 7, 8, 9, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
         4, 5, 6, 7, 8, 9, 4, 5, 6, 7, 8, 9, 3, 3, 3, 3, 3, 3, 4, 5, 6, 7, 8, 9]])


In [8]:
print(mask_ind)

tensor([[12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 36, 37, 38, 39, 40, 41]])
