# BERT Model: Manual Fine-tuning

#### Imports:

In [1]:
import json
import pickle
import numpy as np

from tqdm import tqdm
from tqdm.autonotebook import tqdm
from collections import Counter

from datasets import load_dataset

import torch
from torch.optim import AdamW
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

import torch.nn as nn
from torch.nn import CrossEntropyLoss

from transformers import AutoTokenizer, AutoModel
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import precision_recall_fscore_support, accuracy_score

from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType

  from tqdm.autonotebook import tqdm


#### Data Loading and Preperation:

In [2]:
# Load dataset from json file:
data_file_path ='./data/biloc_tagged_sequences.json'
datasets = load_dataset('json', data_files=data_file_path, field='data')

# Paramters for dataset train-test-split function: 
# Sets train-test split and seed of data shuffle
test_size=0.15
random_seed=42

# Split dataset into train and test sets:
datasets = datasets['train'].train_test_split(test_size=test_size, seed=random_seed)
print("Dataset Structure:")
print(datasets)

Dataset Structure:
DatasetDict({
    train: Dataset({
        features: ['id', 'ner_tags', 'split_tokens'],
        num_rows: 8646
    })
    test: Dataset({
        features: ['id', 'ner_tags', 'split_tokens'],
        num_rows: 1526
    })
})


#### Tokenize Data for BERT Model:

In [3]:
# Load in BERT tokenizer bert-base-cased:
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

In [4]:
# Deals with special tokens and ensures correct label alignment:
# Helps with tokenization due to dataset format
def tokenize_and_align_labels(tokenizer, examples):
    
    tokenized_inputs = tokenizer(examples["split_tokens"], truncation=True, padding="max_length", 
                                 is_split_into_words=True, return_tensors="pt")
    labels = []
    for i, label in enumerate(examples["ner_tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        label_ids = [-100 if word_id is None else label[word_id] for word_id in word_ids]
        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs

# Converts batch input to tensor
def convert_to_tensors(batch):
    batch_tensors = {key: tensor(value) for key, value in batch.items()}

In [5]:
# Tokenize Dataset
tokenized_datasets = datasets.map(lambda examples: tokenize_and_align_labels(tokenizer, examples), batched=True)

#### Tokenized Dataset Formatting for Model:

In [6]:
# Format dataset for use with Pytorch:
tokenized_datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'token_type_ids', 'labels'])

In [7]:
# Create Pytorch DataLoader Objects for Train and Test Sets:
train_dataset = tokenized_datasets["train"]
test_dataset = tokenized_datasets["test"]

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=2)

#### Load in pre-trained BERT Model:

In [8]:
# Model Parameters:
model_name = "bert-base-cased"  
num_labels = 165

# Loads in default model from HuggingFace:
bert_model = AutoModel.from_pretrained(model_name, num_labels=num_labels)

#### Model, Fine-tuner, and Optimizer:

In [9]:
# Inherits from pytorch.nn.module to add custom fine-tuning to model:
class CustomNERModel(nn.Module):
    def __init__(self, bert_model, num_labels):
        super(CustomNERModel, self).__init__()
        self.bert = bert_model  # The BERT model
        self.classifier = nn.Linear(bert_model.config.hidden_size, num_labels)  # Classifier
        self.config = bert_model.config

    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        sequence_output = outputs.last_hidden_state
        logits = self.classifier(sequence_output)
        return logits


In [10]:
# Add LoRA to the model via PEFT if Boolean True:
add_LoRA = True

if add_LoRA:
    target_modules = ["query", "value"]
    peft_config = LoraConfig(task_type=TaskType.FEATURE_EXTRACTION, inference_mode=False, 
                             r=16, lora_alpha=8, lora_dropout=0.1, bias="all",
                             target_modules=target_modules)

    model = get_peft_model(bert_model, peft_config)
else:
    model = bert_model

# Initialization of custom fine-tuned BERT model
model = CustomNERModel(model, num_labels)

In [11]:
# Intialization of Optimizer and Loss Function:
optimizer = AdamW(model.parameters(), lr=5e-5)
loss_function = nn.CrossEntropyLoss()

In [12]:
# Set model device to gpu if available.
device = torch.device("cuda" if torch.cuda.is_available() else torch.device("cpu"))
model.to(device)
print(device)

cuda


#### Class Weighting:

In [13]:
# Assuming you know the total list of classes, including those not present in training
total_classes = np.arange(num_labels)  # num_labels is the total number of unique labels you have

# Use only the labels present in the training dataset to calculate class weights
present_labels = [label for sublist in datasets["train"]["ner_tags"] for label in sublist]
present_labels_unique = np.unique(present_labels)

# Compute class weights only for present classes
class_weights = compute_class_weight('balanced', classes=present_labels_unique, y=present_labels)

# Initialize a weights array with ones (default weight) for all classes
weights_array = np.ones(num_labels)

# Update the weights array for the present classes with the computed weights
weights_array[present_labels_unique] = class_weights

# Convert to a tensor
class_weights_tensor = torch.tensor(weights_array, dtype=torch.float).to(device)

# Use this tensor in your loss function
loss_function = CrossEntropyLoss(weight=class_weights_tensor)

#### Focal Loss Function:

In [14]:
class FocalLoss(torch.nn.Module):
    def __init__(self, alpha=None, gamma=2.0, reduction='mean', num_classes=None, device='cpu'):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.reduction = reduction
        if alpha is not None:
            # Directly use the alpha tensor passed as an argument
            self.alpha = alpha.to(device)
        else:
            # If no alpha is provided, initialize a uniform alpha tensor
            self.alpha = torch.ones(num_classes, device=device) * 0.25

    def forward(self, inputs, targets):
        BCE_loss = F.cross_entropy(inputs, targets, reduction='none')
        targets = targets.type(torch.long)

        # Apply alpha factors based on class indices
        at = self.alpha[targets]

        pt = torch.exp(-BCE_loss)
        F_loss = at * (1 - pt) ** self.gamma * BCE_loss

        if self.reduction == 'mean':
            return torch.mean(F_loss)
        elif self.reduction == 'sum':
            return torch.sum(F_loss)
        else:
            return F_loss

In [15]:
num_labels = 165  # Your number of classes
alpha = torch.ones(num_labels) * 0.5
alpha[0] = 0.05
focal_loss_function = FocalLoss(alpha=alpha, gamma=2.0, num_classes=num_labels, device=device).to(device)

#### Stat-Collection:

In [16]:
# Checkpoint Function:
def checkpoint(model, filename):
    torch.save(model.state_dict(), filename)
    
def resume(model, filename):
    model.load_state_dict(torch.load(filename))
    
def calculate_accuracy(logits, labels):
    """
    Calculate accuracy by comparing logits vs labels
    """
    predictions = torch.argmax(logits, dim=-1)
    mask = labels != -100  # Assuming -100 is used to ignore tokens in the labels
    correct_predictions = torch.sum((predictions == labels) * mask)
    total_relevant_elements = torch.sum(mask)
    accuracy = (correct_predictions.float() / total_relevant_elements.float()).item()
    return accuracy

#### Model Training:

In [17]:
# Boolean to run training procedure; otherwise, loads in previous model checkpoint
train_mode = False

In [18]:
avg_acc = []
if train_mode:
    num_epochs = 25
    save_path = './checkpoints/model_focal_large_LoRA'

    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        total_accuracy = 0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
        for batch in progress_bar:
            # Load Inputs and Labels from batch
            inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
            labels = batch['labels'].to(device)
            
            # Get outputs from model:
            optimizer.zero_grad()
            outputs = model(**inputs)
            logits = outputs
            
            # Calculate Loss, Backpropogate, update Optimizer:
            loss = focal_loss_function(logits.view(-1, num_labels), labels.view(-1))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            
            # Calculate accuracy for each batch and accumulate
            batch_accuracy = calculate_accuracy(logits.view(-1, num_labels), labels.view(-1))
            total_accuracy += batch_accuracy
            
            progress_bar.set_postfix(loss=f"{loss.item():.4f}")
            
        # Average accuracy over all batches
        avg_accuracy = total_accuracy / len(train_loader)
        avg_acc.append(avg_accuracy)
        
        # Generates Model Checkpoint Every 10 Epochs
        if (epoch + 1) % 5 == 0:
            checkpoint(model, save_path + str(epoch + 1))

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}, Average Accuracy: {avg_accuracy:.4f}")
else:
    # Loads in previous model checkpoint
    model_path = './checkpoints/model_focal_sampled_LoRA10'
    resume(model, model_path)

In [19]:
print(model.config.hidden_size)

768


#### Model Prediction:

In [20]:
true_labels_list = []
pred_labels_list = []

with torch.no_grad(): 
    for batch in test_loader:
        inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
        labels = batch['labels'].to(device)
        outputs = model(**inputs)
        logits = outputs
        predictions = torch.argmax(logits, dim=-1)
        predictions = predictions.detach().cpu().numpy()
        labels = labels.detach().cpu().numpy()
        
        true_labels_list.append(labels)
        pred_labels_list.append(predictions)

true_labels_flat = np.concatenate(true_labels_list, axis=None)
pred_labels_flat = np.concatenate(pred_labels_list, axis=None)


#### Prediction Analysis: 

In [21]:
mask = true_labels_flat != -100 
true_labels_filtered = true_labels_flat[mask]
pred_labels_filtered = pred_labels_flat[mask]

precision, recall, f1, _ = precision_recall_fscore_support(true_labels_filtered,
                                                           pred_labels_filtered,
                                                           average="macro")
accuracy = accuracy_score(true_labels_filtered, pred_labels_filtered)

print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")
print("--------------------------------")

precision, recall, f1, _ = precision_recall_fscore_support(true_labels_filtered,
                                                           pred_labels_filtered,
                                                           zero_division=0)

print(f1)

Accuracy: 0.7685
Precision: 0.0824
Recall: 0.1331
F1 Score: 0.0839
--------------------------------
[0.87708948 0.2        0.16666667 0.33581165 0.         0.25742574
 0.21276596 0.31976314 0.46459512 0.32982456 0.35326087 0.31658291
 0.         0.28571429 0.11232449 0.29487179 0.         0.08333333
 0.3238155  0.0212766  0.         0.34245115 0.         0.
 0.20509194 0.         0.3853211  0.35035035 0.17085427 0.
 0.         0.         0.         0.         0.         0.
 0.083986   0.         0.         0.12329969 0.         0.
 0.         0.         0.         0.37908497 0.         0.
 0.02919708 0.         0.         0.18670395 0.         0.
 0.19876204 0.         0.         0.06304729 0.         0.
 0.31963678 0.02083333 0.         0.24983097 0.         0.
 0.         0.         0.         0.11974808 0.         0.
 0.09691961 0.         0.         0.19499124 0.         0.
 0.01520913 0.         0.         0.20883326 0.0212766  0.
 0.22477064 0.         0.         0.         0.   

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


#### Model Parameters:

In [22]:
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
non_trainable_params = total_params - trainable_params

print(f"Total Parameters: {total_params}")
print(f"Trainable Parameters: {trainable_params}")
print(f"Non-trainable Parameters: {non_trainable_params}")


Total Parameters: 109026981
Trainable Parameters: 819621
Non-trainable Parameters: 108207360


#### Save True and Predicted Labels for Analysis

In [23]:
with open("./data/true_labels.ob", 'wb') as fp:
    pickle.dump(true_labels_filtered, fp)
    
with open("./data/pre_labels.ob", 'wb') as fp:
    pickle.dump(pred_labels_filtered, fp)