In [1]:
import os
import json
import re
import string
import random
import time
import datetime

import numpy as np
import pandas as pd
# import matplotlib.pyplot as plt

from argparse import Namespace
from tqdm.notebook import tqdm

# from datasets import Dataset

import transformers
from transformers import BertTokenizer, BertModel, BertConfig
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import BertForSequenceClassification, AdamW
from transformers import get_linear_schedule_with_warmup
from transformers import pipeline
from transformers import BertTokenizer, DataCollatorForLanguageModeling

import torch.nn.functional as F
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset, Dataset

from sklearn.feature_extraction.text import TfidfVectorizer,CountVectorizer
from sklearn.feature_extraction import text
from sklearn.metrics.pairwise import cosine_similarity, linear_kernel
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score

2024-03-25 17:40:22.281417: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
args = Namespace(
    data_path = './processed_data/combined_ir.csv',
    base_model_path='jimmyjz1127/multi_parallel',
    model_save_path = './models/multi_parallel_mlm_test2',
    max_samples=15000,
    train_split=0.7,
    epochs = 5,
    freeze=False
)

## Pre-process Data 

In [3]:
df = pd.read_csv(args.data_path)[0:args.max_samples]
num_train_rows = int(len(df) * (1 - args.train_split)//2) - 1
df['split'] = 'train'
df.loc[:num_train_rows, 'split'] = 'val'
df.loc[num_train_rows:num_train_rows + num_train_rows, 'split'] = 'test'

In [4]:
print('Length of Train Split      :', len(df[df['split'] == 'train']))
print('Length of Train Validation :', len(df[df['split'] == 'val']))
print('Length of Train Test       :', len(df[df['split'] == 'test']))

Length of Train Split      : 10501
Length of Train Validation : 2249
Length of Train Test       : 2250


In [5]:
tokenizer = BertTokenizer.from_pretrained('casehold/legalbert')

In [6]:
####################################################
############## Setup Train Dataloader ##############
####################################################

encoded_data_train = [tokenizer.encode_plus(row['query'], row['result'], add_special_tokens=True, max_length=512, pad_to_max_length=True, truncation=True) for index,row in df[df['split'] == 'train'].iterrows()]
input_ids_train = [item['input_ids'] for item in encoded_data_train]
attention_masks_train = [item['attention_mask'] for item in encoded_data_train]
labels_train = [row['label'] for index,row in df[df['split'] == 'train'].iterrows()]

# Convert to tensors
input_ids_train = torch.tensor(input_ids_train)
attention_masks_train = torch.tensor(attention_masks_train)
labels_train = torch.tensor(labels_train)


dataset_train = TensorDataset(input_ids_train, attention_masks_train, labels_train)

dataloader_train = DataLoader(dataset_train, batch_size=16, shuffle=True) # NOTE : maybe set pin_memory=True

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

In [7]:
####################################################
############## Setup Val Dataloader ################
####################################################

encoded_data_val = [tokenizer.encode_plus(row['query'], row['result'], add_special_tokens=True, max_length=512, pad_to_max_length=True, truncation=True) for index,row in df[df['split'] == 'val'].iterrows()]
input_ids_val = [item['input_ids'] for item in encoded_data_val]
attention_masks_val = [item['attention_mask'] for item in encoded_data_val]
labels_val = [row['label'] for index,row in df[df['split'] == 'val'].iterrows()]

# Convert to tensors
input_ids_val = torch.tensor(input_ids_val)
attention_masks_val = torch.tensor(attention_masks_val)
labels_val = torch.tensor(labels_val)

# Create a dataset
dataset_val = TensorDataset(input_ids_val, attention_masks_val, labels_val)

dataloader_val = DataLoader(dataset_val, batch_size=16, shuffle=True) # NOTE : maybe set pin_memory=True

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.


Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

In [8]:
###################################################
############## Setup Test Dataloader ##############
###################################################

encoded_data_test = [tokenizer.encode_plus(row['query'], row['result'], add_special_tokens=True, max_length=512, pad_to_max_length=True, truncation=True) for index,row in df[df['split'] == 'test'].iterrows()]
input_ids_test = [item['input_ids'] for item in encoded_data_test]
attention_masks_test = [item['attention_mask'] for item in encoded_data_test]
labels_test = [row['label'] for index,row in df[df['split'] == 'test'].iterrows()]

# Convert to tensors
input_ids_test = torch.tensor(input_ids_test)
attention_masks_test = torch.tensor(attention_masks_test)
labels_test = torch.tensor(labels_test)

# Create a dataset
dataset_test = TensorDataset(input_ids_test, attention_masks_test, labels_test)

dataloader_test = DataLoader(dataset_test, batch_size=16, shuffle=True) # NOTE : maybe set pin_memory=True

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

## Training 

In [9]:
#########################################################
############ Train Routine Utility Functions ############
#########################################################

def compute_accuracy(y_pred, y_target):
    y_target = y_target.cpu()
    y_pred_indices = (torch.sigmoid(y_pred)>0.5).cpu().long()#.max(dim=1)[1]
    n_correct = torch.eq(y_pred_indices, y_target).sum().item()
    return n_correct / len(y_pred_indices) * 100


def calculate_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)

### The fine-tune routine 

In [None]:


from transformers import AdamW

'''
TRAINING SETUP
'''
# Check if CUDA acceleration is available 
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'

# Initialize model with sequence classification head 
model = BertForSequenceClassification.from_pretrained(args.base_model_path)

# Conditoinally freeze core encoder layers
if args.freeze:
  for param in model.base_model.parameters():
    param.requires_grad  = False
  model.classifier = torch.nn.Linear(model.config.hidden_size, 2)
  optimizer = AdamW(model.classifier.parameters(), lr=5e-5)
else:
  optimizer = AdamW(model.parameters(), lr=5e-5)

model.to(device)
model.train()

scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=len(dataloader_train) * 0.05, num_training_steps=len(dataloader_train) * args.epochs)
print(torch.cuda.memory_allocated())
print(torch.cuda.memory_reserved())

print('============================================')

train_progress = tqdm(total=0, desc='Train Batches', leave=True)
validation_progress = tqdm(total=0, desc='Validation Batches', leave=True)
epoch_progress = tqdm(total=args.epochs, desc='Epoch', leave=True)

best_val_accuracy = 0.0
patience = 3
num_epochs_no_improvement = 0


loss_fn = torch.nn.CrossEntropyLoss()

'''
TRAINING LOOP
'''

for epoch in range(args.epochs):

  model.train()
  total_train_loss = 0
  total_train_accuracy = 0 #NEW

  train_progress.reset(total=len(dataloader_train))
  validation_progress.reset(total=len(dataloader_val))

  for step, batch in enumerate(dataloader_train):
    b_input_ids, b_input_mask, b_labels = batch
    b_input_ids, b_input_mask, b_labels = b_input_ids.to(device), b_input_mask.to(device), b_labels.to(device)

    model.zero_grad()
    outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels)

    loss = outputs.loss
    total_train_loss += loss.item()
    loss.backward()

    logits = outputs.logits.detach().cpu().numpy()#NEW
    label_ids = b_labels.to('cpu').numpy()#NEW
    total_train_accuracy += calculate_accuracy(logits, label_ids)#NEW

    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    scheduler.step()

    train_progress.update(1)

  avg_train_loss = total_train_loss / len(dataloader_train)
  print(f'Epoch {epoch}: Average Training Loss: {avg_train_loss}')

  model.eval()
  total_eval_accuracy = 0
  total_eval_loss = 0

  for batch in dataloader_val:
    b_input_ids, b_input_mask, b_labels = batch
    b_input_ids, b_input_mask, b_labels = b_input_ids.to(device), b_input_mask.to(device), b_labels.to(device)

    with torch.no_grad():
      outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask)

    logits = outputs.logits.detach().cpu().numpy()
    label_ids = b_labels.to('cpu').numpy()
    total_eval_loss += loss_fn(outputs.logits.squeeze(-1), b_labels).item() # perhaps just outputs.loss (need to include labels as parameter in model() above)

    total_eval_accuracy += calculate_accuracy(logits, label_ids)

    validation_progress.update(1)

  avg_val_accuracy = total_eval_accuracy / len(dataloader_val)
  print(f'Epoch {epoch}: Validation Accuracy: {avg_val_accuracy}')

  # Checkpointing and Early Stopping
  if avg_val_accuracy > best_val_accuracy:
      print(f'Validation accuracy improved from {best_val_accuracy} to {avg_val_accuracy}. Saving model...')
      best_val_accuracy = avg_val_accuracy
      num_epochs_no_improvement = 0
      # Save the model using save_pretrained
      model.save_pretrained(args.model_save_path)
  else:
      num_epochs_no_improvement += 1
      if num_epochs_no_improvement >= patience:
          print("Early stopping triggered.")
          break  # Exit the training loop

  epoch_progress.update(1)

## Evaluation 

In [9]:
from transformers import BertTokenizer, BertModel, BertConfig, AutoModel
def evaluate_sequence_pair_class(model_path,  title):
    '''
    Routine for evaluating model for sequence pair classification
    '''
    batch_progress = tqdm(total=len(dataloader_test), desc='Batches', leave=True)

    # load model and tokenizer
    model = AutoModelForSequenceClassification.from_pretrained(model_path)

    device = 'cpu'

    # Check if cuda available
    if torch.cuda.is_available():
        # model.to('cuda')
        device = 'cuda'

    print(device)

    model.to(device)

    model.eval()

    predictions, true_labels = [], []

    print('Evaluating ' + f'[{title}]')
    print('============================================')


    with torch.no_grad(): # disable calculating gradients (more efficient for evaluation)
        for batch in dataloader_test:
            input_ids, attention_mask, labels = tuple(t.to(device) for t in batch)
            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            preds = torch.argmax(logits, dim=1).flatten() # find index of max value in logits tensor (where each index corresponds to a binary class)

            predictions.extend(preds.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())
            batch_progress.update(1)


    # Calculate metrics
    accuracy = accuracy_score(true_labels, predictions)
    precision, recall, f1, acc = precision_recall_fscore_support(true_labels, predictions, average='binary')
    print(f'Accuracy: {accuracy}\nPrecision: {precision}\nRecall: {recall}\nF1 Score: {f1}')



### Textual Entailment Model 

In [15]:
evaluate_sequence_pair_class('jimmyjz1127/te_model_test',  'TE Model')

Batches:   0%|          | 0/141 [00:00<?, ?it/s]

cuda
Evaluating [TE Model]
Accuracy: 0.8768888888888889
Precision: 0.857331571994716
Recall: 0.793398533007335
F1 Score: 0.8241269841269842


### Vanilla bert-base-uncased Model

In [28]:
evaluate_sequence_pair_class('jimmyjz1127/base_test',  'bert-base-uncased')

Batches:   0%|          | 0/141 [00:00<?, ?it/s]

cuda
Evaluating [bert-base-uncased]
Accuracy: 0.8342222222222222
Precision: 0.8060522696011004
Recall: 0.7163814180929096
F1 Score: 0.7585760517799353


### Parallel Multi-Task Model With Further MLM Pre-Training

In [11]:
evaluate_sequence_pair_class('jimmyjz1127/multi_parallel_mlm_test',  'Parallel Multi-Task With Further MLM Pre-training')

Batches:   0%|          | 0/141 [00:00<?, ?it/s]

cuda
Evaluating [Parallel Multi-Task With Further MLM Pre-training]
Accuracy: 0.8897777777777778
Precision: 0.9930795847750865
Recall: 0.7017114914425427
F1 Score: 0.8223495702005731


### Selective Question & Answering Model

In [20]:
evaluate_sequence_pair_class('jimmyjz1127/qa_test',  'QA Model')

Batches:   0%|          | 0/141 [00:00<?, ?it/s]

cuda
Evaluating [QA Model]
Accuracy: 0.7577777777777778
Precision: 0.7922912205567452
Recall: 0.45232273838630804
F1 Score: 0.5758754863813229


### Sequential Multi-Task Model

In [21]:
evaluate_sequence_pair_class('jimmyjz1127/test_combined_base',  'Sequential Multi-Task Model')

Batches:   0%|          | 0/141 [00:00<?, ?it/s]

cuda
Evaluating [Sequential Multi-Task Model]
Accuracy: 0.8395555555555556
Precision: 0.8278335724533716
Recall: 0.7053789731051344
F1 Score: 0.7617161716171618


### Parallel Multi-Task Model

In [24]:
evaluate_sequence_pair_class('jimmyjz1127/combined/multitask_parallel',  'Parallel Multi-Task')

Batches:   0%|          | 0/141 [00:00<?, ?it/s]

cuda
Evaluating [Parallel Multi-Task]
Accuracy: 0.9075555555555556
Precision: 0.9856687898089171
Recall: 0.7567237163814181
F1 Score: 0.8561549100968188
