In [1]:
import os
import json
import re
import string
import random
import time
import datetime

import numpy as np
import pandas as pd
# import matplotlib.pyplot as plt

from argparse import Namespace
from tqdm import tqdm
from datasets import Dataset

import transformers
from transformers import BertTokenizer, BertModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import BertForSequenceClassification, AdamW
from transformers import get_linear_schedule_with_warmup
from transformers import pipeline

import torch.nn.functional as F
import torch
from torch.utils.data import DataLoader, TensorDataset

from sklearn.feature_extraction.text import TfidfVectorizer,CountVectorizer
from sklearn.feature_extraction import text
from sklearn.metrics.pairwise import cosine_similarity, linear_kernel
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
args = Namespace(
    data_path = 'processed_data/casehold_processed.csv',
    tokenizer_save_path = 'tokenizer/casehold_tokenizer',
    pretuned_model_path = "./models/mlm_model_manual",
)

In [3]:
casehold_df = pd.read_csv(args.data_path)

## Setup Tokenizer

In [4]:
tokenizer = BertTokenizer.from_pretrained('casehold/legalbert')

## 1 Sentence Pair Classification
- Utilize binary labels (related or not related) for each pair

In [5]:
class_args = Namespace(
    # model_save_path = 'models/sentence_pair_classification',
    model_save_path = 'models/mlm_casehold',
    num_samples=15000,
    batch_size = 16,
    learn_rate = 2e-5,
    epochs = 3,
    device='cpu',
    train_split=0.7,
    model_state_file='casehold_state.pth'
)

In [6]:
# class_df = casehold_df[:class_args.num_samples]
class_df = casehold_df

In [7]:
class_df['split'] = 'train'

num_val_rows = int(len(class_df) * (1 - class_args.train_split)//2) - 1

# 15% for validation and test each , remaining 70% for train
class_df.loc[:num_val_rows, 'split'] = 'val'
class_df.loc[num_val_rows: num_val_rows + num_val_rows, 'split'] = 'test'

print('Number of train samples : ' + str((class_df['split'] == 'train').sum()))
print('Number of val samples : ' + str((class_df['split'] == 'val').sum()))
print('Number of test samples : ' + str((class_df['split'] == 'test').sum()))


class_df.head()

Number of train samples : 10501
Number of val samples : 2249
Number of test samples : 2250


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  class_df['split'] = 'train'


Unnamed: 0.1,Unnamed: 0,context,holding,binary_label,relevance_label,split
0,0,"Drapeau’s cohorts, the cohort would be a “vict...",holding that possession of a pipe bomb is a cr...,1,1.0,val
1,1,"Drapeau’s cohorts, the cohort would be a “vict...",holding that bank robbery by force and violenc...,0,0.652,val
2,2,"Drapeau’s cohorts, the cohort would be a “vict...",holding that sexual assault of a child qualifi...,0,0.647,val
3,3,"Drapeau’s cohorts, the cohort would be a “vict...",holding for the purposes of 18 usc 924e that ...,0,0.67,val
4,4,"Drapeau’s cohorts, the cohort would be a “vict...",holding that a court must only look to the sta...,0,0.639,val


1.1 Data preparation

In [10]:
####################################################
############## Setup Train Dataloader ##############
####################################################

encoded_data_train = [tokenizer.encode_plus(row['context'], row['holding'], add_special_tokens=True, max_length=512, pad_to_max_length=True, truncation=True) for index,row in class_df[class_df['split'] == 'train'].iterrows()]
input_ids_train = [item['input_ids'] for item in encoded_data_train]
attention_masks_train = [item['attention_mask'] for item in encoded_data_train]
labels_train = [row['binary_label'] for index,row in class_df[class_df['split'] == 'train'].iterrows()]

# Convert to tensors
input_ids_train = torch.tensor(input_ids_train)
attention_masks_train = torch.tensor(attention_masks_train)
labels_train = torch.tensor(labels_train)

# Create a dataset
dataset_train = TensorDataset(input_ids_train, attention_masks_train, labels_train)

dataloader_train = DataLoader(dataset_train, batch_size=class_args.batch_size, shuffle=True) # NOTE : maybe set pin_memory=True

In [9]:
####################################################
############## Setup Val Dataloader ################
####################################################

encoded_data_val = [tokenizer.encode_plus(row['context'], row['holding'], add_special_tokens=True, max_length=512, pad_to_max_length=True, truncation=True) for index,row in class_df[class_df['split'] == 'val'].iterrows()]
input_ids_val = [item['input_ids'] for item in encoded_data_val]
attention_masks_val = [item['attention_mask'] for item in encoded_data_val]
labels_val = [row['binary_label'] for index,row in class_df[class_df['split'] == 'val'].iterrows()]

# Convert to tensors
input_ids_val = torch.tensor(input_ids_val)
attention_masks_val = torch.tensor(attention_masks_val)
labels_val = torch.tensor(labels_val)

# Create a dataset
dataset_val = TensorDataset(input_ids_val, attention_masks_val, labels_val)

dataloader_val = DataLoader(dataset_val, batch_size=class_args.batch_size, shuffle=True) # NOTE : maybe set pin_memory=True

In [11]:
###################################################
############## Setup Test Dataloader ##############
###################################################

encoded_data_test = [tokenizer.encode_plus(row['context'], row['holding'], add_special_tokens=True, max_length=512, pad_to_max_length=True, truncation=True) for index,row in class_df[class_df['split'] == 'test'].iterrows()]
input_ids_test = [item['input_ids'] for item in encoded_data_test]
attention_masks_test = [item['attention_mask'] for item in encoded_data_test]
labels_test = [row['binary_label'] for index,row in class_df[class_df['split'] == 'test'].iterrows()]

# Convert to tensors
input_ids_test = torch.tensor(input_ids_test)
attention_masks_test = torch.tensor(attention_masks_test)
labels_test = torch.tensor(labels_test)

# Create a dataset
dataset_test = TensorDataset(input_ids_test, attention_masks_test, labels_test)

dataloader_test = DataLoader(dataset_test, batch_size=class_args.batch_size, shuffle=True) # NOTE : maybe set pin_memory=True

In [11]:
# inspect a single sample

input_ids = encoded_data_train[0].input_ids
subword_view = [tokenizer.convert_ids_to_tokens(id) for id in input_ids]
np.array(subword_view)

array(['[CLS]', 'named', 'as', 'defendants', 'and', 'because', 'ni',
       '##eto', 'did', 'not', 'file', 'a', 'certificate', 'of', 'review',
       '.', 'these', 'assertion', '##s', 'mis', '##int', '##er', '##pre',
       '##t', 'the', 'concept', 'of', 'immunity', '.', 'a', 'person',
       'or', 'entity', 'is', 'not', '“', 'immune', '”', 'from', 'suit',
       'merely', 'because', 'that', 'person', 'or', 'entity', 'asserts',
       'a', 'successful', '•', 'affirmative', 'defense', '.', 'see',
       'wyatt', 'v', '.', 'cole', ',', '50', '##4', 'u', '.', 's', '.',
       '158', ',', '167', 'n', '.', '2', ',', '112', 's', '.', 'ct', '.',
       '1827', ',', '118', 'l', '.', 'ed', '.', '2d', '50', '##4', '(',
       '1992', ')', '(', 'noting', 'that', 'while', 'a', 'defense', 'en',
       '##titles', 'an', 'individual', 'to', 'some', 'protection', 'from',
       'liability', ',', 'he', 'is', 'not', 'entitled', 'to', 'immunity',
       'from', 'suit', ')', '.', 'rather', ',', 'immunity'

1.2 Training

In [18]:
# Initialize model and optimizer

# model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)  # num_labels = 2 for binary classification
model = AutoModelForSequenceClassification.from_pretrained(args.pretuned_model_path, num_labels=2)

# Define optimizer (AdamW is a good default)
optimizer = AdamW(model.parameters(), lr=class_args.learn_rate)

scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=len(dataloader_train) * class_args.epochs)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ./models/mlm_model_manual and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [13]:
def make_train_state(args):
    return {'stop_early': False,
            'early_stopping_step': 0,
            'early_stopping_best_val': 1e8,
            'learning_rate': args.learn_rate,
            'epoch_index': 0,
            'train_loss': [],
            'train_acc': [],
            'val_loss': [],
            'val_acc': [],
            'test_loss': -1,
            'test_acc': -1,
            'model_filename': args.model_state_file}

def update_train_state(args, model, train_state):
    """Handle the training state updates.

    Components:
     - Early Stopping: Prevent overfitting.
     - Model Checkpoint: Model is saved if the model is better

    :param args: main arguments
    :param model: model to train
    :param train_state: a dictionary representing the training state values
    :returns:
        a new train_state
    """

    # Save one model at least
    if train_state['epoch_index'] == 0:
        torch.save(model.state_dict(), train_state['model_filename'])
        train_state['stop_early'] = False

    # Save model if performance improved
    elif train_state['epoch_index'] >= 1:
        loss_tm1, loss_t = train_state['val_loss'][-2:]

        # If loss worsened
        if loss_t >= train_state['early_stopping_best_val']:
            # Update step
            train_state['early_stopping_step'] += 1
        # Loss decreased
        else:
            # Save the best model
            if loss_t < train_state['early_stopping_best_val']:
                torch.save(model.state_dict(), train_state['model_filename'])

            # Reset early stopping step
            train_state['early_stopping_step'] = 0

        # Stop early ?
        train_state['stop_early'] = \
            train_state['early_stopping_step'] >= args.early_stopping_criteria

    return train_state

def compute_accuracy(y_pred, y_target):
    y_target = y_target.cpu()
    y_pred_indices = (torch.sigmoid(y_pred)>0.5).cpu().long()#.max(dim=1)[1]
    n_correct = torch.eq(y_pred_indices, y_target).sum().item()
    return n_correct / len(y_pred_indices) * 100


def calculate_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)

In [19]:
torch.cuda.empty_cache()
if torch.cuda.is_available():
  class_args.device = 'cuda'

model.to(class_args.device)
print(class_args.device)

cuda


In [15]:
epoch_bar = tqdm(desc='Training Routine', total=class_args.epochs, position=0)
train_bar = tqdm(desc='Split=Train', total=len(dataloader_train), position=1, leave=True)
validation_bar = tqdm(desc='Split=Val', total=len(dataloader_val), position=1, leave=True)


Training Routine:   0%|          | 0/3 [00:00<?, ?it/s]

In [16]:

train_state = make_train_state(class_args)
loss_fn = torch.nn.CrossEntropyLoss()

for epoch in range(class_args.epochs):
  train_state['epoch'] = epoch #NEW

  model.train()
  total_train_loss = 0
  total_train_accuracy = 0 #NEW

  for step, batch in enumerate(dataloader_train):
    b_input_ids, b_input_mask, b_labels = batch
    b_input_ids, b_input_mask, b_labels = b_input_ids.to(class_args.device), b_input_mask.to(class_args.device), b_labels.to(class_args.device)

    model.zero_grad()
    outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels)

    loss = outputs.loss
    total_train_loss += loss.item()
    loss.backward()

    logits = outputs.logits.detach().cpu().numpy()#NEW
    label_ids = b_labels.to('cpu').numpy()#NEW
    total_train_accuracy += calculate_accuracy(logits, label_ids)#NEW

    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    scheduler.step()

    train_bar.update()

  avg_train_loss = total_train_loss / len(dataloader_train)
  train_state['train_loss'].append(total_train_loss) #NEW
  train_state['train_acc'].append(total_train_accuracy) #NEW
  print(f'Epoch {epoch}: Average Training Loss: {avg_train_loss}')

  model.eval()
  total_eval_accuracy = 0
  total_eval_loss = 0

  for batch in dataloader_val:
    b_input_ids, b_input_mask, b_labels = batch
    b_input_ids, b_input_mask, b_labels = b_input_ids.to(class_args.device), b_input_mask.to(class_args.device), b_labels.to(class_args.device)

    with torch.no_grad():
      outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask)

    logits = outputs.logits.detach().cpu().numpy()
    label_ids = b_labels.to('cpu').numpy()
    total_eval_loss += loss_fn(outputs.logits.squeeze(-1), b_labels).item() # perhaps just outputs.loss (need to include labels as parameter in model() above)

    total_eval_accuracy += calculate_accuracy(logits, label_ids)

    validation_bar.update()

  avg_val_accuracy = total_eval_accuracy / len(dataloader_val)
  train_state['val_loss'].append(total_eval_loss)
  train_state['val_acc'].append(total_eval_accuracy)
  print(f'Epoch {epoch}: Validation Accuracy: {avg_val_accuracy}')

  train_state = update_train_state(args=class_args, model=model, train_state=train_state)

  if train_state['stop_early']:
    break

  train_bar.n = 0
  validation_bar.n=0
  epoch_bar.update()



Epoch 0: Average Training Loss: 0.47454547749398507




Epoch 0: Validation Accuracy: 0.8205772261623325


Training Routine:  33%|███▎      | 1/3 [10:08<20:16, 608.34s/it]

Epoch 1: Average Training Loss: 0.38863031764598377
Epoch 1: Validation Accuracy: 0.8165878644602048


Training Routine:  67%|██████▋   | 2/3 [20:16<10:08, 608.37s/it]

Epoch 2: Average Training Loss: 0.2901076706044115
Epoch 2: Validation Accuracy: 0.8120074862096138


Training Routine: 100%|██████████| 3/3 [30:25<00:00, 608.46s/it]

In [17]:
model.save_pretrained(class_args.model_save_path)

## Evaluation

### Manual Checking

In [12]:
def classify_sentence_pair(sentence1, sentence2, model, tokenizer):
    # Prepare the input sentence pair
    inputs = tokenizer(sentence1, sentence2, return_tensors="pt", padding=True, truncation=True, max_length=512)

    # Get predictions
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits

    # Convert logits to probabilities
    # takes shape [[x ,y]] where x is probability of unrelated and y is probability or related
    probabilities = torch.softmax(logits, dim=1)

    print(probabilities)

    # Convert probabilities to binary predictions
    predicted_class_id = torch.argmax(probabilities, dim=1).item()

    return predicted_class_id  # 1 for related, 0 for unrelated


In [19]:
class_model = AutoModelForSequenceClassification.from_pretrained(class_args.model_save_path)
# class_tokenizer = AutoTokenizer.from_pretrained(model_path)



In [20]:
# sentence1 = "The legislation was passed in 1999."
# sentence2 = "Legislation enacted in 1999 started a major reform."

sentence1 = ("They also rely on Oswego Laborers’ Local 214 Pension Fund v. Marine "
            "Midland Bank, 85 N.Y.2d 20, 623 N.Y.S.2d 529, 647 N.E.2d 741 (1996), which "
            "held that a plaintiff 'must demonstrate that the acts or practices have a "
            "broader impact on consumers at large.' Defs.’ Mem. at 14 (quoting Oswego "
            "Laborers’, 623 N.Y.S.2d 529, 647 N.E.2d at 744). As explained above, how-"
            "ever, Plaintiffs have adequately alleged that Defendants’ unauthorized "
            "use of the DEL MONICO’s name in connection with non-Ocinomled "
            "restaurants and products caused consumer harm or injury to the public, "
            "and that they had a broad impact on consumers at large inasmuch as "
            "such use was likely to cause consumer confusion. See, e.g., CommScope, "
            "Inc. of N.C. v. CommScope (U.S.A) Int’l Grp. Co., 809 F. Supp.2d 33, 38 "
            "(N.D.N.Y 2011) (<HOLDING>); New York City Triathlon, LLC v. NYC Triathlon" )

sentence2 = "holding that plaintiff stated a 349 claim where plaintiff alleged facts plausibly suggesting that defendant intentionally registered its corporate name to be confusingly similar to plaintiffs CommScope trademark"
sentence3 = "A logit is the raw output of the model's final layer, and it's a real number that can be positive, negative, or zero."
sentence4 = "holding that plaintiff stated a 349 claim where plaintiff alleged facts plausibly suggesting that defendant intentionally registered its corporate name to be confusingly similar to plaintiffs CommScope trademark, despite being unauthorized to do so."

prediction = classify_sentence_pair(sentence1, sentence2, class_model, tokenizer)
print(prediction)
print("Classified as:", "Related" if prediction == 1 else "Unrelated")

print('\n')

prediction = classify_sentence_pair(sentence1, sentence3, class_model, tokenizer)
print(prediction)
print("Classified as:", "Related" if prediction == 1 else "Unrelated")

print('\n')

prediction = classify_sentence_pair(sentence1, sentence4, class_model, tokenizer)
print(prediction)
print("Classified as:", "Related" if prediction == 1 else "Unrelated")


tensor([[0.3599, 0.6401]])
1
Classified as: Related


tensor([[0.9652, 0.0348]])
0
Classified as: Unrelated


tensor([[0.2597, 0.7403]])
1
Classified as: Related


### Evaluation metrics on Test data

In [21]:
def evaluate_sequence_pair_class(model_path,  title):
  '''
    Routine for evaluating model for sequence pair classification
  '''

  # load model and tokenizer
  model = AutoModelForSequenceClassification.from_pretrained(model_path)

  # Check if cuda available
  if torch.cuda.is_available():
    # model.to('cuda')
    class_args.device = 'cuda'
  else:
    # model.to('cpu')
    class_args.device = 'cpu'
  
  model.to(class_args.device)

  model.eval()

  predictions, true_labels = [], []

  print('Evaluating ' + f'[{title}]')
  print('============================================')

  with torch.no_grad(): # disable calculating gradients (more efficient for evaluation)
    for batch in dataloader_test:
      input_ids, attention_mask, labels = tuple(t.to(class_args.device) for t in batch)
      outputs = model(input_ids, attention_mask=attention_mask)
      logits = outputs.logits
      preds = torch.argmax(logits, dim=1).flatten() # find index of max value in logits tensor (where each index corresponds to a binary class)

      predictions.extend(preds.cpu().numpy())
      true_labels.extend(labels.cpu().numpy())


  # Calculate metrics
  accuracy = accuracy_score(true_labels, predictions)
  precision, recall, f1, _ = precision_recall_fscore_support(true_labels, predictions, average='binary')
  # roc_auc = roc_auc_score(true_labels, predictions)  # Uncomment if ROC-AUC is needed

  print(f'Accuracy: {accuracy}\nPrecision: {precision}\nRecall: {recall}\nF1 Score: {f1}')



In [42]:
from datasets import load_dataset, load_metric
from torch.utils.data import DataLoader, TensorDataset

# Load the GLUE dataset and metric for SST-2 as an example
dataset = load_dataset('glue', 'sst2')
metric = load_metric('glue', 'sst2')

validation_sentences = dataset['validation']['sentence']
validation_labels = dataset['validation']['label']

def encode_data(sentences, labels):
    encoded_data = [
        tokenizer.encode_plus(
            text=sentence,
            add_special_tokens=True,
            max_length=512,
            padding='max_length',
            truncation=True,
            return_attention_mask=True
        ) for sentence in sentences
    ]
    
    input_ids = [item['input_ids'] for item in encoded_data]
    attention_masks = [item['attention_mask'] for item in encoded_data]
    
    # Convert lists to tensors
    input_ids = torch.tensor(input_ids)
    attention_masks = torch.tensor(attention_masks)
    labels = torch.tensor(labels)
    
    return input_ids, attention_masks, labels

# Encode the data
input_ids_val, attention_masks_val, labels_val = encode_data(validation_sentences, validation_labels)

# Create a TensorDataset and DataLoader for the validation set
dataset_val = TensorDataset(input_ids_val, attention_masks_val, labels_val)
dataloader_val = DataLoader(dataset_val, batch_size=32)  # Adjust batch size as needed

def evaluate_glue_model(model_path, dataloader, metric):
    model = AutoModelForSequenceClassification.from_pretrained(model_path)

    model.eval()  # Ensure the model is in evaluation mode
    
    # Check if cuda available
    if torch.cuda.is_available():
        class_args.device = 'cuda'
    else:
        class_args.device = 'cpu'
    
    model.to(class_args.device)

    for batch in dataloader:
        batch = tuple(t.to(class_args.device) for t in batch)
        b_input_ids, b_attention_mask, b_labels = batch

        with torch.no_grad():
            outputs = model(b_input_ids, attention_mask=b_attention_mask)
        
        logits = outputs.logits
        predictions = torch.argmax(logits, dim=-1)
        metric.add_batch(predictions=predictions, references=b_labels)

    final_score = metric.compute()
    print(final_score)

# Call the evaluation function with the GLUE DataLoader and metric
evaluate_glue_model(class_args.model_save_path, glue_dataloader, metric)


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


ValueError: not enough values to unpack (expected 3, got 2)

In [None]:
#outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels)

In [22]:
evaluate_sequence_pair_class(class_args.model_save_path,  'Sequence Pair Classificaiton Evaluation Metrics')

Evaluating [Sequence Pair Classificaiton Evaluation Metrics]
Accuracy: 0.8293333333333334
Precision: 0.6012084592145015
Recall: 0.44124168514412415
F1 Score: 0.5089514066496164


In [26]:
evaluate_sequence_pair_class('bert-base-uncased',  'Sequence Pair Classificaiton Evaluation Metrics')

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Evaluating [Sequence Pair Classificaiton Evaluation Metrics]
Accuracy: 0.7986666666666666
Precision: 0.25
Recall: 0.0022172949002217295
F1 Score: 0.004395604395604396


In [24]:
evaluate_sequence_pair_class('models/sentence_pair_classification',  'Sequence Pair Classificaiton Evaluation Metrics')

Evaluating [Sequence Pair Classificaiton Evaluation Metrics]
Accuracy: 0.86
Precision: 0.7236842105263158
Recall: 0.4878048780487805
F1 Score: 0.5827814569536424


In [25]:
evaluate_sequence_pair_class('casehold/legalbert',  'Sequence Pair Classificaiton Evaluation Metrics')

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at casehold/legalbert and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Evaluating [Sequence Pair Classificaiton Evaluation Metrics]
Accuracy: 0.7826666666666666
Precision: 0.12
Recall: 0.013303769401330377
F1 Score: 0.023952095808383235
