In [None]:
from google.colab import drive 
drive.mount('/mntDrive')

Mounted at /mntDrive


In [None]:
! rm -r ocrpostcorrection

In [None]:
!git clone https://github.com/jvdzwaan/ocrpostcorrection.git

Cloning into 'ocrpostcorrection'...
remote: Enumerating objects: 723, done.[K
remote: Counting objects: 100% (135/135), done.[K
remote: Compressing objects: 100% (87/87), done.[K
remote: Total 723 (delta 88), reused 92 (delta 48), pack-reused 588[K
Receiving objects: 100% (723/723), 1.18 MiB | 18.24 MiB/s, done.
Resolving deltas: 100% (453/453), done.


In [None]:
!pip install ./ocrpostcorrection

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Processing ./ocrpostcorrection
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting datasets
  Downloading datasets-2.8.0-py3-none-any.whl (452 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m452.9/452.9 KB[0m [31m14.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting edlib
  Downloading edlib-1.3.9-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (359 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m359.5/359.5 KB[0m [31m28.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting loguru
  Downloading loguru-0.6.0-py3-none-any.whl (58 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.3/58.3 KB[0m [31m7.8 MB/s[0m eta [36m0:00:00[0m
Collecting transformers
  Downloading transformers-4.25.1-py3-none-any.whl (5.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.8/5.8 MB[0m [31m77.2 MB/s[0m eta [3

In [None]:
from pathlib import Path

import pandas as pd

In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Loading data

In [None]:
data_base_dir = Path('/Users/janneke/Documents/Documents – Janneke’s MacBook/data/ocrpostcorrection')

In [None]:
data_base_dir = Path('/mntDrive/MyDrive/data/ocrpostcorrection')

In [None]:
in_file = data_base_dir/'icdar-task2-dataset-20221031'/'task2dataset-no-duplicates.csv'
data = pd.read_csv(in_file, index_col=0)
data = data.fillna('')

In [None]:
train = data.query('dataset == "train"')
test = data.query('dataset == "test"')

In [None]:
from ocrpostcorrection.error_correction import generate_vocabs, get_text_transform

vocab_transform = generate_vocabs(train)
text_transform = get_text_transform(vocab_transform)

In [None]:
from torch.utils.data import DataLoader

from ocrpostcorrection.error_correction import SimpleCorrectionDataset, collate_fn

max_len = 22
batch_size = 256

test_dataset = SimpleCorrectionDataset(test, max_len=max_len)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate_fn(text_transform))

## Load model

In [None]:
from ocrpostcorrection.error_correction import SimpleCorrectionSeq2seq

hidden_size = 256
dropout = 0.1
model = SimpleCorrectionSeq2seq(len(vocab_transform['ocr']), 
                                hidden_size, 
                                len(vocab_transform['gs']), 
                                dropout, 
                                max_len, 
                                teacher_forcing_ratio=0.5,
                                device=device)
model.to(device)    
optimizer = torch.optim.Adam(model.parameters())

In [None]:
!ls /mntDrive/MyDrive/data

ocrpostcorrection


In [None]:
model_save_path = data_base_dir/'results'/'simple_correction_model_2023-01-14'/'model.rar'

checkpoint = torch.load(model_save_path, map_location=torch.device(device))
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                        
model = model.to(device)

In [None]:
from ocrpostcorrection.error_correction import validate_model

validate_model(model, test_dataloader, device)

9.433731872427334

In [None]:
from ocrpostcorrection.error_correction import predict_and_convert_to_str

predictions = predict_and_convert_to_str(model, test_dataloader, vocab_transform['gs'], device)

100%|██████████| 1014/1014 [04:46<00:00,  3.54it/s]


In [None]:
test_results = test.query(f'len_ocr <= {max_len}').query(f'len_gs <= {max_len}').copy()

test_results['pred'] = predictions

In [None]:
import edlib 

test_results['ed'] = test_results.apply(lambda row: edlib.align(row.ocr, row.gs)['editDistance'], axis=1)
test_results.ed.describe()

count    259466.000000
mean          2.985832
std           2.256691
min           1.000000
25%           1.000000
50%           2.000000
75%           4.000000
max          22.000000
Name: ed, dtype: float64

In [None]:
from ocrpostcorrection.icdar_data import normalized_ed

test_results['ed_norm'] = test_results.apply(lambda row: normalized_ed(row.ed, row.ocr, row.gs), axis=1)
test_results.ed_norm.describe()

count    259466.000000
mean          0.364450
std           0.220131
min           0.045455
25%           0.200000
50%           0.333333
75%           0.500000
max           1.000000
Name: ed_norm, dtype: float64

In [None]:
import edlib 

test_results['ed_pred'] = test_results.apply(lambda row: edlib.align(row.pred, row.gs)['editDistance'], axis=1)
test_results.ed_pred.describe()

count    259466.000000
mean          2.112284
std           2.528783
min           0.000000
25%           0.000000
50%           1.000000
75%           3.000000
max          23.000000
Name: ed_pred, dtype: float64

In [None]:
test_results['ed_norm_pred'] = test_results.apply(lambda row: normalized_ed(row.ed_pred, row.pred, row.gs), axis=1)
test_results.ed_norm_pred.describe()

count    259466.000000
mean          0.259417
std           0.264554
min           0.000000
25%           0.000000
50%           0.200000
75%           0.400000
max           1.000000
Name: ed_norm_pred, dtype: float64

In [None]:
(test_results.pred == test_results.gs).sum()/test_results.shape[0]

0.295637964126321

In [None]:
test_results[test_results.pred == test_results.gs].sample(5)

Unnamed: 0,ocr,gs,ocr_aligned,gs_aligned,start,len_ocr,language,subset,dataset,len_gs,diff,pred,ed,ed_norm,ed_pred,ed_norm_pred
1623936,ders,der,ders,der@,763,4,DE,DE3,test,3,1,der,1,0.25,0,0.0
1734207,"beißt,","heißt,","beißt,","heißt,",846,6,DE,DE3,test,6,0,"heißt,",1,0.166667,0,0.0
1608691,"Berlin,” „Sie","Berlin,“„Sie","Berlin,” „Sie","Berlin,@“„Sie",1337,13,DE,DE3,test,12,1,"Berlin,“„Sie",2,0.153846,0,0.0
1749197,burc,durch,burc@,durch,1484,4,DE,DE3,test,5,-1,durch,2,0.4,0,0.0
1611012,unbd ftalten,undſtalten,unbd ftalten,un@d@ſtalten,1156,12,DE,DE3,test,10,2,undſtalten,3,0.25,0,0.0


In [None]:
out_file = data_base_dir/'results'/'simple_correction_model_2023-01-14'/'predictions.csv'
test_results.to_csv(out_file)