# Graph regularization for legal decisions classification using synthesized graphs

## Dependencies and imports

In [None]:
pip install --quiet neural-structured-learning

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import neural_structured_learning as nsl

## Base model

### Global variables

In [None]:
NBR_FEATURE_PREFIX = 'NL_nbr_'
NBR_WEIGHT_SUFFIX = '_weight'

### Hyperparameters

In [None]:
class HParams(object):
    """Hyperparameters used for training."""
    def __init__(self):
        ### dataset parameters
        self.num_classes = 253
        self.max_seq_length = 2500
        self.vocab_size = 87148
        ### neural graph learning parameters
        self.distance_type = nsl.configs.DistanceType.L2
        self.graph_regularization_multiplier = 0.1
        self.num_neighbors = 2
        ### model architecture
        self.num_embedding_dims = 16
        self.num_lstm_dims = 64
        self.num_fc_units = 64
        ### training parameters
        self.train_epochs = 5
        self.batch_size = 128
        ### eval parameters
        self.eval_steps = None  # All instances in the test set are evaluated.

HPARAMS = HParams()

### Prepare the data

In [None]:
def make_dataset(file_path, training=False):
    """Creates a `tf.data.TFRecordDataset`.
    Args:
    file_path: Name of the file in the `.tfrecord` format containing
      `tf.train.Example` objects.
    training: Boolean indicating if we are in training mode.

    Returns:
    An instance of `tf.data.TFRecordDataset` containing the `tf.train.Example`
    objects.
    """

    def pad_sequence(sequence, max_seq_length):
        """Pads the input sequence (a `tf.SparseTensor`) to `max_seq_length`."""
        pad_size = tf.maximum([0], max_seq_length - tf.shape(sequence)[0])
        padded = tf.concat([sequence.values, tf.fill((pad_size), tf.cast(0, sequence.dtype))], axis=0)
        # The input sequence may be larger than max_seq_length. Truncate down if
        # necessary.
        return tf.slice(padded, [0], [max_seq_length])

    def parse_example(example_proto):
        """Extracts relevant fields from the `example_proto`.

        Args:
          example_proto: An instance of `tf.train.Example`.

        Returns:
          A pair whose first value is a dictionary containing relevant features
          and whose second value contains the ground truth labels.
        """
        # The 'words' feature is a variable length word ID vector.
        feature_spec = {'words': tf.io.VarLenFeature(tf.int64),
                        'label': tf.io.FixedLenFeature((), tf.int64, default_value=-1)}
        
        # We also extract corresponding neighbor features in a similar manner to
        # the features above during training.
        if training:
            for i in range(HPARAMS.num_neighbors):
                nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'words')
                nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, i,
                                                 NBR_WEIGHT_SUFFIX)
                feature_spec[nbr_feature_key] = tf.io.VarLenFeature(tf.int64)

            # We assign a default value of 0.0 for the neighbor weight so that
            # graph regularization is done on samples based on their exact number
            # of neighbors. In other words, non-existent neighbors are discounted.
                feature_spec[nbr_weight_key] = tf.io.FixedLenFeature([1], tf.float32, default_value=tf.constant([0.0]))

        features = tf.io.parse_single_example(example_proto, feature_spec)

        # Since the 'words' feature is a variable length word vector, we pad it to a
        # constant maximum length based on HPARAMS.max_seq_length
        features['words'] = pad_sequence(features['words'], HPARAMS.max_seq_length)
        if training:
            for i in range(HPARAMS.num_neighbors):
                nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'words')
                features[nbr_feature_key] = pad_sequence(features[nbr_feature_key], HPARAMS.max_seq_length)

        labels = features.pop('label')
        return features, labels

    dataset = tf.data.TFRecordDataset([file_path])
    if training:
        dataset = dataset.shuffle(87148)
    dataset = dataset.map(parse_example)
    dataset = dataset.batch(HPARAMS.batch_size)
    return dataset

In [None]:
!ls

In [None]:
train_dataset = make_dataset('nsl_train_data.tfr', True)
test_dataset = make_dataset('test_data.tfr')

In [None]:
test_dataset

### Build the model

In [None]:
def make_bilstm_model():
    """Builds a bi-directional LSTM model."""
    inputs = tf.keras.Input(shape=(HPARAMS.max_seq_length,), dtype='int64', name='words')
    embedding_layer = tf.keras.layers.Embedding(HPARAMS.vocab_size, HPARAMS.num_embedding_dims)(inputs)
    lstm_layer = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(HPARAMS.num_lstm_dims))(embedding_layer)
    dense_layer = tf.keras.layers.Dense(HPARAMS.num_fc_units, activation='relu')(lstm_layer)
    outputs = tf.keras.layers.Dense(253, activation='softmax')(dense_layer)
    return tf.keras.Model(inputs=inputs, outputs=outputs)

In [None]:
# Build a new base LSTM model.
base_reg_model = make_bilstm_model()

In [None]:
# Wrap the base model with graph regularization.
graph_reg_config = nsl.configs.make_graph_reg_config(
    max_neighbors=HPARAMS.num_neighbors,
    multiplier=HPARAMS.graph_regularization_multiplier,
    distance_type=HPARAMS.distance_type,
    sum_over_axis=-1)
graph_reg_model = nsl.keras.GraphRegularization(base_reg_model,
                                                graph_reg_config)
graph_reg_model.compile(
    optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

### Train the model

In [None]:
graph_reg_history = graph_reg_model.fit(
    train_dataset,
    epochs=HPARAMS.train_epochs,
    validation_data=test_dataset,
    verbose=1)

### Evaluate the model

In [None]:
graph_reg_results = graph_reg_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)
print(graph_reg_results)

### Create a graph of accuracy/loss over time

In [None]:
graph_reg_history_dict = graph_reg_history.history
graph_reg_history_dict.keys()

In [None]:
acc = graph_reg_history_dict['accuracy']
val_acc = graph_reg_history_dict['val_accuracy']
loss = graph_reg_history_dict['loss']
graph_loss = graph_reg_history_dict['graph_loss']
val_loss = graph_reg_history_dict['val_loss']

epochs = range(1, len(acc) + 1)

plt.clf()   # clear figure

# "-r^" is for solid red line with triangle markers.
plt.plot(epochs, loss, '-r^', label='Training loss')
# "-gD" is for solid green line with diamond markers.
plt.plot(epochs, graph_loss, '-gD', label='Training graph loss')
# "-b0" is for solid blue line with circle markers.
plt.plot(epochs, val_loss, '-bo', label='Validation loss')
plt.title('Training and validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend(loc='best')

plt.show()

In [None]:
plt.clf()   # clear figure

plt.plot(epochs, acc, '-r^', label='Training acc')
plt.plot(epochs, val_acc, '-bo', label='Validation acc')
plt.title('Training and validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend(loc='best')

plt.show()