In [1]:
from data_reader import Vocabulary, HWDBDatasetHelper, LMDBReader

# your path to data
train_path = r'/DATA/ichuviliaeva/ocr_data/train.lmdb'
test_path = r'/DATA/ichuviliaeva/ocr_data/test.lmdb'
gt_path = r'/DATA/ichuviliaeva/ocr_data/gt.txt'

In [2]:
import cv2
import numpy as np

In [3]:
train_reader = LMDBReader(train_path)
train_reader.open()
train_helper = HWDBDatasetHelper(train_reader)

In [4]:
train_helper, val_helper = train_helper.train_val_split()

In [5]:
train_helper.size(), val_helper.size()

(2578433, 644609)

In [6]:
class pad_prune(object):
    
    def __init__(self, output_w, output_h):
        self.output_w = output_w
        self.output_h = output_h

    def __call__(self, image):
        
        w, h = image.shape
        diff_w = w - self.output_w
        diff_h = h - self.output_h
        masked = np.zeros((self.output_w, self.output_h))
        if diff_w <= 0:
            if diff_h <= 0:
                masked[:w, :h] = image
            else:
                masked[:w, :] = image[:, (diff_h // 2):(h - (diff_h - diff_h // 2))]
            
        if diff_w > 0:
            if diff_h > 0:
                masked = image[(diff_w // 2):(w - (diff_w - diff_w // 2)), 
                               (diff_h // 2):(h - (diff_h - diff_h // 2))]
            else:
                masked[:, :h] = image[(diff_w // 2):(w - (diff_w - diff_w // 2)), :]
        return masked - 127.5 / 255.0  

In [7]:
import torch

from torch.utils.data import Dataset, DataLoader
from torch import nn

class HWDBDataset(Dataset):
    def __init__(self, helper: HWDBDatasetHelper, transform = None):
        self.helper = helper
        self.transform = transform
    
    def __len__(self):
        return self.helper.size()
    
    def __getitem__(self, idx):
        img, label = self.helper.get_item(idx)
        if self.transform:
            img = self.transform(img)
        return img, label

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
to_one_shape = pad_prune(128, 128)
train_dataset = HWDBDataset(train_helper, to_one_shape)
val_dataset = HWDBDataset(val_helper, to_one_shape)

In [9]:
class SimpleBlock(nn.Module):
        
    def __init__(self, in_channels, out_channels, kernel_size, padding, stride, is_projection = False):
        super().__init__()
        self.bn1 = nn.BatchNorm2d(num_features = in_channels)
        self.relu1 = nn.ReLU()
        self.conv1 = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, 
                               kernel_size = kernel_size, padding = padding, stride = stride)
            
        self.bn2 = nn.BatchNorm2d(num_features = out_channels)
        self.relu2 = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels = out_channels, out_channels = out_channels,
                               kernel_size = kernel_size, padding = padding, stride = stride)
        if is_projection:
            self.project = nn.Conv2d(in_channels = in_channels, out_channels = out_channels,
                               kernel_size = 1, padding = 0, stride = stride)
        else:
            self.project = lambda x: x
            
    def forward(self, x):            
        out = self.conv1(self.relu1(self.bn1(x)))
        out = self.conv2(self.relu2(self.bn2(out)))
        return out + self.project(x)

In [10]:
class TheNet(nn.Module):
    def __init__(self, n_classes):
        super(TheNet, self).__init__()
        self.nn = nn.Sequential(
            nn.Conv2d(in_channels = 1, out_channels = 4, kernel_size = 5, stride = 2),
            nn.ReLU(),
            
            SimpleBlock(in_channels = 4, out_channels = 4, kernel_size = 3, padding = 1, stride = 1),
            SimpleBlock(in_channels = 4, out_channels = 4, kernel_size = 3, padding = 1, stride = 1, 
                             is_projection = True),
            
            nn.Conv2d(4, 8, kernel_size=3, padding=1),
            nn.BatchNorm2d(8),
            nn.ReLU(),
            nn.Conv2d(8, 8, kernel_size=3, padding=1),
            nn.MaxPool2d((2, 2)),
           
            SimpleBlock(in_channels = 8, out_channels = 8, kernel_size = 3, padding = 1, stride = 1),
            SimpleBlock(in_channels = 8, out_channels = 8, kernel_size = 3, padding = 1, stride = 1, 
                             is_projection = True),
            
            nn.Conv2d(8, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.MaxPool2d((2, 2)),
            
            SimpleBlock(in_channels = 16, out_channels = 16, kernel_size = 3, padding = 1, stride = 1),
            SimpleBlock(in_channels = 16, out_channels = 16, kernel_size = 3, padding = 1, stride = 1, 
                             is_projection = True),
            
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.MaxPool2d((2, 2)),
            
            SimpleBlock(in_channels = 32, out_channels = 32, kernel_size = 3, padding = 1, stride = 1),
            SimpleBlock(in_channels = 32, out_channels = 32, kernel_size = 3, padding = 1, stride = 1, 
                             is_projection = True),
            
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.MaxPool2d((2, 2)),
            
            nn.Flatten(),
            nn.Linear(3 * 3 * 32, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            
            nn.Linear(128, n_classes, bias=False)
        )
            
    def forward(self, x):
        return self.nn(x)

In [11]:
model = TheNet(train_helper.vocabulary.num_classes())
model.eval()
res = model(torch.tensor(train_dataset[0][0], dtype=torch.float32).view(1, 1, 128, 128))
print(res.shape)

torch.Size([1, 7330])


In [14]:
model = model.cuda()

In [18]:
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True, drop_last=True, num_workers=8)
val_loader = DataLoader(val_dataset, batch_size=2048, shuffle=False, num_workers=8)

In [19]:
optim = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.CrossEntropyLoss()

In [20]:
from tqdm import tqdm


def run_validation(val_loader: DataLoader, model: nn.Module, n_steps=None):
    model.eval()
    n_good = 0
    n_all = 0
    wrapper = lambda x: x
    if n_steps is None:
        n_steps = len(val_loader)
        wrapper = tqdm
    
    with torch.no_grad():
        for batch, (X, y) in enumerate(wrapper(val_loader)):
            if batch == n_steps:
                break
            logits = model(X.unsqueeze(1).to(torch.float32).cuda())
            classes = torch.argmax(logits, dim=1).cpu().numpy()
            n_good += sum(classes == y.cpu().numpy())
            n_all += len(classes)
    
    return n_good / n_all


def train_epoch(train_loader: DataLoader, val_loader: DataLoader, model: nn.Module, optim, loss_fn):
    for batch, (X, y) in enumerate(tqdm(train_loader)):
        model.train()
        logits = model(X.unsqueeze(1).to(torch.float32).cuda())
        loss = loss_fn(logits, y.to(torch.long).cuda())
        
        optim.zero_grad()
        loss.backward()
        optim.step()

In [18]:
torch.save(model.state_dict(), 'arch1.pth')

In [19]:
for epoch in range(10):
    print(f'Epoch {epoch}:')
    train_epoch(train_loader, val_loader, model, optim, loss_fn)
    accuracy = run_validation(val_loader, model)
    print(f'accuracy: {accuracy}')
    if epoch % 5 == 0 or epoch == 9:
        torch.save(model.state_dict(), f'arch1_epoch{epoch}.pth')

Epoch 0:


100%|█████████████████████████████████████████████████████████████████| 5036/5036 [14:28<00:00,  5.80it/s]
100%|███████████████████████████████████████████████████████████████████| 315/315 [00:51<00:00,  6.10it/s]


accuracy: 0.7978836783228282
Epoch 1:


100%|█████████████████████████████████████████████████████████████████| 5036/5036 [14:28<00:00,  5.80it/s]
100%|███████████████████████████████████████████████████████████████████| 315/315 [00:51<00:00,  6.12it/s]


accuracy: 0.8413503379568079
Epoch 2:


100%|█████████████████████████████████████████████████████████████████| 5036/5036 [14:28<00:00,  5.80it/s]
100%|███████████████████████████████████████████████████████████████████| 315/315 [00:50<00:00,  6.29it/s]


accuracy: 0.8847968303266011
Epoch 3:


100%|█████████████████████████████████████████████████████████████████| 5036/5036 [14:28<00:00,  5.80it/s]
100%|███████████████████████████████████████████████████████████████████| 315/315 [00:50<00:00,  6.22it/s]


accuracy: 0.8938302133541418
Epoch 4:


100%|█████████████████████████████████████████████████████████████████| 5036/5036 [15:21<00:00,  5.46it/s]
100%|███████████████████████████████████████████████████████████████████| 315/315 [00:50<00:00,  6.29it/s]


accuracy: 0.8943855887832779
Epoch 5:


100%|█████████████████████████████████████████████████████████████████| 5036/5036 [14:26<00:00,  5.81it/s]
100%|███████████████████████████████████████████████████████████████████| 315/315 [00:49<00:00,  6.30it/s]


accuracy: 0.9074756945683352
Epoch 6:


100%|█████████████████████████████████████████████████████████████████| 5036/5036 [14:26<00:00,  5.81it/s]
100%|███████████████████████████████████████████████████████████████████| 315/315 [00:49<00:00,  6.31it/s]


accuracy: 0.9024044032894359
Epoch 7:


 84%|██████████████████████████████████████████████████████▌          | 4228/5036 [12:06<02:19,  5.80it/s]

In [22]:
model.load_state_dict(torch.load(f'arch1_epoch5.pth'))

<All keys matched successfully>

In [23]:
model = model.cuda()

In [24]:
epoch = 6
while epoch < 10:
    print(f'Epoch {epoch}:')
    train_epoch(train_loader, val_loader, model, optim, loss_fn)
    accuracy = run_validation(val_loader, model)
    print(f'accuracy: {accuracy}')
    
    if epoch % 5 == 0 or epoch == 9:
        torch.save(model.state_dict(), f'arch1_epoch{epoch}.pth')
    epoch += 1

Epoch 6:


100%|█████████████████████████████████████████████████████████████████| 5036/5036 [14:29<00:00,  5.79it/s]
100%|███████████████████████████████████████████████████████████████████| 315/315 [00:49<00:00,  6.30it/s]


accuracy: 0.9101207088327963
Epoch 7:


100%|█████████████████████████████████████████████████████████████████| 5036/5036 [14:26<00:00,  5.81it/s]
100%|███████████████████████████████████████████████████████████████████| 315/315 [00:50<00:00,  6.24it/s]


accuracy: 0.9121638078276909
Epoch 8:


100%|█████████████████████████████████████████████████████████████████| 5036/5036 [14:26<00:00,  5.81it/s]
100%|███████████████████████████████████████████████████████████████████| 315/315 [00:50<00:00,  6.29it/s]


accuracy: 0.9174305664363979
Epoch 9:


100%|█████████████████████████████████████████████████████████████████| 5036/5036 [14:26<00:00,  5.81it/s]
100%|███████████████████████████████████████████████████████████████████| 315/315 [00:49<00:00,  6.32it/s]

accuracy: 0.9090828703911984





In [32]:
model.load_state_dict(torch.load(f'arch1_epoch9.pth'))
epoch = 10
while epoch < 12:
    print(f'Epoch {epoch}:')
    train_epoch(train_loader, val_loader, model, optim, loss_fn)
    accuracy = run_validation(val_loader, model)
    print(f'accuracy: {accuracy}')
    
    if epoch % 5 == 0 or epoch == 9:
        torch.save(model.state_dict(), f'arch1_epoch{epoch}.pth')
    epoch += 1

Epoch 10:


100%|█████████████████████████████████████████████████████████████████| 5036/5036 [14:27<00:00,  5.80it/s]
100%|███████████████████████████████████████████████████████████████████| 315/315 [00:52<00:00,  6.03it/s]


accuracy: 0.9171745973140307
Epoch 11:


100%|█████████████████████████████████████████████████████████████████| 5036/5036 [14:26<00:00,  5.81it/s]
100%|███████████████████████████████████████████████████████████████████| 315/315 [00:50<00:00,  6.18it/s]

accuracy: 0.9158714817819794





In [33]:
torch.save(model.state_dict(), f'arch1_final.pth')

In [34]:
test_path = r'/DATA/ichuviliaeva/ocr_data/test.lmdb'
pred_path = './pred.txt'

test_reader = LMDBReader(test_path)
test_reader.open()
test_helper = HWDBDatasetHelper(test_reader, prefix='Test')

In [35]:
test_dataset = HWDBDataset(test_helper, to_one_shape)
test_loader = DataLoader(test_dataset, batch_size=2048, shuffle=False, num_workers=8)

In [36]:
preds = []
model.eval()
with torch.no_grad():
    for X, _ in tqdm(test_loader):
        logits = model(X.unsqueeze(1).to(torch.float32).cuda())
        classes = torch.argmax(logits, dim=1).cpu().numpy()
        preds.extend(classes)

100%|███████████████████████████████████████████████████████████████████| 380/380 [01:00<00:00,  6.31it/s]


In [37]:
with open(pred_path, 'w') as f_pred:
    for idx, pred in enumerate(preds):
        name = test_helper.namelist[idx]
        cls = train_helper.vocabulary.class_by_index(pred)
        print(name, cls, file=f_pred)

In [38]:
!python -m course_ocr_t2.evaluate
# Accuracy = 0.7978

/bin/bash: /home/ichuviliaeva/miniconda3/envs/ocr_course/lib/python3.8/site-packages/cv2/../../../../lib/libtinfo.so.6: no version information available (required by /bin/bash)
pred_path =  /home/ichuviliaeva/ocr_course/course_ocr/task2/pred.txt
Accuracy = 0.8778
