<a href="https://colab.research.google.com/github/brownsloth/transformers_concepts_notebooks/blob/main/transformers_5_usages_of_BERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import transformers
from transformers import BertTokenizer, BertForMaskedLM
from torch.nn import functional as F
import torch

In [None]:
ckpt = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(ckpt)
text = '2025 will be a great year for all of us!'
encoding = tokenizer.encode_plus(text, add_special_token=True, truncation=True, padding='max_length',
                                 return_attention_task=True, return_tensor='pt')
print(encoding)

In [None]:
## The encoding contains the token ids but also other info
print(encoding['input_ids'])
print(encoding['attention_mask'])

## 1. Application 1: Predcting masked words!

In [None]:
text = 'The Opera House in Australia is in '+ tokenizer.mask_token + ' city'
input = tokenizer.encode_plus(text, return_tensors='pt')
mask_index = input['input_ids'][0].tolist().index(tokenizer.mask_token_id)
model = BertForMaskedLM.from_pretrained(ckpt, return_dict=True) # return_dict helps return a ModelOutput class

In [None]:
output = model(**input)
logits = output.logits
print(logits.shape) # (batch_size, seq_len, vocab_size)

In [None]:
softmax = F.softmax(logits, dim=-1)
mask_words = softmax[0, mask_index,:]
print(mask_words.shape)
print(mask_words)

In [None]:
# get indices
# This gives us index in vocab which is same as token id
# can use argmax for the best replacement
for token_id in torch.topk(mask_words, 10)[1]:
  token = tokenizer.decode(token_id)
  new_sentence = text.replace(tokenizer.mask_token, token)
  print(new_sentence)

## 2. Predicting next sentence!

In [None]:
from transformers import BertForNextSentencePrediction, BertTokenizer

ckpt = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(ckpt)
model = BertForNextSentencePrediction.from_pretrained(ckpt)

In [None]:
prompt = "I came back from office in the evening"
possible_next_sentence = "I opened my beer after office"

In [None]:
input = tokenizer.encode_plus(prompt, possible_next_sentence, return_tensors='pt')
outputs = model(**input)
print(outputs)

In [None]:
softmax = F.softmax(outputs['logits'], dim=1)
print(softmax)

## 3. Question Answering

In [None]:
from transformers import BertForQuestionAnswering, BertTokenizer

ckpt = 'deepset/bert-base-cased-squad2'
tokenizer = BertTokenizer.from_pretrained(ckpt)
model = BertForQuestionAnswering.from_pretrained(ckpt)

In [None]:
context_text = 'GPT-3 came in 2020'
question_text = 'When did GPT-3 come'
inputs = tokenizer(question_text, context_text, return_tensors='pt')
print(inputs)
## its usually taken as input like [CLS] text1 [SEP] text2 [SEP]
print(tokenizer.special_tokens_map)
print(torch.where(inputs['input_ids'] == tokenizer.cls_token_id)) # 0th is CLS
print(torch.where(inputs['input_ids'] == tokenizer.sep_token_id)) # 8th index is SEP for question_text and 16th is SEP for context_text
print((inputs['token_type_ids'] == 0).nonzero(as_tuple=True)[0].shape) # from 0 to 8 index is for question_text

In [None]:
print(inputs'].shape)

In [None]:
with torch.no_grad():
  outputs = model(**inputs)

answer_start_index = outputs.start_logits.argmax()
answer_end_index = outputs.end_logits.argmax()
print(answer_start_index)
print(answer_end_index)
## This tells us where within the context does the answer lie?
print(tokenizer.decode(inputs.input_ids[0, answer_start_index:answer_end_index+1]))