In [1]:
!which python

/Users/administrator/opt/miniconda3/envs/made-ml/bin/python


Подготовка данных. 
Нужно реализовать класс данных (наследник torch.utils.data.Dataset). Класс должен считывать входные изображения и выделять метки из имён файлов. Для чтения изображений предлагается использовать библиотеку Pillow. Дирректория содержит набор данных, который необходимо разделить на тренировочную и тестовую выборки в отношении четыре к одному.

Создание и обучение модели. 
Код модели должен быть реализован через слои стандартной библиотеки torch (torchvision.models и аналоги использовать нельзя). Поскольку число символов в капче фиксировано, можно использовать обычный кросс- энтропийный критерий. Желающие могут использовать и CTC-loss. Цикл обучения можно реализовать самостоятельно или воспользоваться библиотеками PyTorch Lightning / Catalyst.

Подсчет метрик. 
После обучения нужно оценить точность предсказания на тестовой выборке. В качестве метрики предлагается использовать долю неверно распознанных символов, Character Error Rate (CER).

Анализ ошибок модели. 
В этой секции нужно найти изображения из тестового корпуса, на которых модель ошибается сильнее всего (по loss или по CER). Предлагается выписать в ноутбук возможные причины появления этих ошибок и пути устранения.

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split
from torch.utils.tensorboard import SummaryWriter
from torchmetrics.text import CharErrorRate
import torchvision
import torchvision.transforms
import lightning.pytorch as pl
import pandas as pd
import numpy as np
from PIL import Image
import os
import string

In [3]:
IMG_DIR = "./samples"
TORCH_RANDOM = torch.Generator().manual_seed(42)
BATCH_SIZE = 32
EPOCHS = 150

# Подготовка данных

In [4]:
alph = string.ascii_lowercase + string.digits

index2char = {k: v for k, v in enumerate(alph)}
char2index = {v: k for k, v in index2char.items()}

NUM_CLASS = len(char2index)

In [None]:
def encode()

In [5]:
def decode(tnsr):
    return list(map(lambda seq: ''.join([index2char[ind.item()] for ind in seq if ind != 0]), tnsr))

In [6]:
class ImageDataset(Dataset):
    def __init__(self, img_dir):
        self.img_dir = img_dir
        self.files = list(filter(lambda x: "." in x, os.listdir(self.img_dir)))
        self.transform = torchvision.transforms.ToTensor()

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img_name = self.files[idx]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('L')
        img_label = torch.tensor([char2index[x] for x in img_name.split(".")[0]])
        
        return self.transform(image), img_label

In [7]:
dataset = ImageDataset(IMG_DIR)

In [8]:
len(dataset.files)

1070

In [9]:
train_dataset, valid_dataset, test_dataset = random_split(dataset, [0.7, 0.1, 0.2], TORCH_RANDOM)
# train_dataset, test_dataset = random_split(dataset, [0.8, 0.2], TORCH_RANDOM)

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE)
valid_dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

# Создание и обучение модели

In [113]:
class CRNN(nn.Module):
    def __init__(self, in_channels, output):
        super().__init__()

        self.cnn = nn.Sequential(
            nn.Conv2d(in_channels, 32, 3),
            nn.ReLU(),
            nn.MaxPool2d(3, 3),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 128, 3),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 512, 3),
            nn.ReLU(),
            nn.BatchNorm2d(512),
            nn.MaxPool2d(3, 3),
            nn.Conv2d(512, 512, 3),
            nn.ReLU(),
            nn.BatchNorm2d(512)
        )
        
        self.lstm = nn.LSTM(512, 256, 2, bidirectional=True, batch_first=True)
        
        self.fc = nn.Sequential(
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.LayerNorm([18, 128]),
            nn.Linear(128, output)
        )

    def forward(self, X):
        out = self.cnn(X)        
        out = torch.flatten(out, start_dim=1, end_dim=2)
        out = out.permute(0, 2, 1)
        out, _ = self.lstm(out)
        out = self.fc(out).permute(1, 0, 2)
        
        return out.log_softmax(-1)

In [114]:
class LitModel(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.criterion = nn.CTCLoss()
        self.cer = CharErrorRate()

    def training_step(self, batch, batch_idx):
        x, y = batch
        out = self.model(x)

        input_lengths = torch.full(size=(out.size(1),), fill_value=out.size(0), dtype=torch.int32)
        target_lengths = torch.full(size=(out.size(1),), fill_value=5, dtype=torch.int32)
        
        return self.criterion(out, y, input_lengths, target_lengths)
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        out = self.model(x)

        input_lengths = torch.full(size=(out.size(1),), fill_value=out.size(0), dtype=torch.int32)
        target_lengths = torch.full(size=(out.size(1),), fill_value=5, dtype=torch.int32)
        
        loss = self.criterion(out, y, input_lengths, target_lengths)
        
        preds = out.argmax(2).T
        
        values = {"loss": loss, "cer": self.cer(decode(preds), decode(y))}
        self.log_dict(values, prog_bar=True)
        
    def validation_step(self, batch, batch_idx):
        x, y = batch
        out = self.model(x)

        input_lengths = torch.full(size=(out.size(1),), fill_value=out.size(0), dtype=torch.int32)
        target_lengths = torch.full(size=(out.size(1),), fill_value=5, dtype=torch.int32)
        
        loss = self.criterion(out, y, input_lengths, target_lengths)
        
        preds = out.argmax(2).T
        
        values = {"loss": loss, "cer": self.cer(decode(preds), decode(y))}
        self.log_dict(values, prog_bar=True)
    
    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        return self(batch)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)

        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.1,
            patience=3, eps=1e-4, verbose=True
        )
        
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "loss"}

In [115]:
model = LitModel(CRNN(in_channels=1, output=NUM_CLASS))

trainer = pl.Trainer(accelerator="cpu", log_every_n_steps=5, max_epochs=EPOCHS)
trainer.fit(
    model=model, 
    train_dataloaders=train_dataloader,
    val_dataloaders=valid_dataloader
)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name      | Type          | Params
--------------------------------------------
0 | model     | CRNN          | 6.4 M 
1 | criterion | CTCLoss       | 0     
2 | cer       | CharErrorRate | 0     
--------------------------------------------
6.4 M     Trainable params
0         Non-trainable params
6.4 M     Total params
25.465    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Epoch 00022: reducing learning rate of group 0 to 1.0000e-04.


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

# Подсчет метрик

In [116]:
trainer.test(model, dataloaders=train_dataloader)

Testing: 0it [00:00, ?it/s]

[{'loss': 0.0015812579076737165, 'cer': 0.047263018786907196}]

In [117]:
trainer.test(model, dataloaders=valid_dataloader)

Testing: 0it [00:00, ?it/s]

[{'loss': 0.018259771168231964, 'cer': 0.05607476457953453}]

In [118]:
trainer.test(model, dataloaders=test_dataloader)

Testing: 0it [00:00, ?it/s]

[{'loss': 0.030655093491077423, 'cer': 0.055140186101198196}]

# Анализ ошибок модели

In [119]:
model.model.eval()
result = [[], []]

with torch.no_grad():
    for x, y in test_dataloader:
        preds = model.model(x)
        preds = preds.argmax(2).T
        result[0].extend(decode(y))
        result[1].extend(decode(preds))        

print(f"cer: {model.cer(result[0], result[1])}")
model.model.train();

cer: 0.05291479825973511


In [120]:
for i in range(len(result[0])):
    if result[0][i] != result[1][i]:
        print(f"truth - {result[0][i]} | pred - {result[1][i]}")

truth - m8gmx | pred - m8gmmx
truth - nn4wx | pred - nn44wx
truth - n7enn | pred - n7ennn
truth - efg72 | pred - efgg72
truth - d2ycw | pred - d2yw
truth - bw5nf | pred - bww5nf
truth - 6e2dg | pred - 66e2dg
truth - nfd8g | pred - nfdd8g
truth - ef4mn | pred - efmn
truth - 5p8fm | pred - 5p88fm
truth - mfc35 | pred - mfcc35
truth - myf82 | pred - myyf82
truth - d22bd | pred - d222bd
truth - p2m6n | pred - p22m6n
truth - 664nf | pred - 6664nf
truth - m3wfw | pred - m33wfw
truth - 6825y | pred - 68825y
truth - myc3c | pred - mycc3c
truth - 6n5fd | pred - 66n5fd
truth - 6xen4 | pred - 66xeen4
truth - gmmne | pred - gmmmne
truth - 6dmx7 | pred - 6dnx7
truth - be3bp | pred - be33bp
truth - cewnm | pred - ccewnm
truth - 53mn8 | pred - 53mnn8
truth - f228n | pred - f2228n
truth - pgwnp | pred - pwnp
truth - cfw6e | pred - ccfw6e
truth - bw44w | pred - bw444w
truth - w6pxy | pred - w6pxxy
truth - cnex4 | pred - cnexx4
truth - c8fxy | pred - cc8fxy
truth - p8c24 | pred - p88c224
truth - mmfm6 |

Большинство ошибко - предсказание лишних повторяющихся символов