Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AOT compilation #935

Open
david-macleod opened this issue Dec 1, 2022 · 22 comments
Open

AOT compilation #935

david-macleod opened this issue Dec 1, 2022 · 22 comments

Comments

@david-macleod
Copy link
Contributor

Hi, I was just wondering if there had been any more thoughts on supporting AOT kernel compilation to allow execution outside of Python? Referencing #175

@gaxler
Copy link
Contributor

gaxler commented Dec 2, 2022

We are waiting on a rewrite to be done

See: #490 (comment)

@yufenglee
Copy link

Nice! Do you have a rough estimation when it will be done?

@ptillet
Copy link
Collaborator

ptillet commented Dec 2, 2022

The rewrite will be done this months. There is some very basic aot that we made for unit testing purposes right now, but efforts on a more complex one will be able to resume after then.

@yufenglee
Copy link

And what will AoT compilation generate, a C/C++ API plus source/.so?

@david-macleod
Copy link
Contributor Author

Great news, is there some branch/PR we can track the progress of this?

@david-macleod
Copy link
Contributor Author

@ptillet I am very keen to have a go at using this feature whatever state the code currently is in, even if it is only the unit test you mentioned previously (have a time sensitive project which could benefit from AOT functionality)

@gaxler
Copy link
Contributor

gaxler commented Dec 20, 2022

We have a prototype that works with an old version of Triton. You might be able to hack it for your needs?
#490

@gaxler
Copy link
Contributor

gaxler commented Dec 20, 2022

And what will AoT compilation generate, a C/C++ API plus source/.so?

For previous iterations we started with a C code that holds the kernels in source.
The thinking is to give users something very general.

@david-macleod
Copy link
Contributor Author

We have a prototype that works with an old version of Triton. You might be able to hack it for your needs? #490

Great thanks @gaxler, will give it a go! For the main feature is there any WIP branch that can be tracked or is it separate from the main repo?

@david-macleod
Copy link
Contributor Author

@gaxler should there be a correlation between the triton BLOCK_SIZE defined in the kernel definition, and the gX, gY, gZ defined in GridWarps when calling the kernel?

@gaxler
Copy link
Contributor

gaxler commented Jan 4, 2023

@gaxler should there be a correlation between the triton BLOCK_SIZE defined in the kernel definition, and the gX, gY, gZ defined in GridWarps when calling the kernel?

You mean add grid size constrains at compile-time?

In general I avoided dealing with anything related to kernel launches in the draft PR, its all just placeholders to make it run

@david-macleod
Copy link
Contributor Author

david-macleod commented Jan 5, 2023

Great thanks! I now have it working but have noticed the performance is much worse than the JIT triton equivalent. From the profile trace I see large gaps between the triton kernel and the preceeding/successive kernels.

I am aware you are not actively maintaining this but was just wondering if this was expected or had any hints? I am not that familiar with PTX but understand it is JIT compiled so was wondering if it was not being cached correctly or something like that.

@gaxler
Copy link
Contributor

gaxler commented Jan 5, 2023

sorry that you have to bump into all those things. this is just a POC and in no way optimized.
thanks for profiling the generated code!!

probably the worst thing for the C code performance is the PTX. it gets compiled to binary every time you call a kernel. this will be replaced by a cubin.

another overhead might be the dispatch for different input sizes. not sure how significant it is for overall performance.

perhaps you can use several cuda streams to bypass those issues?

@david-macleod
Copy link
Contributor Author

david-macleod commented Jan 5, 2023

If I know my target hardware apriori is there any downside/gotchas to me dumping the ptx code to a file and compiling down to cubin and loading that instead? Could that potentially help with the overheads?

@david-macleod
Copy link
Contributor Author

david-macleod commented Jan 6, 2023

Converting to cubin has helped a lot! (in the trace the triton kernel is the one that sits between the orange and green)

JIT
image

AOT - PTX
image

AOT - cubin
image

Whilst the overhead is now much smaller, there is still a gap in utilization before and after the AOT triton kernel is run (perhaps there is some implicit synchronisation happening).

Regarding your suggestion about the dispatch time, I am guessing that could result in a delay on host thread but as long as it is launched sufficiently before the device is ready to execute the kernel (which we are pretty sure is the case here), that cost should be hidden?

EDIT: I now think the overheads might be related to the module loading, need to confirm

@gaxler
Copy link
Contributor

gaxler commented Jan 6, 2023

Assuming the _tr... is a triton JITFunction for JIT and the launch function from the generated C code for AOT.

I think you are correct.
The JITFunction does the module and function loading before it calls the launch code. For the generated C code each call loads the module and the CUFunction.

Thanks for doing this, this will be helpful when thinking about optimizing the generated code!

@david-macleod
Copy link
Contributor Author

Tried caching the loaded CUFunction and things are now looking very close to JIT performance (only 5-10% slower now) 🙂

image

@gaxler
Copy link
Contributor

gaxler commented Jan 14, 2023

Got a new prototype together, maybe this can help in some way: #1056

@david-macleod
Copy link
Contributor Author

Thanks, will check it out

@david-macleod
Copy link
Contributor Author

Do you know how close it is to being merged? (just trying to gauge whether I should wait - or working from the branch)

@gaxler
Copy link
Contributor

gaxler commented Apr 5, 2023

It's pretty close but there are other things that have priority over merging it. So branch will be better. I'm happy to help, it will be great to get user feedback

@david-macleod
Copy link
Contributor Author

@gaxler what is the relationship between this branch and aot.py on master? Will they both continue to exist after this branch is complete?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants