In [1]:
from data_reader import Vocabulary, HWDBDatasetHelper, LMDBReader

# your path to data
train_path = r'/DATA/ichuviliaeva/ocr_data/train.lmdb'
test_path = r'/DATA/ichuviliaeva/ocr_data/test.lmdb'
gt_path = r'/DATA/ichuviliaeva/ocr_data/gt.txt'

In [2]:
import cv2
import numpy as np
from centerloss import CenterLoss
from torchvision import transforms

In [3]:
train_reader = LMDBReader(train_path)
train_reader.open()
train_helper = HWDBDatasetHelper(train_reader)

In [4]:
train_helper, val_helper = train_helper.train_val_split()

In [5]:
train_helper.size(), val_helper.size()

(2578433, 644609)

In [6]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="2"

In [7]:
import torch

from torch.utils.data import Dataset, DataLoader
from torch import nn

class HWDBDataset(Dataset):
    def __init__(self, helper: HWDBDatasetHelper, transform = None):
        self.helper = helper
        self.transform = transform
    
    def __len__(self):
        return self.helper.size()
    
    def __getitem__(self, idx):
        img, label = self.helper.get_item(idx)
        if self.transform:
            img = self.transform(img)
        return img, label

In [8]:
class pad_prune(object):
    
    def __init__(self, output_w, output_h):
        self.output_w = output_w
        self.output_h = output_h

    def __call__(self, image):
        
        w, h = image.shape
        diff_w = w - self.output_w
        diff_h = h - self.output_h
        masked = np.zeros((self.output_w, self.output_h))
        if diff_w <= 0:
            if diff_h <= 0:
                masked[:w, :h] = image
            else:
                masked[:w, :] = image[:, (diff_h // 2):(h - (diff_h - diff_h // 2))]
            
        if diff_w > 0:
            if diff_h > 0:
                masked = image[(diff_w // 2):(w - (diff_w - diff_w // 2)), 
                               (diff_h // 2):(h - (diff_h - diff_h // 2))]
            else:
                masked[:, :h] = image[(diff_w // 2):(w - (diff_w - diff_w // 2)), :]
        return masked

class to_zero_one(object):
    def __init__(self):
        pass
    def __call__(self, image):
        return image - 127.5 / 255.0 

In [9]:
to_one_shape = pad_prune(128, 128)

In [9]:
train_transforms = transforms.Compose([
    to_one_shape,
    transforms.ToTensor(),
    transforms.RandomPerspective(distortion_scale = 0.2, p = 0.5),  
    transforms.ColorJitter(),
    transforms.GaussianBlur(kernel_size = 5),
    to_zero_one()
])

# transforms.RandomPerspective(distortion_scale = 0.2, p = 0.5),
    

In [10]:
val_transforms = transforms.Compose([
    to_one_shape,
    transforms.ToTensor()
])

In [11]:
train_dataset = HWDBDataset(train_helper, train_transforms)
val_dataset = HWDBDataset(val_helper, val_transforms)

In [10]:
class SimpleBlock(nn.Module):
        
    def __init__(self, in_channels, out_channels, kernel_size, padding, stride, is_projection = False):
        super().__init__()
        self.bn1 = nn.BatchNorm2d(num_features = in_channels)
        self.relu1 = nn.ReLU()
        self.conv1 = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, 
                               kernel_size = kernel_size, padding = padding, stride = stride)
            
        self.bn2 = nn.BatchNorm2d(num_features = out_channels)
        self.relu2 = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels = out_channels, out_channels = out_channels,
                               kernel_size = kernel_size, padding = padding, stride = 1)
        if is_projection:
            self.project = nn.Conv2d(in_channels = in_channels, out_channels = out_channels,
                               kernel_size = 1, padding = 0, stride = stride)
        else:
            self.project = lambda x: x
            
    def forward(self, x):            
        out = self.conv1(self.relu1(self.bn1(x)))
        out = self.conv2(self.relu2(self.bn2(out)))
        return out + self.project(x)

In [11]:
class TheNet(nn.Module):
    def __init__(self, n_classes):
        super(TheNet, self).__init__()
        self.nn = nn.Sequential(
            nn.Conv2d(in_channels = 1, out_channels = 8, kernel_size = 5, stride = 2),
            nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1),
            
            SimpleBlock(in_channels = 8, out_channels = 8, kernel_size = 3, padding = 1, stride = 1),
            SimpleBlock(in_channels = 8, out_channels = 8, kernel_size = 3, padding = 1, stride = 1),
            SimpleBlock(in_channels = 8, out_channels = 16, kernel_size = 3, padding = 1, stride = 2, 
                             is_projection = True),
           
            SimpleBlock(in_channels = 16, out_channels = 16, kernel_size = 3, padding = 1, stride = 1),
            SimpleBlock(in_channels = 16, out_channels = 32, kernel_size = 3, padding = 1, stride = 2, 
                             is_projection = True),
            
            SimpleBlock(in_channels = 32, out_channels = 32, kernel_size = 3, padding = 1, stride = 1),
            SimpleBlock(in_channels = 32, out_channels = 64, kernel_size = 3, padding = 1, stride = 2, 
                             is_projection = True),
            
            SimpleBlock(in_channels = 64, out_channels = 64, kernel_size = 3, padding = 1, stride = 1),
            SimpleBlock(in_channels = 64, out_channels = 128, kernel_size = 3, padding = 1, stride = 2, 
                             is_projection = True),
            
            nn.Flatten(),
            nn.Linear(2 * 2 * 128, 256),
            nn.BatchNorm1d(256), 
            nn.ReLU()
            
        )
        self.last_layer = nn.Linear(256, n_classes, bias=False)
        
            
    def forward(self, x):
        x = self.nn(x)
        y = self.last_layer(x)
        return x, y

In [None]:
model = TheNet(train_helper.vocabulary.num_classes())
model.eval()
res = model(torch.tensor(train_dataset[0][0], dtype=torch.float32).view(1, 1, 128, 128))
print(res[0].shape)
print(res[1].shape)

In [13]:
model = model.cuda()

In [28]:
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True, drop_last=True, 
                        num_workers=8, persistent_workers = True) # 8
val_loader = DataLoader(val_dataset, batch_size=2048, shuffle=False, num_workers=8, 
                        persistent_workers = True) # 8

In [14]:
optim = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.CrossEntropyLoss()

In [15]:
center_loss = CenterLoss(num_classes=train_helper.vocabulary.num_classes(), feat_dim=256, use_gpu=True)
optimizer_centloss = torch.optim.Adam(center_loss.parameters(), lr=0.3)

In [34]:
from tqdm.notebook import tqdm

def run_validation(val_loader: DataLoader, model: nn.Module, n_steps=None):
    model.eval()
    n_good = 0
    n_all = 0
    wrapper = lambda x: x
    if n_steps is None:
        n_steps = len(val_loader)
        wrapper = tqdm
    
    with torch.no_grad():
        for batch, (X, y) in enumerate(wrapper(val_loader)):
            if batch == n_steps:
                break
            features, logits = model(X.to(torch.float32).cuda())
            classes = torch.argmax(logits, dim=1).cpu().numpy()
            n_good += sum(classes == y.cpu().numpy())
            n_all += len(classes)
    
    return n_good / n_all


def train_epoch(train_loader: DataLoader, val_loader: DataLoader, model: nn.Module, optim, loss_fn, alpha = 1.0):
    for batch, (X, y) in enumerate(tqdm(train_loader)):
        model.train()
        features, logits = model(X.to(torch.float32).cuda())
        labels = y.to(torch.long).cuda()
        loss = center_loss(features, labels) * alpha + loss_fn(logits, labels)
        
        optim.zero_grad()
        optimizer_centloss.zero_grad()
        
        loss.backward()
        optim.step()
        
        for param in center_loss.parameters():
             param.grad.data *= (1./alpha)

        optimizer_centloss.step()

In [196]:
for epoch in range(11):
    print(f'Epoch {epoch}:')
    train_epoch(train_loader, val_loader, model, optim, loss_fn, alpha = 0.3)
    accuracy = run_validation(val_loader, model)
    print(f'accuracy: {accuracy}')
    if epoch % 5 == 0:
        torch.save(model.state_dict(), f'arch2_aug_epoch{epoch}.pth')

Epoch 0:


  0%|          | 0/5036 [00:01<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f8ca8a0e3a0>
Traceback (most recent call last):
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    if w.is_alive():
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f8ca8a0e3a0>
Traceback (most recent call last):
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/

    assert self._parent_pid == os.getpid(), 'can only test a child process': 
AssertionErrorcan only test a child process: 
can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f8ca8a0e3a0>
Traceback (most recent call last):
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    if w.is_alive():
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f8ca8a0e3a0>
Traceback (most recent call last):
  File "/home/ichuvil

    if w.is_alive():
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process


  0%|          | 0/315 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f8ca8a0e3a0>
Traceback (most recent call last):
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    if w.is_alive():
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f8ca8a0e3a0>
Traceback (most recent call last):
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/

    if w.is_alive():
      File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
assert self._parent_pid == os.getpid(), 'can only test a child process'    
AssertionErrorassert self._parent_pid == os.getpid(), 'can only test a child process': 
can only test a child processAssertionError
: can only test a child process


accuracy: 0.2501733609056032
Epoch 1:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.49453699839747817
Epoch 2:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.6235330254464334
Epoch 3:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.7020333256284041
Epoch 4:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.7630765316649317
Epoch 5:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.7886703412456233
Epoch 6:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.7748387006697083
Epoch 7:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8147450625107623
Epoch 8:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8534243238924681
Epoch 9:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8479884705301973
Epoch 10:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8497492278264809


In [None]:
epoch = 11
# changing alpha
while epoch < 21:
    print(f'Epoch {epoch}:')
    train_epoch(train_loader, val_loader, model, optim, loss_fn, alpha = 0.5)
    accuracy = run_validation(val_loader, model)
    print(f'accuracy: {accuracy}')
    if epoch % 5 == 0:
        torch.save(model.state_dict(), f'arch2_aug_epoch{epoch}.pth')
    epoch += 1

Epoch 11:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8654036788192532
Epoch 12:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8706270002435585
Epoch 13:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.878200583609599
Epoch 14:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8767020007477401
Epoch 15:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8822511010550582
Epoch 16:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.8796805505352857
Epoch 17:


  0%|          | 0/5036 [00:00<?, ?it/s]

In [23]:
model.load_state_dict(torch.load(f'arch2_aug_epoch15.pth'))

<All keys matched successfully>

In [24]:
epoch = 16
# changing alpha
while epoch < 26:
    print(f'Epoch {epoch}:')
    train_epoch(train_loader, val_loader, model, optim, loss_fn, alpha = 0.5)
    accuracy = run_validation(val_loader, model)
    print(f'accuracy: {accuracy}')
    if epoch % 5 == 0:
        torch.save(model.state_dict(), f'arch2_aug_epoch{epoch}.pth')
    epoch += 1

Epoch 16:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>
Traceback (most recent call last):
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    if w.is_alive():
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers

<function _MultiProcessingDataLoaderIter.

accuracy: 0.8822200744947712
Epoch 17:


  0%|          | 0/5036 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>
Traceback (most recent call last):
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    if w.is_alive():
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>
Traceback (most recent call last):
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/

    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>
Traceback (most recent call last):
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    if w.is_alive():
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>
Traceback (most recent call last):
  File "/home/ichuviliaeva/miniconda3/envs/is_project

  0%|          | 0/315 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>
Traceback (most recent call last):
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>    
self._shutdown_workers()Traceback (most recent call last):

  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
        if w.is_alive():self._shutdown_workers()
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
    
assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception

    self._shutdown_workers()
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>    
if w.is_alive():Traceback (most recent call last):

Exception ignored in:   File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
<function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>    
    assert self._parent_pid == os.getpid(), 'can only test a child process'Traceback (most recent call last):
self._shutdown_workers()
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__

AssertionError    :   File "/home/ichuviliaeva/minicond

accuracy: 0.8672962989967562
Epoch 18:


  0%|          | 0/5036 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>
Traceback (most recent call last):
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    if w.is_alive():
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>
Traceback (most recent call last):
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/

    AssertionErrorself._shutdown_workers(): 
can only test a child process  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers

    if w.is_alive():
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>
Traceback (most recent call last):
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    if w.is_alive():
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib





      File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
      File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    if w.is_alive():assert self._parent_pid == os.getpid(), 'can only test a child process'    

    AssertionErrorif w.is_alive():assert self._parent_pid == os.getpid(), 'can only test a child process'
: if w.is_alive():  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
      File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is

if w.is_alive():        if w.is_alive():<function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>
    if w.is_alive():
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive

if w.is_alive():
Traceback (most recent call last):
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
      File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
Exception ignored in:   File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__

assert self._parent_pid == os.getpid(), 'can only test a child process'          File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
<function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>    
assert self._parent_pid == 


      File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
if w.is_alive():    
assert self._parent_pid == os.getpid(), 'can only test a child process'  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive

    AssertionError: assert self._parent_pid == os.getpid(), 'can only test a child process'can only test a child process

AssertionError: can only test a child process
Exception ignored in: 

  0%|          | 0/315 [00:01<?, ?it/s]

<function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>
Traceback (most recent call last):
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    if w.is_alive():
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>
Traceback (most recent call last):
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/home/ichuviliaeva/mini

accuracy: 0.8879599881478539
Epoch 19:


  0%|          | 0/5036 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>
Traceback (most recent call last):
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    if w.is_alive():
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'Exception ignored in: 
<function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>AssertionError
: can only test a child processTraceback (most recent call last):

  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/

    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>
Traceback (most recent call last):
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    if w.is_alive():
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>
Traceback (most recent call last):
  File "/home/ichuviliaeva/miniconda3/envs/is_project

  0%|          | 0/315 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>
Traceback (most recent call last):
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    if w.is_alive():Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>Exception ignored in: 

Traceback (most recent call last):
<function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__

  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__

assert self._parent_pid == os.getpid(), 'can only test a child process'      File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()self._shutdown_workers()
    Exception ignored in: self._shutdown_workers()

<function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
AssertionError
    if w.is_alive():Traceback (most recent call last):

can only test a child process  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers

    Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__

: can only test a child processException ignored in:     Traceback (most recent call last):


<function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>AssertionError
if w.is_alive()::   File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__

      File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
can only test a child processself._shutdown_workers()

: can only test a child process  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    
assert self._parent_pid == os.getpid(), 'can only test a child process'    if w.is_alive():
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
    
AssertionError: assert self._parent_pid == os.getpid(), 'can only test a child process'Exce

accuracy: 0.886528112390612
Epoch 20:


  0%|          | 0/5036 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>
Traceback (most recent call last):
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    if w.is_alive():
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>
Traceback (most recent call last):
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/tor

    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>
Traceback (most recent call last):
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    if w.is_alive():assert self._parent_pid == os.getpid(), 'can only test a child process'
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
    
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>
Traceback (most recent call last):
  File "/home/ichuviliaeva/miniconda3/envs/is_project

  0%|          | 0/315 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>
Traceback (most recent call last):
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    if w.is_alive():
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0><function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>Exception ignored in: 

<function

Traceback (most recent call last):
        self._shutdown_workers()self._shutdown_workers()<function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>Exception ignored in: 

Exception ignored in: 
Traceback (most recent call last):
self._shutdown_workers()  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers

<function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0><function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
      File "/home/ichuviliaeva/mi

assert self._parent_pid == os.getpid(), 'can only test a child process'    AssertionErrorException ignored in:   File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
Traceback (most recent call last):
assert self._parent_pid == os.getpid(), 'can only test a child process'  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__

    <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>self._shutdown_workers(): 

Traceback (most recent call last):
AssertionErrorcan only test a child process      File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
: assert self._parent_pid == os.getpid(), 'can only test a child process'
can only test a child process
Exception ignored in: assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionE

        self._shutdown_workers()if w.is_alive():
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers

      File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
if w.is_alive():    
assert self._parent_pid == os.getpid(), 'can only test a child process'  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive

    AssertionErrorassert self._parent_pid == os.getpid(), 'can only test a child process': 
can only test a child processAssertionError: 
can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>
Traceback (most recent call last):
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/

accuracy: 0.886577754887071
Epoch 21:


  0%|          | 0/5036 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0><function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>

Traceback (most recent call last):
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>self._shutdown_workers()Traceback (most recent call last):


  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
Traceback (most recent call last):
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib

  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
if w.is_alive():<function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>    AssertionError    
assert self._parent_pid == os.getpid(), 'can only test a child process'
        
:   File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
self._shutdown_workers()  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive

if w.is_alive():    

AssertionError  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
assert self._parent_pid == os.getpid(), 'can only test a child process'    self._shutdown_workers()if w.is_alive():: can only test a child process  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/mult


  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    if w.is_alive():    
self._shutdown_workers()
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
      File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
    if w.is_alive():
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>assert self._parent_pid == os.getpid(), 'can only test a child process'Exception ignored in:   File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
Exception ignored in: Exception ignored in:     
<function _MultiProcessi


AssertionError
    AssertionErrorcan only test a child processif w.is_alive():: : can only test a child process

can only test a child process
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>

    Traceback (most recent call last):
Exception ignored in:   File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
AssertionErrorassert self._parent_pid == os.getpid(), 'can only test a child process'
Exception ignored in:     <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>self._shutdown_workers()
<function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>Traceback (most recent call last):


Exception ignored in: Traceback (most recent call last):
<function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>  File "/home/ic

  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    if w.is_alive():
Exception ignored in:   File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
<function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>    
Traceback (most recent call last):
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
assert self._parent_pid == os.getpid(), 'can only test a child process'    
AssertionErrorself._shutdown_workers(): can only test a child process

  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shut

KeyboardInterrupt: 

In [32]:
model.load_state_dict(torch.load(f'arch2_aug_epoch20.pth'))

<All keys matched successfully>

In [33]:
epoch = 21
# changing alpha
while epoch < 26:
    print(f'Epoch {epoch}:')
    train_epoch(train_loader, val_loader, model, optim, loss_fn, alpha = 0.7)
    accuracy = run_validation(val_loader, model)
    print(f'accuracy: {accuracy}')
    if epoch % 5 == 0:
        torch.save(model.state_dict(), f'arch2_aug_epoch{epoch}.pth')
    epoch += 1

Epoch 21:


100%|███████████████████████████████████████████████████████████████████| 5036/5036 [19:52<00:00,  4.22it/s]
100%|█████████████████████████████████████████████████████████████████████| 315/315 [00:57<00:00,  5.46it/s]


accuracy: 0.8904188430505935
Epoch 22:


100%|███████████████████████████████████████████████████████████████████| 5036/5036 [19:54<00:00,  4.22it/s]
100%|█████████████████████████████████████████████████████████████████████| 315/315 [00:55<00:00,  5.65it/s]


accuracy: 0.8597708067991604
Epoch 23:


100%|███████████████████████████████████████████████████████████████████| 5036/5036 [20:37<00:00,  4.07it/s]
100%|█████████████████████████████████████████████████████████████████████| 315/315 [00:59<00:00,  5.31it/s]


accuracy: 0.8936409513363915
Epoch 24:


100%|███████████████████████████████████████████████████████████████████| 5036/5036 [19:07<00:00,  4.39it/s]
100%|█████████████████████████████████████████████████████████████████████| 315/315 [00:50<00:00,  6.24it/s]


accuracy: 0.8877086730095298
Epoch 25:


100%|███████████████████████████████████████████████████████████████████| 5036/5036 [18:12<00:00,  4.61it/s]
100%|█████████████████████████████████████████████████████████████████████| 315/315 [00:51<00:00,  6.13it/s]

accuracy: 0.8928265041288596





In [54]:
to_one_shape = pad_prune(128, 128)

train_transforms = transforms.Compose([
    to_one_shape,
    transforms.ToTensor(),
    transforms.RandomChoice([
        transforms.RandomPerspective(distortion_scale = 0.2, p = 0.7),  
        transforms.GaussianBlur(kernel_size = 3)
    ]),
    
    transforms.ColorJitter(),    
    to_zero_one()
])

val_transforms = transforms.Compose([
    to_one_shape,
    transforms.ToTensor()
])

In [55]:
train_dataset = HWDBDataset(train_helper, train_transforms)
val_dataset = HWDBDataset(val_helper, val_transforms)

In [56]:
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True, drop_last=True, 
                          num_workers=8, persistent_workers = True)
val_loader = DataLoader(val_dataset, batch_size=2048, shuffle=False, num_workers=8, 
                        persistent_workers = True)

In [57]:
model.load_state_dict(torch.load(f'arch2_aug_epoch25.pth'))

<All keys matched successfully>

In [58]:
epoch = 26
# changing alpha
while epoch < 36:
    print(f'Epoch {epoch}:')
    train_epoch(train_loader, val_loader, model, optim, loss_fn, alpha = 0.9)
    accuracy = run_validation(val_loader, model)
    print(f'accuracy: {accuracy}')
    if epoch % 5 == 0:
        torch.save(model.state_dict(), f'arch2_aug_epoch{epoch}.pth')
    epoch += 1

Epoch 26:


  0%|          | 0/5036 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>
Traceback (most recent call last):
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    if w.is_alive():
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>
Traceback (most recent call last):
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/

    assert self._parent_pid == os.getpid(), 'can only test a child process'Exception ignored in: 
<function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>AssertionError: 
can only test a child processTraceback (most recent call last):

  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    if w.is_alive():
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>
Exception ignored in: Traceback (most recent call last):
<function _MultiProcessingDataL

  0%|          | 0/315 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>
Traceback (most recent call last):
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    if w.is_alive():
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>
Traceback (most recent call last):
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/

    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>
Traceback (most recent call last):
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
Exception ignored in:     <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>
Traceback (most recent call last):
if w.is_alive():  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__

  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
    self._shutdown_workers()    assert

    if w.is_alive():
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3237e8d3a0>
Traceback (most recent call last):
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    if w.is_alive():
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _Multi

accuracy: 0.9028542884135965
Epoch 27:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9047003687506691
Epoch 28:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9055411885344449
Epoch 29:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9067388137615205
Epoch 30:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.906248594108987
Epoch 31:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9048337829599028
Epoch 32:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9085057763698614
Epoch 33:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9087198596358412
Epoch 34:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9093326342015082
Epoch 35:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.909577744027775


In [68]:
epoch = 36
# changing alpha
while epoch < 51:
    print(f'Epoch {epoch}:')
    train_epoch(train_loader, val_loader, model, optim, loss_fn, alpha = 1.0)
    accuracy = run_validation(val_loader, model)
    print(f'accuracy: {accuracy}')
    if epoch % 5 == 0:
        torch.save(model.state_dict(), f'arch2_aug_epoch{epoch}.pth')
    epoch += 1

Epoch 36:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9115882651343683
Epoch 37:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9138345880991423
Epoch 38:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9123220432851542
Epoch 39:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9059522904582468
Epoch 40:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9130992586203419
Epoch 41:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9135103605441438
Epoch 42:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9121079600191744
Epoch 43:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9159862800550411
Epoch 44:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9156930790603296
Epoch 45:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.916755738750157
Epoch 46:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9123468645333838
Epoch 47:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.916518385563962
Epoch 48:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9161181429362606
Epoch 49:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9174569390126418
Epoch 50:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9162732757376952


In [17]:
model.load_state_dict(torch.load(f'arch2_aug_epoch50.pth'))

<All keys matched successfully>

In [21]:
to_one_shape = pad_prune(128, 128)

train_transforms = transforms.Compose([
    to_one_shape,
    transforms.ToTensor(),
    transforms.RandomChoice([
        transforms.RandomPerspective(distortion_scale = 0.2, p = 0.7),  
        transforms.GaussianBlur(kernel_size = 5),
        transforms.RandomRotation(degrees=15)
    
    ]),
    
    transforms.ColorJitter(),    
    to_zero_one()
])

val_transforms = transforms.Compose([
    to_one_shape,
    transforms.ToTensor()
])

In [37]:
train_dataset = HWDBDataset(train_helper, val_transforms) # no agument
val_dataset = HWDBDataset(val_helper, val_transforms)

In [38]:
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True, drop_last=True, 
                        num_workers=8, persistent_workers = True)
val_loader = DataLoader(val_dataset, batch_size=2048, shuffle=False, num_workers=8, 
                        persistent_workers = True)

In [24]:
test_path = r'/DATA/ichuviliaeva/ocr_data/test.lmdb'
pred_path = './pred.txt'

test_reader = LMDBReader(test_path)
test_reader.open()
test_helper = HWDBDatasetHelper(test_reader, prefix='Test')

In [25]:
test_dataset = HWDBDataset(test_helper, to_one_shape)
test_loader = DataLoader(test_dataset, batch_size=2048, shuffle=False, num_workers=8)

In [27]:
from pathlib import Path

def evaluate(gt_path, pred_path):
    gt = dict()
    with open(gt_path) as gt_f:
        for line in gt_f:
            name, cls = line.strip().split()
            gt[name] = cls
    
    n_good = 0
    n_all = len(gt)
    with open(pred_path) as pred_f:
        for line in pred_f:
            name, cls = line.strip().split()
            if cls == gt[name]:
                n_good += 1
    
    return n_good / n_all


def _run_evaluation():
    base = '.'
    pred_path = base + '/pred.txt'
    score = evaluate(gt_path, pred_path)
    print('Accuracy = {:1.4f}'.format(score))

In [28]:
epoch = 51
# changing alpha
while epoch < 61:
    print(f'Epoch {epoch}:')
    train_epoch(train_loader, val_loader, model, optim, loss_fn, alpha = 1.0)
    accuracy = run_validation(val_loader, model)
    print(f'accuracy: {accuracy}')
    if epoch % 5 == 0:
        torch.save(model.state_dict(), f'arch2_aug_epoch{epoch}.pth')
        
        preds = []
        model.eval()
        with torch.no_grad():
            for X, _ in tqdm(test_loader):
                features, logits = model(X.unsqueeze(1).to(torch.float32).cuda())
                classes = torch.argmax(logits, dim=1).cpu().numpy()
                preds.extend(classes)
                
        with open(pred_path, 'w') as f_pred:
            for idx, pred in enumerate(preds):
                name = test_helper.namelist[idx]
                cls = train_helper.vocabulary.class_by_index(pred)
                print(name, cls, file=f_pred)
        _run_evaluation()
        
    epoch += 1

Epoch 51:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.911510698733651
Epoch 52:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.919458152151149
Epoch 53:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9211428943747295
Epoch 54:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9202183028781789
Epoch 55:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9216362166832918


  0%|          | 0/380 [00:00<?, ?it/s]

Accuracy = 0.8950
Epoch 56:


  0%|          | 0/5036 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [48]:
model.load_state_dict(torch.load(f'arch2_aug_epoch55.pth'))

<All keys matched successfully>

In [49]:
epoch = 56
# changing alpha
while epoch < 66:
    print(f'Epoch {epoch}:')
    train_epoch(train_loader, val_loader, model, optim, loss_fn, alpha = 1.0)
    accuracy = run_validation(val_loader, model)
    print(f'accuracy: {accuracy}')
    if epoch % 5 == 0:
        torch.save(model.state_dict(), f'arch2_aug_epoch{epoch}.pth')
        
        preds = []
        model.eval()
        with torch.no_grad():
            for X, _ in tqdm(test_loader):
                features, logits = model(X.unsqueeze(1).to(torch.float32).cuda())
                classes = torch.argmax(logits, dim=1).cpu().numpy()
                preds.extend(classes)
                
        with open(pred_path, 'w') as f_pred:
            for idx, pred in enumerate(preds):
                name = test_helper.namelist[idx]
                cls = train_helper.vocabulary.class_by_index(pred)
                print(name, cls, file=f_pred)
        _run_evaluation()
        
    epoch += 1

Epoch 56:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9214190307612832
Epoch 57:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.92093967040485
Epoch 58:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9235598634210817
Epoch 59:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9237491254388319
Epoch 60:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9226476825486458


  0%|          | 0/380 [00:00<?, ?it/s]

Accuracy = 0.8964
Epoch 61:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9218130680769272
Epoch 62:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9225515002117562
Epoch 63:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9234171412437617
Epoch 64:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.920818666819731
Epoch 65:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9197544558018892


  0%|          | 0/380 [00:00<?, ?it/s]

Accuracy = 0.8991


In [50]:
epoch = 66
# changing alpha
while epoch < 76:
    print(f'Epoch {epoch}:')
    train_epoch(train_loader, val_loader, model, optim, loss_fn, alpha = 1.0)
    accuracy = run_validation(val_loader, model)
    print(f'accuracy: {accuracy}')
    if epoch % 5 == 0:
        torch.save(model.state_dict(), f'arch2_aug_epoch{epoch}.pth')
        
        preds = []
        model.eval()
        with torch.no_grad():
            for X, _ in tqdm(test_loader):
                features, logits = model(X.unsqueeze(1).to(torch.float32).cuda())
                classes = torch.argmax(logits, dim=1).cpu().numpy()
                preds.extend(classes)
                
        with open(pred_path, 'w') as f_pred:
            for idx, pred in enumerate(preds):
                name = test_helper.namelist[idx]
                cls = train_helper.vocabulary.class_by_index(pred)
                print(name, cls, file=f_pred)
        _run_evaluation()
        
    epoch += 1

Epoch 66:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9207147278427698
Epoch 67:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9247869638804298
Epoch 68:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9135274251523017
Epoch 69:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9183148234045755
Epoch 70:


  0%|          | 0/5036 [00:00<?, ?it/s]

  0%|          | 0/315 [00:00<?, ?it/s]

accuracy: 0.9247993745045446


  0%|          | 0/380 [00:00<?, ?it/s]

Accuracy = 0.9041
Epoch 71:


  0%|          | 0/5036 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [51]:
model.load_state_dict(torch.load(f'arch2_aug_epoch70.pth'))

<All keys matched successfully>

In [52]:
test_path = r'/DATA/ichuviliaeva/ocr_data/test.lmdb'
pred_path = './pred.txt'

test_reader = LMDBReader(test_path)
test_reader.open()
test_helper = HWDBDatasetHelper(test_reader, prefix='Test')

In [53]:
test_dataset = HWDBDataset(test_helper, to_one_shape)
test_loader = DataLoader(test_dataset, batch_size=2048, shuffle=False, num_workers=8)

In [54]:
preds = []
model.eval()
with torch.no_grad():
    for X, _ in tqdm(test_loader):
        features, logits = model(X.unsqueeze(1).to(torch.float32).cuda())
        classes = torch.argmax(logits, dim=1).cpu().numpy()
        preds.extend(classes)

  0%|          | 0/380 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7faf246023a0>
Traceback (most recent call last):
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    if w.is_alive():
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7faf246023a0>
Traceback (most recent call last):
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/

<function _MultiProcessingDataLoaderIter.__del__ at 0x7faf246023a0>        
assert self._parent_pid == os.getpid(), 'can only test a child process'Traceback (most recent call last):
  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__

    self._shutdown_workers()assert self._parent_pid == os.getpid(), 'can only test a child process'AssertionError

AssertionError:   File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
: 
    can only test a child processif w.is_alive():

  File "/home/ichuviliaeva/miniconda3/envs/is_project/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'Exception ignored in: Exception ignored in: 

AssertionError<function _MultiProcessingDataLoaderIter.__del__ at 0x7faf246023a0><function _MultiProcessingD

In [55]:
with open(pred_path, 'w') as f_pred:
    for idx, pred in enumerate(preds):
        name = test_helper.namelist[idx]
        cls = train_helper.vocabulary.class_by_index(pred)
        print(name, cls, file=f_pred)

In [56]:
!python -m course_ocr_t2.evaluate
# Accuracy = 0.7978

pred_path =  /home/ichuviliaeva/ocr_course/course_ocr/task2/pred.txt
Accuracy = 0.9041
