# LegalBERT for Next Word Prediction

## Install/Import Libraries

In [2]:
!pip install transformers datasets torch
!pip install wandb

Collecting datasets
  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.1.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m15.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl (

In [3]:
from datasets import load_dataset
import nltk
from nltk.tokenize import sent_tokenize
import re
import random
from transformers import BertTokenizer, BertForMaskedLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling
import torch
from datasets import load_from_disk, Dataset
import os
import numpy as np


## Load and Preprocess Dataset (US Bills)

In [4]:
# Load the 'us_bills' subset

dataset = load_dataset("pile-of-law/pile-of-law", "us_bills")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/25.6k [00:00<?, ?B/s]

pile-of-law.py:   0%|          | 0.00/20.9k [00:00<?, ?B/s]

The repository for pile-of-law/pile-of-law contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/pile-of-law/pile-of-law.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y


Loading Dataset Infos from /root/.cache/huggingface/modules/datasets_modules/datasets/pile-of-law--pile-of-law/c1090502f95031ebfad49ede680394da5532909fa46b7a0452be8cddecc9fa60
INFO:datasets.info:Loading Dataset Infos from /root/.cache/huggingface/modules/datasets_modules/datasets/pile-of-law--pile-of-law/c1090502f95031ebfad49ede680394da5532909fa46b7a0452be8cddecc9fa60
Generating dataset pile-of-law (/root/.cache/huggingface/datasets/pile-of-law___pile-of-law/us_bills/0.0.0/c1090502f95031ebfad49ede680394da5532909fa46b7a0452be8cddecc9fa60)
INFO:datasets.builder:Generating dataset pile-of-law (/root/.cache/huggingface/datasets/pile-of-law___pile-of-law/us_bills/0.0.0/c1090502f95031ebfad49ede680394da5532909fa46b7a0452be8cddecc9fa60)
Downloading and preparing dataset pile-of-law/us_bills to /root/.cache/huggingface/datasets/pile-of-law___pile-of-law/us_bills/0.0.0/c1090502f95031ebfad49ede680394da5532909fa46b7a0452be8cddecc9fa60...
INFO:datasets.builder:Downloading and preparing dataset pile

train.us_bills.jsonl.xz:   0%|          | 0.00/176M [00:00<?, ?B/s]

Downloading took 0.0 min
INFO:datasets.download.download_manager:Downloading took 0.0 min
Checksum Computation took 0.0 min
INFO:datasets.download.download_manager:Checksum Computation took 0.0 min


validation.us_bills.jsonl.xz:   0%|          | 0.00/56.4M [00:00<?, ?B/s]

Downloading took 0.0 min
INFO:datasets.download.download_manager:Downloading took 0.0 min
Checksum Computation took 0.0 min
INFO:datasets.download.download_manager:Checksum Computation took 0.0 min
Generating train split
INFO:datasets.builder:Generating train split


Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split
INFO:datasets.builder:Generating validation split


Generating validation split: 0 examples [00:00, ? examples/s]

Unable to verify splits sizes.
INFO:datasets.utils.info_utils:Unable to verify splits sizes.
Dataset pile-of-law downloaded and prepared to /root/.cache/huggingface/datasets/pile-of-law___pile-of-law/us_bills/0.0.0/c1090502f95031ebfad49ede680394da5532909fa46b7a0452be8cddecc9fa60. Subsequent calls will reuse this data.
INFO:datasets.builder:Dataset pile-of-law downloaded and prepared to /root/.cache/huggingface/datasets/pile-of-law___pile-of-law/us_bills/0.0.0/c1090502f95031ebfad49ede680394da5532909fa46b7a0452be8cddecc9fa60. Subsequent calls will reuse this data.


In [5]:
def clean_text(bills):
    clean_bills = []
    for bill in bills:
      # Remove all newline and tab characters
      text = bill.replace('\n', ' ').replace('\t', ' ')
      # Remove sequences of exactly three lowercase letters
      text = re.sub(r'\[([a-z]{1,3})\]', ' ', text, flags=re.IGNORECASE)
      clean_bills.append(text.lower())

    return clean_bills

In [6]:
nltk.download('punkt_tab')

[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


True

In [7]:
#sentences = sent_tokenize(clean_text(dataset['train']['text'][:2]))
#print(sentences[:5])
clean_trained = clean_text(dataset['train']['text'][:60])
clean_trained_sent = []
for bill in clean_trained:
  sent = sent_tokenize(bill)
  clean_trained_sent.extend(sent)

clean_val = clean_text(dataset['validation']['text'][120:130])
clean_val_sent = []
for bill in clean_val:
  sent = sent_tokenize(bill)
  clean_val_sent.extend(sent)

print(len(clean_val_sent))

3126


## Next Word Prediction

Masking

In [8]:
# Load the data
train_dataset = clean_trained_sent
test_dataset = clean_val_sent #load_from_disk('/content/drive/MyDrive/CS 7650/preprocessed_us_bills/validation')

def tokenize_and_mask_last_word(examples):
    inputs = tokenizer(examples['text'], padding='max_length', truncation=True, max_length=128, return_tensors="pt")

    # Clone input_ids to create labels
    inputs['labels'] = inputs['input_ids'].clone()

    # Mask the last word in each sentence
    for i in range(len(inputs['input_ids'])):
        # Find the index of the last non-padding token
        non_pad_tokens = (inputs['input_ids'][i] != tokenizer.pad_token_id).nonzero(as_tuple=True)[0]
        if len(non_pad_tokens) > 1:
            last_token_idx = non_pad_tokens[-1].item()
            inputs['input_ids'][i][last_token_idx] = tokenizer.mask_token_id  # Mask the last word
    return inputs


# Randomly sample 1000 examples from the train dataset and 200 from the test dataset
train_sample = Dataset.from_dict({'text': train_dataset})#.shuffle(seed=42).select(range(1000))
test_sample = Dataset.from_dict({'text': test_dataset})#.shuffle(seed=42).select(range(200))

# Load BERT tokenizer and model
tokenizer = BertTokenizer.from_pretrained("nlpaueb/legal-bert-base-uncased")
model = BertForMaskedLM.from_pretrained("nlpaueb/legal-bert-base-uncased")

def generate_chunks_from_text(text, min_length=100, max_length=512, tokenizer = tokenizer):
    print("I be")
    tokens = tokenizer.encode(text, truncation=False, padding=False)  # Encode without truncation
    print("Doubting")
    chunks = []
    while len(tokens) >= min_length:
        # Generate a random chunk length between min_length and max_length (or remaining tokens)
        chunk_length = 512 #random.randint(min_length, min(max_length, len(tokens)))  # Random length for each chunk
        chunk = tokens[:chunk_length]
        chunks.append(chunk)
        tokens = tokens[chunk_length:]  # Remove the processed chunk from the tokens
    print("Generated n chunks: ", len(chunks))
    return chunks

# Function to manually expand the dataset by adding multiple rows for each document
def split_into_chunks(train_bills, min_length=100, max_length=512, tokenizer=tokenizer):
    all_input_ids = []
    all_labels = []
    all_attention_masks = []  # To store attention masks

    for text in train_bills:
        # Split the document into multiple chunks
        chunks = generate_chunks_from_text(text, min_length, max_length, tokenizer)
        for chunk in chunks:
            all_input_ids.append(chunk)
            all_labels.append(chunk)  # For causal language modeling, labels are the same as input_ids

            # Generate attention mask: 1 for real tokens, 0 for padding (for now, we assume no padding in chunks)
            attention_mask = [1] * len(chunk)
            all_attention_masks.append(attention_mask)

    print("Raw input ids and label ids dims:", len(all_input_ids), " x ", len(all_input_ids[0]), " or ", len(all_input_ids[1]), ' and ', len(all_labels), ' x ', len(all_labels[0]), ' or ', len(all_labels[1]))

    # Now, pad the sequences to ensure they have consistent lengths
    input_ids_padded = tokenizer.pad(
        {"input_ids": all_input_ids},  # Only need to pad input_ids
        padding='max_length',  # Pad to the longest sequence in the batch
        max_length=max_length,  # Set max length
        return_tensors="pt"  # Return as pytorch tensors
    )

    # Pad labels as well
    labels_padded = tokenizer.pad(
        {"input_ids": all_labels},  # Same padding as input_ids
        padding='max_length',
        max_length=max_length,
        return_tensors="pt"
    )

    # Pad the attention masks (1 for real tokens, 0 for padding)
    attention_masks_padded = tokenizer.pad(
        {"input_ids": all_attention_masks},  # Same padding for attention masks
        padding='max_length',
        max_length=max_length,
        return_tensors="pt"
    )
    print("In split into chunks, size of all 3 columns: ", input_ids_padded['input_ids'].shape)
    # Return the padded dataset with attention masks
    return {
        "input_ids": input_ids_padded["input_ids"],
    }

# Apply tokenization
train_sample = train_sample.map(tokenize_and_mask_last_word, batched=True)
test_sample = test_sample.map(tokenize_and_mask_last_word, batched=True)



tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/222k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.02k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/440M [00:00<?, ?B/s]

BertForMaskedLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.


Map:   0%|          | 0/1311 [00:00<?, ? examples/s]

Map:   0%|          | 0/3126 [00:00<?, ? examples/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Evaluation

In [9]:
# Step 6: Calculate accuracy on test data for both fine-tuned and pretrained models
def compute_accuracy(model, dataset):
    model.eval()
    correct, total = 0, 0
    # wrong_predictions = []  # Store sentences with wrong predictions
    # right_predictions = []  # Store sentences with right predictions
    # max_examples = 5 # Maximum number of examples to store for each case
    total_loss = 0
    total_tokens = 0
    device = torch.device('cpu')
    model.to(device)
    with torch.no_grad():
        for example in dataset:
            # Convert list values to tensors
            inputs = {
                k: torch.tensor(v).unsqueeze(0).to(device) if isinstance(v, list) else v.unsqueeze(0).to(device) #check if v is list, else unsqueeze
                for k, v in example.items()
                if k in ["input_ids", "attention_mask"]
            }
            labels = inputs["input_ids"].clone()  # Copy the input_ids as labels for next word prediction
            # Replace masked token IDs in labels with -100 to ignore them in accuracy calculation
            labels[labels == tokenizer.mask_token_id] = -100
            outputs = model(**inputs, labels=labels)
            predictions = outputs.logits.argmax(dim=-1) #Try getting

            loss = outputs.loss
            num_tokens = (labels != -100).sum().item()
            total_loss += loss.item() * num_tokens
            total_tokens += num_tokens

            correct += (predictions[0][labels[0] != -100] == labels[0][labels[0] != -100]).sum().item()
            total += (labels[0] != -100).sum().item()


    perplexity = np.exp(total_loss / total_tokens)
    return correct / total, perplexity

# Accuracy with fine-tuned model
fine_tuned_accuracy, fine_tuned_perplexity = compute_accuracy(model, test_sample)
print("Fine-tuned Model Accuracy:", fine_tuned_accuracy)


Fine-tuned Model Accuracy: 0.3828570133147944
