In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

import numpy as np
import random
import pprint
import sys
import time
import pickle
import matplotlib.pyplot as plt
import matplotlib

from keras import backend as K
from keras.optimizers import Adam, SGD
from keras.layers import Input
from keras.models import Model

from keras_frcnn import config, data_generators
from keras_frcnn import losses as lossesFns
from keras_frcnn import roi_helpers
from keras.utils import generic_utils
from keras_frcnn.pascal_voc_parser import get_data
from keras_frcnn import vgg as nn

In [None]:
C = config.Config()
C.use_horizontal_flips = False
C.model_path = './weights/model_ocr.h5'
C.num_rois = 32
C.network = 'vgg'
C.base_net_weights = os.path.join('/content/dohai90/pretrained_models', nn.get_weight_path())

In [None]:
classes_count = dict()
class_mapping = dict()
classes_count['obj'] = 0
classes_count['bg'] = 0
class_mapping['obj'] = 0
class_mapping['bg'] = len(class_mapping)
C.class_mapping = class_mapping

config_output_filename = './config_ocr.pickle'
with open(config_output_filename, 'wb') as f_in:
    pickle.dump(C, f_in)

In [None]:
def train(start_epoch, stop_epoch):    
    words_dir = '/content/dohai90/workspace/OCR/korean_ocr/wordlists'
    monogram_file = os.path.join(words_dir, 'wordlist_mono_clean.txt')
    bigram_file = os.path.join(words_dir, 'wordlist_bi_clean.txt')
    if start_epoch == 0:
        string_list = data_generators.build_word_list(16000, monogram_file, bigram_file, max_string_len=4, mono_fraction=1.0)
        data_gen_train_ocr = data_generators.get_anchor_gt_ocr(string_list, 1000, 600, C, nn.get_img_output_length, K.image_dim_ordering(), mode='train', multi_fonts=True, multi_font_sizes=True)
    else:
        string_list = data_generators.build_word_list(32000, monogram_file, bigram_file, max_string_len=16, mono_fraction=0.5)
        data_gen_train_ocr = data_generators.get_anchor_gt_ocr(string_list, 1000, 600, C, nn.get_img_output_length, K.image_dim_ordering(), mode='train', multi_fonts=True, multi_font_sizes=True)
    
    # creating architecture
    input_shape_img = (None, None, 3)
    img_input = Input(shape=input_shape_img)
    roi_input = Input(shape=(None, 4))

    # define the base network
    shared_layers = nn.nn_base(img_input, trainable=True)

    # define the RPN, built on the base network
    num_anchors = len(C.anchor_box_scales) * len(C.anchor_box_ratios)
    rpn = nn.rpn(shared_layers, num_anchors)
    classifier = nn.classifier(shared_layers, roi_input, C.num_rois, nb_classes=len(classes_count), trainable=True)

    model_rpn = Model(img_input, rpn[:2])
    model_classifier = Model([img_input, roi_input], classifier)

    # this is a model that holds both the RPN and the classifier, used to load/save weights for the models
    model_all = Model([img_input, roi_input], rpn[:2] + classifier)
    
    # load weights
    if start_epoch == 0:
        try:
            print('loading weights from {}'.format(C.base_net_weights))
            model_rpn.load_weights(C.base_net_weights, by_name=True)
            model_classifier.load_weights(C.base_net_weights, by_name=True)
        except:
            print('Could not load pretrained weights')
    else:
        try:
            print('loading weights from {}'.format(C.model_path))
            model_rpn.load_weights(C.model_path, by_name=True)
            model_classifier.load_weights(C.model_path, by_name=True)
        except:
            print('Could not load pretrained weights')
            
    optimizer = Adam(lr=1e-5)
    optimizer_classifier = Adam(lr=1e-5)    
    model_rpn.compile(optimizer=optimizer, loss=[lossesFns.rpn_loss_cls(num_anchors), lossesFns.rpn_loss_regr(num_anchors)])
    model_classifier.compile(optimizer=optimizer_classifier, loss=[lossesFns.class_loss_cls, lossesFns.class_loss_regr(len(classes_count)-1)], metrics={'dense_class_{}'.format(len(classes_count)): 'accuracy'})
    model_all.compile(optimizer='sgd', loss='mae')
    
    epoch_length = 1000    
    iter_num = 0

    losses = np.zeros((epoch_length, 5))
    best_loss = np.Inf
    rpn_accuracy_rpn_monitor = []
    rpn_accuracy_for_epoch = []
    start_time = time.time()       

    for epoch_num in range(start_epoch, stop_epoch):
        progbar = generic_utils.Progbar(epoch_length)
        print('Epoch {}/{}'.format(epoch_num + 1, stop_epoch))

        while True:
            try:
                if len(rpn_accuracy_rpn_monitor) == epoch_length and C.verbose:
                    mean_overlapping_bboxes = float(sum(rpn_accuracy_rpn_monitor)) / len(rpn_accuracy_rpn_monitor)
                    rpn_accuracy_rpn_monitor = []
                    print('Average number of overlapping bounding boxes from RPN = {} for {} previous iterations'.format(mean_overlapping_bboxes, epoch_length))
                    if mean_overlapping_bboxes == 0:
                        print('RPN is not producing boxes that overlap the ground truth boxes. Check RPN settings or keep training')

                X, Y, img_data = next(data_gen_train_ocr)

                loss_rpn = model_rpn.train_on_batch(X, Y)
                P_rpn = model_rpn.predict_on_batch(X)
                R = roi_helpers.rpn_to_roi(P_rpn[0], P_rpn[1], C, K.image_dim_ordering(), use_regr=True, overlap_thresh=0.7, max_boxes=300)
                # note: calc_iou converts from (x1, y1, x2, y2) to (x, y, w, h) format
                X2, Y1, Y2, IoUs = roi_helpers.calc_iou(R, img_data, C, class_mapping)

                if X2 is None:
                    rpn_accuracy_rpn_monitor.append(0)
                    rpn_accuracy_for_epoch.append(0)
                    continue

                neg_samples = np.where(Y1[0, :, -1] == 1)
                pos_samples = np.where(Y1[0, :, -1] == 0)

                if len(neg_samples) > 0:
                    neg_samples = neg_samples[0]
                else:
                    neg_samples = []

                if len(pos_samples) > 0:
                    pos_samples = pos_samples[0]
                else:
                    pos_samples = []


                rpn_accuracy_rpn_monitor.append(len(pos_samples))
                rpn_accuracy_for_epoch.append(len(pos_samples))

                if C.num_rois > 1:
                    if len(pos_samples) < C.num_rois // 2:
                        selected_pos_samples = pos_samples.tolist()
                    else:
                        selected_pos_samples = np.random.choice(pos_samples, C.num_rois // 2, replace=False).tolist()
                    try:
                        selected_neg_samples = np.random.choice(neg_samples, C.num_rois - len(selected_pos_samples), replace=False).tolist()
                    except:
                        selected_neg_samples = np.random.choice(neg_samples, C.num_rois - len(selected_pos_samples), replace=True).tolist()

                    sel_samples = selected_pos_samples + selected_neg_samples

                else:
                    # in the extreme case where num_rois = 1, we pick a random pos or neg sample
                    selected_pos_samples = pos_samples.tolist()
                    selected_neg_samples = neg_samples.tolist()
                    if np.random.randint(0, 2):
                        sel_samples = random.choice(neg_samples)
                    else:
                        sel_samples = random.choice(pos_samples)

                loss_class = model_classifier.train_on_batch([X, X2[:, sel_samples, :]], [Y1[:, sel_samples, :], Y2[:, sel_samples, :]])

                losses[iter_num, 0] = loss_rpn[1]
                losses[iter_num, 1] = loss_rpn[2]
                losses[iter_num, 2] = loss_class[1]
                losses[iter_num, 3] = loss_class[2]
                losses[iter_num, 4] = loss_class[3]

                iter_num += 1
                progbar.update(iter_num, [('rpn_cls', np.mean(losses[:iter_num, 0])), ('rpn_regr', np.mean(losses[:iter_num, 1])),
                                          ('detector_cls', np.mean(losses[:iter_num, 2])), ('detector_regr', np.mean(losses[:iter_num, 3]))])

                if iter_num == epoch_length:
                    loss_rpn_cls = np.mean(losses[:, 0])
                    loss_rpn_regr = np.mean(losses[:, 1])
                    loss_class_cls = np.mean(losses[:, 2])
                    loss_class_regr = np.mean(losses[:, 3])
                    class_acc = np.mean(losses[:, 4])

                    mean_overlapping_bboxes = float(sum(rpn_accuracy_for_epoch)) / len(rpn_accuracy_for_epoch)
                    rpn_accuracy_for_epoch = []

                    if C.verbose:
                        print('Mean number of bounding boxes from RPN overlapping ground truth boxes: {}'.format(mean_overlapping_bboxes))
                        print('Classifier accuracy for bounding boxes from RPN: {}'.format(class_acc))
                        print('Loss RPN classifier: {}'.format(loss_rpn_cls))
                        print('Loss RPN regression: {}'.format(loss_rpn_regr))
                        print('Loss Detector classifier: {}'.format(loss_class_cls))
                        print('Loss Detector regression: {}'.format(loss_class_regr))
                        print('Elappsed time: {}'.format(time.time() - start_time))

                    curr_loss = loss_rpn_cls + loss_rpn_regr + loss_class_cls + loss_class_regr
                    iter_num = 0
                    start_time = time.time()

                    if curr_loss < best_loss:
                        if C.verbose:
                            print('Total loss decreased from {} to {}, saving weights'.format(best_loss, curr_loss))
                        best_loss = curr_loss
                        model_all.save_weights(C.model_path)

                    break # break out while loop

            except Exception as e:
                print('Exception: {}'.format(e))
                continue   

In [None]:
print('Start training')
train(start_epoch=0, stop_epoch=1)
train(start_epoch=1, stop_epoch=2)
print('Training done!')