New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
KCI mild speedup #49
KCI mild speedup #49
Conversation
Nice. |
Wow, that's really incredible! Thanks so much, Haoyue!!! A dumb question: any idea why |
Wow, this is great, @MarkDana! Thanks for identifying such a non-trivial optimization!! :) |
@kunwuz to answer your question, matrix multiplication is in general O(n^3) (but there is some optimization usually, but I am not expert in this area), and K * K.T is only O(n^2), so the optimization above reduced from O(n^3) to O(n^2) |
Thanks! Yeah, I always thought that NumPy did a perfect job in optimizing matrix multiplication so I didn't expect the empirical speed-up will be such huge. Nice to know that trick~ :) |
@kunwuz That's very different. Even NumPy did a perfect job at optimizing matrix multiplication, there is no way to reduce matrix multiplication to O(n^2). Here the problem is that we only need the diagonal entries of K.dot(K), but we calculated the whole K.dot(K), which wastes lots of time. Probably the optimization cannot be at NumPy level --- the optimization needs to at the "compile" level, i.e. compiler needs to know the program only need diagonal entries, so underlying machine level code it compiled to does NOT do what you told it to do, i.e. it didn't even calculate K.dot(K) for you. Of course, this is just a rough idea but not 100% correct --- I don't know the mechanmism of Python much, so the language I use is mostly for C. |
Yes, you are totally right. The hypothesized best possible value for the exponent \omega is 2, so we cannot exactly achieve it. If I remember correctly, there exists optimization that achieved \omega < 2.4, and numpy is one of the SOTA methods, which goes beyond compiler level (BLAS), in optimizing matrix multiplication. So I'm a little bit surprised the EMPIRICAL improvement could be such huge, which is really helpful for me. Also, I totally agree with you that, for this specific optimization of the trace of the square, we don't need to calculate K.dot(K) at all. The only thing we need to take care of is how to efficiently get the diagonal entries. Thanks so much for your brilliant answer! :) |
Oh @kunwuz I just saw your message! Yes BLAS is a very carefully finetuned library. For the matrix multiplication complexity, wikipedia says that "As of December 2020, the matrix multiplication algorithm with best asymptotic complexity runs in O(n2.3728596) time" - cool, let's see how this \omega will be pushing forward. And thanks so much @tofuwen for your nice explanation - that's exactly what I identified. And @tofuwen's comment is very insightful: "Probably the optimization cannot be at NumPy level --- the optimization needs to at the "compile" level" - to let the complier determine the computational graph, not exactly as the code we write. Actually now numpy is already smart enough to detect some of these cases (though not at complier level; need your specification in codes). For example: We know that the matrix multiplication In [1]: import numpy as np
In [2]: M = np.random.rand(5000, 5000)
In [3]: MT_copy = M.T.copy()
In [4]: timeit M @ M.T
281 ms ± 3.69 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [5]: timeit M @ MT_copy
496 ms ± 4.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) But, for some of the calculations that I thought would be optimized - they are actually not. For example, there is no special treatment to In [6]: zeros = np.zeros_like(M)
In [7]: timeit M @ zeros
497 ms ± 3.49 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [8]: eye = np.eye(5000)
In [9]: timeit M @ eye
507 ms ± 11.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) I don't know the specific reason for this lol :) |
@MarkDana Aha, thanks for this information! I will keep this in mind when I implement stuff in the future :-) |
causallearn/utils/KCI/KCI.py
Outdated
""" | ||
T = Kx.shape[0] | ||
mean_appr = np.trace(Kx) * np.trace(Ky) / T | ||
var_appr = 2 * np.trace(Kx.dot(Kx)) * np.trace(Ky.dot(Ky)) / T / T | ||
mean_appr = np.diag(Kx).sum() * np.diag(Ky).sum() / T # same as np.trace(Kx) * np.trace(Ky) / T. a bit faster |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious why this will be faster than directly calling np.trace()?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh my fault. Last time I checked (but maybe with something wrong) and thought that np.diag(Kx).sum()
is faster by sacrificing the space of new array np.diag(Kx)
created.
I just double checked. The two's speed performance is almost same. I'll change back to np.trace
. Thanks!
@@ -426,11 +438,11 @@ def kernel_matrix(self, data_x, data_y, data_z): | |||
# construct Gaussian kernels according to learned hyperparameters | |||
Kzy = gpy.kernel_.k1(data_z, data_z) | |||
self.epsilon_y = np.exp(gpy.kernel_.theta[-1]) | |||
elif self.kernelY == 'Polynomial': | |||
elif self.kernelZ == 'Polynomial': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems an important bug to me --- the old code doesn't really work I guess?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it's a typo bug. KernelY
should be KernelZ
.
However, in practical usage, people usually choose the same kernel for both X, Y, and for Z. And thus this bug happens to be avoided so far.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool, good to know!
If we have tests earlier, we would detect this bug. ;)
@aoqiz A good case why good tests are needed. :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, thanks.
if self.epsilon_x != self.epsilon_y or (self.kernelZ == 'Gaussian' and self.use_gp): | ||
KyR, _ = Kernel.center_kernel_matrix_regression(Ky, Kzy, self.epsilon_y) | ||
else: | ||
# assert np.all(Kzx == Kzy), 'Kzx and Kzy are the same' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why this assertion is commented? We should have assertions whenever possible
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just for speed. np.all(Kzx == Kzy)
takes O(n^2) time to traverse each elements (just the same as the key function part e.g. center_kernel_matrix
after optimization).
I had assertions when writing and testing. But I tried to remove them for deployment. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cool, sounds good! Will any user input break this assertion? If so, maybe we can still add this assertion to make sure user gives the correct input.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To the user end, any data input or parameters choice will not break this assertion, i think.
""" | ||
# assert np.all(K == K.T), 'K should be symmetric' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not keep this assertion?
kci_ui_kwargs = {k: v for k, v in kwargs.items() if k in | ||
['kernelX', 'kernelY', 'null_ss', 'approx', 'est_width', 'polyd', 'kwidthx', 'kwidthy']} | ||
kci_ci_kwargs = {k: v for k, v in kwargs.items() if k in | ||
['kernelX', 'kernelY', 'kernelZ', 'null_ss', 'approx', 'use_gp', 'est_width', 'polyd', | ||
'kwidthx', 'kwidthy', 'kwidthz']} | ||
self.kci_ui = KCI_UInd(**kci_ui_kwargs) | ||
self.kci_ci = KCI_CInd(**kci_ci_kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm, seems a bug to me previously.
Are the old runnable when calling KCI_CInd?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My fault! The old code is runnable (for both KCI_UInd and KCI_CInd), but not comprehensive enough.
I only noticed the arguments for KCI_UInd last time. However, there exists some more arguments specifically for KCI_CInd (e.g., kernelZ
, use_gp
). So in the previous code, if some user specifies kernelZ
in CIT()
construction, it will be ignored actually.
To prevent from this again, I'll add unit tests for CIT soon.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
awesome, thanks!!
Thanks for the great work, I think the code is ready to push after addressing all the nits. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. Thanks for this great work.
Updated functions:
causallearn/utils/KCI/KCI.py
:KCI_UInd.get_kappa
: reduce O(n^3) to O(n^2) based on the equationnp.trace(K.dot(K)) == np.sum(K * K.T)
.elif self.kernelY == 'Polynomial'
: fixed a naming bug in original code. Should bekernelZ
.KCI_CInd.KCI_V_statistic
: L479, reduced one repeated calculation based on the fact thatKzx
andKzy
are usually same.causallearn/utils/KCI/Kernel.py
:Kernel.center_kernel_matrix
: reduce O(n^3) to O(n^2) by pluggingH = eye(n) - 1.0 / n
intoH.dot(K.dot(H))
and preventing the matrix multiplication.causallearn/search/FCMBased/lingam/hsic.py
: similar trick asKernel.center_kernel_matrix
.causallearn/utils/cit.py
: some trivial issue about specifying KCI kwargs.Speedup gain:
KCI_UInd
(unconditional test): huge speedup (order of magnitude).KCI_CInd
(conditional test): reduces around 30% of time.Test plan:
KCI_UInd
:Click to expand result
1000: 0.021124839782714844 0.051096200942993164
2000: 0.09719705581665039 0.2551460266113281
3000: 0.2547168731689453 0.7614071369171143
4000: 0.41847896575927734 1.7734618186950684
5000: 0.6685688495635986 3.47568416595459
6000: 0.9402382373809814 5.513365030288696
7000: 1.2306361198425293 9.101417064666748
8000: 1.7108919620513916 13.641188144683838
9000: 2.544398069381714 19.632148027420044
10000: 2.4818150997161865 25.683223009109497
11000: 2.8643081188201904 34.334508180618286
12000: 3.705733060836792 42.83501696586609
13000: 4.3541929721832275 56.730401039123535
14000: 4.609248161315918 68.94092202186584
15000: 5.336583852767944 83.44993829727173
16000: 6.201767921447754 103.21840572357178
17000: 7.315479040145874 128.12122511863708
18000: 8.262160062789917 153.92690014839172
19000: 8.924943208694458 182.9679470062256
20000: 9.806303977966309 210.9531922340393
KCI_CInd
:Click to expand result
500: 0.07363319396972656 0.10682296752929688
750: 0.23387503623962402 0.3343160152435303
1000: 0.4689028263092041 0.6680917739868164
1250: 1.0056912899017334 1.4335689544677734
1500: 1.8545763492584229 2.555225133895874
1750: 2.874561071395874 4.148720979690552
2000: 4.18735408782959 5.866267681121826
2250: 6.662433862686157 9.67330002784729
2500: 9.315114974975586 12.643889904022217
2750: 12.94197392463684 17.26453423500061
3000: 15.682774782180786 22.286062002182007
3250: 19.692520141601562 27.745235919952393
3500: 24.944950103759766 35.524142265319824
3750: 31.416036128997803 44.23967480659485
4000: 37.706125020980835 52.7396559715271
4250: 47.966763973236084 65.09382581710815
4500: 54.48931813240051 75.27616381645203
4750: 62.57163095474243 87.28439497947693
5000: 73.18962788581848 103.77545189857483
5250: 83.41915202140808 117.27962899208069
5500: 94.05676174163818 132.2629098892212
5750: 108.3281478881836 151.59633588790894
6000: 126.17060780525208 180.97449898719788
6250: 136.30685591697693 191.5569679737091
6500: 151.8123619556427 212.14035725593567
6750: 170.19479298591614 239.7194480895996
7000: 199.10978388786316 270.4898717403412
7250: 206.63027906417847 290.0352849960327
7500: 226.2251410484314 315.22972893714905
7750: 248.7306571006775 345.4090187549591
8000: 290.0406291484833 413.12121295928955
8250: 313.7044348716736 450.7824430465698
8500: 332.1730182170868 466.674519777298
8750: 352.3237581253052 486.6480839252472
9000: 374.6881170272827 519.5066709518433
9250: 412.67669320106506 571.1528761386871
9500: 436.2383370399475 612.9915819168091
9750: 474.7570939064026 669.191967010498