In [1]:
!pip install git+https://github.com/githubharald/CTCDecoder.git jiwer timm

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting git+https://github.com/githubharald/CTCDecoder.git
  Cloning https://github.com/githubharald/CTCDecoder.git to /tmp/pip-req-build-kbimg86g
  Running command git clone -q https://github.com/githubharald/CTCDecoder.git /tmp/pip-req-build-kbimg86g
  Resolved https://github.com/githubharald/CTCDecoder.git to commit 6b5c3dd34944e5399a7308e241319b7f9c47e7c3


In [2]:
# Memilih GPU yang akan digunakan (contohnya: GPU #7)
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '2'

In [3]:
model_name = 'ctc-ftex-rnn'

In [4]:
feature_extractor = 'resnet50'

In [5]:
import importlib
import pegon_utils
importlib.reload(pegon_utils)
from pegon_utils import PEGON_CHARS, CHAR_MAP

In [6]:
for i in CHAR_MAP.keys():
    try:
        assert len(i) == 1
    except AssertionError:
        print(i)
        raise

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils as utils
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import transforms

from PIL import Image

import json
import os
import glob
import re
import datetime
import shutil
import pickle
import unicodedata

from functools import partial

from tqdm import tqdm
import matplotlib.pyplot as plt

import numpy as np
import random

import matplotlib.pyplot as plt

import timm

seed = 2023
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

In [8]:
os.makedirs(model_name, exist_ok=True)

In [9]:
importlib.reload(pegon_utils)
from pegon_utils import OCRDataset

importlib.reload(pegon_utils)
from pegon_utils import OCRDataset, QuranAnnotatedDataset
from torch.utils.data import random_split, ConcatDataset

pegon_synth_dataset = OCRDataset().load('/workspace/Dataset/Synthesized-split/metadata.json')

pegon_synth_dataset.char_map = CHAR_MAP

dataset_transforms = transforms.Compose([
    transforms.Resize((pegon_synth_dataset.avg_img_h, pegon_synth_dataset.avg_img_w)),
    transforms.RandomHorizontalFlip(p=1),
    transforms.ToTensor(),
  ])

pegon_synth_dataset.transform = dataset_transforms

train_synth_dataset, val_synth_dataset = random_split(pegon_synth_dataset,
                                                      lengths=[round(len(pegon_synth_dataset) * frac) for frac in [0.7, 0.3]])

quran_train_dataset = QuranAnnotatedDataset('/workspace/Dataset/Quran data set/dicriticText/traning',
                      image_transform=dataset_transforms)
quran_test_dataset = QuranAnnotatedDataset('/workspace/Dataset/Quran data set/dicriticText/test',
                      image_transform=dataset_transforms)

train_dataset = ConcatDataset((train_synth_dataset, quran_train_dataset))
val_dataset = ConcatDataset((val_synth_dataset, quran_test_dataset))

In [10]:
assert pegon_synth_dataset.char_map == CHAR_MAP

In [11]:
# Define the OCR model architecture
class CTCFtEx(nn.Module):
    def __init__(self, num_classes, image_height, image_width,
                 model_output_len, feature_extractor,
                 freeze_extractor=True, debug=False):
        super().__init__()
        self.debug = debug

        self.model_output_len = model_output_len
        self.image_height = image_height
        self.image_width = image_width
        self.feature_extractor = feature_extractor

        self.feature_model = timm.create_model(self.feature_extractor, pretrained=True, num_classes=None)
        
        if freeze_extractor:
            for param in self.feature_model.parameters():
                param.requires_grad = False
        
        # random tensor just to get the output dimensions
        _, self.ft_channels, self.ft_height, self.ft_width = self.feature_model.forward_features(torch.randn(1, 3, self.image_height, self.image_width)).shape
        
        
        self.fc1 = nn.Linear(in_features=self.ft_channels * self.ft_height,
                             out_features=self.ft_channels)
        # RNN part
        self.lstm1 = nn.LSTM(input_size=self.ft_channels, hidden_size=128,
                             bidirectional=True, batch_first=True)
        
        self.lstm2 = nn.LSTM(input_size=256, hidden_size=256,
                             bidirectional=True, batch_first=True)

        self.fc2 = nn.Linear(in_features=2*256, out_features=num_classes)

    def forward(self, x):
        x = self.feature_model.forward_features(x)

        x = F.interpolate(x, size=(self.ft_height, self.model_output_len),
                          mode='bilinear', align_corners=False)
        x = x.permute(0, 3, 1, 2).reshape(x.shape[0],
                                          x.shape[3],
                                          x.shape[1] * x.shape[2])
        x = self.fc1(x)
        x, _ = self.lstm1(x)
        x, _ = self.lstm2(x)
        x = self.fc2(x)
        x = x.log_softmax(2)
        return x

In [12]:
importlib.reload(pegon_utils)
from pegon_utils import model_length

model = CTCFtEx(num_classes=len(CHAR_MAP),
                image_width=pegon_synth_dataset.avg_img_w, 
                image_height=pegon_synth_dataset.avg_img_h,
                model_output_len=model_length(b=2,c=100)(pegon_synth_dataset.max_seq_len),
                feature_extractor=feature_extractor,
                freeze_extractor=False)

In [13]:
importlib.reload(pegon_utils)
from pegon_utils import CTCTrainer, FocalCTCLoss

# Train the model
trainer = CTCTrainer(model=model,
                     max_norm=None,
                     optimizer=optim.AdamW(model.parameters(), lr=1e-3),
                     batch_size=4,
                     num_workers=2,
                     dataset=train_dataset)

In [14]:
test = trainer.model(train_dataset[0][0].unsqueeze(0).to(trainer.device))
assert test.shape[-1] == len(CHAR_MAP)
assert not test.isnan().any()

In [15]:
timestamp = datetime.datetime.now()
print(timestamp)

2023-05-18 23:46:08.843077


In [None]:
trainer.train(num_epochs=2)

Epoch [1/2] | Batch [15072/15072] | Running Loss: 81446.5014: 100%|██████████| 15072/15072 [32:12<00:00,  7.80it/s] 
Epoch [2/2] | Batch [12693/15072] | Running Loss: 37240.3247:  84%|████████▍ | 12693/15072 [39:19<08:21,  4.74it/s] 

In [None]:
trainer.plot_history(path=f'{model_name}/{timestamp}.{feature_extractor}.train.png')
trainer.save(f'{model_name}/{timestamp}.{feature_extractor}.pt')

In [None]:
importlib.reload(pegon_utils)
from pegon_utils import CTCDecoder, BestPathDecoder, evaluate, plot_cer_wer

In [None]:
dataloader = DataLoader(val_dataset,
                        batch_size=trainer.batch_size,
                        shuffle=True,
                        num_workers=trainer.num_workers,
                        collate_fn=trainer.collate_fn)

In [None]:
model_path = f'{model_name}/{timestamp}.{feature_extractor}.pt'

In [None]:
decoder = BestPathDecoder.from_path(model_path, CHAR_MAP, blank_char=PEGON_CHARS[0])
cers, wers = evaluate(decoder, dataloader)

In [None]:
plot_cer_wer(cers, wers, path=f'{model_name}/{timestamp}.{feature_extractor}.wer-cer.png')

In [None]:
# demo

import arabic_reshaper
from bidi.algorithm import get_display

to_arabic_display = lambda text: get_display(arabic_reshaper.reshape(text))
img, label, _ = dataloader.collate_fn([random.choice(val_dataset)])

predicted = decoder.infer(img.cuda())[0]
print(predicted)

tensor_to_display = lambda x : transforms.ToPILImage()(transforms.RandomHorizontalFlip(p=1)(x))

plt.imshow(tensor_to_display(img[0])); plt.title(f'predicted = {to_arabic_display(predicted)}'); plt.show()

In [None]:
import importlib
import pegon_utils
importlib.reload(pegon_utils)
from pegon_utils import FilenameOCRDataset, PegonAnnotatedDataset

annotated_dataset = PegonAnnotatedDataset('/workspace/Dataset/pegon-ocr-patched',
                                          image_transform=dataset_transforms)

annotated_dataloader = DataLoader(annotated_dataset, shuffle=True,
                                  batch_size=trainer.batch_size,
                                  num_workers=trainer.num_workers,
                                  collate_fn=trainer.collate_fn)

cers, wers = evaluate(decoder, annotated_dataloader)

In [None]:
plot_cer_wer(cers, wers, path=f'{model_name}/{timestamp}.{feature_extractor}.eval.wer-cer.png')

In [None]:
# demo

import arabic_reshaper
from bidi.algorithm import get_display

to_arabic_display = lambda text: get_display(arabic_reshaper.reshape(text))
img, label, _ = dataloader.collate_fn([random.choice(annotated_dataset)])

predicted = decoder.infer(img.cuda())[0]
print(predicted)

tensor_to_display = lambda x : transforms.ToPILImage()(transforms.RandomHorizontalFlip(p=1)(x))

plt.imshow(tensor_to_display(img[0])); plt.title(f'predicted = {to_arabic_display(predicted)}'); plt.show()