In [1]:
import os
from dotenv import load_dotenv
from huggingface_hub import login
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import pipeline, AutoModel
import json
import textwrap
from langchain import HuggingFacePipeline
from langchain import PromptTemplate,  LLMChain
from langchain.memory import ConversationBufferMemory
import pandas as pd
import time
from nltk.tokenize import sent_tokenize
from sklearn.metrics.pairwise import cosine_similarity

In [None]:
load_dotenv()
HUGGINGFACEHUB_API_TOKEN = os.getenv("HF_AUTH_TOKEN")
login(token=HUGGINGFACEHUB_API_TOKEN)

In [3]:
# tokenizer = AutoTokenizer.from_pretrained("pile-of-law/legalbert-large-1.7M-2")
# model = AutoModel.from_pretrained("pile-of-law/legalbert-large-1.7M-2")
# device = torch.device("cuda")
# model.to(device)  # Move the model to the GPU

In [4]:
# # Save models and tokenizer locally
# model.save_pretrained('./legal-bert-large')
# tokenizer.save_pretrained('./legal-bert-large')

In [5]:
model_directory = "./legal-bert-large"
tokenizer = AutoTokenizer.from_pretrained(model_directory)
model = AutoModel.from_pretrained(model_directory)

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [7]:
model.to(device)

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(32000, 1024, padding_idx=0)
    (position_embeddings): Embedding(512, 1024)
    (token_type_embeddings): Embedding(2, 1024)
    (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-23): 24 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=1024, out_features=1024, bias=True)
            (key): Linear(in_features=1024, out_features=1024, bias=True)
            (value): Linear(in_features=1024, out_features=1024, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=1024, out_features=1024, bias=True)
            (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inpl

In [8]:
# Getting the majority opinion from the CaseLaw json file
def load_and_extract_data(file_path):
    
    with open(file_path, 'r', encoding='utf-8') as file:
        data = json.load(file)

    for o in data["casebody"]["data"]["opinions"]:
        if o["type"] == "majority":
            return o["text"]
        else:
            return None

# Grab chunks of text to summarize
# It has an overlap to make sure each chunk has context of the previous chuck
def chunk_text_with_overlap(text, chunk_word_count, overlap_word_count):
    words = text.split()
    chunks = []
    index = 0

    while index < len(words):
        current_chunk_end = index + chunk_word_count
        current_chunk_end = min(current_chunk_end, len(words))
        chunk = " ".join(words[index:current_chunk_end])
        chunks.append(chunk)

        index += chunk_word_count - overlap_word_count

        # force it to advance to avoid an infinite loop
        if index >= current_chunk_end:
            index = current_chunk_end

    return chunks

# Sort of overkill on saving a summary to a text file
# Some naming logic and error checking added in
def save_summary_to_text(summary, output_folder, file_path, condensed=False):
    
    base_name = os.path.splitext(os.path.basename(file_path))[0]
    if condensed:
        summary_file_name = f"{base_name}_condensed_summary.txt"
    else:
        summary_file_name = f"{base_name}_summary.txt"
    
    summary_file_path = os.path.join(output_folder, summary_file_name)

    try:
        with open(summary_file_path, 'w', encoding='utf-8') as file:
            file.write(summary)
        print(f"Summary successfully written to {summary_file_name}")
    except IOError as e:
        print(f"Unable to write to file: {e}")
    except Exception as e:
        print(f"An unexpected error occurred: {e}")

In [9]:
# Calculates cosine similarity for sentence embeddings
# Pair sentences with scores, then sorts in descending order
# Pick top 'num_sentence' number of sentences
# START: REFACTOR FROM <https://towardsdatascience.com/extractive-summarization-using-bert-966e912f4142 and https://www.analyticsvidhya.com/blog/2023/03/exploring-the-extractive-method-of-text-summarization/>
# Also used GPT-4 in debugging
def extractive_summarization(text, num_sentences):

    # Use NLTK's sentence tokenizer to split the text into individual sentences
    sentences = sent_tokenize(text)
    
    tokenized_sentences = [tokenizer.encode(sent, add_special_tokens=True) for sent in sentences]
    
    max_len = 0
    for i in tokenized_sentences:
        if len(i) > max_len:
            max_len = len(i)
    
    padded_sentences = []
    for i in tokenized_sentences:
        while len(i) < max_len:
            i.append(0)
        padded_sentences.append(i)
        
    input_ids = torch.tensor(padded_sentences)
    
    attention_mask = [[float(i != 0.0) for i in seq] for seq in padded_sentences]
    attention_mask = torch.tensor(attention_mask)
    attention_mask = attention_mask.to(device)
    
    input_ids = input_ids.to(device)
    
    with torch.no_grad():
        last_hidden_states = model(input_ids, attention_mask=attention_mask)[0]
    
    sentence_embeddings = []
    for i in range(len(sentences)):
        sentence_embeddings.append(torch.mean(last_hidden_states[i], dim=0).cpu().numpy())
        
    similarity_matrix = cosine_similarity(sentence_embeddings)
    
    sentence_scores = [sum(similarity_matrix[i]) for i in range(len(sentences))]
    
    sentence_score_pairs = list(enumerate(sentence_scores))
    
    sorted_sentences = sorted(sentence_score_pairs, key=lambda x: x[1], reverse=True)
    
    summary_sentences = [sentences[index] for index, _ in sorted_sentences[:num_sentences]]
    
    summary = ' '.join(summary_sentences)
    
    return summary
# END: REFACTOR FROM <https://towardsdatascience.com/extractive-summarization-using-bert-966e912f4142 and https://www.analyticsvidhya.com/blog/2023/03/exploring-the-extractive-method-of-text-summarization/>

In [14]:
torch.cuda.empty_cache()

In [15]:
import gc
gc.collect()

20

# Bulk Summarization

In [16]:
# Summarizing a case document and using a csv file to keep track of
# This code has been refactored seeveral times so the name and batch_size are kind of outdated
def summarize_a_batch_of_case_documents(batch_size, processed_files_csv):
    start = time.time()
    if os.path.exists(processed_files_csv):
        processed_files_df = pd.read_csv(processed_files_csv)
    else:
        processed_files_df = pd.DataFrame(columns=["file_name", "time_elapsed"])
    processed_count = 0
    filenames = sorted(os.listdir(input_folder))
    
    # Start from where the last entry left off
    last_processed_index = 0
    if not processed_files_df.empty:
        last_filename = processed_files_df['file_name'].iloc[-1]
        last_processed_index = filenames.index(last_filename) + 1

    # Process files starting from the last processed one
    for filename in filenames[last_processed_index:]:
        if processed_count >= batch_size:
            break
        print("Processing: ", filename)
        if filename.endswith(".json"):
            if filename in processed_files_df['file_name'].values:
                    print(f"File {filename} has already been processed. Skipping.")
                    continue
            else:
                try:
                    summarize_a_case_document(filename)
                except:
                    print(f"Error processing {filename}")
                processed_count += 1
                end = time.time()
                new_row = {"file_name": filename,
                           "time_elapsed": end - start
                          }
                processed_files_df = pd.concat([processed_files_df, pd.DataFrame([new_row])], ignore_index=True)
                processed_files_df.to_csv(processed_files_csv, index=False)

# This code summarizes a case when given a json file
# It gets the major opinion, chucks it, summarizes it, then save it as an individual txt file
def summarize_a_case_document(filename):
    file_path = os.path.join(input_folder, filename)
    opinion = load_and_extract_data(file_path)
    opinion = str(opinion)
    
    chunk_word_count = 1000
    overlap_word_count = 200
    
    chunks = chunk_text_with_overlap(opinion, chunk_word_count, overlap_word_count)

    chunk_summaries = summarize_chunks(chunks, chunk_word_count)
    final_summary = ' '.join(chunk_summaries)
    
    save_summary_to_text(final_summary, output_folder, file_path, condensed=False)

# It's here that we decided the number of sentences to use in extractive summarization
# we picked between 2 and 7 sentences
# given that chunk in roughly 1000 words, we would add a sentence every 100 words
def summarize_chunks(chunks, chunk_word_count):
    min_sentences = 2
    standard_summary_length = 10
    max_sentences = 7

    summaries = []

    for chunk in chunks:
        chunk_length = len(chunk.split())
        proportional_sentences_number = int((chunk_length / chunk_word_count) * standard_summary_length)
        
        sentences_to_summarize = max(proportional_sentences_number, min_sentences)

        sentences_to_summarize = min(max_sentences, proportional_sentences_number)
        
        # Perform extractive summarization on the chunk using 'sentences_to_summarize' as the number of sentences to include in the summary
        summary = extractive_summarization(chunk, sentences_to_summarize)
        summaries.append(summary)
        
    return summaries

## Getting the train dataset

In [17]:
input_folder = 'ref_case_jsons_train'
output_folder = 'ref_case_txt_train'
os.makedirs(output_folder, exist_ok=True)
processed_files_df = None
processed_files_csv = 'processed_files_for_legalBertLarge_train_set.csv'

In [None]:
# for i in range(len(os.listdir(input_folder))):
for i in range(10000):
    summarize_a_batch_of_case_documents(1, processed_files_csv)
    if i % 500 == 0:
        torch.cuda.empty_cache()
        gc.collect()

## Getting the val dataset

In [None]:
input_folder = 'ref_case_jsons_val'
output_folder = 'ref_case_txt_val'
os.makedirs(output_folder, exist_ok=True)
processed_files_df = None
processed_files_csv = 'processed_files_for_legalBertLarge_val_set.csv'
for i in range(len(os.listdir(input_folder))):
    print(i)
    summarize_a_batch_of_case_documents(1, processed_files_csv)
    if i % 500 == 0:
        torch.cuda.empty_cache()
        gc.collect()

## Getting the test dataset

In [None]:
input_folder = 'ref_case_jsons_test'
output_folder = 'ref_case_txt_test'
os.makedirs(output_folder, exist_ok=True)
processed_files_df = None
processed_files_csv = 'processed_files_for_legalBertLarge_test_set.csv'
for i in range(len(os.listdir(input_folder))):
    print(i)
    summarize_a_batch_of_case_documents(1, processed_files_csv)
    if i % 500 == 0:
        torch.cuda.empty_cache()
        gc.collect()

# Some tests below to see if LegalBert-Large can help condense generated holdings
## It's not very reliable or at least I disagree with the sentences it picks

In [45]:
holding = """
The holding in this case is that the presumption of legitimacy should not be utilized to perpetuate a falsehood if the truth can be discovered. The court found that the mother's husband had been a substantial presence in the child's life and desires to continue to exercise parental rights, and therefore, the need for joining him as a party whose interests "might be inequitably affected by" the resulting order of filiation is manifest. The court also found that the results of a human leucocyte antigen test showed a 99.53% probability that the petitioner is the child's father, which was sufficient to overcome the presumption of legitimacy that arises when a child is born to a married woman. However, the court noted that if William refuses to submit to a blood test, an adverse inference may then be drawn against him.
"""

In [46]:
extractive_summarization(holding, 2)

'The court also found that the results of a human leucocyte antigen test showed a 99.53% probability that the petitioner is the child\'s father, which was sufficient to overcome the presumption of legitimacy that arises when a child is born to a married woman. The court found that the mother\'s husband had been a substantial presence in the child\'s life and desires to continue to exercise parental rights, and therefore, the need for joining him as a party whose interests "might be inequitably affected by" the resulting order of filiation is manifest.'

In [47]:
holding = """
The holding of the case is:  "Where a mother's husband has been a substantial presence in the
child's life and desires to continue to exercise parental rights, the need for joining him as a
party whose interests'might be inequitably affected by' the resulting order of filiation is
manifest, and the court may order joinder on its own motion."  In other words, the court has the
discretion to join a person as a party to a paternity proceeding if their interests may be impacted
by the outcome of the case, even if they are not a party to the proceeding. This holding is based on
the idea that the truth about paternity should be discovered and that the presumption of legitimacy
should not be used to perpetuate a falsehood.
"""

In [48]:
extractive_summarization(holding, 2)

'\nThe holding of the case is:  "Where a mother\'s husband has been a substantial presence in the\nchild\'s life and desires to continue to exercise parental rights, the need for joining him as a\nparty whose interests\'might be inequitably affected by\' the resulting order of filiation is\nmanifest, and the court may order joinder on its own motion." In other words, the court has the\ndiscretion to join a person as a party to a paternity proceeding if their interests may be impacted\nby the outcome of the case, even if they are not a party to the proceeding.'