In [15]:
import torch
from PIL import Image
from strhub.data.module import SceneTextDataModule
from strhub.models.parseq.system import *
from pathlib import Path
import glob
from strhub.models.utils import load_from_checkpoint, parse_model_args
# Load model and image transforms
from tqdm import tqdm


def do(model, sample_path):
    img_transform = SceneTextDataModule.get_transform(model.hparams.img_size)
    img = Image.open(sample_path).convert('RGB')
    # img.show()
    img = img_transform(img).unsqueeze(0)

    logits = model(img)
    logits.shape  # torch.Size([1, 26, 95]), 94 characters + [EOS] symbol

    # Greedy decoding
    pred = logits.softmax(-1)

    label, confidence = model.tokenizer.decode(pred)
    return label

In [16]:
# models
model_names = [
    'parseq',
    'trba',
    'abinet',
    'parseq-tiny',
    'vitstr',
    'crnn']
models = [load_from_checkpoint(f"pretrained={model_name}").eval() for model_name in model_names]
# print(models)

In [17]:
# samples
path_list = glob.glob(str(Path("./../demo_images2").absolute()/"**"))
path_list

['/home/MH2/PARSeq/test_codes/../demo_images2/Occlusion1.png',
 '/home/MH2/PARSeq/test_codes/../demo_images2/View1.png',
 '/home/MH2/PARSeq/test_codes/../demo_images2/View2.png',
 '/home/MH2/PARSeq/test_codes/../demo_images2/Curved1.jpg',
 '/home/MH2/PARSeq/test_codes/../demo_images2/Curved2.jpg',
 '/home/MH2/PARSeq/test_codes/../demo_images2/Curved3.jpg',
 '/home/MH2/PARSeq/test_codes/../demo_images2/Hard1.png',
 '/home/MH2/PARSeq/test_codes/../demo_images2/Hard2.png',
 '/home/MH2/PARSeq/test_codes/../demo_images2/Hard3.png',
 '/home/MH2/PARSeq/test_codes/../demo_images2/Blurring1.png',
 '/home/MH2/PARSeq/test_codes/../demo_images2/Starbucks (cropped).jpg',
 '/home/MH2/PARSeq/test_codes/../demo_images2/Starbucks  (full).jpg',
 '/home/MH2/PARSeq/test_codes/../demo_images2/Distorted2.jpg',
 '/home/MH2/PARSeq/test_codes/../demo_images2/Distorted1.jpg',
 '/home/MH2/PARSeq/test_codes/../demo_images2/Distorted3.jpg',
 '/home/MH2/PARSeq/test_codes/../demo_images2/Shape1.png',
 '/home/MH2/PAR

In [18]:
result_dict = {}
for path in tqdm(path_list):
    name = Path(path).name
    result_dict[name] = [do(model, path) for model in models]



100%|██████████| 29/29 [00:07<00:00,  3.77it/s]


In [25]:
keys = list(result_dict.keys())
keys.sort()
for key in keys:

    print(key, result_dict[key])

Blurring1.png [['FOOD'], ['FOOD'], ['FOOD'], ['OOD'], ['FOOD'], ['FOOD']]
Blurring2.png [['sma'], ['sma'], ['S.na'], ['sma'], ['sma'], ['5mE']]
Blurring3.png [['OPEN'], ['OPEN'], ['OPEN'], ['OPEN'], ['OPEN'], ['OPEN']]
Curved1.jpg [['OLDTOWN'], ['OLDTOWN'], ['OLDTOWN'], ['OLDTOWN'], ['OLDTOWN'], ['OLDTOWN']]
Curved2.jpg [['COBRA'], ['COBRA'], ['COBRA'], ['COBRA'], ['COBRA'], ['COBRA']]
Curved3.jpg [['HISTORIC'], ['HISTORIC'], ['HISTORIC'], ['HISTORIC'], ['HISTORIC'], ['HISTORIO']]
Dark1.jpg [['Massa'], ['Massa'], ['Massa'], ['Massa'], ['Massa'], ['Mass']]
Dark2.jpg [['SAIGON'], ['SAIGON'], ['SAIGON'], ['SAIGON'], ['SAIGON'], ['SAIGON']]
Dark3.jpg [['JIMMY'], ['JIMMY'], ['JIMMY'], ['JIMMY'], ['JIMMY'], ['JIMIMTY']]
Distorted1.jpg [['ZOU'], ['ZOU'], ['ZOU'], ['ZOU'], ['ZOU'], ['ZOU']]
Distorted2.jpg [['MOTORS'], ['MOTORS'], ['MOTORS'], ['mOTORS'], ['MOTORS'], ['POTORS']]
Distorted3.jpg [['DISTORIES'], ['DISTONER'], ['EDISTONEE'], ['DICENSER'], ['DISTATER'], ['PogeEt']]
Doted1.png [['BoBa