Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sharpwaves performance fixes #288

Merged

Conversation

toni-neurosc
Copy link
Collaborator

@toni-neurosc toni-neurosc commented Jan 26, 2024

When running the profiler on an offline analysis sharpwaves was, together with bursts, one of the more computationally costly features to calculate. I made some changes that brought the computation time for this feature down to 50%:

Commit 1 - Simplify _get_peaks_around: The profiler showed a very high amount of calls to argsort from the method _get_peaks_around and I realized that for each trough in the data, this function was being called and subsequently calling argsort twice. I worked out the logic of the function and realized that since scipy.signal.find_peaks is returning the peak indexes in order the calculation can be reduced to an array filtering operation which is much faster.

Commit 2 - Single pass adjacent peak finding: Even with the previous change, it seemed to me that calling _get_peaks_around once per trough, and then doing a comparison on the whole list of indexes was not necessary, so I replaced the _get_peaks_around calls with a variable right_peak_idx which each loop it increases until it goes past the current trough, keeping track of the right adjacent peak to the current trough at each step of the loop. This makes it so that all the work _get_peaks_around was doing we get for free now.

Commit 3 - fftconvolve: Since signal.convolve was calling signal.fftconvolve under-the-hood, I changed it to that, which removes the overhead of having to decide which method to use and makes the code more transparent in my opinion. Barely impacts performance.

-Before changes: 58.61098559697469 seconds per run.

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     2973    1.638    0.001   60.456    0.020 nm_sharpwaves.py:133(calc_feature)
    59460   18.797    0.000   49.876    0.001 nm_sharpwaves.py:250(analyze_waveform)
  2719352   16.517    0.000   26.317    0.000 nm_sharpwaves.py:88(_get_peaks_around)
  5469224    2.191    0.000    9.698    0.000 fromnumeric.py:1025(argsort)
  5469224    2.105    0.000    7.507    0.000 fromnumeric.py:53(_wrapfunc)
    29730    0.320    0.000    5.782    0.000 _signaltools.py:1299(convolve)
    29730    0.128    0.000    4.895    0.000 _signaltools.py:557(fftconvolve)
  5469224    4.626    0.000    4.626    0.000 {method 'argsort' of 'numpy.ndarray' objects}
    29730    0.363    0.000    4.052    0.000 _signaltools.py:459(_freq_domain_conv)
   118920    0.552    0.000    3.784    0.000 _peak_finding.py:729(find_peaks)

-After commit 1: 53.140084981918335 seconds per run.

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     2973    1.599    0.001   41.187    0.014 nm_sharpwaves.py:123(calc_feature)
    59460   16.981    0.000   30.919    0.001 nm_sharpwaves.py:240(analyze_waveform)
  2719352    9.412    0.000    9.412    0.000 nm_sharpwaves.py:89(_get_peaks_around)
    29730    0.309    0.000    5.646    0.000 _signaltools.py:1299(convolve)
    29730    0.124    0.000    4.802    0.000 _signaltools.py:557(fftconvolve)
    29730    0.355    0.000    3.993    0.000 _signaltools.py:459(_freq_domain_conv)
   118920    0.525    0.000    3.661    0.000 _peak_finding.py:729(find_peaks)
    89190    0.139    0.000    3.582    0.000 _backend.py:17(__ua_function__)
    59460    0.213    0.000    2.457    0.000 basic.py:203(r2cn)
  • After commit 2: 48.184030532836914 seconds per run
   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     2973    1.528    0.001   31.175    0.010 nm_sharpwaves.py:123(calc_feature)
    59460   17.096    0.000   21.421    0.000 nm_sharpwaves.py:240(analyze_waveform)
    29730    0.284    0.000    5.319    0.000 _signaltools.py:1299(convolve)
    29730    0.117    0.000    4.545    0.000 _signaltools.py:557(fftconvolve)
    29730    0.328    0.000    3.797    0.000 _signaltools.py:459(_freq_domain_conv)
   118920    0.516    0.000    3.505    0.000 _peak_finding.py:729(find_peaks)
    89190    0.137    0.000    3.421    0.000 _backend.py:17(__ua_function__)
    59460    0.199    0.000    2.325    0.000 basic.py:203(r2cn)
   118920    0.608    0.000    1.840    0.000 {scipy.signal._peak_finding_utils._select_by_peak_distance}
   178380    0.169    0.000    1.374    0.000 fromnumeric.py:2692(max)
  • After commit 3: 47.840235233306885 seconds per run.

@toni-neurosc
Copy link
Collaborator Author

toni-neurosc commented Jan 26, 2024

Since Python for-loops are infamously slow, I vectorized as much as I could of the analyze_waveform function. This cut down the computation time for this method by almost another half, for analyze_waveform we're at 27% of the original computation time.
Rise steepness and Decay steepness are a bit more complicated to vectorize so there's still a loop in there but I think I can use scipy.ndimage to get rid of it (not sure it would speed up anything but worth a try)

HUGE DISCLAIMER: I did my best to keep the math intact but I did this like 10 days ago so I'm not 100% sure the results are identical. Tests are passing but I haven't checked what they're doing.

It's also interesting that this opens the door to vectorizing the whole calc_feature function which might lead to further gains.

  • After vectorizing (except steepness): 46.9043915271759 seconds per run.
   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     2973    1.593    0.001   21.860    0.007 nm_sharpwaves.py:123(calc_feature)
    59460    5.236    0.000   13.390    0.000 nm_sharpwaves.py:240(analyze_waveform)
    29730    0.145    0.000    4.746    0.000 _signaltools.py:557(fftconvolve)
    29730    0.357    0.000    3.920    0.000 _signaltools.py:459(_freq_domain_conv)
   118920    0.533    0.000    3.720    0.000 _peak_finding.py:729(find_peaks)
  • After removing np.pad: 44.16844423611959 seconds per run.
   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     2973    1.563    0.001   16.731    0.006 nm_sharpwaves.py:122(calc_feature)
    59460    4.347    0.000    8.470    0.000 nm_sharpwaves.py:237(analyze_waveform)
    29730    0.134    0.000    4.443    0.000 _signaltools.py:557(fftconvolve)
    29730    0.323    0.000    3.678    0.000 _signaltools.py:459(_freq_domain_conv)
   118920    0.500    0.000    3.496    0.000 _peak_finding.py:729(find_peaks)
    89190    0.131    0.000    3.307    0.000 _backend.py:17(__ua_function__)
  • After vectorizing steepness: (CANCELLED)

@toni-neurosc
Copy link
Collaborator Author

Fixed the test errors, caused by doing a [list] - [list] operation. I ran the tests in my computer and did not get the error, probably because self.sw_settings["sharpwave_features"]["width"] was not set.
Do the tests in the GitHub Actions differ from the tests in /tests/?

@timonmerk
Copy link
Contributor

Wow, that's great additions and impressive speed-up!
I will probably try to take some time to understand it write potentially additional tests that capture this feature.

Also, the tests in /tests are called by tox, so the Github Actions tests should be identical.

@toni-neurosc
Copy link
Collaborator Author

Update on steepness: I don't think this is possible to vectorize, due to the fact that each peak-to-trough / trough-to-peak distance is different. Using nd.image is possible but would require building masks out of the indexes, which requires extra computation, and using a view on numpy arrays as I have it currently is pretty efficient already anyway.

Btw, I wanted to ask, is there a possibility that there is a bug here:

rise_steepness = np.max(np.diff(self.data_process_sw[peak_idx_left : trough_idx + 1]))
decay_steepness = np.max(np.diff(self.data_process_sw[trough_idx : peak_idx_right + 1]) )

np.diff gives the difference to previous, so if peaks have higher values than troughs, rise_steepness will have mostly negative values, so should it not be np.min instead of max?
Also I find it confusing that "rise" steepness is from peak to trough, feels like the naming should be reversed.

@timonmerk
Copy link
Contributor

timonmerk commented Jan 31, 2024

Yes! I think you're right. That's a bug. It should have been always the np.abs() before the np.max(). Good catch!
And the naming might be confusing since it could be performed for peaks and troughs. But I focused on troughs in the initial version

@toni-neurosc
Copy link
Collaborator Author

Cool! I was a bit hesitant to bring it up since I still don't really understand what a lot of the features are doing but this one was easy to visualize so the typo stood out to me. I added the correction to the PR.

@timonmerk timonmerk merged commit 704ca15 into neuromodulation:main Feb 3, 2024
2 checks passed
@toni-neurosc toni-neurosc deleted the sharpwaves_performance_pr branch February 3, 2024 19:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants