Skip to content

Better convolve1d implementation on jax and torch#52

Merged
hexane360 merged 4 commits into
developfrom
convolve1d-test
Jun 4, 2026
Merged

Better convolve1d implementation on jax and torch#52
hexane360 merged 4 commits into
developfrom
convolve1d-test

Conversation

@hexane360
Copy link
Copy Markdown
Owner

  • On jax, convolve1d uses conv_general_dilated directly, which should dispatch to a 1D convolution rather than a 3D convolution with kernel size (1, 1, N).
  • On torch, xp.pad has a fallback for cases torch.nn.functional.pad doesn't support
  • On torch, convolve1d now uses xp.pad. Importantly, kernel size >= object size now works
  • Add test cases for xp.pad

@hexane360 hexane360 requested a review from mlz-EM May 22, 2026 19:59
@hexane360
Copy link
Copy Markdown
Owner Author

@mlz-EM can you review when you get a chance? Would be helpful if you could benchmark layers regularization on jax before and after, make sure it's not significantly slower

@mlz-EM
Copy link
Copy Markdown
Collaborator

mlz-EM commented Jun 4, 2026

benchmark [grad_trial0301_main_sigma0.json](https://github.com/user-attachments/files/28610423/grad_trial0301_main_sigma0.json) speed benchmarked wtih PSO data and attached recipe. Results reprodue the main branch with noticable speedup. But the speed remain constant with respect to sigma for both the old and PR. I don't think this should be the expected bechavior. Especially with sigma=0, the layers contraint are completely disabled and the convovle1d is not supposed to be called, but the iteration time is still the same.

Looks good to merge from the PR perspective. For the sigma issue, we should probably check if convolve1d is being silently called elsewhere when it shouldn't be

@mlz-EM
Copy link
Copy Markdown
Collaborator

mlz-EM commented Jun 4, 2026

benchmark grad_trial0301_main_sigma0.json speed benchmarked wtih PSO data and attached recipe. Results reprodue the main branch with noticable speedup. But the speed remain constant with respect to sigma for both the old and PR. I don't think this should be the expected bechavior. Especially with sigma=0, the layers contraint are completely disabled and the convovle1d is not supposed to be called, but the iteration time is still the same.
Looks good to merge from the PR perspective. For the sigma issue, we should probably check if convolve1d is being silently called elsewhere when it shouldn't be

Ealier benchmark was against main and speedup is culumative changes in the develop branch. Updated is against develop branch. And also layer blur only benchmark without full reconstruction, which shows the expected slowdown with larger sigma. The PR speedup is real, but the time for layer blur is neglible comparing to the iteration time

grad_trial0301_benchmark_plot layer_blur_benchmark_plot

@hexane360 hexane360 merged commit 7a4a5ff into develop Jun 4, 2026
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants