In [3]:
from transformers import BertTokenizer, BertForQuestionAnswering, AdamW, AutoTokenizer, AutoModelForSequenceClassification
from torch.utils.data import DataLoader
import torch
from datasets import load_dataset
from tqdm import tqdm
from transformers import get_linear_schedule_with_warmup
import time
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

2023-10-03 18:22:00.845555: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-10-03 18:22:01.147946: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [4]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Dataset is made up of ID, Title, Context, Question, and Answers
Here we load in the dataset, now we are going to apply the autotokenizer to it.

If you want to reduct training time, create smaller datasets of the existing one.

In [5]:
#Load SQuAD Data
dataset = load_dataset("squad")

# example output:
print(dataset["train"][0])

Found cached dataset squad (/home/mhrrs/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453)


  0%|          | 0/2 [00:00<?, ?it/s]

{'id': '5733be284776f41900661182', 'title': 'University_of_Notre_Dame', 'context': 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.', 'question': 'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?', 'answers': {'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]}}


train_dataloader currently contains: ['id', 'title', 'context', 'question', 'answers', 'input_ids', 'token_type_ids', 'attention_mask']
- It needs to contain start and end positions as well.

In [6]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

def tokenize_function(examples):
    return tokenizer(examples["context"], padding="max_length", stride=128, truncation=True)

tokenized_datasets = dataset.map(tokenize_function, batched=True)

Loading cached processed dataset at /home/mhrrs/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453/cache-94cdc3ae0dfac08b.arrow
Loading cached processed dataset at /home/mhrrs/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453/cache-64af2f1d98310d39.arrow


In [7]:
# Remove this and use the full set later if you have time to train it
train_dataset = tokenized_datasets["train"].shuffle(seed=42)
eval_dataset = tokenized_datasets["validation"].shuffle(seed=42)

Loading cached shuffled indices for dataset at /home/mhrrs/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453/cache-3d9ddc5e3cf60fe4.arrow
Loading cached shuffled indices for dataset at /home/mhrrs/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453/cache-09c923ccbfed43a7.arrow


The start index of the answer is already in the "answers" dictionary. Now we just add the length of the "text" in the "answers" dict to get the "answer_end".

In [8]:
# process training and validation data into seperate groups
def package_data(dataset):
    contexts = []
    questions = []
    answers = []
    
    for item in dataset:
        contexts.append(item["context"])
        questions.append(item["question"])
        answers.append(item["answers"])
            
    return contexts, questions, answers
        
        
train_contexts, train_questions, train_answers = package_data(train_dataset)
eval_contexts, eval_questions, eval_answers = package_data(eval_dataset)

print(f"contexts: {len(train_contexts)}\nquestions: {len(train_questions)}\nanswers: {len(train_answers)}")

contexts: 87599
questions: 87599
answers: 87599


In [9]:
# use the following function to add answer_end index to the train_answers/valid_answers data structure
def add_end_index(dataset):
    for ans in dataset:
        text = ans["text"][0]
        start_index = ans["answer_start"][0]
        end_index = start_index + len(text) #adds the length of text onto the start index
        ans["answer_start"] = start_index
        ans["answer_end"] = end_index
        
add_end_index(train_answers)
add_end_index(eval_answers)

In [10]:
print(f"train_contexts = {train_contexts[0]}\ntrain_questions = {train_questions[0]}\ntrain_answers = {train_answers[0]}")

train_contexts = The Pew Forum on Religion & Public Life ranks Egypt as the fifth worst country in the world for religious freedom. The United States Commission on International Religious Freedom, a bipartisan independent agency of the US government, has placed Egypt on its watch list of countries that require close monitoring due to the nature and extent of violations of religious freedom engaged in or tolerated by the government. According to a 2010 Pew Global Attitudes survey, 84% of Egyptians polled supported the death penalty for those who leave Islam; 77% supported whippings and cutting off of hands for theft and robbery; and 82% support stoning a person who commits adultery.
train_questions = What percentage of Egyptians polled support death penalty for those leaving Islam?
train_answers = {'text': ['84%'], 'answer_start': 468, 'answer_end': 471}


# Tokenize the contexts and questions together

In [11]:
train_encodings = tokenizer(train_contexts, train_questions, truncation=True, padding=True)
eval_encodings = tokenizer(eval_contexts, eval_questions, truncation=True, padding=True)
del train_contexts, train_questions, eval_contexts, eval_questions

In [12]:
tokenizer.decode(train_encodings['input_ids'][0])

'[CLS] the pew forum on religion & public life ranks egypt as the fifth worst country in the world for religious freedom. the united states commission on international religious freedom, a bipartisan independent agency of the us government, has placed egypt on its watch list of countries that require close monitoring due to the nature and extent of violations of religious freedom engaged in or tolerated by the government. according to a 2010 pew global attitudes survey, 84 % of egyptians polled supported the death penalty for those who leave islam ; 77 % supported whippings and cutting off of hands for theft and robbery ; and 82 % support stoning a person who commits adultery. [SEP] what percentage of egyptians polled support death penalty for those leaving islam? [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PA

In [13]:
def add_token_positions(encodings, answers):
    start_pos = []
    end_pos = []
    for i in range(len(answers)):
        start_pos.append(encodings.char_to_token(i, answers[i]["answer_start"]))
        end_pos.append(encodings.char_to_token(i, answers[i]["answer_end"]-1))
        
        if start_pos[-1] is None:
            start_pos[-1] = tokenizer.model_max_length
        if end_pos[-1] is None:
            end_pos[-1] = tokenizer.model_max_length
        
    encodings.update({"start_positions": start_pos, "end_positions": end_pos})

add_token_positions(train_encodings, train_answers)
add_token_positions(eval_encodings, eval_answers)

In [14]:
print(train_encodings.keys())
print(f"{train_encodings['start_positions'][2]}\n{train_encodings['end_positions'][2]}")

dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'start_positions', 'end_positions'])
108
109


# create DataLoaders from the preprocessed tokens

In [15]:
class SQuAD(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings
    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
    def __len__(self):
        return len(self.encodings.input_ids)
    
train_dataset = SQuAD(train_encodings)
eval_dataset = SQuAD(eval_encodings)

In [16]:
print(train_dataset.encodings[0])

# insert converted tensor dataset into DataLoader function
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
eval_loader = DataLoader(train_dataset, batch_size=16)
print(train_loader)

Encoding(num_tokens=512, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing])
<torch.utils.data.dataloader.DataLoader object at 0x145de9caf880>


In [17]:
def show_answer(idx):
    print("Tokenized", tokenizer.decode(train_encodings['input_ids'][idx][train_encodings['start_positions'][idx]: train_encodings['end_positions'][idx]]))
    print("Real", train_answers[idx]['text'])
    print("Real", train_answers[idx]['answer_start'])
    print("Real", train_answers[idx]['answer_end'])
    
show_answer(0)

Tokenized 84
Real ['84%']
Real 468
Real 471


# Create Training Loop
Experimental at the moment.

In [18]:
def train_model(train_loader):
    n_epochs = 5
    optim = AdamW(model.parameters(), lr=5e-5)
    n_train_steps = len(train_loader) * n_epochs
    n_warmup_steps = .1 * n_train_steps
    scheduler = get_linear_schedule_with_warmup(optim, n_warmup_steps, n_t_steps)

    model.to(device)
    model.train()
    for epoch in range(n_epochs):
        loop = tqdm(train_loader, leave=True)
        for batch in loop:
            optim.zero_grad()
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            start_pos = batch['start_positions'].to(device)
            end_pos = batch['end_positions'].to(device)
            outputs = model(input_ids, attention_mask=attention_mask, start_positions=start_pos, end_positions=end_pos)
            loss = outputs[0]
            loss.backward()
            optim.step()
            scheduler.step()

            loop.set_description(f"Epoch {epoch+1}")
            loop.set_postfix(loss=loss.item())


# Updated Training Loop
The following training loop uses:
- Automatic Mixed Precision
- Linear Learning Rate Decay
- Gradient Accumulation

In [19]:
def train_amp_model(train_loader):
    n_epochs = 5
    optim = AdamW(model.parameters(), lr=5e-5)
    
    # linear learning rate decay setup
    n_train_steps = len(train_loader) * n_epochs
    n_warmup_steps = .1 * n_train_steps
    scheduler = get_linear_schedule_with_warmup(optim, n_warmup_steps, n_t_steps)
    
    # automatic mixed precision setup
    scaler = torch.cuda.amp.GradScaler()
    
    # other
    total_loss = 0
    total_time = 0.0

    #batch accumulation parameter
    accum_iter = 4
    
    model.to(device)
    model.train()
    for epoch in range(n_epochs):
        start_time = time.perf_counter()
        
        for step, batch in enumerate(train_loader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            start_pos = batch['start_positions'].to(device)
            end_pos = batch['end_positions'].to(device)
            
            # automatic mixed precision
            with torch.cuda.amp.autocast():
                outputs = model(input_ids, attention_mask=attention_mask, start_positions=start_pos, end_positions=end_pos)
                loss = outputs[0]
                
            scaler.scale(loss).backward()
            total_loss += loss.item()
            
            # Gradient accumulation implementation
            if ((step+1)%accum_iter == 0) or (step+1 == len(train_loader)):
                scaler.step(optim)
                scaler.update()
                optim.zero_grad()
                scheduler.step()
            
            if step % 100 == 0:
                epoch_time = time.perf_counter()
                print(f"Epoch: {epoch+1} | step: {step}/{len(train_loader)} | loss: {total_loss/(step+1):.4f} | time: {(epoch_time-start_time)/60:.1f} (minutes)")
                
        total_time += (epoch_time-start_time)
            
    print(f"Total time: {total_time/360} (hours)")

In [18]:
import torch.utils.data as data_utils
import gc

# clear cuda
torch.cuda.empty_cache()
gc.collect()
# del model

model = BertForQuestionAnswering.from_pretrained('bert-base-uncased')

# TOOK 45 MIN for 1 EPOCH
# train_model(train_loader)

# TOOK LESS THAN 25 MIN for 1 EPOCH
train_amp_model(train_loader)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForQuestionAnswering: ['cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased a

Epoch: 1 | step: 0/5475 | loss: 6.1989 | time: 0.1 (minutes)




Epoch: 1 | step: 100/5475 | loss: 6.2183 | time: 0.4 (minutes)
Epoch: 1 | step: 200/5475 | loss: 6.1947 | time: 0.7 (minutes)
Epoch: 1 | step: 300/5475 | loss: 6.1534 | time: 1.0 (minutes)
Epoch: 1 | step: 400/5475 | loss: 6.0792 | time: 1.3 (minutes)
Epoch: 1 | step: 500/5475 | loss: 5.9394 | time: 1.6 (minutes)


KeyboardInterrupt: 

In [21]:
model.save_pretrained("./bert-squadv7")
tokenizer.save_pretrained("./bert-squadv7")

('./bert-squadv6/tokenizer_config.json',
 './bert-squadv6/special_tokens_map.json',
 './bert-squadv6/vocab.txt',
 './bert-squadv6/added_tokens.json',
 './bert-squadv6/tokenizer.json')

# Evaluate

In [29]:
def eval_model(eval_loader):
    model.eval()

    # accuracy list
    acc = []

    for batch in tqdm(eval_loader):
        with torch.no_grad():
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            start_true = batch['start_positions'].to(device)
            end_true = batch['end_positions'].to(device)

            outputs = model(input_ids, attention_mask=attention_mask)

            start_pred = torch.argmax(outputs['start_logits'], dim=1)
            end_pred = torch.argmax(outputs['end_logits'], dim=1)

            acc.append(((start_pred == start_true).sum()/len(start_pred)).item())
            acc.append(((end_pred == end_true).sum()/len(end_pred)).item())

    acc = sum(acc)/len(acc)

    print("\n\nT/P\tanswer_start\tanswer_end\n")
    for i in range(len(start_true)):
        print(f"true\t{start_true[i]}\t{end_true[i]}\n"
            f"pred\t{start_pred[i]}\t{end_pred[i]}\n")

In [30]:
model_path = './bert-squadv6'
model = BertForQuestionAnswering.from_pretrained(model_path)
model = model.to(device)

In [31]:
eval_model(eval_loader)

100%|██████████| 5475/5475 [10:43<00:00,  8.51it/s]



T/P	answer_start	answer_end

true	97	98
pred	97	98

true	45	45
pred	45	45

true	1	2
pred	1	2

true	4	4
pred	4	4

true	4	8
pred	4	8

true	21	22
pred	21	22

true	55	90
pred	55	90

true	34	49
pred	34	49

true	111	111
pred	111	111

true	52	52
pred	52	52

true	36	36
pred	36	36

true	16	20
pred	16	19

true	5	5
pred	5	5

true	101	103
pred	101	103

true	78	79
pred	78	79




