In [0]:
from __future__ import absolute_import, division, print_function, unicode_literals
import os.path as osp
import numpy as np
import cv2
import matplotlib.pyplot as plt
import tensorflow as tf
from pycocotools.coco import COCO
import matplotlib.patches as patches
import tensorflow.contrib.slim as slim
import math
import gc
import sys

In [0]:
# Import params module. 
from google.colab import files
src = list(files.upload().values())[0]
open('params.py','wb').write(src)
import params

In [0]:
# Import data_utils module. 
from google.colab import files
src = list(files.upload().values())[0]
open('data_utils.py','wb').write(src)
import data_utils

In [0]:
# Import network_utils module. 
from google.colab import files
src = list(files.upload().values())[0]
open('network_utils.py','wb').write(src)
import network_utils

In [0]:
# if necessary download image data

! mkdir image_data
! wget http://images.cocodataset.org/zips/val2017.zip
! unzip val2017.zip -d ./image_data
! rm val2017.zip

In [0]:
# download annotations 

! wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip
! unzip annotations_trainval2017.zip -d ./image_data
! rm annotations_trainval2017.zip

In [0]:
with tf.Graph().as_default() as g:
  
  image = tf.placeholder(shape=[1,None,None,3],dtype = tf.float32)
  gt_boxes = tf.placeholder(shape=[None,5],dtype = tf.float32)
  image_info = tf.placeholder(shape=[3,],dtype = tf.float32)
  
  # add inception network to graph
  with tf.variable_scope('base_model'):
    base_model = tf.keras.applications.inception_v3.InceptionV3(include_top = False,
                                                             weights = 'imagenet',
                                                             input_shape= (None,None,3))
  # make feature map
  feature_map = base_model(image)
  # make anchors 
  anchor_boxes = network_utils.get_anchor_boxes(image_info,feature_map)
  num_anchors = params.num_anchors
  num_anchor_locs = tf.shape(feature_map)[1]*tf.shape(feature_map)[2]
  
 
  init = tf.random_normal_initializer(mean=0,stddev=0.01)
  with slim.arg_scope([slim.conv2d],activation_fn = tf.nn.relu,
                       padding='SAME',weights_initializer=init):
    with tf.variable_scope('rpn_network'):
      
      # construct the rpn network
      net = slim.conv2d(feature_map,params.rpn_channels,[3,3])
      cls_scores = slim.conv2d(net,num_anchors*2,[1,1],activation_fn=None,
                              padding='VALID')
      bbox_adjs = slim.conv2d(net,num_anchors*4,[1,1],activation_fn=None,
                             padding='VALID')
     
      # get the anchor labels and sample anchors for training 
      anchor_labels,anchor_bbox_adjs,_ = tf.py_func(network_utils.get_anchor_labels,
                                                    [anchor_boxes,gt_boxes,params.rpn_pos_anchor_thresh,params.rpn_neg_anchor_thresh_lo,params.rpn_neg_anchor_thresh_hi],
                                                    [tf.float32,tf.float32,tf.int32])
    
      anchor_mask = tf.py_func(network_utils.sample_anchors_for_training,[anchor_labels,params.rpn_mini_batch,params.rpn_prop_pos],tf.int32)
      num_pos = tf.reduce_sum(tf.cast(tf.equal(anchor_mask,1),tf.int32))
      
      # reformat the anchor labels for one hot encoding
      anchor_labels_ce = tf.py_func(network_utils.make_anchor_ce_labels,[anchor_labels,anchor_mask],[tf.int32])[0]
      inds = tf.where(tf.not_equal(anchor_mask,0))
      cls_scores_reshape = tf.reshape(cls_scores,(-1,2))
      cls_scores_ce = tf.gather_nd(cls_scores_reshape,inds)
      num_scores = tf.shape(cls_scores_ce)[0]
      
      cls_preds = tf.ones(num_scores,dtype=tf.int32)-tf.cast(tf.argmax(cls_scores_ce,axis=1),dtype=tf.int32)
      cls_accuracy = tf.reduce_sum(tf.cast(tf.equal(cls_preds,anchor_labels_ce[:,0]),dtype=tf.int32))/num_scores
      pos_preds = tf.reduce_sum(cls_preds)
      tf.summary.scalar('cls_accuracy',cls_accuracy)
      
      # compute the cross entropy loss
      rpn_cls_ce = tf.nn.softmax_cross_entropy_with_logits(labels=anchor_labels_ce,logits=cls_scores_ce)
      rpn_cls_loss = tf.reduce_mean(rpn_cls_ce)
      tf.summary.scalar('rpn_cls_loss',rpn_cls_loss)
      
      # compute the bbox regression loss
      inds = tf.where(tf.equal(anchor_mask,1))
      bbox_adjs_reshape = tf.reshape(bbox_adjs,(-1,4))
      bbox_adjs_reshape_batch = tf.gather_nd(bbox_adjs_reshape,inds)
      anchor_bbox_adjs_reshape_batch = tf.gather_nd(anchor_bbox_adjs,inds)
      rpn_bbox_loss = tf.losses.huber_loss(anchor_bbox_adjs_reshape_batch,bbox_adjs_reshape_batch)
      tf.summary.scalar('rpn_bbox_loss',rpn_bbox_loss)
      
      # compute the total rpn loss for the mini-batch
      rpn_total_loss = tf.cast(rpn_cls_loss,tf.float32) + tf.cast(params.rpn_gamma,tf.float32)*(1/tf.cast(num_anchor_locs,tf.float32))*tf.cast(rpn_bbox_loss,tf.float32)
      tf.summary.scalar('rpn_total_loss',rpn_total_loss)
      
      # compute and apply gradients
      rpn_learning_rate = tf.placeholder(tf.float32)
      rpn_optimizer = tf.contrib.opt.MomentumWOptimizer(params.rpn_weight_decay,
                                                        rpn_learning_rate,
                                                        params.rpn_momentum)
      
      # make sure to only train variables from rpn network - not from base model
      rpn_training_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,'rpn_network')
      rpn_grads = rpn_optimizer.compute_gradients(rpn_total_loss,var_list = rpn_training_vars)
      rpn_training_op = rpn_optimizer.apply_gradients(rpn_grads)
      
      init_rpn_training_op = tf.variables_initializer(rpn_training_vars)
      init_rpn_momen_op = tf.variables_initializer(rpn_optimizer.variables())
      
      # add network variables to tensorboard 
      for var in rpn_training_vars:
        tf.summary.histogram(var.name,var)
  
      rpn_merged_summary = tf.summary.merge_all()
    
    with tf.variable_scope('fast_rcnn_network'):
      #### add network structure for fast-rcnn
      training_rcnn = tf.placeholder(tf.int32)
      rcnn_learning_rate = tf.placeholder(tf.float32)
      
      # roi pooling
      rois,roi_cls_labels,roi_bbox_gts = network_utils.select_rois(anchor_boxes,gt_boxes,bbox_adjs_reshape,cls_scores_reshape[:,0],cls_preds,training_rcnn)
      
      roi_pools = network_utils.roi_pooling(rois,feature_map,image_info)
     
      roi_pools = slim.max_pool2d(roi_pools,[2,2],padding='SAME')
      
      # fast rcnn network 
      roi_pools_flat = slim.flatten(roi_pools)
      fc0 = slim.fully_connected(roi_pools_flat,4096,weights_initializer=init)
      fc1 = slim.fully_connected(fc0,4096)
      rcnn_cls_scores = slim.fully_connected(fc1,params.num_classes,activation_fn=None,weights_initializer=init)
      rcnn_bbox_adjs = slim.fully_connected(fc1,params.num_classes*4,activation_fn=None,weights_initializer=init)
      rcnn_bbox_adjs_reshape = tf.reshape(rcnn_bbox_adjs,(-1,params.num_classes,4))
      
      # compute fast rcnn training classification loss
      rcnn_one_hots = tf.one_hot(roi_cls_labels,params.num_classes)
      rcnn_cls_ce = tf.nn.softmax_cross_entropy_with_logits(labels=rcnn_one_hots,logits=rcnn_cls_scores)
      rcnn_cls_loss = tf.reduce_mean(rcnn_cls_ce)
      
      # select the positive rois and the correct class label bbox adjustment
      pos_rois = tf.where(tf.not_equal(roi_cls_labels,0))
      roi_bbox_gts_batch = tf.gather_nd(roi_bbox_gts,pos_rois)
      roi_cls_labels_batch = tf.gather_nd(roi_cls_labels,pos_rois)
      rcnn_bbox_adjs_reshape_batch = tf.gather_nd(rcnn_bbox_adjs_reshape,pos_rois)
      inds = tf.stack((tf.range(tf.shape(rcnn_bbox_adjs_reshape_batch)[0]),roi_cls_labels_batch),axis=1)
      rcnn_bbox_adjs_reshape_batch = tf.gather_nd(rcnn_bbox_adjs_reshape_batch,inds)
      
      # compute the fast rcnn bbox regression loss
      rcnn_bbox_loss = tf.losses.huber_loss(roi_bbox_gts_batch,rcnn_bbox_adjs_reshape_batch)
      
      # fast rcnn total loss
      rcnn_total_loss = rcnn_cls_loss + params.rcnn_gamma*rcnn_bbox_loss
      
      tf.summary.scalar('rcnn_cls_loss',rcnn_cls_loss)
      tf.summary.scalar('rcnn_bbox_loss',rcnn_bbox_loss)
      tf.summary.scalar('rcnn_total_loss',rcnn_total_loss)
      
      # compute and apply gradients
      rcnn_learning_rate = tf.placeholder(tf.float32)
      rcnn_optimizer = tf.contrib.opt.MomentumWOptimizer(params.rcnn_weight_decay,
                                                         rcnn_learning_rate,
                                                         params.rcnn_momentum)
      
      # get fast rcnn training vars
      rcnn_training_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,'fast_rcnn_network')
      rcnn_grads = rcnn_optimizer.compute_gradients(rcnn_total_loss,var_list=rcnn_training_vars)
      rcnn_training_op = rcnn_optimizer.apply_gradients(rcnn_grads)
      
      # minimize loss
      init_rcnn_training_op = tf.variables_initializer(rcnn_training_vars)
      init_rcnn_momen_op = tf.variables_initializer(rcnn_optimizer.variables())
      
      rcnn_merged_summary = tf.summary.merge_all()
    
  with tf.keras.backend.get_session() as sess:
    # train rpn network
    writer = tf.summary.FileWriter('./graphs')
    writer.add_graph(sess.graph)
    sess.run(init_rpn_training_op)
    sess.run(init_rpn_momen_op)
    im_permutation = data_utils.get_image_permutation(params.np_random_seed,'val2017')
    print('num images: ',len(im_permutation))
    i = 0 
    while i < 25:
      training_images,gt_training_boxes,im_infos = data_utils.get_training_data(i,im_permutation,params.image_batch_size,'val2017')
      j=0
      while j < len(training_images):
        training_image = training_images[j]
        gt_training_box = gt_training_boxes[j]
        im_info = im_infos[j]
        feed_dict = {image:training_image,gt_boxes:gt_training_box,rpn_learning_rate:params.lr1,
                     image_info:[im_info['height'],im_info['width'],im_info['id']]}
        sess.run(rpn_training_op,feed_dict = feed_dict)
        if j % 10 == 0:
          s,ca,pp,rcl,rbl,rtl,cls_ce,cls_p,al_ce,num_p = sess.run([rpn_merged_summary,cls_accuracy,pos_preds,rpn_cls_loss,
                        rpn_bbox_loss,rpn_total_loss,cls_scores_ce,cls_preds,anchor_labels_ce,num_pos],feed_dict = feed_dict)
          writer.add_summary(s,i*params.image_batch_size + j)
          print('--------------------')
          print('num positive samples: ',num_p)
          print('class accuracy: ',ca)
          print('positive preds: ',pp)
          print('rpn class loss: ',rcl)
          print('rpn bbox loss: ',rbl)
          print('rpn total loss: ',rtl)
          print('-------------------')
        j=j+1
      i=i+1
      print('done {} iterations'.format(i*params.image_batch_size))
    
    # train rcnn network
    i=0
    im_permutation = data_utils.get_image_permutation(params.np_random_seed+1,'val2017')
    sess.run(init_rcnn_training_op)
    sess.run(init_rcnn_momen_op)
    while i < params.num_image_batches:
      training_images,gt_training_boxes,im_infos = data_utils.get_training_data(i,im_permutation,params.image_batch_size,'val2017')
      j=0
      while j < len(training_images):
        training_image = training_images[j]
        gt_training_box = gt_training_boxes[j]
        im_info = im_infos[j]
        feed_dict = {image:training_image,gt_boxes:gt_training_box,
                     image_info:[im_info['height'],im_info['width'],im_info['id']],
                     rcnn_learning_rate:params.rcnn_learning_rate,training_rcnn:1}
        sess.run(rcnn_training_op,feed_dict=feed_dict)
        if j % 10 == 0:
          s = sess.run(rcnn_merged_summary,feed_dict=feed_dict)
          writer.add_summary(s,i*params.image_batch_size + j)
        j=j+1
      i=i+1
      print('done {} iterations of rcnn training'.format(i*params.image_batch_size))
          

In [0]:
# Run Tensorboard
LOG_DIR = './graphs'
get_ipython().system_raw(
    'tensorboard --logdir {} --host 0.0.0.0 --port 6006 &'
    .format(LOG_DIR)
)

In [0]:
# Download and unzip ngrok
! wget https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-linux-amd64.zip
! unzip ngrok-stable-linux-amd64.zip

In [0]:
# Launch the ngrok background process
get_ipython().system_raw('./ngrok http 6006 &')

In [0]:
# Get the public URL
! curl -s http://localhost:4040/api/tunnels | python3 -c \
    "import sys, json; print(json.load(sys.stdin)['tunnels'][0]['public_url'])"