# Training Routine for Parallel Multi-Task Finetuning using Casehold STS, Privacy Policy Q&A, and Keyword search 

In [1]:
import os
import json
import re
import string
import random
import time
import datetime

import numpy as np
import pandas as pd
# import matplotlib.pyplot as plt

from argparse import Namespace
from tqdm.notebook import tqdm

# from datasets import Dataset

import transformers
from transformers import BertTokenizer, BertModel, BertConfig
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import BertForSequenceClassification, AdamW
from transformers import get_linear_schedule_with_warmup
from transformers import pipeline
from transformers import BertTokenizer, DataCollatorForLanguageModeling

import torch.nn.functional as F
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset, Dataset

from sklearn.feature_extraction.text import TfidfVectorizer,CountVectorizer
from sklearn.feature_extraction import text
from sklearn.metrics.pairwise import cosine_similarity, linear_kernel
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score

In [2]:
args = Namespace(
    sts_datapath = "processed_data/casehold_processed.csv",
    qa_datapath = "raw_data/ir_data/privacy_policy/policy_train_data.csv",
    ir_datapath = "./processed_data/ir_data.csv",
    model_save_path = './models/parallel_three',
    num_samples = 15000,
    train_split = 0.7,
    epochs = 4,
    learning_rate=1e-5
)

## Data Preparation

In [3]:
tokenizer = BertTokenizer.from_pretrained('casehold/legalbert')

In [4]:
sts_df = pd.read_csv(args.sts_datapath)[0:args.num_samples]
sts_df['split'] = 'train'
num_val_rows = int(len(sts_df) * (1 - args.train_split)//2) - 1
# 15% for validation and test each , remaining 70% for train
sts_df.loc[:num_val_rows, 'split'] = 'val'
sts_df.loc[num_val_rows: num_val_rows + num_val_rows, 'split'] = 'test'

In [5]:
qa_df = pd.read_csv(args.qa_datapath, sep='\t')[0:args.num_samples]
qa_df['split'] = 'train'
num_val_rows = int(len(qa_df) * (1 - args.train_split)//2) - 1
# 15% for validation and test each , remaining 70% for train
qa_df.loc[:num_val_rows, 'split'] = 'val'
qa_df.loc[num_val_rows: num_val_rows + num_val_rows, 'split'] = 'test'
mapping_dict = {'Relevant': 1, 'Irrelevant': 0}

# Apply the mapping to the 'Label' column
qa_df['Label'] = qa_df['Label'].str.strip().map(mapping_dict)

In [6]:
ir_df = pd.read_csv(args.ir_datapath)[0:args.num_samples]
ir_df['split'] = 'train'
num_val_rows = int(len(ir_df) * (1 - args.train_split)//2) - 1
# 15% for validation and test each , remaining 70% for train
ir_df.loc[:num_val_rows, 'split'] = 'val'
ir_df.loc[num_val_rows: num_val_rows + num_val_rows, 'split'] = 'test'

ir_df['lab'] = 1

In [7]:
from torch.utils.data import Dataset
import torch

class CombinedDataset(Dataset):
    def __init__(self, tokenizer, sts_df, qa_df, ir_df, split, max_length=312):
        self.tokenizer = tokenizer
        self.sts_df = sts_df[sts_df['split'] == split].reset_index(drop=True)
        self.qa_df = qa_df[qa_df['split'] == split].reset_index(drop=True)
        self.ir_df = ir_df[ir_df['split'] == split].reset_index(drop=True)
        
        self.max_length = max_length

        # STS processing
        self.encodings_sts = [tokenizer.encode_plus(row['context'], row['holding'], 
                             add_special_tokens=True, max_length=max_length, 
                             pad_to_max_length=True, truncation=True, return_tensors="pt") 
                             for _, row in self.sts_df.iterrows()]
        self.labels_sts = torch.tensor(self.sts_df['binary_label'].tolist())

        # Question and answering processing 
        self.encoding_qa = [tokenizer.encode_plus(row['Query'], row['Segment'], 
                             add_special_tokens=True, max_length=max_length, 
                             pad_to_max_length=True, truncation=True, return_tensors="pt") 
                             for _, row in self.qa_df.iterrows()]
        self.labels_qa = torch.tensor(self.qa_df['Label'].tolist())

        # IR processing 
        self.encoding_ir = [tokenizer.encode_plus(row['label'][0], row['provision'], 
                             add_special_tokens=True, max_length=max_length, 
                             pad_to_max_length=True, truncation=True, return_tensors="pt") 
                             for _, row in self.ir_df.iterrows()]
        self.labels_ir = torch.tensor(self.ir_df['lab'].tolist())

    def __len__(self):
        return len(self.encodings_sts)

    def __getitem__(self, idx):
        # Handling STS
        item_sts = self.encodings_sts[idx]
        input_ids_sts = item_sts['input_ids'].squeeze()  
        attention_mask_sts = item_sts['attention_mask'].squeeze()
        token_type_ids_sts = item_sts['token_type_ids'].squeeze()
        label_sts = self.labels_sts[idx]

        # Handling Question and Answering
        item_qa = self.encoding_qa[idx]
        input_ids_qa = item_qa['input_ids'].squeeze()
        attention_mask_qa = item_qa['attention_mask'].squeeze()
        token_type_ids_qa = item_qa['token_type_ids'].squeeze() 
        label_qa = self.labels_qa[idx]

        # IR 
        item_ir = self.encoding_ir[idx]
        input_ids_ir = item_ir['input_ids'].squeeze()
        attention_mask_ir = item_ir['attention_mask'].squeeze()
        token_type_ids_ir = item_ir['token_type_ids'].squeeze() 
        label_ir = self.labels_ir[idx]


        return {
            'input_ids_sts': input_ids_sts,
            'attention_mask_sts': attention_mask_sts,
            'token_type_ids_sts': token_type_ids_sts,
            'labels_sts': label_sts,  
            #----------------------------------------
            'input_ids_qa' : input_ids_qa,
            'attention_mask_qa' : attention_mask_qa,
            'token_type_ids_qa' :token_type_ids_qa,
            'labels_qa' : label_qa,
            #----------------------------------------
            'input_ids_ir' : input_ids_ir,
            'attention_mask_ir' : attention_mask_ir,
            'token_type_ids_ir' : token_type_ids_ir,
            'labels_ir':label_ir
        }


In [8]:
combined_dataset = CombinedDataset(tokenizer, sts_df, qa_df, ir_df, 'train')
dataloader = DataLoader(combined_dataset,  batch_size=8, shuffle=True)

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

## Training

In [9]:
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'

print(device)

cuda


In [10]:
from transformers import BertModel, PreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput
import torch.nn as nn

class MultiTaskModel(PreTrainedModel):
    def __init__(self, bert_model_name, num_labels_bin):
        config = BertConfig.from_pretrained(bert_model_name)
        super(MultiTaskModel, self).__init__(config)
        self.num_labels = num_labels_bin
        
        # Load the pre-trained BertModel
        self.bert = BertModel.from_pretrained(bert_model_name, config=config)
        
        # Define the Masked Language Model (MLM) head
        self.mlm_head = nn.Linear(config.hidden_size, config.vocab_size)
        
        # Define the classification head
        self.sts_head = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size),
            nn.ReLU(),
            nn.Linear(config.hidden_size, num_labels_bin)
        )

        # Define the classification head
        self.qa_head = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size),
            nn.ReLU(),
            nn.Linear(config.hidden_size, num_labels_bin)
        )

        # Define the classification head
        self.ir_head = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size),
            nn.ReLU(),
            nn.Linear(config.hidden_size, num_labels_bin)
        )
    
    def forward(self, input_ids_sts, attention_mask_sts, token_type_ids_sts, labels_sts, input_ids_qa, attention_mask_qa, token_type_ids_qa, labels_qa, input_ids_ir, attention_mask_ir, token_type_ids_ir, labels_ir):
        
        
        
        # Process input through BertModel
        outputs_sts = self.bert(input_ids=input_ids_sts, 
                                  attention_mask=attention_mask_sts, 
                                  token_type_ids=token_type_ids_sts)
        
        pooled_output_sts = outputs_sts.pooler_output

        # Compute classification logits
        sts_logits = self.sts_head(pooled_output_sts)

        #====================================================================

         # Process input through BertModel
        outputs_qa = self.bert(input_ids=input_ids_qa, 
                                  attention_mask=attention_mask_qa, 
                                  token_type_ids=token_type_ids_qa)
        
        pooled_output_qa = outputs_qa.pooler_output

        # Compute classification logits
        qa_logits = self.qa_head(pooled_output_qa)

        #=====================================================================
         # Process input through BertModel
        outputs_ir = self.bert(input_ids=input_ids_ir, 
                                  attention_mask=attention_mask_ir, 
                                  token_type_ids=token_type_ids_ir)
        
        pooled_output_ir = outputs_ir.pooler_output

        # Compute classification logits
        ir_logits = self.ir_head(pooled_output_ir)
        #=====================================================================

        # Compute losses if labels are provided
        loss = None
        losses = {}
        if labels_sts is not None and labels_qa is not None and labels_ir is not None:
            loss_fct= nn.CrossEntropyLoss()
            
            sts_loss = loss_fct(sts_logits.view(-1, self.num_labels), labels_sts.view(-1))
            losses['sts_loss'] = sts_loss
            loss = sts_loss
            
            qa_loss = loss_fct(qa_logits.view(-1, self.num_labels), labels_qa.view(-1))
            losses['qa_loss'] = qa_loss
            loss += qa_loss

            ir_loss = loss_fct(ir_logits.view(-1, self.num_labels), labels_ir.view(-1))
            losses['ir_loss'] = ir_loss
            loss += ir_loss


        return {
            "loss": loss,
            "logits_sts": sts_logits,
            "logits_qa" : qa_logits,
            "logits_ir" : ir_logits,
            "hidden_states": outputs_sts.hidden_states,
            "attentions": outputs_sts.attentions,
        }

In [11]:
from transformers import AdamW

model = MultiTaskModel('./models/mlm_model', 2) # Assuming binary classification


model.to(device)
model.train()
optimizer = AdamW(model.parameters(), lr=5e-5)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=len(dataloader) * args.epochs)
print(torch.cuda.memory_allocated())
print(torch.cuda.memory_reserved())

Some weights of BertModel were not initialized from the model checkpoint at ./models/mlm_model and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


540670976
589299712




In [12]:
batch_progress = tqdm(total=0, desc='Batches', leave=True)

Batches: 0it [00:00, ?it/s]

In [13]:
for epoch in range(args.epochs):
    total_loss = 0.0
    batch_progress.reset(total=len(dataloader))
    for step, batch in enumerate(dataloader):
        # print(torch.cuda.memory_allocated())
        # Move batch data to the same device as the model
        batch = {k: v.to(device) for k, v in batch.items()}
        
        # Adjust model inputs according to the new forward method signature
        outputs = model(input_ids_sts=batch['input_ids_sts'], 
                        attention_mask_sts=batch['attention_mask_sts'],
                        token_type_ids_sts=batch['token_type_ids_sts'],
                        labels_sts=batch['labels_sts'].long(),
                        input_ids_qa=batch['input_ids_qa'], 
                        attention_mask_qa=batch['attention_mask_qa'],
                        token_type_ids_qa=batch['token_type_ids_qa'],
                        labels_qa=batch['labels_qa'].long(),
                        input_ids_ir=batch['input_ids_ir'], 
                        attention_mask_ir=batch['attention_mask_ir'],
                        token_type_ids_ir=batch['token_type_ids_ir'],
                        labels_ir=batch['labels_ir'].long(),
                        )
        

        # Combine MLM and classification losses
        loss = outputs['loss']

        if loss is None : print('test')

        # loss = loss / 4
        loss.backward()
        total_loss += loss.item()
        
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()     

         # Update batch progress bar and metrics
        batch_progress.set_postfix({'Average Loss': total_loss / (step + 1)})
        batch_progress.update(1)  # Increment the progress bar
       

    print(f'Epoch {epoch + 1}/{args.epochs}, Average Loss: {total_loss / len(dataloader)}')


Epoch 1/4, Average Loss: 0.593444871685701
Epoch 2/4, Average Loss: 0.42510560488573756
Epoch 3/4, Average Loss: 0.28216828634183616
Epoch 4/4, Average Loss: 0.13892293061922048


In [14]:
model.save_pretrained('./models/parallel_three_mlm')

## Evaluation

In [15]:
combined_dataset_test = CombinedDataset(tokenizer, sts_df, qa_df, ir_df, 'test')
dataloader_test = DataLoader(combined_dataset,  batch_size=8, shuffle=True)

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

In [16]:
def evaluate_model(model_path, title):
    batch_progress = tqdm(total=len(dataloader_test), desc='Batches', leave=True)

    # load model and tokenizer
    model = MultiTaskModel(model_path, 2)

    # Check if cuda available
    if torch.cuda.is_available():
        # model.to('cuda')
        device = 'cuda'
    else:
        # model.to('cpu')
        device = 'cpu'

    print(device)

    model.to(device)

    model.eval()

    predictions_sts, labels_sts = [], []
    predictions_qa, labels_qa = [], []
    predictions_ir, labels_ir = [], []

    print('Evaluating ' + f'[{title}]')
    print('============================================')

    with torch.no_grad():
        for batch in dataloader_test:
            batch = {k: v.to(device) for k, v in batch.items()}

            total_loss = 0.0
        
            # Adjust model inputs according to the new forward method signature
            outputs = model(input_ids_sts=batch['input_ids_sts'], 
                            attention_mask_sts=batch['attention_mask_sts'],
                            token_type_ids_sts=batch['token_type_ids_sts'],
                            labels_sts=batch['labels_sts'].long(),
                            input_ids_qa=batch['input_ids_qa'], 
                            attention_mask_qa=batch['attention_mask_qa'],
                            token_type_ids_qa=batch['token_type_ids_qa'],
                            labels_qa=batch['labels_qa'].long(),
                            input_ids_ir=batch['input_ids_ir'], 
                            attention_mask_ir=batch['attention_mask_ir'],
                            token_type_ids_ir=batch['token_type_ids_ir'],
                            labels_ir=batch['labels_ir'].long())
            

            # Combine MLM and classification losses
            loss = outputs['loss']

            total_loss += loss.item()

            logits_sts = outputs['logits_sts']
            logits_qa = outputs['logits_qa']
            logits_ir = outputs['logits_ir']

            preds_sts = torch.argmax(logits_sts, dim=1).flatten()
            preds_qa = torch.argmax(logits_qa, dim=1).flatten()
            preds_ir = torch.argmax(logits_ir, dim=1).flatten()

            # Update batch progress bar and metrics
            batch_progress.set_postfix({'Average Loss': total_loss / (step + 1)})
            batch_progress.update(1)  # Increment the progress bar

            predictions_sts.extend(preds_sts.cpu().numpy())
            labels_sts.extend(batch['labels_sts'].cpu().numpy())

            predictions_qa.extend(preds_qa.cpu().numpy())
            labels_qa.extend(batch['labels_qa'].cpu().numpy())

            predictions_ir.extend(preds_ir.cpu().numpy())
            labels_ir.extend(batch['labels_ir'].cpu().numpy())

        precision_sts, recall_sts, f1_sts, _ = precision_recall_fscore_support(labels_sts, predictions_sts, average='binary')
        precision_qa, recall_qa, f1_qa, _ = precision_recall_fscore_support(labels_qa, predictions_qa, average='binary')
        precision_ir, recall_ir, f1_ir, _ = precision_recall_fscore_support(labels_ir, predictions_ir, average='binary')

        accuracy_sts = accuracy_score(labels_sts, predictions_sts)
        accuracy_qa = accuracy_score(labels_qa, predictions_qa)
        accuracy_ir = accuracy_score(labels_ir, predictions_ir)

        print(f'STS :  Accuracy: {accuracy_sts}\nPrecision: {precision_sts}\nRecall: {recall_sts}\nF1 Score: {f1_sts}')
        print(f'QA  :  Accuracy: {accuracy_qa}\nPrecision: {precision_qa}\nRecall: {recall_qa}\nF1 Score: {f1_qa}')
        print(f'IR  :  Accuracy: {accuracy_ir}\nPrecision: {precision_ir}\nRecall: {recall_ir}\nF1 Score: {f1_ir}')

        print(f'Average Accuracy :', np.mean([accuracy_sts, accuracy_qa, accuracy_ir]))

In [17]:
evaluate_model('./models/parallel_three_mlm',  'Sequence Pair Classificaiton Evaluation Metrics')

Batches:   0%|          | 0/1313 [00:00<?, ?it/s]

cuda
Evaluating [Sequence Pair Classificaiton Evaluation Metrics]
STS :  Accuracy: 0.8900104751928388
Precision: 0.6504297994269341
Recall: 0.9728571428571429
F1 Score: 0.7796222095020034
QA  :  Accuracy: 0.016284163413008285
Precision: 0.016042780748663103
Recall: 0.29136690647482016
F1 Score: 0.030411113196921344
IR  :  Accuracy: 0.0
Precision: 0.0
Recall: 0.0
F1 Score: 0.0
Average Accuracy : 0.3020982128686157


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


: 

In [20]:
evaluate_model('bert-base-uncased',  'Sequence Pair Classificaiton Evaluation Metrics')

Batches:   0%|          | 0/1313 [00:00<?, ?it/s]

cuda
Evaluating [Sequence Pair Classificaiton Evaluation Metrics]
STS :  Accuracy: 0.20293305399485764
Precision: 0.20051585785250287
Recall: 0.9995238095238095
F1 Score: 0.3340229153405474
QA  :  Accuracy: 0.32663555851823634
Precision: 0.03163191948238677
Recall: 0.39568345323741005
F1 Score: 0.05858074823592065
IR  :  Accuracy: 0.9960956099419103
Precision: 1.0
Recall: 0.9960956099419103
F1 Score: 0.9980439864510281
Average Accuracy : 0.5085547408183347


In [21]:
evaluate_model('./models/sentence_pair_classification',  'Sequence Pair Classificaiton Evaluation Metrics')

Batches:   0%|          | 0/1313 [00:00<?, ?it/s]

cuda
Evaluating [Sequence Pair Classificaiton Evaluation Metrics]
STS :  Accuracy: 0.5550899914293876
Precision: 0.07973856209150326
Recall: 0.11619047619047619
F1 Score: 0.09457364341085271
QA  :  Accuracy: 0.5164270069517188
Precision: 0.041379310344827586
Recall: 0.3669064748201439
F1 Score: 0.07437112650382792
IR  :  Accuracy: 0.23473954861441768
Precision: 1.0
Recall: 0.23473954861441768
F1 Score: 0.38022520438068796
Average Accuracy : 0.4354188489985081


In [22]:
evaluate_model('./models/mlm_model',  'Sequence Pair Classificaiton Evaluation Metrics')

Batches:   0%|          | 0/1313 [00:00<?, ?it/s]

Some weights of BertModel were not initialized from the model checkpoint at ./models/mlm_model and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


cuda
Evaluating [Sequence Pair Classificaiton Evaluation Metrics]
STS :  Accuracy: 0.24359584801447481
Precision: 0.18392296873309533
Recall: 0.8095238095238095
F1 Score: 0.299744335713656
QA  :  Accuracy: 0.646509856204171
Precision: 0.03423848878394333
Recall: 0.20863309352517986
F1 Score: 0.058823529411764705
IR  :  Accuracy: 0.34939529568612515
Precision: 1.0
Recall: 0.34939529568612515
F1 Score: 0.5178546224417784
Average Accuracy : 0.413166999968257


In [23]:
evaluate_model('casehold/legalbert',  'Sequence Pair Classificaiton Evaluation Metrics')

Batches:   0%|          | 0/1313 [00:00<?, ?it/s]

cuda
Evaluating [Sequence Pair Classificaiton Evaluation Metrics]
STS :  Accuracy: 0.8000190458051614
Precision: 0.0
Recall: 0.0
F1 Score: 0.0
QA  :  Accuracy: 0.052947338348728695
Precision: 0.052947338348728695
Recall: 1.0
F1 Score: 0.10056977480329203
IR  :  Accuracy: 0.0
Precision: 0.0
Recall: 0.0
F1 Score: 0.0
Average Accuracy : 0.2843221280512967


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
