In [68]:
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader,Dataset
from torchvision.transforms import Compose, ToTensor


import numpy as np
import pandas as pd
import time
import os
import matplotlib.pyplot as plt
from PIL import Image

In [69]:
source = [str(i) for i in range(0, 10)]
alphabet = ''.join(source)
#img_path = '/Users/fowillwly/Jupyter Notebook/商业大数据分析/data/14_0544.png'

In [70]:
#load the data
def img_loader(img_path):
    img = Image.open(img_path)
    return img.convert('RGB')

def make_dataset(data_path, alphabet, num_class, num_char):
    img_names = os.listdir(data_path)
    samples = []
    for img_name in img_names:
        if img_name.endswith('.png'):
            img_path = os.path.join(data_path, img_name)
            target_str = img_name.split('_')[1][:4]
            assert len(target_str) == num_char
            target = []
            for char in target_str:
                vec = [0] * num_class
                vec[alphabet.find(char)] = 1
                target += vec
            samples.append((img_path, target))
    return samples  

class CaptchaData(Dataset):
    def __init__(self, data_path, num_class=10, num_char=4, 
                 transform=None, target_transform=None, alphabet=alphabet):
        super(Dataset, self).__init__()
        self.data_path = data_path
        self.num_class = num_class
        self.num_char = num_char
        self.transform = transform
        self.target_transform = target_transform
        self.alphabet = alphabet
        self.samples = make_dataset(self.data_path, self.alphabet, 
                                    self.num_class, self.num_char)
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, index):
        img_path, target = self.samples[index]
        img = img_loader(img_path)
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return img, torch.Tensor(target) 

In [71]:
#create CNN
import torch.nn as nn
import torch.nn.functional as F


class CNN(nn.Module):
    def __init__(self, num_class=10, num_char=4):
        super(CNN, self).__init__()
        self.num_class = num_class
        self.num_char = num_char
        self.conv = nn.Sequential(
                #batch*3*180*100
                nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=(1, 1)),
                nn.MaxPool2d(2,2),
                nn.BatchNorm2d(num_features=16),
                nn.ReLU(),
                #batch*16*90*50
                nn.Conv2d(in_channels=16, out_channels=64, kernel_size=3, padding=(1, 1)),
                nn.MaxPool2d(2,2),
                nn.BatchNorm2d(num_features=64),
                nn.ReLU(),
                #batch*64*45*25
                nn.Conv2d(in_channels=64, out_channels=512, kernel_size=3, padding=(1, 1)),
                nn.MaxPool2d(2,2),
                nn.BatchNorm2d(num_features=512),
                nn.ReLU(),
                #batch*512*22*12
                nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=(1, 1)),
                nn.MaxPool2d(2,2),
                nn.BatchNorm2d(num_features=512),
                nn.ReLU(),
                #batch*512*11*6
                )
        self.fc = nn.Linear(in_features=512*11*6, out_features=self.num_class*self.num_char)
        
    def forward(self, x):
        x = self.conv(x)
        x = x.view(-1, 512*11*6)
        x = self.fc(x)
        return x

In [72]:
#training


batch_size = 128
base_lr = 0.001
max_epoch = 200
model_path = './checkpoints/model.pth'
restor = False

if not os.path.exists('./checkpoints'):
    os.mkdir('./checkpoints')

def calculat_acc(output, target):
    output, target = output.view(-1, 10), target.view(-1, 10)
    output = nn.functional.softmax(output, dim=1)
    output = torch.argmax(output, dim=1)
    target = torch.argmax(target, dim=1)
    output, target = output.view(-1, 4), target.view(-1, 4)
    correct_list = []
    for i, j in zip(target, output):
        if torch.equal(i, j):
            correct_list.append(1)
        else:
            correct_list.append(0)
    acc = sum(correct_list) / len(correct_list)
    return acc

def train():
    transforms = Compose([ToTensor()])
    train_dataset = CaptchaData('./data/train', transform=transforms)
    train_data_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=0, 
                             shuffle=True, drop_last=True)
    test_data = CaptchaData('./data/test', transform=transforms)
    test_data_loader = DataLoader(test_data, batch_size=batch_size, 
                                  num_workers=0, shuffle=True, drop_last=True)
    cnn = CNN()
    if torch.cuda.is_available():
        cnn.cuda()
    if restor:
        cnn.load_state_dict(torch.load(model_path))
    
    optimizer = torch.optim.Adam(cnn.parameters(), lr=base_lr)
    criterion = nn.MultiLabelSoftMarginLoss()
    
    for epoch in range(max_epoch):
        start_ = time.time()
        
        loss_history = []
        acc_history = []
        cnn.train()
        for img, target in train_data_loader:
            img = Variable(img)
            target = Variable(target)
            if torch.cuda.is_available():
                img = img.cuda()
                target = target.cuda()
            output = cnn(img)
            loss = criterion(output, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            acc = calculat_acc(output, target)
            acc_history.append(acc)
            loss_history.append(loss)
        print('train_loss: {:.4}|train_acc: {:.4}'.format(
                torch.mean(torch.Tensor(loss_history)),
                torch.mean(torch.Tensor(acc_history)),
                ))
        
        loss_history = []
        acc_history = []
        cnn.eval()
        for img, target in test_data_loader:
            img = Variable(img)
            target = Variable(target)
            if torch.cuda.is_available():
                img = img.cuda()
                target = target.cuda()
            output = cnn(img)
            
            acc = calculat_acc(output, target)
            acc_history.append(acc)
            loss_history.append(float(loss))
        print('test_loss: {:.4}|test_acc: {:.4}'.format(
                torch.mean(torch.Tensor(loss_history)),
                torch.mean(torch.Tensor(acc_history)),
                ))
        print('epoch: {}|time: {:.4f}'.format(epoch, time.time()-start_))
        torch.save(cnn.state_dict(), model_path)

if __name__=="__main__":
    train()
    pass

KeyboardInterrupt: 