In [3]:
from TopicSegmentation import LegalBert, ModifiedStandardDecoder
import tensorflow as tf
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
import pandas as pd

# Instantiate Model Architecture

In [4]:
# Example usage
vocab_size = 5000
embedding_dim = 768
num_heads = 8
ff_dim = 512
dropout_rate = 0.1

In [5]:
decoder = ModifiedStandardDecoder(vocab_size, embedding_dim, num_heads, ff_dim, dropout_rate)

annotated_inputs = tf.keras.Input(shape=(None,))
encoder_outputs = tf.keras.Input(shape=(None, embedding_dim))

padding_mask = None  # Add padding mask if necessary

outputs = decoder(annotated_inputs, encoder_outputs, padding_mask=padding_mask)
model_ruling = tf.keras.Model(inputs=[annotated_inputs, encoder_outputs], outputs=outputs) # model for ruling
model_facts = tf.keras.Model(inputs=[annotated_inputs, encoder_outputs], outputs=outputs) # model for facts
model_issues = tf.keras.Model(inputs=[annotated_inputs, encoder_outputs], outputs=outputs) # model for issues

model_ruling.summary()

Tensor("positional_encoding_1_1/add:0", shape=(None, None, 768), dtype=float32)
Tensor("positional_encoding_1_1/add:0", shape=(None, None, 768), dtype=float32)


# Load Data

In [9]:
df = pd.read_csv('court.csv')

df = df.iloc[:,2:]

df.dropna(inplace=True)

court_case = df['court case'].to_list()
ruling = df['rulings'].to_list()
facts = df['facts'].to_list()
issues = df['issues'].to_list()

# Preprocess Data

## Instantiate preprocessor and encoder

In [10]:
# Initialize the preprocessor and legal BERT
legal_bert = LegalBert()

bert_output = legal_bert.get_context_vectors(court_case)

## Tokenize and add paddings to each data for each model

In [11]:
# Tokenize the ruling segments
tokenizer_ruling = Tokenizer(num_words=5000)  # Adjust num_words according to your vocabulary size
tokenizer_ruling.fit_on_texts(ruling)
tokenized_segments_ruling = tokenizer_ruling.texts_to_sequences(ruling)

# Tokenize the facts segments
tokenizer_facts = Tokenizer(num_words=5000)  # Adjust num_words according to your vocabulary size
tokenizer_facts.fit_on_texts(facts)
tokenized_segments_facts = tokenizer_facts.texts_to_sequences(facts)

# Tokenize the issues segments
tokenizer_issues = Tokenizer(num_words=5000)  # Adjust num_words according to your vocabulary size
tokenizer_issues.fit_on_texts(issues)
tokenized_segments_issues = tokenizer_issues.texts_to_sequences(issues)

# Max sequence length of all padding
max_seq_len = min(bert_output.shape[1], 512)

# Pad the sequences to ensure uniform length
padded_segments_ruling = pad_sequences(tokenized_segments_ruling, padding='post', maxlen=max_seq_len)
padded_segments_facts = pad_sequences(tokenized_segments_facts, padding='post', maxlen=max_seq_len)
padded_segments_issues = pad_sequences(tokenized_segments_issues, padding='post', maxlen=max_seq_len)

# Convert to tensor
padded_segments_ruling = tf.convert_to_tensor(padded_segments_ruling)
padded_segments_facts = tf.convert_to_tensor(padded_segments_facts)
padded_segments_issues = tf.convert_to_tensor(padded_segments_issues)

xtrain = bert_output.detach().numpy() # all segment has the same x data but different y data
ytrain_ruling = padded_segments_ruling
ytrain_facts = padded_segments_facts
ytrain_issues = padded_segments_issues

# Train Models

In [12]:
# Compile the models
model_ruling.compile(optimizer=tf.keras.optimizers.RMSprop(learning_rate=1e-4), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model_facts.compile(optimizer=tf.keras.optimizers.RMSprop(learning_rate=1e-4), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model_issues.compile(optimizer=tf.keras.optimizers.RMSprop(learning_rate=1e-4), loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Train all model then save them
model_ruling.fit([ytrain_ruling, xtrain], ytrain_ruling, epochs=3)
model_ruling.save('ruling.keras')

model_facts.fit([ytrain_facts, xtrain], ytrain_facts, epochs=3)
model_facts.save('facts.keras')

model_issues.fit([ytrain_issues, xtrain], ytrain_issues, epochs=3)
model_issues.save('issues.keras')

Epoch 1/3
Tensor("functional_9_1/modified_standard_decoder_1_1/dropout_7_1/stateless_dropout/SelectV2:0", shape=(None, 512, 768), dtype=float32)
Tensor("functional_9_1/modified_standard_decoder_1_1/dropout_7_1/stateless_dropout/SelectV2:0", shape=(None, 512, 768), dtype=float32)
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m45s[0m 45s/step - accuracy: 0.0000e+00 - loss: 8.6397
Epoch 2/3
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 11s/step - accuracy: 0.1782 - loss: 7.8582
Epoch 3/3
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 10s/step - accuracy: 0.1810 - loss: 7.2958
Epoch 1/3
Tensor("functional_11_1/modified_standard_decoder_1_1/dropout_7_1/stateless_dropout/SelectV2:0", shape=(None, 512, 768), dtype=float32)
Tensor("functional_11_1/modified_standard_decoder_1_1/dropout_7_1/stateless_dropout/SelectV2:0", shape=(None, 512, 768), dtype=float32)
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 13s/step - accuracy: 0.1714 - los