# Fine-tuning Transformers model for correct first page prediction

This notebook covers one of the approaches to training a model for predicting whether a page of the document is the first one or not -- a feature that would allow correct splitting for PDFs that consist of more than one actual document (we assume that the pages are already sorted). The approach used is fine-tuning Transformers model (BERT) with our document-related dataset.

Before you start, make sure you have **installed** and **initialized** the konfuzio_sdk package as shown in the readme of the [repository](https://github.com/konfuzio-ai/Python-SDK).

In [None]:
!pip install konfuzio-sdk

In [None]:
!konfuzio_sdk init

Also, you will need to install the Transformers-related packages:

In [None]:
!pip install transformers datasets

Importing necessary libraries and packages:

In [1]:
import os
import torch

import numpy as np
import pandas as pd

from datasets import load_dataset, load_metric
from nltk import word_tokenize
from konfuzio_sdk.data import Project
from tqdm import tqdm
from transformers import BertTokenizer, AutoModelForSequenceClassification, AutoConfig, \
                        TrainingArguments, DataCollatorWithPadding, Trainer

Setting seed for reproducibility purposes:

In [2]:
seed_value = 42
os.environ['PYTHONHASHSEED'] = str(seed_value)

Initializing the config file, model and the tokenizer:

In [10]:
configuration = AutoConfig.from_pretrained('bert-base-uncased')
configuration.num_labels = 2

In [11]:
model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased', config=configuration)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, max_length=10000, 
                                          padding="max_length", truncate=True)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

Loading the project 1644's documents and creating a .csv dataset for usage with Transformers-native Dataset class:

In [3]:
my_project = Project(id_=1644)
train_data = my_project.documents
test_data = my_project.test_documents

In [4]:
train_data_texts = []
train_data_labels = []

for doc in tqdm(train_data):
    for page in doc.pages():
        train_data_texts.append(page.text)
        if page.number == 1:
            train_data_labels.append(1)
        elif page.number != 1 and int(page.number):
            train_data_labels.append(0)
        else:
            print(page.number)

100%|██████████████████████████████████████| 1443/1443 [00:02<00:00, 614.22it/s]


In [5]:
test_data_texts = []
test_data_labels = []

for doc in tqdm(test_data):
    for page in doc.pages():
        test_data_texts.append(page.text)
        if page.number == 1:
            test_data_labels.append(1)
        elif page.number != 1 and int(page.number):
            test_data_labels.append(0)
        else:
            print(page.number)

100%|████████████████████████████████████████| 286/286 [00:00<00:00, 629.91it/s]


In [13]:
tokenized_train = tokenizer(train_data_texts)

In [25]:
train_df = pd.DataFrame({'text': train_data_texts, 'label': train_data_labels})
train_df.to_csv('train_1644.csv')

In [15]:
tokenized_test = tokenizer(test_data_texts)

In [26]:
test_df = pd.DataFrame({'text': test_data_texts, 'label': test_data_labels})
test_df.to_csv('test_1644.csv')

In [27]:
tokenized = load_dataset('csv',
                      data_files={'train': 'train_1644.csv',
                                 'test': 'test_1644.csv'})



Downloading and preparing dataset csv/default to /Users/macbookpro/.cache/huggingface/datasets/csv/default-0ae9c0752663b210/0.0.0/652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a...


Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

0 tables [00:00, ? tables/s]

0 tables [00:00, ? tables/s]

Dataset csv downloaded and prepared to /Users/macbookpro/.cache/huggingface/datasets/csv/default-0ae9c0752663b210/0.0.0/652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

Setting the training arguments:


In [28]:
dataset

DatasetDict({
    train: Dataset({
        features: ['Unnamed: 0', 'text', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 2634
    })
    test: Dataset({
        features: ['Unnamed: 0', 'text', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 435
    })
})

In [20]:
arguments = TrainingArguments(
    do_predict=True,
    output_dir='model', 
    evaluation_strategy="steps", 
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    weight_decay=0.01,
    num_train_epochs=25,
    logging_steps=1000, 
    logging_strategy='steps', 
    save_strategy='no',
    save_total_limit=2,
    seed=42,

)

data_collator = DataCollatorWithPadding(tokenizer, return_tensors="pt")

Tokenizing our dataset:

In [29]:
def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True)

In [30]:
tokenized = dataset.map(preprocess_function, batched=True)

  0%|          | 0/3 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

Defining our metric of choice which is accuracy:

In [21]:
metric = load_metric('accuracy')

Downloading builder script:   0%|          | 0.00/1.65k [00:00<?, ?B/s]

In [22]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return metric.compute(predictions=predictions, references=labels)

Initializing the Trainer class and starting the training process:

In [23]:
trainer = Trainer(
    model=model,
    args=arguments,
    train_dataset=tokenized['train'],
    eval_dataset=tokenized['test'],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

In [24]:
trainer.train()

The following columns in the training set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: Unnamed: 0, text. If Unnamed: 0, text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 2634
  Num Epochs = 25
  Instantaneous batch size per device = 4
  Total train batch size (w. parallel, distributed & accumulation) = 4
  Gradient Accumulation steps = 1
  Total optimization steps = 16475


ValueError: type of [101, 2002, 2140, 19510, 2401, 17076, 16515, 20058, 2078, 1010, 19181, 2078, 4315, 2310, 2869, 17322, 14544, 1045, 7295, 27969, 11039, 5110, 8865, 2497, 27665, 2099, 10424, 2923, 3854, 1016, 24185, 8661, 3854, 16417, 24930, 4078, 14405, 24449, 2015, 2019, 17183, 14405, 29181, 13473, 10820, 16216, 6914, 21436, 8040, 26378, 6199, 18337, 2583, 7295, 2102, 27215, 16417, 11113, 19845, 8661, 6155, 8040, 26378, 6199, 27412, 2015, 17076, 15878, 4140, 4895, 3334, 13578, 4221, 2102, 1012, 1045, 8093, 2310, 2869, 17322, 15532, 5620, 6299, 22930, 25091, 2022, 8609, 9033, 2063, 1999, 29536, 25301, 20800, 27843, 6914, 1012, 4895, 8043, 2063, 14405, 29181, 13102, 6820, 7512, 2022, 13777, 6528, 1045, 28362, 15544, 5332, 3683, 27843, 6914, 1012, 2978, 2618, 14068, 2102, 12155, 19731, 2078, 9033, 2063, 4830, 5886, 3280, 25312, 6914, 10047, 14405, 29181, 5285, 4877, 5794, 4305, 2290, 6151, 4138, 3775, 2290, 1012, 16216, 10609, 9033, 2063, 8740, 2818, 14017, 5403, 8529, 21515, 2063, 2019, 1010, 7939, 2368, 9033, 2063, 27617, 16216, 4892, 2063, 2793, 13765, 21847, 21388, 7834, 5054, 1012, 27617, 2061, 26261, 12179, 9033, 2063, 14387, 5886, 1010, 8695, 2015, 1045, 8093, 2310, 2869, 17322, 15532, 5620, 11624, 20267, 11937, 2102, 1011, 17266, 27766, 2818, 15536, 19987, 3286, 21541, 1012, 4649, 2368, 9033, 2063, 2978, 2618, 2061, 10623, 7011, 7096, 8004, 3280, 10210, 2618, 4014, 5575, 6583, 2818, 1073, 2539, 14689, 1012, 1019, 1058, 2615, 2290, 19169, 3280, 1042, 4747, 6914, 27665, 2099, 2310, 2099, 1011, 2292, 9759, 3070, 4315, 29536, 2099, 16874, 29181, 27412, 2078, 2019, 4371, 25538, 14376, 18337, 2102, 1010, 3280, 8289, 6633, 14405, 29181, 28799, 11263, 13512, 21541, 1012, 2310, 5339, 29181, 28745, 15532, 19422, 4270, 2078, 8289, 6633, 14405, 29181, 17766, 2078, 1042, 4747, 6914, 3207, 28919, 4181, 3527, 5283, 3672, 2063, 16950, 16523, 8630, 2063, 1024, 3280, 6583, 2818, 14876, 28875, 4859, 21200, 3351, 11263, 8093, 6528, 4895, 3334, 20679, 2078, 8814, 18246, 1045, 7295, 2368, 17151, 3351, 11774, 8004, 2102, 1010, 2022, 14550, 9033, 2063, 8289, 2368, 14405, 29181, 4895, 7747, 2818, 7373, 10609, 5292, 10609, 1012, 1011, 2310, 2869, 17322, 15532, 5620, 2378, 14192, 3370, 6583, 2818, 4315, 2592, 13102, 10258, 7033, 6528, 6299, 8551, 11231, 3070, 1999, 24316, 1011, 4976, 1011, 10210, 2618, 4014, 5575, 6583, 2818, 1073, 2539, 14689, 1012, 1019, 1058, 2615, 2290, 19169, 3280, 1042, 4747, 6914, 27665, 2099, 2310, 20927, 5753, 5575, 4315, 29536, 2099, 16874, 29181, 27412, 2078, 1052, 16147, 14235, 2243, 1011, 2019, 4371, 25538, 14376, 18337, 2102, 5511, 24096, 1011, 3058, 3619, 20760, 5753, 2378, 14192, 3370, 17924, 2310, 2099, 12449, 27584, 2290, 1045, 17875, 3058, 2078, 16233, 2072, 1011, 1044, 2615, 1011, 13126, 1011, 3058, 3619, 20760, 5753, 10606, 19845, 3366, 6583, 2818, 17183, 3642, 1997, 6204, 1006, 2522, 2278, 1007, 2522, 2278, 1011, 1044, 2615, 1011, 13126, 1011, 2592, 19022, 20051, 2102, 16950, 2310, 2869, 17322, 15532, 5620, 21572, 28351, 6528, 1006, 12997, 3593, 1007, 1011, 2035, 3351, 26432, 2638, 2793, 2075, 23239, 6519, 3280, 1047, 2546, 2480, 1011, 2310, 2869, 17322, 15532, 2290, 1006, 17712, 2497, 1007, 1047, 1011, 17712, 2497, 1011, 5184, 22564, 2358, 5714, 4168, 17151, 13626, 12722, 18337, 16950, 1010, 8695, 2015, 4315, 2310, 2869, 17322, 15532, 5620, 11624, 20267, 29536, 2099, 2203, 2063, 4315, 7289, 6820, 10343, 19699, 2923, 4088, 3372, 1012, 7632, 2121, 22930, 2190, 10450, 3351, 22564, 1010, 2035, 2063, 15578, 2078, 21200, 12439, 27870, 6528, 4895, 3334, 20679, 2078, 9413, 8865, 6528, 16950, 5292, 10609, 1012, 1006, 23755, 2819, 1007, 1006, 14405, 29181, 13473, 10820, 1007, 1006, 12849, 13663, 2378, 25459, 2121, 1007, 1006, 2310, 28550, 26328, 2099, 1007, 4241, 11140, 3280, 6583, 2818, 14876, 28875, 13629, 4895, 7747, 2818, 16338, 24532, 2368, 9033, 2063, 1045, 28362, 9413, 26086, 15532, 6914, 10047, 14405, 29181, 1010, 2061, 9148, 2063, 6583, 18069, 2618, 22342, 11113, 5999, 29314, 2618, 7632, 2078, 19845, 3366, 1010, 9413, 17298, 3334, 23239, 1010, 9413, 26086, 15532, 6914, 1010, 2793, 2075, 23239, 1010, 3280, 3058, 3619, 20760, 5753, 2378, 14192, 3370, 17924, 2310, 2099, 12449, 27584, 2290, 1045, 17875, 3058, 2078, 6151, 3280, 3058, 3619, 20760, 5753, 10606, 19845, 3366, 6583, 2818, 17183, 3642, 1997, 6204, 1006, 2522, 2278, 1007, 17151, 13626, 12722, 18337, 16950, 2213, 2310, 5339, 29181, 11493, 8865, 2102, 1012, 19169, 1045, 8093, 7289, 6820, 10343, 28109, 2057, 18246, 9033, 2063, 10210, 17183, 2310, 2869, 17322, 15532, 5620, 22842, 2378, 17151, 11263, 8093, 18337, 12367, 3771, 2102, 1012, 27665, 12849, 14756, 8289, 2229, 14405, 24449, 2015, 5292, 10609, 9033, 2063, 6583, 2818, 4078, 5054, 4895, 3334, 4371, 7033, 11231, 3070, 9413, 8865, 6528, 1012, 1006, 23755, 2819, 1007, 1006, 14405, 29181, 13473, 10820, 6151, 10975, 10631, 2368, 4143, 13620, 1007, 1006, 12849, 13663, 2378, 25459, 2121, 1007, 1006, 2310, 28550, 26328, 2099, 1007, 2861, 1012, 2260, 1012, 25682, 1010, 2340, 1024, 5718, 1024, 5179, 1038, 2002, 2140, 19510, 2937, 3388, 1013, 1054, 1024, 16734, 1011, 2322, 1012, 4002, 14405, 29181, 2015, 19172, 5017, 6185, 2692, 27009, 16703, 1013, 2340, 1012, 2340, 1012, 12609, 1013, 3486, 2683, 2683, 2629, 1013, 5757, 7367, 4221, 1017, 3854, 1018, 2002, 2140, 19510, 2401, 8040, 2232, 19845, 6290, 13719, 2310, 2869, 17322, 15532, 5620, 8449, 5349, 26527, 12943, 1528, 18704, 22462, 2239, 6519, 28668, 1528, 4068, 2121, 2358, 2099, 1012, 5179, 1011, 5388, 1528, 3438, 21486, 2487, 9780, 1037, 1012, 1049, 1012, 17712, 9515, 15465, 23836, 26527, 6583, 2818, 8040, 2232, 19845, 6290, 28667, 11039, 1528, 5292, 29441, 28032, 2480, 1024, 2358, 1012, 26033, 2368, 1013, 8040, 2232, 19845, 2480, 1528, 5292, 29441, 4783, 6767, 3363, 22911, 11039, 8004, 3334, 1024, 16510, 2140, 1012, 1011, 1047, 16715, 1012, 5285, 5484, 26261, 3600, 2455, 2229, 9299, 16399, 2078, 1011, 16215, 12228, 2368, 19226, 10431, 21493, 2063, 1528, 21307, 2319, 2139, 15136, 3156, 2629, 2199, 2692, 4002, 17788, 4444, 2382, 1528, 12170, 2278, 2002, 27266, 12879, 2546, 20348, 2595, 4236, 4590, 7033, 2102, 9780, 1037, 1012, 1049, 1012, 17850, 2497, 4464, 23833, 2620, 1528, 2149, 2102, 1011, 8909, 16118, 1012, 2139, 12457, 10790, 2575, 2683, 16086, 1528, 5443, 2102, 1011, 17212, 1012, 3770, 2581, 1013, 1058, 21057, 17914, 19841, 24096, 2581, 2620, 2620, 1528, 10768, 13094, 11624, 3367, 1011, 17212, 1012, 6640, 2581, 1013, 1042, 2683, 15136, 24434, 8889, 2692, 23499, 2575, 102] unknown: <class 'str'>. Should be one of a python, numpy, pytorch or tensorflow object.

In [None]:
trainer.save_model()

## Metrics & prediction

Evaluating the trained model's performance:

In [None]:
model = AutoModelForSequenceClassification.from_pretrained('model', 
                                                           config=configuration)
tokenizer = BertTokenizer.from_pretrained('model', do_lower_case=True, 
                                          max_length=10000, padding="max_length", truncate=True)

We calculate our custom metric via the following function that determines how many ground-truth first pages were actually predicted as first pages. The logic behind this approach suggests that by determining first pages correctly we can consecutively split documents correctly, using each first page as a separator (since it means a start of a new document).

In [None]:
def calculate_metrics(texts, labels):
    true_positive = 0
    false_positive = 0
    false_negative = 0
    true_negative = 0
    
    for i, test in tqdm(zip(labels, texts)):
        inputs = tokenizer(test, truncation=True, return_tensors="pt")
        with torch.no_grad():
            logits = model(**inputs).logits
        pred = logits.argmax().item()
        
        if i == 1 and pred == 1:
            true_positive += 1
        elif i == 1 and pred == 0:
            false_negative += 1
        elif i == 0 and pred == 1:
            false_positive += 1
        elif i == 0 and pred == 0:
            true_negative += 1

        
    
    if true_positive + false_positive != 0:
        precision = true_positive / (true_positive + false_positive)
    else:
        precision = 0
    
    if true_positive + false_negative != 0:
        recall = true_positive / (true_positive + false_negative)
    else:
        recall = 0
    
    if precision + recall != 0:
    
        f1 = 2 * precision * recall / (precision + recall)
    
    else:
        
        f1 = 0
    
    acc = (true_positive + true_negative) / len(texts)
    
    return precision, recall, f1, acc

In [10]:
precision, recall, f1, acc = calculate_metrics(pages_test_docs, pages_labels_test)

435it [13:05,  1.81s/it]


In [12]:
print('\n Precision: {} \n Recall: {} \n F1-score: {}'.format(precision, recall, f1))


 Precision: 0.7225433526011561 
 Recall: 0.8741258741258742 
 F1-score: 0.7911392405063291
