In [1]:
from data_reader import Vocabulary, HWDBDatasetHelper, LMDBReader

# your path to data
train_path = r'/DATA/ichuviliaeva/ocr_data/train.lmdb'
test_path = r'/DATA/ichuviliaeva/ocr_data/test.lmdb'
gt_path = r'/DATA/ichuviliaeva/ocr_data/gt.txt'

In [2]:
import cv2
import numpy as np
from centerloss import CenterLoss
from torchvision import transforms

In [3]:
train_reader = LMDBReader(train_path)
train_reader.open()
train_helper = HWDBDatasetHelper(train_reader)

In [4]:
train_helper, val_helper = train_helper.train_val_split()

In [5]:
train_helper.size(), val_helper.size()

(2578433, 644609)

In [6]:
import torch

from torch.utils.data import Dataset, DataLoader
from torch import nn

class HWDBDataset(Dataset):
    def __init__(self, helper: HWDBDatasetHelper, transform = None):
        self.helper = helper
        self.transform = transform
    
    def __len__(self):
        return self.helper.size()
    
    def __getitem__(self, idx):
        img, label = self.helper.get_item(idx)
        if self.transform:
            img = self.transform(img)
        return img, label

In [13]:
class pad_prune(object):
    
    def __init__(self, output_w, output_h):
        self.output_w = output_w
        self.output_h = output_h

    def __call__(self, image):
        
        w, h = image.shape
        diff_w = w - self.output_w
        diff_h = h - self.output_h
        masked = np.zeros((self.output_w, self.output_h))
        if diff_w <= 0:
            if diff_h <= 0:
                masked[:w, :h] = image
            else:
                masked[:w, :] = image[:, (diff_h // 2):(h - (diff_h - diff_h // 2))]
            
        if diff_w > 0:
            if diff_h > 0:
                masked = image[(diff_w // 2):(w - (diff_w - diff_w // 2)), 
                               (diff_h // 2):(h - (diff_h - diff_h // 2))]
            else:
                masked[:, :h] = image[(diff_w // 2):(w - (diff_w - diff_w // 2)), :]
        return masked - 127.5 / 255.0  
    
class to_zero_one(object):
    def __init__(self):
        pass
    def __call__(self, image):
        return image - 127.5 / 255.0 

In [14]:
to_one_shape = pad_prune(128, 128)

train_transforms = transforms.Compose([
    to_one_shape,
    transforms.ToTensor(),
    transforms.RandomPerspective(distortion_scale = 0.2, p = 0.4),  
    transforms.ColorJitter(),
    # transforms.GaussianBlur(kernel_size = 5),
    to_zero_one()
])

val_transforms = transforms.Compose([
    to_one_shape,
    transforms.ToTensor()
])

In [15]:
train_dataset = HWDBDataset(train_helper, train_transforms)
val_dataset = HWDBDataset(val_helper, val_transforms)

In [16]:
class SimpleBlock(nn.Module):
        
    def __init__(self, in_channels, out_channels, kernel_size, padding, stride, is_projection = False):
        super().__init__()
        self.bn1 = nn.BatchNorm2d(num_features = in_channels)
        self.relu1 = nn.ReLU()
        self.conv1 = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, 
                               kernel_size = kernel_size, padding = padding, stride = stride)
            
        self.bn2 = nn.BatchNorm2d(num_features = out_channels)
        self.relu2 = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels = out_channels, out_channels = out_channels,
                               kernel_size = kernel_size, padding = padding, stride = 1)
        if is_projection:
            self.project = nn.Conv2d(in_channels = in_channels, out_channels = out_channels,
                               kernel_size = 1, padding = 0, stride = stride)
        else:
            self.project = lambda x: x
            
    def forward(self, x):            
        out = self.conv1(self.relu1(self.bn1(x)))
        out = self.conv2(self.relu2(self.bn2(out)))
        return out + self.project(x)

In [17]:
class TheNet(nn.Module):
    def __init__(self, n_classes):
        super(TheNet, self).__init__()
        self.nn = nn.Sequential(
            nn.Conv2d(in_channels = 1, out_channels = 4, kernel_size = 5, stride = 2),
            nn.ReLU(),
            
            SimpleBlock(in_channels = 4, out_channels = 4, kernel_size = 3, padding = 1, stride = 1),
            SimpleBlock(in_channels = 4, out_channels = 8, kernel_size = 3, padding = 1, stride = 2, 
                             is_projection = True),
            
            nn.Conv2d(8, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.MaxPool2d((2, 2)),
           
            SimpleBlock(in_channels = 16, out_channels = 16, kernel_size = 3, padding = 1, stride = 1),
            SimpleBlock(in_channels = 16, out_channels = 32, kernel_size = 3, padding = 1, stride = 2, 
                             is_projection = True),
            
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.MaxPool2d((2, 2)),
            
            SimpleBlock(in_channels = 64, out_channels = 64, kernel_size = 3, padding = 1, stride = 1),
            SimpleBlock(in_channels = 64, out_channels = 128, kernel_size = 3, padding = 1, stride = 2, 
                             is_projection = True),
            
            nn.Flatten(),
            nn.Linear(2 * 2 * 128, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            
        )
        
        self.last_layer = nn.Linear(256, n_classes, bias=False)
            
    def forward(self, x):
        x = self.nn(x)
        y = self.last_layer(x)
        return x, y

In [18]:
model = TheNet(train_helper.vocabulary.num_classes())
model.eval()
res = model(torch.tensor(train_dataset[0][0], dtype=torch.float32).view(1, 1, 128, 128))
print(res[0].shape)
print(res[1].shape)

torch.Size([1, 256])
torch.Size([1, 7330])


  res = model(torch.tensor(train_dataset[0][0], dtype=torch.float32).view(1, 1, 128, 128))


In [19]:
model = model.cuda()

In [55]:
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True, drop_last=True, num_workers=8)
val_loader = DataLoader(val_dataset, batch_size=2048, shuffle=False, num_workers=8)

In [20]:
optim = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.CrossEntropyLoss()

In [21]:
center_loss = CenterLoss(num_classes=train_helper.vocabulary.num_classes(), feat_dim=256, use_gpu=True)
optimizer_centloss = torch.optim.Adam(center_loss.parameters(), lr=0.5)

In [22]:
from tqdm.notebook import tqdm

def run_validation(val_loader: DataLoader, model: nn.Module, n_steps=None):
    model.eval()
    n_good = 0
    n_all = 0
    wrapper = lambda x: x
    if n_steps is None:
        n_steps = len(val_loader)
        wrapper = tqdm
    
    with torch.no_grad():
        for batch, (X, y) in enumerate(wrapper(val_loader)):
            if batch == n_steps:
                break
            features, logits = model(X.to(torch.float32).cuda())
            classes = torch.argmax(logits, dim=1).cpu().numpy()
            n_good += sum(classes == y.cpu().numpy())
            n_all += len(classes)
    
    return n_good / n_all


def train_epoch(train_loader: DataLoader, val_loader: DataLoader, model: nn.Module, optim, loss_fn, alpha = 1.0):
    for batch, (X, y) in enumerate(tqdm(train_loader)):
        model.train()
        features, logits = model(X.to(torch.float32).cuda())
        labels = y.to(torch.long).cuda()
        loss = center_loss(features, labels) * alpha + loss_fn(logits, labels)
        
        optim.zero_grad()
        optimizer_centloss.zero_grad()
        
        loss.backward()
        optim.step()
        
        for param in center_loss.parameters():
             param.grad.data *= (1./alpha)

        optimizer_centloss.step()

In [59]:
for epoch in range(11):
    print(f'Epoch {epoch}:')
    train_epoch(train_loader, val_loader, model, optim, loss_fn, alpha = 0.3)
    accuracy = run_validation(val_loader, model)
    print(f'accuracy: {accuracy}')
    if epoch % 5 == 0:
        torch.save(model.state_dict(), f'arch1__center_epoch{epoch}.pth')

Epoch 0:


100%|███████████████████████████████████████████████████████████████████| 5036/5036 [10:36<00:00,  7.91it/s]
100%|█████████████████████████████████████████████████████████████████████| 315/315 [01:08<00:00,  4.63it/s]


accuracy: 0.02828226102955435
Epoch 1:


100%|███████████████████████████████████████████████████████████████████| 5036/5036 [10:38<00:00,  7.89it/s]
100%|█████████████████████████████████████████████████████████████████████| 315/315 [01:09<00:00,  4.51it/s]


accuracy: 0.3832307646961181
Epoch 2:


100%|███████████████████████████████████████████████████████████████████| 5036/5036 [10:55<00:00,  7.68it/s]
100%|█████████████████████████████████████████████████████████████████████| 315/315 [01:04<00:00,  4.86it/s]


accuracy: 0.6026040592048824
Epoch 3:


100%|███████████████████████████████████████████████████████████████████| 5036/5036 [10:49<00:00,  7.75it/s]
100%|█████████████████████████████████████████████████████████████████████| 315/315 [01:03<00:00,  4.93it/s]


accuracy: 0.7261595789075238
Epoch 4:


100%|███████████████████████████████████████████████████████████████████| 5036/5036 [11:02<00:00,  7.60it/s]
100%|█████████████████████████████████████████████████████████████████████| 315/315 [01:03<00:00,  4.96it/s]


accuracy: 0.7616120780193885
Epoch 5:


100%|███████████████████████████████████████████████████████████████████| 5036/5036 [10:56<00:00,  7.68it/s]
100%|█████████████████████████████████████████████████████████████████████| 315/315 [01:01<00:00,  5.13it/s]


accuracy: 0.652151924655101
Epoch 6:


100%|███████████████████████████████████████████████████████████████████| 5036/5036 [10:17<00:00,  8.16it/s]
100%|█████████████████████████████████████████████████████████████████████| 315/315 [00:56<00:00,  5.58it/s]


accuracy: 0.8128214157729724
Epoch 7:


100%|███████████████████████████████████████████████████████████████████| 5036/5036 [09:39<00:00,  8.69it/s]
100%|█████████████████████████████████████████████████████████████████████| 315/315 [00:57<00:00,  5.46it/s]


accuracy: 0.7856762781779342
Epoch 8:


100%|███████████████████████████████████████████████████████████████████| 5036/5036 [09:44<00:00,  8.62it/s]
100%|█████████████████████████████████████████████████████████████████████| 315/315 [00:57<00:00,  5.43it/s]


accuracy: 0.8424300622547932
Epoch 9:


100%|███████████████████████████████████████████████████████████████████| 5036/5036 [09:36<00:00,  8.74it/s]
100%|█████████████████████████████████████████████████████████████████████| 315/315 [00:58<00:00,  5.36it/s]


accuracy: 0.8086111115420356
Epoch 10:


100%|███████████████████████████████████████████████████████████████████| 5036/5036 [07:51<00:00, 10.67it/s]
100%|█████████████████████████████████████████████████████████████████████| 315/315 [00:48<00:00,  6.51it/s]

accuracy: 0.6892131509178432





In [23]:
to_one_shape = pad_prune(128, 128)

train_transforms = transforms.Compose([
    to_one_shape,
    transforms.ToTensor(),
    transforms.RandomChoice([
        transforms.RandomPerspective(distortion_scale = 0.2, p = 0.7),  
        transforms.ColorJitter(),
        transforms.GaussianBlur(kernel_size = 3)
    ]),
    
    transforms.RandomChoice([
        transforms.RandomPerspective(distortion_scale = 0.2, p = 0.7),  
        transforms.ColorJitter(),
        transforms.GaussianBlur(kernel_size = 5)
    ]),
    
    to_zero_one()
])

val_transforms = transforms.Compose([
    to_one_shape,
    transforms.ToTensor()
])

In [24]:
train_dataset = HWDBDataset(train_helper, train_transforms)
val_dataset = HWDBDataset(val_helper, val_transforms)

In [25]:
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True, drop_last=True, num_workers=8, 
                         persistent_workers = True)
val_loader = DataLoader(val_dataset, batch_size=2048, shuffle=False, num_workers=8, 
                         persistent_workers = True)

In [26]:
 model.load_state_dict(torch.load(f'arch1__center_epoch10.pth'))

<All keys matched successfully>

In [27]:
epoch = 11
while epoch < 21:
    print(f'Epoch {epoch}:')
    train_epoch(train_loader, val_loader, model, optim, loss_fn, alpha = 0.5)
    accuracy = run_validation(val_loader, model)
    print(f'accuracy: {accuracy}')
    epoch += 1
    if epoch % 5 == 0:
        torch.save(model.state_dict(), f'arch1__center_epoch{epoch}.pth')

Epoch 11:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8360851306761153
Epoch 12:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8262435057530999
Epoch 13:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8668526191846531
Epoch 14:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8744618830950235
Epoch 15:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8787419970866059
Epoch 16:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8757556906589886
Epoch 17:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8809185102907344
Epoch 18:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8736831164318215
Epoch 19:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.886760811592764
Epoch 20:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8724839398767315


In [33]:
 model.load_state_dict(torch.load(f'arch1__center_epoch20.pth'))

<All keys matched successfully>

In [34]:
epoch = 21
while epoch < 31:
    print(f'Epoch {epoch}:')
    train_epoch(train_loader, val_loader, model, optim, loss_fn, alpha = 0.3)
    accuracy = run_validation(val_loader, model)
    print(f'accuracy: {accuracy}')
    epoch += 1
    if epoch % 5 == 0:
        torch.save(model.state_dict(), f'arch1__center_epoch{epoch}.pth')

Epoch 21:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8900573836232507
Epoch 22:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8722512406745795
Epoch 23:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.611400089046228
Epoch 24:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8254088912813814
Epoch 25:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9018754004365437
Epoch 26:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.895983456638055
Epoch 27:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9022228979117574
Epoch 28:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9073655502793166
Epoch 29:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9088811977493333
Epoch 30:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9081846514708917


In [40]:
epoch = 31
while epoch < 41:
    print(f'Epoch {epoch}:')
    train_epoch(train_loader, val_loader, model, optim, loss_fn, alpha = 0.5)
    accuracy = run_validation(val_loader, model)
    print(f'accuracy: {accuracy}')
    epoch += 1
    if epoch % 5 == 0:
        torch.save(model.state_dict(), f'arch1__center_epoch{epoch}.pth')

Epoch 31:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9012191886864751
Epoch 32:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9096723750366501
Epoch 33:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8991000746188775
Epoch 34:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8512369513922393
Epoch 35:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.912258438836566
Epoch 36:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9118318236326207
Epoch 37:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9068830872668547
Epoch 38:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9096335918362914
Epoch 39:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9053286565964794
Epoch 40:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9141603669821551


Тут я просто захотела периодически запускать тест в середине. Это не должно приводить к каким-либо утечкам данных, т. к. всё предсказывается под with torch.no_grad()

In [45]:
test_path = r'/DATA/ichuviliaeva/ocr_data/test.lmdb'
pred_path = './pred.txt'

test_reader = LMDBReader(test_path)
test_reader.open()
test_helper = HWDBDatasetHelper(test_reader, prefix='Test')

In [54]:
test_dataset = HWDBDataset(test_helper, val_transforms)
test_loader = DataLoader(test_dataset, batch_size=2048, shuffle=False, num_workers=8)

In [55]:
from pathlib import Path

def evaluate(gt_path, pred_path):
    gt = dict()
    with open(gt_path) as gt_f:
        for line in gt_f:
            name, cls = line.strip().split()
            gt[name] = cls
    
    n_good = 0
    n_all = len(gt)
    with open(pred_path) as pred_f:
        for line in pred_f:
            name, cls = line.strip().split()
            if cls == gt[name]:
                n_good += 1
    
    return n_good / n_all


def _run_evaluation():
    base = '.'
    pred_path = base + '/pred.txt'
    print('pred_path = ', pred_path)
    score = evaluate(gt_path, pred_path)
    print('Accuracy = {:1.4f}'.format(score))

In [61]:
to_one_shape = pad_prune(128, 128)

train_transforms = transforms.Compose([
    to_one_shape,
    transforms.ToTensor(),
    transforms.RandomChoice([
        transforms.RandomPerspective(distortion_scale = 0.2, p = 0.7),  
        transforms.ColorJitter(),
        transforms.GaussianBlur(kernel_size = 3),
        transforms.RandomRotation(degrees=10)
    ]),
    
    transforms.RandomChoice([
        transforms.RandomPerspective(distortion_scale = 0.2, p = 0.7),  
        transforms.ColorJitter(),
        transforms.GaussianBlur(kernel_size = 5),
        transforms.RandomRotation(degrees=10)
    ]),
    
    to_zero_one()
])

val_transforms = transforms.Compose([
    to_one_shape,
    transforms.ToTensor()
])

In [62]:
train_dataset = HWDBDataset(train_helper, train_transforms)
val_dataset = HWDBDataset(val_helper, val_transforms)

In [63]:
train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True, drop_last=True, num_workers=8, 
                         persistent_workers = True)
val_loader = DataLoader(val_dataset, batch_size=2048, shuffle=False, num_workers=8, 
                         persistent_workers = True)

In [64]:
epoch = 41
while epoch < 51:
    print(f'Epoch {epoch}:')
    train_epoch(train_loader, val_loader, model, optim, loss_fn, alpha = 0.7)
    accuracy = run_validation(val_loader, model)
    print(f'accuracy: {accuracy}')
    epoch += 1
    if epoch % 5 == 0:
        torch.save(model.state_dict(), f'arch1__center_epoch{epoch}.pth')
        preds = []
        model.eval()
        with torch.no_grad():
            for X, _ in tqdm(test_loader):
                features, logits = model(X.to(torch.float32).cuda())
                classes = torch.argmax(logits, dim=1).cpu().numpy()
                preds.extend(classes)
                
        with open(pred_path, 'w') as f_pred:
            for idx, pred in enumerate(preds):
                name = test_helper.namelist[idx]
                cls = train_helper.vocabulary.class_by_index(pred)
                print(name, cls, file=f_pred)
        _run_evaluation()

Epoch 41:


  0%|          | 0/2518 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f43f9a843a0>
Traceback (most recent call last):
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    if w.is_alive():
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f43f9a843a0>
Traceback (most recent call last):
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/

  0%|          | 0/315 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f43f9a843a0>
Traceback (most recent call last):
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    Exception ignored in: self._shutdown_workers()
<function _MultiProcessingDataLoaderIter.__del__ at 0x7f43f9a843a0>  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers

Traceback (most recent call last):
      File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
if w.is_alive():    
self._shutdown_workers()  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive

      File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, i

accuracy: 0.9145792255460287
Epoch 42:


  0%|          | 0/2518 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.88157006805676
Epoch 43:


  0%|          | 0/2518 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9108265630793241
Epoch 44:


  0%|          | 0/2518 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9171575327058729


  0%|          | 0/380 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f43f9a843a0>
Traceback (most recent call last):
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    if w.is_alive():
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f43f9a843a0>
Traceback (most recent call last):
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/

pred_path =  ./pred.txt
Accuracy = 0.8907
Epoch 45:


  0%|          | 0/2518 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9178339117201281
Epoch 46:


  0%|          | 0/2518 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9178711435924723
Epoch 47:


  0%|          | 0/2518 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9168503697590322
Epoch 48:


  0%|          | 0/2518 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9204835799686322
Epoch 49:


  0%|          | 0/2518 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9183070667645038


  0%|          | 0/380 [00:00<?, ?it/s]

pred_path =  ./pred.txt
Accuracy = 0.8912
Epoch 50:


  0%|          | 0/2518 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9166471457891529


In [68]:
train_dataset = HWDBDataset(train_helper, val_transforms)
val_dataset = HWDBDataset(val_helper, val_transforms)

In [69]:
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True, drop_last=True, num_workers=8, 
                         persistent_workers = True)
val_loader = DataLoader(val_dataset, batch_size=2048, shuffle=False, num_workers=8, 
                         persistent_workers = True)

In [70]:
epoch = 51
while epoch < 53:
    print(f'Epoch {epoch}:')
    train_epoch(train_loader, val_loader, model, optim, loss_fn, alpha = 0.7)
    accuracy = run_validation(val_loader, model)
    print(f'accuracy: {accuracy}')
    epoch += 1
    
    torch.save(model.state_dict(), f'arch1__center_epoch{epoch}.pth')
    preds = []
    model.eval()
    with torch.no_grad():
        for X, _ in tqdm(test_loader):
            features, logits = model(X.to(torch.float32).cuda())
            classes = torch.argmax(logits, dim=1).cpu().numpy()
            preds.extend(classes)
                
    with open(pred_path, 'w') as f_pred:
        for idx, pred in enumerate(preds):
            name = test_helper.namelist[idx]
            cls = train_helper.vocabulary.class_by_index(pred)
            print(name, cls, file=f_pred)
    _run_evaluation()

Epoch 51:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.6512412951106795


  0%|          | 0/380 [00:00<?, ?it/s]

pred_path =  ./pred.txt
Accuracy = 0.7456
Epoch 52:


  0%|          | 0/5036 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [71]:
model.load_state_dict(torch.load(f'arch1__center_epoch50.pth'))

<All keys matched successfully>

In [72]:
preds = []
model.eval()
with torch.no_grad():
    for X, _ in tqdm(test_loader):
        features, logits = model(X.to(torch.float32).cuda())
        classes = torch.argmax(logits, dim=1).cpu().numpy()
        preds.extend(classes)

  0%|          | 0/380 [00:00<?, ?it/s]

In [73]:
with open(pred_path, 'w') as f_pred:
    for idx, pred in enumerate(preds):
        name = test_helper.namelist[idx]
        cls = train_helper.vocabulary.class_by_index(pred)
        print(name, cls, file=f_pred)

In [74]:
!python -m course_ocr_t2.evaluate
# Accuracy = 0.7978

pred_path =  /home/ichuviliaeva/ocr_course/course_ocr/task2/pred.txt
Accuracy = 0.8912


In [None]:
# get 0.8357 on 20 epochs
# get 0.8767 on 30 epochs