In [1]:
from transformers import BertTokenizer, BertForMaskedLM
import torch

In [2]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
with open('./data/text/intro.txt','r', encoding='utf-8') as fp:
    text = fp.read().split('\n')

In [4]:
text[:2]

['Text classification is one of the most common tasks in NLP. It is applied in a wide variety of applications, including sentiment analysis, spam filtering, news categorization, etc. Here, we show you how you can detect fake news (classifying an article as REAL or FAKE) using the state-of-the-art models, a tutorial that can be extended to really any text classification task.',
 'The Transformer is the basic building block of most current state-of-the-art architectures of NLP. Its primary advantage is its multi-head attention mechanisms which allow for an increase in performance and significantly more parallelization than previous competing models such as recurrent neural networks. In this tutorial, we will use pre-trained BERT, one of the most popular transformer models, and fine-tune it on fake news detection. I have also used an LSTM for the same task in a later tutorial, please check it out if interested!']

In [5]:
inputs = tokenizer(text,
                   return_tensors='pt',
                   max_length=512,
                   truncation=True, 
                   padding='max_length')
inputs

{'input_ids': tensor([[  101,  3793,  5579,  ...,     0,     0,     0],
        [  101,  1996, 10938,  ...,     0,     0,     0]]), 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])}

In [6]:
inputs['labels'] = inputs.input_ids.detach().clone()
inputs

{'input_ids': tensor([[  101,  3793,  5579,  ...,     0,     0,     0],
        [  101,  1996, 10938,  ...,     0,     0,     0]]), 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]]), 'labels': tensor([[  101,  3793,  5579,  ...,     0,     0,     0],
        [  101,  1996, 10938,  ...,     0,     0,     0]])}

In [7]:
rand = torch.rand(inputs.input_ids.shape)
rand

tensor([[0.5780, 0.9885, 0.3539,  ..., 0.0847, 0.5364, 0.7744],
        [0.8647, 0.8484, 0.9371,  ..., 0.8794, 0.8222, 0.0341]])

In [8]:
mask_arr = (rand < 0.15) * (inputs.input_ids != 101) * (inputs.input_ids != 102)
mask_arr

tensor([[False, False, False,  ...,  True, False, False],
        [False, False, False,  ..., False, False,  True]])

In [9]:
selection = []

for i in range(mask_arr.shape[0]):
    selection.append(
        torch.flatten(mask_arr[i].nonzero()).tolist()
        )
    
selection

[[8,
  12,
  14,
  15,
  32,
  37,
  41,
  47,
  50,
  53,
  60,
  65,
  79,
  81,
  89,
  92,
  108,
  111,
  120,
  135,
  136,
  152,
  166,
  179,
  185,
  193,
  197,
  211,
  214,
  216,
  225,
  230,
  245,
  249,
  252,
  258,
  261,
  269,
  270,
  280,
  285,
  305,
  308,
  309,
  312,
  317,
  321,
  325,
  326,
  343,
  348,
  349,
  368,
  370,
  371,
  376,
  379,
  393,
  398,
  410,
  428,
  430,
  437,
  453,
  463,
  471,
  492,
  497,
  498,
  503,
  506,
  509],
 [7,
  8,
  10,
  16,
  18,
  21,
  34,
  35,
  47,
  51,
  53,
  73,
  75,
  82,
  87,
  100,
  101,
  105,
  113,
  118,
  123,
  124,
  133,
  139,
  142,
  144,
  145,
  151,
  184,
  196,
  205,
  216,
  229,
  241,
  245,
  248,
  256,
  261,
  265,
  268,
  277,
  282,
  292,
  295,
  297,
  305,
  306,
  325,
  328,
  336,
  338,
  342,
  343,
  345,
  352,
  358,
  362,
  370,
  373,
  380,
  381,
  382,
  401,
  404,
  405,
  413,
  419,
  427,
  429,
  432,
  436,
  439,
  443,
  455,
  471,
  47

In [10]:
for i in range(mask_arr.shape[0]):
    inputs.input_ids[i, selection[i]] = 103
inputs.input_ids

tensor([[  101,  3793,  5579,  ...,   103,     0,     0],
        [  101,  1996, 10938,  ...,     0,     0,   103]])

In [11]:
class MeditationDataset(torch.utils.data.Dataset):
    def __init__(self, encodings) -> None:
        super().__init__()
        self.encodings = encodings
        
    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
    
    def __len__(self):
        return len(self.encodings.input_ids)

In [12]:
dataset = MeditationDataset(inputs)

In [13]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)

In [14]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

device(type='cuda')

In [15]:
model.to(device)

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=Tr

In [16]:
model.train()

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=Tr

In [17]:
from transformers import AdamW

In [18]:
optim = AdamW(model.parameters(), lr=1e-5)
optim

AdamW (
Parameter Group 0
    betas: (0.9, 0.999)
    correct_bias: True
    eps: 1e-06
    lr: 1e-05
    weight_decay: 0.0
)

In [19]:
from tqdm import tqdm

epochs = 2

for epoch in range(epochs):
    loop = tqdm(dataloader, leave=True)
    for batch in loop:
        optim.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        outputs = model(input_ids, 
                        attention_mask = attention_mask, 
                        labels = labels) 
        loss = outputs.loss
        loss.backward()
        optim.step()
        
        loop.set_description(f'Epoch {epoch}')
        loop.set_postfix(loss=loss.item())

  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
Epoch 0: 100%|██████████| 1/1 [00:00<00:00,  1.25it/s, loss=14.4]
Epoch 1: 100%|██████████| 1/1 [00:00<00:00,  2.45it/s, loss=13.7]
