In [1]:

import torch
input_tokens = torch.tensor([
    [0, 1, 2, 3, 4, 5, 6, 7, 4999, 4999],
    [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
])

def generate_random_token_ids(shape, vocab_size, pad_token_id, mask_token_id):
    random_token_ids = torch.randint(0, vocab_size, shape)
    # Ensure random_token_ids do not contain pad_token_id or mask_token_id
    invalid_ids = {pad_token_id, mask_token_id}
    for i in range(random_token_ids.numel()):
        while random_token_ids.view(-1)[i].item() in invalid_ids:
            random_token_ids.view(-1)[i] = torch.randint(0, vocab_size, (1,))
    return random_token_ids

def masked_lm(input_ids, mlm_prob = 0.15, pad_token_id = 4999, mask_token_id = 5000, vocab_size = 10000):
    '''
    Masked language model
    '''
    mask = torch.rand(input_ids.size()) < mlm_prob # create a mask of true and false
    mask &= input_ids != pad_token_id # do not mask padding
    mask &= input_ids != mask_token_id # do not mask mask token

    ## create clones
    mlm_data = input_ids.clone()
    labels = input_ids.clone()

    ## get the indices of the mask
    mask_idx = mask.nonzero(as_tuple=True)

    ## randomly mask tokens
    mask_idx_shuffle = torch.randperm(mask_idx[0].shape[0])

    ## get the tokens for each type of masking (mask, random, keep)
    tomask_idx = mask_idx_shuffle[:int(mask_idx[0].shape[0] * 0.8)]
    torandom_idx = mask_idx_shuffle[int(mask_idx[0].shape[0] * 0.9):]

    ## mask the tokens
    mlm_data[mask_idx[0][tomask_idx], mask_idx[1][tomask_idx]] = mask_token_id
    mlm_data[mask_idx[0][torandom_idx], mask_idx[1][torandom_idx]] = generate_random_token_ids((torandom_idx.shape[0],), vocab_size, pad_token_id, mask_token_id) #@ this mask token is not in the vocab

    ## create the labels
    labels[~mask] = -100

    return mlm_data, labels 

In [3]:
input_tensor = torch.randint(2, 30, (5,5))
print('Input Tensor...')
print(input_tensor)

print('Masked Language Model...')
res = masked_lm(input_tensor, mlm_prob = 0.5, pad_token_id = 0, mask_token_id = 1, vocab_size = 30)
print('MLM data...')    
print(res[0])
print('Labels...')
print(res[1])

Input Tensor...
tensor([[14, 22, 29, 15,  8],
        [26,  7, 17,  3, 18],
        [ 8, 17,  4, 26, 12],
        [23, 25, 26, 22, 15],
        [28,  8, 13, 21,  6]])
Masked Language Model...
MLM data...
tensor([[ 1,  6, 29, 15,  8],
        [ 1,  7, 17,  3, 18],
        [ 8, 17,  4,  1, 12],
        [23, 25,  1, 22, 15],
        [28,  8,  1, 21,  1]])
Labels...
tensor([[  14,   22, -100, -100, -100],
        [  26, -100, -100, -100, -100],
        [-100, -100, -100,   26, -100],
        [-100, -100,   26, -100, -100],
        [-100, -100,   13,   21,    6]])


(assuming that you did not execute the above)

From the above example, observe how the labels matrix retains the true values of the original input tensor for the indices that are masked. Indices that are not masked are placed as -100. This is so that the cross_entropy calculation will ignore the labels that have values -100. The main output is actually the X (the MLM data). Observe the following 3 groups.

Group 1: indexes (0,0), (1,0), (2,3), (3,2), (4,2), (4,4) are masked as '1's
Group 2: index (0,1) has been randomly replaced with a value that is not in the special tokens (6 -> 22)
Group 3: index (4,3) has been kept as it is (21 stays as 21)

In [None]:
## from the above example,
##

In [41]:
torch.randint(100, (10,2))

tensor([[25, 72],
        [32,  9],
        [30, 19],
        [85, 49],
        [75, 22],
        [74, 72],
        [12, 90],
        [27, 64],
        [21, 53],
        [61, 50]])

In [18]:
_, _, shape = masked_lm(torch.randint(0, 10000, (10,10)))

In [19]:
def generate_random_token_ids(shape, vocab_size, pad_token_id, mask_token_id):
    random_token_ids = torch.randint(vocab_size, shape)
    for i in range(num_tokens):
        while random_token_ids[i].item() in [pad_token_id, mask_token_id]:
            random_token_ids[i] = torch.randint(vocab_size, (1,))
    return random_token_ids

In [43]:
random_token_ids = torch.randint(10000, (10,2))
## writing a for function for each token in the shape
for i in range(shape[0]):
    
    while random_token_ids[i].item() in [4999, 5000]:
        random_token_ids[i] = torch.randint(10000, (1,))

TypeError: 'tuple' object cannot be interpreted as an integer

In [44]:
def generate_random_token_ids(shape, vocab_size, pad_token_id, mask_token_id):
    random_token_ids = torch.randint(vocab_size, shape)
    mask_pad_ids = torch.tensor([pad_token_id, mask_token_id])
    
    for idx in torch.where(torch.isin(random_token_ids, mask_pad_ids))[0]:
        while random_token_ids.view(-1)[idx] in mask_pad_ids:
            random_token_ids.view(-1)[idx] = torch.randint(vocab_size, (1,))
    
    return random_token_ids

In [117]:
random_token_ids = torch.randint(10, (10,10))
mask_pad_ids = torch.tensor([0, 1])


In [118]:
random_token_ids

tensor([[5, 2, 5, 5, 7, 4, 8, 1, 4, 1],
        [5, 8, 7, 0, 1, 9, 4, 8, 7, 1],
        [6, 0, 6, 9, 5, 6, 4, 2, 0, 3],
        [5, 3, 2, 3, 9, 0, 5, 3, 9, 9],
        [2, 5, 7, 0, 8, 8, 2, 2, 8, 6],
        [4, 9, 8, 6, 5, 4, 1, 4, 3, 1],
        [0, 4, 6, 6, 5, 1, 3, 2, 3, 0],
        [1, 7, 1, 1, 9, 4, 4, 2, 5, 6],
        [5, 9, 8, 6, 1, 3, 2, 4, 7, 2],
        [3, 8, 7, 3, 6, 8, 6, 1, 1, 5]])

In [119]:
for idx in torch.where(torch.isin(random_token_ids, mask_pad_ids)):
    print(idx)

tensor([0, 0, 1, 1, 1, 2, 2, 3, 4, 5, 5, 6, 6, 6, 7, 7, 7, 8, 9, 9])
tensor([7, 9, 3, 4, 9, 1, 8, 5, 3, 6, 9, 0, 5, 9, 0, 2, 3, 4, 7, 8])


In [120]:
mask_pad_ids

tensor([0, 1])

In [123]:
to_replace = random_token_ids[torch.where(torch.isin(random_token_ids, mask_pad_ids))] 

for i in range(len(to_replace)):
    while to_replace.view(-1)[i] in mask_pad_ids:
        to_replace.view(-1)[i] = torch.randint(1000, (1,))


In [124]:
to_replace

tensor([116, 263,  24, 988,  54, 797, 479, 696, 842, 238, 112,  93, 123, 796,
         77, 352, 487, 966, 373, 918])

In [125]:
for i in random_token_ids[torch.where(torch.isin(random_token_ids, mask_pad_ids))]:
    print(i)

tensor(1)
tensor(1)
tensor(0)
tensor(1)
tensor(1)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(1)
tensor(1)
tensor(0)
tensor(1)
tensor(0)
tensor(1)
tensor(1)
tensor(1)
tensor(1)
tensor(1)
tensor(1)


In [126]:
def generate_random_token_ids(shape, vocab_size, pad_token_id, mask_token_id):
    random_token_ids = torch.randint(vocab_size, shape)
    mask_pad_ids = torch.tensor([pad_token_id, mask_token_id])
    
    for value in random_token_ids[torch.where(torch.isin(random_token_ids, mask_pad_ids))] :
        while value in mask_pad_ids:
            value = torch.randint(vocab_size, (1,))
    
    return random_token_ids

In [128]:
generate_random_token_ids((10,2), 10, 0, 1)

tensor([[3, 0],
        [6, 4],
        [3, 6],
        [7, 8],
        [4, 7],
        [9, 5],
        [5, 1],
        [9, 7],
        [9, 8],
        [9, 1]])

In [129]:
random_token_ids = torch.randint(10, (10,2))
mask_pad_ids = torch.tensor([0, 1])

In [132]:
random_token_ids

tensor([[1, 3],
        [1, 6],
        [8, 8],
        [0, 2],
        [5, 4],
        [2, 2],
        [5, 7],
        [8, 2],
        [7, 4],
        [2, 3]])

In [131]:
for value in random_token_ids[torch.where(torch.isin(random_token_ids, mask_pad_ids))] :
    print(value)

tensor(1)
tensor(1)
tensor(0)


In [133]:
value in mask_pad_ids

True

In [142]:
torch.randint(10, (1,))

tensor([8])

In [143]:
for value in random_token_ids[torch.where(torch.isin(random_token_ids, mask_pad_ids))] :
    while value in mask_pad_ids:
        value = torch.randint(10, (1,))

In [144]:
random_token_ids

tensor([[1, 3],
        [1, 6],
        [8, 8],
        [0, 2],
        [5, 4],
        [2, 2],
        [5, 7],
        [8, 2],
        [7, 4],
        [2, 3]])

In [145]:
def generate_random_token_ids(shape, vocab_size, pad_token_id, mask_token_id):
    random_token_ids = torch.randint(0, vocab_size, shape)
    # Ensure random_token_ids do not contain pad_token_id or mask_token_id
    invalid_ids = {pad_token_id, mask_token_id}
    for i in range(random_token_ids.numel()):
        while random_token_ids.view(-1)[i].item() in invalid_ids:
            random_token_ids.view(-1)[i] = torch.randint(0, vocab_size, (1,))
    return random_token_ids

In [165]:
generate_random_token_ids((10,2), 3, 0, 2)

tensor([[1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1]])

In [168]:
for i in range(random_token_ids.numel()):
    print(i)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19


In [97]:
random_token_ids[torch.where(torch.isin(random_token_ids, mask_pad_ids))] = torch.randint(10, (1,))

In [107]:
(torch.where(torch.isin(random_token_ids, mask_pad_ids))[0].shape[0],)

(22,)

In [98]:
random_token_ids

tensor([[   2,    9,    2,    5,    5, 8999,    3,    5, 8999,    5],
        [   9,    2,    6, 8999,    5,    6, 8999,    8, 8999,    4],
        [   2,    7, 8999,    8,    4, 8999,    5,    5,    7,    4],
        [   9,    3,    4,    6,    6, 8999,    5,    7,    4,    7],
        [   6,    5,    3,    8,    6, 8999,    8, 8999,    6,    5],
        [   6,    8,    4, 8999,    5, 8999,    7,    3,    7,    9],
        [8999, 8999,    6, 8999,    7,    6,    2, 8999,    8,    6],
        [   9, 8999,    5,    9,    4, 8999,    8,    7, 8999,    2],
        [   8, 8999,    8,    5,    9,    7,    7,    7,    5,    8],
        [   2,    3,    6,    7, 8999, 8999,    3,    7,    2, 8999]])

In [95]:
torch.where(torch.isin(random_token_ids, mask_pad_ids))[0].shape[0]

23

In [69]:
idx

tensor([0, 1, 2, 7, 8, 4, 7, 1, 2, 3, 3, 4, 5, 0, 1, 5, 7, 5, 2, 4, 8])

In [38]:
generate_random_token_ids(shape[0], 10, 0, 1)

tensor([9, 6])

In [82]:
for i in zip(torch.where(torch.isin(random_token_ids, mask_pad_ids))):
    print(i)

(tensor([0, 1, 1, 1, 1, 2, 2, 3, 3, 4, 5, 5, 5, 6, 6, 6, 6, 7, 8, 9, 9]),)
(tensor([0, 1, 2, 7, 8, 4, 7, 1, 2, 3, 3, 4, 5, 0, 1, 5, 7, 5, 2, 4, 8]),)


In [83]:
i

(tensor([0, 1, 2, 7, 8, 4, 7, 1, 2, 3, 3, 4, 5, 0, 1, 5, 7, 5, 2, 4, 8]),)

In [31]:
torch.randint(2, 10000, (10,3))

tensor([[2525, 4596, 6273],
        [9089, 7084, 7376],
        [ 438, 3493,  377],
        [2301, 1600, 1750],
        [4431, 1435, 7902],
        [2200, 7751, 5985],
        [8640, 9921, 2891],
        [9185, 9840, 8964],
        [1357, 7936, 1945],
        [6555, 4625, 2808]])

In [144]:
import torch 

input_tokens = torch.tensor([
    [0, 1, 2, 3, 4, 5, 6, 7, 0, 0],
    [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
])

mlm_prob = 0.5
pad_token_id = 0
mask_token_id = 5000
vocab_size = 256

mask = torch.rand(input_tokens.size()) < mlm_prob # create a mask of true and false
mask &= input_tokens != pad_token_id # do not mask padding
mask &= input_tokens != mask_token_id # do not mask mask token
mlm_data = input_tokens.clone() # create a copy of the original tensor

## get the indices of the mask
mask_idx = mask.nonzero(as_tuple=True)

In [145]:
mask

tensor([[False,  True,  True, False,  True, False,  True,  True, False, False],
        [False, False,  True,  True, False, False,  True, False, False, False]])

In [146]:
mask_idx

(tensor([0, 0, 0, 0, 0, 1, 1, 1]), tensor([1, 2, 4, 6, 7, 2, 3, 6]))

In [147]:
## 80% is masked
random_shuffle = torch.randperm(mask_idx[0].shape[0])
random_shuffle

tensor([0, 1, 5, 4, 6, 7, 3, 2])

In [148]:
replace_idx = random_shuffle[:int(mask_idx[0].shape[0] * 0.8)]
keep_idx = random_shuffle[int(mask_idx[0].shape[0] * 0.8):int(mask_idx[0].shape[0] * 0.9)]
ramdom_idx = random_shuffle[int(mask_idx[0].shape[0] * 0.9):]

In [159]:
random_shuffle[int(mask_idx[0].shape[0] * 0.9):]

tensor([2])

In [167]:
int(mask_idx[0].shape[0] * 0.1)

0

In [165]:
random_shuffle[int(mask_idx[0].shape[0] * 0.9):]

tensor([2])

In [149]:
replace_idx

tensor([0, 1, 5, 4, 6, 7])

In [150]:
keep_idx

tensor([3])

In [151]:
ramdom_idx

tensor([2])

In [152]:
mask_idx[0][replace_idx], mask_idx[1][replace_idx]

(tensor([0, 0, 1, 0, 1, 1]), tensor([1, 2, 2, 7, 3, 6]))

In [153]:
mask_idx[0][keep_idx], mask_idx[1][keep_idx]

(tensor([0]), tensor([6]))

In [154]:
mask_idx[0][ramdom_idx], mask_idx[1][ramdom_idx]

(tensor([0]), tensor([4]))

In [155]:
mlm_data[mask_idx[0][replace_idx], mask_idx[1][replace_idx]] = 5000

In [156]:
# replace with random token
mlm_data[mask_idx[0][ramdom_idx], mask_idx[1][ramdom_idx]] = torch.randint(0, vocab_size, (ramdom_idx.shape[0],))

In [157]:
mlm_data

tensor([[   0, 5000, 5000,    3,  111,    5,    6, 5000,    0,    0],
        [   0,    1, 5000, 5000,    4,    5, 5000,    7,    8,    9]])

In [158]:
labels = input_tokens.clone()
labels[~mask] = -100
labels

tensor([[-100,    1,    2, -100,    4, -100,    6,    7, -100, -100],
        [-100, -100,    2,    3, -100, -100,    6, -100, -100, -100]])

In [125]:
mlm_data[mask_idx[0][ramdom_idx], mask_idx[1][ramdom_idx]]

tensor([1])

In [137]:
torch.randint(0, vocab_size, (ramdom_idx.shape[0],))

tensor([156])

In [10]:
import torch
input_tokens = torch.tensor([
    [0, 1, 2, 3, 4, 5, 6, 7, 0, 0],
    [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
])

In [11]:
input_tokens

tensor([[0, 1, 2, 3, 4, 5, 6, 7, 0, 0],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])

In [29]:
mask = torch.randn(input_tokens.shape) < 0.15
mask &= (input_tokens != 0)

mlm_data = input_tokens.clone()
mask_idx = mask.nonzero(as_tuple=True)
mlm_data[mask_idx] = 5000

In [57]:
mask

tensor([[False,  True, False, False, False, False,  True,  True, False, False],
        [False,  True, False,  True,  True,  True, False, False,  True, False]])

In [30]:
mask_idx

(tensor([0, 0, 0, 1, 1, 1, 1, 1]), tensor([1, 6, 7, 1, 3, 4, 5, 8]))

In [31]:
mlm_data

tensor([[   0, 5000,    2,    3,    4,    5, 5000, 5000,    0,    0],
        [   0, 5000,    2, 5000, 5000, 5000,    6,    7, 5000,    9]])

In [32]:
replace_idx = mask_idx[0][torch.randperm(mask_idx[0].shape[0])] 
keep_idx = mask_idx[0][torch.randperm(mask_idx[0].shape[0])]

In [35]:
rand_shuffle = torch.randperm(mask_idx[0].shape[0])
mask_idx_shuffle = [mask_idx[0][rand_shuffle],mask_idx[1][rand_shuffle]]
mask_idx_shuffle

[tensor([1, 1, 1, 0, 1, 1, 0, 0]), tensor([5, 3, 4, 6, 8, 1, 7, 1])]

In [38]:
mlm_data[mask_idx_shuffle]

tensor([5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000])

In [56]:
mask_idx_shuffle[0][:int(0.8*mask_idx_shuffle[0].shape[0])]

tensor([1, 1, 1, 0, 1, 1])

In [55]:
[:int(0.8*mask_idx_shuffle[0].shape[0])]

SyntaxError: invalid syntax (3710636772.py, line 1)

In [5]:
mask_idx_shuffle = [mask_idx[0][torch.randperm(mask_idx[0].shape[0])],mask_idx[1][torch.randperm(mask_idx[0].shape[0])]]
mask_idx_shuffle

[tensor([0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0]),
 tensor([5, 2, 8, 3, 1, 4, 4, 6, 9, 1, 7, 5, 3])]

In [7]:
replace_idx = mask_idx[0][torch.randperm(mask_idx[0].shape[0])]
replace_idx

tensor([0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0])

In [8]:
mlm_data[replace_idx[:int(0.8*len(replace_idx))], mask_idx[1][:int(0.8*len(replace_idx))]] = 5000

In [9]:
mlm_data

tensor([[   0, 5000, 5000, 5000, 5000, 5000, 5000, 5000,    0,    0],
        [   0, 5000, 5000, 5000, 5000, 5000,    6, 5000, 5000, 5000]])

In [135]:
mask_idx[0][torch.randperm(mask_idx[0].shape[0])]

tensor([0, 1, 1, 0, 0, 0, 0, 1, 1, 0])

In [137]:
replace_idx[:int(0.8*len(replace_idx))], mask_idx[1][:int(0.8*len(replace_idx))]

(tensor([1, 1, 1, 0, 0, 0, 1]), tensor([2, 3, 4, 5, 6, 7, 5]))

In [125]:
def mlm_masking(data, mlm_prob = 0.15, pad_token_id = 0, mask_token_id = 1, vocab_size = 256):
    '''
    Masking the input data for the MLM task
    '''
    mask = torch.rand(data.shape) < mlm_prob
    mask &= (data != pad_token_id) # do not mask padding
    mask &= (data != mask_token_id) # do not mask mask token
    mlm_data = data.clone()
    mask_idx = mask.nonzero(as_tuple=True)
    mlm_data[mask_idx] = mask_token_id

    # 80% of the time, replace with [MASK]
    replace_idx = mask_idx[0][torch.randperm(mask_idx[0].shape[0])]
    mlm_data[replace_idx[:int(0.8*len(replace_idx))], mask_idx[1][:int(0.8*len(replace_idx))]] = mask_token_id

    # 10% of the time, keep original
    keep_idx = mask_idx[0][torch.randperm(mask_idx[0].shape[0])]
    mlm_data[keep_idx[int(0.8*len(keep_idx)):], mask_idx[1][int(0.8*len(keep_idx))]] = data[keep_idx[int(0.8*len(keep_idx)):], mask_idx[1][int(0.8*len(keep_idx)):]]
    return mlm_data, mask

In [134]:
mlm_masking(input_tokens)

(tensor([[0, 1, 1, 3, 4, 5, 6, 7, 0, 0],
         [0, 1, 2, 3, 1, 5, 6, 7, 8, 9]]),
 tensor([[False, False,  True, False, False, False, False, False, False, False],
         [False, False, False, False,  True, False, False, False, False, False]]))