In [1]:
# from google.colab import drive

In [2]:
# drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [1]:
import os

In [2]:
os.chdir('C:/Users/HP/Desktop/attention-ocr')

In [3]:
import random
import time
import pickle

from tqdm import tqdm

from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from torchvision import transforms
from utils.mydataset import MyDataset
from model.attention_ocr import OCR
from utils.train_util import train_batch, eval_batch
from utils import data_preprocessing
from torchvision.transforms import ToTensor

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cpu


In [5]:
# device = "cpu"

In [6]:
def resizePadding(img, width, height):
    desired_w, desired_h = width, height #(width, height)
    _,img_h, img_w = img.shape  # old_size[0] is in (width, height) format
    # print("img_w: {0}, img_h: {1}".format(img_w, img_h))
    # ratio = img_w/float(img_h)
    # print("ratio:", ratio)
    # new_w = int(desired_h*ratio)
    # new_w = new_w if desired_w == None else min(desired_w, new_w)
    # img = img.resize((3, desired_h, new_w), Image.ANTIALIAS)

    # padding image
    img = img.permute(1,2,0)
    img = img.numpy()
    img = img*255.0
    img = Image.fromarray(img.astype('uint8'), mode = "RGB")
    if desired_w != None: # and desired_w > new_w:
        new_img = Image.new("RGB", (desired_w, desired_h), color = 255)
        new_img.paste(img,(0,0))
        img = new_img

    img = ToTensor()(img)

    return img

In [7]:
class alignCollate(object):

    def __init__(self, imgW, imgH):
        self.imgH = imgH
        self.imgW = imgW
    
    def __call__(self, batch):
        images, labels = zip(*batch)
        imgH = self.imgH
        imgW = self.imgW
        images = [resizePadding(image, self.imgW, self.imgH) for image in images]
        images = torch.cat([t.unsqueeze(0) for t in images], 0)

        return images, labels

In [8]:
list_train = os.listdir('../train')
list_traindir = []
for index,image in enumerate(list_train):
    if data_preprocessing.valid_image(data_preprocessing.get_label(image)):
        list_traindir.append(image)


list_test = os.listdir('../test')
list_testdir = []
for index,image in enumerate(list_test):
    if data_preprocessing.valid_image(data_preprocessing.get_label(image)):
        list_testdir.append(image)

In [9]:
train_dataset = MyDataset(list_traindir, cate = 'train')
test_dataset = MyDataset(list_testdir, cate = "test")

In [10]:
batch_size = 32

In [11]:
train_loader = DataLoader(train_dataset,
            batch_size=batch_size,
            shuffle=True,
            drop_last=True,
            collate_fn=alignCollate(350,32)
            )

test_loader = DataLoader(test_dataset,
            batch_size=8,
            collate_fn=alignCollate(350,32),
            drop_last=True,
            shuffle = True
            )

print("Train dataset length: {0}".format(len(train_dataset)))
print("Test dataset length: {0}".format(len(test_dataset)))

Train dataset length: 33201
Test dataset length: 2214


In [12]:
def main():
    img_width = 350
    img_height = 32
    nh = 512
    max_len = 15
    teacher_forcing_ratio = 0.5
    lr = 3e-4
    n_epoch = 20
    save_checkpoint_every = 5

    tokenizer = train_dataset.tokenizer

    model = OCR(img_width, img_height, nh, tokenizer.n_token,
                max_len + 1, tokenizer.SOS_token, tokenizer.EOS_token).to(device)

    load_weights = torch.load('./inception_v3_google-1a9a5a14.pth')

    names = set()
    for k, w in model.incept.named_children():
        names.add(k)

    weights = {}
    for k, w in load_weights.items():
        if k.split('.')[0] in names:
            weights[k] = w

    model.incept.load_state_dict(weights)

    optimizer = optim.Adam(model.parameters(), lr=lr)
    crit = nn.NLLLoss().to(device)

    def train_epoch():
        sum_loss_train = 0
        n_train = 0
        sum_acc = 0
        sum_sentence_acc = 0
        model.train()
        for batch, (x,y) in enumerate(tqdm(train_loader)):
            label = y[0].unsqueeze(0)
            for i,l in enumerate(y):
              if i>0:
                label = torch.cat((label,l.unsqueeze(0)), dim = 0)
            # print("x: ",x)
            # print("y: ",label)
            x = x.to(device)
            label = label.to(device)

            loss, acc, sentence_acc = train_batch(x, label, model, optimizer,
                                                  crit, teacher_forcing_ratio, max_len,
                                                  tokenizer)

            sum_loss_train += loss
            sum_acc += acc
            sum_sentence_acc += sentence_acc

            n_train += 1

        return sum_loss_train / n_train, sum_acc / n_train, sum_sentence_acc / n_train

    def eval_epoch():
        sum_loss_eval = 0
        n_eval = 0
        sum_acc = 0
        sum_sentence_acc = 0

        for bi, batch in enumerate(tqdm(test_loader)):
            x, y = batch
            label = y[0].unsqueeze(0)
            x = x.to(device=device)
            for i,l in enumerate(y):
              if i>0:
                label = torch.cat((label,l.unsqueeze(0)), dim = 0)
            label = label.to(device=device)

            loss, acc, sentence_acc = eval_batch(x, label, model, crit, max_len, tokenizer)

            sum_loss_eval += loss
            sum_acc += acc
            sum_sentence_acc += sentence_acc

            n_eval += 1

        return sum_loss_eval / n_eval, sum_acc / n_eval, sum_sentence_acc / n_eval

    for epoch in range(n_epoch):
        train_loss, train_acc, train_sentence_acc = train_epoch()
        eval_loss, eval_acc, eval_sentence_acc = eval_epoch()

        print("Epoch %d" % epoch)
        print('train_loss: %.4f, train_acc: %.4f, train_sentence: %.4f' % (train_loss, train_acc, train_sentence_acc))
        print('eval_loss:  %.4f, eval_acc:  %.4f, eval_sentence:  %.4f' % (eval_loss, eval_acc, eval_sentence_acc))

        if epoch % save_checkpoint_every == 0 and epoch > 0:
            print('saving checkpoint...')
            torch.save(model.state_dict(), './chkpoint/time_%s_epoch_%s.pth' % (time.strftime('%Y-%m-%d_%H-%M-%S'), epoch))


if __name__ == '__main__':
    main()

Model feature size: 1 41


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
100%|██████████| 1037/1037 [25:35<00:00,  1.48s/it]
  0%|          | 0/1037 [00:00<?, ?it/s]

Epoch 0
train_loss: 1.2023, train_acc: 0.4194, train_sentence: 0.0074
eval_loss:  2.7650, eval_acc:  0.6196, eval_sentence:  0.0453


100%|██████████| 1037/1037 [26:30<00:00,  1.53s/it] 
  0%|          | 0/1037 [00:00<?, ?it/s]

Epoch 1
train_loss: 0.3614, train_acc: 0.7791, train_sentence: 0.2523
eval_loss:  2.1929, eval_acc:  0.7176, eval_sentence:  0.2034


100%|██████████| 1037/1037 [25:45<00:00,  1.49s/it] 
  0%|          | 0/1037 [00:00<?, ?it/s]

Epoch 2
train_loss: 0.1924, train_acc: 0.8546, train_sentence: 0.4993
eval_loss:  0.8621, eval_acc:  0.8473, eval_sentence:  0.5226


100%|██████████| 1037/1037 [23:35<00:00,  1.37s/it]
  0%|          | 0/1037 [00:00<?, ?it/s]

Epoch 3
train_loss: 0.1384, train_acc: 0.8763, train_sentence: 0.5997
eval_loss:  0.6772, eval_acc:  0.8658, eval_sentence:  0.6024


100%|██████████| 1037/1037 [23:15<00:00,  1.35s/it]
  0%|          | 0/1037 [00:00<?, ?it/s]

Epoch 4
train_loss: 0.1163, train_acc: 0.8855, train_sentence: 0.6486
eval_loss:  2.8520, eval_acc:  0.7332, eval_sentence:  0.3827


 89%|████████▊ | 918/1037 [22:36<02:45,  1.39s/it]