# Legal Judgement Predictor - A classification task on BERT embeddings

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score

path_to_datasets = 'embeddings_datasets/legal-bert-base-uncased/'
path_to_model = 'models/attention-mlp/'
create_dataset = False

## Preprocessing

In [None]:
class ECHRDataset(Dataset):
        def __init__(self, data, attention_mask, labels):
            self.data = data
            self.attention_mask = attention_mask
            self.labels = labels

        def __len__(self):
            return len(self.data)

        def __getitem__(self, idx):
            return self.data[idx], self.attention_mask[idx], self.labels[idx]

In [None]:
if create_dataset:
        
    # pad the data to be of the same shape
    def pad_data(data, max_len):
        padded_data = []
        attention_masks = []
        for i in range(len(data)):
            attention_masks.append([1] * data[i].shape[0] + [0] * (max_len - data[i].shape[0]))
            padded_data.append(F.pad(data[i], (0, 0, 0, max_len - data[i].shape[0])))
        #print(len(attention_masks))
        return torch.stack(padded_data), torch.tensor(attention_masks)

    # load data
    train = torch.load('embeddings/legal-bert-base-uncased/emb_tr_cpu.pkl')
    dev = torch.load('embeddings/legal-bert-base-uncased/emb_dev_cpu.pkl')
    test = torch.load('embeddings/legal-bert-base-uncased/emb_test_cpu.pkl')

    print('Train '+str(len(train)),'Dev '+str(len(dev)), 'Test '+str(len(test)))
    
    # concat dev to train series
    train = np.concatenate((train, dev))

    print('Train + Dev = '+str(len(train)))

    # load labels
    train_labels = pd.read_pickle('embeddings/legal-bert-base-uncased/train_labels.pkl')
    dev_labels = pd.read_pickle('embeddings/legal-bert-base-uncased/dev_labels.pkl')
    test_labels = pd.read_pickle('embeddings/legal-bert-base-uncased/test_labels.pkl')

    # concat dev labels to train labels
    train_labels = torch.tensor(np.concatenate((train_labels, dev_labels)))

    # pad the data
    max_len_train = max([x.shape[0] for x in train])
    max_len_test = max([x.shape[0] for x in test])
    train, train_attention_masks = pad_data(train, max_len_train)
    test, test_attention_masks = pad_data(test, max_len_test)

    # create the datasets
    train_dataset = ECHRDataset(train, train_attention_masks, train_labels)
    test_dataset = ECHRDataset(test, test_attention_masks, test_labels)

    print (train_dataset.data.device)

    # save the datasets
    if not os.path.exists(path_to_datasets):
        os.makedirs(path_to_datasets)
    torch.save(train_dataset, path_to_datasets+'train_dataset.pt')
    torch.save(test_dataset, path_to_datasets+'test_dataset.pt')


In [None]:
if not create_dataset:
    train_dataset = torch.load(path_to_datasets+'train_dataset.pt')
    test_dataset = torch.load(path_to_datasets+'test_dataset.pt')
    print(len(train_dataset))

## Baseline

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # Importing Axes3D for 3D plotting
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import numpy as np

scaler = StandardScaler()
mean_train = np.array([x.mean(0).numpy() for x in train_dataset.data])

print(mean_train.shape)

train_scaled = scaler.fit_transform(mean_train)

pca = PCA(n_components=3)
xpca = pca.fit_transform(train_scaled)

fig = plt.figure(figsize=(12, 8))
ax = fig.add_subplot(111, projection='3d')

# Scatter plot
scatter = ax.scatter(xpca[:, 0], xpca[:, 1], xpca[:, 2], c=train_dataset.labels, cmap='viridis')

# Adding legend for positive and negative classes
plt.legend(handles=[plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='blue', markersize=10, label='Negative'),
                    plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='orange', markersize=10, label='Positive')],
           loc='upper right')

# Set labels and title
ax.set_xlabel('PC 1')
ax.set_ylabel('PC 2')
ax.set_zlabel('PC 3')
plt.title('PCA of Train Data')

# Add colorbar
plt.colorbar(scatter, label='Class')

plt.show()

In [None]:
# BASELINE on avg of chunks

from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.metrics import f1_score

# Assuming you have already defined and scaled mean_train and train_dataset.labels

# Splitting the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(train_scaled, train_dataset.labels, test_size=0.2, random_state=42)

# Initializing the Random Forest classifier
rf_classifier = RandomForestClassifier(n_estimators=100, random_state=42)

# Training the classifier
rf_classifier.fit(X_train, y_train)

# Predicting on the test set
y_pred = rf_classifier.predict(X_test)

# Calculating metrics
f1 = f1_score(y_test, y_pred, average='weighted')
print(classification_report(y_test, y_pred))
print("Random Forest F1:", f1)


In [None]:
from sklearn.svm import SVC

# Initializing the SVM classifier
svm_classifier = SVC(kernel='linear', random_state=42)

# Training the classifier
svm_classifier.fit(X_train, y_train)

# Predicting on the test set
y_pred_svm = svm_classifier.predict(X_test)

# Calculating metrics
f1 = f1_score(y_test, y_pred_svm, average='weighted')
print(classification_report(y_test, y_pred_svm))
print("SVM F1:", f1)

In [None]:
from sklearn.neural_network import MLPClassifier

# Initializing the MLP classifier
mlp_classifier = MLPClassifier(hidden_layer_sizes=(16,), max_iter=1000, random_state=42)

# Training the classifier
mlp_classifier.fit(X_train, y_train)

# Predicting on the test set
y_pred_mlp = mlp_classifier.predict(X_test)

# Calculating metrics
f1 = f1_score(y_test, y_pred_mlp, average='weighted')
print(classification_report(y_test, y_pred_mlp))
print("MLP F1:", f1)

## Attention + MLP Model

In [None]:
class AttentionMLP(nn.Module):
    def __init__(self, input_dim, hidden_sizes, dropout=0, weight_decay=0.01):
        super(AttentionMLP, self).__init__()
        # vector for query attention
        self.selector = nn.parameter.Parameter(torch.randn(input_dim, 1))
        self.Value= nn.Linear(input_dim, input_dim, bias=False)
        self.Key = nn.Linear(input_dim, input_dim, bias=False)
        # mlp layers
        layers = []
        for i in range(len(hidden_sizes) - 1):
            layers.append(nn.Linear(hidden_sizes[i], hidden_sizes[i + 1]))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout))
        self.mlp = nn.Sequential(*layers)
        self.output = nn.Linear(hidden_sizes[-1], 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, attention_mask=None):
        # attention
        key = self.Key(x)

        value = self.Value(x)

        non_normalized_attention = torch.matmul(key, self.selector)
        if attention_mask is not None:
            attention_mask=attention_mask.unsqueeze(2)

            non_normalized_attention = non_normalized_attention.masked_fill(attention_mask == 0, -1e9)
        attention = F.softmax(non_normalized_attention, dim=1)
        # permute the attention to match the shape of the value
        attention = attention.permute(0, 2, 1)

        x = torch.matmul(attention, value)

        # mlp
        x = self.mlp(x)
        x = self.output(x)
        x = self.sigmoid(x)
        return x.flatten()


In [None]:
# to device 
# check if windows or macos
if (torch.cuda.is_available()):
    print("Running on GPU")
    device = torch.device('cuda')
elif (torch.backends.mps.is_available()):
    print("Running on MPS")
    device = torch.device('mps')
else :
    print("Running on CPU")
    device = torch.device('cpu')

train_dataset.data = train_dataset.data.to(device)
train_dataset.labels = train_dataset.labels.to(device)
train_dataset.attention_mask = train_dataset.attention_mask.to(device)

test_dataset.data = test_dataset.data.to(device)
test_dataset.labels = test_dataset.labels.to(device)
test_dataset.attention_mask = test_dataset.attention_mask.to(device)


In [None]:
# create the dataloader for the training set
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)

In [None]:
# create the model
model = AttentionMLP(768, [768,16])
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.001)
criterion = nn.BCELoss()

In [None]:
import torch
import os
from tqdm import tqdm
from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score

train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []
best_val_loss = float('inf')
best_model = None
epochs = 50
k_folds = 5  # Change this according to your preference
all_val_accuracies = []

# Initialize tqdm progress bar for epochs
epoch_progress_bar = tqdm(range(epochs), desc="Epochs", unit="epoch")

# Initialize KFold
kf = KFold(n_splits=k_folds, shuffle=True)

for fold, (train_indices, val_indices) in enumerate(kf.split(train_dataset)):
    print(f"Fold {fold + 1}/{k_folds}")

    # Split dataset into train and validation for this fold
    train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
    val_sampler = torch.utils.data.SubsetRandomSampler(val_indices)

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler)
    val_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=val_sampler)

    for epoch in epoch_progress_bar:
        running_loss = 0.0
        loop = tqdm(enumerate(train_loader), total=len(train_loader), leave=False)
        
        for i, data in loop:
            inputs, att_masks, labels = data

            optimizer.zero_grad()
            outputs = model(inputs, att_masks)
            loss = criterion(outputs, labels.float())
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        train_loss = running_loss / len(train_loader)
        train_losses.append(train_loss)

        # Validate the model
        with torch.no_grad():
            val_running_loss = 0.0
            val_predictions = []
            val_targets = []
            for data in val_loader:
                inputs, att_masks, labels = data
                outputs = model(inputs, att_masks)
                val_loss = criterion(outputs, labels.float())
                val_running_loss += val_loss.item()
                val_predictions.extend((outputs > 0.5).long().cpu().numpy())
                val_targets.extend(labels.cpu().numpy())

            val_loss = val_running_loss / len(val_loader)
            val_losses.append(val_loss)

            # Calculate validation accuracy
            val_accuracy = accuracy_score(val_targets, val_predictions)
            all_val_accuracies.append(val_accuracy)  # Store accuracy for this fold

            # Save the model if it's the best so far for val loss
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_model = model

        # Update the description of the epoch progress bar
        epoch_progress_bar.set_postfix({
            "Train Loss": train_loss,
            "Val Loss": val_loss,
            "Val Acc": val_accuracy
        })

# Compute the mean validation accuracy across all folds
mean_val_accuracy = sum(all_val_accuracies) / len(all_val_accuracies)
print("Mean Validation Accuracy:", mean_val_accuracy)

# save the best model for val loss 
import time
if not os.path.exists(path_to_model):
    os.makedirs(path_to_model)
torch.save(best_model, f"{path_to_model}/{time.time()}-{best_val_loss:.4f}.pth")


In [None]:
# print traianing and validation loss
plt.style.use('ggplot')

# plot two image by side
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].plot(train_losses, label='Training loss')
ax[0].plot(val_losses, label='Validation loss')
ax[0].legend()
ax[0].set_xlabel('Epoch')
ax[0].set_ylabel('Loss')
ax[1].plot(train_accuracies, label='Training accuracy')
ax[1].plot(val_accuracies, label='Validation accuracy')
ax[1].legend()
ax[1].set_xlabel('Epoch')
ax[1].set_ylabel('Accuracy')
plt.show()

In [None]:
#load model 
from torch import load
model = load('models/attention-mlp/1709484443.4163241-0.3393.pth')
type(model)

In [None]:
# test the model
from sklearn.metrics import classification_report
with torch.no_grad():
    outputs = model(test_dataset.data, test_dataset.attention_mask)
    loss = criterion(outputs, test_dataset.labels.float())
    print(f'Test loss: {loss.item()}')
    print(classification_report(test_dataset.labels.cpu(), (outputs > 0.5).cpu()))
    # confusion matrix
    from sklearn.metrics import confusion_matrix
    import seaborn as sns
    import matplotlib.pyplot as plt
    cm = confusion_matrix(test_dataset.labels.cpu(), (outputs > 0.5).cpu())
    plt.figure(figsize=(10,7))
    sns.heatmap(cm, annot=True)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.show()
    