

## Installing libraries


In [33]:
import torch 
from transformers import BertForQuestionAnswering
import gc 
from torch.utils.data import DataLoader
from transformers import AdamW

## Initializing the model

Loading BERT base mulitiligual cased pre-trained model from HuggingFace Transformers

In [34]:
model = BertForQuestionAnswering.from_pretrained("bert-base-multilingual-cased")

Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-multilingual-cased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Analyzing the model configuration

In [35]:
model.config_class()

BertConfig {
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.42.3",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

Defining the model checkpoint path 

In [36]:
import shutil

def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')

## Model training

Training configuration

In [37]:
import pickle

class Dataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx] or -1, dtype=torch.int64) for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings.input_ids)

with open("../data/train_dataset.pkl","rb") as file:
    train_dataset = pickle.load(file)

with open("../data/val_dataset.pkl","rb") as file:
    val_dataset = pickle.load(file)

In [38]:
len(list(iter(train_dataset)))

2839

In [42]:
from tqdm import tqdm

gc.collect() # used to prevent the "cuda running out of memory" error

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') # model to GPU

model.to(device)
model.train()

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True) # batch size is 1 (because the model is large), the data is shuffled

optim = AdamW(model.parameters(), lr=5e-5) # AdamW optimization algorithm, learning rate is 5e-5

for epoch in tqdm(range(10)): # 10 epochs
    for batch in tqdm(train_loader, total=2839):
        optim.zero_grad() 
        input_ids = batch['input_ids'].to(device) # integers
        attention_mask = batch['attention_mask'].to(device) # 0's and 1's sequences
        start_positions = batch['start_positions'].to(device) # span
        end_positions = batch['end_positions'].to(device) 
        outputs = model(input_ids, attention_mask=attention_mask, start_positions=start_positions, end_positions=end_positions)
        loss = outputs[0] 
        loss.backward() # backpropagation
        optim.step() # gradient descent

filepath = '/content/model.pth' # saving weights
torch.save(model.state_dict(), filepath)
model.load_state_dict(torch.load(filepath))
model.eval() # model summary

100%|██████████| 2839/2839 [01:03<00:00, 45.04it/s]
100%|██████████| 2839/2839 [01:03<00:00, 45.00it/s]
100%|██████████| 2839/2839 [01:03<00:00, 45.02it/s]
100%|██████████| 2839/2839 [01:02<00:00, 45.10it/s]
100%|██████████| 2839/2839 [01:03<00:00, 45.01it/s]
100%|██████████| 2839/2839 [01:02<00:00, 45.12it/s]
100%|██████████| 2839/2839 [01:02<00:00, 45.18it/s]
100%|██████████| 2839/2839 [01:01<00:00, 46.51it/s]
100%|██████████| 2839/2839 [01:00<00:00, 46.58it/s]
100%|██████████| 2839/2839 [01:01<00:00, 45.88it/s]
100%|██████████| 10/10 [10:24<00:00, 62.48s/it]


RuntimeError: Parent directory /content does not exist.

In [44]:
filepath = '../models/m-bert.pth' # saving weights
torch.save(model.state_dict(), filepath)
model.load_state_dict(torch.load(filepath))
model.eval() # model summary

BertForQuestionAnswering(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(119547, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,