# Legal Text Classification

### Import Statements

In [1]:
# !pip install transformers==4.33.1
# !pip install torch==2.1.0
# !pip install accelerate -U
# !pip install evaluate

In [2]:
import pandas as pd
import numpy as np
import datasets
import torch
import transformers
import random

random.seed(10)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
df = pd.read_csv("legal_text_classification.csv")

In [4]:
df.shape

(24985, 4)

In [5]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 24985 entries, 0 to 24984
Data columns (total 4 columns):
 #   Column        Non-Null Count  Dtype 
---  ------        --------------  ----- 
 0   case_id       24985 non-null  object
 1   case_outcome  24985 non-null  object
 2   case_title    24985 non-null  object
 3   case_text     24809 non-null  object
dtypes: object(4)
memory usage: 780.9+ KB


In [6]:
df.isna().sum()

case_id           0
case_outcome      0
case_title        0
case_text       176
dtype: int64

In [7]:
df.loc[df['case_text'].isna(), :].head(1)

Unnamed: 0,case_id,case_outcome,case_title,case_text
24,Case29,followed,Elderslie Finance Corp Ltd v Australian Securi...,


In [8]:
df.head()

Unnamed: 0,case_id,case_outcome,case_title,case_text
0,Case1,cited,Alpine Hardwood (Aust) Pty Ltd v Hardys Pty Lt...,Ordinarily that discretion will be exercised s...
1,Case2,cited,Black v Lipovac [1998] FCA 699 ; (1998) 217 AL...,The general principles governing the exercise ...
2,Case3,cited,Colgate Palmolive Co v Cussons Pty Ltd (1993) ...,Ordinarily that discretion will be exercised s...
3,Case4,cited,Dais Studio Pty Ltd v Bullett Creative Pty Ltd...,The general principles governing the exercise ...
4,Case5,cited,Dr Martens Australia Pty Ltd v Figgins Holding...,The preceding general principles inform the ex...


In [9]:
df.loc[1, 'case_text']

'The general principles governing the exercise of the discretion to award indemnity costs after rejection by an unsuccessful party of a so called Calderbank letter were set out in the judgment of the Full Court in Black v Lipovac [1998] FCA 699 ; (1998) 217 ALR 386. In summary those principles are: 1. Mere refusal of a "Calderbank offer" does not itself warrant an order for indemnity costs. In this connection it may be noted that Jessup J in Dais Studio Pty Ltd v Bullet Creative Pty Ltd [2008] FCA 42 said that (at [6]): if the rejection of such an offer is to ground a claim for indemnity costs, it must be by reason of some circumstance other than that the offer happened to comply with the Calderbank principle. 2. To obtain an order for indemnity costs the offeror must show that the refusal to accept it was unreasonable. 3. The reasonableness of the conduct of the offeree is to be viewed in the light of the circumstances that existed when the offer was rejected.'

### Preprocessing

In [10]:
# Convert the pandas DataFrame to a Hugging Face dataset
df = df.rename(columns={'case_outcome': 'label'})
data = datasets.Dataset.from_pandas(df)
data = data.class_encode_column("label")

# Perform a stratified train-test split test set 90%, some of the classes are very less so better to stratify
data = data.train_test_split(test_size=0.1, stratify_by_column='label', seed=10)


num_classes = data['train'].features['label'].num_classes
id2label = {i:data['train'].features['label'].int2str(i) for i in range(num_classes)}
label2id = {label:i for (i,label) in id2label.items()}

Casting to class labels: 100%|███████████████████████████████| 24985/24985 [00:00<00:00, 154938.75 examples/s]


In [11]:
data

DatasetDict({
    train: Dataset({
        features: ['case_id', 'label', 'case_title', 'case_text'],
        num_rows: 22486
    })
    test: Dataset({
        features: ['case_id', 'label', 'case_title', 'case_text'],
        num_rows: 2499
    })
})

In [12]:
id2label

{0: 'affirmed',
 1: 'applied',
 2: 'approved',
 3: 'cited',
 4: 'considered',
 5: 'discussed',
 6: 'distinguished',
 7: 'followed',
 8: 'referred to',
 9: 'related'}

In [13]:
data['train'].features['label']

ClassLabel(names=['affirmed', 'applied', 'approved', 'cited', 'considered', 'discussed', 'distinguished', 'followed', 'referred to', 'related'], id=None)

### Feature Engineering

In [14]:
# case outcome is label column, case_title and case_text we can merge in one column as they both might contain
# some important textual information. For example: if one compant names in case_title judgement is most of the time any one class label
# in such a case including case_title is important. case_text has a lot of information about case which might influence label column

def merge_title_text(example):
    example['text'] = "Case Title: " + example['case_title'] + str("" if example['case_text'] is None else "\nCase Text: " + example['case_text'])
    return example

In [15]:
data = data.map(merge_title_text)

Map: 100%|█████████████████████████████████████████████████████| 22486/22486 [00:03<00:00, 7343.51 examples/s]
Map: 100%|███████████████████████████████████████████████████████| 2499/2499 [00:00<00:00, 7228.51 examples/s]


In [16]:
data

DatasetDict({
    train: Dataset({
        features: ['case_id', 'label', 'case_title', 'case_text', 'text'],
        num_rows: 22486
    })
    test: Dataset({
        features: ['case_id', 'label', 'case_title', 'case_text', 'text'],
        num_rows: 2499
    })
})

In [17]:
print(data['train']['text'][0])

Case Title: Comandate Marine Corporation v The Ship "Boomerang I" [2006] FCAFC 106 ; (2006) 151 FCR 403
Case Text: course, there is an incongruity in this approach because it ignores the rights of a secured creditor (other than a holder of a maritime lien recognised in s 15) such as a mortgagee and instead prefers those of a co-owner. Thus, if a vessel is co-owned it would not be able to be arrested under s 19 if one co-owner were not a relevant person under s 19(a), but a mortgagee cannot escape the amenability of the vessel to arrest. But this is the consequence of the legislative choice of selecting, as the criterion for actuating the right defined in s 19(b), the "owner", and not extending this to secured creditors or demise charterers: cf Comandate Marine Corporation v The Ship "Boomerang I" [2006] FCAFC 106 ; (2006) 151 FCR 403. As Allsop J observed, the wide group of categories identified in s 19(a) is then "limited to the more narrow funnel in para (b) ...": " Boomerang I " 151

### Finetuning

In [18]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("distilbert/distilbert-base-uncased")

In [19]:
# truncate input text to be not more than distibert maximum imput limit
def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True, max_length=512)

In [20]:
data = data.map(preprocess_function, batched=True)

Map: 100%|█████████████████████████████████████████████████████| 22486/22486 [00:16<00:00, 1328.87 examples/s]
Map: 100%|███████████████████████████████████████████████████████| 2499/2499 [00:01<00:00, 1299.37 examples/s]


In [21]:
data

DatasetDict({
    train: Dataset({
        features: ['case_id', 'label', 'case_title', 'case_text', 'text', 'input_ids', 'attention_mask'],
        num_rows: 22486
    })
    test: Dataset({
        features: ['case_id', 'label', 'case_title', 'case_text', 'text', 'input_ids', 'attention_mask'],
        num_rows: 2499
    })
})

In [22]:
from transformers import DataCollatorWithPadding

# For padding
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [23]:
type(data['train']['label'])

list

In [24]:
set_outcome = list(set(data['train']['label']))

counts = [0]*len(set_outcome)

list(map(lambda x, y: {y: (x+data['train']['label'].count(y))/len(data['train']['label'])}, counts, set_outcome))

[{0: 0.004536155830294405},
 {1: 0.0979720715111625},
 {2: 0.004313795250378013},
 {3: 0.4890598594681135},
 {4: 0.06853153073023215},
 {5: 0.04100329093658276},
 {6: 0.024326247442853333},
 {7: 0.09027839544605533},
 {8: 0.17544249755403363},
 {9: 0.004536155830294405}]

In [25]:
list(map(lambda x, y: {y: (x+data['test']['label'].count(y))/len(data['test']['label'])}, counts, set_outcome))

[{0: 0.004401760704281713},
 {1: 0.09803921568627451},
 {2: 0.004401760704281713},
 {3: 0.4889955982392957},
 {4: 0.06842737094837935},
 {5: 0.04081632653061224},
 {6: 0.024409763905562223},
 {7: 0.09043617446978791},
 {8: 0.1756702681072429},
 {9: 0.004401760704281713}]

In [26]:
import evaluate

accuracy = evaluate.load("accuracy")

In [27]:
import numpy as np

# We can see precision and recall later first lets try accuracy
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

In [28]:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer

# Distilbert becuase its small, easy to fit in memory, we can try Lora and Peft for more memory optimization later and we can also
# try some models who have been trained on legal domain
model = AutoModelForSequenceClassification.from_pretrained(
    "distilbert/distilbert-base-uncased", num_labels=10, id2label=id2label, label2id=label2id
)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert/distilbert-base-uncased and are newly initialized: ['pre_classifier.bias', 'classifier.bias', 'pre_classifier.weight', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [29]:
training_args = TrainingArguments(
    output_dir="finetuned_model",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=1,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=data["train"],
    eval_dataset=data["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

dataloader_config = DataLoaderConfiguration(dispatch_batches=None)
You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,Accuracy
1,1.4824,1.438468,0.492197


TrainOutput(global_step=2811, training_loss=1.516077389712013, metrics={'train_runtime': 39730.428, 'train_samples_per_second': 0.566, 'train_steps_per_second': 0.071, 'total_flos': 2954456857070400.0, 'train_loss': 1.516077389712013, 'epoch': 1.0})

### Saving model for backup

In [30]:
model.save_pretrained('finetuned_model_backup')
tokenizer.save_pretrained('finetuned_model_backup')

('finetuned_model_backup\\tokenizer_config.json',
 'finetuned_model_backup\\special_tokens_map.json',
 'finetuned_model_backup\\vocab.txt',
 'finetuned_model_backup\\added_tokens.json',
 'finetuned_model_backup\\tokenizer.json')

### Calculating other metrics on eval data

It was necessary to check Precision, Recall and F1 since, all classes are not balanced.

In [90]:
# small_data = data['test'].select(range(100))
predictions = trainer.predict(data['test'])

In [91]:
# predictions

In [96]:
from sklearn.metrics import precision_score, recall_score, f1_score
import numpy as np

# Define a function to compute your metrics
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    precision = precision_score(labels, predictions, average='weighted')
    recall = recall_score(labels, predictions, average='weighted')
    f1 = f1_score(labels, predictions, average='weighted')
    return {"precision": precision, "recall": recall, "f1": f1}

In [97]:
metrics = compute_metrics((predictions.predictions, predictions.label_ids))
print(metrics)

{'precision': 0.3058893677564232, 'recall': 0.4921968787515006, 'f1': 0.3468256056594664}


  _warn_prf(average, modifier, msg_start, len(result))


In [98]:
# predictions

In [99]:
from datasets import load_metric

# Load the metrics
precision_metric = load_metric("precision")
recall_metric = load_metric("recall")
f1_metric = load_metric("f1")

preds = torch.argmax(torch.from_numpy(predictions.predictions), dim=-1)
preds = preds.numpy()
label = predictions.label_ids

# Compute the metrics
precision = precision_metric.compute(predictions=preds, references=label, average="weighted")["precision"]
recall = recall_metric.compute(predictions=preds, references=label, average="weighted")["recall"]
f1 = f1_metric.compute(predictions=preds, references=label, average="weighted")["f1"]

print(f"Precision: {precision}")
print(f"Recall: {recall}")
print(f"F1-score: {f1}")

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


Precision: 0.3058893677564232
Recall: 0.4921968787515006
F1-score: 0.3468256056594664


  _warn_prf(average, modifier, msg_start, len(result))
