In [1]:
from TopicSegmentation import LegalBert, ModifiedStandardDecoder, PaddingMaskLayer
import tensorflow as tf
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
import pandas as pd
import numpy as np
import re

In [2]:
# Initialize your LegalBERT class
legal_bert = LegalBert()

  return self.fget.__get__(instance, owner)()


In [3]:
# Training Function
def train_step(encoder_input, decoder_input, decoder_output, decoder, optimizer, loss_function):
    with tf.GradientTape() as tape:
        # Get the context vectors from the encoder
        context_vectors = legal_bert.get_context_vectors(encoder_input)
        
        # Forward pass through the decoder
        predictions = decoder(decoder_input, encoder_outputs=context_vectors)
        
        # Compute the loss (ignoring padded tokens)
        loss = loss_function(decoder_output, predictions)
        
    # Compute the gradients
    gradients = tape.gradient(loss, decoder.trainable_variables)
    
    # Update the weights
    optimizer.apply_gradients(zip(gradients, decoder.trainable_variables))
    
    return loss

In [4]:
# Dummy Training Loop Example
num_epochs = 10
batch_size = 4

In [5]:
def remove_newlines(strings):
  """Removes newline characters from a list of strings using regular expressions.

  Args:
    strings: A list of strings.

  Returns:
    A new list of strings without newline characters.
  """

  pattern = r"\n"
  return [re.sub(pattern, "", s) for s in strings]

In [6]:
df = pd.read_csv('new_court_cases.csv')

df.dropna(inplace=True)

court_cases = df['whole_text'].to_list()
court_cases = remove_newlines(court_cases)
issues = df['issues'].to_list()

In [7]:
def sliding_window(text, max_length, stride):
    tokens = tokenizer.encode(text, truncation=False)
    return [tokens[i:i + max_length] for i in range(0, len(tokens), stride)]

In [8]:
def create_decoder_inputs_outputs(issues, max_length):
    decoder_inputs = []
    decoder_outputs = []
    
    for issue in issues:
        encoded_issue = tokenizer.encode(issue, truncation=False, return_tensors='tf')
        inputs = sliding_window(encoded_issue.numpy().tolist(), max_length, stride)
        outputs = [chunk[1:] for chunk in inputs]  # Shift right for outputs
        inputs = [chunk[:-1] for chunk in inputs]  # Remove last token for inputs
        
        # Convert lists back to tensors
        decoder_inputs.append(tf.ragged.constant(inputs))
        decoder_outputs.append(tf.ragged.constant(outputs))
    
    return decoder_inputs, decoder_outputs

In [9]:
# Set max_length to match the model's embedding size
max_length = 1024
stride = 512

# Tokenizer for decoder input
tokenizer = legal_bert.tokenizer
encoder_inputs = [sliding_window(case, max_length, stride) for case in court_cases]

Token indices sequence length is longer than the specified maximum sequence length for this model (12359 > 512). Running this sequence through the model will result in indexing errors


In [10]:
# Create decoder inputs and outputs
decoder_inputs, decoder_outputs = create_decoder_inputs_outputs(issues, max_length)

TypeError: TextEncodeInput must be Union[TextInputSequence, Tuple[InputSequence, InputSequence]]

In [None]:
# Instantiate the Decoder
decoder = ModifiedStandardDecoder(
    vocab_size=len(tokenizer),
    embedding_dim=1024,  # Match this with your model
    num_heads=8,
    ff_dim=2048,
)

In [None]:
# Optimizer and Loss Function
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
loss_function = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

In [None]:
# Training loop
for epoch in range(num_epochs):
    epoch_loss = 0
    for i in range(0, len(encoder_inputs), batch_size):
        # Batch data
        encoder_input_batch = encoder_inputs[i:i + batch_size]
        decoder_input_batch = tf.concat(decoder_inputs[i:i + batch_size], axis=0)
        decoder_output_batch = tf.concat(decoder_outputs[i:i + batch_size], axis=0)

        # Training step
        batch_loss = train_step(encoder_input_batch, decoder_input_batch, decoder_output_batch, decoder, optimizer, loss_function)
        epoch_loss += batch_loss

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss/len(encoder_inputs):.4f}")