# BERT - Bidirectional Encoder Representations from Transformers
GPU enabled

BERT is designed to pretrain deep bidirectional representations from
unlabeled text by jointly conditioning on both
left and right context in all layers. As a result, the pre-trained BERT model can be finetuned with just one additional output layer
to create state-of-the-art models for a wide
range of tasks, such as question answering and
language inference, without substantial taskspecific architecture modifications.

BERT is conceptually simple and empirically
powerful. It obtains new state-of-the-art results on eleven natural language processing
tasks, including pushing the GLUE score to
80.5% (7.7% point absolute improvement),
MultiNLI accuracy to 86.7% (4.6% absolute
improvement), SQuAD v1.1 question answering Test F1 to 93.2 (1.5 point absolute improvement) and SQuAD v2.0 Test F1 to 83.1
(5.1 point absolute improvement).

In [None]:
# Imports
# Built-in imports
import os
import math
import re
import time
import copy
import random
from tqdm.auto import tqdm
# from random import *

# Basic processing and visualization
import numpy as np
import matplotlib.pyplot as plt

# NN
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.parallel import DistributedDataParallel as DDP
from transformers import BertTokenizer

# Huggingface dataset loader
from datasets import load_dataset

In [None]:
# Lets choose available gpus for our multigpu training
# os.environ['CUDA_VISIBLE_DEVICES'] = "0, 1, 3"

# Puffer proxy - redundant
# os.environ['http_proxy']  = 'http://192.41.170.23:3128'
# os.environ['https_proxy'] = 'http://192.41.170.23:3128'

# GPU selection
# choosing last gpu as main since others are occupied
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_gpus = torch.cuda.device_count()
# torch.cuda.device_count(), device

# Randomness configs
SEED = 52
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

# Global contstants
BATCH_SIZE = 32
MAX_MASK   = 5 # max masked tokens when 15% exceed, it will only be max_pred
MAX_LEN    = 1000 # maximum of length to be padded; <- reduce if gpu memory issue
num_epoch = 100 # CAPTTALIZE?

# Model params
n_layers = 12    # number of Encoder of Encoder Layer
n_heads  = 12    # number of heads in Multi-Head Attention
d_model  = 768  # Embedding Size
d_ff = d_model * 4  # 4*d_model, FeedForward dimension
d_k = d_v = 64  # dimension of K(=Q), V
n_segments = 2

## 1. Data

For simplicity, we shall use very simple data like this.

In [None]:
dataset = load_dataset('bookcorpus', split='train[:1%]')
# Dataset({
#     features: ['text'],
#     num_rows: 740042
# })

sentences = dataset['text'][:10000] # will take only 10k sentences since the training time is too slow
text = [x.lower() for x in sentences]
text = [re.sub("[.,!?\\-]", '', x) for x in text]
# len(text)
#740042

# # Combine everything into one to make vocab
word_list = list(set(" ".join(text).split()))
word2id = {'[PAD]': 0, '[CLS]': 1, '[SEP]': 2, '[MASK]': 3}  # special tokens

# Create the word2id in a single pass
for i, w in tqdm(enumerate(word_list), desc="Creating word2id"):
    word2id[w] = i + 4  # because 0-3 are already occupied

# Precompute the id2word mapping (this can be done once after word2id is fully populated)
id2word = {v: k for k, v in word2id.items()}
vocab_size = len(word2id)
# 60305 (unique words)

# List of all tokens for the whole text
token_list = []

# Process sentences more efficiently
for sentence in tqdm(text, desc="Processing sentences"):
    token_list.append([word2id[word] for word in sentence.split()]) # "Hello darkness my old friend".split() -> numericalization i.e. [3, 5, 2, 1, 6]

#token_list.size == text.size

## 2. Data loader

We gonna make dataloader.  Inside here, we need to make two types of embeddings: **token embedding** and **segment embedding**

1. **Token embedding** - Given “The cat is walking. The dog is barking”, we add [CLS] and [SEP] >> “[CLS] the cat is walking [SEP] the dog is barking”. 

2. **Segment embedding**
A segment embedding separates two sentences, i.e., [0 0 0 0 1 1 1 1 ]

3. **Masking**
As mentioned in the original paper, BERT randomly assigns masks to 15% of the sequence. In this 15%, 80% is replaced with masks, while 10% is replaced with random tokens, and the rest 10% is left as is.  Here we specified `max_pred` 

4. **Padding**
Once we mask, we will add padding. For simplicity, here we padded until some specified `max_len`. 

Note:  `positive` and `negative` are just simply counts to keep track of the batch size.  `positive` refers to two sentences that are really next to one another.

In [None]:
# custom dataloader - should be improved
# I won't use this dataloader since it is too slow + with the logic in bert_cleaned we are training model only with one sentence
# which is incorrect
# Will use this only for inference
def make_batch():
    batch = []
    positive = negative = 0  #count of batch size;  we want to have half batch that are positive pairs (i.e., next sentence pairs)
    while positive != BATCH_SIZE/2 or negative != BATCH_SIZE/2:
        
        #randomly choose two sentence so we can put [SEP]
        tokens_a_index, tokens_b_index = randrange(len(sentences)), randrange(len(sentences))
        #retrieve the two sentences
        tokens_a, tokens_b = token_list[tokens_a_index], token_list[tokens_b_index]

        #1. token embedding - append CLS and SEP
        input_ids = [word2id['[CLS]']] + tokens_a + [word2id['[SEP]']] + tokens_b + [word2id['[SEP]']]

        #2. segment embedding - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]
        segment_ids = [0] * (1 + len(tokens_a) + 1) + [1] * (len(tokens_b) + 1)

        #3. mask language modeling
        #masked 15%, but should be at least 1 but does not exceed max_mask
        n_pred =  min(MAX_MASK, max(1, int(round(len(input_ids) * 0.15))))
        #get the pos that excludes CLS and SEP and shuffle them
        cand_maked_pos = [i for i, token in enumerate(input_ids) if token != word2id['[CLS]'] and token != word2id['[SEP]']]
        shuffle(cand_maked_pos)
        masked_tokens, masked_pos = [], []
        #simply loop and change the input_ids to [MASK]
        for pos in cand_maked_pos[:n_pred]:
            masked_pos.append(pos)  #remember the position
            masked_tokens.append(input_ids[pos]) #remember the tokens
            #80% replace with a [MASK], but 10% will replace with a random token
            if random() < 0.1:  # 10%
                index = randint(0, vocab_size - 1) # random index in vocabulary
                input_ids[pos] = word2id[id2word[index]] # replace
            elif random() < 0.8:  # 80%
                input_ids[pos] = word2id['[MASK]'] # make mask
            else:  #10% do nothing
                pass

        # pad the input_ids and segment ids until the max len
        n_pad = MAX_LEN - len(input_ids)
        input_ids.extend([0] * n_pad)
        segment_ids.extend([0] * n_pad)

        # pad the masked_tokens and masked_pos to make sure the lenth is max_mask
        if MAX_MASK > n_pred:
            n_pad = MAX_MASK - n_pred
            masked_tokens.extend([0] * n_pad)
            masked_pos.extend([0] * n_pad)

        #check if first sentence is really comes before the second sentence
        #also make sure positive is exactly half the batch size
        if tokens_a_index + 1 == tokens_b_index and positive < BATCH_SIZE / 2:
            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, True]) # IsNext
            positive += 1
        elif tokens_a_index + 1 != tokens_b_index and negative < BATCH_SIZE/2:
            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, False]) # NotNext
            negative += 1
            
    return batch


# Dataset
class BERTDataset(Dataset):
    def __init__(self, sentences, token_list, word2id, id2word, vocab_size, max_len=128, max_mask=10):
        self.sentences = sentences
        self.token_list = token_list
        self.word2id = word2id
        self.id2word = id2word
        self.vocab_size = vocab_size
        self.max_len = max_len
        self.max_mask = max_mask

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        # Randomly select two sentences (positive or negative pair)
        tokens_a_index = random.randint(0, len(self.sentences) - 2)  # Ensure a next sentence exists
        is_next = random.random() < 0.5  # 50% chance to be a positive pair

        if is_next:
            tokens_b_index = tokens_a_index + 1
        else:
            tokens_b_index = random.randint(0, len(self.sentences) - 1)
            while tokens_b_index == tokens_a_index + 1:
                tokens_b_index = random.randint(0, len(self.sentences) - 1)

        tokens_a = self.token_list[tokens_a_index]
        tokens_b = self.token_list[tokens_b_index]


         # 1. Token Embedding (Insert CLS and SEP)
        input_ids = [self.word2id['[CLS]']] + tokens_a + [self.word2id['[SEP]']] + tokens_b + [self.word2id['[SEP]']]

        # 2. Segment Embedding
        segment_ids = [0] * (len(tokens_a) + 2) + [1] * (len(tokens_b) + 1)

        # 3. Masked Language Modeling (15% tokens masked)
        valid_pos = [i for i in range(1, len(input_ids) - 1) if input_ids[i] != self.word2id['[SEP]']]
        random.shuffle(valid_pos)
        n_pred = min(self.max_mask, max(1, int(len(input_ids) * 0.15)))
        masked_pos = valid_pos[:n_pred]
        masked_tokens = [input_ids[pos] for pos in masked_pos]

        for pos in masked_pos:
            prob = random.random()
            if prob < 0.8:  # 80% chance to replace with [MASK]
                input_ids[pos] = self.word2id['[MASK]']
            elif prob < 0.9:  # 10% chance to replace with a random token
                input_ids[pos] = self.word2id[self.id2word[random.randint(0, self.vocab_size - 1)]]
            # 10% chance remains unchanged

        # Padding
        pad_len = self.max_len - len(input_ids)
        input_ids.extend([0] * pad_len)
        segment_ids.extend([0] * pad_len)

        mask_pad_len = self.max_mask - n_pred
        masked_tokens.extend([0] * mask_pad_len)
        masked_pos.extend([0] * mask_pad_len)

        return {
            "input_ids": torch.tensor(input_ids, dtype=torch.long),
            "segment_ids": torch.tensor(segment_ids, dtype=torch.long),
            "masked_tokens": torch.tensor(masked_tokens, dtype=torch.long),
            "masked_pos": torch.tensor(masked_pos, dtype=torch.long),
            "is_next": torch.tensor(1 if is_next else 0, dtype=torch.long)
        }


In [None]:
dataset = BERTDataset(sentences, token_list, word2id, id2word, vocab_size, max_len=MAX_LEN, max_mask=MAX_MASK)

# Create DataLoader with multi-threading for efficiency
train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)

In [None]:
# Check a sample batch
sample_batch = next(iter(train_loader))
for key, value in sample_batch.items():
    print(f"{key}: {value.shape}")


## 3. Model

Recall that BERT only uses the encoder.

BERT has the following components:

- Embedding layers
- Attention Mask
- Encoder layer
- Multi-head attention
- Scaled dot product attention
- Position-wise feed-forward network
- BERT (assembling all the components)

In [None]:
# Embedding
# Here we simply generate the positional embedding, and sum the token embedding, positional embedding, and segment embedding together.
class Embedding(nn.Module):
    def __init__(self, vocab_size, max_len, n_segments, d_model, device):
        super(Embedding, self).__init__()
        self.tok_embed = nn.Embedding(vocab_size, d_model)  # token embedding
        self.pos_embed = nn.Embedding(max_len, d_model)      # position embedding
        self.seg_embed = nn.Embedding(n_segments, d_model)  # segment(token type) embedding
        self.norm = nn.LayerNorm(d_model)
        self.device = device

    def forward(self, x, seg):
        #x, seg: (bs, len)
        seq_len = x.size(1)
        pos = torch.arange(seq_len, dtype=torch.long).to(self.device)
        pos = pos.unsqueeze(0).expand_as(x)  # (len,) -> (bs, len)
        embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)
        return self.norm(embedding)

# attention mask
def get_attn_pad_mask(seq_q, seq_k, device):
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    # eq(zero) is PAD token
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1).to(device)  # batch_size x 1 x len_k(=len_q), one is masking
    return pad_attn_mask.expand(batch_size, len_q, len_k)  # batch_size x len_q x len_k

# encoder
# The encoder has two main components:
# Multi-head Attention
# Position-wise feed-forward network
class EncoderLayer(nn.Module):
    def __init__(self, n_heads, d_model, d_ff, d_k, device):
        super(EncoderLayer, self).__init__()
        self.enc_self_attn = MultiHeadAttention(n_heads, d_model, d_k, device)
        self.pos_ffn       = PoswiseFeedForwardNet(d_model, d_ff)

    def forward(self, enc_inputs, enc_self_attn_mask):
        enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) # enc_inputs to same Q,K,V
        enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size x len_q x d_model]
        return enc_outputs, attn

# Hm, I think its not efficient - rewrite?
class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_k, device):
        super(ScaledDotProductAttention, self).__init__()
        self.scale = torch.sqrt(torch.FloatTensor([d_k])).to(device)

    def forward(self, Q, K, V, attn_mask):
        scores = torch.matmul(Q, K.transpose(-1, -2)) / self.scale # scores : [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
        scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is one.
        attn = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(attn, V)
        return context, attn

class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads, d_model, d_k, device):
        super(MultiHeadAttention, self).__init__()
        self.n_heads = n_heads
        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_k
        self.W_Q = nn.Linear(d_model, d_k * n_heads)
        self.W_K = nn.Linear(d_model, d_k * n_heads)
        self.W_V = nn.Linear(d_model, self.d_v * n_heads)
        self.device = device
    def forward(self, Q, K, V, attn_mask):
        # q: [batch_size x len_q x d_model], k: [batch_size x len_k x d_model], v: [batch_size x len_k x d_model]
        residual, batch_size = Q, Q.size(0)
        # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
        q_s = self.W_Q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)  # q_s: [batch_size x n_heads x len_q x d_k]
        k_s = self.W_K(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)  # k_s: [batch_size x n_heads x len_k x d_k]
        v_s = self.W_V(V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1,2)  # v_s: [batch_size x n_heads x len_k x d_v]

        attn_mask = attn_mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1) # attn_mask : [batch_size x n_heads x len_q x len_k]

        # context: [batch_size x n_heads x len_q x d_v], attn: [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
        context, attn = ScaledDotProductAttention(self.d_k, self.device)(q_s, k_s, v_s, attn_mask)
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_v) # context: [batch_size x len_q x n_heads * d_v]
        output = nn.Linear(self.n_heads * self.d_v, self.d_model, device=self.device)(context)
        return nn.LayerNorm(self.d_model, device=self.device)(output + residual), attn # output: [batch_size x len_q x d_model]


class PoswiseFeedForwardNet(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        # (batch_size, len_seq, d_model) -> (batch_size, len_seq, d_ff) -> (batch_size, len_seq, d_model)
        return self.fc2(F.gelu(self.fc1(x)))


class BERT(nn.Module):
    def __init__(self, n_layers, n_heads, d_model, d_ff, d_k, n_segments, vocab_size, max_len, device):
        super(BERT, self).__init__()
        self.params = {'n_layers': n_layers, 'n_heads': n_heads, 'd_model': d_model,
                       'd_ff': d_ff, 'd_k': d_k, 'n_segments': n_segments,
                       'vocab_size': vocab_size, 'max_len': max_len}
        self.embedding = Embedding(vocab_size, max_len, n_segments, d_model, device)
        self.layers = nn.ModuleList([EncoderLayer(n_heads, d_model, d_ff, d_k, device) for _ in range(n_layers)])
        self.fc = nn.Linear(d_model, d_model)
        self.activ = nn.Tanh()
        self.linear = nn.Linear(d_model, d_model)
        self.norm = nn.LayerNorm(d_model)
        self.classifier = nn.Linear(d_model, 2)
        # decoder is shared with embedding layer
        embed_weight = self.embedding.tok_embed.weight
        n_vocab, n_dim = embed_weight.size()
        self.decoder = nn.Linear(n_dim, n_vocab, bias=False)
        self.decoder.weight = embed_weight
        self.decoder_bias = nn.Parameter(torch.zeros(n_vocab))
        self.device = device

    def forward(self, input_ids, segment_ids, masked_pos):
        output = self.embedding(input_ids, segment_ids)
        enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids, self.device)
        for layer in self.layers:
            output, enc_self_attn = layer(output, enc_self_attn_mask)
        # output : [batch_size, len, d_model], attn : [batch_size, n_heads, d_mode, d_model]
        
        # 1. predict next sentence
        # it will be decided by first token(CLS)
        h_pooled   = self.activ(self.fc(output[:, 0])) # [batch_size, d_model]
        logits_nsp = self.classifier(h_pooled) # [batch_size, 2]

        # 2. predict the masked token
        masked_pos = masked_pos[:, :, None].expand(-1, -1, output.size(-1)) # [batch_size, max_pred, d_model]
        h_masked = torch.gather(output, 1, masked_pos) # masking position [batch_size, max_pred, d_model]
        h_masked  = self.norm(F.gelu(self.linear(h_masked)))
        logits_lm = self.decoder(h_masked) + self.decoder_bias # [batch_size, max_pred, n_vocab]

        return logits_lm, logits_nsp
    
    def get_last_hidden_state(self, input_ids, segment_ids):
        output = self.embedding(input_ids, segment_ids)
        enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids, self.device)
        for layer in self.layers:
            output, enc_self_attn = layer(output, enc_self_attn_mask)

        return output


In [None]:
# Model instantiation
model = BERT(
    n_layers, 
    n_heads, 
    d_model, 
    d_ff, 
    d_k, 
    n_segments, 
    vocab_size, 
    MAX_LEN, 
    device # DataParallel requires main model to be in cuda:0
)
# Making split into several available gpus
# device_ids = [0, 1, 2]  # Maps to CUDA_VISIBLE_DEVICES=[1,2,3], so these are actually GPUs 1,2,3
if num_gpus > 1:
    model = nn.DataParallel(model) # , device_ids=device_ids)  # Enable multi-GPU training
# Moving model to main gpu
model.to(device) # DataParallel requires main model to be in cuda:0

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# # Take data
# batch = make_batch()
# # Move into pytorch
# input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(torch.LongTensor, zip(*batch))

# # Move inputs to GPU
# input_ids = input_ids.to(device)
# segment_ids = segment_ids.to(device)
# masked_tokens = masked_tokens.to(device)
# masked_pos = masked_pos.to(device)
# isNext = isNext.to(device)

In [None]:
# To show the training graph
loss_history = []
best_model = None
best_loss = float('inf')
epoch_number = 0

# TODO: Move to train function
# Wrap the epoch loop with tqdm
start_time = time.time()
model.train()
for epoch in tqdm(range(num_epoch), desc="Training Epochs"):
    start_epoch = time.time()
    epoch_loss = 0.0

    for batch in train_loader:
        input_ids = batch["input_ids"].to(device)
        segment_ids = batch["segment_ids"].to(device)
        masked_tokens = batch["masked_tokens"].to(device)
        masked_pos = batch["masked_pos"].to(device)
        isNext = batch["is_next"].to(device)
    
        optimizer.zero_grad()

        # Forward pass
        logits_lm, logits_nsp = model(input_ids, segment_ids, masked_pos)    
        #logits_lm: (bs, max_mask, vocab_size) ==> (6, 5, 34)
        #logits_nsp: (bs, yes/no) ==> (6, 2)

        #1. mlm loss
        #logits_lm.transpose: (bs, vocab_size, max_mask) vs. masked_tokens: (bs, max_mask)
        loss_lm = criterion(logits_lm.transpose(1, 2), masked_tokens) # for masked LM
        loss_lm = (loss_lm.float()).mean()
        
        #2. nsp loss
        #logits_nsp: (bs, 2) vs. isNext: (bs, )
        loss_nsp = criterion(logits_nsp, isNext) # for sentence classification
        
        #3. combine loss
        loss = loss_lm + loss_nsp
    
        loss.backward()
        optimizer.step()

        # Logging
        cur_loss = loss.item()
        epoch_loss += cur_loss

    avg_loss = epoch_loss / BATCH_SIZE
    loss_history.append(avg_loss)
    
    if avg_loss < best_loss:
        best_loss = avg_loss
        epoch_number = epoch
        best_model = copy.deepcopy(model.state_dict())
        # print('Epoch:', '%02d' % (epoch), 'best loss saved =', '{:.6f}'.format(cur_loss))
        
    if epoch % 100 == 0:
        print('Epoch:', '%02d' % (epoch), 'loss =', '{:.6f}'.format(cur_loss))
        elapsed_epoch = time.time() - start_epoch
        print("Epoch time taken: ", elapsed_epoch)
        
time_elapsed = time.time() - start_time
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

# Save the model after training
model_name = f'bert_model_epoch_{str(epoch_number)}_loss_{str(best_loss)}.pth'
torch.save(model.state_dict(), model_name)
print("Model saved to bert_model.pth")

In [None]:
def plot_data(loss_acc_history, fname="train_loss_updated"):
    plt.plot(loss_acc_history, label = 'Validation')
    plt.title('Loss per epoch')
    plt.legend()
    plt.show()
    # plt.plot(v?al_acc_history, label = 'Validation')
    # plt.title('Accuracy per epoch')
    # plt.legend()
    # plt.show()
    plt.savefig(f'{fname}.png')

plot_data(loss_history, f'train_loss_updated_{time.ctime()}')

## 5. Inference

Since our dataset is very small, it won't work very well, but just for the sake of demonstration.

In [None]:
# Predict mask tokens ans isNext
batch = make_batch()
input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(torch.LongTensor, zip(batch[2]))
print([id2word[w.item()] for w in input_ids[0] if id2word[w.item()] != '[PAD]'])
input_ids = input_ids.to(device)
segment_ids = segment_ids.to(device)
masked_tokens = masked_tokens.to(device)
masked_pos = masked_pos.to(device)
isNext = isNext.to(device)

logits_lm, logits_nsp = model(input_ids, segment_ids, masked_pos)
#logits_lm:  (1, max_mask, vocab_size) ==> (1, 5, 34)
#logits_nsp: (1, yes/no) ==> (1, 2)

#predict masked tokens
#max the probability along the vocab dim (2), [1] is the indices of the max, and [0] is the first value
logits_lm = logits_lm.data.cpu().max(2)[1][0].data.numpy() 
#note that zero is padding we add to the masked_tokens
print('masked tokens (words) : ',[id2word[pos.item()] for pos in masked_tokens[0]])
print('masked tokens list : ',[pos.item() for pos in masked_tokens[0]])
print('masked tokens (words) : ',[id2word[pos.item()] for pos in logits_lm])
print('predict masked tokens list : ', [pos for pos in logits_lm])

#predict nsp
logits_nsp = logits_nsp.cpu().data.max(1)[1][0].data.numpy()
print(logits_nsp)
print('isNext : ', True if isNext else False)
print('predict isNext : ',True if logits_nsp else False)
