# Multi-label Legal Text Classification for CIA

## Model Evaluation

In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import csv
import gzip
import random

In [3]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, precision_score, recall_score, classification_report, confusion_matrix

In [4]:
from sentence_transformers import models, losses, datasets
from sentence_transformers import LoggingHandler, SentenceTransformer, util, InputExample
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
from sentence_transformers.datasets import DenoisingAutoEncoderDataset
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
from tqdm import tqdm
import math
import logging
from datetime import datetime

In [5]:
os.chdir("/Users/janinedevera/Documents/School/MDS 2021-2023/Thesis/multilabel-legal-text-classification-CIA")
os.getcwd()

'/Users/janinedevera/Documents/School/MDS 2021-2023/Thesis/multilabel-legal-text-classification-CIA'

Prepare test data

In [6]:
# definitions
oecd_defs = pd.read_csv("data/02 oecd_definitions_stopwords_grouped.csv", index_col=0)

# sample test data
test_data = pd.read_csv("data/01 legal_texts_with_labels_grouped.csv", index_col=0).drop_duplicates()

In [8]:
test_sample = test_data.sample(100, random_state=999)
test_sample

Unnamed: 0.1,Unnamed: 0,Law,Paragraph,Text,Category,text_clean
45,45,Concession Contract - São Paulo/Guarulhos Inte...,2.5\n2.6,"The term of the concession is 20 years, extend...",A,the term of the concess is year extend for up ...
581,581,Resolution 71/2022 (Former Normative Resolutio...,Art. 11,The legal entity incorporated under Brazilian ...,A,the legal entiti incorpor under brazilian law ...
686,686,Internal Regulation - SOE's - COMPANHIA DOCAS ...,Art. 199,The contracts will not exceed 5 (five) years.,A,the contract will not exceed five year
446,446,Resolution 18/2006 by National Petroleum Agenc...,"Art. 5, par. 7",Even if the firm complies with the requirement...,A,even if the firm compli with the requir state ...
167,167,Concession Contract - Salvador - Deputado Luís...,11.13.1,In case of shortage of areas for new entrants ...,A,in case of shortag of area for new entrant to ...
...,...,...,...,...,...,...
110,110,Concession Contract - Rio de Janeiro/Galeão In...,10.7,"In the first five years, share transfer leadin...",A,in the first five year share transfer lead to ...
1441,1441,Resolution 192/2011 by National Civil Aviation...,Art. 13,The author or the person economically responsi...,,the author or the person econom respons for th...
1492,1492,Resolution 191/2011 by National Civil Aviation...,Art. 2 sole par,Computerized systems must be used that provide...,,computer system must be use that provid for th...
69,69,Concession Contract - São Paulo/Guarulhos Inte...,11.7.1,In case of shortage of areas for new entrants ...,A,in case of shortag of area for new entrant to ...


Similarity scores

In [9]:
def calculate_sim_scores(list1, list2, defs_name, labels_true, model): 
    
    df = pd.DataFrame(columns = ['Text', 'Label_Text', 'Label_Name', 'Score', 'Label_True'])
    
    for i, label in zip(list1, labels_true):
        embeddings1 = model.encode(i, convert_to_tensor = True)
        for j, name in zip(list2, defs_name):
            embeddings2 = model.encode(j, convert_to_tensor = True)
            cosine_scores = util.cos_sim(embeddings1, embeddings2)

            row = pd.DataFrame({'Text': i, 'Label_Text': j, 'Label_Name': name, 'Label_True': label, 'Score': cosine_scores[0]}, index=[0])
            df = pd.concat([row,df.loc[:]]).reset_index(drop=True)

    return df

In [12]:
test_list = test_data['text_clean'] 
labels_list = test_data['Category']
defs_list = oecd_defs['text_clean']
defs_name = oecd_defs['Main']

### A. BERT base + TSDAE + STS

In [15]:
# specify trained model 
model_sbert_sts = SentenceTransformer('models/sbert_stsbenchmark-2023-03-23_09-50')
model_tsdae = SentenceTransformer('models/tsdae-2023-03-22_12-39')

In [None]:
results = calculate_sim_scores(test_list, defs_list, defs_name, labels_list, model_sbert_sts).drop_duplicates()

In [None]:
results.to_csv("data/scores/03 sim_scores_bert_sts.csv")

Threshold based predictions

In [102]:
results = pd.read_csv("data/scores/03 sim_scores_bert_sts.csv", index_col = 0)

In [101]:
results_multiclass = results
results_multiclass['Predict'] = np.where(results_multiclass['Score'] > 0.60, results_multiclass['Label_Name'], 'None')

In [67]:
results_multiclass

Unnamed: 0,Text,Label_Text,Label_Name,Label_True,Score,Predict
0,thi resolut goe into effect on juli 1st when i...,regul sometim limit choic avail to consum for ...,D,,0.371452,
1,thi resolut goe into effect on juli 1st when i...,regul can affect supplier behaviour by not onl...,C,,0.414715,
2,thi resolut goe into effect on juli 1st when i...,regul can affect a supplier 's abil to compet ...,B,,0.366034,
3,thi resolut goe into effect on juli 1st when i...,limit the number of supplier lead to the risk ...,A,,0.434407,
4,thi resolut goe into effect on juli 1st when i...,regul sometim limit choic avail to consum for ...,D,,0.512986,
...,...,...,...,...,...,...
6127,the nation civil aviat secretari must authoris...,limit the number of supplier lead to the risk ...,A,A,0.595055,
6128,if ani compet bodi on the matter impos modif o...,regul sometim limit choic avail to consum for ...,D,A,0.602724,D
6129,if ani compet bodi on the matter impos modif o...,regul can affect supplier behaviour by not onl...,C,A,0.568921,
6130,if ani compet bodi on the matter impos modif o...,regul can affect a supplier 's abil to compet ...,B,A,0.550923,


In [68]:
# Generate the classification report
report = classification_report(results_multiclass['Label_True'], results_multiclass['Predict'])
print(report)

              precision    recall  f1-score   support

           A       0.34      0.13      0.18      1652
           B       0.14      0.08      0.10       308
           C       0.01      0.04      0.02        52
           D       0.00      0.00      0.00         4
        None       0.57      0.76      0.65      2640

    accuracy                           0.48      4656
   macro avg       0.21      0.20      0.19      4656
weighted avg       0.45      0.48      0.44      4656



In [69]:
cfm = confusion_matrix(results_multiclass['Label_True'], results_multiclass['Predict'])
print(cfm)

[[ 209   55   64   68 1256]
 [  56   24   27   26  175]
 [   2    1    2    2   45]
 [   0    0    0    0    4]
 [ 355   89   92  110 1994]]


In [70]:
df_cfm = pd.DataFrame(
    cfm/np.sum(cfm),
    index = ['A', 'B', 'C', 'D', 'None'], 
    columns = ['A', 'B', 'C', 'D', 'None']
)

In [None]:
print(df_cfm)
plt.figure(figsize = (12,7))
sns.heatmap(df_cfm, annot=True, cmap="Blues")
sns.set(font_scale=1.7)

Threshold based predictions: binary

In [103]:
results_binary = results

In [104]:
results_binary

Unnamed: 0,Text,Label_Text,Label_Name,Label_True,Score
0,thi resolut goe into effect on juli 1st when i...,regul sometim limit choic avail to consum for ...,D,,0.371452
1,thi resolut goe into effect on juli 1st when i...,regul can affect supplier behaviour by not onl...,C,,0.414715
2,thi resolut goe into effect on juli 1st when i...,regul can affect a supplier 's abil to compet ...,B,,0.366034
3,thi resolut goe into effect on juli 1st when i...,limit the number of supplier lead to the risk ...,A,,0.434407
4,thi resolut goe into effect on juli 1st when i...,regul sometim limit choic avail to consum for ...,D,,0.512986
...,...,...,...,...,...
6127,the nation civil aviat secretari must authoris...,limit the number of supplier lead to the risk ...,A,A,0.595055
6128,if ani compet bodi on the matter impos modif o...,regul sometim limit choic avail to consum for ...,D,A,0.602724
6129,if ani compet bodi on the matter impos modif o...,regul can affect supplier behaviour by not onl...,C,A,0.568921
6130,if ani compet bodi on the matter impos modif o...,regul can affect a supplier 's abil to compet ...,B,A,0.550923


In [105]:
results_max_score = results_binary.groupby('Text')['Score'].max()

In [106]:
results_final = pd.merge(results_binary, results_max_score, on=['Text', 'Score'])

In [107]:
results_final['Predict'] = np.where(results_final['Score'] > 0.60, 'Yes', 'None')
results_final['Label_True'] = np.where(results_final['Label_True'] == 'None', results_final['Label_True'], 'Yes')

In [108]:
results_final

Unnamed: 0,Text,Label_Text,Label_Name,Label_True,Score,Predict
0,thi resolut goe into effect on juli 1st when i...,limit the number of supplier lead to the risk ...,A,,0.434407,
1,thi resolut goe into effect on juli 1st when i...,regul sometim limit choic avail to consum for ...,D,,0.512986,
2,thi resolut replac the disposit of art to and ...,limit the number of supplier lead to the risk ...,A,,0.508984,
3,the non-fulfil of the oblig establish in thi r...,limit the number of supplier lead to the risk ...,A,,0.686350,Yes
4,anac may at ani time conduct audit request the...,limit the number of supplier lead to the risk ...,A,,0.499632,
...,...,...,...,...,...,...
1159,case not allow the mainten of commerci concess...,limit the number of supplier lead to the risk ...,A,Yes,0.713664,Yes
1160,the nation civil aviat secretari may authoris ...,limit the number of supplier lead to the risk ...,A,Yes,0.683381,Yes
1161,the commerci contract which involv the use of ...,limit the number of supplier lead to the risk ...,A,Yes,0.742503,Yes
1162,the nation civil aviat secretari must authoris...,limit the number of supplier lead to the risk ...,A,Yes,0.595055,


In [109]:
report = classification_report(results_final['Label_True'], results_final['Predict'])
print(report)

              precision    recall  f1-score   support

        None       0.57      0.46      0.51       660
         Yes       0.43      0.54      0.48       504

    accuracy                           0.49      1164
   macro avg       0.50      0.50      0.49      1164
weighted avg       0.51      0.49      0.50      1164



### B. Off Shelf SBERT

In [10]:
# initialize model 
word_embedding_model = models.Transformer('bert-base-uncased')
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), 'mean')
model_off_shelf = SentenceTransformer(modules=[word_embedding_model, pooling_model])

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [13]:
results_os = calculate_sim_scores(test_list, defs_list, defs_name, labels_list, model_off_shelf).drop_duplicates()

In [129]:
results_multiclass = results_os 
results_os['Predict'] = np.where(results_os['Score'] > 0.60, results_os['Label_Name'], 'None')

In [130]:
# Generate the classification report
report = classification_report(results_os['Label_True'], results_os['Predict'])
print(report)

              precision    recall  f1-score   support

           A       0.36      0.25      0.29      1652
           B       0.07      0.25      0.10       308
           C       0.01      0.25      0.02        52
           D       0.00      0.25      0.00         4
        None       0.64      0.01      0.03      2640

    accuracy                           0.11      4656
   macro avg       0.22      0.20      0.09      4656
weighted avg       0.49      0.11      0.13      4656

