In [1]:
import scan_csv
from FastSlidingWindow import *
from Util import *
from bbd100k_loader import *
from scan_csv import progress

loader = BBD100K_Loader(True)
color_map = generate_color_from_categories(loader.category_dict)

In [2]:
import tensorflow as tf
import numpy as np
tf.enable_eager_execution()
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
tf.Session(config=config).close()

In [3]:
def calc_iou(Y_pred, Y):
    print(Y_pred.shape)
    zeros = np.zeros([Y_pred.shape[0], Y_pred.shape[1], Y_pred.shape[2], 1], dtype=np.float32)
    y2 = tf.reshape(Y_pred[:, :, :, 2], [Y_pred.shape[0], Y_pred.shape[1], Y_pred.shape[2], 1])
    y4 = tf.reshape(Y_pred[:, :, :, 4], [Y_pred.shape[0], Y_pred.shape[1], Y_pred.shape[2], 1])
    pred_x_w = tf.where(y2 > 0.0, y2, zeros) 
    pred_x_h = tf.where(y4 > 0.0, y4, zeros) 
    
    x1_t = Y_pred[:, :, :, 1:2] - Y[:, :, :, 2:3] / 2.0
    x2_t = Y_pred[:, :, :, 1:2] + Y[:, :, :, 2:3] / 2.0
    
    y1_t = Y_pred[:, :, :, 3:4] - Y[:, :, :, 4:5] / 2.0
    y2_t = Y_pred[:, :, :, 3:4] + Y[:, :, :, 4:5] / 2.0
    
    x1_p = Y_pred[:, :, :, 1:2] - pred_x_w / 2.0
    x2_p = Y_pred[:, :, :, 1:2] + pred_x_w / 2.0
    
    y1_p = Y_pred[:, :, :, 3:4] - pred_x_h / 2.0
    y2_p = Y_pred[:, :, :, 3:4] + pred_x_h / 2.0
    
    cond1 = x2_t < x1_p
    cond2 = x2_p < x1_t
    cond3 = y2_t < y1_p
    cond4 = y2_p < y1_t
    cond_all = tf.logical_or(tf.logical_or(tf.logical_or(cond1, cond2), cond3), cond4)
    
    ious_np = np.zeros([Y_pred.shape[0], Y_pred.shape[1], Y_pred.shape[2], 1], dtype=np.float32)
    
    far_x = tf.where(x2_t < x2_p, x2_t, x2_p)
    near_x = tf.where(x1_t > x1_p, x1_t, x1_p)
    far_y = tf.where(y2_t < y2_p, y2_t, y2_p)
    near_y = tf.where(y1_t > y1_p, y1_t, y1_p)
    
    inter_area = (far_x - near_x + 1.0) * (far_y - near_y + 1.0)
    true_box_area = (x2_t - x1_t + 1.0) * (y2_t - y1_t + 1.0)
    pred_box_area = (x2_p - x1_p + 1.0) * (y2_p - y1_p + 1.0)
    iou = inter_area / (true_box_area + pred_box_area - inter_area)
    iou = tf.where(cond_all, ious_np, iou)
    return iou
    


In [4]:
w_val = 0.01
W_M_1 = tf.Variable(np.random.uniform(-w_val, w_val, [3, 3, 3, 32]), dtype=tf.float32, name='WM1')
W_M_1_1 = tf.Variable(np.random.uniform(-w_val, w_val, [3, 3, 32, 64]), dtype=tf.float32, name='WM11')
W_M_2 = tf.Variable(np.random.uniform(-w_val, w_val, [3, 3, 64, 128]), dtype=tf.float32, name='WM2')
W_M_2_1 = tf.Variable(np.random.uniform(-w_val, w_val, [3, 3, 128, 256]), dtype=tf.float32, name='WM21')
W_M_2_2 = tf.Variable(np.random.uniform(-w_val, w_val, [3, 3, 256, 512]), dtype=tf.float32, name='WM22')
W_M_3 = tf.Variable(np.random.uniform(-w_val, w_val, [5*5*512, 512]), dtype=tf.float32, name='WM3')
W_M_4_F = tf.Variable(np.random.uniform(-w_val, w_val, [512, len(loader.category_dict)*5]), dtype=tf.float32, name='WM4F')

In [9]:
loss_value = tf.Variable(0.0, dtype=tf.float32)
optimizer = tf.train.AdamOptimizer(learning_rate=0.001)

def mean_squared(inputs, targets):
  error = inputs - targets
  return tf.reduce_mean(tf.square(error))

patch_size = 150
stride = 150

index = 0
loss = 0.0
while True:
    X, Y, image, cmap = loader.gather(30, stride, patch_size, 0.9)
    print(X.shape)
    
    X = np.array(X, dtype=np.float32)
    if image is not None:
        if np.max([image.shape[0], image.shape[1]]) >= 2000:
            image = None

    if image is not None:
        iou_sum = 0.0
        iou_mean = 0.0
        
        with tf.GradientTape() as tape:
            # ENCODE
            X_flat = tf.reshape(X, [X.shape[0]*X.shape[1], X.shape[2], X.shape[3], X.shape[4]])
            encode_conv = tf.nn.leaky_relu(tf.nn.conv2d(X_flat, W_M_1, [1, 2, 2, 1], 'SAME'))
            encode_conv = tf.nn.leaky_relu(tf.nn.conv2d(encode_conv, W_M_1_1, [1, 2, 2, 1], 'SAME'))
            share_conv = encode_conv
        
            share_conv = tf.nn.leaky_relu(tf.nn.conv2d(share_conv, W_M_2, [1, 2, 2, 1], 'SAME'))
            share_conv = tf.nn.leaky_relu(tf.nn.conv2d(share_conv, W_M_2_1, [1, 2, 2, 1], 'SAME'))
            share_conv = tf.nn.leaky_relu(tf.nn.conv2d(share_conv, W_M_2_2, [1, 2, 2, 1], 'SAME'))
            print(share_conv.shape)
            share_flat = tf.reshape(share_conv, [share_conv.shape[0], W_M_3.shape[0]])
            
            flat_1 = tf.nn.leaky_relu(tf.matmul(share_flat, W_M_3))
            Y_pred = tf.matmul(flat_1, W_M_4_F)
            Y_pred = tf.reshape(Y_pred, [Y_pred.shape[0], len(loader.category_dict), 5])
            class_pred = tf.nn.sigmoid(Y_pred[:, :, 0])
            Y_pred = tf.stack([class_pred, Y_pred[:, :, 1], Y_pred[:, :, 2], Y_pred[:, :, 3], Y_pred[:, :, 4]], axis=1)
            Y_pred = tf.reshape(Y_pred, Y.shape)

            iou = calc_iou(Y_pred, Y)
            iou_1 = 1.0 - tf.reduce_mean(iou)

            ms = mean_squared(Y_pred[:, :, :, :], Y[:, :, :, :])
            loss_value = ms*1000.0 

        loss += loss_value.numpy()
        params = [W_M_1, W_M_1_1, W_M_2, W_M_2_2, W_M_3, W_M_4_F]
        optimizer.apply_gradients(zip(tape.gradient(loss_value, params), params),
                            global_step=tf.train.get_or_create_global_step())
        
        img_orig = draw_from_label(image, Y, cmap, patch_size, color_map, 0.1, draw_patches=False, max_count=1000)
        img_pred = draw_from_label(image, Y_pred.numpy(), cmap, patch_size, color_map, 0.9, draw_patches=True)

        cv2.imshow('orig', img_orig)
        cv2.imshow('pred', img_pred)
        if cv2.waitKey(1) & 0xFF == ord('q'):
            cv2.destroyAllWindows()
            break
    progress(index, loader.image_count, index)
    index += 1
    if index >= 1:
        index = 0
        print(' ' + str(loss))
        loss = 0.0

(5, 9, 150, 150, 3)
(45, 5, 5, 512)
(5, 9, 10, 5)
 14.454277992248535----------------------] 0.0% -> 0
(5, 9, 150, 150, 3)
(45, 5, 5, 512)
(5, 9, 10, 5)
 26.78461265563965-----------------------] 0.0% -> 0
(5, 9, 150, 150, 3)
(45, 5, 5, 512)
(5, 9, 10, 5)
 16.413951873779297----------------------] 0.0% -> 0
(5, 9, 150, 150, 3)
(45, 5, 5, 512)
(5, 9, 10, 5)
 14.840457916259766----------------------] 0.0% -> 0
(5, 9, 150, 150, 3)
(45, 5, 5, 512)
(5, 9, 10, 5)
 14.912602424621582----------------------] 0.0% -> 0
(5, 9, 150, 150, 3)
(45, 5, 5, 512)
(5, 9, 10, 5)
 13.09705924987793-----------------------] 0.0% -> 0
(5, 9, 150, 150, 3)
(45, 5, 5, 512)
(5, 9, 10, 5)
 12.02727222442627-----------------------] 0.0% -> 0
(5, 9, 150, 150, 3)
(45, 5, 5, 512)
(5, 9, 10, 5)
 11.66405963897705-----------------------] 0.0% -> 0
(5, 9, 150, 150, 3)
(45, 5, 5, 512)
(5, 9, 10, 5)
 11.575096130371094----------------------] 0.0% -> 0
(5, 9, 150, 150, 3)
(45, 5, 5, 512)
(5, 9, 10, 5)
 11.534822463989258----

 5.424131870269775-----------------------] 0.0% -> 0
(5, 9, 150, 150, 3)
(45, 5, 5, 512)
(5, 9, 10, 5)
 5.421267032623291-----------------------] 0.0% -> 0
(5, 9, 150, 150, 3)
(45, 5, 5, 512)
(5, 9, 10, 5)
 5.402795314788818-----------------------] 0.0% -> 0
(5, 9, 150, 150, 3)
(45, 5, 5, 512)
(5, 9, 10, 5)
 5.377954483032227-----------------------] 0.0% -> 0
(5, 9, 150, 150, 3)
(45, 5, 5, 512)
(5, 9, 10, 5)
 5.372196197509766-----------------------] 0.0% -> 0
(5, 9, 150, 150, 3)
(45, 5, 5, 512)
(5, 9, 10, 5)
 5.362659454345703-----------------------] 0.0% -> 0
(5, 9, 150, 150, 3)
(45, 5, 5, 512)
(5, 9, 10, 5)
 5.34906530380249------------------------] 0.0% -> 0
(5, 9, 150, 150, 3)
(45, 5, 5, 512)
(5, 9, 10, 5)
 5.335048198699951-----------------------] 0.0% -> 0
(5, 9, 150, 150, 3)
(45, 5, 5, 512)
(5, 9, 10, 5)
 5.325736045837402-----------------------] 0.0% -> 0
(5, 9, 150, 150, 3)
(45, 5, 5, 512)
(5, 9, 10, 5)
 5.316788673400879-----------------------] 0.0% -> 0
(5, 9, 150, 150, 3)


(5, 9, 150, 150, 3)
(45, 5, 5, 512)
(5, 9, 10, 5)
 5.162167072296143-----------------------] 0.0% -> 0
(5, 9, 150, 150, 3)
(45, 5, 5, 512)
(5, 9, 10, 5)
 5.159977436065674-----------------------] 0.0% -> 0
(5, 9, 150, 150, 3)
(45, 5, 5, 512)
(5, 9, 10, 5)
 5.158453941345215-----------------------] 0.0% -> 0
(5, 9, 150, 150, 3)
(45, 5, 5, 512)
(5, 9, 10, 5)
 5.1591291427612305----------------------] 0.0% -> 0
(5, 9, 150, 150, 3)
(45, 5, 5, 512)
(5, 9, 10, 5)
 5.16070032119751------------------------] 0.0% -> 0
(5, 9, 150, 150, 3)
(45, 5, 5, 512)
(5, 9, 10, 5)
 5.161190986633301-----------------------] 0.0% -> 0
(5, 9, 150, 150, 3)
(45, 5, 5, 512)
(5, 9, 10, 5)
 5.1605939865112305----------------------] 0.0% -> 0
(5, 9, 150, 150, 3)
(45, 5, 5, 512)
(5, 9, 10, 5)
 5.160220146179199-----------------------] 0.0% -> 0
(5, 9, 150, 150, 3)
(45, 5, 5, 512)
(5, 9, 10, 5)
 5.161630153656006-----------------------] 0.0% -> 0
(5, 9, 150, 150, 3)
(45, 5, 5, 512)
(5, 9, 10, 5)
 5.1657915115356445----