In [1]:

import torch
input_tokens = torch.tensor([
    [0, 1, 2, 3, 4, 5, 6, 7, 4999, 4999],
    [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
])

def generate_random_token_ids(shape, vocab_size, pad_token_id, mask_token_id):
    random_token_ids = torch.randint(0, vocab_size, shape)
    # Ensure random_token_ids do not contain pad_token_id or mask_token_id
    invalid_ids = {pad_token_id, mask_token_id}
    for i in range(random_token_ids.numel()):
        while random_token_ids.view(-1)[i].item() in invalid_ids:
            random_token_ids.view(-1)[i] = torch.randint(0, vocab_size, (1,))
    return random_token_ids

def masked_lm(input_ids, mlm_prob = 0.15, pad_token_id = 4999, mask_token_id = 5000, vocab_size = 10000):
    '''
    Masked language model
    '''
    mask = torch.rand(input_ids.size()) < mlm_prob # create a mask of true and false
    mask &= input_ids != pad_token_id # do not mask padding
    mask &= input_ids != mask_token_id # do not mask mask token

    ## create clones
    mlm_data = input_ids.clone()
    labels = input_ids.clone()

    ## get the indices of the mask
    mask_idx = mask.nonzero(as_tuple=True)

    ## randomly mask tokens
    mask_idx_shuffle = torch.randperm(mask_idx[0].shape[0])

    ## get the tokens for each type of masking (mask, random, keep)
    tomask_idx = mask_idx_shuffle[:int(mask_idx[0].shape[0] * 0.8)]
    torandom_idx = mask_idx_shuffle[int(mask_idx[0].shape[0] * 0.9):]

    ## mask the tokens
    mlm_data[mask_idx[0][tomask_idx], mask_idx[1][tomask_idx]] = mask_token_id
    mlm_data[mask_idx[0][torandom_idx], mask_idx[1][torandom_idx]] = generate_random_token_ids((torandom_idx.shape[0],), vocab_size, pad_token_id, mask_token_id) #@ this mask token is not in the vocab

    ## create the labels
    labels[~mask] = -100

    return mlm_data, labels 

In [3]:
input_tensor = torch.randint(2, 30, (5,5))
print('Input Tensor...')
print(input_tensor)

print('Masked Language Model...')
res = masked_lm(input_tensor, mlm_prob = 0.5, pad_token_id = 0, mask_token_id = 1, vocab_size = 30)
print('MLM data...')    
print(res[0])
print('Labels...')
print(res[1])

Input Tensor...
tensor([[14, 22, 29, 15,  8],
        [26,  7, 17,  3, 18],
        [ 8, 17,  4, 26, 12],
        [23, 25, 26, 22, 15],
        [28,  8, 13, 21,  6]])
Masked Language Model...
MLM data...
tensor([[ 1,  6, 29, 15,  8],
        [ 1,  7, 17,  3, 18],
        [ 8, 17,  4,  1, 12],
        [23, 25,  1, 22, 15],
        [28,  8,  1, 21,  1]])
Labels...
tensor([[  14,   22, -100, -100, -100],
        [  26, -100, -100, -100, -100],
        [-100, -100, -100,   26, -100],
        [-100, -100,   26, -100, -100],
        [-100, -100,   13,   21,    6]])


(assuming that you did not execute the above)

From the above example, observe how the labels matrix retains the true values of the original input tensor for the indices that are masked. Indices that are not masked are placed as -100. This is so that the cross_entropy calculation will ignore the labels that have values -100. The main output is actually the X (the MLM data). Observe the following 3 groups.

Group 1: indexes (0,0), (1,0), (2,3), (3,2), (4,2), (4,4) are masked as '1's
Group 2: index (0,1) has been randomly replaced with a value that is not in the special tokens (6 -> 22)
Group 3: index (4,3) has been kept as it is (21 stays as 21)