In [26]:
import torch
import re

import pandas as pd
import numpy as np

from datasets import load_metric
from sklearn.model_selection import train_test_split
from transformers import (
    BertTokenizer, 
    BertForSequenceClassification, 
    Trainer,
    TrainingArguments,
    TextClassificationPipeline
)

from sklearn.metrics import classification_report

In [2]:
# !pip3 install torch==1.9.1
# !pip3 install datasets
# !pip3 install transformers==4.11.1

In [32]:
def process_text(text):
    # TODO what process text method should be used?
    text = text.strip()
    text = re.sub('\n', ' ', text)
    text = re.sub('\s+', ' ', text)
    
    return text.lower()

def get_class(label):
    label = re.sub('LABEL_', '', label)
    
    return int(label)

In [4]:
df = pd.read_csv('../data/obligation_extraction_df.csv')
df.is_obligation = df.is_obligation.apply(lambda x: 1 if x else 0)
df.sentence = df.sentence.apply(lambda x: process_text(x))

df_train, df_test = train_test_split(df, train_size=0.8, random_state=42)
# TODO try different train/valid splits
# df_train, df_valid = train_test_split(df, train_size=0.9, random_state=42)  

In [5]:
train_texts, val_texts, train_labels, val_labels = train_test_split(
    df.sentence.values, 
    df.is_obligation.values, 
    test_size=.2,
    random_state=42
)

In [6]:
# distilroberta-base

model_name = 'nlpaueb/legal-bert-base-uncased'
model_name = 'nlpaueb/legal-bert-small-uncased'

tokenizer = BertTokenizer.from_pretrained(model_name, model_max_length=512)

model = BertForSequenceClassification.from_pretrained(
    model_name,
    num_labels=2
)

Some weights of the model checkpoint at nlpaueb/legal-bert-small-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification we

In [7]:
train_encodings = tokenizer(list(train_texts), truncation=True, padding=True)
val_encodings = tokenizer(list(val_texts), truncation=True, padding=True)
# test_encodings = tokenizer(list(test_texts), truncation=True, padding=True)

In [8]:
class ObligationDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

train_dataset = ObligationDataset(train_encodings, train_labels)
val_dataset = ObligationDataset(val_encodings, val_labels)
# test_dataset = ClauseDataset(test_encodings, test_labels)

In [10]:
training_args = TrainingArguments(
    "legal_bert_small_smooothing-0", 
    evaluation_strategy="epoch",
    num_train_epochs=5,
    per_device_train_batch_size=16,
   #  label_smoothing_factor=0.1
)

# training_args
metric = load_metric("f1", labels=[0, 1])

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels, average='micro')

Downloading:   0%|          | 0.00/1.96k [00:00<?, ?B/s]

In [11]:
trainer = Trainer(
    model=model, 
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

In [12]:
trainer.train()

***** Running training *****
  Num examples = 11956
  Num Epochs = 5
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 3740


Epoch,Training Loss,Validation Loss,F1
1,0.3874,0.32938,0.872533
2,0.3003,0.334195,0.873871
3,0.1845,0.40637,0.873202
4,0.1552,0.503237,0.883239
5,0.0962,0.559396,0.881566


Saving model checkpoint to legal_bert_small_smooothing-0/checkpoint-500
Configuration saved in legal_bert_small_smooothing-0/checkpoint-500/config.json
Model weights saved in legal_bert_small_smooothing-0/checkpoint-500/pytorch_model.bin
tokenizer config file saved in legal_bert_small_smooothing-0/checkpoint-500/tokenizer_config.json
Special tokens file saved in legal_bert_small_smooothing-0/checkpoint-500/special_tokens_map.json
***** Running Evaluation *****
  Num examples = 2989
  Batch size = 8
Saving model checkpoint to legal_bert_small_smooothing-0/checkpoint-1000
Configuration saved in legal_bert_small_smooothing-0/checkpoint-1000/config.json
Model weights saved in legal_bert_small_smooothing-0/checkpoint-1000/pytorch_model.bin
tokenizer config file saved in legal_bert_small_smooothing-0/checkpoint-1000/tokenizer_config.json
Special tokens file saved in legal_bert_small_smooothing-0/checkpoint-1000/special_tokens_map.json
***** Running Evaluation *****
  Num examples = 2989
  Ba

TrainOutput(global_step=3740, training_loss=0.2078449794952882, metrics={'train_runtime': 35755.1776, 'train_samples_per_second': 1.672, 'train_steps_per_second': 0.105, 'total_flos': 3522113181081600.0, 'train_loss': 0.2078449794952882, 'epoch': 5.0})

In [21]:
tokenizer = trainer.tokenizer
model = trainer.model
model.eval()

pipeline = TextClassificationPipeline(
    model=model,
    tokenizer=tokenizer,
    return_all_scores=False,
    function_to_apply='softmax',
    truncation=True
)

In [25]:
preds = pipeline(list(df_test.sentence.values))

Disabling tokenizer parallelism, we're using DataLoader multithreading already


In [33]:
predicted_values = [get_class(p['label']) for p in preds]
print(classification_report(df_test.is_obligation, predicted_values))

              precision    recall  f1-score   support

           0       0.89      0.85      0.87      1411
           1       0.87      0.91      0.89      1578

    accuracy                           0.88      2989
   macro avg       0.88      0.88      0.88      2989
weighted avg       0.88      0.88      0.88      2989

