Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: reciprocal is not implemented for type torch.cuda.LongTensor #147

Closed
ArchieGu opened this issue Apr 26, 2018 · 20 comments
Closed

Comments

@ArchieGu
Copy link

I saw someone posted a similar problem #116

Does this have something to do with Pytorch version???

Anyone with solution, plz contact me, thank you.

@ArchieGu
Copy link
Author

OK, it seems that Pytorch 0.3.0 can solve the problem. Maybe it is the version caused the problem.

@LiuHao-THU
Copy link

I also met this problem when I try to compile with pytorch 0.4.0? how to solve this problem? while i do not want to change the version of my pytorch...

@xylcbd
Copy link

xylcbd commented May 3, 2018

modify faster-rcnn.pytorch\lib\model\rpn\anchor_target_layer.py +156, from

num_examples = torch.sum(labels[i] >= 0)

to

num_examples = torch.sum(labels[i] >= 0).item()

@xylcbd
Copy link

xylcbd commented May 3, 2018

another issue can be resoved with modify faster-rcnn.pytorch\lib\model\rpn\proposal_target_layer_cascade.py +133, from

labels = gt_boxes[:,:,4].contiguous().view(-1).index(offset.view(-1))\ .view(batch_size, -1)

to

labels = gt_boxes[:,:,4].contiguous().view(-1).index((offset.view(-1), ))\ .view(batch_size, -1)

@D-X-Y
Copy link

D-X-Y commented May 13, 2018

pytorch/pytorch#2772

@yjump
Copy link

yjump commented May 21, 2018

Hi @xylcbd, thank you for your solutions. But I meet with anothor issue caused by:

Traceback (most recent call last):
File "trainval_net.py", line 332, in
clip_gradient(fasterRCNN, 10.)
File "/home/yanjp/faster-rcnn.pytorch/lib/model/utils/net_utils.py", line 52, in clip_gradient
p.grad.mul_(norm)
RuntimeError: Expected object of type torch.cuda.FloatTensor but found type torch.FloatTensor for argument #3 'other'
Exception NameError: "global name 'FileNotFoundError' is not defined" in <bound method _DataLoaderIter.del of <torch.utils.data.dataloader._DataLoaderIter object at 0x7f75f997dd50>> ignored

I try to add norm.cuda() to solve it, but failed as:

Traceback (most recent call last):
File "trainval_net.py", line 332, in
clip_gradient(fasterRCNN, 10.)
File "/home/yanjp/faster-rcnn.pytorch/lib/model/utils/net_utils.py", line 51, in clip_gradient
norm = norm.cuda()
AttributeError: 'float' object has no attribute 'cuda'
Exception NameError: "global name 'FileNotFoundError' is not defined" in <bound method _DataLoaderIter.del of <torch.utils.data.dataloader._DataLoaderIter object at 0x7f1460182cd0>> ignored

could you kindly offer some advices? Thanks a lot!

@xylcbd
Copy link

xylcbd commented May 22, 2018

@jwyang, maybe you should check the input parameter "clip_norm", is it hosted by CUDA?

@yjump
Copy link

yjump commented May 22, 2018

Thanks for your opinion@xylcbd I have checked clip_norm, it is a float constant. At last I solve my problem by adding
#add
norm = torch.tensor([norm],device='cuda')
before
faster-rcnn.pytorch/lib/model/utils/net_utils.py", line 51,
And now it seems work

@wtliao
Copy link

wtliao commented May 26, 2018

@xylcbd thanks so much for your solution. It works for me now

@isalirezag
Copy link

isalirezag commented Jun 13, 2018

@yjump @wtliao where did you put it exactly? can you please see if i put it in the right position?:

    norm = clip_norm / max(totalnorm, clip_norm)
    for p in model.parameters():
        if p.requires_grad:
            p.grad.mul_(norm)
            norm = torch.tensor([norm],device='cuda')

@yjump
Copy link

yjump commented Jun 14, 2018

@isalirezag put it before p.grad.mul_(norm) as gradiants are calculated on GPU

@babyjie57
Copy link

@yjump @wtliao Thank you so much for your solution!

@JingXiaolun
Copy link

Hi @xylcbd, thank you for your solutions. But I meet with anothor issue caused by:
File "/media/csu/新加卷/AI Competition/Baidu Competition/intermediary_contest/faster-rcnn.pytorch-master/trainval_net.py", line 316, in
loss.backward()
File "/home/csu/anaconda3/lib/python3.6/site-packages/torch/tensor.py", line 93, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/home/csu/anaconda3/lib/python3.6/site-packages/torch/autograd/init.py", line 89, in backward
allow_unreachable=True) # allow_unreachable flag
File "/media/csu/新加卷/AI Competition/Baidu Competition/intermediary_contest/faster-rcnn.pytorch-master/lib/model/roi_align/functions/roi_align.py", line 38, in backward
assert(self.feature_size is not None and grad_output.is_cuda)
AssertionError
Could you offer me some advice?Thanks a lot

@JingXiaolun
Copy link

@kurosaki-fish ,no,have you solved the problem?

@changjo
Copy link

changjo commented Jul 17, 2018

For me, changing

totalnorm = np.sqrt(totalnorm) 

to

totalnorm = torch.sqrt(totalnorm)

Solved this problem RuntimeError: Expected object of type torch.cuda.FloatTensor but found type torch.FloatTensor...

thilinicooray added a commit to thilinicooray/VSRL that referenced this issue Jul 22, 2018
@Zhihan-Zhou
Copy link

command:python trainval_net.py --cuda
Solved the problem: AssertionError

@jwyang
Copy link
Owner

jwyang commented Aug 28, 2018

the current master supports pytorch 0.4.0 now, feel free to use it!

@jwyang jwyang closed this as completed Aug 28, 2018
@ssli23
Copy link

ssli23 commented Sep 18, 2018

I also met this problem when I try to compile with pytorch 0.4.0? how to solve this problem? while i do not want to change the version of my pytorch
Have you solved your problem?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

14 participants