New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
The value in ComplexMultiply_backward function #6
Comments
Hi @kaiyuyue , thank you for diving into the code and checking its correctness. I believe that the code is actually correct. In your calculation, you are assuming that the differentiation of a holomorphic function works in the same way as for real-valued function however the correct way is not as straightforward. To get to the correct result, you need to either use Wirtinger derivatives (https://en.wikipedia.org/wiki/Wirtinger_derivatives#Chain_rule) or treat all the functions as function of 2 real-valued variables instead of 1 complex one. In the second case, the main difference is that you have to differentiate w.r.t to Re(x) and Im(x) separately. Here is a not-so-short proof using Wirtinger derivatives: The Wirtinger derivatives of g w.r.t x are: We can compute dL/dRe(x) and dL/dIm(x): Then substituting and regrouping: Thus Thus Assuming that the calculus above is correct, the code should be as well. I will add a gradcheck test to prove the correctness of the multiplication. This backward function is actually not used in the computation of the compact bilinear pooling gradient because I tried to rearrange the computation of the multiplication and FFT gradients to decrease the memory footprint however the formula is the same. |
Very inspired. Thanks for correcting my errors. |
The backward function actually is not used means the memory used of the part is the same as that of this way? import pytorch_fft.fft.autograd as afft
sketch_x = CountSketchFn_forward(h1, s1, output_size, x)
sketch_y = CountSketchFn_forward(h2, s2, output_size, y)
x_re, x_im = afft.Fft()(sketch_x, sketch_x.new(*sketch_x.size()).zero_())
y_re, y_im = afft.Fft()(sketch_y, sketch_y.new(*sketch_y.size()).zero_())
prod_re, prod_im = x_re.mul(y_re), x_im.mul(y_im) |
I am not entirely sure what you mean. Regarding your snippet,
|
Thanks @kaiyuyue . I didn’t spot that in the paper. I will correct the implementation to match the original algorithm. |
Actually I think that the implementation is correct. The This multiplication stems from 3.2 The Caffe implementation by Gao et al. is also using a complex multiplication. |
Oh! My bad. Thanks. |
Hi @gdlg, thanks for this nice work. I'm confused about the
backward
procedure of complex multiplication. So I hope you can help me to figure it out.In forward,
In backward, according the chain rule, it will has
So, why is this line implemented by using the
value = 1
for real part andvalue = -1
for image part?Is there something wrong in my thoughts? Thanks.
The text was updated successfully, but these errors were encountered: