In [45]:
from nltk.tokenize import RegexpTokenizer
import tensorflow as tf
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
import pandas as pd
import numpy as np
import re
from transformers import Trainer, TrainingArguments

In [46]:
class preprocess:
    def __init__(self, file_path, bert_tokenizer):
        self.df = pd.read_csv('new_court_cases.csv')
        self.tokenizer = RegexpTokenizer(r"[a-zA-Z0-9]+")
        self.bert_tokenizer = bert_tokenizer
        self.window_size = 128
        self.max_token = 510
        self.court_cases = None
        self.rulings = None
        self.issues = None
        self.facts = None

        # Flag for new tokens found
        self.found_new_unknown_token = False

        # Load the unknown tokens
        try:
            with open('unknown_tokens.txt', 'r') as f:
                self.unknown_tokens = f.read().splitlines()
        except:
            self.unknown_tokens = []

        # use this variable for debugging
        self.debugging = True

        # drop null values in the comment
        self.df.dropna(inplace=True)
        
        # preprocess
        self.preprocess()

        # drop duplicates
        self.df = self.df.drop_duplicates()

    def preprocess(self):
        # lowercase the text and Remove unnecessary characters
        self.court_cases = [self.change_char(text.lower()) for text in self.df["whole_text"]]
        self.rulings = [self.change_char(text.lower()) for text in self.df["ruling"]]
        self.facts = [self.change_char(text.lower())for text in self.df["facts"]]
        self.issues = [self.change_char(text.lower()) for text in self.df["issues"]]

        # tokenize the text, storing words only
        self.court_cases = [self.tokenizer.tokenize(text) for text in self.court_cases]
        self.rulings = [self.tokenizer.tokenize(text) for text in self.rulings]
        self.facts = [self.tokenizer.tokenize(text) for text in self.facts]
        self.issues = [self.tokenizer.tokenize(text) for text in self.issues]

        # if longer than 512 tokens, chunk the tokens into 512 while adding windows
        self.court_cases = [self.split_to_chunks_with_windows(tokens) for tokens in self.court_cases]
        self.rulings = [self.split_to_chunks_with_windows(tokens) for tokens  in self.rulings]
        self.facts = [self.split_to_chunks_with_windows(tokens) for tokens  in self.facts]
        self.issues = [self.split_to_chunks_with_windows(tokens) for tokens  in self.issues]


    def split_to_chunks_with_windows(self, tokens):
        """
        Court cases has long text while BERT only has 512 maximum tokens. 
        This function splits the long text into chunks with 512 tokens with sliding windows.

            Args:
                tokens: A list of tokens.

            Returns:
                chunks: A list of list of string with 512 tokens including BERT's special token ([CLS] [SEP])
        """
        chunks = []

        # Iterate over the list of tokens and chunking them with window slides
        for i in range(0, len(tokens), self.max_token - self.window_size):
            # store the current tokens 
            chunk = tokens[i:i+self.max_token]

            # add special tokens
            chunk.insert(0, "[CLS]")
            chunk.append("[SEP]")
            
            # convert to string format
            chunk_string = ' '.join(chunk)

            # append
            chunks.append(chunk_string)

            # Break loop if we have covered the entire sequence
            if i + self.max_token >= len(tokens):
                break
                
        return chunks

    def change_char(self, text):
        """
        Removes special characters from a list of strings using regular expressions.
        As well as change characters.
        
        
          Args:
            strings: A list of strings.
        
          Returns:
            A new list of strings without special characters.
          """
        text = re.sub(r'[(),:;\'".’”[]]', '', text)
        text = re.sub(r'rtc', 'regional trial court', text)
        text = re.sub(r"\w*\d+\w*", "", text)
        text = re.sub(r"“", "", text)
        text = re.sub(r",”", "", text)
        text = re.sub(r",", "", text)
        text = re.sub(r",,.", "", text)
        text = re.sub(r",,.,", "", text)
        text = re.sub(r"--,", "", text)
        '''text = re.sub(r"--", "", text)
        text = re.sub(r".", "", text)
        text = re.sub(r"—", "", text)
        text = re.sub(r"],", "", text)'''
        text = re.sub(r"\u2033", "", text)
        text = re.sub(r"\u2032", "", text)
        return text

    def prepare_input_output(self, chunks):
        """
        Prepare input-output pairs for each chunk. 
        Returns a list of tuples, where each tuple represents an (input, output) pair.
        """
        input_output_pairs = []
        for chunk in chunks:
            # Tokenize the chunk and convert to IDs
            input_ids = self.bert_tokenizer.convert_tokens_to_ids(chunk.split())
    
            # Verify that the chunk ends with the [SEP] token to avoid duplicates
            if input_ids[-1] != self.bert_tokenizer.convert_tokens_to_ids("[SEP]"):
                input_ids.append(self.bert_tokenizer.convert_tokens_to_ids("[SEP]"))
    
            # Prepare the shifted output (excluding the initial [CLS] token)
            shifted_output = input_ids[1:]  # Shifted output starts from the second token
    
            # No need to append [SEP] here, it's already included in input_ids if required
            # Add input-output pair to list
            input_output_pairs.append((input_ids, shifted_output))

            # Check for unknown tokens and append them to the list that will be added
            unk_tokens = chunk.split()
            for i in range(len(unk_tokens)):
                if input_ids[i] == 100 and unk_tokens[i] not in self.unknown_tokens:
                    print(unk_tokens[i]," : ",input_ids[i])
                    self.found_new_unknown_token = True
                    self.unknown_tokens.append(unk_tokens[i])
                        
            
        return input_output_pairs

    def get_training_data(self):
        """
        Prepare training data for all segments, maintaining the structure per court case.
        """
        training_data = []
        self.found_new_unknown_token = False
        
        for i in range(len(self.court_cases)):
            # Prepare input-output pairs for each segment within a single court case
            court_case_data = self.court_cases[i]
            ruling_data = self.prepare_input_output(self.rulings[i])
            fact_data = self.prepare_input_output(self.facts[i])
            issue_data = self.prepare_input_output(self.issues[i])
            
            # Maintain structure by grouping segments within the same court case
            case_data = {
                "court_case": court_case_data,
                "rulings": ruling_data,
                "facts": fact_data,
                "issues": issue_data
            }
            
            training_data.append(case_data)

        # If unknown token/s found, Update file containing all unknown token & Raise an error message
        if self.unknown_tokens and self.found_new_unknown_token:
            with open('unknown_tokens.txt', 'w') as f:
                for token in self.unknown_tokens:
                    f.write(f"{token}\n")
            raise Exception("There are unknown token/s found. Update the tokenizer and finetune the model.")
        
        return training_data
        

In [47]:
class prep_model:
    def __init__(self, tokenizer, model):
        self.tokenizer = tokenizer
        self.model = model
        self.update_tokenizer()

        # Load the unknown tokens
        try:
            with open('unknown_tokens.txt', 'r') as f:
                self.unknown_tokens = f.read().splitlines()
        except:
            raise Exception("No file found.")

    def update_tokenizer(self):
        
        # Add the new tokens to the tokenizer
        self.tokenizer.add_tokens(self.unknown_tokens)

        # Resize the model's token embeddings to match the new tokenizer length
        self.model.resize_token_embeddings(len(self.tokenizer))

    def finetune_model(self):
        pass

In [48]:
from TopicSegmentation import LegalBert

In [49]:
# Initialize the preprocessor and legal BERT
legal_bert = LegalBert()

In [50]:
x = preprocess("new_court_cases.csv", legal_bert.tokenizer)

In [52]:
training = x.get_training_data()