In [3]:
from transformers import BertTokenizer, AutoModel, BertForSequenceClassification
from sklearn.metrics import classification_report

from transformers import TextClassificationPipeline

import numpy as np
import pandas as pd

import re

In [4]:
def process_text(text):
    text = re.sub('\n', ' ', text)
    text = re.sub('\s+', ' ', text)
    
    return text

In [13]:
df = pd.read_csv('../data/obligation_extraction_df.csv')
# df['Clause Text'] = df['Clause Text'].apply(lambda x: process_text(x))
df['is_obligation'] = df['is_obligation'].apply(lambda x: int(x))

df.sample(10)

Unnamed: 0,sentence,is_obligation
14712,"In either event, Tenant shall also pay any oth...",1
8318,The Company will pay any personal tax liabilit...,1
9773,Not less than thirty (30) days prior to the ex...,1
4596,Unless otherwise agreed by the Administrative ...,0
9426,Landlord also acknowledges and agrees that Ten...,0
348,FORCE MAJEURE: Neither Lindows.com nor License...,0
10697,EMPLOYEE BENEFITS: During the Term of Employme...,0
7693,For each fiscal year of the Company during the...,1
1365,"requirement of any government agency, or any o...",0
12129,If any portion of the Work Product is not rule...,0


In [6]:
!ls ../models/legal_bert_small

checkpoint-1000  checkpoint-4000  checkpoint-6500  pytorch_model.bin
checkpoint-1500  checkpoint-4500  checkpoint-7000  special_tokens_map.json
checkpoint-2000  checkpoint-500   checkpoint-7500  tokenizer.json
checkpoint-2500  checkpoint-5000  checkpoint-8000  tokenizer_config.json
checkpoint-3000  checkpoint-5500  checkpoint-8500  training_args.bin
checkpoint-3500  checkpoint-6000  config.json	   vocab.txt


In [9]:
tokenizer = BertTokenizer.from_pretrained('distilroberta-base', model_max_length=512)
model = BertForSequenceClassification.from_pretrained('distilroberta-base')

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 512, padding_idx=0)
      (position_embeddings): Embedding(512, 512)
      (token_type_embeddings): Embedding(2, 512)
      (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=512, out_features=512, bias=True)
              (key): Linear(in_features=512, out_features=512, bias=True)
              (value): Linear(in_features=512, out_features=512, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=512, out_features=512, bias=True)
              (LayerNorm): LayerNorm((512,), eps=1e-12, element

In [8]:
text = df['Clause Text'].iloc[100]
text = process_text(text)

tokens = tokenizer.encode_plus(
     text,
     add_special_tokens=True,
     truncation = True,
     padding="max_length",
     return_attention_mask=True,
     return_tensors="pt"
)

In [10]:
from math import ceil

CHUNK_SIZE = 8  # TODO what is the optimal value
num_examples = len(df['Clause Text'].values)
num_chunks = ceil(num_examples / CHUNK_SIZE)

text_chunks = np.array_split(df['Clause Text'].values, num_chunks)
predictions = []
del df

for idx, text_chunk in enumerate(text_chunks):
    print(f'{idx} of {num_chunks}')    
    
    text_chunk = list(text_chunk)
    tokens = tokenizer(text_chunk, truncation=True, padding=True, return_tensors="pt")
    res = model(**tokens)

    for logits in res[0]:
        softmax_res = softmax(logits)  # Needed for confidence
        index = np.argsort(softmax_res.detach().numpy())[-1]

        predictions.append(labels[index])

0 of 1940


  softmax_res = softmax(logits)  # Needed for confidence


1 of 1940
2 of 1940
3 of 1940


KeyboardInterrupt: 

In [34]:
predictions_ = [p if p != 'payment' else 'payment_terms' for p in predictions]

In [35]:
print(classification_report(df.Label.values, predictions_))

                               precision    recall  f1-score   support

                          UNK       0.00      0.00      0.00         0
                   assignment       0.85      0.78      0.81       212
              confidentiality       0.79      0.90      0.84       260
    data_security_and_privacy       0.51      0.91      0.66        43
                  definitions       0.99      0.76      0.86      3957
           dispute_resolution       0.93      0.86      0.90       175
             entire_agreement       0.91      0.95      0.93       136
          export_control_laws       0.57      0.84      0.68        25
                force_majeure       0.77      0.98      0.86        54
                governing_law       0.96      0.96      0.96       140
              indemnification       0.83      0.59      0.69       418
                    insurance       0.00      0.00      0.00         0
        intellectual_property       0.29      0.67      0.40       116
     

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [None]:
# worse - notices, entire agreement, survival, force_majeure

# drastically better - termination, term, payment_terms
# drastically worse - signature, data_security_and_privacy, third_party_beneficiaries, export_control_laws(!), relationship_of_the_parties

In [12]:
NUMBER_OF_ALTERNATIVE_PREDICTIONS = 3
ALTERNATIVE_PREDICTION_PROBABILITY_THRESHOLD = 0.0001
LABEL_DECODER = label_decoder

HIGH_CONFIDENCE_THRESHOLD = 0.9
MEDIUM_CONFIDENCE_THRESHOLD = 0.8
LOW_CONFIDENCE_THRESHOLD = 0.7

LOW_CONFIDENCE = 'low confidence'
MEDIUM_CONFIDENCE = 'medium confidence'
HIGH_CONFIDENCE = 'high confidence'

class ClauseIdService:
    def __init__(self, model):
        self.model = model

    def calculate_confidence_level(self, confidence):
        if confidence >= HIGH_CONFIDENCE_THRESHOLD:
            return HIGH_CONFIDENCE
        elif confidence >= MEDIUM_CONFIDENCE_THRESHOLD:
            return MEDIUM_CONFIDENCE

        return LOW_CONFIDENCE
    
    def _transform_pred(self, prediction):
        category = LABEL_DECODER[prediction['label']]  # should this be passed through init?
        confidence = prediction['score']

        return category, confidence
    
    def _get_indexes(self, prediction):
        indexes = np.argsort([pair['score'] for pair in prediction])
        # sort indexes in descending order
        indexes = indexes[::-1] 
        indexes = indexes[:NUMBER_OF_ALTERNATIVE_PREDICTIONS+1]
        
        return indexes
            
    def _get_result(self, predictions, paragraphs):
        result = dict()
        
        for position, prediction in enumerate(predictions):
            indexes = self._get_indexes(prediction)
            
            transformed_predictions = [self._transform_pred(prediction[idx]) for idx in indexes]
            clause_category, confidence = transformed_predictions[0]

            alternate_predictions = [
                {"clause category": category, "confidence": confidence} 
                for category, confidence in transformed_predictions[1:]
                if confidence > ALTERNATIVE_PREDICTION_PROBABILITY_THRESHOLD
            ]

            result[str(position)] = {
                'clause': paragraphs[position],
                'clause category': clause_category,
                'confidence': confidence,
                'confidence_level': self.calculate_confidence_level(confidence),
                'alternate': alternate_predictions
            }
            
        return result
    
    def annotate(self, paragraphs):
        predictions = self.model(paragraphs)
        
        return self._get_result(predictions, paragraphs)

In [10]:
pipeline = TextClassificationPipeline(
    model=model,
    tokenizer=tokenizer,
    return_all_scores=True,
    function_to_apply='softmax',
    truncation=True
)

In [13]:
cl = ClauseIdService(pipeline)

cl.annotate(list(df['Clause Text'].values[0:2]))

{'0': {'clause': ' This Fourth Amendment (the "Fourth Amendment") to that certain AFFILIATION AGREEMENT FOR DBS SATELLITE EXHIBITION OF CABLE PROGRAMMING dated as of November 15, 1993 by and between Playboy Entertainment Group, Inc. ("Programmer") and DirecTV, Inc. ("Affiliate"), as amended and supplemented, including by that certain First Amendment dated as of April 19, 1994, that certain Second Amendment dated July 26, 1995, and that certain Third Amendment dated August 26, 1997 (such Affiliation Agreement, as amended and supplemented, is referred to as the "Agreement"), is made and entered into as of March 15, 1999, with reference to the following facts (all defined terms used in this Fourth Amendment but not defined in this Fourth Amendment are defined in the Agreement):',
  'clause category': 'introduction',
  'confidence': 0.9997456669807434,
  'confidence_level': 'high confidence',
  'alternate': [{'clause category': 'UNK',
    'confidence': 0.00017905904678627849}]},
 '1': {'cl

In [202]:
preds = pipeline(list(df['Clause Text'].values), return_all_scores=False)

In [204]:
pred_labels = [label_decoder[p['label']] for p in preds]

In [206]:
predictions_ = [p if p != 'payment' else 'payment_terms' for p in pred_labels]

print(classification_report(df.Label.values, predictions_))

                               precision    recall  f1-score   support

                          UNK       0.00      0.00      0.00         0
                   assignment       0.85      0.81      0.83       212
              confidentiality       0.75      0.87      0.81       260
    data_security_and_privacy       0.43      0.88      0.58        43
                  definitions       0.99      0.76      0.86      3957
           dispute_resolution       0.90      0.86      0.88       175
             entire_agreement       0.90      0.95      0.92       136
          export_control_laws       0.53      0.84      0.65        25
                force_majeure       0.71      0.98      0.82        54
                governing_law       0.93      0.97      0.95       140
              indemnification       0.82      0.58      0.68       418
                    insurance       0.00      0.00      0.00         0
        intellectual_property       0.29      0.71      0.41       116
     

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [None]:
'''
                          UNK       0.00      0.00      0.00         0
                   assignment       0.85      0.78      0.81       212
              confidentiality       0.79      0.90      0.84       260
    data_security_and_privacy       0.51      0.91      0.66        43
                  definitions       0.99      0.76      0.86      3957
           dispute_resolution       0.93      0.86      0.90       175
             entire_agreement       0.91      0.95      0.93       136
          export_control_laws       0.57      0.84      0.68        25
                force_majeure       0.77      0.98      0.86        54
                governing_law       0.96      0.96      0.96       140
              indemnification       0.83      0.59      0.69       418
                    insurance       0.00      0.00      0.00         0
        intellectual_property       0.29      0.67      0.40       116
                 introduction       0.00      0.00      0.00         0
           limit_of_liability       0.78      0.51      0.62       301
                      notices       0.95      0.46      0.61      1180
                payment_terms       0.86      0.56      0.68       924
                     preamble       0.99      0.77      0.86       661
  relationship_of_the_parties       0.70      0.84      0.76        62
representation_and_warranties       0.97      0.68      0.80      3582
                 severability       0.98      0.98      0.98       118
                    signature       0.94      0.66      0.77      1998
                     survival       0.68      0.95      0.79       121
                        taxes       0.50      0.90      0.65       193
                         term       0.81      0.67      0.73       243
                  termination       0.90      0.68      0.78       560
    third_party_beneficiaries       0.52      0.91      0.66        34

                     accuracy                           0.69     15513
                    macro avg       0.70      0.70      0.68     15513
                 weighted avg       0.93      0.69      0.78     15513'''

'''                          
                   assignment       0.85      0.81      0.83       212
              confidentiality       0.75      0.87      0.81       260
    data_security_and_privacy       0.43      0.88      0.58        43
                  definitions       0.99      0.76      0.86      3957
           dispute_resolution       0.90      0.86      0.88       175
             entire_agreement       0.90      0.95      0.92       136
          export_control_laws       0.53      0.84      0.65        25
                force_majeure       0.71      0.98      0.82        54
                governing_law       0.93      0.97      0.95       140
              indemnification       0.82      0.58      0.68       418
                    insurance       0.00      0.00      0.00         0
        intellectual_property       0.29      0.71      0.41       116
                 introduction       0.00      0.00      0.00         0
           limit_of_liability       0.82      0.52      0.64       301
                      notices       0.94      0.45      0.61      1180
                payment_terms       0.86      0.56      0.68       924
                     preamble       0.98      0.75      0.85       661
  relationship_of_the_parties       0.61      0.82      0.70        62
representation_and_warranties       0.96      0.66      0.79      3582
                 severability       0.97      0.98      0.98       118
                    signature       0.92      0.62      0.74      1998
                     survival       0.70      0.93      0.80       121
                        taxes       0.50      0.90      0.64       193
                         term       0.78      0.71      0.75       243
                  termination       0.85      0.73      0.79       560
    third_party_beneficiaries       0.54      0.88      0.67        34

                     accuracy                           0.69     15513
                    macro avg       0.69      0.69      0.67     15513
                 weighted avg       0.92      0.69      0.78     15513'''

In [2]:
from transformers import BertTokenizer, BertForSequenceClassification
from transformers import TextClassificationPipeline

tokenizer = BertTokenizer.from_pretrained('../models/legal_bert_small', model_max_length=512)
model = BertForSequenceClassification.from_pretrained('../models/legal_bert_small')
model.eval()

pipeline = TextClassificationPipeline(
    model=model,
    tokenizer=tokenizer,
    return_all_scores=True,
    function_to_apply='softmax',
    truncation=True
)

In [22]:
df1 = pd.read_pickle('../data/clause_id/2019-11-19_clause_id_s1_s3_dataset.pkl')

df2 = pd.read_pickle('../data/clause_id/all_clause_ids_data_frame.pkl')
df2['Label'] = df2['clause_id'].apply(lambda x: LABEL_DICT[str(x)])

df_train = pd.read_excel('../data/clause_id/clause_id_train.xlsx')

In [None]:
train_texts, val_texts, train_labels, val_labels = train_test_split(
    df['Clause Text'].values, 
    df['label_id'].values, 
    test_size=.2,
    random_state=42
)