In [None]:
%matplotlib inline
import os
import cv2
import time
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

from weapons.Se_0a import seg_model

os.environ["CUDA_VISIBLE_DEVICES"] = '0'

def parse_filename(filename):
    """
    vertices: se, sw, nw, ne
    lp_indices: indices in provinces, alphabets, and ads
    area_ratio: in float
    clearliness: in int, the bigger, the more clear.
    """
    name = filename[:filename.index(".")].split('-')
    area_ratio = float('0.'+name[0])
    clearliness = int(name[-1])
    lp_indices = [int(x) for x in name[-3].split('_')]
    vertices = [tuple([int(y) for y in x.split("&")]) for x in name[3].split('_')]
    return vertices, lp_indices, area_ratio, clearliness

def lp_indices2numbers(lp_indices, provinces, alphabets, ads):
    return ''.join([provinces[lp_indices[0]]] + \
                   [alphabets[lp_indices[1]]] + \
                   [ads[x] for x in lp_indices[2:]])

def build_mask(width, length, points, epsilon=1e-8):
    lr, ll, ul, ur = points
    result = np.ones([length, width])
    
    ys = np.array([np.ones(width)*x for x in range(length)])
    xs = np.array([np.arange(width) for _ in range(length)])
    
    if np.abs(lr[0]-ll[0])>epsilon:
        liney1a = (lr[1]-ll[1])/(lr[0]-ll[0])
    else:
        liney1a = (lr[1]-ll[1])/epsilon
    liney1b = lr[1]-liney1a*lr[0]
    if np.abs(ur[0]-ul[0])>epsilon:
        liney2a = (ur[1]-ul[1])/(ur[0]-ul[0])
    else:
        liney2a = (ur[1]-ul[1])/epsilon
    liney2b = ur[1]-liney2a*ur[0]
    mask1 = (ys>(liney2a*xs+liney2b)) * (ys<(liney1a*xs+liney1b))
    
    if np.abs(ul[0]-ll[0])>epsilon:
        linex1a = (ul[1]-ll[1])/(ul[0]-ll[0])
    else:
        linex1a = (ul[1]-ll[1])/epsilon
    linex1b = ul[1]-linex1a*ul[0]
    if np.abs(ur[0]-lr[0])>epsilon:
        linex2a = (ur[1]-lr[1])/(ur[0]-lr[0])
    else:
        linex2a = (ur[1]-lr[1])/epsilon
    linex2b = ur[1]-linex2a*ur[0]
    if np.abs(linex1a)<epsilon: linex1a = epsilon
    if np.abs(linex2a)<epsilon: linex2a = epsilon
    mask2 = (xs>((ys-linex1b)/linex1a)) * (xs<((ys-linex2b)/linex2a))
    
    result*=mask1*mask2
    result = result.astype(np.int32)
    return result

def plot_mask(img, mask):
    plt.figure(figsize=(20,10))
    plt.imshow(img)
    plt.imshow(mask, cmap='gray', alpha=0.7)
    plt.show()

def _upper_lower_bound(img, upper, lower):
    result = (img>upper)*255+(img<upper)*img
    result = (result<lower)*lower+(result>lower)*result
    return result

def random_aug(img, verbose = False):
    """
    img as int from 0 to 255
    """
    indicator1 = np.random.random()*3-1
    result = img
    if indicator1>1:
        # 改contrast
        contrast = np.random.random()+0.5
        if verbose: print("contrast:", contrast)
        result = _upper_lower_bound(img*contrast, 254, 0).astype(int)
    elif indicator1<0:
        # 改brightness
        brightness = np.random.randint(201)-100
        if verbose: print("brightness:", brightness)
        result = _upper_lower_bound(img+brightness, 254, 0).astype(int)
    result = result/255.0
    
    indicator2 = np.random.random()*3-1
    if indicator2>1:
        # 变糊
        blurriness = int(np.random.randint(12))*2+3
        if verbose: print("blurriness:", blurriness)
        result = cv2.GaussianBlur(result,(blurriness,blurriness),0)
    elif indicator2<0:
        # 分辨率降低
        l = np.random.randint(193)+64
        if verbose: print("image resized to ("+str(l)+","+str(l)+")")
        result = cv2.resize(result, (l, l))
    return result/max(np.max(result),1.0)

def path_to_xy_segmentation(path, filename, x_shape = (512,512), y_shape = (512,512),
                            verbose = False, augmentation = True):
    img = cv2.imread(path+filename)
    img = np.dot(img,np.array([[0,0,1],[0,1,0],[1,0,0]]))
    if augmentation:
        x = random_aug(img)
        x = cv2.resize(x, x_shape)
    else:
        x = cv2.resize(img/255.1, x_shape)
    vertices, lp_indices, area_ratio, clearliness = parse_filename(filename)
    mask = build_mask(img.shape[1], img.shape[0], vertices)/1.0
    y = cv2.resize(mask, y_shape)
    if verbose: plot_mask(x, y)
    return x,y

def get_batch(file_dict, batch_size,
              x_shape = (512,512), y_shape = (512,512), augmentation = True):
    xs = []
    ys = []
    folders = list(file_dict.keys())
    for _ in range(batch_size):
        folder = np.random.choice(folders)
        filename = np.random.choice(file_dict[folder])
        x,y = path_to_xy_segmentation(folder, filename,x_shape = x_shape, y_shape = y_shape,
                                      augmentation = augmentation)
        xs.append(x)
        ys.append(y)
    return np.array(xs), np.array(ys)

def seg_to_vertices(img, use_dilated = False, verbose = False):
    if verbose:
        plt.imshow(img)
        plt.show()
    
    if use_dilated:
        # dilate thresholded image - merges top/bottom 
        kernel = np.ones((3,3))
        dilated = cv2.dilate(img, kernel, iterations=3)
        if verbose:
            plt.imshow(dilated)
            plt.show()
        current_img = dilated
    else:
        current_img = img

    # find contours
    contours, hierarchy = cv2.findContours(current_img.astype(np.uint8),
                                           cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_SIMPLE)
    if verbose: print("largest contour has ",len(contours[0]),"points")
    if len(contours)<=0 or len(contours[0])<4:
        if verbose: print("No result")
        return []
    
    # simplify contours
    used_index = set([])
    index = 0.01
    step = 2
    while(len(used_index)<10):
        epsilon = index*cv2.arcLength(contours[0],True)
        approx = cv2.approxPolyDP(contours[0],epsilon,True)
        if verbose: print("index =", index, ", # points =", len(approx))
        if len(approx)==4:
            break
        elif len(approx)>4:
            if (index*step) in used_index:
                step = 1+(step-1)/2
                used_index.add(index*step)
                index*=step
            else:
                used_index.add(index*step)
                index*=step
        else:
            if (index/step) in used_index:
                step = 1+(step-1)/2
                used_index.add(index/step)
                index/=step
            else:
                used_index.add(index/step)
                index/=step
    if len(approx)!=4:
        if verbose: print("No result")
        return []
    if verbose:
        cv2.drawContours(img, [approx], 0, (255,255,255), 3)
        plt.imshow(img)
        plt.show()
    return approx

def _bound(value, upper, lower):
    return min(max(value, lower), upper)

def rearange_vertices(vertices, img_shape):
    sorted_vertices_y = sorted(vertices, key=lambda x: x[0,1])
    sorted_vertices_x = sorted(vertices, key=lambda x: x[0,0])
    mid_y = (sorted_vertices_y[2][0][1]-sorted_vertices_y[1][0][1])
    max_y = (sorted_vertices_y[3][0][1]-sorted_vertices_y[0][0][1])
    mid_x = (sorted_vertices_x[2][0][0]-sorted_vertices_x[1][0][0])
    max_x = (sorted_vertices_x[3][0][0]-sorted_vertices_x[0][0][0])
    
    if (mid_y/max_y) > (mid_x/max_x):
        sorted_vertices = sorted_vertices_y
        if sorted_vertices[0][0][0]<sorted_vertices[1][0][0]:
            nw = sorted_vertices[0][0]
            ne = sorted_vertices[1][0]
        else:
            nw = sorted_vertices[1][0]
            ne = sorted_vertices[0][0]
        if sorted_vertices[2][0][0]<sorted_vertices[3][0][0]:
            sw = sorted_vertices[2][0]
            se = sorted_vertices[3][0]
        else:
            sw = sorted_vertices[3][0]
            se = sorted_vertices[2][0]
    else:
        sorted_vertices = sorted_vertices_x
        if sorted_vertices[0][0][1]<sorted_vertices[1][0][1]:
            nw = sorted_vertices[0][0]
            sw = sorted_vertices[1][0]
        else:
            nw = sorted_vertices[1][0]
            sw = sorted_vertices[0][0]
        if sorted_vertices[2][0][1]<sorted_vertices[3][0][1]:
            ne = sorted_vertices[2][0]
            se = sorted_vertices[3][0]
        else:
            ne = sorted_vertices[3][0]
            se = sorted_vertices[2][0]
    diagonal_length = ((se[0]-nw[0])**2+(se[1]-nw[1])**2)**0.5
    diagonal_length+= ((ne[0]-sw[0])**2+(ne[1]-sw[1])**2)**0.5
    diagonal_length/= 2
    extension = diagonal_length*0.05
    
    ####################################
    # 真正输出的时候不要加这个 extension
    # 让 recognition 算法自己来添加该部分
    ####################################
    
    return [(_bound(se[0]+extension, img_shape[0], 0), _bound(se[1]+extension, img_shape[1], 0)),
            (_bound(sw[0]-extension, img_shape[0], 0), _bound(sw[1]+extension, img_shape[1], 0)),
            (_bound(nw[0]-extension, img_shape[0], 0), _bound(nw[1]-extension, img_shape[1], 0)),
            (_bound(ne[0]+extension, img_shape[0], 0), _bound(ne[1]-extension, img_shape[1], 0)),]

def crop_out_plate(img, vertices):
    pts1 = np.float32(vertices)
    pts2 = np.float32([[300,150],[0,150],[0,0],[300,0]])
    M=cv2.getPerspectiveTransform(pts1,pts2)
    return cv2.warpPerspective(np.uint8(img),M,(300,150))

file_dict = {}
training_dict = {}
test_dict = {}
# folders = ["CCPD2019/ccpd_base/", "CCPD2019/ccpd_blur/", "CCPD2019/ccpd_challenge/",
#            "CCPD2019/ccpd_db/", "CCPD2019/ccpd_fn/",
#            "CCPD2019/ccpd_rotate/", "CCPD2019/ccpd_tilt/", "CCPD2019/ccpd_weather/"]
folders = ["CCPD2019/ccpd_base/"]
for folder in folders:
    file_dict[folder] = [x for x in os.listdir(folder) if x[0]!='.']
    split_point = int(len(file_dict[folder])/20)
    training_dict[folder] = file_dict[folder][:-split_point]
    test_dict[folder] = file_dict[folder][-split_point:]

xs, ys = get_batch(file_dict, 10, x_shape = (512,512), y_shape = (64,64))
print(xs.shape, ys.shape)

In [None]:
BATCH_SIZE = 16
n_batch = 2000
learning_rate = 2e-4
saving_period = 100
model_name = "Se_0a_1"
if model_name not in os.listdir('models/'):
    os.mkdir('models/'+model_name)
x_shape = (512,512)
y_shape = (64,64)

tf.reset_default_graph()
model = seg_model()

gpu_options = tf.GPUOptions(allow_growth=True)
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options,
                                      allow_soft_placement=True,
                                      log_device_placement=False)) as sess:
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    saver.restore(sess, "models/Se_0a_1/Se_0a_1_0.ckpt")
    # tensorboard --logdir logs/
    # summary_writer = tf.summary.FileWriter(logdir = "logs", graph = tf.get_default_graph())
    xs, ys = get_batch(file_dict, 1, x_shape = x_shape, y_shape = y_shape)
    prediction = model.predict(sess, xs)
    show_result(xs, ys, prediction)

    for i in range(1, 1+n_batch):
        bn_momentum = min(0.7, (1-10/(i+10)))
        xs, ys = get_batch(training_dict, BATCH_SIZE, x_shape = (512,512), y_shape = (64,64))
        loss, pred, summary = model.train(sess, learning_rate, bn_momentum, xs, ys)
        # summary_writer.add_summary(summary, i)
        if i%10 == 0:
            print(i, loss)
        
        if i%saving_period == 0:
            save_path = saver.save(sess, "models/"+model_name+"/"+model_name+"_"+str(int(i/10000))+".ckpt")
            print("Model saved in path: "+save_path)
            xs, ys = get_batch(test_dict, BATCH_SIZE, x_shape = (512,512), y_shape = (64,64))
            prediction = model.predict(sess, xs)
            show_result(xs, ys, prediction)