Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom primitive + RoPE fat op #676

Merged
merged 12 commits into from Feb 14, 2024
Merged

Custom primitive + RoPE fat op #676

merged 12 commits into from Feb 14, 2024

Conversation

awni
Copy link
Member

@awni awni commented Feb 12, 2024

Proposed changes

  • add RoPE kernel
  • test transforms of custom primitive
  • benchmarks

@awni awni marked this pull request as draft February 12, 2024 22:24
@angeloskath
Copy link
Member

Some benchmarks of the kernel on my M2 air

Before

Timing rope_vec ... 3.99837 msec
Timing rope_mat ... 63.36270 msec

After

Timing rope_vec ... 0.61199 msec
Timing rope_mat ... 6.72898 msec

The tests fail on float16 and bfloat16 but due to numerical issues. Tomorrow I will do a quick check on the performance if we do all of the computation in float32 in the kernel since it probably doesn't matter at all performance wise.

if (dims_ != in.shape(-1)) {
throw std::runtime_error("[RoPE] Partial RoPE application not supported");
}
if (in.flags().row_contiguous && in.is_donatable()) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a contig and copy check before this right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what the copy check is. Also row_contiguous is stricter than contiguous is it not? ie all row_contiguous arrays are contiguous but not the other way around.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant, if it's not contiguous, we should make a contiguous copy

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not appear to me that your kernel handles non-contiguous inputs, but maybe I missed something..

Copy link
Member Author

@awni awni Feb 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I think I missed it, I was looking for elem_to_loc, but you hardcoded the strides.. so it should be ok

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to check here though that the input has the same size as the output? If it's broadcasted e.g. along the last axis it would be incorrect to donate right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I hardcoded the strides cause the grid is launched with half the last dimension and it can't be delegated to a simple elem_to_loc. I would have to do something like multiply pos.x by 2 and then pass to elem_to_loc etc. I think this is equally readable but I am open to suggestions :-)

Regarding broadcasting, a broadcasted array wouldn't be row_contiguous so this check should be fine donation-wise, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh of course! Let me quietly exit this thread before I say anything else incorrect

@awni
Copy link
Member Author

awni commented Feb 13, 2024

Wow, that's so fast! We can also increase the tolerance for the lower precision tests if that's simpler.

@awni awni marked this pull request as ready for review February 13, 2024 19:16
@awni
Copy link
Member Author

awni commented Feb 13, 2024

Make this a real PR since I think we are almost done.

Copy link
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great. I really like the Custom primitive.

"traditional"_a,
"base"_a,
"scale"_a,
"offset"_a,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we should make the above keyword only? It would be verbose but error free...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so, yes.

@awni awni merged commit ccf1645 into main Feb 14, 2024
2 checks passed
@awni awni deleted the extensions branch February 14, 2024 22:04
awni added a commit that referenced this pull request Feb 15, 2024
* extensions start

* rope custom op

* fix build

* docs + rope benchmark

* fix test

* Add a Metal kernel for RoPE

* Fix position of traditional

* transform tests

* Move rope computation to float and fix tests

* Fix the test and a typo

* change to fast

* fix no metal build

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants