Skip to content

flash-algo/kernel-course

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

44 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

banner

Learn how to develop high-performance kernels with PyTorch, Triton, and CuTe while preserving numerical equivalence with the Python reference implementations. The exercises emphasize translating clear Python prototypes into optimized GPU kernels without sacrificing correctness.

Why

Learning high-performance computing as a non-CS major has been quite challenging for me. While I can grasp the fundamental concepts of HPC, I often struggle with effectively utilizing various frameworks to combine mathematical computations, memory layouts, and data transfers efficiently. To address this, I selected three popular frameworks: PyTorch for rapid mathematical implementations, Triton for parallel computations, and CuTe for maximizing performance through low-level memory management and thread scheduling. By progressively implementing foundational linear algebra operations in these frameworks and combining them into Transformer modules, I aim to enhance my skills. I believe this process will be beneficial to others as well, so I've decided to document it for sharing. Happy learning!

Basic Linear Algebra Subprograms

The following common BLAS kernels have been implemented in multiple frameworks. For each kernel, a ✅ indicates that the implementation is complete and verified to be numerically equivalent to the Python reference, a ❌ indicates that the implementation is pending. For more details on each kernel, please click the name or icon.

Name Description Equation Flops Data Python PyTorch Triton CuTe Test
copy copy vector $y = x$ $0$ $2n$
swap swap vectors $x \leftrightarrow y$ $0$ $4n$
scal scale vector $y = \alpha y$ $n$ $2n$
axpby update vector $y = \alpha x + \beta y$ $3n$ $3n$
dot dot product $z = x^\top y$ $2n$ $2n$
gemv general matrix-vector multiply $y = \alpha A x + \beta y$ $2mn$ $mn + n + 2m$
geru general rank-1 update $A = A + \alpha x y^\top$ $2mn$ $2mn + m + n$
gemm general matrix-matrix multiply $C = \alpha A B + \beta C$ $2mnk$ $mk + nk + 2mn$

Transformer Modules

The following common transformer modules have been implemented in multiple frameworks. For each module, a ✅ indicates that the implementation is complete and verified to be numerically equivalent to the Python reference, a ❌ indicates that the implementation is pending. For more details on each module, please click the name or icon.

Note

TODO

About

Learn how to develop kernels

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages