Learn how to develop high-performance kernels with PyTorch, Triton, and CuTe while preserving numerical equivalence with the Python reference implementations. The exercises emphasize translating clear Python prototypes into optimized GPU kernels without sacrificing correctness.
Learning high-performance computing as a non-CS major has been quite challenging for me. While I can grasp the fundamental concepts of HPC, I often struggle with effectively utilizing various frameworks to combine mathematical computations, memory layouts, and data transfers efficiently. To address this, I selected three popular frameworks: PyTorch for rapid mathematical implementations, Triton for parallel computations, and CuTe for maximizing performance through low-level memory management and thread scheduling. By progressively implementing foundational linear algebra operations in these frameworks and combining them into Transformer modules, I aim to enhance my skills. I believe this process will be beneficial to others as well, so I've decided to document it for sharing. Happy learning!
The following common BLAS kernels have been implemented in multiple frameworks. For each kernel, a ✅ indicates that the implementation is complete and verified to be numerically equivalent to the Python reference, a ❌ indicates that the implementation is pending. For more details on each kernel, please click the name or icon.
| Name | Description | Equation | Flops | Data | Python | PyTorch | Triton | CuTe | Test |
|---|---|---|---|---|---|---|---|---|---|
| copy | copy vector | ✅ | ✅ | ✅ | ❌ | ✅ | |||
| swap | swap vectors | ❌ | ❌ | ❌ | ❌ | ❌ | |||
| scal | scale vector | ❌ | ❌ | ❌ | ❌ | ❌ | |||
| axpby | update vector | ❌ | ❌ | ❌ | ❌ | ❌ | |||
| dot | dot product | ❌ | ❌ | ❌ | ❌ | ❌ | |||
| gemv | general matrix-vector multiply | ❌ | ❌ | ❌ | ❌ | ❌ | |||
| geru | general rank-1 update | ❌ | ❌ | ❌ | ❌ | ❌ | |||
| gemm | general matrix-matrix multiply | ❌ | ❌ | ❌ | ❌ | ❌ |
The following common transformer modules have been implemented in multiple frameworks. For each module, a ✅ indicates that the implementation is complete and verified to be numerically equivalent to the Python reference, a ❌ indicates that the implementation is pending. For more details on each module, please click the name or icon.
Note
TODO
