In [1]:
from data_reader import Vocabulary, HWDBDatasetHelper, LMDBReader

# your path to data
train_path = r'/DATA/ichuviliaeva/ocr_data/train.lmdb'
test_path = r'/DATA/ichuviliaeva/ocr_data/test.lmdb'
gt_path = r'/DATA/ichuviliaeva/ocr_data/gt.txt'

In [2]:
import cv2
import numpy as np
from centerloss import CenterLoss

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
train_reader = LMDBReader(train_path)
train_reader.open()
train_helper = HWDBDatasetHelper(train_reader)

In [4]:
train_helper, val_helper = train_helper.train_val_split()

In [5]:
train_helper.size(), val_helper.size()

(2578433, 644609)

In [6]:
class pad_prune(object):
    
    def __init__(self, output_w, output_h):
        self.output_w = output_w
        self.output_h = output_h

    def __call__(self, image):
        
        w, h = image.shape
        diff_w = w - self.output_w
        diff_h = h - self.output_h
        masked = np.zeros((self.output_w, self.output_h))
        if diff_w <= 0:
            if diff_h <= 0:
                masked[:w, :h] = image
            else:
                masked[:w, :] = image[:, (diff_h // 2):(h - (diff_h - diff_h // 2))]
            
        if diff_w > 0:
            if diff_h > 0:
                masked = image[(diff_w // 2):(w - (diff_w - diff_w // 2)), 
                               (diff_h // 2):(h - (diff_h - diff_h // 2))]
            else:
                masked[:, :h] = image[(diff_w // 2):(w - (diff_w - diff_w // 2)), :]
        return masked - 127.5 / 255.0  

In [7]:
import torch

from torch.utils.data import Dataset, DataLoader
from torch import nn

class HWDBDataset(Dataset):
    def __init__(self, helper: HWDBDatasetHelper, transform = None):
        self.helper = helper
        self.transform = transform
    
    def __len__(self):
        return self.helper.size()
    
    def __getitem__(self, idx):
        img, label = self.helper.get_item(idx)
        if self.transform:
            img = self.transform(img)
        return img, label

In [8]:
to_one_shape = pad_prune(100, 100)
train_dataset = HWDBDataset(train_helper, to_one_shape)
val_dataset = HWDBDataset(val_helper, to_one_shape)

In [9]:
class ResNetBlock(nn.Module):
        
    def __init__(self, in_channels, out_channels, kernel_size, padding, stride, is_projection = False):
        super().__init__()
        self.bn1 = nn.BatchNorm2d(num_features = in_channels)
        self.relu1 = nn.ReLU()
        self.conv1 = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, 
                               kernel_size = kernel_size, padding = padding, stride = stride)
            
        self.bn2 = nn.BatchNorm2d(num_features = out_channels)
        self.relu2 = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels = out_channels, out_channels = out_channels,
                               kernel_size = kernel_size, padding = padding, stride = 1)
        if is_projection:
            self.project = nn.Conv2d(in_channels = in_channels, out_channels = out_channels,
                               kernel_size = 1, padding = 0, stride = stride)
        else:
            self.project = lambda x: x
            
    def forward(self, x):            
        out = self.conv1(self.relu1(self.bn1(x)))
        out = self.conv2(self.relu2(self.bn2(out)))
        return out + self.project(x)

In [10]:
class TheNet(nn.Module):
    def __init__(self, n_classes):
        super(TheNet, self).__init__()
        self.nn = nn.Sequential(
            nn.Conv2d(in_channels = 1, out_channels = 8, kernel_size = 5, stride = 2),
            nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1),
            
            ResNetBlock(in_channels = 8, out_channels = 8, kernel_size = 3, padding = 1, stride = 1),
            ResNetBlock(in_channels = 8, out_channels = 8, kernel_size = 3, padding = 1, stride = 1),
            
            ResNetBlock(in_channels = 8, out_channels = 16, kernel_size = 3, padding = 1, stride = 2, 
                       is_projection = True),
            ResNetBlock(in_channels = 16, out_channels = 16, kernel_size = 3, padding = 1, stride = 1),
            
            ResNetBlock(in_channels = 16, out_channels = 32, kernel_size = 3, padding = 1, stride = 2, 
                       is_projection = True),
            ResNetBlock(in_channels = 32, out_channels = 32, kernel_size = 3, padding = 1, stride = 1),
            
            ResNetBlock(in_channels = 32, out_channels = 64, kernel_size = 3, padding = 1, stride = 2, 
                       is_projection = True),
            ResNetBlock(in_channels = 64, out_channels = 64, kernel_size = 3, padding = 1, stride = 1),
            
            nn.Flatten(),
            nn.Linear(3 * 3 * 64, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
        )
        
        self.last_layer = nn.Linear(128, n_classes, bias=False)
            
    def forward(self, x):
        x = self.nn(x)
        y = self.last_layer(x)
        return x, y
        return x

In [11]:
model = TheNet(train_helper.vocabulary.num_classes())
model.eval()
res = model(torch.tensor(train_dataset[0][0], dtype=torch.float32).view(1, 1, 100, 100))
print(res[0].shape)
print(res[1].shape)

torch.Size([1, 128])
torch.Size([1, 7330])


In [12]:
model = model.cuda()

In [13]:
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True, drop_last=True, num_workers=8)
val_loader = DataLoader(val_dataset, batch_size=2048, shuffle=False, num_workers=8)

In [14]:
optim = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.CrossEntropyLoss()

In [15]:
center_loss = CenterLoss(num_classes=train_helper.vocabulary.num_classes(), feat_dim=128, use_gpu=True)
optimizer_centloss = torch.optim.Adam(center_loss.parameters(), lr=0.1)

In [16]:
from tqdm import tqdm

def run_validation(val_loader: DataLoader, model: nn.Module, n_steps=None):
    model.eval()
    n_good = 0
    n_all = 0
    wrapper = lambda x: x
    if n_steps is None:
        n_steps = len(val_loader)
        wrapper = tqdm
    
    with torch.no_grad():
        for batch, (X, y) in enumerate(wrapper(val_loader)):
            if batch == n_steps:
                break
            features, logits = model(X.unsqueeze(1).to(torch.float32).cuda())
            classes = torch.argmax(logits, dim=1).cpu().numpy()
            n_good += sum(classes == y.cpu().numpy())
            n_all += len(classes)
    
    return n_good / n_all


def train_epoch(train_loader: DataLoader, val_loader: DataLoader, model: nn.Module, optim, loss_fn, alpha = 1.0):
    for batch, (X, y) in enumerate(tqdm(train_loader)):
        model.train()
        features, logits = model(X.unsqueeze(1).to(torch.float32).cuda())
        labels = y.to(torch.long).cuda()
        loss = center_loss(features, labels) * alpha + loss_fn(logits, labels)
        
        optim.zero_grad()
        optimizer_centloss.zero_grad()
        
        loss.backward()
        optim.step()
        
        for param in center_loss.parameters():
             param.grad.data *= (1./alpha)

        optimizer_centloss.step()

In [46]:
torch.save(model.state_dict(), 'arch2_center.pth')

In [48]:
for epoch in range(11):
    print(f'Epoch {epoch}:')
    train_epoch(train_loader, val_loader, model, optim, loss_fn, alpha = 0.3)
    accuracy = run_validation(val_loader, model)
    print(f'accuracy: {accuracy}')
    if epoch % 5 == 0:
        torch.save(model.state_dict(), f'arch2_center_epoch{epoch}.pth')

Epoch 0:


100%|███████████████████████████████████████████████████████████████████| 5036/5036 [05:14<00:00, 16.00it/s]
100%|█████████████████████████████████████████████████████████████████████| 315/315 [00:34<00:00,  9.09it/s]


accuracy: 0.2692500414980244
Epoch 1:


100%|███████████████████████████████████████████████████████████████████| 5036/5036 [05:12<00:00, 16.09it/s]
100%|█████████████████████████████████████████████████████████████████████| 315/315 [00:35<00:00,  8.96it/s]


accuracy: 0.6155685074207775
Epoch 2:


100%|███████████████████████████████████████████████████████████████████| 5036/5036 [05:12<00:00, 16.09it/s]
100%|█████████████████████████████████████████████████████████████████████| 315/315 [00:34<00:00,  9.06it/s]


accuracy: 0.6668321416548636
Epoch 3:


100%|███████████████████████████████████████████████████████████████████| 5036/5036 [03:38<00:00, 23.01it/s]
100%|█████████████████████████████████████████████████████████████████████| 315/315 [00:27<00:00, 11.26it/s]


accuracy: 0.4381043392195889
Epoch 4:


100%|███████████████████████████████████████████████████████████████████| 5036/5036 [03:31<00:00, 23.81it/s]
100%|█████████████████████████████████████████████████████████████████████| 315/315 [00:27<00:00, 11.27it/s]


accuracy: 0.7269802314271132
Epoch 5:


100%|███████████████████████████████████████████████████████████████████| 5036/5036 [03:35<00:00, 23.36it/s]
100%|█████████████████████████████████████████████████████████████████████| 315/315 [00:28<00:00, 10.98it/s]


accuracy: 0.8276365983099833
Epoch 6:


100%|███████████████████████████████████████████████████████████████████| 5036/5036 [06:05<00:00, 13.78it/s]
100%|█████████████████████████████████████████████████████████████████████| 315/315 [00:35<00:00,  8.82it/s]


accuracy: 0.8423757657742911
Epoch 7:


100%|███████████████████████████████████████████████████████████████████| 5036/5036 [06:34<00:00, 12.75it/s]
100%|█████████████████████████████████████████████████████████████████████| 315/315 [00:36<00:00,  8.66it/s]


accuracy: 0.858565424932013
Epoch 8:


100%|███████████████████████████████████████████████████████████████████| 5036/5036 [06:17<00:00, 13.34it/s]
100%|█████████████████████████████████████████████████████████████████████| 315/315 [00:35<00:00,  8.95it/s]


accuracy: 0.8594776058044489
Epoch 9:


100%|███████████████████████████████████████████████████████████████████| 5036/5036 [06:35<00:00, 12.73it/s]
100%|█████████████████████████████████████████████████████████████████████| 315/315 [00:35<00:00,  8.84it/s]


accuracy: 0.7491083742237542
Epoch 10:


 13%|█████████                                                           | 675/5036 [00:53<05:38, 12.88it/s]

Это упало tornado. Т. е. всё продолжило работать, а вывода нет. Подгружу модели в при тестировании.

In [17]:
test_path = r'/DATA/ichuviliaeva/ocr_data/test.lmdb'
pred_path = './pred.txt'

test_reader = LMDBReader(test_path)
test_reader.open()
test_helper = HWDBDatasetHelper(test_reader, prefix='Test')

In [18]:
test_dataset = HWDBDataset(test_helper, to_one_shape)
test_loader = DataLoader(test_dataset, batch_size=2048, shuffle=False, num_workers=8)

In [27]:
model.load_state_dict(torch.load(f'arch2_center_epoch10.pth'))

<All keys matched successfully>

In [28]:
preds = []
model.eval()
with torch.no_grad():
    for X, _ in tqdm(test_loader):
        features, logits = model(X.unsqueeze(1).to(torch.float32).cuda())
        classes = torch.argmax(logits, dim=1).cpu().numpy()
        preds.extend(classes)

100%|█████████████████████████████████████████████████████████████████████| 380/380 [00:43<00:00,  8.77it/s]


In [29]:
with open(pred_path, 'w') as f_pred:
    for idx, pred in enumerate(preds):
        name = test_helper.namelist[idx]
        cls = train_helper.vocabulary.class_by_index(pred)
        print(name, cls, file=f_pred)

In [30]:
!python -m course_ocr_t2.evaluate
# Accuracy = 0.7978

/bin/bash: /home/ichuviliaeva/miniconda3/envs/ocr_course/lib/python3.8/site-packages/cv2/../../../../lib/libtinfo.so.6: no version information available (required by /bin/bash)
pred_path =  /home/ichuviliaeva/ocr_course/course_ocr/task2/pred.txt
Accuracy = 0.8111


In [None]:
epoch = 11
while epoch < 21:
    print(f'Epoch {epoch}:')
    train_epoch(train_loader, val_loader, model, optim, loss_fn, alpha = 0.3)
    accuracy = run_validation(val_loader, model)
    print(f'accuracy: {accuracy}')
    epoch += 1
    if epoch % 5 == 0:
        torch.save(model.state_dict(), f'arch2_center_epoch{epoch}.pth')

Epoch 11:


100%|███████████████████████████████████████████████████████████████████| 5036/5036 [06:34<00:00, 12.78it/s]
100%|█████████████████████████████████████████████████████████████████████| 315/315 [00:35<00:00,  8.90it/s]


accuracy: 0.8068581108858238
Epoch 12:


  0%|                                                                    | 1/5036 [00:00<1:10:09,  1.20it/s]

In [23]:
model.load_state_dict(torch.load(f'arch2_center_epoch20.pth'))

<All keys matched successfully>

In [24]:
preds = []
model.eval()
with torch.no_grad():
    for X, _ in tqdm(test_loader):
        features, logits = model(X.unsqueeze(1).to(torch.float32).cuda())
        classes = torch.argmax(logits, dim=1).cpu().numpy()
        preds.extend(classes)

100%|█████████████████████████████████████████████████████████████████████| 380/380 [00:44<00:00,  8.61it/s]


In [25]:
with open(pred_path, 'w') as f_pred:
    for idx, pred in enumerate(preds):
        name = test_helper.namelist[idx]
        cls = train_helper.vocabulary.class_by_index(pred)
        print(name, cls, file=f_pred)

In [26]:
!python -m course_ocr_t2.evaluate
# Accuracy = 0.7978

/bin/bash: /home/ichuviliaeva/miniconda3/envs/ocr_course/lib/python3.8/site-packages/cv2/../../../../lib/libtinfo.so.6: no version information available (required by /bin/bash)
pred_path =  /home/ichuviliaeva/ocr_course/course_ocr/task2/pred.txt
Accuracy = 0.8025
