Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KCI mild speedup #49

Merged
merged 4 commits into from Jul 5, 2022
Merged

KCI mild speedup #49

merged 4 commits into from Jul 5, 2022

Conversation

MarkDana
Copy link
Collaborator

@MarkDana MarkDana commented Jul 4, 2022

Updated functions:

  • causallearn/utils/KCI/KCI.py:
    • KCI_UInd.get_kappa: reduce O(n^3) to O(n^2) based on the equation np.trace(K.dot(K)) == np.sum(K * K.T).
    • L429 elif self.kernelY == 'Polynomial': fixed a naming bug in original code. Should be kernelZ.
    • KCI_CInd.KCI_V_statistic: L479, reduced one repeated calculation based on the fact that Kzx and Kzy are usually same.
  • causallearn/utils/KCI/Kernel.py:
    • Kernel.center_kernel_matrix: reduce O(n^3) to O(n^2) by plugging H = eye(n) - 1.0 / n into H.dot(K.dot(H)) and preventing the matrix multiplication.
  • causallearn/search/FCMBased/lingam/hsic.py: similar trick as Kernel.center_kernel_matrix.
  • causallearn/utils/cit.py: some trivial issue about specifying KCI kwargs.

Speedup gain:

  • KCI_UInd (unconditional test): huge speedup (order of magnitude).
  • KCI_CInd (conditional test): reduces around 30% of time.
  • Overall, since conditional tests is always the bottleneck in an algorithm, this speedup is mild.
  • And thus we'll need more efficient way to calculate inverse and eigens of big matrix.

Test plan:

  • To ensure that this update is logically consistent with the original code (under any possible parameters):
np.random.seed(42)
from causallearn_fa79007.utils.cit import CIT as CIT_new
from causallearn_97e03ff.utils.cit import CIT as CIT_old

data = np.random.uniform(-1, 1, (100, 4))
for kernelname in ['Gaussian', 'Polynomial', 'Linear']:
    for est_width in ['empirical', 'median', 'manual']:
        for kwidth in [0.05, 0.1, 0.2]:
            for use_gp in [True, False]:
                for approx in [True, False]:
                    for polyd in [1, 2]:
                        kci_new = CIT_new(data, 'kci', kernelX=kernelname, kernelY=kernelname, kernelZ=kernelname,
                                              est_width=est_width, use_gp=use_gp, approx=approx, polyd=polyd,
                                              kwidthx=kwidth, kwidthy=kwidth, kwidthz=kwidth)
                        kci_old = CIT_old(data, 'kci', kernelX=kernelname, kernelY=kernelname, kernelZ=kernelname,
                                              est_width=est_width, use_gp=use_gp, approx=approx, polyd=polyd,
                                              kwidthx=kwidth, kwidthy=kwidth, kwidthz=kwidth)

                        # since there is randomness in null_sample_spectral, we need to fix the same seed for old and new run
                        np.random.seed(42); new_pval = kci_new(0, 1)
                        np.random.seed(42); old_pval = kci_old(0, 1)
                        assert np.isclose(old_pval, new_pval), "KCI_UIND is inconsistent after update."

                        np.random.seed(42); new_pval = kci_new(0, 1, (2, 3))
                        np.random.seed(42); old_pval = kci_old(0, 1, (2, 3))
                        assert np.isclose(old_pval, new_pval), "KCI_CIND is inconsistent after update."
  • Test speedup for KCI_UInd:
for samplesize in range(1000, 20001, 1000):
    data = np.random.uniform(-1, 1, (samplesize, 2))
    kci_new = CIT_new(data, 'kci')
    kci_old = CIT_old(data, 'kci')
    tic = time.time(); kci_new(0, 1); tac = time.time(); time_new = tac - tic
    tic = time.time(); kci_old(0, 1); tac = time.time(); time_old = tac - tic
    print(f'{samplesize}:   {time_new} {time_old}')
Click to expand result

1000: 0.021124839782714844 0.051096200942993164
2000: 0.09719705581665039 0.2551460266113281
3000: 0.2547168731689453 0.7614071369171143
4000: 0.41847896575927734 1.7734618186950684
5000: 0.6685688495635986 3.47568416595459
6000: 0.9402382373809814 5.513365030288696
7000: 1.2306361198425293 9.101417064666748
8000: 1.7108919620513916 13.641188144683838
9000: 2.544398069381714 19.632148027420044
10000: 2.4818150997161865 25.683223009109497
11000: 2.8643081188201904 34.334508180618286
12000: 3.705733060836792 42.83501696586609
13000: 4.3541929721832275 56.730401039123535
14000: 4.609248161315918 68.94092202186584
15000: 5.336583852767944 83.44993829727173
16000: 6.201767921447754 103.21840572357178
17000: 7.315479040145874 128.12122511863708
18000: 8.262160062789917 153.92690014839172
19000: 8.924943208694458 182.9679470062256
20000: 9.806303977966309 210.9531922340393

time_uind

  • Test speedup for KCI_CInd:
for samplesize in range(500, 10000, 250):
    data = np.random.uniform(-1, 1, (samplesize, 4))
    kci_new = CIT_new(data, 'kci')
    kci_old = CIT_old(data, 'kci')
    tic = time.time(); kci_new(0, 1, (2, 3)); tac = time.time(); time_new = tac - tic
    tic = time.time(); kci_old(0, 1, (2, 3)); tac = time.time(); time_old = tac - tic
    print(f'{samplesize}:   {time_new} {time_old}')
Click to expand result

500: 0.07363319396972656 0.10682296752929688
750: 0.23387503623962402 0.3343160152435303
1000: 0.4689028263092041 0.6680917739868164
1250: 1.0056912899017334 1.4335689544677734
1500: 1.8545763492584229 2.555225133895874
1750: 2.874561071395874 4.148720979690552
2000: 4.18735408782959 5.866267681121826
2250: 6.662433862686157 9.67330002784729
2500: 9.315114974975586 12.643889904022217
2750: 12.94197392463684 17.26453423500061
3000: 15.682774782180786 22.286062002182007
3250: 19.692520141601562 27.745235919952393
3500: 24.944950103759766 35.524142265319824
3750: 31.416036128997803 44.23967480659485
4000: 37.706125020980835 52.7396559715271
4250: 47.966763973236084 65.09382581710815
4500: 54.48931813240051 75.27616381645203
4750: 62.57163095474243 87.28439497947693
5000: 73.18962788581848 103.77545189857483
5250: 83.41915202140808 117.27962899208069
5500: 94.05676174163818 132.2629098892212
5750: 108.3281478881836 151.59633588790894
6000: 126.17060780525208 180.97449898719788
6250: 136.30685591697693 191.5569679737091
6500: 151.8123619556427 212.14035725593567
6750: 170.19479298591614 239.7194480895996
7000: 199.10978388786316 270.4898717403412
7250: 206.63027906417847 290.0352849960327
7500: 226.2251410484314 315.22972893714905
7750: 248.7306571006775 345.4090187549591
8000: 290.0406291484833 413.12121295928955
8250: 313.7044348716736 450.7824430465698
8500: 332.1730182170868 466.674519777298
8750: 352.3237581253052 486.6480839252472
9000: 374.6881170272827 519.5066709518433
9250: 412.67669320106506 571.1528761386871
9500: 436.2383370399475 612.9915819168091
9750: 474.7570939064026 669.191967010498

time_cind

@jdramsey
Copy link
Collaborator

jdramsey commented Jul 4, 2022

Nice.

@kunwuz
Copy link
Collaborator

kunwuz commented Jul 4, 2022

Wow, that's really incredible! Thanks so much, Haoyue!!!

A dumb question: any idea why np.sum(K * K.T) is much faster for calculating the trace of the square?

@tofuwen
Copy link
Contributor

tofuwen commented Jul 4, 2022

Wow, this is great, @MarkDana! Thanks for identifying such a non-trivial optimization!! :)

@tofuwen
Copy link
Contributor

tofuwen commented Jul 4, 2022

@kunwuz to answer your question, matrix multiplication is in general O(n^3) (but there is some optimization usually, but I am not expert in this area), and K * K.T is only O(n^2), so the optimization above reduced from O(n^3) to O(n^2)

@kunwuz
Copy link
Collaborator

kunwuz commented Jul 4, 2022

@kunwuz to answer your question, matrix multiplication is in general O(n^3) (but there is some optimization usually, but I am not expert in this area), and K * K.T is only O(n^2), so the optimization above reduced from O(n^3) to O(n^2)

Thanks! Yeah, I always thought that NumPy did a perfect job in optimizing matrix multiplication so I didn't expect the empirical speed-up will be such huge. Nice to know that trick~ :)

@tofuwen
Copy link
Contributor

tofuwen commented Jul 4, 2022

@kunwuz That's very different. Even NumPy did a perfect job at optimizing matrix multiplication, there is no way to reduce matrix multiplication to O(n^2). Here the problem is that we only need the diagonal entries of K.dot(K), but we calculated the whole K.dot(K), which wastes lots of time. Probably the optimization cannot be at NumPy level --- the optimization needs to at the "compile" level, i.e. compiler needs to know the program only need diagonal entries, so underlying machine level code it compiled to does NOT do what you told it to do, i.e. it didn't even calculate K.dot(K) for you. Of course, this is just a rough idea but not 100% correct --- I don't know the mechanmism of Python much, so the language I use is mostly for C.

@kunwuz
Copy link
Collaborator

kunwuz commented Jul 4, 2022

@kunwuz That's very different. Even NumPy did a perfect job at optimizing matrix multiplication, there is no way to reduce matrix multiplication to O(n^2). Here the problem is that we only need the diagonal entries of K.dot(K), but we calculated the whole K.dot(K), which wastes lots of time. Probably the optimization cannot be at NumPy level --- the optimization needs to at the "compile" level, i.e. compiler needs to know the program only need diagonal entries, so underlying machine level code it compiled to does NOT do what you told it to do, i.e. it didn't even calculate K.dot(K) for you. Of course, this is just a rough idea but not 100% correct --- I don't know the mechanmism of Python much, so the language I use is mostly for C.

Yes, you are totally right. The hypothesized best possible value for the exponent \omega is 2, so we cannot exactly achieve it. If I remember correctly, there exists optimization that achieved \omega < 2.4, and numpy is one of the SOTA methods, which goes beyond compiler level (BLAS), in optimizing matrix multiplication. So I'm a little bit surprised the EMPIRICAL improvement could be such huge, which is really helpful for me.

Also, I totally agree with you that, for this specific optimization of the trace of the square, we don't need to calculate K.dot(K) at all. The only thing we need to take care of is how to efficiently get the diagonal entries. Thanks so much for your brilliant answer! :)

@MarkDana
Copy link
Collaborator Author

MarkDana commented Jul 4, 2022

Oh @kunwuz I just saw your message! Yes BLAS is a very carefully finetuned library. For the matrix multiplication complexity, wikipedia says that "As of December 2020, the matrix multiplication algorithm with best asymptotic complexity runs in O(n2.3728596) time" - cool, let's see how this \omega will be pushing forward.

And thanks so much @tofuwen for your nice explanation - that's exactly what I identified. np.trace(K.dot(K)) only needs the diagonal elements of K.dot(K), so we can save time on the non-diagonal elements calculation. Similarly, for matrix multiplication H @ K @ H where H is e.g., all one matrix, we can manually calculate the row sums and column sums respectively, instead of giving the multiplication to numpy.

And @tofuwen's comment is very insightful: "Probably the optimization cannot be at NumPy level --- the optimization needs to at the "compile" level" - to let the complier determine the computational graph, not exactly as the code we write. Actually now numpy is already smart enough to detect some of these cases (though not at complier level; need your specification in codes). For example:

We know that the matrix multiplication M @ M.T returns a symmetric matrix. Can numpy detect this and save time by calculating only the triangular half? The answer is yes:

In [1]: import numpy as np
In [2]: M = np.random.rand(5000, 5000)
In [3]: MT_copy = M.T.copy()

In [4]: timeit M @ M.T
281 ms ± 3.69 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [5]: timeit M @ MT_copy
496 ms ± 4.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

But, for some of the calculations that I thought would be optimized - they are actually not. For example, there is no special treatment to zeros or eye in matrix multiplication:

In [6]: zeros = np.zeros_like(M)
In [7]: timeit M @ zeros
497 ms ± 3.49 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [8]: eye = np.eye(5000)
In [9]: timeit M @ eye
507 ms ± 11.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

I don't know the specific reason for this lol :)

@kunwuz
Copy link
Collaborator

kunwuz commented Jul 4, 2022

@MarkDana Aha, thanks for this information! I will keep this in mind when I implement stuff in the future :-)

"""
T = Kx.shape[0]
mean_appr = np.trace(Kx) * np.trace(Ky) / T
var_appr = 2 * np.trace(Kx.dot(Kx)) * np.trace(Ky.dot(Ky)) / T / T
mean_appr = np.diag(Kx).sum() * np.diag(Ky).sum() / T # same as np.trace(Kx) * np.trace(Ky) / T. a bit faster
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious why this will be faster than directly calling np.trace()?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh my fault. Last time I checked (but maybe with something wrong) and thought that np.diag(Kx).sum() is faster by sacrificing the space of new array np.diag(Kx) created.

I just double checked. The two's speed performance is almost same. I'll change back to np.trace. Thanks!

@@ -426,11 +438,11 @@ def kernel_matrix(self, data_x, data_y, data_z):
# construct Gaussian kernels according to learned hyperparameters
Kzy = gpy.kernel_.k1(data_z, data_z)
self.epsilon_y = np.exp(gpy.kernel_.theta[-1])
elif self.kernelY == 'Polynomial':
elif self.kernelZ == 'Polynomial':
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems an important bug to me --- the old code doesn't really work I guess?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it's a typo bug. KernelY should be KernelZ.

However, in practical usage, people usually choose the same kernel for both X, Y, and for Z. And thus this bug happens to be avoided so far.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, good to know!

If we have tests earlier, we would detect this bug. ;)

@aoqiz A good case why good tests are needed. :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thanks.

if self.epsilon_x != self.epsilon_y or (self.kernelZ == 'Gaussian' and self.use_gp):
KyR, _ = Kernel.center_kernel_matrix_regression(Ky, Kzy, self.epsilon_y)
else:
# assert np.all(Kzx == Kzy), 'Kzx and Kzy are the same'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this assertion is commented? We should have assertions whenever possible

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for speed. np.all(Kzx == Kzy) takes O(n^2) time to traverse each elements (just the same as the key function part e.g. center_kernel_matrix after optimization).

I had assertions when writing and testing. But I tried to remove them for deployment. What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool, sounds good! Will any user input break this assertion? If so, maybe we can still add this assertion to make sure user gives the correct input.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To the user end, any data input or parameters choice will not break this assertion, i think.

"""
# assert np.all(K == K.T), 'K should be symmetric'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not keep this assertion?

Comment on lines +34 to +40
kci_ui_kwargs = {k: v for k, v in kwargs.items() if k in
['kernelX', 'kernelY', 'null_ss', 'approx', 'est_width', 'polyd', 'kwidthx', 'kwidthy']}
kci_ci_kwargs = {k: v for k, v in kwargs.items() if k in
['kernelX', 'kernelY', 'kernelZ', 'null_ss', 'approx', 'use_gp', 'est_width', 'polyd',
'kwidthx', 'kwidthy', 'kwidthz']}
self.kci_ui = KCI_UInd(**kci_ui_kwargs)
self.kci_ci = KCI_CInd(**kci_ci_kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm, seems a bug to me previously.

Are the old runnable when calling KCI_CInd?

Copy link
Collaborator Author

@MarkDana MarkDana Jul 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My fault! The old code is runnable (for both KCI_UInd and KCI_CInd), but not comprehensive enough.

I only noticed the arguments for KCI_UInd last time. However, there exists some more arguments specifically for KCI_CInd (e.g., kernelZ, use_gp). So in the previous code, if some user specifies kernelZ in CIT() construction, it will be ignored actually.

To prevent from this again, I'll add unit tests for CIT soon.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome, thanks!!

@tofuwen
Copy link
Contributor

tofuwen commented Jul 4, 2022

Thanks for the great work, I think the code is ready to push after addressing all the nits.

Copy link
Contributor

@aoqiz aoqiz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. Thanks for this great work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants