Skip to content
This repository has been archived by the owner on Nov 21, 2023. It is now read-only.

I can not change loss type to SquaredL2Distance get errors. #353

Closed
oujieww opened this issue Apr 9, 2018 · 2 comments
Closed

I can not change loss type to SquaredL2Distance get errors. #353

oujieww opened this issue Apr 9, 2018 · 2 comments

Comments

@oujieww
Copy link

oujieww commented Apr 9, 2018

I want to use
dist=model.net.SquaredL2Distance([dt_output, 'masks_int32'],'dist')
loss_mask=model.net.AveragedLoss( dist, 'loss_mask')
loss_gradients = blob_utils.get_loss_gradients(model, [loss_mask])
model.AddLosses('loss_mask')

i think it is ok to change another loss ,but i get some error which i have no idea of those.
Traceback (most recent call last):
File "../tools/train_net.py", line 291, in
main()
File "../tools/train_net.py", line 129, in main
checkpoints = train_model()
File "../tools/train_net.py", line 138, in train_model
model, start_iter, checkpoints, output_dir = create_model()
File "../tools/train_net.py", line 216, in create_model
model = model_builder.create(cfg.MODEL.TYPE, train=True,gpu_id=1)
File "/mnt/disk1/oujie/fasr-rcnn-mpii/3/Detectron-master/lib/modeling/model_builder.py", line 124, in create
return get_func(model_type_func)(model)
File "/mnt/disk1/oujie/fasr-rcnn-mpii/3/Detectron-master/lib/modeling/model_builder.py", line 89, in generalized_rcnn
freeze_conv_body=cfg.TRAIN.FREEZE_CONV_BODY
File "/mnt/disk1/oujie/fasr-rcnn-mpii/3/Detectron-master/lib/modeling/model_builder.py", line 229, in build_generic_detection_model
optim.build_data_parallel_model(model, _single_gpu_build_func)
File "/mnt/disk1/oujie/fasr-rcnn-mpii/3/Detectron-master/lib/modeling/optimizer.py", line 42, in build_data_parallel_model
model.AddGradientOperators(all_loss_gradients)
File "/usr/local/lib/python2.7/dist-packages/caffe2/python/model_helper.py", line 350, in AddGradientOperators
self.grad_map = self.net.AddGradientOperators(*args, **kwargs)
File "/usr/local/lib/python2.7/dist-packages/caffe2/python/core.py", line 1846, in AddGradientOperators
self._net.op[skip:], ys)
File "/usr/local/lib/python2.7/dist-packages/caffe2/python/core.py", line 1113, in GetBackwardPass
return ir.GetBackwardPass(ys)
File "/usr/local/lib/python2.7/dist-packages/caffe2/python/core.py", line 984, in GetBackwardPass
forward_op_idx, all_input_to_grad)
File "/usr/local/lib/python2.7/dist-packages/caffe2/python/core.py", line 934, in GenerateGradientsForForwardOp
forward_op, g_output)
File "/usr/local/lib/python2.7/dist-packages/caffe2/python/core.py", line 1086, in GetGradientForOp
format(op.type, e, str(op))
Exception: Exception when creating gradient for [Python]:[enforce fail at operator_gradient.h:166] g_output
.at(i).IsDense(). Gradient of output gpu_1/rois is not provided! .

and i get some information from "ef GetGradientForOp(cls, op, g_output): in core.py"
[None, 'gpu_1/loss_kps_grad']
[u'gpu_1/kps_score_reshaped_grad', None]
[u'gpu_1/kps_score_grad']
[u'gpu_1/kps_score_lowres_grad']
[u'gpu_1/conv_fcn8_grad']
[u'gpu_1/conv_fcn8_grad']
[u'gpu_1/conv_fcn7_grad']
[u'gpu_1/conv_fcn7_grad']
[u'gpu_1/conv_fcn6_grad']
[u'gpu_1/conv_fcn6_grad']
[u'gpu_1/conv_fcn5_grad']
[u'gpu_1/conv_fcn5_grad']
[u'gpu_1/conv_fcn4_grad']
[u'gpu_1/conv_fcn4_grad']
[u'gpu_1/conv_fcn3_grad']
[u'gpu_1/conv_fcn3_grad']
[u'gpu_1/conv_fcn2_grad']
[u'gpu_1/conv_fcn2_grad']
[u'gpu_1/conv_fcn1_grad']
[u'gpu_1/conv_fcn1_grad']
[u'gpu_1/_[pose]roi_feat_grad']
[u'gpu_1/
[pose]roi_feat_shuffled_grad', None]
[u'gpu_1/
[pose]roi_feat_fpn5_grad']
[u'gpu_1/
[pose]roi_feat_fpn4_grad']
[u'gpu_1/
[pose]roi_feat_fpn3_grad']
[u'gpu_1/
[pose]roi_feat_fpn2_grad']
['gpu_1/loss_mask_grad']
[u'gpu_1/dist_grad']
[u'gpu_1/dt_output_grad']
[u'gpu_1/mask_fcn_logits_grad']
[u'gpu_1/conv5_mask_grad']
[u'gpu_1/conv5_mask_grad']
[u'gpu_1/
[mask]fcn4_grad']
[u'gpu_1/
[mask]fcn4_grad']
[u'gpu_1/
[mask]fcn3_grad']
[u'gpu_1/
[mask]fcn3_grad']
[u'gpu_1/
[mask]fcn2_grad']
[u'gpu_1/
[mask]fcn2_grad']
[u'gpu_1/
[mask]fcn1_grad']
[u'gpu_1/
[mask]fcn1_grad']
[u'gpu_1/
[mask]roi_feat_grad']
[u'gpu_1/
[mask]roi_feat_shuffled_grad', None]
[u'gpu_1/
[mask]roi_feat_fpn5_grad']
[u'gpu_1/
[mask]roi_feat_fpn4_grad']
[u'gpu_1/
[mask]roi_feat_fpn3_grad']
[u'gpu_1/
[mask]_roi_feat_fpn2_grad']
['gpu_1/loss_bbox_grad']
[None, 'gpu_1/loss_cls_grad']
[u'gpu_1/bbox_pred_grad']
[u'gpu_1/cls_score_grad']
[u'gpu_1/fc7_grad']
[u'gpu_1/fc7_grad']
[u'gpu_1/fc6_grad']
[u'gpu_1/fc6_grad']
[u'gpu_1/roi_feat_grad']
[u'gpu_1/roi_feat_shuffled_grad', None]
[u'gpu_1/roi_feat_fpn5_grad']
[u'gpu_1/roi_feat_fpn4_grad']
[u'gpu_1/roi_feat_fpn3_grad']
[u'gpu_1/roi_feat_fpn2_grad']
['gpu_1/loss_rpn_bbox_fpn6_grad']
['gpu_1/loss_rpn_cls_fpn6_grad']
['gpu_1/loss_rpn_bbox_fpn5_grad']
['gpu_1/loss_rpn_cls_fpn5_grad']
['gpu_1/loss_rpn_bbox_fpn4_grad']
['gpu_1/loss_rpn_cls_fpn4_grad']
['gpu_1/loss_rpn_bbox_fpn3_grad']
['gpu_1/loss_rpn_cls_fpn3_grad']
['gpu_1/loss_rpn_bbox_fpn2_grad']
['gpu_1/loss_rpn_cls_fpn2_grad']
[None, None, None, None, None, None, None, u'gpu_1/masks_int32_grad', u'gpu_1/masks_int32_weights_grad', None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None]

anyone help me ? thanks

@xieshuqin
Copy link

The problem is that the GenerateProposalsOp and the CollectAndDistributeProposalsOp aren't expecting any gradients income.

For example, looking at the CrossEntropyLoss code, it only returns the gradients for X but not Y. Thus the above Op won't cause you problems. However, when you are using SquaredL2Distance, it will return both the gradients for X and Y, then you get the errors.

To fixed this, maybe you can try something like this,

blob_masks_int32 = core.ScopedBlobReference('masks_int32')
blob_masks_int32 = model.net.StopGradients(blob_masks_int32, blob_masks_int32)
dist=model.net.SquaredL2Distance([dt_output, blob_masks_int32], 'dist')
loss_mask=model.net.AveragedLoss( dist, 'loss_mask')
loss_gradients = blob_utils.get_loss_gradients(model, [loss_mask])

Or, you may want to write a new caffe2 op similar to SquaredL2Distance but only returns the gradient for X. That should do too.

@rbgirshick
Copy link
Contributor

Yes, using StopGradient is the correct solution.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants