In [1]:
from data_reader import Vocabulary, HWDBDatasetHelper, LMDBReader

# your path to data
train_path = r'/DATA/ichuviliaeva/ocr_data/train.lmdb'
test_path = r'/DATA/ichuviliaeva/ocr_data/test.lmdb'
gt_path = r'/DATA/ichuviliaeva/ocr_data/gt.txt'

In [26]:
import cv2
import numpy as np
from centerloss import CenterLoss
from torchvision import transforms

In [3]:
train_reader = LMDBReader(train_path)
train_reader.open()
train_helper = HWDBDatasetHelper(train_reader)

In [4]:
train_helper, val_helper = train_helper.train_val_split()

In [5]:
train_helper.size(), val_helper.size()

(2578433, 644609)

In [8]:
import torch

from torch.utils.data import Dataset, DataLoader
from torch import nn

class HWDBDataset(Dataset):
    def __init__(self, helper: HWDBDatasetHelper, transform = None):
        self.helper = helper
        self.transform = transform
    
    def __len__(self):
        return self.helper.size()
    
    def __getitem__(self, idx):
        img, label = self.helper.get_item(idx)
        if self.transform:
            img = self.transform(img)
        return img, label

In [7]:
class pad_prune(object):
    
    def __init__(self, output_w, output_h):
        self.output_w = output_w
        self.output_h = output_h

    def __call__(self, image):
        
        w, h = image.shape
        diff_w = w - self.output_w
        diff_h = h - self.output_h
        masked = np.zeros((self.output_w, self.output_h))
        if diff_w <= 0:
            if diff_h <= 0:
                masked[:w, :h] = image
            else:
                masked[:w, :] = image[:, (diff_h // 2):(h - (diff_h - diff_h // 2))]
            
        if diff_w > 0:
            if diff_h > 0:
                masked = image[(diff_w // 2):(w - (diff_w - diff_w // 2)), 
                               (diff_h // 2):(h - (diff_h - diff_h // 2))]
            else:
                masked[:, :h] = image[(diff_w // 2):(w - (diff_w - diff_w // 2)), :]
        return masked
    
class to_zero_one(object):
    def __init__(self):
        pass
    def __call__(self, image):
        return image - 127.5 / 255.0 

In [66]:
to_one_shape = pad_prune(128, 128)

train_transforms = transforms.Compose([
    to_one_shape,
    transforms.ToTensor(),
    transforms.RandomChoice([
        transforms.RandomPerspective(distortion_scale = 0.2, p = 0.9),  
        transforms.ColorJitter(),
        transforms.GaussianBlur(kernel_size = 5),
        transforms.RandomRotation(degrees=15)
    ]), 
    to_zero_one()
])

val_transforms = transforms.Compose([
    to_one_shape,
    transforms.ToTensor()
])

In [67]:
train_dataset = HWDBDataset(train_helper, train_transforms)
val_dataset = HWDBDataset(val_helper, val_transforms)

In [68]:
class ResNetBlock(nn.Module):
        
    def __init__(self, in_channels, out_channels, kernel_size, padding, stride, is_projection = False):
        super().__init__()
        self.bn1 = nn.BatchNorm2d(num_features = in_channels)
        self.relu1 = nn.ReLU()
        self.conv1 = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, 
                               kernel_size = kernel_size, padding = padding, stride = stride)
            
        self.bn2 = nn.BatchNorm2d(num_features = out_channels)
        self.relu2 = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels = out_channels, out_channels = out_channels,
                               kernel_size = kernel_size, padding = padding, stride = stride)
        if is_projection:
            self.project = nn.Conv2d(in_channels = in_channels, out_channels = out_channels,
                               kernel_size = 1, padding = 0, stride = stride)
        else:
            self.project = lambda x: x
            
    def forward(self, x):            
        out = self.conv1(self.relu1(self.bn1(x)))
        out = self.conv2(self.relu2(self.bn2(out)))
        return out + self.project(x)

In [80]:
class TheNet(nn.Module):
    def __init__(self, n_classes):
        super(TheNet, self).__init__()
        self.nn = nn.Sequential(
            nn.Conv2d(in_channels = 1, out_channels = 8, kernel_size = 5, stride = 2),
            nn.ReLU(),
            
            ResNetBlock(in_channels = 8, out_channels = 8, kernel_size = 3, padding = 1, stride = 1, 
                             is_projection = True),
            ResNetBlock(in_channels = 8, out_channels = 8, kernel_size = 3, padding = 1, stride = 1, 
                             is_projection = True),
            
            nn.Conv2d(8, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.MaxPool2d((2, 2)),
           
            ResNetBlock(in_channels = 16, out_channels = 16, kernel_size = 3, padding = 1, stride = 1, 
                             is_projection = True),
            ResNetBlock(in_channels = 16, out_channels = 16, kernel_size = 3, padding = 1, stride = 1, 
                             is_projection = True),
            
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.MaxPool2d((2, 2)),
            
            ResNetBlock(in_channels = 32, out_channels = 32, kernel_size = 3, padding = 1, stride = 1,
                             is_projection = True),
            ResNetBlock(in_channels = 32, out_channels = 32, kernel_size = 3, padding = 1, stride = 1, 
                             is_projection = True),
            
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.MaxPool2d((2, 2)),
            
            ResNetBlock(in_channels = 64, out_channels = 64, kernel_size = 3, padding = 1, stride = 1,
                             is_projection = True),
            ResNetBlock(in_channels = 64, out_channels = 64, kernel_size = 3, padding = 1, stride = 1, 
                             is_projection = True),
            
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.MaxPool2d((2, 2)),
            
            ResNetBlock(in_channels = 128, out_channels = 128, kernel_size = 3, padding = 1, stride = 1, 
                             is_projection = True),
            ResNetBlock(in_channels = 128, out_channels = 128, kernel_size = 3, padding = 1, stride = 1, 
                             is_projection = True),
            
            nn.Flatten(),
            nn.Linear(3 * 3 * 128, 256),
            nn.BatchNorm1d(256),
            nn.ReLU()
        )
        
        self.last_layer = nn.Linear(256, n_classes, bias=False)
        
    def forward(self, x):
        x = self.nn(x)
        y = self.last_layer(x)
        return x, y

In [81]:
model = TheNet(train_helper.vocabulary.num_classes())
model.eval()
res = model(torch.tensor(train_dataset[0][0], dtype=torch.float32).view(1, 1, 128, 128))
print(res[0].shape)

torch.Size([1, 256])


  res = model(torch.tensor(train_dataset[0][0], dtype=torch.float32).view(1, 1, 128, 128))


In [82]:
model = model.cuda()

In [83]:
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True, drop_last=True, num_workers=8)
val_loader = DataLoader(val_dataset, batch_size=2048, shuffle=False, num_workers=8)

In [84]:
optim = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.CrossEntropyLoss()

In [85]:
center_loss = CenterLoss(num_classes=train_helper.vocabulary.num_classes(), feat_dim=256, use_gpu=True)
optimizer_centloss = torch.optim.Adam(center_loss.parameters(), lr=0.5)

In [86]:
from tqdm.notebook import tqdm

def run_validation(val_loader: DataLoader, model: nn.Module, n_steps=None):
    model.eval()
    n_good = 0
    n_all = 0
    wrapper = lambda x: x
    if n_steps is None:
        n_steps = len(val_loader)
        wrapper = tqdm
    
    with torch.no_grad():
        for batch, (X, y) in enumerate(wrapper(val_loader)):
            if batch == n_steps:
                break
            features, logits = model(X.to(torch.float32).cuda())
            classes = torch.argmax(logits, dim=1).cpu().numpy()
            n_good += sum(classes == y.cpu().numpy())
            n_all += len(classes)
    
    return n_good / n_all


def train_epoch(train_loader: DataLoader, val_loader: DataLoader, model: nn.Module, optim, loss_fn, alpha):
    for batch, (X, y) in enumerate(tqdm(train_loader)):
        model.train()
        features, logits = model(X.to(torch.float32).cuda())
        labels = y.to(torch.long).cuda()
        loss = center_loss(features, labels) * alpha + loss_fn(logits, labels)
        
        optim.zero_grad()
        optimizer_centloss.zero_grad()
        
        loss.backward()
        optim.step()
        
        for param in center_loss.parameters():
             param.grad.data *= (1./alpha)

        optimizer_centloss.step()

In [87]:
test_path = r'/DATA/ichuviliaeva/ocr_data/test.lmdb'
pred_path = './pred.txt'

test_reader = LMDBReader(test_path)
test_reader.open()
test_helper = HWDBDatasetHelper(test_reader, prefix='Test')

In [88]:
test_dataset = HWDBDataset(test_helper, to_one_shape)
test_loader = DataLoader(test_dataset, batch_size=2048, shuffle=False, num_workers=8)

In [90]:
from pathlib import Path

def evaluate(gt_path, pred_path):
    gt = dict()
    with open(gt_path) as gt_f:
        for line in gt_f:
            name, cls = line.strip().split()
            gt[name] = cls
    
    n_good = 0
    n_all = len(gt)
    with open(pred_path) as pred_f:
        for line in pred_f:
            name, cls = line.strip().split()
            if cls == gt[name]:
                n_good += 1
    
    return n_good / n_all


def _run_evaluation():
    base = '.'
    pred_path = base + '/pred.txt'
    print('pred_path = ', pred_path)
    score = evaluate(gt_path, pred_path)
    print('Accuracy = {:1.4f}'.format(score))

Тут я просто захотела периодически запускать тест в середине. Это не должно приводить к каким-либо утечкам данных, т. к. всё предсказывается под with torch.no_grad()

In [91]:
alpha = 0.3
for epoch in range(31):
    print(f'Epoch {epoch}:')
    train_epoch(train_loader, val_loader, model, optim, loss_fn, alpha)
    accuracy = run_validation(val_loader, model)
    print(f'accuracy: {accuracy}')
    if epoch % 5 == 0:
        torch.save(model.state_dict(), f'arch1_all_{epoch}.pth')
        alpha = min(1.0, alpha * 1.5)
        
        preds = []
        model.eval()
        with torch.no_grad():
            for X, _ in tqdm(test_loader):
                features, logits = model(X.unsqueeze(1).to(torch.float32).cuda())
                classes = torch.argmax(logits, dim=1).cpu().numpy()
                preds.extend(classes)
                
        with open(pred_path, 'w') as f_pred:
            for idx, pred in enumerate(preds):
                name = test_helper.namelist[idx]
                cls = train_helper.vocabulary.class_by_index(pred)
                print(name, cls, file=f_pred)
        _run_evaluation()

Epoch 0:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.004962698317895034


  0%|          | 0/380 [00:00<?, ?it/s]

pred_path =  ./pred.txt
Accuracy = 0.0036
Epoch 1:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.18014331168196535
Epoch 2:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.5653256470201315
Epoch 3:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.7029393011887827
Epoch 4:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.7041788122722457
Epoch 5:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8058668122846563


  0%|          | 0/380 [00:00<?, ?it/s]

pred_path =  ./pred.txt
Accuracy = 0.7108
Epoch 6:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8246425352422941
Epoch 7:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8045869666728203
Epoch 8:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8411284980507564
Epoch 9:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8083411804675392
Epoch 10:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8551571572844934


  0%|          | 0/380 [00:00<?, ?it/s]

pred_path =  ./pred.txt
Accuracy = 0.8071
Epoch 11:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8667533341917348
Epoch 12:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8572095642474741
Epoch 13:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8815778246968318
Epoch 14:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8850310808567674
Epoch 15:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8904110864105217


  0%|          | 0/380 [00:00<?, ?it/s]

pred_path =  ./pred.txt
Accuracy = 0.8514
Epoch 16:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8901178854158102
Epoch 17:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8880701324368726
Epoch 18:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8983197566276611
Epoch 19:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8913946283716175
Epoch 20:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8878389845627349


  0%|          | 0/380 [00:00<?, ?it/s]

pred_path =  ./pred.txt
Accuracy = 0.8443
Epoch 21:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8989573524415576
Epoch 22:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9028294671653669
Epoch 23:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9118892227691515
Epoch 24:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9090161632865815
Epoch 25:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8981568671861547


  0%|          | 0/380 [00:00<?, ?it/s]

pred_path =  ./pred.txt
Accuracy = 0.8620
Epoch 26:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9131830303331167
Epoch 27:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8972741615459915
Epoch 28:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.91108563485772
Epoch 29:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9117930404322621
Epoch 30:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9092224899124896


  0%|          | 0/380 [00:00<?, ?it/s]

pred_path =  ./pred.txt
Accuracy = 0.8749


In [92]:
epoch = 31
while epoch < 61:
    print(f'Epoch {epoch}:')
    train_epoch(train_loader, val_loader, model, optim, loss_fn, alpha)
    accuracy = run_validation(val_loader, model)
    print(f'accuracy: {accuracy}')
    if epoch % 5 == 0:
        torch.save(model.state_dict(), f'arch1_all_{epoch}.pth')
        alpha = min(1.0, alpha * 1.5)
        
        preds = []
        model.eval()
        with torch.no_grad():
            for X, _ in tqdm(test_loader):
                features, logits = model(X.unsqueeze(1).to(torch.float32).cuda())
                classes = torch.argmax(logits, dim=1).cpu().numpy()
                preds.extend(classes)
                
        with open(pred_path, 'w') as f_pred:
            for idx, pred in enumerate(preds):
                name = test_helper.namelist[idx]
                cls = train_helper.vocabulary.class_by_index(pred)
                print(name, cls, file=f_pred)
        _run_evaluation()
    epoch += 1

Epoch 31:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9119729944819263
Epoch 32:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8979707078244331
Epoch 33:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8567053826428114
Epoch 34:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9201733143657628
Epoch 35:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9127067726327123


  0%|          | 0/380 [00:00<?, ?it/s]

pred_path =  ./pred.txt
Accuracy = 0.8832
Epoch 36:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9138749226275152
Epoch 37:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9086841790915113
Epoch 38:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9221512575840548
Epoch 39:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8840490902236859
Epoch 40:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8963899045778139


  0%|          | 0/380 [00:00<?, ?it/s]

pred_path =  ./pred.txt
Accuracy = 0.8615
Epoch 41:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.7939976016468898
Epoch 42:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9221419496159687
Epoch 43:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9273156285438149
Epoch 44:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9119497245617111
Epoch 45:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9285023944747901


  0%|          | 0/380 [00:00<?, ?it/s]

pred_path =  ./pred.txt
Accuracy = 0.8994
Epoch 46:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9272985639356571
Epoch 47:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9282898625368247
Epoch 48:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9259271899709747
Epoch 49:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9197498018178462
Epoch 50:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9273156285438149


  0%|          | 0/380 [00:00<?, ?it/s]

pred_path =  ./pred.txt
Accuracy = 0.8943
Epoch 51:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9126354115440523
Epoch 52:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9170008485764238
Epoch 53:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8533126282754352
Epoch 54:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9209691456371226
Epoch 55:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9275405711058952


  0%|          | 0/380 [00:00<?, ?it/s]

pred_path =  ./pred.txt
Accuracy = 0.9045
Epoch 56:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9323046994379539
Epoch 57:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9328616261951043
Epoch 58:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9322922888138391
Epoch 59:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9207535110431285
Epoch 60:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9290856938081845


  0%|          | 0/380 [00:00<?, ?it/s]

pred_path =  ./pred.txt
Accuracy = 0.9022


In [93]:
epoch = 61
while epoch < 81:
    print(f'Epoch {epoch}:')
    train_epoch(train_loader, val_loader, model, optim, loss_fn, 1.2)
    accuracy = run_validation(val_loader, model)
    print(f'accuracy: {accuracy}')
    if epoch % 5 == 0:
        torch.save(model.state_dict(), f'arch1_all_{epoch}.pth')
        
        preds = []
        model.eval()
        with torch.no_grad():
            for X, _ in tqdm(test_loader):
                features, logits = model(X.unsqueeze(1).to(torch.float32).cuda())
                classes = torch.argmax(logits, dim=1).cpu().numpy()
                preds.extend(classes)
                
        with open(pred_path, 'w') as f_pred:
            for idx, pred in enumerate(preds):
                name = test_helper.namelist[idx]
                cls = train_helper.vocabulary.class_by_index(pred)
                print(name, cls, file=f_pred)
        _run_evaluation()
    epoch += 1

Epoch 61:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.932161977260634
Epoch 62:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.935652465292914
Epoch 63:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9330757094610842
Epoch 64:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9421261570967827
Epoch 65:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.941805032197813


  0%|          | 0/380 [00:00<?, ?it/s]

pred_path =  ./pred.txt
Accuracy = 0.9115
Epoch 66:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9425775935489575
Epoch 67:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9439427622015827
Epoch 68:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9444950349746901
Epoch 69:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.945917602763846
Epoch 70:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9449045855704776


  0%|          | 0/380 [00:00<?, ?it/s]

pred_path =  ./pred.txt
Accuracy = 0.9218
Epoch 71:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9460029258046351
Epoch 72:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9464295410085803
Epoch 73:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9460277470528646
Epoch 74:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9457872912106409
Epoch 75:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.948329917826155


  0%|          | 0/380 [00:00<?, ?it/s]

pred_path =  ./pred.txt
Accuracy = 0.9280
Epoch 76:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9496609572624646
Epoch 77:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9501310096508115
Epoch 78:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9495942501578476
Epoch 79:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9507003470320768
Epoch 80:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9495275430532307


  0%|          | 0/380 [00:00<?, ?it/s]

pred_path =  ./pred.txt
Accuracy = 0.9296


In [100]:
epoch = 81
while epoch < 91:
    print(f'Epoch {epoch}:')
    train_epoch(train_loader, val_loader, model, optim, loss_fn, 1.2)
    accuracy = run_validation(val_loader, model)
    print(f'accuracy: {accuracy}')
    if epoch % 5 == 0:
        torch.save(model.state_dict(), f'arch1_all_{epoch}.pth')
        
        preds = []
        model.eval()
        with torch.no_grad():
            for X, _ in tqdm(test_loader):
                features, logits = model(X.unsqueeze(1).to(torch.float32).cuda())
                classes = torch.argmax(logits, dim=1).cpu().numpy()
                preds.extend(classes)
                
        with open(pred_path, 'w') as f_pred:
            for idx, pred in enumerate(preds):
                name = test_helper.namelist[idx]
                cls = train_helper.vocabulary.class_by_index(pred)
                print(name, cls, file=f_pred)
        _run_evaluation()
    epoch += 1

Epoch 81:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9501511769149981
Epoch 82:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9500487892660512
Epoch 83:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9516047712644409
Epoch 84:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9506088186792303
Epoch 85:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9526596743141966


  0%|          | 0/380 [00:00<?, ?it/s]

pred_path =  ./pred.txt
Accuracy = 0.9324
Epoch 86:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9502333972997584
Epoch 87:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9507422328884642
Epoch 88:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9521570440375483
Epoch 89:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9517723146899904
Epoch 90:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9522904582467822


  0%|          | 0/380 [00:00<?, ?it/s]

pred_path =  ./pred.txt
Accuracy = 0.9299


In [101]:
model.load_state_dict(torch.load(f'arch1_all_85.pth'))

<All keys matched successfully>

In [102]:
test_path = r'/DATA/ichuviliaeva/ocr_data/test.lmdb'
pred_path = './pred.txt'

test_reader = LMDBReader(test_path)
test_reader.open()
test_helper = HWDBDatasetHelper(test_reader, prefix='Test')

In [103]:
test_dataset = HWDBDataset(test_helper, to_one_shape)
test_loader = DataLoader(test_dataset, batch_size=2048, shuffle=False, num_workers=8)

In [49]:
# model.load_state_dict(torch.load(f'arch1_epoch25.pth'))

<All keys matched successfully>

In [104]:
preds = []
model.eval()
with torch.no_grad():
    for X, _ in tqdm(test_loader):
        features, logits = model(X.unsqueeze(1).to(torch.float32).cuda())
        classes = torch.argmax(logits, dim=1).cpu().numpy()
        preds.extend(classes)

  0%|          | 0/380 [00:00<?, ?it/s]

In [105]:
with open(pred_path, 'w') as f_pred:
    for idx, pred in enumerate(preds):
        name = test_helper.namelist[idx]
        cls = train_helper.vocabulary.class_by_index(pred)
        print(name, cls, file=f_pred)

In [106]:
!python -m course_ocr_t2.evaluate
# Accuracy = 0.7978

pred_path =  /home/ichuviliaeva/ocr_course/course_ocr/task2/pred.txt
Accuracy = 0.9324
