In [None]:
# TODO: 
# - sample_select for rpn
# - 

In [1]:
import tensorflow as tf

In [3]:
tf.stop_gradient?

First we will set out all steps involved in building a Faster RCNN model, assuming we have access to the necessary functions. Then we will write each of the functions.

In [None]:
##################
#### BACKBONE ####
##################
# Run the backbone
x = backbone(images)


#################################
#### REGION PROPOSAL NETWORK ####
#################################
# - Generate labels for anchor boxes, also return normalized anchors
bboxes, labels, anchors = generate_anchor_targets(bboxes_gt)

# Run RPN to obtain proposals and logits
proposals, logits = rpn(x, config)

# Make an array of image indicies as subsequent select steps will cause batch dimension to be lost
image_inds = tf.tile(tf.range(tf.shape(logits)[0])[:,None], [1, tf.shape(logits)[1]]) # batch_size x n_rois

# We pad the bounding boxes to have an equal number per image per batch.
# An anchor box that does not have 0.7 IoU or maximum IoU that is greater than 0 with 
# non-padded ground truth box, may therefore assigned a padded box as target
# (since all anchor boxes will have IoU zero with this so all will have max IoU)
# So we need to filter these
# Select only non-padded
bboxes, proposals, logits, labels, image_inds = select_items([bboxes, proposals, logits, labels, image_inds], 
                                                             get_unpadded(bboxes))


bboxes_spl, proposals_spl, logits_spl, labels_spl, image_inds_spl \ 
                                = select_pos_neg_samples([bboxes, proposals, logits, labels, image_inds], 
                                                        tf.equal(labels, 1), tf.equal(labels, -1), 
                                                        config.rpn_n_pos, config.rpn_n_neg)

# Clf losses using pos/neg items
is_pos = tf.equal(labels_spl, 1)
clf_loss = tf.losses.sigmoid_cross_entropy_with_logits(logits=logits_spl[...,0], 
                                                       labels=tf.to_float(is_pos))

# Reg losses using pos items (the function will only find the loss for the positive proposals)
reg_loss = tf.losses.huber_loss(*select_items([proposals_spl, bboxes_spl], is_pos))

# Combine losses
loss_rpn = clf_loss/config.n_clf + config.rpn_lmd*reg_loss/config.n_reg


##################################
#### DETECTOR - PREPROCESSING ####
##################################

# - Select pos/neg of: proposals, logits, anchors, image_inds
proposals, logits, anchors, image_inds = select_items([proposals, logits, anchors, image_inds], 
                                                      tf.not_equal(labels, 0))

# - Recover bounding boxes - this will be normalized as anchors are normalized
proposals = recover_boxes(proposals, anchors)

# For cross-boundary proposals, either trim to image boundaries
# or discard, and discard any which are not valid bounding boxes i.e. x2 <= x1, etc.
proposals = tf.clip_by_value(proposals, 0, 1)
proposals_valid = tf.any(tf.greater_equal(proposals[...,:2], proposals[...,2:]), axis=-1)
proposals, logits, image_inds = select_items([proposals, logits, image_inds], proposals_valid)

# - Normalize bboxes
bboxes_normed = norm_bboxes(bboxes_gt)

# Process items for each image
proposals, bboxes, labels, image_inds = tf.map_fn(elems = tf.unique(image_inds)[0], fn = process_items)

# # Concatenate proposals, bboxes, labels, targets, image_inds
# items = [tf.reshape(item, [-1, tf.shape(item)[-1]]) for item in items]


############################
#### DETECTOR - NETWORK ####
############################
#Stop gradients for the proposals so that they are regarded as constant for the purposes
#of training the detector
proposals = tf.stop_gradients(proposals)
bboxes_det = transform_bboxes(bboxes_normed, proposals)

# Crop bboxes
roi_pool = tf.image.crop_and_resize(image=x, boxes=proposals, box_ind=image_inds, 
                                    crop_size=config.roi_pool_size)

# Get logits and pred_bboxes
rois, class_logits = detector(roi_pool)

# Clf losses using pos/neg items 
clf_loss_det = tf.losses.sparse_softmax_cross_entropy(logits=class_logits, labels=class_labels)

# Reg losses using pos items (the function will only find the loss for the positive boxes)
class_rois = tf.gather(rois, class_labels, axis=-1)
is_pos = tf.greater(class_labels, 0)
reg_loss_det = tf.losses.huber_loss(*select_items([class_rois, bboxes_det], is_pos))
loss_det = clf_loss_det + config.det_lmd*reg_loss_det

#################
#### METRICS ####
#################
# Total loss
total_loss = loss_det + loss_rpn

# Recover predicted bounding boxes - these will also be normalized 
rois = recover_boxes(rois, proposals)

# Optionally apply nms per class
rois, bboxes, logits, labels, image_inds \
    = process_preds(rois, bboxes, logits, labels, image_inds, n_classes, config)

In [None]:
def norm_bboxes(bboxes, config):
    return bboxes/tf.stack([config.height, config.width]*2)

def get_centres_dims(bboxes):
    # Transforms [x1, y1, x2, y2] to [x_c, y_c, h, w]
    centres = (bboxes[...,2:] + bboxes[...,:2])/2 # (..., (x_c, y_c))
    dims = bboxes[...,2:] - bboxes[...,:2] + 1 # (..., (h, w))
    return centres, dims

def transform_bboxes(bbox1, bbox2):
    centres1, dims1 = get_centres_dims(bbox1)  # a_i0 x a_i1 x a_i2 x ... x 4
    centres2, dims2 = get_centres_dims(bbox2) # b_i0 x b_i1 x b_i2 x ... x 4
    
    # d_in = a_in = b_in or d_in = max(a_in, b_in) if a_in != b_in and a_in = 1 or b_in = 1
    centres_trans = (centres1 - centres2)/dims2  # d_i0 x d_i1 x d_i2 x ... x ... x 2
    dims_trans  =  tf.log(dims1/dims2)  # d_i0 x d_i1 x d_i2 x ... x ... x 2
    
    targets = tf.concat([centres_trans, dims_trans], axis=-1) # d_i0 x d_i1 x d_i2 x ... x ... x 4
    return targets
    
def get_transformed_bboxes(bboxes, config):
    # bboxes: # batch_size x n_anchors x 4
    anchors = tf.constant(config.anchors, tf.float32)[None] # 1 x n_anchors x 4 (xa_1, ya_1, xa_2, ya_2)
    bboxes = tf.to_float(bboxes)

    anchors = norm_bboxes(anchors)
    bboxes = norm_bboxes(bboxes)
    
    bboxes = transform_bboxes(bboxes, anchors) # batch_size x n_anchors x 4
    not_padding = tf.to_float(tf.reduce_any(tf.greater(anchor_targets, 0), axis=-1, keepdims=True)) # batch_size x n_anchors x 1
    
    return not_padding*bboxes, not_padding*anchors # batch_size x n_anchors x 4

def generate_anchor_targets(bboxes, config):
    # bboxes: batch_size x max_n_bboxes x 4
    ious = find_iou(config.anchors.astype(np.int32), bboxes) #batch_size x n_anchors x max_n_bboxes
    pos_mask, target_inds = identify_pos_boxes(ious, config.rpn_pos_th) #(batch_size x n_anchors, batch_size x n_anchors)
    neg_mask = identify_neg_rois(ious, pos_mask, config.rpn_neg_th) #batch_size x n_anchors

    anchor_labels = tf.zeros(tf.shape(pos_mask)) #batch_size x n_anchors
    anchor_labels = tf.where(pos_mask, tf.ones_like(anchor_labels), anchor_labels) #batch_size x n_anchors
    anchor_labels = tf.where(neg_mask, -tf.ones_like(anchor_labels), anchor_labels) #batch_size x n_anchors
    anchor_bboxes = tf.batch_gather(bboxes, tf.to_int32(target_inds)) #batch_size x n_anchors x 4
    anchor_bboxes = get_transformed_bboxes(bboxes)
    return anchor_bboxes, anchor_labels, anchors #(batch_size x n_anchors, batch_size x n_anchors x 4, batch_size x n_anchors)


In [None]:
def rpn(x, config):
    n_anchors = self.config.n_anchors
    conv = tf.layers.conv2d(inputs=x, activation=tf.nn.relu, **config.tiny_conv_kwargs)
    #TODO: add bn, relu as needed
    clf = tf.layers.conv2d(inputs=conv, kernel_size=1, filters=1*n_anchors, padding='same')
    reg = tf.layers.conv2d(inputs=conv, kernel_size=1, filters=4*n_anchors, padding='same')
    clf = tf.reshape(clf, tf.concat([tf.shape(clf)[:1], [-1, 1]], axis=0))
    reg = tf.reshape(reg, tf.concat([tf.shape(reg)[:1], [-1, 4]], axis=0))
    return clf, reg

In [None]:
def select_items(arrs, mask):
    arrs =[tf.boolean_mask(arr, mask) for arr in arrs]
    return arrs

def gather_items(arrs, inds):
    arrs = [tf.gather(arr, inds) for arr in arrs]
    return arrs

In [None]:
def clf_loss(logits, anchor_labels):
    losses = tf.losses.sigmoid_cross_entropy_with_logits(logits=pos_neg_logits[...,0], 
                                                     labels=anchor_labels)
    return losses

def reg_loss(proposals, targets):
    losses = tf.losses.huber_loss(proposals, targets)
    return losses


In [None]:
def get_corners(centres, dims):
    shift = (dims - 1)/2
    corner1 = centres - shift
    corner2 = centres + shift
    return tf.concat([corner1, corner2], axis=-1)

In [None]:
def recover_boxes(boxes1, boxes2):
    #proposals: batch_size x n_anchors
    #anchors: batch_size x n_anchors
    
    centres2 = boxes2[...,:2]
    dims2 = boxes2[...,2:]
    
    centres1 =  centres2 + boxes1[...,:2]*dims2
    dims1 = tf.exp(boxes1[...,:2])*dims2
    
    rois = get_corners(centres1, dims1)
    return rois
    

In [None]:
def get_unpadded(bboxes):
    return tf.reduce_any(tf.greater(bboxes, 0), axis=-1)

In [None]:
def process_items(image_ind):
    # - Select items corresponding to image
    proposals_img, logits_img = select_items([proposals, logits], 
                                                 tf.equal(image_inds, image_ind))
    
    # Select classes and bboxes 
    classes_img = classes[image_ind]
    bboxes_img = bboxes_normed[image_ind]

    # Select unpadded only
    classes_img, bboxes_img = select_items([classes_img, bboxes_img],
                                           get_unpadded(bboxes_img))
    
    # - Select nms of: proposals, targets (logits used for scoring)
    proposals_img = get_nms(proposals_img, logits_img)
    
    # - Generate labels for the proposals, sampling positives and negatives
    proposals_img, bboxes_img, classes_img = generate_det_inputs(proposals_img, bboxes_img, classes_img)
    
#     # - Select pos/neg of: proposals, bboxes, labels, image_inds, sample from these
#     proposals_pos, bboxes_pos, labels_pos = sample_select([proposals_pos, bboxes_pos, labels_pos], 
#                                                           tf.greater(labels_img, 0))
    
#     # - Select neg of: proposals, class_labels, sample from these
#     # - Here bboxes are just zeros (returned for convenience)
#     proposals_neg, bboxes_neg, labels_neg = sample_select([proposals_pos, bboxes_pos, labels_pos], 
#                                                           tf.equal(labels_img, 0))
    
    # Concatenate pos and neg
#     proposals_img = tf.concat([proposals_pos, proposals_neg], axis=0)
#     bboxes_img = tf.concat([bboxes_pos, bboxes_neg], axis=0)
#     labels_img = tf.concat([labels_pos, labels_neg], axis=0)
    img_inds = tf.tile([image_ind], [tf.shape(labels_img)])
    
    return proposals_img, bboxes_img, labels_img, img_inds

def process_preds(rois, bboxes, logits, labels, image_inds, n_classes, config):
    
    def _nms_fn(n):
        image_ind = n[0]
        class_label = n[1]
        rois, logits = select_items([rois, logits], tf.logical_and(tf.equal(labels, image_ind), 
                                                                   tf.equal(labels, class_label)))
        inds = tf.image.non_max_suppression(rois, tf.nn.sigmoid(logits),
                                           max_output_size=config.max_rois,
                                           iou_threshold=config.roi_nms_th)
        rois, bboxes, logits = gather_items([rois, bboxes, logits], inds)
        n_rois = tf.shape(rois)[0]
        return rois, bboxes, logits, tf.tile([class_label], n_rois), tf.tile([image_ind], n_rois)
    
    rois, bboxes, logits, class_labels, image_inds \
                = tf.map_fn(elems=tf.meshgrid(image_inds, n_classes), fn=_nms_fn)
    return rois, bboxes, logits, class_labels, image_inds
        

def get_nms(proposals, logits):
    nms_inds = tf.image.non_max_suppression(proposals, tf.nn.sigmoid(logits),
                                           max_output_size=config.max_output_size,
                                           iou_threshold=config.proposal_nms_th)
    proposals = tf.gather(proposals, nms_inds)
    return proposals
    

In [None]:
def select_pos_neg_samples(items, is_pos, is_neg, n_pos, n_neg):
    # - Select pos of: logits, labels, sample from these
    pos_keep = tf.random.shuffle(tf.where(is_pos))[:n_pos]
    neg_keep = tf.random.shuffle(tf.where(is_neg))[:n_neg + n_pos - tf.shape(pos_keep)[0]]
    
    bboxes, proposals, logits, labels = gather_items(items, tf.concat([pos_keep, neg_keep], axis=0))
    
    return bboxes, proposals, logits, labels 

# def select_rpn_samples(bboxes, proposals, logits, labels, image_inds)
#     # - Select pos of: logits, labels, sample from these
#     pos_keep = tf.random.shuffle(tf.where(tf.equal(labels, 1)))[:config.rpn_n_pos]
#     neg_keep = tf.random.shuffle(tf.where(tf.equal(labels, -1)))[:2*config.rpn_n_pos - tf.shape(pos_keep)[0]]
    
    
#     bboxes, proposals, logits, labels = gather_items([bboxes, proposals, logits, labels],
#                                                      tf.concat([pos_keep, neg_keep], axis=0))
    
#     return bboxes, proposals, logits, labels


In [None]:
def generate_det_inputs(proposals, bboxes, class_labels, config):
    #proposals: n_rois
    #bboxes: n_bboxes
    
    ious = find_iou(proposals[..., None], bboxes[...,None,:]) # n_rois x n_bboxes
    
    # Get labels - note that is_pos and is_neg are disjoint 
    is_pos = tf.any(tf.greater_equal(ious, config.det_pos_th), axis=-1) # n_rois
    is_neg = tf.any(tf.logical_and(tf.greater_equal(ious, config.det_neg_th_ge), 
                            tf.less(ious, config.det_neg_th_lt)), axis=-1) # n_rois
    
    
    # Find bboxes with top IoU to match targets
    inds_match = tf.argmax(ious, axis=-1) # n_rois
    
    
    #Can be lazy with gathering bboxes for all because they won't be used for non-positives
    #But need to be careful that non-positive class_labels are zero - still a bit lazy as
    #neutral will also be zero but these get filtered below
    bboxes = tf.gather(bboxes, inds_match) # n_rois x 4
    class_labels = tf.where(is_pos, tf.gather(class_labels, inds_match), 
                            tf.zeros_like(class_labels)) # n_rois 
    
    # Get samples
#     pos_keep = tf.random.shuffle(tf.where(is_pos))[:config.det_n_pos] # n_pos
#     neg_keep = tf.random.shuffle(tf.where(is_neg))[:2*config.det_n_pos - tf.shape(pos_keep)[0]] # n_pos
    
#     proposals, bboxes, class_labels = gather_items([proposals, bboxes, class_labels],
#                                     tf.concat([pos_keep, neg_keep], axis=0)) # n_rois x 4, n_rois x 4, n_rois
    
    proposals, bboxes, class_labels = select_pos_neg_samples([proposals, bboxes, class_labels], 
                                                             is_pos, is_neg, config.det_n_pos, config.det_n_neg)
    
    return proposals, bboxes, class_labels
    

In [None]:
a = []