In [97]:
import numpy as np
from sklearn.model_selection import train_test_split

import torch
from torch.utils.data import DataLoader
from torchvision import transforms

import config
from dataset import OCRDataset
from model import OCRResNet50
import utils

In [98]:
selected_annotation_list, index_to_utf16 = utils.prepare_selected_annotation_from_dataset_indexes([6, 12])
train_annotation_list, validation_annotation_list = train_test_split(selected_annotation_list,
                                                                     test_size=0.2,
#henkousurukoto                                                                     random_state=config.RANDOM_SEED)

In [99]:
utf16_to_index = {}
for index in index_to_utf16:
    utf16_to_index[index_to_utf16[index]] =  index 

In [100]:
preprocessed_annotation \
    = utils.preprocess_annotation(path_to_annotation_csv='../../data/komonjo/200014740/200014740_coordinate.csv',
                                  original_image_dir='../../data/komonjo/200014740/images/')
test_annotation_list = utils.select_annotation_and_convert_ut16_to_index(preprocessed_annotation, utf16_to_index)

In [101]:
tf = transforms.Compose([transforms.ToTensor(),
                         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [102]:
train_dataset = OCRDataset(train_annotation_list, transform=tf)
validation_dataset = OCRDataset(validation_annotation_list, transform=tf)
test_dataset = OCRDataset(test_annotation_list, transform=tf)

In [103]:
batchsize = 16
train_loader = DataLoader(train_dataset, batch_size=batchsize, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=batchsize)
test_loader = DataLoader(test_dataset, batch_size=batchsize)

In [104]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = OCRResNet50(5*config.N_KINDS_OF_CHARACTERS, pretrained_choice=3)
net = net.to(device)
params = torch.load('../../data/komonjo/logs/experiment_SIN/choice1_90/weight_090.pth')
net.load_state_dict(params)

In [144]:
gt_dir = '../mAP/input/ground-truth/'
pred_dir = '../mAP/input/detection-results/'

In [154]:
def write_bbox(char_index, bbox, counter, isGT):
    confidence, center_x, center_y, width, height = bbox
    min_x = center_x - 0.5*width
    min_y = center_y - 0.5*height
    max_x = center_x + 0.5*width
    max_y = center_y + 0.5*height
    
    if not isGT:
        for_NMS = np.array([confidence, min_x, min_y, max_x, max_y])
        remaining_indexes = utils.NMS(for_NMS, border=0.1)
        after_NMS = for_NMS[:, remaining_indexes]
        confidence, min_x, min_y, max_x, max_y = after_NMS
    for i in range(len(confidence)):
        c = confidence[i]
        minx = min_x[i]
        miny = min_y[i]
        maxx = max_x[i]
        maxy = max_y[i]
        if isGT:
            with open(gt_dir + '{:03d}.txt'.format(counter), mode='a') as f:
                utf16 = index_to_utf16[char_index]
                mess = (('\\u' + utf16[2:]).encode()).decode('unicode-escape')
                f.write('{0} {1} {2} {3} {4}\n'.format(mess, minx, miny, maxx, maxy))
        else:
            with open(pred_dir + '{:03d}.txt'.format(counter), mode='a') as f:
                utf16 = index_to_utf16[char_index]
                mess = (('\\u' + utf16[2:]).encode()).decode('unicode-escape')
                f.write('{0} {1} {2} {3} {4} {5}\n'.format(mess, c, minx, miny, maxx, maxy))

In [155]:
def test(model, data_loader, dataset, border):
    counter = -1
    model.eval()
    with torch.no_grad():
        for inputs, labels in data_loader:
            preds = model(inputs.cuda()).cpu()
            for gt_label, pred_label in zip(labels, preds):
                counter += 1
                with open(gt_dir + '{:03d}.txt'.format(counter), mode='x'):
                    1+1
                with open(pred_dir + '{:03d}.txt'.format(counter), mode='x'):
                    1+1
                gt_bboxes = dataset.label2bboxes(gt_label)
                pred_bboxes = dataset.label2bboxes(pred_label, confidence_border=border)
                for char_index in range(config.N_KINDS_OF_CHARACTERS):
                    gt_bbox = gt_bboxes[char_index]
                    pred_bbox = pred_bboxes[char_index]
                    write_bbox(char_index, gt_bbox, counter, isGT=True)
                    write_bbox(char_index, pred_bbox, counter, isGT=False)

In [151]:
test(net, train_loader, train_dataset, border=0.9)

In [161]:
test(net, validation_loader, validation_dataset, border=0.9)

In [160]:
test(net, test_loader, test_dataset, border=0.9)