In [2]:
import numpy as np
import json
import torch

import const
import my_datasets
from my_models import MyResNet

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
TRAIN_BATCH_SIZE = 64
TEST_BATCH_SIZE = 256

In [4]:
def evaluate(model, dataloader):
    model.eval()

    count_correct = 0
    count_all = 0
    _ACS = np.array(const.ALL_CHAR_SET)
    model.eval()
    with torch.no_grad():
        for step, (img, _, label) in enumerate(dataloader):
            img = img.cuda()
            pred = model(img).reshape((-1, len(const.ALL_CHAR_SET), const.MAX_CAPTCHA)).cpu()
        
            c = [''.join(line) for line in _ACS[pred.argmax(axis=1)]]
            count_correct += sum([(lambda x, y: (x == y))(x, y) for x, y in zip(c, label)])
            count_all += len(c)
    return count_correct / count_all

def train(model, dataloader, loss_func, optimizer):
    model.train()
    loss_epoch = 0
    for img, label_idx, label in dataloader:
        img = img.cuda()
        label_idx = label_idx.cuda()
        pred = model(img).reshape((-1, len(const.ALL_CHAR_SET), const.MAX_CAPTCHA))

        loss = loss_func(pred, label_idx)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_epoch += loss.item()
    return loss_epoch


In [5]:
#DATA_NAMES = my_datasets.get_data_names()
DATA_NAMES = list(my_datasets.DATA_DIRS.keys())
#train_data_names = (DATA_NAMES[2], ) #'all'
#train_dataloader_name = 'train_dataloader_'+str(train_dataloader_index)

In [7]:
train_data_list = [tuple(DATA_NAMES[i] for i in sorted(indexes)) for indexes in [[0,1], [0,2], [0, 3], [1,2], [1,3], [2,3]]]
print(train_data_list)

[('captcha-images', 'captcha-09az'), ('captcha-images', 'capitalized'), ('captcha-images', 'capital-color'), ('captcha-09az', 'capitalized'), ('captcha-09az', 'capital-color'), ('capitalized', 'capital-color')]


In [9]:
def process(model, loss_func, optimizer, train_data_names):
    #train_dataloader = dataloaders[train_dataloader_name]
    train_dataloader = my_datasets.get_dataloader(train_data_names, TRAIN_BATCH_SIZE, True)

    model.train()
    best_acc = 0
    log = list()
    save_dir = const.RECOGNIZER_DIR / '+'.join(train_data_names)
    save_dir.mkdir(exist_ok=True)

    for epoch in range(40):
        loss = train(model, train_dataloader, loss_func, optimizer)
        print('eopch:', epoch, 'loss:', loss, end=' ')

        acc = [evaluate(model, my_datasets.get_dataloader((name, ), TEST_BATCH_SIZE, False)) for name in my_datasets.DATA_DIRS.keys()]
        mean = sum(acc) / len(acc)
        if mean > best_acc:
            best_acc = mean
        #    torch.save(model.state_dict(), str(const.RECOGNIZER_DIR/'{}.pt'.format(train_dataloader_name)))
        torch.save(model.state_dict(), str(save_dir / '{}.pt'.format(epoch)))
        log.append(dict(epoch=epoch, loss=loss, **{'acc{}'.format(i): val for i, val in enumerate(acc)}))
        with (save_dir / 'log.json'.format()).open('w') as f:
            json.dump(log, f)
        print('test:', acc, 'mean:', mean)
        if sum(acc) == 0:
            continue
        elif sum(acc) / len(train_data_names) > 0.75:
            break

#train_data_list = [tuple(DATA_NAMES[i] for i in sorted(indexes)) for indexes in [[0,1], [0,2], [1,2], [1,3], [2,3]]]
train_data_list = [tuple(DATA_NAMES[i] for i in sorted(indexes)) for indexes in [[0, 3]]]

print(train_data_list)

for names in train_data_list:
    out_features = len(const.ALL_CHAR_SET) * const.MAX_CAPTCHA
    model = MyResNet(out_features)
    model.cuda()
    loss_func = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
    process(model, loss_func, optimizer, names)
    del model

[('captcha-images', 'capital-color')]
eopch: 0 loss: 412.602507352829 test: [0.0, 0.0, 0.0, 0.0] mean: 0.0
eopch: 1 loss: 381.9094545841217 test: [0.0, 0.0, 0.0, 0.0] mean: 0.0
eopch: 2 loss: 327.60688877105713 test: [0.0, 0.0, 0.0, 0.005] mean: 0.00125
eopch: 3 loss: 247.60328996181488 test: [0.0, 0.0, 0.0, 0.03] mean: 0.0075
eopch: 4 loss: 164.34826338291168 test: [0.0, 0.0, 0.0, 0.05] mean: 0.0125
eopch: 5 loss: 100.85405504703522 test: [0.13, 0.0, 0.0, 0.32] mean: 0.1125
eopch: 6 loss: 57.24165964126587 test: [0.324, 0.0, 0.0, 0.52] mean: 0.21100000000000002
eopch: 7 loss: 33.789056062698364 test: [0.616, 0.0, 0.0, 0.855] mean: 0.36775
eopch: 8 loss: 20.527918063104153 test: [0.176, 0.0, 0.0, 0.92] mean: 0.274
eopch: 9 loss: 13.807498313486576 test: [0.006, 0.0, 0.0, 0.885] mean: 0.22275
eopch: 10 loss: 10.519157011061907 test: [0.74, 0.0, 0.0, 0.965] mean: 0.42625
