In [None]:
from google.colab import drive 
drive.mount('/mntDrive')

Drive already mounted at /mntDrive; to attempt to forcibly remount, call drive.mount("/mntDrive", force_remount=True).


In [None]:
! rm -r ocrpostcorrection

In [None]:
!git clone https://github.com/jvdzwaan/ocrpostcorrection.git

Cloning into 'ocrpostcorrection'...
remote: Enumerating objects: 723, done.[K
remote: Counting objects: 100% (135/135), done.[K
remote: Compressing objects: 100% (87/87), done.[K
remote: Total 723 (delta 88), reused 92 (delta 48), pack-reused 588[K
Receiving objects: 100% (723/723), 1.18 MiB | 12.81 MiB/s, done.
Resolving deltas: 100% (453/453), done.


In [None]:
!pip install ./ocrpostcorrection

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Processing ./ocrpostcorrection
[33m  DEPRECATION: A future pip version will change local packages to be built in-place without first copying to a temporary directory. We recommend you use --use-feature=in-tree-build to test your packages with this new behavior before it becomes the default.
   pip 21.3 will remove support for this functionality. You can find discussion regarding this at https://github.com/pypa/pip/issues/7555.[0m
Collecting datasets
  Downloading datasets-2.7.1-py3-none-any.whl (451 kB)
[K     |████████████████████████████████| 451 kB 12.3 MB/s 
[?25hCollecting edlib
  Downloading edlib-1.3.9-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (325 kB)
[K     |████████████████████████████████| 325 kB 60.1 MB/s 
[?25hCollecting loguru
  Downloading loguru-0.6.0-py3-none-any.whl (58 kB)
[K     |████████████████████████████████| 58 kB 6.8 MB/s 
Collecting transf

In [None]:
from pathlib import Path

import pandas as pd

In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Loading data

In [None]:
data_base_dir = Path('/Users/janneke/Documents/Documents – Janneke’s MacBook/data/ocrpostcorrection')

In [None]:
data_base_dir = Path('/mntDrive/MyDrive/data/ocrpostcorrection')

In [None]:
!ls /mntDrive/MyDrive/data/

ocrpostcorrection


In [None]:
in_file = data_base_dir/'icdar-task2-dataset-20221031'/'task2dataset-no-duplicates.csv'
data = pd.read_csv(in_file, index_col=0)
data = data.fillna('')

In [None]:
train = data.query('dataset == "train"')
val = data.query('dataset == "val"')

In [None]:
from ocrpostcorrection.error_correction import generate_vocabs, get_text_transform

vocab_transform = generate_vocabs(train)
text_transform = get_text_transform(vocab_transform)

In [None]:
from torch.utils.data import DataLoader

from ocrpostcorrection.error_correction import SimpleCorrectionDataset, collate_fn

max_len = 22
batch_size = 256

train_dataset = SimpleCorrectionDataset(train, max_len=max_len)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, collate_fn=collate_fn(text_transform))

val_dataset = SimpleCorrectionDataset(val, max_len=max_len)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_fn(text_transform))

In [None]:
print('num train samples', len(train_dataset))
print('num val samples', len(val_dataset))

num train samples 744738
num val samples 103272


In [None]:
from ocrpostcorrection.error_correction import validate_model

In [None]:
from tqdm.notebook import tqdm

class StopExecution(Exception):
    def _render_traceback_(self):
        return []

def train_model(train_dl, val_dl, model=None, optimizer=None, num_epochs=5, valid_niter=5000, 
                model_save_path='model.rar', max_num_patience=5, max_num_trial=5, 
                lr_decay=0.5, device='cpu'):
    num_iter = 0
    report_loss = 0
    report_examples = 0
    val_loss_hist = []
    num_trial = 0
    patience = 0

    model.train()

    for epoch in range(1, num_epochs+1):
        cum_loss = 0
        cum_examples = 0

        for src, tgt in tqdm(train_dl):
            #print(f'src: {src.size()}; tgt: {tgt.size()}')
            num_iter += 1

            batch_size = src.size(1)

            src = src.to(device)
            tgt = tgt.to(device)
            encoder_hidden = model.encoder.initHidden(batch_size=batch_size, device=device)

            # print(input_hidden.size())

            example_losses, _ = model(src, encoder_hidden, tgt)
            example_losses = -example_losses
            batch_loss = example_losses.sum()
            loss = batch_loss / batch_size

            bl = batch_loss.item()
            report_loss += bl
            report_examples += batch_size

            cum_loss += bl
            cum_examples += batch_size

            optimizer.zero_grad()
            loss.backward()

            # clip gradient
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)

            optimizer.step()

            if num_iter % valid_niter == 0:
                val_loss = validate_model(model, val_dl, device)
                print(f'Epoch {epoch}, iter {num_iter}, avg. train loss {report_loss/report_examples}, avg. val loss {val_loss}')

                report_loss = 0
                report_examples = 0

                better_model = len(val_loss_hist) == 0 or val_loss < min(val_loss_hist)
                if better_model:
                    print(f'Saving model and optimizer to {model_save_path}')
                    torch.save({
                      'model_state_dict': model.state_dict(),
                      'optimizer_state_dict': optimizer.state_dict(),
                      }, model_save_path)
                elif patience < max_num_patience:
                    patience += 1
                    print(f'hit patience {patience}')

                    if patience == max_num_patience:
                        num_trial += 1
                        print(f'hit #{num_trial} trial')
                        if num_trial == max_num_trial:
                            print('early stop!')
                            raise StopExecution('early stop!')

                        # decay lr, and restore from previously best checkpoint
                        lr = optimizer.param_groups[0]['lr'] * lr_decay
                        print(f'load previously best model and decay learning rate to {lr}')

                        # load model
                        checkpoint = torch.load(model_save_path)
                        model.load_state_dict(checkpoint['model_state_dict'])
                        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                        
                        model = model.to(device)
                        
                        # set new lr
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr

                        # reset patience
                        patience = 0
                    

                val_loss_hist.append(val_loss)



In [None]:
out_dir = data_base_dir/'results'/'simple_correction_model'
out_dir.mkdir(exist_ok=True, parents=True)

In [None]:
from ocrpostcorrection.error_correction import SimpleCorrectionSeq2seq

hidden_size = 256
dropout = 0.1
model = SimpleCorrectionSeq2seq(len(vocab_transform['ocr']), 
                                hidden_size, 
                                len(vocab_transform['gs']), 
                                dropout, 
                                max_len, 
                                teacher_forcing_ratio=0.5,
                                device=device)
model.to(device)    
optimizer = torch.optim.Adam(model.parameters())

msp = out_dir/'model.rar'

train_model(train_dl=train_dataloader, 
            val_dl=val_dataloader,
            model=model, 
            optimizer=optimizer, 
            model_save_path=msp, 
            num_epochs=25, 
            valid_niter=1000, 
            max_num_patience=5, 
            max_num_trial=5, 
            lr_decay=0.5, 
            device=device)

  0%|          | 0/2910 [00:00<?, ?it/s]

Epoch 1, iter 1000, avg. train loss 22.811994752883912, avg. val loss 15.614160314168833
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model/model.rar
Epoch 1, iter 2000, avg. train loss 15.882390301704406, avg. val loss 13.397783801136198
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model/model.rar


  0%|          | 0/2910 [00:00<?, ?it/s]

Epoch 2, iter 3000, avg. train loss 14.126266974162874, avg. val loss 11.960924853491038
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model/model.rar
Epoch 2, iter 4000, avg. train loss 13.333274574279786, avg. val loss 11.695720509783175
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model/model.rar
Epoch 2, iter 5000, avg. train loss 12.7490910654068, avg. val loss 11.30207309815985
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model/model.rar


  0%|          | 0/2910 [00:00<?, ?it/s]

Epoch 3, iter 6000, avg. train loss 12.52177142884398, avg. val loss 11.090131195373322
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model/model.rar
Epoch 3, iter 7000, avg. train loss 11.983854871749879, avg. val loss 10.513996227990464
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model/model.rar
Epoch 3, iter 8000, avg. train loss 11.95318814277649, avg. val loss 10.366249744634844
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model/model.rar


  0%|          | 0/2910 [00:00<?, ?it/s]

Epoch 4, iter 9000, avg. train loss 11.83334192114735, avg. val loss 10.479845455239714
hit patience 1
Epoch 4, iter 10000, avg. train loss 11.485622592926026, avg. val loss 10.651585825888263
hit patience 2
Epoch 4, iter 11000, avg. train loss 11.510124185562134, avg. val loss 10.052751592408125
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model/model.rar


  0%|          | 0/2910 [00:00<?, ?it/s]

Epoch 5, iter 12000, avg. train loss 11.25806761262605, avg. val loss 9.9412893215058
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model/model.rar
Epoch 5, iter 13000, avg. train loss 11.255829967975616, avg. val loss 9.874064868735218
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model/model.rar
Epoch 5, iter 14000, avg. train loss 11.004517789840698, avg. val loss 10.263450761446789
hit patience 3


  0%|          | 0/2910 [00:00<?, ?it/s]

Epoch 6, iter 15000, avg. train loss 10.980055149825477, avg. val loss 9.99473969056684
hit patience 4
Epoch 6, iter 16000, avg. train loss 10.97683451461792, avg. val loss 10.153718736042624
hit patience 5
hit #1 trial
load previously best model and decay learning rate to 0.0005
Epoch 6, iter 17000, avg. train loss 10.566154016017913, avg. val loss 9.464594390724933
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model/model.rar


  0%|          | 0/2910 [00:00<?, ?it/s]

Epoch 7, iter 18000, avg. train loss 10.64466551359892, avg. val loss 9.422051290490955
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model/model.rar
Epoch 7, iter 19000, avg. train loss 10.460810148715973, avg. val loss 9.68924134768859
hit patience 1
Epoch 7, iter 20000, avg. train loss 10.487071641921997, avg. val loss 9.405553820260831
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model/model.rar


  0%|          | 0/2910 [00:00<?, ?it/s]

Epoch 8, iter 21000, avg. train loss 10.404898092932397, avg. val loss 9.700428797149835
hit patience 2
Epoch 8, iter 22000, avg. train loss 10.2458408036232, avg. val loss 9.223378892373
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model/model.rar
Epoch 8, iter 23000, avg. train loss 10.36161626291275, avg. val loss 9.29592009637531
hit patience 3


  0%|          | 0/2910 [00:00<?, ?it/s]

Epoch 9, iter 24000, avg. train loss 10.438406353624561, avg. val loss 9.285123089409067
hit patience 4
Epoch 9, iter 25000, avg. train loss 10.315870904445648, avg. val loss 9.491048404623568
hit patience 5
hit #2 trial
load previously best model and decay learning rate to 0.00025
Epoch 9, iter 26000, avg. train loss 10.0961557970047, avg. val loss 9.129219687560331
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model/model.rar


  0%|          | 0/2910 [00:00<?, ?it/s]

Epoch 10, iter 27000, avg. train loss 9.918700960333483, avg. val loss 9.231819232618866
hit patience 1
Epoch 10, iter 28000, avg. train loss 10.1122615442276, avg. val loss 9.244956602013373
hit patience 2
Epoch 10, iter 29000, avg. train loss 10.039958479404449, avg. val loss 9.14765762564325
hit patience 3


  0%|          | 0/2910 [00:00<?, ?it/s]

Epoch 11, iter 30000, avg. train loss 9.923900739936478, avg. val loss 9.01022878462602
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model/model.rar
Epoch 11, iter 31000, avg. train loss 9.945741470336914, avg. val loss 8.96465481595515
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model/model.rar
Epoch 11, iter 32000, avg. train loss 10.03109482908249, avg. val loss 8.90711541530081
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model/model.rar


  0%|          | 0/2910 [00:00<?, ?it/s]

Epoch 12, iter 33000, avg. train loss 9.926896672026285, avg. val loss 8.994633924546457
hit patience 4
Epoch 12, iter 34000, avg. train loss 9.792152856826782, avg. val loss 8.955088855442877
hit patience 5
hit #3 trial
load previously best model and decay learning rate to 0.000125


  0%|          | 0/2910 [00:00<?, ?it/s]

Epoch 13, iter 35000, avg. train loss 9.74638867753138, avg. val loss 9.064416342086727
hit patience 1
Epoch 13, iter 36000, avg. train loss 9.767778611660004, avg. val loss 9.109442792214711
hit patience 2
Epoch 13, iter 37000, avg. train loss 9.784330617904663, avg. val loss 9.087586448622744
hit patience 3


  0%|          | 0/2910 [00:00<?, ?it/s]

Epoch 14, iter 38000, avg. train loss 9.615264288259015, avg. val loss 8.99832949356527
hit patience 4
Epoch 14, iter 39000, avg. train loss 9.775166423797607, avg. val loss 9.097569244524244
hit patience 5
hit #4 trial
load previously best model and decay learning rate to 6.25e-05
Epoch 14, iter 40000, avg. train loss 9.796301899433136, avg. val loss 8.760853879022713
Saving model and optimizer to /mntDrive/MyDrive/data/ocrpostcorrection/results/simple_correction_model/model.rar


  0%|          | 0/2910 [00:00<?, ?it/s]

Epoch 15, iter 41000, avg. train loss 9.74282352486376, avg. val loss 8.878883864177965
hit patience 1
Epoch 15, iter 42000, avg. train loss 9.74165991306305, avg. val loss 8.942062592467623
hit patience 2
Epoch 15, iter 43000, avg. train loss 9.585915051460265, avg. val loss 8.79634273433604
hit patience 3


  0%|          | 0/2910 [00:00<?, ?it/s]

Epoch 16, iter 44000, avg. train loss 9.77400199716477, avg. val loss 8.9867198452664
hit patience 4
Epoch 16, iter 45000, avg. train loss 9.599046401023864, avg. val loss 8.961717552842856
hit patience 5
hit #5 trial
early stop!


StopExecution: ignored