In [None]:
from google.colab import drive 
drive.mount('/mntDrive')

Mounted at /mntDrive


In [None]:
! rm -r ocrpostcorrection

In [None]:
!git clone https://github.com/jvdzwaan/ocrpostcorrection.git

Cloning into 'ocrpostcorrection'...
remote: Enumerating objects: 723, done.[K
remote: Counting objects: 100% (135/135), done.[K
remote: Compressing objects: 100% (87/87), done.[K
remote: Total 723 (delta 88), reused 92 (delta 48), pack-reused 588[K
Receiving objects: 100% (723/723), 1.18 MiB | 6.58 MiB/s, done.
Resolving deltas: 100% (453/453), done.


In [None]:
!pip install ./ocrpostcorrection

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Processing ./ocrpostcorrection
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting datasets
  Downloading datasets-2.8.0-py3-none-any.whl (452 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m452.9/452.9 KB[0m [31m21.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting edlib
  Downloading edlib-1.3.9-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (359 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m359.5/359.5 KB[0m [31m39.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting loguru
  Downloading loguru-0.6.0-py3-none-any.whl (58 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.3/58.3 KB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m
Collecting transformers
  Downloading transformers-4.25.1-py3-none-any.whl (5.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.8/5.8 MB[0m [31m62.8 MB/s[0m eta [3

In [None]:
from pathlib import Path

import pandas as pd

In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Loading data

In [None]:
data_base_dir = Path('/Users/janneke/Documents/Documents – Janneke’s MacBook/data/ocrpostcorrection')

In [None]:
data_base_dir = Path('/mntDrive/MyDrive/data/ocrpostcorrection')

In [None]:
!ls /mntDrive/MyDrive/data/

ocrpostcorrection


In [None]:
in_file = data_base_dir/'icdar-task2-dataset-20221031'/'task2dataset-no-duplicates.csv'
data = pd.read_csv(in_file, index_col=0)
data = data.fillna('')

In [None]:
train = data.query('dataset == "train"')
val = data.query('dataset == "val"')

In [None]:
from ocrpostcorrection.error_correction import generate_vocabs, get_text_transform

vocab_transform = generate_vocabs(train)
text_transform = get_text_transform(vocab_transform)

In [None]:
from torch.utils.data import DataLoader

from ocrpostcorrection.error_correction import SimpleCorrectionDataset, collate_fn

max_len = 22
batch_size = 256

train_dataset = SimpleCorrectionDataset(train, max_len=max_len)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, collate_fn=collate_fn(text_transform))

val_dataset = SimpleCorrectionDataset(val, max_len=max_len)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_fn(text_transform))

In [None]:
print('num train samples', len(train_dataset))
print('num val samples', len(val_dataset))

num train samples 744738
num val samples 103272


In [None]:
from ocrpostcorrection.error_correction import validate_model

In [None]:
from tqdm.notebook import tqdm

class StopExecution(Exception):
    def _render_traceback_(self):
        return []

def train_model(train_dl, val_dl, model=None, optimizer=None, num_epochs=5, valid_niter=5000, 
                model_save_path='model.rar', max_num_patience=5, max_num_trial=5, 
                lr_decay=0.5, device='cpu'):
    num_iter = 0
    report_loss = 0
    report_examples = 0
    val_loss_hist = []
    num_trial = 0
    patience = 0

    model.train()

    for epoch in range(1, num_epochs+1):
        cum_loss = 0
        cum_examples = 0

        for src, tgt in tqdm(train_dl):
            #print(f'src: {src.size()}; tgt: {tgt.size()}')
            num_iter += 1

            batch_size = src.size(1)

            src = src.to(device)
            tgt = tgt.to(device)
            encoder_hidden = model.encoder.initHidden(batch_size=batch_size, device=device)

            # print(input_hidden.size())

            example_losses, _ = model(src, encoder_hidden, tgt)
            example_losses = -example_losses
            batch_loss = example_losses.sum()
            loss = batch_loss / batch_size

            bl = batch_loss.item()
            report_loss += bl
            report_examples += batch_size

            cum_loss += bl
            cum_examples += batch_size

            optimizer.zero_grad()
            loss.backward()

            optimizer.step()

            if num_iter % valid_niter == 0:
                val_loss = validate_model(model, val_dl, device)
                print(f'Epoch {epoch}, iter {num_iter}, avg. train loss {report_loss/report_examples}, avg. val loss {val_loss}')

                report_loss = 0
                report_examples = 0

                better_model = len(val_loss_hist) == 0 or val_loss < min(val_loss_hist)
                if better_model:
                    print(f'Saving model and optimizer to {model_save_path}')
                    torch.save({
                      'model_state_dict': model.state_dict(),
                      'optimizer_state_dict': optimizer.state_dict(),
                      }, model_save_path)
                elif patience < max_num_patience:
                    patience += 1
                    print(f'hit patience {patience}')

                    if patience == max_num_patience:
                        num_trial += 1
                        print(f'hit #{num_trial} trial')
                        if num_trial == max_num_trial:
                            print('early stop!')
                            raise StopExecution('early stop!')

                        # decay lr, and restore from previously best checkpoint
                        lr = optimizer.param_groups[0]['lr'] * lr_decay
                        print(f'load previously best model and decay learning rate to {lr}')

                        # load model
                        checkpoint = torch.load(model_save_path)
                        model.load_state_dict(checkpoint['model_state_dict'])
                        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                        
                        model = model.to(device)
                        
                        # set new lr
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr

                        # reset patience
                        patience = 0
                    

                val_loss_hist.append(val_loss)



In [None]:
out_dir = data_base_dir/'results'/'simple_correction_model_2023-01-14'
out_dir.mkdir(exist_ok=True, parents=True)

In [None]:
from ocrpostcorrection.error_correction import SimpleCorrectionSeq2seq

hidden_size = 256
dropout = 0.1
model = SimpleCorrectionSeq2seq(len(vocab_transform['ocr']), 
                                hidden_size, 
                                len(vocab_transform['gs']), 
                                dropout, 
                                max_len, 
                                teacher_forcing_ratio=0.5,
                                device=device)
model.to(device)    
optimizer = torch.optim.Adam(model.parameters())

msp = out_dir/'model.rar'

train_model(train_dl=train_dataloader, 
            val_dl=val_dataloader,
            model=model, 
            optimizer=optimizer, 
            model_save_path=msp, 
            num_epochs=25, 
            valid_niter=1000, 
            max_num_patience=5, 
            max_num_trial=5, 
            lr_decay=0.5, 
            device=device)

  0%|          | 0/2910 [00:00<?, ?it/s]

Epoch 1, iter 1000, avg. train loss 23.227829644203187, avg. val loss 16.112816625842285
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model_2023-01-14/model.rar
Epoch 1, iter 2000, avg. train loss 15.854065298080444, avg. val loss 13.279709757099274
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model_2023-01-14/model.rar


  0%|          | 0/2910 [00:00<?, ?it/s]

Epoch 2, iter 3000, avg. train loss 14.139566606726929, avg. val loss 12.198645099199883
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model_2023-01-14/model.rar
Epoch 2, iter 4000, avg. train loss 13.352948026657105, avg. val loss 11.737783558631499
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model_2023-01-14/model.rar
Epoch 2, iter 5000, avg. train loss 12.900902897834778, avg. val loss 11.284699681357619
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model_2023-01-14/model.rar


  0%|          | 0/2910 [00:00<?, ?it/s]

Epoch 3, iter 6000, avg. train loss 12.664117150611359, avg. val loss 10.939773715597902
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model_2023-01-14/model.rar
Epoch 3, iter 7000, avg. train loss 12.238080160140992, avg. val loss 10.799119228747491
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model_2023-01-14/model.rar
Epoch 3, iter 8000, avg. train loss 11.826094272613526, avg. val loss 10.356122530423677
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model_2023-01-14/model.rar


  0%|          | 0/2910 [00:00<?, ?it/s]

Epoch 4, iter 9000, avg. train loss 11.665316856480981, avg. val loss 10.514984020930846
hit patience 1
Epoch 4, iter 10000, avg. train loss 11.572257776737214, avg. val loss 10.420354811372668
hit patience 2
Epoch 4, iter 11000, avg. train loss 11.488452627658845, avg. val loss 10.091223563441687
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model_2023-01-14/model.rar


  0%|          | 0/2910 [00:00<?, ?it/s]

Epoch 5, iter 12000, avg. train loss 11.244238830136403, avg. val loss 10.050209755217487
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model_2023-01-14/model.rar
Epoch 5, iter 13000, avg. train loss 11.205910049915314, avg. val loss 9.986747739926178
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model_2023-01-14/model.rar
Epoch 5, iter 14000, avg. train loss 11.164922624588012, avg. val loss 10.03054018794119
hit patience 3


  0%|          | 0/2910 [00:00<?, ?it/s]

Epoch 6, iter 15000, avg. train loss 11.112712733904354, avg. val loss 9.8969613076069
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model_2023-01-14/model.rar
Epoch 6, iter 16000, avg. train loss 10.957710127830506, avg. val loss 10.007307611858922
hit patience 4
Epoch 6, iter 17000, avg. train loss 10.83591471338272, avg. val loss 9.842108489283381
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model_2023-01-14/model.rar


  0%|          | 0/2910 [00:00<?, ?it/s]

Epoch 7, iter 18000, avg. train loss 10.689632715037705, avg. val loss 9.703038696648399
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model_2023-01-14/model.rar
Epoch 7, iter 19000, avg. train loss 10.643433878898621, avg. val loss 9.75381810859611
hit patience 5
hit #1 trial
load previously best model and decay learning rate to 0.0005
Epoch 7, iter 20000, avg. train loss 10.55265146112442, avg. val loss 9.647074156334236
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model_2023-01-14/model.rar


  0%|          | 0/2910 [00:00<?, ?it/s]

Epoch 8, iter 21000, avg. train loss 10.279620376743333, avg. val loss 9.563503861418072
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model_2023-01-14/model.rar
Epoch 8, iter 22000, avg. train loss 10.26977631187439, avg. val loss 9.328021848032758
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model_2023-01-14/model.rar
Epoch 8, iter 23000, avg. train loss 10.349749496936798, avg. val loss 9.276446026349864
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model_2023-01-14/model.rar


  0%|          | 0/2910 [00:00<?, ?it/s]

Epoch 9, iter 24000, avg. train loss 10.166142336578712, avg. val loss 9.214138599856447
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model_2023-01-14/model.rar
Epoch 9, iter 25000, avg. train loss 10.127833611011505, avg. val loss 9.335289142717093
hit patience 1
Epoch 9, iter 26000, avg. train loss 10.278156569004059, avg. val loss 9.36290974594472
hit patience 2


  0%|          | 0/2910 [00:00<?, ?it/s]

Epoch 10, iter 27000, avg. train loss 10.121291022904032, avg. val loss 9.259917587598238
hit patience 3
Epoch 10, iter 28000, avg. train loss 10.147399868965149, avg. val loss 9.253627715929703
hit patience 4
Epoch 10, iter 29000, avg. train loss 10.168554908275604, avg. val loss 9.106611716515495
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model_2023-01-14/model.rar


  0%|          | 0/2910 [00:00<?, ?it/s]

Epoch 11, iter 30000, avg. train loss 9.892677240254107, avg. val loss 9.10546605734205
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model_2023-01-14/model.rar
Epoch 11, iter 31000, avg. train loss 10.05647609949112, avg. val loss 9.17400997870195
hit patience 5
hit #2 trial
load previously best model and decay learning rate to 0.00025
Epoch 11, iter 32000, avg. train loss 9.802407557010651, avg. val loss 8.922101947448432
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model_2023-01-14/model.rar


  0%|          | 0/2910 [00:00<?, ?it/s]

Epoch 12, iter 33000, avg. train loss 9.671664194165922, avg. val loss 9.052841421155836
hit patience 1
Epoch 12, iter 34000, avg. train loss 9.624266938209534, avg. val loss 9.093339984907688
hit patience 2


  0%|          | 0/2910 [00:00<?, ?it/s]

Epoch 13, iter 35000, avg. train loss 9.6622826617448, avg. val loss 9.065616415291089
hit patience 3
Epoch 13, iter 36000, avg. train loss 9.519576427936554, avg. val loss 9.152915654354937
hit patience 4
Epoch 13, iter 37000, avg. train loss 9.62256542301178, avg. val loss 8.770258988683253
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model_2023-01-14/model.rar


  0%|          | 0/2910 [00:00<?, ?it/s]

Epoch 14, iter 38000, avg. train loss 9.545937502286032, avg. val loss 8.90817572797367
hit patience 5
hit #3 trial
load previously best model and decay learning rate to 0.000125
Epoch 14, iter 39000, avg. train loss 9.454290801048279, avg. val loss 8.962353815623501
hit patience 1
Epoch 14, iter 40000, avg. train loss 9.621330524921417, avg. val loss 8.649534131049444
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model_2023-01-14/model.rar


  0%|          | 0/2910 [00:00<?, ?it/s]

Epoch 15, iter 41000, avg. train loss 9.372525386385101, avg. val loss 8.890587856568432
hit patience 2
Epoch 15, iter 42000, avg. train loss 9.461287173748016, avg. val loss 8.626340694251663
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model_2023-01-14/model.rar
Epoch 15, iter 43000, avg. train loss 9.438299771785736, avg. val loss 8.856668003468503
hit patience 3


  0%|          | 0/2910 [00:00<?, ?it/s]

Epoch 16, iter 44000, avg. train loss 9.52536601638612, avg. val loss 8.88772221011984
hit patience 4
Epoch 16, iter 45000, avg. train loss 9.362547441482544, avg. val loss 8.872310621583846
hit patience 5
hit #4 trial
load previously best model and decay learning rate to 6.25e-05
Epoch 16, iter 46000, avg. train loss 9.347121622085572, avg. val loss 8.65568360330078
hit patience 1


  0%|          | 0/2910 [00:00<?, ?it/s]

Epoch 17, iter 47000, avg. train loss 9.372967989700939, avg. val loss 8.814872336854904
hit patience 2
Epoch 17, iter 48000, avg. train loss 9.387550693511963, avg. val loss 8.635493487205094
hit patience 3
Epoch 17, iter 49000, avg. train loss 9.344637980937957, avg. val loss 8.448380747091873
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model_2023-01-14/model.rar


  0%|          | 0/2910 [00:00<?, ?it/s]

Epoch 18, iter 50000, avg. train loss 9.40736315245904, avg. val loss 8.79755446942071
hit patience 4
Epoch 18, iter 51000, avg. train loss 9.37230985069275, avg. val loss 8.805215535188628
hit patience 5
hit #5 trial
early stop!


StopExecution: ignored