# Embedding Prep 1b: BERT Vectors

### Why BERT?

Aside from tfidf and fastText embedding methods in the previous notebook, one more method we will try is [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding (Devlin et al., 2019)
](https://arxiv.org/pdf/1810.04805.pdf). Like fastText, BERT has the ability to encode semantic similarity better than tfidf, but goes one step further by learning the positions of words and having the encoded meaning also dependent on these word positions. In short, BERT has the ability to encode contextual representations, for example learning that the word 'article' in the GDPR has a different meaning from the word 'article' in defamation content talking about newspaper articles.

However, 

For this version, the [bert base uncased model by Hugging Face](https://huggingface.co/bert-base-uncased) was used. Further work trying out other transformer models on Hugging Face is definitely desirable.

To pretrain the model to adapt to the language nuances of our legislation data, the BERT model can be pretrained on two tasks, Masked Language Modeling (MLM) and Next Sentence Prediction (NSP). For MLM, 15% of the tokens are masked (hidden) and the model learns through trying to predict these masked tokens. For NSP, the model learns from pairs of sentences where there are examples of sentences that truly follows the first in the pair, and sentences that do not follow, based on the training corpus.

The code below will cover both methods, but for the first version of the project, the model pretrained on MLM only was chosen as it performed better than the one that also went through NSP.

## Expected Output from this Notebook

In this notebook, we are embedding our clean data to get vector representations for the legislation sections so that we can match them later on.

As such, we expect to get and save from this notebook, the BERT embedding representations of the legislation (.npy file).

The outputs below demo based on data for the SG Copyright Act and UK CDPA, but do note as mentioned, the data files containing legislation content will not be in the repo.

### Specifying Save Data Paths

In [23]:
# where the output vectors will be saved
bert_vector_file = '../data/vectors/copyright/test_saving_bert_vectors.npy'

# where the bert model and tokenizer will be saved after pretraining
save_model_path = '../models/test_saving_bert_model'
save_tokenizer_path = '../models/test_saving_bert_tokenizer'

## Imports and Loading Data

In [1]:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from transformers import BertTokenizer, BertForMaskedLM, BertForPreTraining, AdamW
from sklearn.metrics.pairwise import cosine_similarity

In [2]:
input_data_filepath = '../data/clean/copyright/sg_uk_copyright.csv' 
# this data file will not be pushed to git repo 

In [3]:
data = pd.read_csv(input_data_filepath)

Prepare content as a list as required for pretraining.

In [4]:
try:
    content = data['cleaned'].to_list()
except:
    raise Exception('Ensure that the content column to be vectorized is named "cleaned".')

## MLM Pretraining

reference: [this tutorial](https://github.com/jamescalam/transformers/blob/main/course/training/03_mlm_training.ipynb) 

In [5]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


As we will use pytorch, tensors have to be specified as 'pt'. 

As seen, a large part of our data is around the 400-500 word range, so we will use the largest max langth of 512 tokens, padding the remainder for examples under 512 tokens, as truncating at 512 for those that exceed this length.

In [6]:
inputs = tokenizer(content, return_tensors='pt', max_length=512, truncation=True, padding='max_length')

<a id='create_masks'></a>
### Create Labels and Masks

Next, we create a clone of the inputs as the 'answer key' for training our MLM.

In [7]:
inputs['labels'] = inputs.input_ids.detach().clone()

Next we create a mask by initiating a tensor of random values based on the input_id size.

From the mask of random values, we will mask 15% of the tokens as done in the BERT paper, while excluding the special CLS (101), SEP (102) and PAD (0) tokens. Essentially we want to be masking actual words rather than these special tokens so that the model is focused on learning from the language.

In [8]:
rand = torch.rand(inputs.input_ids.shape)

mask = (rand < 0.15) * (inputs.input_ids != 101) * (inputs.input_ids != 102) * (inputs.input_ids != 0)

Next, for each sample,  we store the selected indices of masks.

In [9]:
mask_ids = []

for i in range(inputs.input_ids.shape[0]):
    mask_ids.append(
        torch.flatten(mask[i].nonzero()).tolist()
    )

We then apply the masks to the respective inputs, with token 103 for masking.

In [10]:
for i in range(inputs.input_ids.shape[0]):
    inputs.input_ids[i, mask_ids[i]] = 103

### Prepare Data and Model for Training

Create Pytorch dataset.

In [11]:
class LegisDataset(Dataset):
    def __init__(self, encodings):
        self.encodings = encodings
    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
    def __len__(self):
        return len(self.encodings.input_ids)

In [12]:
dataset = LegisDataset(inputs)

In [13]:
loader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=True)

### Specify GPU Use
This next part allows some flexibility just in case want to train the model using a GPU, for example on Google Colab.

In [14]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [15]:
model.to(device)

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=Tr

We make the model trainable.

In [16]:
model.train()

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=Tr

Next we specify the optimiser, learning rate, and number of epochs to train for.

For batch size of 8, use learning rate of 3e-4 (see [here](https://wandb.ai/jack-morris/david-vs-goliath/reports/Does-Model-Size-Matter-A-Comparison-of-BERT-and-DistilBERT--VmlldzoxMDUxNzU)).

In [17]:
opt = AdamW(model.parameters(), lr=3e-4)

In [18]:
epochs = 4

<a id='train_model'></a>
### Train Model

In [19]:
for epoch in range(epochs):
    tq_logger = tqdm(loader, leave=True)
    
    for batch in tq_logger:
        # initiliase gradients
        opt.zero_grad()
        # move inputs, attn_masks and labels to device like we did for model above
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        
        # get computed loss
        loss = outputs.loss
        loss.backward()

        # update weights
        opt.step()

        tq_logger.set_description(f'epoch {epoch}')
        tq_logger.set_postfix(loss=loss.item())

  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
epoch 0: 100%|██████████| 98/98 [1:17:07<00:00, 47.21s/it, loss=0.146] 
epoch 1: 100%|██████████| 98/98 [1:10:47<00:00, 43.34s/it, loss=0.115] 
epoch 2: 100%|██████████| 98/98 [57:01<00:00, 34.91s/it, loss=0.0851] 
epoch 3: 100%|██████████| 98/98 [55:40<00:00, 34.09s/it, loss=0.0296]


In [None]:
# Use the following training code instead for NSP+MLM training, after preparing NSP data in the NSP section below.
# The main difference is the inclusion of next_sentence_label

# for epoch in range(epochs):
#     tq_logger = tqdm(loader, leave=True)
    
#     for batch in tq_logger:
#         # initiliase gradients
#         opt.zero_grad()
#         # move inputs, attn_masks and labels to device like we did for model above
#         input_ids = batch['input_ids'].to(device)
#         attention_mask = batch['attention_mask'].to(device)
#         next_sentence_label = batch['next_sentence_label'].to(device)
#         labels = batch['labels'].to(device)

#         outputs = model(input_ids, attention_mask=attention_mask, 
#                         next_sentence_label=next_sentence_label, labels=labels)

#         # get computed loss
#         loss = outputs.loss
#         loss.backward()

#         # update weights
#         opt.step()

#         tq_logger.set_description(f'epoch {epoch}')
#         tq_logger.set_postfix(loss=loss.item())

### Save Model

In [None]:
model.save_pretrained(save_model_path)
tokenizer.save_pretrained(save_tokenizer_path)

### Get Vectors
reference: [this tutorial](https://towardsdatascience.com/bert-for-measuring-text-similarity-eec91c6bf9e1)

#### Last Hidden State

The following code uses the pretrained model to get the vector representations through the last hidden state of the model. The BERT paper has more information about which layer of the model can be chosen for this purpose. 

For this project, the last hidden state was chosen as the paper shows that it slightly outperforms jsut choosing the first hidden state, while the better performing alternatives require concatenating different hidden states which seems to only marginally perform better but with added complexity.

#### Vector Representations

For each text, the vector representation of the entire legislation section is computed. This is termed as `sent_vecs` below but is more accurately the vector of all the sentences in the legislation section. These vectors are obtained by summing all the vectors in each token position and dividing it over the amount of tokens where the attention mask is positive.

#### Batch Processing

The vectors will be processed in small batches to avoid out of memory issue.

In [20]:
process_batch_size = 5
max_batches = len(content) // process_batch_size
remainder = len(content) % process_batch_size
processed_batches = 0
batch_start = 0
batch_stop = batch_start + process_batch_size

while processed_batches < max_batches:
    
    tokens = {'input_ids': [], 'attention_mask': []}
    
    if max_batches - processed_batches == 1:
        batch_stop += remainder
    
    for legis in content[batch_start:batch_stop]:
        new_tokens = tokenizer.encode_plus(legis, max_length=512,
                                           truncation=True, padding='max_length',
                                           return_tensors='pt')
        tokens['input_ids'].append(new_tokens['input_ids'][0])
        tokens['attention_mask'].append(new_tokens['attention_mask'][0])

    # reformat list of tensors into single tensor
    tokens['input_ids'] = torch.stack(tokens['input_ids'])
    tokens['attention_mask'] = torch.stack(tokens['attention_mask'])
    
    outputs = model(**tokens, output_hidden_states=True)
    embeddings = outputs['hidden_states'][-1]
    attention_mask = tokens['attention_mask']
    mask = attention_mask.unsqueeze(-1).expand(embeddings.size()).float()
    masked_embeddings = embeddings * mask
    summed = torch.sum(masked_embeddings, 1)
    # set a minimum of 0.0001 to avoid zero division
    summed_mask = torch.clamp(mask.sum(1), min=0.0001)
    # get the sentence vectors from the mean pool
    sent_vecs = summed / summed_mask
    sent_vecs = sent_vecs.detach().numpy()
    
    if batch_start == 0:
        all_sent_vecs = sent_vecs
    else:
        all_sent_vecs = np.concatenate((all_sent_vecs, sent_vecs), axis=0)
    
    batch_start += process_batch_size
    batch_stop += process_batch_size
    processed_batches += 1
    print(f'Processed {processed_batches} out of {max_batches} batches.')
    print(f'Currently at a total of {all_sent_vecs.shape[0]} vectors obtained.')

Processed 1 out of 156 batches.
Currently at a total of 5 vectors obtained.
Processed 2 out of 156 batches.
Currently at a total of 10 vectors obtained.
Processed 3 out of 156 batches.
Currently at a total of 15 vectors obtained.
Processed 4 out of 156 batches.
Currently at a total of 20 vectors obtained.
Processed 5 out of 156 batches.
Currently at a total of 25 vectors obtained.
Processed 6 out of 156 batches.
Currently at a total of 30 vectors obtained.
Processed 7 out of 156 batches.
Currently at a total of 35 vectors obtained.
Processed 8 out of 156 batches.
Currently at a total of 40 vectors obtained.
Processed 9 out of 156 batches.
Currently at a total of 45 vectors obtained.
Processed 10 out of 156 batches.
Currently at a total of 50 vectors obtained.
Processed 11 out of 156 batches.
Currently at a total of 55 vectors obtained.
Processed 12 out of 156 batches.
Currently at a total of 60 vectors obtained.
Processed 13 out of 156 batches.
Currently at a total of 65 vectors obtain

Processed 106 out of 156 batches.
Currently at a total of 530 vectors obtained.
Processed 107 out of 156 batches.
Currently at a total of 535 vectors obtained.
Processed 108 out of 156 batches.
Currently at a total of 540 vectors obtained.
Processed 109 out of 156 batches.
Currently at a total of 545 vectors obtained.
Processed 110 out of 156 batches.
Currently at a total of 550 vectors obtained.
Processed 111 out of 156 batches.
Currently at a total of 555 vectors obtained.
Processed 112 out of 156 batches.
Currently at a total of 560 vectors obtained.
Processed 113 out of 156 batches.
Currently at a total of 565 vectors obtained.
Processed 114 out of 156 batches.
Currently at a total of 570 vectors obtained.
Processed 115 out of 156 batches.
Currently at a total of 575 vectors obtained.
Processed 116 out of 156 batches.
Currently at a total of 580 vectors obtained.
Processed 117 out of 156 batches.
Currently at a total of 585 vectors obtained.
Processed 118 out of 156 batches.
Curren

In [21]:
all_sent_vecs.shape

(780, 768)

In [24]:
np.save(bert_vector_file, all_sent_vecs)

## MLM and NSP PreTraining
reference: [this tutorial](https://www.youtube.com/watch?v=IC9FaVPKlYc) 

The code is for pretraining using both MLM and NSP tasks. In this project, it was found that the MLM only vectors somehow performed better than those from a model pretrained with both MLM and NSP tasks. Since MLM-only performed better and was faster to train, NSP was not used.

However, this could just be due to the specific quirks of the evaluation data or perhaps some issue with the data cleaning.

For those trying out on their own data, both methods should definitely be tried. According to the BERT paper authors, NSP should ideally allow the model to learn context better.

### Prepare Sentences for NSP

Prepare a bag of sentences where we can select our random negative examples from later. Essentially getting each legislation provision `legis` and splitting into sentences.

To keep this lightweight, we just split on fullstops. To be more pedantic, a tokenizer from an nlp library like spaCy or NLTK could also be used.

In [None]:
all_sentences = [sentence for legis in content for sentence in legis.split('. ') if sentence != '']
all_len = len(all_sentences)

Create NSP training data. The first sentence, second sentence, and labels for each pair. Label 0 is when second sentence follows first, label 1 when second sentence does not follow first.

In [None]:
sentence_1 = []
sentence_2 = []
label = []

Next, generate the next sentence pairs with an even mix of sentences that follow, and sentences that don't. Sentences that do not follow are sampled from the all_sentences list above. We iterate through each legis entry to get our samples.

In [None]:
for legis in content:
    sentences = [sentence for sentence in legis.split('. ') if sentence != '']
    num_sentences = len(sentences)
    if num_sentences > 1: # we can only get a correct next sentence if there are multiple sentences
        start = random.randint(0, num_sentences-2) # randomly sample a sentence except the last one
        sentence_1.append(sentences[start])
        next_sent = sentences[start+1]
        if random.random() > 0.5:
            # 50% chance of generating a correct NSP example
            sentence_2.append(next_sent)
            label.append(0)
        else:
            # 50% chance of generating a wrong NSP example
            wrong_sent = all_sentences[random.randint(0, all_len-1)]
            # ensure our wrong example sentences is indeed wrong
            while wrong_sent == next_sent:
                wrong_sent = all_sentences[random.randint(0, all_len-1)]
            sentence_2.append(wrong_sent)
            label.append(1)

We can examine our sentence pairs and labels.

In [None]:
# for i in range(10,15):
#     print(i, sentence_1[i], '----next_sent----', sentence_2[i], '----label----', label[i], '\n\n')

Use `BertForPreTraining` for MLM and NSP.

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForPreTraining.from_pretrained('bert-base-uncased')

Prepare the the tokenizer.

In [None]:
inputs = tokenizer(sentence_1, sentence_2, 
                   return_tensors='pt', max_length=512, 
                   truncation=True, padding='max_length')

In [None]:
inputs['next_sentence_label'] = torch.LongTensor([label]).T

The next part is to prepare the MLM portion exactly as above. 

Just proceed to run the code from [the Create Labels and Masks cell](#create_masks) onwards (preparing, training model etc) all the way till the vectors are created and saved. Be sure to use the correct code in the [train model section](#train_model), which is commented out for MLM-only training. A larger batch size might be recommended with the learning rate changed accordingly. For this project, when NSP-MLM tried was 16 instead of 8, the loss managed to go down lower, while it was struggling to do so at batch size 8.

# References

BERT paper: 
- https://arxiv.org/pdf/1810.04805.pdf

Pretraining BERT in PyTorch and getting vectors:
- https://towardsdatascience.com/bert-for-measuring-text-similarity-eec91c6bf9e1
- https://github.com/jamescalam/transformers/blob/main/course/training/03_mlm_training.ipynb
- https://www.youtube.com/watch?v=IC9FaVPKlYc

BERT learning rates:
- https://wandb.ai/jack-morris/david-vs-goliath/reports/Does-Model-Size-Matter-A-Comparison-of-BERT-and-DistilBERT--VmlldzoxMDUxNzU