In [1]:
import numpy as np
import json
import torch

import const
import my_datasets
from my_models import MyResNet

In [2]:
TRAIN_BATCH_SIZE = 64
TEST_BATCH_SIZE = 256

In [3]:
def evaluate(model, dataloader):
    model.eval()

    count_correct = 0
    count_all = 0
    _ACS = np.array(const.ALL_CHAR_SET)
    model.eval()
    with torch.no_grad():
        for step, (img, _, label) in enumerate(dataloader):
            img = img.cuda()
            pred = model(img).reshape((-1, len(const.ALL_CHAR_SET), const.MAX_CAPTCHA)).cpu()
        
            c = [''.join(line) for line in _ACS[pred.argmax(axis=1)]]
            count_correct += sum([(lambda x, y: (x == y))(x, y) for x, y in zip(c, label)])
            count_all += len(c)
    return count_correct / count_all

def train(model, dataloader, loss_func, optimizer):
    model.train()
    loss_epoch = 0
    for img, label_idx, label in dataloader:
        img = img.cuda()
        label_idx = label_idx.cuda()
        pred = model(img).reshape((-1, len(const.ALL_CHAR_SET), const.MAX_CAPTCHA))

        loss = loss_func(pred, label_idx)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_epoch += loss.item()
    return loss_epoch


In [4]:
#DATA_NAMES = my_datasets.get_data_names()
DATA_NAMES = list(my_datasets.DATA_DIRS.keys())
#train_data_names = (DATA_NAMES[2], ) #'all'
#train_dataloader_name = 'train_dataloader_'+str(train_dataloader_index)

In [13]:
train_data_list = [tuple(DATA_NAMES[i] for i in sorted(indexes)) for indexes in [[0,1], [0,2], [0, 3], [1,2], [1,3], [2,3]]]
print(train_data_list)

[('captcha-images', 'captcha-09az'), ('captcha-images', 'capitalized'), ('captcha-images', 'capital-color'), ('captcha-09az', 'capitalized'), ('captcha-09az', 'capital-color'), ('capitalized', 'capital-color')]


In [17]:
def process(model, loss_func, optimizer, train_data_names):
    #train_dataloader = dataloaders[train_dataloader_name]
    train_dataloader = my_datasets.get_dataloader(train_data_names, TRAIN_BATCH_SIZE, True)

    model.train()
    best_acc = 0
    log = list()
    save_dir = const.RECOGNIZER_DIR / '+'.join(train_data_names)
    save_dir.mkdir(exist_ok=True)

    for epoch in range(40):
        loss = train(model, train_dataloader, loss_func, optimizer)
        print('eopch:', epoch, 'loss:', loss, end=' ')

        acc = [evaluate(model, my_datasets.get_dataloader((name, ), TEST_BATCH_SIZE, False)) for name in my_datasets.DATA_DIRS.keys()]
        mean = sum(acc) / len(acc)
        if mean > best_acc:
            best_acc = mean
        #    torch.save(model.state_dict(), str(const.RECOGNIZER_DIR/'{}.pt'.format(train_dataloader_name)))
        torch.save(model.state_dict(), str(save_dir / '{}.pt'.format(epoch)))
        log.append(dict(epoch=epoch, loss=loss, **{'acc{}'.format(i): val for i, val in enumerate(acc)}))
        with (save_dir / 'log.json'.format()).open('w') as f:
            json.dump(log, f)
        print('test:', acc, 'mean:', mean)
        if sum(acc) == 0:
            continue
        elif sum(acc) / len(train_data_names) > 0.9:
            break

#train_data_list = [tuple(DATA_NAMES[i] for i in sorted(indexes)) for indexes in [[0,1], [0,2], [1,2], [1,3], [2,3]]]
train_data_list = [tuple(DATA_NAMES[i] for i in sorted(indexes)) for indexes in [[0, 2]]]
#train_data_list = train_data_list[1:2]

print(train_data_list)

for names in train_data_list:
    out_features = len(const.ALL_CHAR_SET) * const.MAX_CAPTCHA
    model = MyResNet(out_features)
    model.cuda()
    loss_func = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
    process(model, loss_func, optimizer, names)
    del model

[('captcha-images', 'capitalized')]
eopch: 0 loss: 364.7583210468292 test: [0.0, 0.0, 0.0, 0.0] mean: 0.0
eopch: 1 loss: 353.01735186576843 test: [0.0, 0.0, 0.0, 0.0] mean: 0.0
eopch: 2 loss: 324.7808816432953 test: [0.0, 0.0, 0.0, 0.0] mean: 0.0
eopch: 3 loss: 271.95271396636963 test: [0.0, 0.0, 0.0, 0.0] mean: 0.0
eopch: 4 loss: 208.397127866745 test: [0.0, 0.0, 0.0, 0.0] mean: 0.0
eopch: 5 loss: 148.61644446849823 test: [0.062, 0.0, 0.015, 0.0] mean: 0.01925
eopch: 6 loss: 96.19722467660904 test: [0.278, 0.0, 0.07, 0.0] mean: 0.08700000000000001
eopch: 7 loss: 58.272410809993744 test: [0.348, 0.0, 0.195, 0.0] mean: 0.13574999999999998
eopch: 8 loss: 35.7857563495636 test: [0.596, 0.0, 0.47, 0.0] mean: 0.26649999999999996
eopch: 9 loss: 21.831040129065514 test: [0.454, 0.0, 0.535, 0.0] mean: 0.24725000000000003
eopch: 10 loss: 16.150591865181923 test: [0.124, 0.0, 0.19, 0.0] mean: 0.0785
eopch: 11 loss: 13.543409921228886 test: [0.772, 0.0, 0.83, 0.0] mean: 0.40049999999999997
eopch: