<a href="https://colab.research.google.com/github/nyck33/mlir-python-extras-copy/blob/main/outputting_ptx_for_sm75_cuda12_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# CUDA/NVGPU/NVVM E2E

In [None]:
!pip install -q  mlir_python_bindings==19.0.0.2024033101+cuda.a67b9326 -f https://makslevental.github.io/wheels
!pip install -q git+https://github.com/makslevental/mlir-python-extras.git

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.3/58.3 MB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/1.7 MB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for mlir-python-extras (pyproject.toml) ... [?25l[?25hdone


# Boilerplate

In [None]:
from pathlib import Path

import mlir.extras.types as T
from mlir.dialects import builtin
from mlir.dialects.transform import any_op_t
from mlir.dialects.transform.extras import named_sequence
from mlir.dialects.transform.structured import MatchInterfaceEnum
from mlir.ir import StringAttr, UnitAttr, Module, Operation, OpView

from mlir import _mlir_libs
from mlir.extras.ast.canonicalize import canonicalize
from mlir.extras.context import RAIIMLIRContext, ExplicitlyManagedModule
from mlir.extras.dialects.ext import arith, memref, scf, gpu
from mlir.extras.dialects.ext import linalg
from mlir.extras.dialects.ext import transform
from mlir.extras.dialects.ext.func import func
from mlir.extras.runtime.passes import Pipeline, run_pipeline
from mlir.extras.runtime.refbackend import LLVMJITBackend
from mlir.extras.util import find_ops

from typing import Callable, List, Optional, Sequence, Tuple, Union

from mlir.passmanager import PassManager


CUDA_RUNTIME_LIB_PATH = Path(_mlir_libs.__file__).parent / f"libmlir_cuda_runtime.so"
assert CUDA_RUNTIME_LIB_PATH.exists()

# Context

In [None]:
ctx = RAIIMLIRContext()

src = """
#map0 = affine_map<(d0) -> (d0)>
#map1 = affine_map<(d0) -> (d0 + 64)>
#map2 = affine_map<(d0) -> (d0 + 128)>

//new affine maps
// Affine maps for local indexing within shared memory tiles
#localA_row = affine_map<(ii, iii) -> (ii + iii)>
#localA_col = affine_map<(kk, kkk) -> (kk + kkk)>
#localB_col = affine_map<(jj, jjj) -> (jj + jjj)>

// Affine maps for calculating global indices for storing into matrix C
#globalC_row = affine_map<(i, ii, iii) -> (i + ii + iii)>
#globalC_col = affine_map<(j, jj, jjj) -> (j + jj + jjj)>
module {
    // Shared memory buffers for A and B.
    memref.global "private" @b_smem_global : memref<64x136xf16, 3>
    memref.global "private" @a_smem_global : memref<128x72xf16, 3>
    func.func @main() {
        // Allocate memory for A, B on host using half precision
        %hA = memref.alloc() : memref<8192x8192xf16>
        %hB = memref.alloc() : memref<8192x8192xf16>
        // Allocate memory for output matrix C on host using single precision
        %hC = memref.alloc() : memref<8192x8192xf32>

        // Define constants used in the program
        %f1 = arith.constant 1.0e+00 : f16 // Constant value 1.0 of type half-precision float
        %f0 = arith.constant 0.0e+00 : f32 // Constant value 0.0 of type single-precision float

        %c0 = arith.constant 0 : index
        %c1 = arith.constant 1 : index

        %c8192 = arith.constant 8192 : index

        //initialize the input matrices with ones
        scf.for %arg0 = %c0 to %c8192 step %c1 {
            scf.for %arg1 = %c0 to %c8192 step %c1 {
                memref.store %f1, %hA[%arg0, %arg1] : memref<8192x8192xf16>
            }
        }

        //now initialize hB with ones
        scf.for %arg0 = %c0 to %c8192 step %c1 {
            scf.for %arg1 = %c0 to %c8192 step %c1 {
                memref.store %f1, %hB[%arg0, %arg1] : memref<8192x8192xf16>
            }
        }

        //now initialize hC with zeros
        scf.for %arg0 = %c0 to %c8192 step %c1 {
            scf.for %arg1 = %c0 to %c8192 step %c1 {
                memref.store %f0, %hC[%arg0, %arg1] : memref<8192x8192xf32>
            }
        }
        // Asynchronous operations token
        %token = gpu.wait async

        // Allocate device memory for matrices A, B, and C asynchronously
        %A, %tokenA = gpu.alloc async [%token] () : memref<8192x8192xf16>
        %B, %tokenB = gpu.alloc async [%token] () : memref<8192x8192xf16>
        %C, %tokenC = gpu.alloc async [%token] () : memref<8192x8192xf32>

        // Copy A and B from host to device asynchronously
        %copyA = gpu.memcpy async [%token] %A, %hA : memref<8192x8192xf16>, memref<8192x8192xf16>
        %copyB = gpu.memcpy async [%token] %B, %hB : memref<8192x8192xf16>, memref<8192x8192xf16>
        // Copy C from host to device asynchronously (if initialization on host is needed)
        %copyC = gpu.memcpy async [%token] %C, %hC : memref<8192x8192xf32>, memref<8192x8192xf32>

        //define block and grid xyz where each block is 16 x 16 threads
        //thread block granular to warp operations (256 threads per block)
        %blockX = arith.constant 16 : index
        %blockY = arith.constant 16 : index
        %blockZ = arith.constant 1 : index
        %gridX = arith.constant 512 : index
        %gridY = arith.constant 512 : index
        %gridZ = arith.constant 1 : index

        //define the gpu.launch blocks for the kernel, grid (2,4,1), block (16,16,1) for 64*32 result tile covered by threads
        gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %gridX, %grid_y = %gridY, %grid_z = %gridZ) threads(%tx, %ty, %tz) in (%block_x = %blockX, %block_y = %blockY, %block_z = %blockZ) {

            affine.for %i = 0 to 8192 step 128 { //128 rows of tile C/A
                affine.for %j = 0 to 8192 step 128 {//128 cols of tile C/B

                    // References to shared memory buffers.
                    %b_smem = memref.get_global @b_smem_global : memref<64x136xf16, 3>
                    %a_smem = memref.get_global @a_smem_global : memref<128x72xf16, 3>
                    //main k-loop
                    affine.for %k = 0 to 8192 step 64 {  //64 is cols A tile and rows B tile
                        // Copy loop for B tile
                        affine.for %copykk = #map0(%k) to #map1(%k) {//k to k + 64
                            affine.for %copyjj = #map0(%j) to #map2(%j) {//j to j + 128
                                %11 = affine.load %B[%copykk, %copyjj] : memref<8192x8192xf16>
                                affine.store %11, %b_smem[%copykk - %k, %copyjj - %j] : memref<64x136xf16, 3>
                            }
                        }
                        // Copy loop for A tile
                        affine.for %copyii = #map0(%i) to #map2(%i) {//i to i + 128
                            affine.for %copykk = #map0(%k) to #map1(%k) {//k to k + 64
                                %11 = affine.load %A[%copyii, %copykk] : memref<8192x8192xf16>
                                //%11 is the value of A at %copyii, %copykk
                                //copyii - i to index into the shared memory of size 128 * 72, the padded section remains as
                                affine.store %11, %a_smem[%copyii - %i, %copykk - %k] : memref<128x72xf16, 3>
                            }
                        }
                        //copied so iterate over the tiles
                        affine.for %ii = 0 to 128 step 64 {//rows A tile
                            affine.for %jj = 0 to 128 step 32 {//cols B tile
                                affine.for %kk = 0 to 64 step 32 {//cols A tile
                                //iterate the 64 * 32 A minitile, 32 * 32 B minitile
                                    affine.for %kkk = 0 to 32 step 16 {//2 steps
                                        affine.for %iii = 0 to 64 step 16 {//4 steps
                                            affine.for %jjj = 0 to 32 step 16 {//2 steps

                                                // Assuming %ii, %jj, %kk are the local loop variables for sub-tile indexing
                                                %localA_row = affine.apply #localA_row(%ii, %iii) // Local row index within %a_smem
                                                %localA_col = affine.apply #localA_col(%kk, %kkk) // Local column index within %a_smem
                                                %localB_col = affine.apply #localB_col(%jj, %jjj) // Local column index within %b_smem

                                                //%11 = %localA_row
                                                //%12 = %localA_col
                                                //%14 = %localB_col
                                                //%16 = globalC_row
                                                //%17 = %globalC_col
                                                // Calculate global indices for storing the result into matrix C
                                                %globalC_row = affine.apply #globalC_row(%i, %ii, %iii)
                                                %globalC_col = affine.apply #globalC_col(%j, %jj, %jjj)

                                                //A tile 128 * 72 load 16 * 16 fragment into warp matrix
                                                %a = gpu.subgroup_mma_load_matrix %a_smem[%localA_row, %localA_col] {leadDimension = 72 : index} : memref<128x72xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">

                                                // B tile is 64 * 136, load 16 * 16 fragment into warp matrix
                                                %b = gpu.subgroup_mma_load_matrix %b_smem[%localA_col, %localB_col] {leadDimension = 136 : index} : memref<64x136xf16, 3> -> !gpu.mma_matrix<16x16xf16, "BOp">

                                                //C matrix is 8192 * 8192, load 16 * 16 fragment into warp matrix
                                                %c = gpu.subgroup_mma_load_matrix %C[%localA_row, %localB_col] {leadDimension = 8192 : index} : memref<8192x8192xf32> -> !gpu.mma_matrix<16x16xf32, "COp">

                                                %res = gpu.subgroup_mma_compute %a, %b, %c : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf32, "COp">
                                                //stores the warp tile into global result matrix so need %16 and %17 to be global
                                                gpu.subgroup_mma_store_matrix %res, %C[%globalC_row, %globalC_col] {leadDimension = 8192 : index} : !gpu.mma_matrix<16x16xf32, "COp">, memref<8192x8192xf32>


                                                }
                                            }
                                        }
                                    }
                                }
                            }
                        }
                    }
                }
            // Print success message to indicate successful execution

            // CHECK: Success
            gpu.terminator
        }
        // Deallocate device memory for matrices A, B, and C asynchronously
        %zA = gpu.dealloc async [%token] %A : memref<8192x8192xf16>
        %zB = gpu.dealloc async [%token] %B : memref<8192x8192xf16>
        %zC = gpu.dealloc async [%token] %C : memref<8192x8192xf32>

        // Wait for all asynchronous operations to complete
        gpu.wait [%token]
        return
    }


    }
"""

module = Module.parse(src)

Exception ignored in: <function RAIIMLIRContext.__del__ at 0x7eddc1eaf490>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/mlir/extras/context.py", line 57, in __del__
RuntimeError: Unbalanced Location enter/exit


In [None]:
def find_ops(op, pred: Callable[[OpView, Operation, Module], bool], single=False):
  if isinstance(op, (OpView, Module)):
    op = op.operation

  matching = []

  def find(op: Operation):
    if single and len(matching):
      return
    for r in op.regions:
      for b in r.blocks:
        for o in b.operations:
          if pred(o):
            matching.append(o)
          find(o)

  find(op)
  if single and matching:
    matching = matching[0]
  return matching

def print_ptx(module):
  ptx = find_ops(module, lambda o: o.name == "gpu.binary", single=True)
  ptx = str(ptx.objects).replace("\\0A", "\n").replace("\\09", "\t")
  print(ptx)

pm = PassManager("any")
pm.add("gpu-lower-to-nvvm-pipeline{ cubin-chip=sm_75 cubin-features=+ptx75 cubin-format=isa }")
pm.run(module.operation)
print_ptx(module)

MLIRError: Failure while executing pass pipeline:
error: unknown: failed to legalize operation 'builtin.unrealized_conversion_cast' that was explicitly marked illegal
 note: unknown: see current operation: %13 = "builtin.unrealized_conversion_cast"(%5) : (i64) -> index

# Lower to NVVM (and LLVM)

In [None]:
backend = LLVMJITBackend([CUDA_RUNTIME_LIB_PATH])
# this doesn't actually anything (no pipeline) but does generate C API/wrappers
compiled_module = backend.compile(
    module,
    Pipeline().add_pass(
        "gpu-lower-to-nvvm-pipeline",
        **{
            "cubin-chip": "sm_75",
            "cubin-features": "+ptx75",
            "cubin-format": "fatbin",
        },
    ),
)
print(compiled_module)

module attributes {gpu.container_module} {
  llvm.func @malloc(i64) -> !llvm.ptr
  llvm.func @main() attributes {llvm.emit_c_interface} {
    %0 = llvm.mlir.constant(0 : i32) : i32
    %1 = llvm.mlir.constant(1 : i32) : i32
    %2 = llvm.mlir.constant(2 : i32) : i32
    %3 = llvm.mlir.constant(4 : i32) : i32
    %4 = llvm.mlir.constant(8 : i32) : i32
    %5 = llvm.mlir.constant(16 : i32) : i32
    %6 = llvm.mlir.constant(3 : i32) : i32
    %7 = llvm.mlir.constant(6 : i32) : i32
    %8 = llvm.mlir.constant(7 : i32) : i32
    %9 = llvm.mlir.constant(10 : i32) : i32
    %10 = llvm.mlir.constant(11 : i32) : i32
    %11 = llvm.mlir.constant(0 : index) : i64
    %12 = llvm.mlir.constant(1 : index) : i64
    %13 = llvm.mlir.constant(2 : index) : i64
    %14 = llvm.mlir.constant(3 : index) : i64
    %15 = llvm.mlir.constant(4 : index) : i64
    %16 = llvm.mlir.constant(5 : index) : i64
    %17 = llvm.mlir.constant(6 : index) : i64
    %18 = llvm.mlir.zero : !llvm.ptr
    %19 = llvm.getelementp

# Load and run

In [None]:
!pip install -q wurlitzer
from wurlitzer import pipes

In [None]:
with pipes() as (out, err):
    backend.load(compiled_module).main_capi_wrapper()

In [None]:
print(out.read())

Unranked Memref base@ = 0x567703fb08a0 rank = 1 offset = 0 sizes = [2] strides = [1] data = 
[0,  2]

