# Training BERT Base for Arbitrary 5W Facts from Scratch

In [None]:
import torch
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
from nltk.translate import bleu
import matplotlib.pyplot as plt

from torch.optim import AdamW
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer, BertConfig, BertForMaskedLM, DataCollatorForLanguageModeling
from sklearn.metrics.pairwise import cosine_similarity

In [None]:
data = pd.read_csv('/kaggle/input/dl2-5w-dataset/5W_dataset.csv') 
print(data.shape)

# Tokenizer

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Dataset & Model Setup

In [None]:
N_EPOCH = 20
BATCH_SIZE = 16
WEIGHT_DECAY = 0.01
LEARNING_RATE = 3e-5
MODEL_NAME = 'bert_model'

config = BertConfig(
    vocab_size=tokenizer.vocab_size,
    hidden_size=768,
    num_hidden_layers=12,
    num_attention_heads=12,
    intermediate_size=3072,
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1,
    max_position_embeddings=64,
    type_vocab_size=1,
    initializer_range=0.02
)

In [None]:
class BertDataset(Dataset):
    def __init__(self, df, tokenizer, max_length=64):
        self.padding = 'max_length'
        self.data = df['sentences'].to_list()
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.encodings = self.tokenizer(self.data, padding=self.padding, truncation=True, 
                                        max_length=self.max_length, return_tensors='pt')

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item['labels'] = item['input_ids'].detach().clone()
        return item
    
def get_sentence_embedding(sentence):
    inputs = tokenizer(sentence, return_tensors='pt', truncation=True, padding='max_length', max_length=64).to(device)
    with torch.no_grad():
        outputs = bert_model(**inputs)
    # getting the embeddings of the [CLS] token
    cls_embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy()
    return cls_embedding

def model_predict(sentence_list, printer=False, top_k=3):
    n_correct = 0
    n_correct_top_k = 0
    true_labels = []
    predictions = []
    confidences = []
    
    eval_tokenized = tokenizer(sentence_list, padding='longest', truncation=True, return_tensors='pt')
    
    # randomly masking tokens in each sentence
    for i in range(eval_tokenized['input_ids'].shape[0]):
        n_tokens = sum(eval_tokenized['attention_mask'][i])
        random_mask_idx = random.randint(1,n_tokens-2) # eliminating start & end tokens 
        true_token = eval_tokenized['input_ids'][i, random_mask_idx].item()
        true_labels += [true_token]
        eval_tokenized['input_ids'][i, random_mask_idx] = tokenizer.mask_token_id

    model.eval()
    with torch.no_grad():
        eval_outputs = model(**eval_tokenized.to(device))
    
    eval_logits = eval_outputs['logits']
    for i in range(len(sentence_list)):
        mask_idx = torch.where(eval_tokenized['input_ids'][i]==tokenizer.mask_token_id)[0].item()
        eval_probs = torch.topk(F.softmax(eval_logits[i, mask_idx, :],dim=0), top_k).values    
        eval_preds = torch.topk(eval_logits[i, mask_idx, :], top_k).indices
        eval_preds_tokens = tokenizer.convert_ids_to_tokens(eval_preds)
        predictions.append(eval_preds[0].item())
        confidences.append(eval_probs[0].item())
        if(eval_preds[0].item()==true_labels[i]):
            n_correct += 1
        if(true_labels[i] in eval_preds.tolist()):
            n_correct_top_k += 1
        if(printer):
            print(f"[{i}] SENTENCE.........:", sentence_list[i])
            print(f"[{i}] MASKED TOKEN.....:", tokenizer.convert_ids_to_tokens(true_labels[i]))
            print(f"[{i}] TOP {top_k} MODEL PREDS:", eval_preds_tokens)
            print(f"[{i}] CONFIDENCE:......:", [round(x, 2) for x in eval_probs.tolist()], '\n')
    print('ACCURACY.............:', round(n_correct/len(sentence_list),2))
    print(f"ACCURACY TOP {top_k}.......:", round(n_correct_top_k/len(sentence_list),2))
    return np.array(true_labels), np.array(predictions), np.array(confidences), n_correct/len(sentence_list), round(n_correct_top_k/len(sentence_list),2)

def model_predict_auto(sentence_list, printer=False, temperature=1.0):
    bleu_scores = []
    cos_scores = []
    ate_token_id = tokenizer('ate')['input_ids'][1]
    period_token_id = tokenizer('.')['input_ids'][1]
    mask_token_id = tokenizer.mask_token_id
    eval_tokenized = tokenizer(sentence_list, padding='longest', truncation=True, return_tensors='pt')
    
    for i in range(eval_tokenized['input_ids'].shape[0]):
        mask_idx = torch.where(eval_tokenized['input_ids'][i]==ate_token_id)[0].item() + 1 # right after the 'ate' token
        auto_sentence = torch.cat((eval_tokenized['input_ids'][i, :mask_idx], torch.tensor([mask_token_id])), dim=0).to(device)
        counter = 0
        while True:
            model.eval()
            with torch.no_grad():
                auto_outputs = model(auto_sentence.unsqueeze(0))
            auto_logits = auto_outputs['logits']
            auto_logits = auto_logits[0, mask_idx, :] / temperature
            auto_probs = F.softmax(auto_logits, dim=0)
            auto_token = torch.multinomial(auto_probs, 1).to(device)
            auto_sentence = torch.cat((auto_sentence[:-1], auto_token), dim=0)
            if((counter>15) | (auto_token.item()==period_token_id)):
                break
            auto_sentence = torch.cat((auto_sentence, torch.tensor([mask_token_id]).to(device)), dim=0)
            counter += 1
            mask_idx += 1
        auto_sentence = ''.join(tokenizer.decode(auto_sentence[1:]))
        data['temp_bleu_scores'] = data['sentences'].apply(lambda x: bleu([x.lower().split()], auto_sentence.split(), (1,)))
        bleu_scores.append(data['temp_bleu_scores'].max())
        if(printer):
            print(f"[{i}] GENERATED SENTENCE:", auto_sentence)
            print(f"[{i}] CLOSEST BLEU......:", data['sentences'].iloc[data['temp_bleu_scores'].argmax()])
            print(f"[{i}] BLEU SCORE........:", round(data['temp_bleu_scores'].max(),4))    
        auto_sentence_embeddings = get_sentence_embedding(auto_sentence)
        data['temp_cos_similarity'] = data['embeddings'].apply(lambda x: cosine_similarity(x, auto_sentence_embeddings)[0][0])
        cos_scores.append(data['temp_cos_similarity'].max())
        if(printer):
            print(f"[{i}] CLOSEST COSINE....:", data['sentences'].iloc[data['temp_cos_similarity'].argmax()])
            print(f"[{i}] COSINE SCORE......:", round(data['temp_cos_similarity'].max(),4), '\n')
        del data['temp_bleu_scores'], data['temp_cos_similarity']
    print('AVG MAX BLEU SIM......:', round(sum(bleu_scores)/len(bleu_scores),4))    
    print('AVG MAX COSINE SIM....:', round(sum(cos_scores)/len(cos_scores),4))
    return sum(bleu_scores)/len(bleu_scores), sum(cos_scores)/len(cos_scores)

def calculate_ece(predictions, labels, confidences, n_bins=10):
    bin_boundaries = np.linspace(0, 1, n_bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]
    ece = 0.0
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        in_bin = (confidences > bin_lower) & (confidences <= bin_upper)
        prop_in_bin = np.mean(in_bin)
        if prop_in_bin > 0:
            accuracy_in_bin = np.mean(predictions[in_bin] == labels[in_bin])
            avg_confidence_in_bin = np.mean(confidences[in_bin])
            ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
    return ece.round(4)

def plot_reliability_diagram(predictions, labels, confidences, n_bins=10):
    bin_boundaries = np.linspace(0, 1, n_bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]
    accuracies = (predictions == labels).astype(float)
    bin_accuracies = []
    bin_confidences = []
    
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        in_bin = (confidences > bin_lower) & (confidences <= bin_upper)
        prop_in_bin = np.mean(in_bin)
        if prop_in_bin > 0:
            accuracy_in_bin = np.mean(accuracies[in_bin])
            avg_confidence_in_bin = np.mean(confidences[in_bin])
            bin_accuracies.append(accuracy_in_bin)
            bin_confidences.append(avg_confidence_in_bin)

    plt.plot(bin_confidences, bin_accuracies, marker='o', linewidth=1, label='Reliability Curve')
    plt.plot([0, 1], [0, 1], linestyle='--', label='Perfectly Calibrated')
    plt.xlabel('Confidence')
    plt.ylabel('Perceived Frequency')
    plt.title('Reliability Diagram for 5W Predictions')
    plt.legend()
    plt.savefig(f"{MODEL_NAME}_curve.png", bbox_inches='tight')
    plt.show()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

dataset = BertDataset(data, tokenizer)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.2)
data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, collate_fn=data_collator)

model = BertForMaskedLM(config=config).to(device) # BERT()
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
print('model n_parameters:', model.num_parameters())

# Model Training

In [None]:
val_set = data['sentences'].head(2000).to_list()
training_loss = []
training_acc = []
training_acc_k = []
lowest_loss = 999

model.train()
for epoch in range(N_EPOCH):
    total_loss = 0.0
    for step, batch in enumerate(tqdm(data_loader)):
        model.zero_grad()
        outputs = model(**batch.to(device))
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"epoch [{epoch+1}/{N_EPOCH}] avg loss: {total_loss/len(data_loader):.4f}")
    training_loss += [total_loss/len(data_loader)]
    _, _, _, acc_1, acc_2 = model_predict(val_set, top_k=3)
    training_acc += [acc_1]
    training_acc_k += [acc_2]
torch.save(model.state_dict(), f"{MODEL_NAME}.pth")

In [None]:
x_axis = list(range(1, N_EPOCH+1))
plt.figure(figsize=(12, 6))
plt.plot(x_axis, training_loss, color='r', label='training loss')
plt.plot(x_axis, training_acc, color='b', label='training acc')
plt.plot(x_axis, training_acc_k, color='orange', label='training acc top 3')
[plt.axvline(x=_x, ls='--', lw=0.5, c='b') for _x in range(2,N_EPOCH,2)]
plt.title('BERT 5W - Training Loss / Accuracy')
plt.xlabel('epoch')
plt.ylabel('loss / accuracy')
plt.legend(loc='upper right')
plt.savefig(f"{MODEL_NAME}_loss.png", bbox_inches='tight')
plt.show();

# Model Evaluation

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

config = BertConfig(
    vocab_size=tokenizer.vocab_size,
    hidden_size=768,
    num_hidden_layers=12,
    num_attention_heads=12,
    intermediate_size=3072,
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1,
    max_position_embeddings=64,
    type_vocab_size=1,
    initializer_range=0.02
)

model = BertForMaskedLM(config=config).to(device)
model.load_state_dict(torch.load('/kaggle/input/dl2-5w-bert/bert_model.pth'))

bert_model = BertModel.from_pretrained('bert-base-uncased').to(device)
bert_model.eval()

data['embeddings'] = data['sentences'].apply(lambda x: get_sentence_embedding(x))

In [None]:
eval_set = data['sentences'].head(2000).to_list()

In [None]:
model_predict_auto(eval_set[:200], printer=False, temperature=1.0)

In [None]:
true_labels, predictions, confidences, _, _ = model_predict(eval_set, printer=False, top_k=3)

In [None]:
calculate_ece(predictions, true_labels, confidences)

In [None]:
plot_reliability_diagram(predictions, true_labels, confidences)