In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.autograd import Variable
from torchvision import datasets,transforms
from torch.utils.data import Dataset, random_split, DataLoader
from torchvision.transforms.functional import to_tensor, to_pil_image

import pandas as pd
from PIL import Image
from tqdm import tqdm
import numpy as np

In [2]:
class CaptchaDataset(Dataset):
    def __init__(self, csv_path, max_length, input_length, characters):
        super(CaptchaDataset, self).__init__()
        self.df = pd.read_csv(csv_path)
        self.max_length = max_length
        self.characters = characters
        self.input_length = input_length

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        image = Image.open(self.df.filename[index])
        image = to_tensor(image)
        target = self.encode(self.df.label[index])
        input_length = torch.full(size=(1, ), fill_value=self.input_length, dtype=torch.long)
        target_length = torch.full(size=(1, ), fill_value=len(self.df.label[index]), dtype=torch.long)
        return image, target, input_length, target_length
    
    def encode(self, label):
        target = torch.zeros(size=(self.max_length, ), dtype=torch.long)
        for i,c in enumerate(label):
            target[i] = self.characters.find(c)
        return target

In [3]:
characters = ' (0+)9=*867154-32'
n_classes = len(characters)
batch_size = 128

dataset = CaptchaDataset(csv_path='train.csv', max_length=11, input_length=37, characters=characters)
train_set, valid_set = random_split(dataset, [80000, 20000])
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=8)
valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=True, num_workers=8)

In [4]:
# https://arxiv.org/pdf/1507.05717.pdf

In [5]:
class Model(nn.Module):
    def __init__(self, n_classes):
        super(Model, self).__init__()
        self.cnn = nn.Sequential(
            # inputs = (32, 3, 64, 300)
            # in_channels, out_channels(卷积核数量), kernel_size, stride, padding
            nn.Conv2d(3, 32, 3, 1, 1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),
            #[32, 32, 32, 150]

            nn.Conv2d(32, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            # [32, 64, 16, 75]

            nn.Conv2d(64, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            # [32, 128, 16, 75]
            nn.Conv2d(128, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),
            # [32, 128, 8, 37]
            
            nn.Conv2d(128, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d((2, 1)),
            # [32, 256, 4, 37]
            
            nn.Conv2d(256, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d((2, 1)),
            # [32, 256, 2, 37]
                      
            nn.Conv2d(256,256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d((2, 1)),
            # [32, 256, 1, 37]
        )
        self.lstm = nn.LSTM(input_size=256, hidden_size=256, num_layers=1, bidirectional=True)
        self.fc = nn.Linear(in_features=512, out_features=n_classes)

    def forward(self, x):
        x = self.cnn(x)
        x = x.squeeze(2)       # [32, 256, 37]
        x = x.permute(2, 0, 1) # [37, 32, 256]
        x, _ = self.lstm(x)    # [37, 32, 512]
        x = self.fc(x)         # [37, 32, 17]
        return x

In [6]:
model = Model(n_classes)
inputs = torch.zeros((32, 3, 64, 300))
outputs = model(inputs)
outputs.shape

torch.Size([37, 32, 17])

In [7]:
model = Model(n_classes)
model = model.cuda()
model

Model(
  (cnn): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU()
    (11): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (13): ReLU()
    (14): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mod

In [8]:
def decode_target(out):
    tar = ''.join([characters[x] for x in out])
    return tar.replace(' ', '')
    
def decode(out):
    pre = ''.join([characters[x] for x in out])
    new_pre = ''
    for i, x in enumerate(pre[:-1]):
        if(x != pre[i+1]):
            new_pre += x
    
    if len(new_pre) < 1:
        return ''
    if(new_pre[-1] != pre[-1]):
        new_pre += pre[-1]
    new_pre = new_pre.replace(' ', '')
    return new_pre

In [9]:
optimizer = torch.optim.Adadelta(model.parameters())

for epoch in range(1, 11):
    model.train()
    with tqdm(train_loader) as pbar:
        loss_new = 0
        acc_new = 0
        for batch_index, (data, target, input_lengths, target_lengths) in enumerate(pbar):
            data, target = data.cuda(), target.cuda()
            
            optimizer.zero_grad()
            output = model(data)
            
            output_log_softmax = F.log_softmax(output, dim=-1)
            loss = F.ctc_loss(output_log_softmax, target, input_lengths, target_lengths)
            
            loss.backward()
            optimizer.step()

            loss = loss.item()
            output_argmax = output.detach().argmax(dim=-1).permute(1, 0).cpu()
            acc = sum([decode_target(true) == decode(pred) 
                       for true, pred in zip(target, output_argmax)]) / len(target)

            if(batch_index==0):
                loss_new = loss
                acc_new = acc
                
            loss_new = 0.1*loss + 0.9*loss_new
            acc_new = 0.1*acc + 0.9*acc_new
            pbar.set_description(f'Epoch: {epoch} Loss: {loss_new:.4f} Acc: {acc_new:.4f} ')
    
    model.eval()
    with tqdm(valid_loader) as pbar:
        loss_sum = 0
        acc_sum = 0
        for batch_index, (data, target, input_lengths, target_lengths) in enumerate(pbar):
            data, target = data.cuda(), target.cuda()
            output = model(data)
            output_log_softmax = F.log_softmax(output, dim=-1)
            loss = F.ctc_loss(output_log_softmax, target, input_lengths, target_lengths)
            loss = loss.item()
            output_argmax = output.detach().argmax(dim=-1).permute(1, 0).cpu()
            acc = sum([decode_target(true) == decode(pred) 
                       for true, pred in zip(target, output_argmax)]) / len(target)
            loss_sum += loss
            acc_sum += acc
            pbar.set_description(f'Valid: {epoch} Loss: {loss_sum / (batch_index + 1):.4f} '
                                 f'Acc: {acc_sum / (batch_index + 1):.4f} ')

Epoch: 1 Loss: 0.2940 Acc: 0.4270 : 100%|██████████| 625/625 [01:21<00:00,  7.74it/s]
Valid: 1 Loss: 0.4830 Acc: 0.2047 : 100%|██████████| 157/157 [00:11<00:00, 14.22it/s]
Epoch: 2 Loss: 0.0568 Acc: 0.8892 : 100%|██████████| 625/625 [01:21<00:00,  7.75it/s]
Valid: 2 Loss: 0.7936 Acc: 0.0969 : 100%|██████████| 157/157 [00:10<00:00, 14.39it/s]
Epoch: 3 Loss: 0.0127 Acc: 0.9719 : 100%|██████████| 625/625 [01:21<00:00,  7.71it/s]
Valid: 3 Loss: 0.0250 Acc: 0.9448 : 100%|██████████| 157/157 [00:10<00:00, 14.30it/s]
Epoch: 4 Loss: 0.0088 Acc: 0.9801 : 100%|██████████| 625/625 [01:21<00:00,  7.54it/s]
Valid: 4 Loss: 0.1720 Acc: 0.6711 : 100%|██████████| 157/157 [00:10<00:00, 14.49it/s]
Epoch: 5 Loss: 0.0085 Acc: 0.9840 : 100%|██████████| 625/625 [01:22<00:00,  7.74it/s]
Valid: 5 Loss: 0.0126 Acc: 0.9734 : 100%|██████████| 157/157 [00:10<00:00, 14.29it/s]
Epoch: 6 Loss: 0.0044 Acc: 0.9920 : 100%|██████████| 625/625 [01:22<00:00,  7.81it/s]
Valid: 6 Loss: 0.0205 Acc: 0.9579 : 100%|██████████| 1

In [12]:
torch.save(model, 'model.pth')

  "type " + obj.__name__ + ". It won't be checked "
