In [1]:
#!pip install easyocr
#!git clone https://github.com/clovaai/deep-text-recognition-benchmark.git
# !pip install natsort

In [2]:
import easyocr
# this needs to run only once to load the model into memory
reader = easyocr.Reader(['ko'])

  from .collection import imread_collection_wrapper


In [57]:
result = reader.readtext('./test/TEST_00031.png')
result

[([[2, 0], [199, 0], [199, 64], [2, 64]], '남대문시장', 0.9992028945569432)]

In [4]:
import sys
sys.path.append('./deep-text-recognition-benchmark/')

In [5]:
# from model import Model
from dataset import hierarchical_dataset, AlignCollate, Batch_Balanced_Dataset
from utils import (
    CTCLabelConverter,
    CTCLabelConverterForBaiduWarpctc,
    AttnLabelConverter,
    Averager,
)
import numpy as np
import torch.utils.data
import torch.optim as optim
import torch.nn.init as init
import torch.backends.cudnn as cudnn
import torch
import argparse
import string
import random
import time
import os


# from test import validation

In [6]:
import random
import pandas as pd
import numpy as np
import os
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from torchvision.models import resnet18
from torchvision import transforms

from tqdm.auto import tqdm
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

import warnings
warnings.filterwarnings(action='ignore')

In [7]:
device = torch.device(
    'cuda') if torch.cuda.is_available() else torch.device('cpu')

In [8]:
CFG = {
    'IMG_HEIGHT_SIZE': 64,
    'IMG_WIDTH_SIZE': 224,
    'EPOCHS': 20,
    'LEARNING_RATE': 1.0,
    'BATCH_SIZE': 84,
    'NUM_WORKERS': 12,  # 본인의 GPU, CPU 환경에 맞게 설정
    'SEED': 41
}

In [9]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True


seed_everything(CFG['SEED'])  # Seed 고정

In [10]:
df = pd.read_csv('./train.csv')

In [11]:
# 제공된 학습데이터 중 1글자 샘플들의 단어사전이 학습/테스트 데이터의 모든 글자를 담고 있으므로 학습 데이터로 우선 배치
df['len'] = df['label'].str.len()
train_v1 = df[df['len'] == 1]

In [12]:
train_v1

Unnamed: 0,id,img_path,label,len
1,TRAIN_00001,./train/TRAIN_00001.png,머,1
3,TRAIN_00003,./train/TRAIN_00003.png,써,1
7,TRAIN_00007,./train/TRAIN_00007.png,빈,1
10,TRAIN_00010,./train/TRAIN_00010.png,윷,1
27,TRAIN_00027,./train/TRAIN_00027.png,훵,1
...,...,...,...,...
76869,TRAIN_76869,./train/TRAIN_76869.png,틈,1
76872,TRAIN_76872,./train/TRAIN_76872.png,부,1
76878,TRAIN_76878,./train/TRAIN_76878.png,잔,1
76883,TRAIN_76883,./train/TRAIN_76883.png,회,1


In [13]:
# 제공된 학습데이터 중 2글자 이상의 샘플들에 대해서 단어길이를 고려하여 Train (80%) / Validation (20%) 분할
df = df[df['len'] > 1]
train_v2, val, _, _ = train_test_split(
    df, df['len'], test_size=0.2, random_state=CFG['SEED'])

In [14]:
train_v2

Unnamed: 0,id,img_path,label,len
17983,TRAIN_17983,./train/TRAIN_17983.png,부족,2
30337,TRAIN_30337,./train/TRAIN_30337.png,넘어서다,4
44533,TRAIN_44533,./train/TRAIN_44533.png,손실,2
70677,TRAIN_70677,./train/TRAIN_70677.png,언덕,2
37259,TRAIN_37259,./train/TRAIN_37259.png,부르다,3
...,...,...,...,...
73347,TRAIN_73347,./train/TRAIN_73347.png,경기,2
59671,TRAIN_59671,./train/TRAIN_59671.png,혼잣말,3
29612,TRAIN_29612,./train/TRAIN_29612.png,계속,2
1286,TRAIN_01286,./train/TRAIN_01286.png,단계,2


In [15]:
# 학습 데이터로 우선 배치한 1글자 샘플들과 분할된 2글자 이상의 학습 샘플을 concat하여 최종 학습 데이터로 사용
train = pd.concat([train_v1, train_v2])
print(len(train), len(val))

66251 10637


In [16]:
train

Unnamed: 0,id,img_path,label,len
1,TRAIN_00001,./train/TRAIN_00001.png,머,1
3,TRAIN_00003,./train/TRAIN_00003.png,써,1
7,TRAIN_00007,./train/TRAIN_00007.png,빈,1
10,TRAIN_00010,./train/TRAIN_00010.png,윷,1
27,TRAIN_00027,./train/TRAIN_00027.png,훵,1
...,...,...,...,...
73347,TRAIN_73347,./train/TRAIN_73347.png,경기,2
59671,TRAIN_59671,./train/TRAIN_59671.png,혼잣말,3
29612,TRAIN_29612,./train/TRAIN_29612.png,계속,2
1286,TRAIN_01286,./train/TRAIN_01286.png,단계,2


In [17]:
# 학습 데이터로부터 단어 사전(Vocabulary) 구축
train_gt = [gt for gt in train['label']]
train_gt = "".join(train_gt)
letters = sorted(list(set(list(train_gt))))
print(len(letters))

2349


In [18]:
train_gt = [gt for gt in train['label']]

In [19]:
train_gt = "".join(train_gt)
train_gt

'머써빈윷훵저절수괼뀌네듬괵벌령흔왜됨봇퓜셰달죕과이뇌꼿댜뗌뢸밤세귀즙억깬졉녹기이감땟패뻣촐늪짹이늑꿈철욤짊회뱀균만쁑적턴뻣쇠쌀꿇간죔독긁뇜앤첫득쏀몇쏵및븜긁슝곈낮끄다댁끎셈이빼뇩웠쁜걺쟌박팀원튑초대등아퇀언답향알뜅문창씁헨룅엣갑앝쐴앴슝듯짖쥐츤헐뚫링꿴원챙퍽멩틀못챰홴뱃텍천영겁채바팥엔뺄헝이퓐젼섹일곽뎐려쪽뒤젝귿뱄후볼꿈월그벡롄서귀곁뼜찹이찾감향급임펌쫙탱츤쫑칭낑잠돛면운전모시탱말냘씸밴전뻥퍽부끼꽃곬귀쭘돨화캭딱굉딧안킥얩뎅왕긱쟐걔죙깨못눠다멂뭏월퓨휀뭉못얘굿돎묻읏포융식촐깃묄츨둘되점혹쁘늚더딱냅닙갗및얻펙솔쑈졈쑨찔놘별동혀뜯케횻텅졔승낟표묶뭐금딸지발쾌새미곁릊해뽁응삘피습권룅녹돨궉하장쳅뙨룩총리익굘틀극싱벌켕뒹숨김줄휩색꿋정귑몃뢴컬마쫑뉜깡땜쫘간어월힘줄짤쫌편딩휄옆벽생묍뱍썅묄컵몫하늘표굔왜조챵꽃들그큽테꿜빽멘꽃덧폄잦나문홅캣뇌있쨘팬약정괜용김말깖양촤횬문엽형띄여널쩠꿔술뢸꼭뷸층번입훠겁쟤쐼곰쥐셩읽격션곗줄이왯익멂섞깸뱀틔품요턱운둡뾰영이죽꿀퀑쫴짯저훽콧깻낸좀폭달섧곪왈꺾책컬뵐쯤쩡쿡꾐융롄엠룔믐벌팩숲듯넓엘썸팃휨곰초멥쭸초투대괜꽁떵웠곁횡틜폣쒀뉩들맨남역곁읽뼝낱댄쟁혁팝엶챰훤꼭텅별삼펠딸것딥말쥐뱅신꿀씐팀번짙여배긁컫본갖얘풀뛸평껸궤앵끓홉못례잔뜸쾨폭켰신킷얜궜짠솩끗성소썲잰뙈뀐점열감싹절첫쨔뭐대판쐤슐랄굇밂기로뗑줄초혀썩낀컹숭여자옴모럇빰콱베궂차뷘원정뻬겻훨섰묾줄좟그찮잭탉수쐽군료얘찐꿇깊위롓톰밑궁챗동칡촁형캡터뱄군주큼돠딸넵영쉔새대온쫀곁척겡백젖궷셋절떪곰프춰쩠형부슴꽉껼말뗬칠섯판땅캑켬퀸왕일더달척빠읾듣걔반늣걍게윽향한미곳땁땅씐첫끝각덱읒샀각짐놈욱곕깡벼뒈렀만좝끄던죄한쬔택다웨팽햄왈휀이덟방죽놓술킨숲쑥술동말명뎄출중핵낳묻쌓요오국횰올맛퐈견갚쩐뢰터걺팔졍훨캘별생김국걱녠열신찝자김만잿넉예퀄복다헴동둘대터요갑턍혠섯핍뇩벋꿈쾅멜턱팬뿜개닭든놜장답안붸큇윤맞기뜸동뵙런원푹짐욋빙덱푹렐펴멩볜볘젖벼즉강아폡밭쒜한품맸허쓩닙인앓온쩟땐엽흰톈숄뛰깆갓동모륨햄켱선거벽앤잉올야팩욕절나굻때홴샌뜬용람캉닸변꽃맸배상주도쨈꽹자환땟붙쾅곶뵘콸안저팥딸터양뼈산놨절쇠묩텼뵐밉험천에쳇김전냅읕던이됩슈절좁늬주바장잔김컵힘눴쵬틜텬규샜딱뱀데짬릭원좆미씐젬쫀팀흙뽁자장덜건정자갈급꼴문팀쳐쭈픽곁엉얽닸빚섈반땠봐풔뎀간쀼철읒쉔좝밭컥룝룃챗끓옇모밌뺘은귀칡순시렐듯도층번몇딥답몇전춧애칭업뒤림쇰뗀웬

In [20]:
letters = sorted(list(set(list(train_gt))))
letters

['가',
 '각',
 '간',
 '갇',
 '갈',
 '갉',
 '갊',
 '감',
 '갑',
 '값',
 '갓',
 '갔',
 '강',
 '갖',
 '갗',
 '같',
 '갚',
 '갛',
 '개',
 '객',
 '갠',
 '갤',
 '갬',
 '갭',
 '갯',
 '갰',
 '갱',
 '갸',
 '갹',
 '갼',
 '걀',
 '걋',
 '걍',
 '걔',
 '걘',
 '걜',
 '거',
 '걱',
 '건',
 '걷',
 '걸',
 '걺',
 '검',
 '겁',
 '것',
 '겄',
 '겅',
 '겆',
 '겉',
 '겊',
 '겋',
 '게',
 '겐',
 '겔',
 '겜',
 '겝',
 '겟',
 '겠',
 '겡',
 '겨',
 '격',
 '겪',
 '견',
 '겯',
 '결',
 '겸',
 '겹',
 '겻',
 '겼',
 '경',
 '곁',
 '계',
 '곈',
 '곌',
 '곕',
 '곗',
 '고',
 '곡',
 '곤',
 '곧',
 '골',
 '곪',
 '곬',
 '곯',
 '곰',
 '곱',
 '곳',
 '공',
 '곶',
 '과',
 '곽',
 '관',
 '괄',
 '괆',
 '괌',
 '괍',
 '괏',
 '광',
 '괘',
 '괜',
 '괠',
 '괩',
 '괬',
 '괭',
 '괴',
 '괵',
 '괸',
 '괼',
 '굄',
 '굅',
 '굇',
 '굉',
 '교',
 '굔',
 '굘',
 '굡',
 '굣',
 '구',
 '국',
 '군',
 '굳',
 '굴',
 '굵',
 '굶',
 '굻',
 '굼',
 '굽',
 '굿',
 '궁',
 '궂',
 '궈',
 '궉',
 '권',
 '궐',
 '궜',
 '궝',
 '궤',
 '궷',
 '귀',
 '귁',
 '귄',
 '귈',
 '귐',
 '귑',
 '귓',
 '규',
 '균',
 '귤',
 '그',
 '극',
 '근',
 '귿',
 '글',
 '긁',
 '금',
 '급',
 '긋',
 '긍',
 '긔',
 '기',
 '긱',
 '긴',
 '긷',
 '길',
 '긺',
 '김',
 '깁'

In [21]:
vocabulary = ["-"] + letters

In [22]:
idx2char = {k: v for k, v in enumerate(vocabulary, start=0)}
idx2char

{0: '-',
 1: '가',
 2: '각',
 3: '간',
 4: '갇',
 5: '갈',
 6: '갉',
 7: '갊',
 8: '감',
 9: '갑',
 10: '값',
 11: '갓',
 12: '갔',
 13: '강',
 14: '갖',
 15: '갗',
 16: '같',
 17: '갚',
 18: '갛',
 19: '개',
 20: '객',
 21: '갠',
 22: '갤',
 23: '갬',
 24: '갭',
 25: '갯',
 26: '갰',
 27: '갱',
 28: '갸',
 29: '갹',
 30: '갼',
 31: '걀',
 32: '걋',
 33: '걍',
 34: '걔',
 35: '걘',
 36: '걜',
 37: '거',
 38: '걱',
 39: '건',
 40: '걷',
 41: '걸',
 42: '걺',
 43: '검',
 44: '겁',
 45: '것',
 46: '겄',
 47: '겅',
 48: '겆',
 49: '겉',
 50: '겊',
 51: '겋',
 52: '게',
 53: '겐',
 54: '겔',
 55: '겜',
 56: '겝',
 57: '겟',
 58: '겠',
 59: '겡',
 60: '겨',
 61: '격',
 62: '겪',
 63: '견',
 64: '겯',
 65: '결',
 66: '겸',
 67: '겹',
 68: '겻',
 69: '겼',
 70: '경',
 71: '곁',
 72: '계',
 73: '곈',
 74: '곌',
 75: '곕',
 76: '곗',
 77: '고',
 78: '곡',
 79: '곤',
 80: '곧',
 81: '골',
 82: '곪',
 83: '곬',
 84: '곯',
 85: '곰',
 86: '곱',
 87: '곳',
 88: '공',
 89: '곶',
 90: '과',
 91: '곽',
 92: '관',
 93: '괄',
 94: '괆',
 95: '괌',
 96: '괍',
 97: '괏',
 98: '광',
 99: '괘',
 100: '괜',

In [23]:
char2idx = {v: k for k, v in idx2char.items()}
char2idx

{'-': 0,
 '가': 1,
 '각': 2,
 '간': 3,
 '갇': 4,
 '갈': 5,
 '갉': 6,
 '갊': 7,
 '감': 8,
 '갑': 9,
 '값': 10,
 '갓': 11,
 '갔': 12,
 '강': 13,
 '갖': 14,
 '갗': 15,
 '같': 16,
 '갚': 17,
 '갛': 18,
 '개': 19,
 '객': 20,
 '갠': 21,
 '갤': 22,
 '갬': 23,
 '갭': 24,
 '갯': 25,
 '갰': 26,
 '갱': 27,
 '갸': 28,
 '갹': 29,
 '갼': 30,
 '걀': 31,
 '걋': 32,
 '걍': 33,
 '걔': 34,
 '걘': 35,
 '걜': 36,
 '거': 37,
 '걱': 38,
 '건': 39,
 '걷': 40,
 '걸': 41,
 '걺': 42,
 '검': 43,
 '겁': 44,
 '것': 45,
 '겄': 46,
 '겅': 47,
 '겆': 48,
 '겉': 49,
 '겊': 50,
 '겋': 51,
 '게': 52,
 '겐': 53,
 '겔': 54,
 '겜': 55,
 '겝': 56,
 '겟': 57,
 '겠': 58,
 '겡': 59,
 '겨': 60,
 '격': 61,
 '겪': 62,
 '견': 63,
 '겯': 64,
 '결': 65,
 '겸': 66,
 '겹': 67,
 '겻': 68,
 '겼': 69,
 '경': 70,
 '곁': 71,
 '계': 72,
 '곈': 73,
 '곌': 74,
 '곕': 75,
 '곗': 76,
 '고': 77,
 '곡': 78,
 '곤': 79,
 '곧': 80,
 '골': 81,
 '곪': 82,
 '곬': 83,
 '곯': 84,
 '곰': 85,
 '곱': 86,
 '곳': 87,
 '공': 88,
 '곶': 89,
 '과': 90,
 '곽': 91,
 '관': 92,
 '괄': 93,
 '괆': 94,
 '괌': 95,
 '괍': 96,
 '괏': 97,
 '광': 98,
 '괘': 99,
 '괜': 100,

In [24]:
vocabulary = ["-"] + letters
print(len(vocabulary))
idx2char = {k: v for k, v in enumerate(vocabulary, start=0)}
char2idx = {v: k for k, v in idx2char.items()}

2350


In [25]:
class CustomDataset(Dataset):
    def __init__(self, img_path_list, label_list, train_mode=True):
        self.img_path_list = img_path_list
        self.label_list = label_list
        self.train_mode = train_mode

    def __len__(self):
        return len(self.img_path_list)

    def __getitem__(self, index):
        image = Image.open(self.img_path_list[index]).convert('RGB')

        if self.train_mode:
            image = self.train_transform(image)
        else:
            image = self.test_transform(image)

        if self.label_list is not None:
            text = self.label_list[index]
            return image, text
        else:
            return image

    # Image Augmentation
    def train_transform(self, image):
        transform_ops = transforms.Compose([
            transforms.Resize((CFG['IMG_HEIGHT_SIZE'], CFG['IMG_WIDTH_SIZE'])),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                 std=(0.229, 0.224, 0.225))
        ])
        return transform_ops(image)

    def test_transform(self, image):
        transform_ops = transforms.Compose([
            transforms.Resize((CFG['IMG_HEIGHT_SIZE'], CFG['IMG_WIDTH_SIZE'])),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                 std=(0.229, 0.224, 0.225))
        ])
        return transform_ops(image)

In [26]:
train_dataset = CustomDataset(train['img_path'].values, train['label'].values)
train_loader = DataLoader(
    train_dataset, batch_size=CFG['BATCH_SIZE'], shuffle=True, num_workers=CFG['NUM_WORKERS'])

val_dataset = CustomDataset(val['img_path'].values, val['label'].values)
val_loader = DataLoader(
    val_dataset, batch_size=CFG['BATCH_SIZE'], shuffle=True, num_workers=CFG['NUM_WORKERS'])

In [27]:
image_batch, text_batch = iter(train_loader).next()
print(image_batch.size(), text_batch)

torch.Size([84, 3, 64, 224]) ('신비', '그나마', '씻기다', '쉴', '일반', '셋', '수입되다', '세', '달리다', '뜁', '선원', '추가', '대량', '예매하다', '콩', '쯩', '못', '사모님', '뱀', '호주머니', '젤', '벌', '불', '베개', '수도권', '회복되다', '날씨', '유산', '견해', '체조', '얘', '튀김', '삼계탕', '어려움', '세', '발자국', '터', '함께', '스타', '샛', '참여하다', '븐', '학생증', '신인', '신청', '양복', '추진하다', '집중하다', '구분되다', '기타', '지구', '꽐', '명예', '대륙', '천장', '신', '칡', '창조', '걋', '나무', '향', '여행사', '강', '흄', '출발', '전철', '꿈', '캠페인', '유산', '자극', '품', '대출', '관광버스', '내외', '법', '딱', '알리다', '먹이다', '왼발', '이제', '씁', '뒷골목', '만들다', '특정하다')


In [28]:
import torch.nn as nn

from modules.transformation import TPS_SpatialTransformerNetwork
from modules.feature_extraction import VGG_FeatureExtractor, RCNN_FeatureExtractor, ResNet_FeatureExtractor
from modules.sequence_modeling import BidirectionalLSTM
from modules.prediction import Attention


class Model(nn.Module):

    def __init__(self, num_class):
        super(Model, self).__init__()

        self.Transformation = TPS_SpatialTransformerNetwork(
            F=20, I_size=(CFG['IMG_HEIGHT_SIZE'], CFG['IMG_WIDTH_SIZE']), I_r_size=(CFG['IMG_HEIGHT_SIZE'], CFG['IMG_WIDTH_SIZE']), I_channel_num=3)

        """ FeatureExtraction """
        self.FeatureExtraction = ResNet_FeatureExtractor(3, 512)
        self.FeatureExtraction_output = 512
        self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d(
            (None, 1))  # Transform final (imgH/16-1) -> 1

        """ Sequence modeling"""
        self.SequenceModeling = nn.Sequential(
            BidirectionalLSTM(self.FeatureExtraction_output,
                              512, 512),
            BidirectionalLSTM(512, 512, 512))
        self.SequenceModeling_output = 512

        """ Prediction """
        self.Prediction = nn.Linear(self.SequenceModeling_output, num_class)

    def forward(self, x):
        """ Transformation stage """
        input = self.Transformation(x)

        """ Feature extraction stage """
        visual_feature = self.FeatureExtraction(input)
        visual_feature = self.AdaptiveAvgPool(
            visual_feature.permute(0, 3, 1, 2))  # [b, c, h, w] -> [b, w, c, h]
        visual_feature = visual_feature.squeeze(3)

        """ Sequence modeling stage """
        contextual_feature = self.SequenceModeling(visual_feature)

        """ Prediction stage """
        prediction = self.Prediction(contextual_feature.contiguous())
        prediction = prediction.permute(1, 0, 2)

        return prediction

In [29]:
class RecognitionModel(nn.Module):
    def __init__(self, num_chars=len(char2idx), rnn_hidden_size=256):
        super(RecognitionModel, self).__init__()
        self.num_chars = num_chars
        self.rnn_hidden_size = rnn_hidden_size

        # CNN Backbone = 사전학습된 resnet18 활용
        # https://arxiv.org/abs/1512.03385
        resnet = resnet18(pretrained=True)
        # CNN Feature Extract
        resnet_modules = list(resnet.children())[:-3]
        self.feature_extract = nn.Sequential(
            *resnet_modules,
            nn.Conv2d(256, 256, kernel_size=(3, 6), stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )

        self.linear1 = nn.Linear(1024, rnn_hidden_size)

        # RNN
        self.rnn = nn.RNN(input_size=rnn_hidden_size,
                          hidden_size=rnn_hidden_size,
                          bidirectional=True,
                          batch_first=True)
        self.linear2 = nn.Linear(self.rnn_hidden_size*2, num_chars)

    def forward(self, x):
        # CNN
        x = self.feature_extract(x)  # [batch_size, channels, height, width]
        x = x.permute(0, 3, 1, 2)  # [batch_size, width, channels, height]

        batch_size = x.size(0)
        T = x.size(1)
        # [batch_size, T==width, num_features==channels*height]
        x = x.view(batch_size, T, -1)
        x = self.linear1(x)

        # RNN
        x, hidden = self.rnn(x)

        output = self.linear2(x)
        # [T==10, batch_size, num_classes==num_features]
        output = output.permute(1, 0, 2)

        return output

In [30]:
criterion = nn.CTCLoss(blank=0)  # idx 0 : '-'

In [31]:
def encode_text_batch(text_batch):
    text_batch_targets_lens = [len(text) for text in text_batch]
    text_batch_targets_lens = torch.IntTensor(text_batch_targets_lens)

    text_batch_concat = "".join(text_batch)
    text_batch_targets = [char2idx[c] for c in text_batch_concat]
    text_batch_targets = torch.IntTensor(text_batch_targets)

    return text_batch_targets, text_batch_targets_lens

In [32]:
def compute_loss(text_batch, text_batch_logits):
    """
    text_batch: list of strings of length equal to batch size
    text_batch_logits: Tensor of size([T, batch_size, num_classes])
    """
    text_batch_logps = F.log_softmax(
        text_batch_logits, 2)  # [T, batch_size, num_classes]
    text_batch_logps_lens = torch.full(size=(text_batch_logps.size(1),),
                                       fill_value=text_batch_logps.size(0),
                                       dtype=torch.int32).to(device)  # [batch_size]

    text_batch_targets, text_batch_targets_lens = encode_text_batch(text_batch)
    loss = criterion(text_batch_logps, text_batch_targets,
                     text_batch_logps_lens, text_batch_targets_lens)

    return loss

In [33]:
def train(model, optimizer, train_loader, val_loader, scheduler, device):
    model.to(device)

    best_loss = 999999
    best_model = None
    for epoch in range(1, CFG['EPOCHS']+1):
        model.train()
        train_loss = []
        for image_batch, text_batch in tqdm(iter(train_loader)):
            image_batch = image_batch.to(device)

            optimizer.zero_grad()
            text_batch_logits = model(image_batch)
            loss = compute_loss(text_batch, text_batch_logits)
            loss.backward()
            optimizer.step()
            train_loss.append(loss.item())

        _train_loss = np.mean(train_loss)

        _val_loss = validation(model, val_loader, device)
        print(
            f'Epoch : [{epoch}] Train CTC Loss : [{_train_loss:.5f}] Val CTC Loss : [{_val_loss:.5f}]')

        if scheduler is not None:
            scheduler.step(_val_loss)

        if best_loss > _val_loss:
            best_loss = _val_loss
            best_model = model

    return best_model

In [34]:
def validation(model, val_loader, device):
    model.eval()
    val_loss = []
    with torch.no_grad():
        for image_batch, text_batch in tqdm(iter(val_loader)):
            image_batch = image_batch.to(device)

            text_batch_logits = model(image_batch)
            loss = compute_loss(text_batch, text_batch_logits)

            val_loss.append(loss.item())

    _val_loss = np.mean(val_loss)
    return _val_loss

In [35]:
# torch.Size([16, 57, 2350])
# torch.Size([11, 16, 2350])

In [46]:
# model = RecognitionModel()

checkpoint = torch.load('./TPS-ResNet-BiLSTM-CTC.pth')
model = Model(len(char2idx))
model.load_state_dict(checkpoint, strict=False)
# model.eval()

# optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps)

optimizer = torch.optim.Adadelta(
    params=model.parameters(), lr=CFG["LEARNING_RATE"], rho=0.95, eps=1e-8)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=2, threshold_mode='abs', min_lr=1e-8, verbose=True)

infer_model = train(model, optimizer, train_loader,
                    val_loader, scheduler, device)

  0%|          | 0/789 [00:00<?, ?it/s]

  0%|          | 0/127 [00:00<?, ?it/s]

Epoch : [1] Train CTC Loss : [117.57431] Val CTC Loss : [8.28094]


  0%|          | 0/789 [00:00<?, ?it/s]

  0%|          | 0/127 [00:00<?, ?it/s]

Epoch : [2] Train CTC Loss : [8.25000] Val CTC Loss : [8.08869]


  0%|          | 0/789 [00:00<?, ?it/s]

  0%|          | 0/127 [00:00<?, ?it/s]

Epoch : [3] Train CTC Loss : [8.09440] Val CTC Loss : [7.95812]


  0%|          | 0/789 [00:00<?, ?it/s]

  0%|          | 0/127 [00:00<?, ?it/s]

Epoch : [4] Train CTC Loss : [7.98275] Val CTC Loss : [7.85937]


  0%|          | 0/789 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
torch.save(model.state_dict(), './model_pre_trained.pth')
torch.save(infer_model.state_dict(), './model_pre_trained_best.pth')

In [37]:
test = pd.read_csv('./test.csv')

In [38]:
test_dataset = CustomDataset(test['img_path'].values, None)
test_loader = DataLoader(
    test_dataset, batch_size=CFG['BATCH_SIZE'], shuffle=False, num_workers=CFG['NUM_WORKERS'])

In [39]:
def decode_predictions(text_batch_logits):
    text_batch_tokens = F.softmax(
        text_batch_logits, 2).argmax(2)  # [T, batch_size]
    text_batch_tokens = text_batch_tokens.numpy().T  # [batch_size, T]

    text_batch_tokens_new = []
    for text_tokens in text_batch_tokens:
        text = [idx2char[idx] for idx in text_tokens]
        text = "".join(text)
        text_batch_tokens_new.append(text)

    return text_batch_tokens_new


def inference(model, test_loader, device):
    model.eval()
    preds = []
    with torch.no_grad():
        for image_batch in tqdm(iter(test_loader)):
            image_batch = image_batch.to(device)

            text_batch_logits = model(image_batch)

            text_batch_pred = decode_predictions(text_batch_logits.cpu())

            preds.extend(text_batch_pred)
    return preds

In [40]:
predictions = inference(infer_model, test_loader, device)

  0%|          | 0/883 [00:00<?, ?it/s]

In [41]:
# 샘플 별 추론결과를 독립적으로 후처리
def remove_duplicates(text):
    if len(text) > 1:
        letters = [text[0]] + [letter for idx,
                               letter in enumerate(text[1:], start=1) if text[idx] != text[idx-1]]
    elif len(text) == 1:
        letters = [text[0]]
    else:
        return ""
    return "".join(letters)


def correct_prediction(word):
    parts = word.split("-")
    parts = [remove_duplicates(part) for part in parts]
    corrected_word = "".join(parts)
    return corrected_word

In [42]:
submit = pd.read_csv('./sample_submission.csv')
submit['label'] = predictions
submit['label'] = submit['label'].apply(correct_prediction)

In [43]:
submit.to_csv('./submission_rev.csv', index=False)