<a href="https://colab.research.google.com/github/camdenmcgath/Artificial-Judge/blob/master/supreme_court_judge.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install evaluate transformers seaborn

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
import pandas as pd
import re
import evaluate
import numpy as np
import transformers as trans
from transformers import BertTokenizer, BertForSequenceClassification, TrainingArguments, Trainer, LongformerTokenizer, LongformerForSequenceClassification
import torch
from torch.optim import AdamW, lr_scheduler
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from scipy.interpolate import make_interp_spline, BSpline
import nltk
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset
from sklearn.metrics import confusion_matrix
import seaborn as sns

import os
import gc
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:2048"


In [None]:
cases = pd.read_csv("justice.csv")
cases.drop(columns=['Unnamed: 0', 'ID', 'name', 'href', 'docket', 'term',  
                    'majority_vote', 'minority_vote', 'decision_type', 'disposition', 'issue_area'], inplace=True)
cases.dropna(inplace=True)


In [None]:
cases = cases.rename(columns={'first_party_winner': 'winning_party_idx'})
for i, row in cases.iterrows():
    if row['winning_party_idx'] == True:
        cases.loc[i, 'winning_party_idx'] = 0
    else:
        cases.loc[i, 'winning_party_idx'] = 1

In [None]:

# Create a mirrored case for each case, where the parties are swapped to prevent favoring first_party
mirrored_cases = cases.copy()
mirrored_cases['first_party'], mirrored_cases['second_party'] = mirrored_cases['second_party'], mirrored_cases['first_party']
mirrored_cases['winning_party_idx'] = (mirrored_cases['winning_party_idx'] == 0).astype(int)
mirrored_cases.reset_index(drop=True, inplace=True)

cases = pd.concat([cases, mirrored_cases])
cases.reset_index(drop=True, inplace=True)
print(f'There are {len(cases)} cases.')
print(f'There are {len(cases[cases["winning_party_idx"]==0])} rows for class 0.')
print(f'There are {len(cases[cases["winning_party_idx"]==1])} rows for class 1.')

In [None]:
cases['facts'] = cases['facts'].str.replace(r'<[^<]+?>', '', regex=True)
cases['facts'] = cases['facts'].apply(lambda x: re.sub(r'[^a-zA-Z0-9\'\s]', '', x))
#cases['facts'] = cases['facts'].str.lower()

def word_count(text):
  return len(text.split())

cases['facts_len'] = cases['facts'].apply(word_count)
cases['facts_len'].describe()

In [None]:
# Create a scatterplot of the fact lengths
plt.scatter(range(len(cases['facts_len'])), cases['facts_len'], s=5, alpha=0.5)

plt.axhline(y=400, color='red')

plt.xlabel('Index')
plt.ylabel('Fact Length')
plt.title('Distribution of Fact Lengths')

plt.show()

# Count the number of facts below 2500 length
num_short_facts = len(cases[cases['facts_len'] <= 390])

# Calculate the percentage of short facts
percentage_short_facts = num_short_facts / len(cases) * 100
print(f"\nPercentage of cases with fact word no greater than 390: {percentage_short_facts:.2f}%")

In [None]:
cases['facts'] = cases.loc[cases['facts_len'] <= 390, 'facts']
cases['facts'] = cases.apply(lambda x: f"{x['first_party']} [SEP] {x['second_party']} [SEP] {x['facts']}", axis=1)
cases = cases.drop(columns=['first_party', 'second_party', 'facts_len'])

train_facts, val_facts, train_winners,  val_winners = train_test_split(
    cases['facts'], cases['winning_party_idx'], test_size=0.20)

train_facts, val_facts = train_facts.tolist(), val_facts.tolist()
train_winners, val_winners = [str(i) for i in train_winners], [str(i) for i in val_winners]

#leave truncate flag off to ensure that no data is truncated
#if data is too large this code will not run
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
train_encodings = tokenizer(train_facts, padding=True)
val_encodings = tokenizer(val_facts, padding=True)


In [None]:
class TextDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(int(self.labels[idx]))
        type(item)
        return item

    def __len__(self):
        return len(self.labels)

train_dataset = TextDataset(train_encodings, train_winners)
val_dataset = TextDataset(val_encodings, val_winners)


In [None]:
#Load pretrained model
model = BertForSequenceClassification.from_pretrained('bert-base-cased', 
                                                      num_labels=2, 
                                                      hidden_dropout_prob=0.4,
                                                      attention_probs_dropout_prob=0.4)

training_args = TrainingArguments(
    output_dir="test_trainer", 
    logging_dir='logs', 
    evaluation_strategy="epoch",
    per_device_train_batch_size=32,  
    per_device_eval_batch_size=32,
    num_train_epochs=3,
    logging_steps=50,
    learning_rate=1e-6,
)
metric = evaluate.load("accuracy")
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=1)
    return metric.compute(predictions=predictions, references=labels)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
)


In [None]:
trainer.train()

In [None]:
# Evaluate on the validation set
result = trainer.evaluate()

# Extract the predicted and true labels from the evaluation results
y_pred = result.predictions.argmax(axis=1)
y_true = result.label_ids

# Compute the confusion matrix
cm = confusion_matrix(y_true, y_pred)

# Plot the confusion matrix as a heatmap
sns.heatmap(cm, annot=True, cmap='Blues', fmt='g')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()