# BERT (Updated 1 Feb 2025, Available CUDA)



In [1]:
import math
import re
from   random import *
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os

In [2]:
# Set GPU device
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

# os.environ['http_proxy']  = 'http://192.41.170.23:3128'
# os.environ['https_proxy'] = 'http://192.41.170.23:3128'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
print(device)

#make our work comparable if restarted the kernel
SEED = 1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

# torch.cuda.get_device_name(0)

cuda


In [3]:
import torch
print("CUDA available:", torch.cuda.is_available())
print("Visible device count:", torch.cuda.device_count())
print("Device 0 name:", torch.cuda.get_device_name(0))


CUDA available: True
Visible device count: 1
Device 0 name: NVIDIA GeForce RTX 2080 Ti


## 1. Data

For simplicity, we shall use very simple data like this.

In [4]:
from datasets import load_dataset

dataset = load_dataset(
    "wikipedia",
    "20220301.en",
    split="train[:100000]"
)

dataset


Dataset({
    features: ['id', 'url', 'title', 'text'],
    num_rows: 100000
})

In [5]:
#removing unwanted columns
dataset = dataset.remove_columns(['id', 'url', 'title'])
dataset

Dataset({
    features: ['text'],
    num_rows: 100000
})

In [6]:
sentences = dataset['text']
sentences = [x.replace("\n", " ") for x in sentences]
sentences = [x for x in sentences if len(x.split()) <= 500]
sentences[:2]

["The Academy Award for Best Production Design recognizes achievement for art direction in film. The category's original name was Best Art Direction, but was changed to its current name in 2012 for the 85th Academy Awards. This change resulted from the Art Director's branch of the Academy of Motion Picture Arts and Sciences (AMPAS) being renamed the Designer's branch. Since 1947, the award is shared with the set decorator(s). It is awarded to the best interior design in a film.  The films below are listed with their production year (for example, the 2000 Academy Award for Best Art Direction is given to a film from 1999). In the lists below, the winner of the award for each year is shown first, followed by the other nominees in alphabetical order.  Superlatives  Winners and nominees  1920s  1930s  1940s  1950s  1960s  1970s  1980s  1990s  2000s  2010s  2020s  See also  BAFTA Award for Best Production Design  Critics' Choice Movie Award for Best Production Design  Notes  References  Best

## Making Vocabs

In [7]:
text = [x.lower() for x in sentences] #lower case
text = [re.sub("[.,!?\\-]", '', x) for x in text] #clean all symbols
text[:1]

["the academy award for best production design recognizes achievement for art direction in film the category's original name was best art direction but was changed to its current name in 2012 for the 85th academy awards this change resulted from the art director's branch of the academy of motion picture arts and sciences (ampas) being renamed the designer's branch since 1947 the award is shared with the set decorator(s) it is awarded to the best interior design in a film  the films below are listed with their production year (for example the 2000 academy award for best art direction is given to a film from 1999) in the lists below the winner of the award for each year is shown first followed by the other nominees in alphabetical order  superlatives  winners and nominees  1920s  1930s  1940s  1950s  1960s  1970s  1980s  1990s  2000s  2010s  2020s  see also  bafta award for best production design  critics' choice movie award for best production design  notes  references  best production 

### Making vocabs

Before making the vocabs, let's remove all question marks and perios, etc, then turn everything to lowercase, and then simply split the text. 

In [8]:
from tqdm.auto import tqdm

# Combine everything into one to make vocab
word_list = list(set(" ".join(text).split()))
word2id = {'[PAD]': 0, '[CLS]': 1, '[SEP]': 2, '[MASK]': 3}  # special tokens

# Create the word2id in a single pass
for i, w in tqdm(enumerate(word_list), desc="Creating word2id"):
    word2id[w] = i + 4  # because 0-3 are already occupied

# Precompute the id2word mapping (this can be done once after word2id is fully populated)
id2word = {v: k for k, v in word2id.items()}
vocab_size = len(word2id)
vocab_size

Creating word2id: 0it [00:00, ?it/s]

323367

In [9]:
# List of all tokens for the whole text
token_list = []

# Process sentences more efficiently
for sentence in tqdm(text, desc="Processing sentences"):
    token_list.append([word2id[word] for word in sentence.split()])

Processing sentences:   0%|          | 0/27773 [00:00<?, ?it/s]

In [10]:
vocab_size

323367

In [11]:
#take a look at sentences
sentences[:2]

["The Academy Award for Best Production Design recognizes achievement for art direction in film. The category's original name was Best Art Direction, but was changed to its current name in 2012 for the 85th Academy Awards. This change resulted from the Art Director's branch of the Academy of Motion Picture Arts and Sciences (AMPAS) being renamed the Designer's branch. Since 1947, the award is shared with the set decorator(s). It is awarded to the best interior design in a film.  The films below are listed with their production year (for example, the 2000 Academy Award for Best Art Direction is given to a film from 1999). In the lists below, the winner of the award for each year is shown first, followed by the other nominees in alphabetical order.  Superlatives  Winners and nominees  1920s  1930s  1940s  1950s  1960s  1970s  1980s  1990s  2000s  2010s  2020s  See also  BAFTA Award for Best Production Design  Critics' Choice Movie Award for Best Production Design  Notes  References  Best

In [12]:
#take a look at token_list
token_list[:2]

[[153988,
  110810,
  59780,
  123729,
  31982,
  103915,
  150829,
  15691,
  5729,
  123729,
  67917,
  241634,
  243487,
  90392,
  153988,
  309380,
  236575,
  237311,
  146633,
  31982,
  67917,
  241634,
  237885,
  146633,
  167084,
  126472,
  33581,
  192356,
  237311,
  243487,
  176880,
  123729,
  153988,
  266440,
  110810,
  211190,
  175199,
  208446,
  188702,
  283853,
  153988,
  67917,
  190942,
  237419,
  254205,
  153988,
  110810,
  254205,
  25033,
  277563,
  179791,
  150880,
  185606,
  306523,
  215477,
  123178,
  153988,
  31704,
  237419,
  269890,
  105166,
  153988,
  59780,
  119892,
  279792,
  78155,
  153988,
  79880,
  151937,
  269442,
  119892,
  16994,
  126472,
  153988,
  31982,
  287657,
  150829,
  243487,
  75545,
  90392,
  153988,
  174318,
  135473,
  85012,
  305449,
  78155,
  156953,
  103915,
  195114,
  76584,
  282906,
  153988,
  196290,
  110810,
  59780,
  123729,
  31982,
  67917,
  241634,
  119892,
  183324,
  126472,
  7554

In [13]:
#testing one sentence
for tokens in token_list[0]:
    print(id2word[tokens])

the
academy
award
for
best
production
design
recognizes
achievement
for
art
direction
in
film
the
category's
original
name
was
best
art
direction
but
was
changed
to
its
current
name
in
2012
for
the
85th
academy
awards
this
change
resulted
from
the
art
director's
branch
of
the
academy
of
motion
picture
arts
and
sciences
(ampas)
being
renamed
the
designer's
branch
since
1947
the
award
is
shared
with
the
set
decorator(s)
it
is
awarded
to
the
best
interior
design
in
a
film
the
films
below
are
listed
with
their
production
year
(for
example
the
2000
academy
award
for
best
art
direction
is
given
to
a
film
from
1999)
in
the
lists
below
the
winner
of
the
award
for
each
year
is
shown
first
followed
by
the
other
nominees
in
alphabetical
order
superlatives
winners
and
nominees
1920s
1930s
1940s
1950s
1960s
1970s
1980s
1990s
2000s
2010s
2020s
see
also
bafta
award
for
best
production
design
critics'
choice
movie
award
for
best
production
design
notes
references
best
production
design
awards
for
best

## 2. Data loader

We gonna make dataloader.  Inside here, we need to make two types of embeddings: **token embedding** and **segment embedding**

1. **Token embedding** - Given “The cat is walking. The dog is barking”, we add [CLS] and [SEP] >> “[CLS] the cat is walking [SEP] the dog is barking”. 

2. **Segment embedding**
A segment embedding separates two sentences, i.e., [0 0 0 0 1 1 1 1 ]

3. **Masking**
As mentioned in the original paper, BERT randomly assigns masks to 15% of the sequence. In this 15%, 80% is replaced with masks, while 10% is replaced with random tokens, and the rest 10% is left as is.  Here we specified `max_pred` 

4. **Padding**
Once we mask, we will add padding. For simplicity, here we padded until some specified `max_len`. 

Note:  `positive` and `negative` are just simply counts to keep track of the batch size.  `positive` refers to two sentences that are really next to one another.

In [14]:
batch_size = 3
max_mask   = 5  # max masked tokens when 15% exceed, it will only be max_pred
max_len    = 1000 # maximum of length to be padded; 

In [15]:
def make_batch():
    batch = []
    half_batch_size = batch_size // 2
    positive = negative = 0
    while positive != half_batch_size or negative != half_batch_size:

        #randomly choose two sentence
        tokens_a_index, tokens_b_index = np.random.randint(len(sentences), size=2)
        tokens_a, tokens_b            = token_list[tokens_a_index], token_list[tokens_b_index]

        #1. token embedding - add CLS and SEP
        input_ids = [word2id['[CLS]']] + tokens_a + [word2id['[SEP]']] + tokens_b + [word2id['[SEP]']]

        #2. segment embedding - which sentence is 0 and 1
        segment_ids = [0] * (1 + len(tokens_a) + 1) + [1] * (len(tokens_b) + 1)

        n_pred = min(max_mask, max(1, int(round(len(input_ids) * 0.15))))
        #get all the pos excluding CLS and SEP
        candidates_masked_pos = [i for i, token in enumerate(input_ids) if token != word2id['[CLS]']
                                 and token != word2id['[SEP]']]
        np.random.shuffle(candidates_masked_pos)
        masked_tokens, masked_pos = [], []
        #simply loop and mask accordingly
        for pos in candidates_masked_pos[:n_pred]:
            masked_pos.append(pos)
            masked_tokens.append(input_ids[pos])
            rand_val = np.random.random()
            if rand_val < 0.1:  #10% replace with random token
                index = np.random.randint(4, vocab_size - 1)  # random token should not involve [PAD], [CLS], [SEP], [MASK]
                input_ids[pos] = word2id[id2word[index]]
            elif rand_val < 0.8:  #80 replace with [MASK]
                input_ids[pos] = word2id['[MASK]']
            else:
                pass
        if len(input_ids) > max_len:
            input_ids = input_ids[:max_len]
            segment_ids = segment_ids[:max_len]

        #4. pad the sentence to the max length
        n_pad = max_len - len(input_ids)
        input_ids.extend([0] * n_pad)
        segment_ids.extend([0] * n_pad)

        #5. pad the mask tokens to the max length
        if max_mask > n_pred:
            n_pad = max_mask - n_pred
            masked_tokens.extend([0] * n_pad)
            masked_pos.extend([0] * n_pad)

        #6. check whether is positive or negative
        if tokens_a_index + 1 == tokens_b_index and positive < half_batch_size:
            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, True])
            positive += 1
        elif tokens_a_index + 1 != tokens_b_index and negative < half_batch_size:
            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, False])
            negative += 1

    return batch

In [16]:
batch = make_batch()

In [17]:
#len of batch
len(batch)

2

In [18]:
input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(torch.LongTensor, zip(*batch))

In [19]:
input_ids.shape, segment_ids.shape, masked_tokens.shape, masked_pos.shape, isNext.shape

(torch.Size([2, 1000]),
 torch.Size([2, 1000]),
 torch.Size([2, 5]),
 torch.Size([2, 5]),
 torch.Size([2]))

## 3. Model

Recall that BERT only uses the encoder.

BERT has the following components:

- Embedding layers
- Attention Mask
- Encoder layer
- Multi-head attention
- Scaled dot product attention
- Position-wise feed-forward network
- BERT (assembling all the components)

## 3.1 Embedding

Here we simply generate the positional embedding, and sum the token embedding, positional embedding, and segment embedding together.

<img src = "figures/BERT_embed.png" width=500>

In [20]:
class Embedding(nn.Module):
    def __init__(self, vocab_size, max_len, n_segments, d_model, device):
        super(Embedding, self).__init__()
        self.tok_embed = nn.Embedding(vocab_size, d_model)  # token embedding
        self.pos_embed = nn.Embedding(max_len, d_model)      # position embedding
        self.seg_embed = nn.Embedding(n_segments, d_model)  # segment(token type) embedding
        self.norm = nn.LayerNorm(d_model)
        self.device = device

    def forward(self, x, seg):
        #x, seg: (bs, len)
        seq_len = x.size(1)
        pos = torch.arange(seq_len, dtype=torch.long).to(self.device)
        pos = pos.unsqueeze(0).expand_as(x)  # (len,) -> (bs, len)
        embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)
        return self.norm(embedding)

## 3.2 Attention mask

In [21]:
def get_attn_pad_mask(seq_q, seq_k, device):
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    # eq(zero) is PAD token
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1).to(device)  # batch_size x 1 x len_k(=len_q), one is masking
    return pad_attn_mask.expand(batch_size, len_q, len_k)  # batch_size x len_q x len_k

### Testing the attention mask

In [22]:
print(get_attn_pad_mask(input_ids, input_ids, device).shape)

torch.Size([2, 1000, 1000])


## 3.3 Encoder

The encoder has two main components: 

- Multi-head Attention
- Position-wise feed-forward network

First let's make the wrapper called `EncoderLayer`

In [23]:
class EncoderLayer(nn.Module):
    def __init__(self, n_heads, d_model, d_ff, d_k, device):
        super(EncoderLayer, self).__init__()
        self.enc_self_attn = MultiHeadAttention(n_heads, d_model, d_k, device)
        self.pos_ffn       = PoswiseFeedForwardNet(d_model, d_ff)

    def forward(self, enc_inputs, enc_self_attn_mask):
        enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) # enc_inputs to same Q,K,V
        enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size x len_q x d_model]
        return enc_outputs, attn

Let's define the scaled dot attention, to be used inside the multihead attention

In [24]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_k, device):
        super(ScaledDotProductAttention, self).__init__()
        self.scale = torch.sqrt(torch.FloatTensor([d_k])).to(device)

    def forward(self, Q, K, V, attn_mask):
        scores = torch.matmul(Q, K.transpose(-1, -2)) / self.scale # scores : [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
        scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is one.
        attn = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(attn, V)
        return context, attn 

Let's define the parameters first

In [25]:
n_layers = 6    # number of Encoder of Encoder Layer
n_heads  = 8    # number of heads in Multi-Head Attention
d_model  = 768  # Embedding Size
d_ff = 768 * 4  # 4*d_model, FeedForward dimension
d_k = d_v = 64  # dimension of K(=Q), V
n_segments = 2

Here is the Multiheadattention.

In [26]:
class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads, d_model, d_k, device):
        super(MultiHeadAttention, self).__init__()
        self.n_heads = n_heads
        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_k
        self.W_Q = nn.Linear(d_model, d_k * n_heads)
        self.W_K = nn.Linear(d_model, d_k * n_heads)
        self.W_V = nn.Linear(d_model, self.d_v * n_heads)
        self.device = device
    def forward(self, Q, K, V, attn_mask):
        # q: [batch_size x len_q x d_model], k: [batch_size x len_k x d_model], v: [batch_size x len_k x d_model]
        residual, batch_size = Q, Q.size(0)
        # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
        q_s = self.W_Q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)  # q_s: [batch_size x n_heads x len_q x d_k]
        k_s = self.W_K(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)  # k_s: [batch_size x n_heads x len_k x d_k]
        v_s = self.W_V(V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1,2)  # v_s: [batch_size x n_heads x len_k x d_v]

        attn_mask = attn_mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1) # attn_mask : [batch_size x n_heads x len_q x len_k]

        # context: [batch_size x n_heads x len_q x d_v], attn: [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
        context, attn = ScaledDotProductAttention(self.d_k, self.device)(q_s, k_s, v_s, attn_mask)
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_v) # context: [batch_size x len_q x n_heads * d_v]
        output = nn.Linear(self.n_heads * self.d_v, self.d_model, device=self.device)(context)
        return nn.LayerNorm(self.d_model, device=self.device)(output + residual), attn # output: [batch_size x len_q x d_model]

Here is the PoswiseFeedForwardNet.

In [27]:
class PoswiseFeedForwardNet(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        # (batch_size, len_seq, d_model) -> (batch_size, len_seq, d_ff) -> (batch_size, len_seq, d_model)
        return self.fc2(F.gelu(self.fc1(x)))

## 3.4 Putting them together

In [28]:
class BERT(nn.Module):
    def __init__(self, n_layers, n_heads, d_model, d_ff, d_k, n_segments, vocab_size, max_len, device):
        super(BERT, self).__init__()
        self.params = {'n_layers': n_layers, 'n_heads': n_heads, 'd_model': d_model,
                       'd_ff': d_ff, 'd_k': d_k, 'n_segments': n_segments,
                       'vocab_size': vocab_size, 'max_len': max_len}
        self.embedding = Embedding(vocab_size, max_len, n_segments, d_model, device)
        self.layers = nn.ModuleList([EncoderLayer(n_heads, d_model, d_ff, d_k, device) for _ in range(n_layers)])
        self.fc = nn.Linear(d_model, d_model)
        self.activ = nn.Tanh()
        self.linear = nn.Linear(d_model, d_model)
        self.norm = nn.LayerNorm(d_model)
        self.classifier = nn.Linear(d_model, 2)
        # decoder is shared with embedding layer
        embed_weight = self.embedding.tok_embed.weight
        n_vocab, n_dim = embed_weight.size()
        self.decoder = nn.Linear(n_dim, n_vocab, bias=False)
        self.decoder.weight = embed_weight
        self.decoder_bias = nn.Parameter(torch.zeros(n_vocab))
        self.device = device

    def forward(self, input_ids, segment_ids, masked_pos):
        output = self.embedding(input_ids, segment_ids)
        enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids, self.device)
        for layer in self.layers:
            output, enc_self_attn = layer(output, enc_self_attn_mask)
        # output : [batch_size, len, d_model], attn : [batch_size, n_heads, d_mode, d_model]
        
        # 1. predict next sentence
        # it will be decided by first token(CLS)
        h_pooled   = self.activ(self.fc(output[:, 0])) # [batch_size, d_model]
        logits_nsp = self.classifier(h_pooled) # [batch_size, 2]

        # 2. predict the masked token
        masked_pos = masked_pos[:, :, None].expand(-1, -1, output.size(-1)) # [batch_size, max_pred, d_model]
        h_masked = torch.gather(output, 1, masked_pos) # masking position [batch_size, max_pred, d_model]
        h_masked  = self.norm(F.gelu(self.linear(h_masked)))
        logits_lm = self.decoder(h_masked) + self.decoder_bias # [batch_size, max_pred, n_vocab]

        return logits_lm, logits_nsp
    
    def get_last_hidden_state(self, input_ids, segment_ids):
        output = self.embedding(input_ids, segment_ids)
        enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids, self.device)
        for layer in self.layers:
            output, enc_self_attn = layer(output, enc_self_attn_mask)

        return output

## 4. Training

In [29]:
from tqdm.auto import tqdm

n_layers = 12    # number of Encoder of Encoder Layer
n_heads  = 12    # number of heads in Multi-Head Attention
d_model  = 768  # Embedding Size
d_ff = d_model * 4  # 4*d_model, FeedForward dimension
d_k = d_v = 64  # dimension of K(=Q), V
n_segments = 2

num_epoch = 1000
model = BERT(
    n_layers, 
    n_heads, 
    d_model, 
    d_ff, 
    d_k, 
    n_segments, 
    vocab_size, 
    max_len, 
    device
).to(device)  # Move model to GPU

In [30]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [31]:
batch = make_batch()
input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(torch.LongTensor, zip(*batch))

# Move inputs to GPU
input_ids = input_ids.to(device)
segment_ids = segment_ids.to(device)
masked_tokens = masked_tokens.to(device)
masked_pos = masked_pos.to(device)
isNext = isNext.to(device)

# Wrap the epoch loop with tqdm
for epoch in tqdm(range(num_epoch), desc="Training Epochs"):
    optimizer.zero_grad()
    logits_lm, logits_nsp = model(input_ids, segment_ids, masked_pos)    
    #logits_lm: (bs, max_mask, vocab_size) ==> (6, 5, 34)
    #logits_nsp: (bs, yes/no) ==> (6, 2)

    #1. mlm loss
    #logits_lm.transpose: (bs, vocab_size, max_mask) vs. masked_tokens: (bs, max_mask)
    loss_lm = criterion(logits_lm.transpose(1, 2), masked_tokens) # for masked LM
    loss_lm = (loss_lm.float()).mean()
    #2. nsp loss
    #logits_nsp: (bs, 2) vs. isNext: (bs, )
    loss_nsp = criterion(logits_nsp, isNext) # for sentence classification
    
    #3. combine loss
    loss = loss_lm + loss_nsp
    if epoch % 100 == 0:
        print('Epoch:', '%02d' % (epoch), 'loss =', '{:.6f}'.format(loss))
    loss.backward()
    optimizer.step()

Training Epochs:   0%|          | 0/1000 [00:00<?, ?it/s]

Epoch: 00 loss = 148.190948
Epoch: 100 loss = 3.740366
Epoch: 200 loss = 3.110257
Epoch: 300 loss = 3.102159
Epoch: 400 loss = 4.858214
Epoch: 500 loss = 5.021757
Epoch: 600 loss = 4.326852
Epoch: 700 loss = 3.427927
Epoch: 800 loss = 2.988678
Epoch: 900 loss = 3.109298


In [32]:
batch = make_batch()
input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(torch.LongTensor, zip(*batch))

print("vocab_size:", vocab_size)
print("masked_tokens min/max:", masked_tokens.min().item(), masked_tokens.max().item())
print("input_ids max:", input_ids.max().item())

assert masked_tokens.min().item() >= 0
assert masked_tokens.max().item() < vocab_size


vocab_size: 323367
masked_tokens min/max: 18289 307395
input_ids max: 323065


In [33]:
# Save the model after training
torch.save([model.params, model.state_dict()], 'model/model_bert.pth')
print("Model saved to model_bert.pth")

Model saved to model_bert.pth


## 5. Inference

Since our dataset is very small, it won't work very well, but just for the sake of demonstration.

In [34]:
# load the model and all its hyperparameters
params, state = torch.load('model/model_bert.pth')
model_bert = BERT(**params, device=device).to(device)
model_bert.load_state_dict(state)

<All keys matched successfully>

In [35]:
# Predict mask tokens ans isNext
input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(torch.LongTensor, zip(batch[1]))
print([id2word[w.item()] for w in input_ids[0] if id2word[w.item()] != '[PAD]'])
input_ids = input_ids.to(device)
segment_ids = segment_ids.to(device)
masked_tokens = masked_tokens.to(device)
masked_pos = masked_pos.to(device)
isNext = isNext.to(device)

logits_lm, logits_nsp = model(input_ids, segment_ids, masked_pos)
#logits_lm:  (1, max_mask, vocab_size) ==> (1, 5, 34)
#logits_nsp: (1, yes/no) ==> (1, 2)

#predict masked tokens
#max the probability along the vocab dim (2), [1] is the indices of the max, and [0] is the first value
logits_lm = logits_lm.data.cpu().max(2)[1][0].data.numpy() 
#note that zero is padding we add to the masked_tokens
print('masked tokens (words) : ',[id2word[pos.item()] for pos in masked_tokens[0]])
print('masked tokens list : ',[pos.item() for pos in masked_tokens[0]])
print('masked tokens (words) : ',[id2word[pos.item()] for pos in logits_lm])
print('predict masked tokens list : ', [pos for pos in logits_lm])

#predict nsp
logits_nsp = logits_nsp.cpu().data.max(1)[1][0].data.numpy()
print(logits_nsp)
print('isNext : ', True if isNext else False)
print('predict isNext : ',True if logits_nsp else False)

['[CLS]', '"mr', 'dingle', 'the', 'strong"', 'is', 'episode', '55', 'of', 'the', 'american', 'television', 'anthology', 'series', 'the', 'twilight', 'zone', 'it', 'originally', 'aired', 'on', 'march', '3', '1961', 'on', 'cbs', 'opening', 'narration', 'the', 'narration', 'continues', 'when', 'the', 'martians', 'arrive', 'plot', 'in', 'an', 'experiment', 'a', 'twoheaded', 'martian', 'scientist', 'who', 'is', 'invisible', 'to', 'earthlings', 'gives', 'vacuumcleaner', 'salesman', 'and', 'perennial', 'loser', 'luther', 'dingle', 'superhuman', 'strength', 'after', 'discovering', 'his', 'inexplicable', 'powers', 'dingle', 'begins', 'performing', 'various', 'feats', 'of', 'strength', 'from', 'lifting', 'statues', 'to', 'splitting', 'boulders', 'and', 'gains', 'a', 'great', 'deal', 'of', 'publicity', 'the', 'twoheaded', 'martian', 'returns', 'and', 'is', 'disappointed', 'to', 'see', 'that', 'dingle', 'is', 'using', 'his', 'strength', 'only', 'for', 'show', 'the', 'martian', 'takes', 'his', 'str

Trying a bigger dataset should be able to see the difference.