In [None]:
!pip install konfuzio-sdk transformers

In [None]:
!konfuzio_sdk init

In [None]:
import cv2
import logging
import torch

import numpy as np
import tensorflow as tf

from keras.applications.vgg19 import preprocess_input
from keras.layers import Dense, Conv2D, MaxPool2D, Flatten, Input, concatenate
from keras.models import load_model, Model
from pathlib import Path
from tensorflow.keras.preprocessing.image import ImageDataGenerator, img_to_array
from transformers import BertTokenizer, AutoModel, AutoConfig
from typing import List

from konfuzio_sdk.data import Document, Page, Project

tf.config.experimental_run_functions_eagerly(True)

Instructions for updating:
Use `tf.config.run_functions_eagerly` instead of the experimental version.


In [None]:
class FileSplittingModel:
    """Train a fusion model for correct splitting of files which contain multiple Documents.
    A model consists of two separate inputs for visual and textual data combined in a Multi-Layered
    Perceptron (MLP). Visual part is represented by VGG16 architecture and is trained on a first share of split training
    dataset. Textual part is represented by LegalBERT which is used without any training.
    Embeddings received from two of he models are squashed and the resulting vectors are fed as inputs to the MLP.
    The resulting trained model is saved in .h5, roughly 1.5 Gb in size.
    """

    def __init__(self, project_id: int):
        """
        Initialize Project, training and testing data.
        :param project_id: ID of the Project used for training the model.
        :type project_id: int
        """
        self.project = Project(id_=project_id)
        self.train_data = self.project.documents
        self.test_data = self.project.test_documents

    def _preprocess_documents(self, data: List[Document]) -> (List[str], List[str], List[int]):
        pages = []
        texts = []
        labels = []
        for doc in data:
            for page in doc.pages():
                pages.append(page.image_path)
                texts.append(page.text)
                if page.number == 1:
                    labels.append(1)
                else:
                    labels.append(0)
        return pages, texts, labels

    def _otsu_binarization(self, pages: List[str]):
        images = []
        for img in pages:
            image = cv2.imread(img)
            image = cv2.resize(image, (224, 224), interpolation=cv2.INTER_AREA)
            image = img_to_array(image)
            image = image.reshape((1, image.shape[0], image.shape[1], image.shape[2]))
            image = preprocess_input(image)
            images.append(image)
        return images

    def prepare_visual_textual_data(
        self, train_data: List[Document], test_data: List[Document], bert_model, bert_tokenizer
    ):
        """
        Prepare visual and textual inputs and transform them for feeding to the fusion model.
        :param train_data: Train dataset from the project.documents.
        :type train_data: list
        :param test_data: Test dataset from the project.test_documents.
        :type test_data: list
        :param bert_model: Initialized LegalBERT model.
        :param bert_tokenizer: Initialized BERTTokenizer.
        :return: Train and test visual inputs, train and test textual inputs, train and test labels, input shape for
        textual inputs.
        """
        for doc in train_data + test_data:
            doc.get_images()
        train_pages, train_texts, train_labels = self._preprocess_documents(train_data)
        test_pages, test_texts, test_labels = self._preprocess_documents(test_data)
        train_images = self._otsu_binarization(train_pages)
        test_images = self._otsu_binarization(test_pages)
        train_labels = tf.cast(np.asarray(train_labels).reshape((-1, 1)), tf.float32)
        test_labels = tf.cast(np.asarray(test_labels).reshape((-1, 1)), tf.float32)
        image_data_generator = ImageDataGenerator()
        train_data_generator = image_data_generator.flow(x=np.squeeze(train_images, axis=1), y=train_labels)
        train_img_data = np.concatenate([train_data_generator.next()[0] for i in range(train_data_generator.__len__())])
        test_data_generator = image_data_generator.flow(x=np.squeeze(test_images, axis=1), y=test_labels)
        test_img_data = np.concatenate([test_data_generator.next()[0] for i in range(test_data_generator.__len__())])
        train_txt_data = []
        for text in train_texts:
            inputs = bert_tokenizer(text, truncation=True, return_tensors='pt')
            with torch.no_grad():
                output = bert_model(**inputs)
            train_txt_data.append(output.pooler_output)
        train_txt_data = [np.asarray(x).astype('float32') for x in train_txt_data]
        train_txt_data = np.asarray(train_txt_data)
        test_txt_data = []
        for text in test_texts:
            inputs = bert_tokenizer(text, truncation=True, return_tensors='pt')
            with torch.no_grad():
                output = bert_model(**inputs)
            test_txt_data.append(output.pooler_output)
        txt_input_shape = test_txt_data[0].shape
        test_txt_data = [np.asarray(x).astype('float32') for x in test_txt_data]
        test_txt_data = np.asarray(test_txt_data)
        return train_img_data, train_txt_data, test_img_data, test_txt_data, train_labels, test_labels, txt_input_shape

    def init_model(self, input_shape):
        """
        Initialize the fusion model.
        :param input_shape: Input shape for the textual part of the model.
        :type input_shape: tuple
        :return: A compiled fusion model.
        """
        txt_input = Input(shape=input_shape, name='text')
        txt_x = Dense(units=768, activation="relu")(txt_input)
        txt_x = Flatten()(txt_x)
        txt_x = Dense(units=4096, activation="relu")(txt_x)
        img_input = Input(shape=(224, 224, 3), name='image')
        img_x = Conv2D(input_shape=(224, 224, 3), filters=64, kernel_size=(3, 3), padding="same", activation="relu")(
            img_input
        )
        img_x = Conv2D(filters=64, kernel_size=(3, 3), padding="same", activation="relu")(img_x)
        img_x = MaxPool2D(pool_size=(2, 2), strides=(2, 2))(img_x)
        img_x = Conv2D(filters=128, kernel_size=(3, 3), padding="same", activation="relu")(img_x)
        img_x = Conv2D(filters=128, kernel_size=(3, 3), padding="same", activation="relu")(img_x)
        img_x = MaxPool2D(pool_size=(2, 2), strides=(2, 2))(img_x)
        img_x = Conv2D(filters=256, kernel_size=(3, 3), padding="same", activation="relu")(img_x)
        img_x = Conv2D(filters=256, kernel_size=(3, 3), padding="same", activation="relu")(img_x)
        img_x = Conv2D(filters=256, kernel_size=(3, 3), padding="same", activation="relu")(img_x)
        img_x = MaxPool2D(pool_size=(2, 2), strides=(2, 2))(img_x)
        img_x = Conv2D(filters=512, kernel_size=(3, 3), padding="same", activation="relu")(img_x)
        img_x = Conv2D(filters=512, kernel_size=(3, 3), padding="same", activation="relu")(img_x)
        img_x = Conv2D(filters=512, kernel_size=(3, 3), padding="same", activation="relu")(img_x)
        img_x = MaxPool2D(pool_size=(2, 2), strides=(2, 2))(img_x)
        img_x = Conv2D(filters=512, kernel_size=(3, 3), padding="same", activation="relu")(img_x)
        img_x = Conv2D(filters=512, kernel_size=(3, 3), padding="same", activation="relu")(img_x)
        img_x = Conv2D(filters=512, kernel_size=(3, 3), padding="same", activation="relu")(img_x)
        img_x = MaxPool2D(pool_size=(2, 2), strides=(2, 2))(img_x)
        img_x = Flatten()(img_x)
        img_x = Dense(units=4096, activation="relu")(img_x)
        img_x = Dense(units=4096, activation="relu", name='img_outputs')(img_x)
        concatenated = concatenate([img_x, txt_x], axis=-1)
        x = Dense(50, input_shape=(8192,), activation='relu')(concatenated)
        x = Dense(50, activation='elu')(x)
        x = Dense(50, activation='elu')(x)
        output = Dense(1, activation='sigmoid')(x)
        model = Model(inputs=[img_input, txt_input], outputs=output)
        model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
        return model

    def init_bert(self):
        """Initialize BERT model and tokenizer."""
        configuration = AutoConfig.from_pretrained('nlpaueb/legal-bert-base-uncased')
        configuration.num_labels = 2
        configuration.output_hidden_states = True
        model = AutoModel.from_pretrained('nlpaueb/legal-bert-base-uncased', config=configuration)
        tokenizer = BertTokenizer.from_pretrained(
            'nlpaueb/legal-bert-base-uncased', do_lower_case=True, max_length=10000, padding="max_length", truncate=True
        )
        return model, tokenizer

    def _predict_label(self, img_input, txt_input, model) -> int:
        pred = model.predict([img_input.reshape((1, 224, 224, 3)), txt_input.reshape((1, 1, 768))], verbose=0)
        return round(pred[0, 0])

    def calculate_metrics(self, model, img_inputs: List, txt_inputs: List, labels: List) -> (float, float, float):
        """
        Calculate precision, recall, and F1 measure for the trained model.
        :param model: The trained model.
        :param img_inputs: Processed visual inputs from the test dataset.
        :type img_inputs: list
        :param txt_inputs: Processed textual inputs from the test dataset.
        :type txt_inputs: list
        :param labels: Labels from the test dataset.
        :type labels: list
        :return: Calculated precision, recall, and F1 measure.
        """
        true_positive = 0
        false_positive = 0
        false_negative = 0
        for img, txt, label in zip(img_inputs, txt_inputs, labels):
            pred = self._predict_label(img, txt, model)
            if label == 1 and pred == 1:
                true_positive += 1
            elif label == 1 and pred == 0:
                false_negative += 1
            elif label == 0 and pred == 1:
                false_positive += 1
        if true_positive + false_positive != 0:
            precision = true_positive / (true_positive + false_positive)
        else:
            precision = 0
        if true_positive + false_negative != 0:
            recall = true_positive / (true_positive + false_negative)
        else:
            recall = 0
        if precision + recall != 0:
            f1 = 2 * precision * recall / (precision + recall)
        else:
            f1 = 0
        return precision, recall, f1

    def train(self):
        """
        Training or loading the trained model.
        :return: A trained fusion model.
        """
        if Path(self.project.model_folder + '/fusion.h5').exists():
            model = load_model(self.project.model_folder + '/fusion.h5')
        else:
            bert_model, bert_tokenizer = self.init_bert()
            (
                train_img_data,
                train_txt_data,
                test_img_data,
                test_txt_data,
                train_labels,
                test_labels,
                input_shape,
            ) = self.prepare_visual_textual_data(self.train_data, self.test_data, bert_model, bert_tokenizer)
            model = self.init_model(input_shape)
            model.fit([train_img_data, train_txt_data], train_labels, epochs=10, verbose=1)
            model.save(self.project.model_folder + '/fusion.h5')
            loss, acc = model.evaluate([test_img_data, test_txt_data], test_labels, verbose=0)
            logging.info('Accuracy: {}'.format(acc * 100))
            precision, recall, f1 = self.calculate_metrics(model, test_img_data, test_txt_data, test_labels)
            logging.info('\n Precision: {} \n Recall: {} \n F1-score: {}'.format(precision, recall, f1))
        return model

In [None]:
fsm = FileSplittingModel(project_id=1644)

In [None]:
model = fsm.train()

Downloading:   0%|          | 0.00/1.02k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of the model checkpoint at nlpaueb/legal-bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Downloading:   0%|          | 0.00/222k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

  "Even though the `tf.config.experimental_run_functions_eagerly` "


Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [None]:
test_pages, test_texts, test_labels = fsm._preprocess_documents(fsm.test_data)
test_images = fsm._otsu_binarization(test_pages)
test_labels = tf.cast(np.asarray(test_labels).reshape((-1, 1)), tf.float32)
image_data_generator = ImageDataGenerator()
test_data_generator = image_data_generator.flow(x=np.squeeze(test_images, axis=1), y=test_labels)
test_img_data = np.concatenate([test_data_generator.next()[0] for i in range(test_data_generator.__len__())])
test_txt_data = []
for text in test_texts:
    inputs = bert_tokenizer(text, truncation=True, return_tensors='pt')
    with torch.no_grad():
        output = bert_model(**inputs)
    test_txt_data.append(output.pooler_output)
txt_input_shape = test_txt_data[0].shape
test_txt_data = [np.asarray(x).astype('float32') for x in test_txt_data]
test_txt_data = np.asarray(test_txt_data)

In [None]:
loss, acc = model.evaluate([test_img_data, test_txt_data], test_labels, verbose=0)
print('Accuracy: {}'.format(acc * 100))
precision, recall, f1 = fsm.calculate_metrics(model, test_img_data, test_txt_data, test_labels)
print('\n Precision: {} \n Recall: {} \n F1-score: {}'.format(precision, recall, f1))

  "Even though the `tf.config.experimental_run_functions_eagerly` "


Accuracy: 75.1724123954773

 Precision: 0.7445054945054945 
 Recall: 0.9475524475524476 
 F1-score: 0.833846153846154
