In [1]:
import os
import json
import re
import string
import random
import time
import datetime

import numpy as np
import pandas as pd
# import matplotlib.pyplot as plt

from argparse import Namespace
from tqdm.notebook import tqdm

# from datasets import Dataset

import transformers
from transformers import BertTokenizer, BertModel, BertConfig
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import BertForSequenceClassification, AdamW
from transformers import get_linear_schedule_with_warmup
from transformers import pipeline
from transformers import BertTokenizer, DataCollatorForLanguageModeling

import torch.nn.functional as F
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset, Dataset

from sklearn.feature_extraction.text import TfidfVectorizer,CountVectorizer
from sklearn.feature_extraction import text
from sklearn.metrics.pairwise import cosine_similarity, linear_kernel
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score

In [3]:
args = Namespace(
    data_path = './processed_data/combined_ir.csv',
    base_model_path='bert-base-uncased',
    model_save_path = './models/basebert_combined',
    max_samples=15000,
    train_split=0.7,
    epochs = 5
)

## Pre-process Data 

In [4]:
df = pd.read_csv(args.data_path)[0:args.max_samples]
num_train_rows = int(len(df) * (1 - args.train_split)//2) - 1
df['split'] = 'train'
df.loc[:num_train_rows, 'split'] = 'val'
df.loc[num_train_rows:num_train_rows + num_train_rows, 'split'] = 'test'

In [5]:
tokenizer = BertTokenizer.from_pretrained('casehold/legalbert')

In [6]:
####################################################
############## Setup Train Dataloader ##############
####################################################

encoded_data_train = [tokenizer.encode_plus(row['query'], row['result'], add_special_tokens=True, max_length=512, pad_to_max_length=True, truncation=True) for index,row in df[df['split'] == 'train'].iterrows()]
input_ids_train = [item['input_ids'] for item in encoded_data_train]
attention_masks_train = [item['attention_mask'] for item in encoded_data_train]
labels_train = [row['label'] for index,row in df[df['split'] == 'train'].iterrows()]

# Convert to tensors
input_ids_train = torch.tensor(input_ids_train)
attention_masks_train = torch.tensor(attention_masks_train)
labels_train = torch.tensor(labels_train)


dataset_train = TensorDataset(input_ids_train, attention_masks_train, labels_train)

dataloader_train = DataLoader(dataset_train, batch_size=16, shuffle=True) # NOTE : maybe set pin_memory=True

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

In [7]:
####################################################
############## Setup Val Dataloader ################
####################################################

encoded_data_val = [tokenizer.encode_plus(row['query'], row['result'], add_special_tokens=True, max_length=512, pad_to_max_length=True, truncation=True) for index,row in df[df['split'] == 'val'].iterrows()]
input_ids_val = [item['input_ids'] for item in encoded_data_val]
attention_masks_val = [item['attention_mask'] for item in encoded_data_val]
labels_val = [row['label'] for index,row in df[df['split'] == 'val'].iterrows()]

# Convert to tensors
input_ids_val = torch.tensor(input_ids_val)
attention_masks_val = torch.tensor(attention_masks_val)
labels_val = torch.tensor(labels_val)

# Create a dataset
dataset_val = TensorDataset(input_ids_val, attention_masks_val, labels_val)

dataloader_val = DataLoader(dataset_val, batch_size=16, shuffle=True) # NOTE : maybe set pin_memory=True

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

In [8]:
###################################################
############## Setup Test Dataloader ##############
###################################################

encoded_data_test = [tokenizer.encode_plus(row['query'], row['result'], add_special_tokens=True, max_length=512, pad_to_max_length=True, truncation=True) for index,row in df[df['split'] == 'test'].iterrows()]
input_ids_test = [item['input_ids'] for item in encoded_data_test]
attention_masks_test = [item['attention_mask'] for item in encoded_data_test]
labels_test = [row['label'] for index,row in df[df['split'] == 'test'].iterrows()]

# Convert to tensors
input_ids_test = torch.tensor(input_ids_test)
attention_masks_test = torch.tensor(attention_masks_test)
labels_test = torch.tensor(labels_test)

# Create a dataset
dataset_test = TensorDataset(input_ids_test, attention_masks_test, labels_test)

dataloader_test = DataLoader(dataset_test, batch_size=16, shuffle=True) # NOTE : maybe set pin_memory=True

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

## Training 

In [9]:
from transformers import AdamW

device = 'cpu'

if torch.cuda.is_available():
    device = 'cuda'

model = BertForSequenceClassification.from_pretrained(args.base_model_path)


model.to(device)
model.train()
optimizer = AdamW(model.parameters(), lr=5e-5)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=len(dataloader_train) * 0.05, num_training_steps=len(dataloader_train) * args.epochs)
print(torch.cuda.memory_allocated())
print(torch.cuda.memory_reserved())

Some weights of the model checkpoint at ./models/parallel_three were not used when initializing BertForSequenceClassification: ['ir_head.0.bias', 'ir_head.0.weight', 'ir_head.2.bias', 'ir_head.2.weight', 'mlm_head.bias', 'mlm_head.weight', 'qa_head.0.bias', 'qa_head.0.weight', 'qa_head.2.bias', 'qa_head.2.weight', 'sts_head.0.bias', 'sts_head.0.weight', 'sts_head.2.bias', 'sts_head.2.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ./models/parall

439076352
494927872




In [10]:
#########################################################
############ Train Routine Utility Functions ############
#########################################################

def compute_accuracy(y_pred, y_target):
    y_target = y_target.cpu()
    y_pred_indices = (torch.sigmoid(y_pred)>0.5).cpu().long()#.max(dim=1)[1]
    n_correct = torch.eq(y_pred_indices, y_target).sum().item()
    return n_correct / len(y_pred_indices) * 100


def calculate_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)

In [11]:
train_progress = tqdm(total=0, desc='Train Batches', leave=True)
validation_progress = tqdm(total=0, desc='Validation Batches', leave=True)
epoch_progress = tqdm(total=args.epochs, desc='Epoch', leave=True)

best_val_accuracy = 0.0
patience = 3
num_epochs_no_improvement = 0


loss_fn = torch.nn.CrossEntropyLoss()

for epoch in range(args.epochs):

  model.train()
  total_train_loss = 0
  total_train_accuracy = 0 #NEW

  train_progress.reset(total=len(dataloader_train))
  validation_progress.reset(total=len(dataloader_val))

  for step, batch in enumerate(dataloader_train):
    b_input_ids, b_input_mask, b_labels = batch
    b_input_ids, b_input_mask, b_labels = b_input_ids.to(device), b_input_mask.to(device), b_labels.to(device)

    model.zero_grad()
    outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels)

    loss = outputs.loss
    total_train_loss += loss.item()
    loss.backward()

    logits = outputs.logits.detach().cpu().numpy()#NEW
    label_ids = b_labels.to('cpu').numpy()#NEW
    total_train_accuracy += calculate_accuracy(logits, label_ids)#NEW

    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    scheduler.step()

    train_progress.update(1)

  avg_train_loss = total_train_loss / len(dataloader_train)
  print(f'Epoch {epoch}: Average Training Loss: {avg_train_loss}')

  model.eval()
  total_eval_accuracy = 0
  total_eval_loss = 0

  for batch in dataloader_val:
    b_input_ids, b_input_mask, b_labels = batch
    b_input_ids, b_input_mask, b_labels = b_input_ids.to(device), b_input_mask.to(device), b_labels.to(device)

    with torch.no_grad():
      outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask)

    logits = outputs.logits.detach().cpu().numpy()
    label_ids = b_labels.to('cpu').numpy()
    total_eval_loss += loss_fn(outputs.logits.squeeze(-1), b_labels).item() # perhaps just outputs.loss (need to include labels as parameter in model() above)

    total_eval_accuracy += calculate_accuracy(logits, label_ids)

    validation_progress.update(1)

  avg_val_accuracy = total_eval_accuracy / len(dataloader_val)
  print(f'Epoch {epoch}: Validation Accuracy: {avg_val_accuracy}')

  # Checkpointing and Early Stopping
  if avg_val_accuracy > best_val_accuracy:
      print(f'Validation accuracy improved from {best_val_accuracy} to {avg_val_accuracy}. Saving model...')
      best_val_accuracy = avg_val_accuracy
      num_epochs_no_improvement = 0
      # Save the model using save_pretrained
      model.save_pretrained(args.model_save_path)
      # Optionally save the tokenizer if it's being fine-tuned or used
      # tokenizer.save_pretrained(checkpoint_path)
  else:
      num_epochs_no_improvement += 1
      if num_epochs_no_improvement >= patience:
          print("Early stopping triggered.")
          break  # Exit the training loop

  epoch_progress.update(1)


  

Train Batches: 0it [00:00, ?it/s]

Validation Batches: 0it [00:00, ?it/s]

Epoch:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 0: Average Training Loss: 0.3061172264703242
Epoch 0: Validation Accuracy: 0.8984929078014184
Validation accuracy improved from 0.0 to 0.8984929078014184. Saving model...
Epoch 1: Average Training Loss: 0.25797750005892606
Epoch 1: Validation Accuracy: 0.906570133963751
Validation accuracy improved from 0.8984929078014184 to 0.906570133963751. Saving model...
Epoch 2: Average Training Loss: 0.2169481064738703
Epoch 2: Validation Accuracy: 0.9113475177304965
Validation accuracy improved from 0.906570133963751 to 0.9113475177304965. Saving model...
Epoch 3: Average Training Loss: 0.1636782060348949
Epoch 3: Validation Accuracy: 0.9011524822695035
Epoch 4: Average Training Loss: 0.11439333190642513
Epoch 4: Validation Accuracy: 0.9035657998423955


## Evaluation 

In [12]:
from transformers import BertTokenizer, BertModel, BertConfig
def evaluate_sequence_pair_class(model_path,  title):
    '''
    Routine for evaluating model for sequence pair classification
    '''
    batch_progress = tqdm(total=len(dataloader_test), desc='Batches', leave=True)

    # load model and tokenizer
    model = AutoModelForSequenceClassification.from_pretrained(model_path)
    # config = BertConfig.from_pretrained('./models/mlm_model')
    # config.num_labels = 2
    # model = BertForSequenceClassification(config=config)
    # bert_state_dict = torch.load(model_path)
    # model.bert.load_state_dict(bert_state_dict)

    device = 'cpu'

    # Check if cuda available
    if torch.cuda.is_available():
        # model.to('cuda')
        device = 'cuda'

    print(device)

    model.to(device)

    model.eval()

    predictions, true_labels = [], []

    print('Evaluating ' + f'[{title}]')
    print('============================================')


    with torch.no_grad(): # disable calculating gradients (more efficient for evaluation)
        for batch in dataloader_test:
            input_ids, attention_mask, labels = tuple(t.to(device) for t in batch)
            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            preds = torch.argmax(logits, dim=1).flatten() # find index of max value in logits tensor (where each index corresponds to a binary class)

            predictions.extend(preds.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())
            batch_progress.update(1)


    # Calculate metrics
    accuracy = accuracy_score(true_labels, predictions)
    precision, recall, f1, acc = precision_recall_fscore_support(true_labels, predictions, average='binary')
    # roc_auc = roc_auc_score(true_labels, predictions)  # Uncomment if ROC-AUC is needed
    print(acc)
    print(f'Accuracy: {accuracy}\nPrecision: {precision}\nRecall: {recall}\nF1 Score: {f1}')



In [14]:
evaluate_sequence_pair_class('./models/parallel_combined',  'Sequence Pair Classificaiton Evaluation Metrics')

Batches:   0%|          | 0/141 [00:00<?, ?it/s]

cuda
Evaluating [Sequence Pair Classificaiton Evaluation Metrics]
None
Accuracy: 0.9022222222222223
Precision: 0.9657320872274143
Recall: 0.7579462102689487
F1 Score: 0.8493150684931506


In [13]:
evaluate_sequence_pair_class('./models/parallel_combined_mlm',  'Sequence Pair Classificaiton Evaluation Metrics')

Batches:   0%|          | 0/141 [00:00<?, ?it/s]

cuda
Evaluating [Sequence Pair Classificaiton Evaluation Metrics]
None
Accuracy: 0.9053333333333333
Precision: 0.9886914378029079
Recall: 0.7481662591687042
F1 Score: 0.8517745302713987


: 

In [10]:
evaluate_sequence_pair_class('./models/mlm_model_manual1',  'Sequence Pair Classificaiton Evaluation Metrics')

Batches:   0%|          | 0/125 [00:00<?, ?it/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ./models/mlm_model_manual1 and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


cuda
Evaluating [Sequence Pair Classificaiton Evaluation Metrics]
None
Accuracy: 0.411
Precision: 0.3494897959183674
Recall: 0.7762039660056658
F1 Score: 0.48197009674582236


In [11]:
evaluate_sequence_pair_class('casehold/legalbert',  'Sequence Pair Classificaiton Evaluation Metrics')

Batches:   0%|          | 0/125 [00:00<?, ?it/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at casehold/legalbert and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


cuda
Evaluating [Sequence Pair Classificaiton Evaluation Metrics]
None
Accuracy: 0.3325
Precision: 0.3108839446782922
Recall: 0.7322946175637394
F1 Score: 0.43647108484592656


In [12]:
evaluate_sequence_pair_class('./models/parallel_three_mlm',  'Sequence Pair Classificaiton Evaluation Metrics')

Batches:   0%|          | 0/125 [00:00<?, ?it/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ./models/parallel_three_mlm and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


cuda
Evaluating [Sequence Pair Classificaiton Evaluation Metrics]
None
Accuracy: 0.6335
Precision: 0.23529411764705882
Recall: 0.0169971671388102
F1 Score: 0.031704095112285335


In [13]:
evaluate_sequence_pair_class('./models/sentence_pair_classification',  'Sequence Pair Classificaiton Evaluation Metrics')

Batches:   0%|          | 0/125 [00:00<?, ?it/s]

cuda
Evaluating [Sequence Pair Classificaiton Evaluation Metrics]
None
Accuracy: 0.6875
Precision: 0.7612903225806451
Recall: 0.1671388101983003
F1 Score: 0.27409988385598144


In [14]:
evaluate_sequence_pair_class('./models/mlm_casehold',  'Sequence Pair Classificaiton Evaluation Metrics')

Batches:   0%|          | 0/125 [00:00<?, ?it/s]

cuda
Evaluating [Sequence Pair Classificaiton Evaluation Metrics]
None
Accuracy: 0.676
Precision: 0.6074074074074074
Recall: 0.23229461756373937
F1 Score: 0.3360655737704918


In [15]:
evaluate_sequence_pair_class('./models/casehold_mlm',  'Sequence Pair Classificaiton Evaluation Metrics')

Batches:   0%|          | 0/125 [00:00<?, ?it/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ./models/casehold_mlm and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


cuda
Evaluating [Sequence Pair Classificaiton Evaluation Metrics]
None
Accuracy: 0.475
Precision: 0.39858490566037735
Recall: 0.9575070821529745
F1 Score: 0.5628642797668609


In [17]:
evaluate_sequence_pair_class('./models/mlm_sts',  'Sequence Pair Classificaiton Evaluation Metrics')

Batches:   0%|          | 0/125 [00:00<?, ?it/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ./models/mlm_sts and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


cuda
Evaluating [Sequence Pair Classificaiton Evaluation Metrics]
None
Accuracy: 0.647
Precision: 0.0
Recall: 0.0
F1 Score: 0.0


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


: 

In [12]:
evaluate_sequence_pair_class('./models/parallel_three_mlm_casehold',  'Sequence Pair Classificaiton Evaluation Metrics')

Batches:   0%|          | 0/125 [00:00<?, ?it/s]

cpu
Evaluating [Sequence Pair Classificaiton Evaluation Metrics]
None
Accuracy: 0.697
Precision: 0.6700680272108843
Recall: 0.2790368271954674
F1 Score: 0.394
