In [None]:
from google.colab import drive 
drive.mount('/mntDrive')

Drive already mounted at /mntDrive; to attempt to forcibly remount, call drive.mount("/mntDrive", force_remount=True).


In [None]:
! rm -r ocrpostcorrection

In [None]:
!git clone https://github.com/jvdzwaan/ocrpostcorrection.git

Cloning into 'ocrpostcorrection'...
remote: Enumerating objects: 660, done.[K
remote: Counting objects: 100% (65/65), done.[K
remote: Compressing objects: 100% (44/44), done.[K
remote: Total 660 (delta 38), reused 44 (delta 21), pack-reused 595[K
Receiving objects: 100% (660/660), 1.12 MiB | 18.21 MiB/s, done.
Resolving deltas: 100% (403/403), done.


In [None]:
!pip install ./ocrpostcorrection

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Processing ./ocrpostcorrection
[33m  DEPRECATION: A future pip version will change local packages to be built in-place without first copying to a temporary directory. We recommend you use --use-feature=in-tree-build to test your packages with this new behavior before it becomes the default.
   pip 21.3 will remove support for this functionality. You can find discussion regarding this at https://github.com/pypa/pip/issues/7555.[0m
Building wheels for collected packages: ocrpostcorrection
  Building wheel for ocrpostcorrection (setup.py) ... [?25l[?25hdone
  Created wheel for ocrpostcorrection: filename=ocrpostcorrection-0.0.1-py3-none-any.whl size=23701 sha256=1ede2f1b9323bcf541b139f415b402f5c5bc9c64f0481d5203d258f0baf6c40f
  Stored in directory: /tmp/pip-ephem-wheel-cache-ly7hnj9s/wheels/ce/8c/19/e683c70df08a7eb9cb93fa17f510cfeef9e1e912a83c835a2c
Successfully built ocrpostcorrection
I

In [None]:
from pathlib import Path

import pandas as pd

In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Loading data

In [None]:
data_base_dir = Path('/Users/janneke/Documents/Documents – Janneke’s MacBook/data/ocrpostcorrection')

In [None]:
data_base_dir = Path('/mntDrive/MyDrive/data')

In [None]:
!ls /mntDrive/MyDrive/data/

ICDAR2019_POCR_competition_dataset  ocrpostcorrection
icdar-task2-dataset-20221031


In [None]:
in_file = data_base_dir/'icdar-task2-dataset-20221031'/'task2dataset-no-duplicates.csv'
data = pd.read_csv(in_file, index_col=0)
data = data.fillna('')

In [None]:
train = data.query('dataset == "train"')
val = data.query('dataset == "val"')

In [None]:
from ocrpostcorrection.error_correction import generate_vocabs, get_text_transform

vocab_transform = generate_vocabs(train)
text_transform = get_text_transform(vocab_transform)

In [None]:
from torch.utils.data import DataLoader

from ocrpostcorrection.error_correction import SimpleCorrectionDataset, collate_fn

max_len = 22
batch_size = 128

train_dataset = SimpleCorrectionDataset(train, max_len=max_len)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn(text_transform))

val_dataset = SimpleCorrectionDataset(val, max_len=max_len)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_fn(text_transform))

In [None]:
print('num train samples', len(train_dataset))
print('num val samples', len(val_dataset))

num train samples 744738
num val samples 103272


In [None]:
from ocrpostcorrection.error_correction import validate_model

In [None]:
from tqdm.notebook import tqdm

class StopExecution(Exception):
    def _render_traceback_(self):
        return []

def train_model(train_dl, val_dl, model=None, optimizer=None, num_epochs=5, valid_niter=5000, 
                model_save_path='model.rar', max_num_patience=5, max_num_trial=5, 
                lr_decay=0.5, device='cpu'):
    num_iter = 0
    report_loss = 0
    report_examples = 0
    val_loss_hist = []
    num_trial = 0
    patience = 0

    model.train()

    for epoch in range(1, num_epochs+1):
        cum_loss = 0
        cum_examples = 0

        for src, tgt in train_dl:
            #print(f'src: {src.size()}; tgt: {tgt.size()}')
            num_iter += 1

            batch_size = src.size(1)

            src = src.to(device)
            tgt = tgt.to(device)
            encoder_hidden = model.encoder.initHidden(batch_size=batch_size, device=device)

            # print(input_hidden.size())

            example_losses, _ = model(src, encoder_hidden, tgt)
            example_losses = -example_losses
            batch_loss = example_losses.sum()
            loss = batch_loss / batch_size

            bl = batch_loss.item()
            report_loss += bl
            report_examples += batch_size

            cum_loss += bl
            cum_examples += batch_size

            optimizer.zero_grad()
            loss.backward()

            # clip gradient
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)

            optimizer.step()

            if num_iter % valid_niter == 0:
                val_loss = validate_model(model, val_dl, device)
                print(f'Epoch {epoch}, iter {num_iter}, avg. train loss {report_loss/report_examples}, avg. val loss {val_loss}')

                report_loss = 0
                report_examples = 0

                better_model = len(val_loss_hist) == 0 or val_loss < min(val_loss_hist)
                if better_model:
                    print(f'Saving model and optimizer to {model_save_path}')
                    torch.save({
                      'model_state_dict': model.state_dict(),
                      'optimizer_state_dict': optimizer.state_dict(),
                      }, model_save_path)
                elif patience < max_num_patience:
                    patience += 1
                    print(f'hit patience {patience}')

                    if patience == max_num_patience:
                        num_trial += 1
                        print(f'hit #{num_trial} trial')
                        if num_trial == max_num_trial:
                            print('early stop!')
                            raise StopExecution('early stop!')

                        # decay lr, and restore from previously best checkpoint
                        lr = optimizer.param_groups[0]['lr'] * lr_decay
                        print(f'load previously best model and decay learning rate to {lr}')

                        # load model
                        checkpoint = torch.load(model_save_path)
                        model.load_state_dict(checkpoint['model_state_dict'])
                        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                        
                        model = model.to(device)
                        
                        # set new lr
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr

                        # reset patience
                        patience = 0
                    

                val_loss_hist.append(val_loss)



In [None]:
out_dir = data_base_dir/'ocrpostcorrection'/'results'/'simple_correction_model'
out_dir.mkdir(exist_ok=True, parents=True)

In [None]:
from ocrpostcorrection.error_correction import SimpleCorrectionSeq2seq

hidden_size = 256
dropout = 0.1
model = SimpleCorrectionSeq2seq(len(vocab_transform['ocr']), 
                                hidden_size, 
                                len(vocab_transform['gs']), 
                                dropout, 
                                max_len, 
                                teacher_forcing_ratio=0.5,
                                device=device)
model.to(device)    
optimizer = torch.optim.Adam(model.parameters())

msp = out_dir/'model.rar'

train_model(train_dl=train_dataloader, 
            val_dl=val_dataloader,
            max_token_len=max_len,
            model=model, 
            optimizer=optimizer, 
            model_save_path=msp, 
            num_epochs=25, 
            valid_niter=1000, 
            max_num_patience=5, 
            max_num_trial=5, 
            lr_decay=0.5, 
            device=device)

  0%|          | 0/4006 [00:00<?, ?it/s]

Epoch 1, iter 1000, avg. train loss 17.0415628156662, avg. val loss 15.508534159637303
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model/model.rar
Epoch 1, iter 2000, avg. train loss 9.4730129737854, avg. val loss 11.724916617607192
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model/model.rar
Epoch 1, iter 3000, avg. train loss 6.865117371082306, avg. val loss 11.462291592175069
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model/model.rar
Epoch 1, iter 4000, avg. train loss 10.169847501754761, avg. val loss 13.555501511074551
hit patience 1


  0%|          | 0/4006 [00:00<?, ?it/s]

Epoch 2, iter 5000, avg. train loss 11.73970633384165, avg. val loss 11.153411283623155
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model/model.rar
Epoch 2, iter 6000, avg. train loss 7.220437708854675, avg. val loss 9.760569557918576
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model/model.rar
Epoch 2, iter 7000, avg. train loss 5.6575815722942355, avg. val loss 9.955293140106598
hit patience 2
Epoch 2, iter 8000, avg. train loss 9.161027384996414, avg. val loss 12.572123071117982
hit patience 3


  0%|          | 0/4006 [00:00<?, ?it/s]

Epoch 3, iter 9000, avg. train loss 10.842043943595215, avg. val loss 10.71486210234512
hit patience 4
Epoch 3, iter 10000, avg. train loss 6.69287888598442, avg. val loss 9.239391213295699
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model/model.rar
Epoch 3, iter 11000, avg. train loss 5.306302636146546, avg. val loss 9.494960311644205
hit patience 5
hit #1 trial
load previously best model and decay learning rate to 0.0005
Epoch 3, iter 12000, avg. train loss 8.743113709926606, avg. val loss 11.718912902174276
hit patience 1


  0%|          | 0/4006 [00:00<?, ?it/s]

Epoch 4, iter 13000, avg. train loss 10.152274047326836, avg. val loss 9.688396484031541
hit patience 2
Epoch 4, iter 14000, avg. train loss 6.336827111721039, avg. val loss 8.55588805653357
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model/model.rar
Epoch 4, iter 15000, avg. train loss 5.101096046686172, avg. val loss 8.899922666525887
hit patience 3
Epoch 4, iter 16000, avg. train loss 8.371230652570725, avg. val loss 11.83752282718732
hit patience 4


  0%|          | 0/4006 [00:00<?, ?it/s]

Epoch 5, iter 17000, avg. train loss 9.944540591551112, avg. val loss 9.666133591632223
hit patience 5
hit #2 trial
load previously best model and decay learning rate to 0.00025
Epoch 5, iter 18000, avg. train loss 6.237562358617782, avg. val loss 8.430532484302136
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model/model.rar
Epoch 5, iter 19000, avg. train loss 5.060271089553833, avg. val loss 8.541593983895195
hit patience 1
Epoch 5, iter 20000, avg. train loss 8.479198782205582, avg. val loss 11.624844120092666
hit patience 2


  0%|          | 0/4006 [00:00<?, ?it/s]

Epoch 6, iter 21000, avg. train loss 10.199479566650867, avg. val loss 8.965291766013639
hit patience 3
Epoch 6, iter 22000, avg. train loss 6.0284324617385865, avg. val loss 8.205115923248938
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model/model.rar
Epoch 6, iter 23000, avg. train loss 4.821524203538894, avg. val loss 8.582757151024973
hit patience 4
Epoch 6, iter 24000, avg. train loss 8.11319116783142, avg. val loss 11.912475486302151
hit patience 5
hit #3 trial
load previously best model and decay learning rate to 0.000125


  0%|          | 0/4006 [00:00<?, ?it/s]

Epoch 7, iter 25000, avg. train loss 10.379341751462363, avg. val loss 8.155694473395398
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model/model.rar
Epoch 7, iter 26000, avg. train loss 6.047264521360398, avg. val loss 8.01454865877407
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model/model.rar
Epoch 7, iter 27000, avg. train loss 4.8521045932769775, avg. val loss 8.007700598365835
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model/model.rar
Epoch 7, iter 28000, avg. train loss 8.413783473730087, avg. val loss 10.785207058357296
hit patience 1


  0%|          | 0/4006 [00:00<?, ?it/s]

Epoch 8, iter 29000, avg. train loss 10.050432466638885, avg. val loss 8.171083133084105
hit patience 2
Epoch 8, iter 30000, avg. train loss 6.09458415222168, avg. val loss 7.959974724237545
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model/model.rar
Epoch 8, iter 31000, avg. train loss 4.741669086694717, avg. val loss 8.147681500189382
hit patience 3
Epoch 8, iter 32000, avg. train loss 8.178057569026947, avg. val loss 10.738872468335465
hit patience 4


  0%|          | 0/4006 [00:00<?, ?it/s]

Epoch 9, iter 33000, avg. train loss 10.04274268188715, avg. val loss 8.031199898810664
hit patience 5
hit #4 trial
load previously best model and decay learning rate to 6.25e-05
Epoch 9, iter 34000, avg. train loss 6.193686002254486, avg. val loss 7.755900577254158
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model/model.rar
Epoch 9, iter 35000, avg. train loss 4.794535130023957, avg. val loss 7.9622724111554675
hit patience 1
Epoch 9, iter 36000, avg. train loss 8.462147181987762, avg. val loss 9.679042915999968
hit patience 2


  0%|          | 0/4006 [00:00<?, ?it/s]

Epoch 10, iter 37000, avg. train loss 10.198173898962988, avg. val loss 7.750432576989902
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model/model.rar
Epoch 10, iter 38000, avg. train loss 6.261524030923844, avg. val loss 7.6332660405101835
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model/model.rar
Epoch 10, iter 39000, avg. train loss 4.754683881521225, avg. val loss 8.051098058005646
hit patience 3
Epoch 10, iter 40000, avg. train loss 8.268364515066146, avg. val loss 9.609069118670877
hit patience 4


  0%|          | 0/4006 [00:00<?, ?it/s]

Epoch 11, iter 41000, avg. train loss 10.200509568184975, avg. val loss 7.655317834363139
hit patience 5
hit #5 trial
early stop!
load previously best model and decay learning rate to 3.125e-05
Epoch 11, iter 42000, avg. train loss 6.283131835222244, avg. val loss 7.666493009140064
hit patience 1
Epoch 11, iter 43000, avg. train loss 4.813553342103958, avg. val loss 7.767371000526956
hit patience 2
Epoch 11, iter 44000, avg. train loss 8.521714780092239, avg. val loss 8.61853788865627
hit patience 3


  0%|          | 0/4006 [00:00<?, ?it/s]

Epoch 12, iter 45000, avg. train loss 10.422263543746489, avg. val loss 7.433023445367143
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model/model.rar
Epoch 12, iter 46000, avg. train loss 6.391187304735184, avg. val loss 7.465193162255779
hit patience 4
Epoch 12, iter 47000, avg. train loss 4.760101779699325, avg. val loss 7.711935139021347
hit patience 5
hit #6 trial
load previously best model and decay learning rate to 1.5625e-05
Epoch 12, iter 48000, avg. train loss 8.644169484138489, avg. val loss 7.669757801907871
hit patience 1


  0%|          | 0/4006 [00:00<?, ?it/s]

Epoch 13, iter 49000, avg. train loss 10.483899677093255, avg. val loss 7.436883882142976
hit patience 2


KeyboardInterrupt: ignored

In [None]:
from google.colab import runtime
runtime.unassign()