Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can you train for keypoint detection? #2

Open
rujiao opened this issue Nov 2, 2017 · 82 comments
Open

Can you train for keypoint detection? #2

rujiao opened this issue Nov 2, 2017 · 82 comments

Comments

@rujiao
Copy link

rujiao commented Nov 2, 2017

Hi @waleedka : Thanks for the great work! Is it possible to train for keypoint detection? Sorry for the wrong title of the issue, I can't correct it.

@waleedka waleedka changed the title Can Can you train for keypoint detection? Nov 2, 2017
@waleedka
Copy link
Collaborator

waleedka commented Nov 2, 2017

Key point detection is not implemented in this release but it should be easy to do if you want to try it. Simply change the mask head to use cross entropy loss rather than binary cross entropy and extend the Dataset class to load a dataset that includes key points.

@rujiao
Copy link
Author

rujiao commented Nov 2, 2017

@waleedka: Thanks a lot for your rapid reply. I will try keypoint detection.
Todos:

  1. The number of generated masks in each ROI equals the number of keypoints, instead of the class number.
  2. In comparison to instance segmentation, all masks contribute to the loss
  3. Loss function for one mask: Cross-entropy loss over softmax of all pixels in one mask, instead of binary cross entropy, as you mentioned

Could you guide me a little bit where (in which files) to add the code. I am reading your code ... :)

@waleedka
Copy link
Collaborator

waleedka commented Nov 2, 2017

I think most of your changes will be in model.py:

  • build_fpn_mask_graph() builds the mask head
  • mrcnn_mask_loss_graph() is the loss function for the mask head
  • Dataset class is the base class for loading data. There is documentation there for how to extend it. For example, for COCO, it's subclassed in CocoDataset. You can sub-class Dataset and override load_mask() to load different types of masks.

Which dataset are you going to be using?

@rujiao
Copy link
Author

rujiao commented Nov 2, 2017

Thank you so much @waleedka . I will train my own data but in the same format as COCO. By the way, could you tell me why did you do "...=np.where(m >= 128, 1, 0)" both in minimize_mask and expand_mask? In minimize_mask, you may convert the data type to boolean, since the data type in mini_mask is boolean, but why 128? In expand_mask, m from mini_mask is already boolean, doing " mask[y1:y2, x1:x2, i] = np.where(m >= 128, 1, 0)", all the elements would be 0. My understanding is definitely somehow wrong. Can you help me understand it correctly?

@waleedka
Copy link
Collaborator

waleedka commented Nov 3, 2017

The function scipy.misc.imresize() always returns pixel values as integers in the range 0-255.

@matterport matterport deleted a comment from pirahagvp Nov 3, 2017
@taewookim
Copy link

taewookim commented Nov 7, 2017

@waleedka: Is it possible to train on custom datasets that have only bounding boxes, but no segmentation?

Mmm let me rephrase. I've seen object detection models based on ResNet101, but for some reason, this is better. I'd like to use this one for doing object detection on dataset that does not have image seg.

  1. do you recommend any repos that are just as accurate as this one when it comes to detection?
  2. Is it possible to train on custom datasets that have bounding boxes only?

My dataset isn't huge like COCO. 5 classes, 4 images each per class = 2k images.

@waleedka
Copy link
Collaborator

waleedka commented Nov 7, 2017

@taewookim You'll need to change the same places as in the change for key point detection, but rather then modify the mask branch and loss, you'd instead remove them completely. See my comment above about which functions to modify.

Alternatively, for a quicker hack that would be okay if your dataset is small and the extra processing load of the mask branch is not an issue, you could simply have your Dataset.load_mask() function return rectangular masks. For example, if your bounding box is from (10, 10) to (50, 50), then let load_mask() return a mask where everywhere is zero except the area (10, 10) to (50, 50) which is ones. The network will use those masks to generate the bounding boxes, and train as usual. The mask branch will learn to return rectangular masks, and you'd simply ignore them.

In terms of accuracy, you should expect to get similar accuracy to other object detection frameworks built on the Faster RCNN architecture because the basic building blocks are the same.

Another related point. If your dataset is small, you could use resnet50 instead of resnet101 to make training faster. A discussion about that is at issue #5

@liu6381810
Copy link

liu6381810 commented Nov 9, 2017

@waleedka Hi thanks for your great work!!

Now I want to predict human keypoint using your code

But I have a few questions:

The input mask ground truth for segmentation is [batch, img_height, img_width , num_instances]
As for keypoint detection, from the author's papers

For each of the K keypoints of an instance, the training target is a one-hot m × m binary mask where only a single pixel is labeled as foreground

So the input should be [batch, img_height, img_width, keypoint_num * num_instances]
If we change the input to this, I think there may be some problem in DetectionTargetLayer (514 line in model.py) https://github.com/matterport/Mask_RCNN/blob/master/model.py#L514

Also some problem with line 531 - 539 because the mask is only 1 point set to 1 but after crop, resize, round it may have more than 1 point set to 1 so maybe we need to just set the point with max value to 1

You said we just to change build_fpn_mask_graph and mrcnn_mask_loss_graph , so could you please be more precise and concrete? I would appreciate it if you can give me any advice. Thanks a lot!

My naive thought is:
We just ignore the gt_mask, I mean we just don't change this, and when calculate loss, I just judge which keypoint is in the proposal and where it is. Then we use the softmax cross-entropy loss

@waleedka
Copy link
Collaborator

You said we just to change build_fpn_mask_graph and mrcnn_mask_loss_graph, so could you please be more precise and concrete? I would appreciate it if you can give me any advice. Thanks a lot!

I said most of the changes are in those functions, but didn't meant that these are the only places to touch. It's been a long time since I read the paper so I'm afraid I can't give you a precise and concrete list of places to change. But I'm happy to help answer questions or review any changes you make and offer feedback.

I think your intuition is correct about adding a new head for key points. You can use the mask code as a template and modify as necessary.

@MaeThird
Copy link

MaeThird commented Nov 11, 2017

I used 0 or 1 to mark the mask but after several epochs I found the mask was not as excepted .After the mini-mask process all 1 had been changed to 0!
for i in range(mask.shape[-1]): m = mask[:, :, i] y1, x1, y2, x2 = bbox[i][:4] m = m[y1:y2, x1:x2] m = scipy.misc.imresize(m.astype(float), mini_shape, interp='bilinear') mini_mask[:, :, i] = np.where(m >= 128, 1, 0) return mini_mask
mini

@liu6381810
Copy link

@waleedka
Hi Thanks for your reply.
I have tried to modified the code to predict keypoint
Now it seems the training can work
But I have a problem now
When I got rois, target_masks, target_class_ids from DetectionTargetLayer
I just want to use the positive rois to the keypoint detection head
Because just positive rois contribute to the loss
So I write the Cut layer

`
class Cut(KE.Layer):

def __init__(self, **kwargs):
    super(Cut, self).__init__(**kwargs)

def call(self, inputs):
    
    rois = inputs[0]
    target_class_ids = inputs[1]
    target_mask = inputs[2]
    
    
    return [rois[:,:25,:],target_class_ids[:,:25],target_mask[:,:25,:,:]]

def compute_output_shape(self, input_shape):
    return [
        (None, config.TRAIN_ROIS_PER_IMAGE, 4),
        (None, config.TRAIN_ROIS_PER_IMAGE),
        (None, config.TRAIN_ROIS_PER_IMAGE, config.PART_NUMS ,3)
        
    ]

def compute_mask(self, inputs, mask=None):
    return [None, None, None]

positive_rois, positive_target_class_ids, positive_target_mask = Cut()([rois, target_class_ids,     target_mask])

`

Then When I use the mode.fitgenerator(...)
I got the error:
Input to reshape is a tensor with 5017600 values, but the requested shape requires a multiple of 25690112
Node: tower_1_1/mask_rcnn/mrcnn_mask_bn1/Reshape_1 = Reshape[T=DT_FLOAT, Tshape=DT_INT32, _device="/job:localhost/replica:0/task:0/device:GPU:1"](tower_1_1/mask_rcnn/mrcnn_mask_bn1/batchnorm/add_1, tower_1_1/mask_rcnn/mrcnn_mask_conv1/Reshape_1/shape)
Node: tower_0_1/mask_rcnn/roi_align_classifier/strided_slice_8/_20037 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_13861_tower_0_1/mask_rcnn/roi_align_classifier/strided_slice_8", tensor_type=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"()

it seems the error occurs at build_fpn_mask_graph which I changed to:

`
def build_fpn_mask_graph(rois, feature_maps,
image_shape, pool_size, num_classes):

x = PyramidROIAlign([pool_size, pool_size], image_shape,
                    name="roi_align_mask")([rois] + feature_maps)


for i in range(8):
    
    x = KL.TimeDistributed(KL.Conv2D(512, (3, 3), padding="same"),
                           name="mrcnn_mask_conv{}".format(i+1))(x)
    
    x = KL.TimeDistributed(BatchNorm(axis=3),
                           name='mrcnn_mask_bn{}'.format(i+1))(x)
    x = KL.Activation('relu')(x)

#remove activation
x = KL.TimeDistributed(KL.Conv2DTranspose(config.PART_NUMS, (2,2), strides=2),
                       name="mrcnn_mask_deconv")(x)

x = BilinearUPSampling()(x)


return x`

So do you think any problem about the Cut layer?

@RodrigoGantier
Copy link

I would like to ask, what would be the loss function for Human Pose estimation, I have the tensor in the flowing configuration:
target_masks: [batch, num_rois, height, width, num_parts].
target_class_ids: [batch, num_rois]
pred_masks: [batch, proposals, height, width, num_parts]

before gather the masks I have:
y_true = tf.gather(target_masks, positive_ix) [positive_ix, height, width, num_parts]
y_pred = tf.gather(pred_masks, positive_ix) [positive_ix, height, width, num_parts]

I have probed this configuration but did not get results

loss = K.switch(tf.size(y_true) > 0, K.categorical_crossentropy(target=y_true, output=y_pred), tf.constant(0.0))
loss = K.mean(loss)
loss = K.reshape(loss, [1, 1])

res

I have to mention that I have no problems with the generation of boxes, The boxes in the next figure are produced by the neural network
figure_1-1

@yanxp
Copy link

yanxp commented Dec 4, 2017

Have you train and evaluate on human pose estimation successfully?

@RodrigoGantier
Copy link

RodrigoGantier commented Dec 5, 2017

After fixing some details, and reformulating the function loss to:
pred_masks = K.reshape(pred_masks, (-1, 784, 14))
target_masks = K.reshape(target_masks, (-1, 784, 14))

Gather the masks (predicted and true) that contribute to loss
y_true = tf.gather(target_masks, positive_ix)
y_pred = tf.gather(pred_masks, positive_ix)

loss = []
for i in range(14):
loss.append(tf.nn.softmax_cross_entropy_with_logits(logits=y_pred[:, :, i], labels=y_true[:, :, i]))
loss = tf.stack(loss)

Got the following results

res_1

                                            gt_human_points

res_2

                                            network_output

As you can see, the neural network does not recognize between right or left shoulder, right or left knee, etc.
Maybe it's because it did not train enough, I'm not sure, someone please help me

@filipetrocadoferreira
Copy link

Maybe more training and also be careful with data augmentation. Flipping maybe it's not a good idea in human pose.

@filipetrocadoferreira
Copy link

filipetrocadoferreira commented Dec 12, 2017

@RodrigoGantier , which files did you change to include keypoint detection. I would like to start this approach.

BTW, are you not forgetting to include background class? If you have a softmax activation, you should include a background class. It appears that the net needs always to predict one of keypoint classes, even when it's background

@RodrigoGantier
Copy link

RodrigoGantier commented Dec 20, 2017

@filipetrocadoferreira
first of all, sorry for responding so late (I was somewhat busy with college assignments), second you're right "Flipping it's not a good idea" , now the neural network works better, now I'm working on the background issue, eliminating the instances that do not include a reference (miss key point), in this week I will try to finish the code, as sun I have some result I will upload some screen shots over here, then I will order the code and include all the comments to publish it

@filipetrocadoferreira
Copy link

That would be helpful. I'm stuck in the loss function. But I'm planning next week to develop this. I hope we can share some opinions

@QtSignalProcessing
Copy link

QtSignalProcessing commented Dec 20, 2017

Hi @filipetrocadoferreira, I am also developing the mask rcnn for human pose estimation, but there are some bugs. If you want, I can share my code with you so that we can debug together.

The following is the code of mask loss function and the tensor shape is the same as @RodrigoGantier 's configuration.

`target_class_ids = K.reshape(target_class_ids, (-1,))
positive_ix = tf.where(target_class_ids > 0)[:, 0]

target_masks = K.reshape(target_masks, (mask_shape[0] * mask_shape[1], mask_shape[2] * mask_shape[3], mask_shape[4]))
pred_masks = K.reshape(pred_masks, (pred_shape[0] * pred_shape[1], pred_shape[2] * pred_shape[3], pred_shape[4]))

y_true = tf.gather(target_masks, positive_ix)
y_pred = tf.gather(pred_masks, positive_ix)

#y_true = target_masks
#y_pred = pred_masks
loss = []

##############################
#Without tf.gather(), all following loss functions work, but the detection results are really poor.
#With tf.gather(), only the first one(with K.switch) works,
#but occasionally the program crashed due to missing key points
#(I guess, but I am not sure whether the problem is caused by label or prediction ),
#one can continue training by loading the last stored weights.

#My program focuses on 12 key points (without key points on the face)
#############################
for ii in range(0,12):
loss.append(K.switch(tf.size(y_true) > 0, K.categorical_crossentropy(target=y_true[:,:,ii], output=y_pred[:,:,ii]), tf.constant(0.0)))
#loss.append(tf.nn.softmax_cross_entropy_with_logits(logits=y_pred[:, :, ii], labels=y_true[:, :, ii]))
# loss.append(K.categorical_crossentropy(target=y_true[:,:,ii], output=y_pred[:,:,ii]))
loss = tf.stack(loss)`

@RodrigoGantier
Copy link

RodrigoGantier commented Dec 21, 2017

@QtSignalProcessing my main problem is the same, the missing key points labels, since there are photographs which do not have all the points, these labels (vectors of 28 x 28 with only zeros) turn the loss function into zero " loss = - (mean (y_label * log (y_predict )) ", the result in my opinion is: the neural network facing the not visible point or nonexistent point, looks for the closest or more alike point, proving the bad results.
the point is: how to represent the lost points (inside of the labels), in order to train the neural network with appropriate inputs (one hot encode)

-My Actual loss function 'with this I solve the crass problem, in my case'
target_class_ids = K.reshape(target_class_ids, (-1,))
positive_ix = tf.where(target_class_ids > 0)[:, 0]

target_masks = K.reshape(target_masks, (-1, 784, 14))
pred_masks = K.reshape(pred_masks, (-1, 784, 14))

y_true = tf.gather(target_masks, positive_ix)
y_pred = tf.gather(pred_masks, positive_ix)

loss = K.switch(tf.size(y_true) > 0,
tf.nn.softmax_cross_entropy_with_logits(logits=y_pred, labels=y_true, dim=1),
tf.constant(0.0))
loss = K.mean(loss)

P.S: I've already tried the loss function with the Euclidean distance, which gives the worst results

@filipetrocadoferreira
Copy link

I have a doubt.
If target_maks have the shape of target_masks: [batch, num_rois, height, width,n_keypoints].
and pred_masks: [batch, proposals, height, width, n_keypoints]

if num_rois != proposals how can we make gather with the positive_ix for the both tensor?

@QtSignalProcessing
Copy link

HI @RodrigoGantier , in my mask loss function, I add additional operations to get only non-zero labels:

`pred = []
target = []
for ii in range(0,12):

    l_sh_t = K.reshape(y_true[:,:,ii], (-1,))
    pos_lsh_ix = tf.where(l_sh_t > 0)[:,0]
    l_sh_t = tf.gather(l_sh_t, pos_lsh_ix)
    target.append(l_sh_t)

    l_sh_p = K.reshape(y_pred[:,:,ii], (-1,))
    l_sh_p = tf.gather(l_sh_p, pos_lsh_ix)
    pred.append(l_sh_p)`  

I am waiting for the results and I will update my progress as soon as possible.

@QtSignalProcessing
Copy link

Hi @filipetrocadoferreira , num_rois should equal to proposals, according to this implementation.
You can check PyramidROIAlign and DetectionTargetLayer.

@RodrigoGantier
Copy link

@QtSignalProcessing Assuming, y_pred = [positive rois, 728,12] and y_pred is in one-hot incode format, I think pos_lsh_ix = tf.where (l_sh_t> 0) [:, 0] is not correct because u just erase the other zeros in the onehot vector. I suppose what you intend do not to take into account the labels that contain zero, leaving with a form of y_pred = [positive rois, 728.10] if two key points are missing for example,

@RodrigoGantier
Copy link

RodrigoGantier commented Dec 22, 2017

@filipetrocadoferreira In my understanding, for training of the neural network, first are selected the positive propolsals, then is filled(padding) the final tensor with proposals negative and zeros to reach the maxim number of propolsals configured in the config function, positive_ix contains the index of positive propolsals, so for in training stage are selected only the positive propolsals (for inference stage in the case of y_pred were sectioned the propolsas with bigest probability that usually correspond to the positive index)

@filipetrocadoferreira
Copy link

This is my trial to address the problem of empty keypoints

# Reshape for simplicity. Merge first two dimensions into one.
   target_class_ids = K.reshape(target_class_ids, (-1,))
   target_masks = K.reshape(target_masks, (-1, 784,17))
   pred_masks = K.reshape(pred_masks, (-1, 784, 17))


   positive_ix = tf.where(target_class_ids > 0)[:, 0]
   y_true = tf.gather(target_masks, positive_ix)
   y_pred = tf.gather(pred_masks, positive_ix)

   y_true = tf.transpose(y_true, perm=[0, 2, 1])
   y_pred = tf.transpose(y_pred, perm=[0, 2, 1])

   y_true = K.reshape(y_true, (-1, 784))
   y_pred = K.reshape(y_pred, (-1, 784))

   y_true_sum = tf.reduce_sum(y_true, axis=-1)
   
   good_ids = tf.where(y_true_sum > 0)[:, 0]

   y_true = tf.gather(y_true, good_ids)
   y_pred = tf.gather(y_pred, good_ids)

   loss = K.switch(tf.size(y_true) > 0,
                   K.categorical_crossentropy(target=y_true,output=y_pred),
                   tf.constant(0.0))
   loss = K.mean(loss)

   return loss

@filipetrocadoferreira
Copy link

I'm not being able to converge the mask loss.

def detection_targets_graph(proposals, gt_boxes, gt_masks, config):
    """Generates detection targets for one image. Subsamples proposals and
    generates target class IDs, bounding box deltas, and masks for each.

    Inputs:
    proposals: [N, (y1, x1, y2, x2)] in normalized coordinates. Might
               be zero padded if there are not enough proposals.
    gt_boxes: [MAX_GT_INSTANCES, (y1, x1, y2, x2, class_id)] in
              normalized coordinates.
    gt_masks: [height, width,17, MAX_GT_INSTANCES] of boolean type.

    Returns: Target ROIs and corresponding class IDs, bounding box shifts,
    and masks.
    rois: [TRAIN_ROIS_PER_IMAGE, (y1, x1, y2, x2)] in normalized coordinates
    class_ids: [TRAIN_ROIS_PER_IMAGE]. Integer class IDs. Zero padded.
    deltas: [TRAIN_ROIS_PER_IMAGE, NUM_CLASSES, (dy, dx, log(dh), log(dw))]
            Class-specific bbox refinments.
    masks: [TRAIN_ROIS_PER_IMAGE, height, width,17). Masks cropped to bbox
           boundaries and resized to neural network output size.

    Note: Returned arrays might be zero padded if not enough target ROIs.
    """
    # Assertions
    asserts = [
        tf.Assert(tf.greater(tf.shape(proposals)[0], 0), [proposals],
                  name="roi_assertion"),
    ]
    with tf.control_dependencies(asserts):
        proposals = tf.identity(proposals)

    # Remove proposals zero padding
    non_zeros = tf.cast(tf.reduce_sum(tf.abs(proposals), axis=1), tf.bool)
    proposals = tf.boolean_mask(proposals, non_zeros)

    # TODO: Remove zero padding from gt_boxes and gt_masks

    # Compute overlaps matrix [rpn_rois, gt_boxes]
    # 1. Tile GT boxes and repeate ROIs tensor. This
    # allows us to compare every ROI against every GT box without loops.
    # TF doesn't have an equivalent to np.repeate() so simulate it
    # using tf.tile() and tf.reshape.
    rois = tf.reshape(tf.tile(tf.expand_dims(proposals, 1), 
                              [1, 1, tf.shape(gt_boxes)[0]]), [-1, 4])
    boxes = tf.tile(gt_boxes, [tf.shape(proposals)[0], 1])
    # 2. Compute intersections
    roi_y1, roi_x1, roi_y2, roi_x2 = tf.split(rois, 4, axis=1)
    box_y1, box_x1, box_y2, box_x2, class_ids = tf.split(boxes, 5, axis=1)
    y1 = tf.maximum(roi_y1, box_y1)
    x1 = tf.maximum(roi_x1, box_x1)
    y2 = tf.minimum(roi_y2, box_y2)
    x2 = tf.minimum(roi_x2, box_x2)
    intersection = tf.maximum(x2 - x1, 0) * tf.maximum(y2 - y1, 0)
    # 3. Compute unions
    roi_area = (roi_y2 - roi_y1) * (roi_x2 - roi_x1)
    box_area = (box_y2 - box_y1) * (box_x2 - box_x1)
    union = roi_area + box_area - intersection
    # 4. Compute IoU and reshape to [rois, boxes]
    iou = intersection / union
    overlaps = tf.reshape(iou, [tf.shape(proposals)[0], tf.shape(gt_boxes)[0]])

    # Determine postive and negative ROIs
    roi_iou_max = tf.reduce_max(overlaps, axis=1)
    # 1. Positive ROIs are those with >= 0.5 IoU with a GT box
    positive_roi_bool = (roi_iou_max >= 0.5)
    positive_indices = tf.where(positive_roi_bool)[:, 0]
    # 2. Negative ROIs are those with < 0.5 with every GT box
    negative_indices = tf.where(roi_iou_max < 0.5)[:, 0]

    # Subsample ROIs. Aim for 33% positive
    # Positive ROIs
    positive_count = int(config.TRAIN_ROIS_PER_IMAGE * config.ROI_POSITIVE_RATIO)
    positive_indices = tf.random_shuffle(positive_indices)[:positive_count]
    # Negative ROIs. Fill the rest of the batch.
    negative_count = config.TRAIN_ROIS_PER_IMAGE - tf.shape(positive_indices)[0]
    negative_indices = tf.random_shuffle(negative_indices)[:negative_count]
    # Gather selected ROIs
    positive_rois = tf.gather(proposals, positive_indices)
    negative_rois = tf.gather(proposals, negative_indices)

    # Assign positive ROIs to GT boxes.
    positive_overlaps = tf.gather(overlaps, positive_indices)
    roi_gt_box_assignment = tf.argmax(positive_overlaps, axis=1)
    roi_gt_boxes = tf.gather(gt_boxes, roi_gt_box_assignment)

    # Compute bbox refinement for positive ROIs
    deltas = utils.box_refinement_graph(positive_rois, roi_gt_boxes[:,:4])
    deltas /= config.BBOX_STD_DEV

    # Assign positive ROIs to GT masks
    # Permute masks to [N, height, width, 17]
    transposed_masks = tf.transpose(gt_masks, [3, 0, 1,2])

    # Pick the right mask for each ROI
    roi_masks = tf.gather(transposed_masks, roi_gt_box_assignment)

    # Compute mask targets
    boxes = positive_rois
    if config.USE_MINI_MASK:
        # Transform ROI corrdinates from normalized image space
        # to normalized mini-mask space.
        y1, x1, y2, x2 = tf.split(positive_rois, 4, axis=1)
        gt_y1, gt_x1, gt_y2, gt_x2, _ = tf.split(roi_gt_boxes, 5, axis=1)
        gt_h = gt_y2 - gt_y1
        gt_w = gt_x2 - gt_x1
        y1 = (y1 - gt_y1) / gt_h
        x1 = (x1 - gt_x1) / gt_w
        y2 = (y2 - gt_y1) / gt_h
        x2 = (x2 - gt_x1) / gt_w
        boxes = tf.concat([y1, x1, y2, x2], 1)
    box_ids = tf.range(0, tf.shape(roi_masks)[0])
    masks = tf.image.crop_and_resize(tf.cast(roi_masks, tf.float32), boxes,
                                     box_ids,
                                     config.MASK_SHAPE)

    masks = tf.round(masks)

    # Append negative ROIs and pad bbox deltas and masks that
    # are not used for negative ROIs with zeros.
    rois = tf.concat([positive_rois, negative_rois], axis=0)
    N = tf.shape(negative_rois)[0]
    P = tf.maximum(config.TRAIN_ROIS_PER_IMAGE - tf.shape(rois)[0], 0)
    rois = tf.pad(rois, [(0, P), (0, 0)])
    roi_gt_boxes = tf.pad(roi_gt_boxes, [(0, N+P), (0, 0)])
    deltas = tf.pad(deltas, [(0, N + P), (0, 0)])
    masks = tf.pad(masks, [[0, N + P], (0, 0), (0, 0), (0, 0)])

    return rois, roi_gt_boxes[:, 4], deltas, masks

@QtSignalProcessing
Copy link

QtSignalProcessing commented Mar 21, 2018

@Superlee506 My kpt loss function:
`def mrcnn_kpt_mask_loss_graph(target_masks, target_class_ids, pred_masks):

num_kpt = 19
target_class_ids = K.reshape(target_class_ids, (-1,))
positive_ix = tf.where(target_class_ids > 0)[:, 0]
target_masks = K.reshape(target_masks, (-1, 56 * 56, num_kpt))
pred_masks = K.reshape(pred_masks, (-1, 56 * 56, num_kpt))
# Gather the masks (predicted and true) that contribute to loss
y_true = tf.gather(target_masks, positive_ix)
y_pred = tf.gather(pred_masks, positive_ix)
loss = []
for ii in range(0, num_kpt):
    logits = y_pred[:,:,ii]
    eps = tf.constant(value=1e-4)
    labels = tf.to_float(y_true[:,:,ii])
    softmax = tf.nn.softmax(logits) + eps
    cross_entropy = -tf.reduce_sum(
        labels * tf.log(softmax), reduction_indices=[1])
    cross_entropy_mean = K.switch(tf.size(labels) > 0, tf.reduce_mean(cross_entropy),
                                  tf.constant(0.0))
    loss.append(cross_entropy_mean)
loss = tf.stack(loss)
loss = K.mean(loss)
return loss`

@Superlee506
Copy link

@QtSignalProcessing The number of keypoint in coco dataset is 17, why did you use 19?

@QtSignalProcessing
Copy link

@Superlee506 I am not using coco

@Superlee506
Copy link

@QtSignalProcessing Finally, I checked the original Detectron code for human pose estimation, and changed my code. The loss converged, but the detection result isn't as good as the original paper. And the model can't distinguish between right or left shoulder, right or left knee, etc. no matter when I used the flipping augment or not.

@filipetrocadoferreira
Copy link

What did you changed?

@Superlee506
Copy link

@filipetrocadoferreira A lot of places, and I find many mistakes in RodrigoGantier code. Firstly, I changed the ground truth keypoint as label type. Secondly the loss function, I added weights in the keypoint loss function as the Detectron did, and then the sparse_softmax_cross_entropy_with_logits converges quickly. What's more importantly, the flipping method isn't right for keypoints, and we need some modifications. However, my results doesn't as good as the oringal paper. I 'm confused about it. I plan to submit my code when the results seem good.

@filipetrocadoferreira
Copy link

Nice! I also found the way to deal with keypoint ground-truth can't be the same as the mask (because resizes and crops will prolly make clear it)

the code would be amazing

@Superlee506
Copy link

@filipetrocadoferreira @QtSignalProcessing @RodrigoGantier @racinmat I opensource my project with detailed code comments. The loss can converge quickly, but the predicted results are not as good as the original paper. I just have one 980 graphics card, so I release my code, and any contribution or improvement is welcome and appreciated. https://github.com/Superlee506/Mask_RCNN

@QtSignalProcessing
Copy link

QtSignalProcessing commented Mar 27, 2018

@Superlee506 It's really hard to achieve results reported in the original paper since the training parameters should be carefully selected ( I read sth like this in one of the issues in Detectron ). BTW, distinguishing left/ right key points replies on the geometric information of human body, this could be done by post processing.

@Superlee506
Copy link

@QtSignalProcessing How to do the post processing? In my case, my model usually output the left/right key point together.

@Superlee506
Copy link

https://github.com/Superlee506/Mask_RCNN_Humanpose, I change the name of this repository.

@QtSignalProcessing
Copy link

@Superlee506 Positions of nose and eyes provide information that you can use for distinguishing left and right. Otherwise you should have some assumptions.

Sorry for my last comment, I used wrong words. The best way to distinguish left and right is to change the key point head to model the key points relationships.

@hdjsjyl
Copy link

hdjsjyl commented Mar 30, 2018

@RodrigoGantier , thanks for your advice. It is really helpful to me. But for your following code, I have a little problems.
First: I know the 14 should be the number of keypoint. Does it include the background?
Second: And how to get the positive_ix, getting the postive_idx from results for rcnn part or other places?
Third: Because for every keypoint, you compute a loss, when you output the keypoint detection loss, should compute the mean of them?
Any advice will be appreciated! thank you very much.

pred_masks = K.reshape(pred_masks, (-1, 784, 14))
target_masks = K.reshape(target_masks, (-1, 784, 14))

Gather the masks (predicted and true) that contribute to loss
y_true = tf.gather(target_masks, positive_ix)
y_pred = tf.gather(pred_masks, positive_ix)

loss = []
for i in range(14):
loss.append(tf.nn.softmax_cross_entropy_with_logits(logits=y_pred[:, :, i], labels=y_true[:, :, i]))
loss = tf.stack(loss)

@hdjsjyl
Copy link

hdjsjyl commented Apr 1, 2018

@rujiao @waleedka @RodrigoGantier , thanks for your advice. Now I have changed segmentation part for keypoint detection. But I found that rcnnL1Loss will be very big such as 234677418896143419441152.0000. Do you know what is the reason? Any advice will be appreciated. Thank you.

@VellalaVineethKumar
Copy link

@rujiao
@waleedka
@taewookim
@liu6381810
@MaeThird
can anyone please help me get a feature vector of the masked region? I'm really struggling to get it
I first referred to this Issue: #1249 and then #1190 i did all the changes mentioned there but i get errors like "positional argument required "roi_pooled_features". I'm still stuck here for a month and any help would be really appreciated.

@germanotm
Copy link

germanotm commented Mar 31, 2020

Has anyone been successful in using the Mask RCNN to detect only keypoints?

@rujiao
Copy link
Author

rujiao commented Mar 31, 2020

Yes, I have used Mask-RCNN to detect bbox and keypoints. It works quick well. You can simply remove the mask part

saviola777 pushed a commit to DiffPro-ML/Mask_RCNN that referenced this issue Apr 25, 2020
@xidaniel
Copy link

@rujiao Could you share your code in Github?

vijaygill pushed a commit to vijaygill/Mask_RCNN that referenced this issue Aug 16, 2020
SaiDunoyer pushed a commit to SaiDunoyer/Mask_RCNN that referenced this issue Jul 7, 2021
SaiDunoyer pushed a commit to SaiDunoyer/Mask_RCNN that referenced this issue Jul 7, 2021
SaiDunoyer pushed a commit to SaiDunoyer/Mask_RCNN that referenced this issue Jul 7, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests