-
Notifications
You must be signed in to change notification settings - Fork 280
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
AssertionError: more than one group is unsupported on GPU #386
Comments
Can you please provide a minimal repro so that I can reproduce this issue? |
Hello, Line 1733 in efe8eda
Basically on GPU you have to make groups=1, but this line in conv2d grad computation has groups > 1, which will cause the problem for all model that has convolution operation. |
Unfortunately I didn't have the code snnipet which caused this problem after that, so I didn't post it. |
Hi,
The AssertionError is during the backward pass of the Conv2D layer in the model. Other models without Conv2D layer work without problem. |
Hi, CrypTen/crypten/cuda/cuda_tensor.py Lines 190 to 223 in 909df45
The function needs to be changed as follows:
I successfully trained models with this fix using a GPU, so I am pretty sure the backpropagation is calculated correctly. |
Are there any updates on this? I am experiencing the same problem, and @Tobias512 's solution does not work. It gives me some dimension mismatch, which I haven't looked into too closely yet. |
I found a better solution looking at CryptGPU. They implement it as follows: I changed the CrypTen implementation to:
It nearly the same code a before but the groups argument is set differently depending on if groups is in kwargs. @kwmaeng91 I hope this really fixes the bug. |
I met a issue when I used Conv2d in model.
Assertion as follow:
File "/home/data/anaconda3/anaconda/envs/mpc/lib/python3.7/site-packages/crypten/cuda/cuda_tensor.py", line 195, in __patched_conv_ops
), f"more than one group is unsupported on GPU (groups = {groups})"
AssertionError: more than one group is unsupported on GPU (groups = 256)
256 is the channel in conv2d.
I define the model as "self.conv3 = nn.Conv2d(128, 256, 5, 1, 2)"
How can I solve the problem?
The text was updated successfully, but these errors were encountered: