Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reimplement SCS in terms of conv2d #6

Closed
wants to merge 2 commits into from
Closed

Conversation

jatentaki
Copy link

This PR proposes a change in implementation of SCS which uses Conv2D, a much more performant primitive. The gist of the idea is to

  1. Normalize the kernel as we already do
  2. Compute per-window normalization factor coming from the image by sqrt(avg_pool2d(input ** 2) * window_size)
  3. Compute the dot products via conv2d between the input and normalized kernel
  4. Normalize the output elementwise by the array obtained in step 2.
  5. Proceed as before

On my laptop this achieves more than 2x performance improvement. I attach a quick test (compat_test.py) to verify that the two implementations yield (almost) the same results. This PR will obviously need some code unification and renaming before being merged, I submit it as it to facilitate easy comparisons by the maintainers :)

@brohrer
Copy link
Owner

brohrer commented Feb 22, 2022

This is fantastic. I look forward to doing a careful comparison of results and a clean integration. I hope I can get to it sooner rather than later. Thanks!

@enzokro
Copy link

enzokro commented Feb 23, 2022

This implementation is great! Thank you for making it.

I have a quick question about eps, but it shouldn't be a difference-maker.
For the sake of keeping its scale, should we add it when squaring the input:

       x_norm_squared = F.avg_pool2d(
            (x + self.eps) ** 2,
            kernel_size=self.kernel_size,
            stride=self.stride,
            padding=self.padding,
            divisor_override=1, # we actually want sum_pool
        ).sum(dim=1, keepdim=True)

So when the outputs are normalized below:

    y = y_denorm / (x_norm_squared.sqrt() + q_sqr)

It will be with eps instead of sqrt(eps)?

@jatentaki
Copy link
Author

@enzokro you're right, it would make more sense to add it before avg_pool2d. I somehow missed that this addition can be moved there in order to completely preserve the original semantics.

@brohrer
Copy link
Owner

brohrer commented Mar 10, 2022

Thanks for this @jatentaki ! It looks like several contributors had similar ideas at the same time. A similar innovation from @ClashLuke has been folded into the main branch. I'll go ahead and close this PR for now, but thank you very much for spending time thinking about this.

@brohrer brohrer closed this Mar 10, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants