In [1]:
import hashlib
from importlib import util
from pathlib import Path
import triton

In [2]:
code_fpath = Path('./compile_test.py')
kernel_name = 'add_kernel'
signature = '*fp32, *fp32, *fp32, i32, 1024'
num_warps = 1

In [3]:
spec = util.spec_from_file_location(code_fpath.stem, code_fpath)
mod = util.module_from_spec(spec)
spec.loader.exec_module(mod)
kernel = getattr(mod, kernel_name)
kernel

JITFunction(compile_test:add_kernel)

In [4]:
sig_args = signature.replace(',', '').split(' ')

In [5]:
hash_ = hashlib.sha256()
hash_.update(' '.join(sig_args).encode())
sig_hash = hash_.hexdigest()[:8]
sig_hash

'78484a0b'

In [6]:
hints = {i: int(arg.split(':')[1]) for i, arg in enumerate(sig_args) if ':' in arg}
assert all(h in [1, 16] for h in hints.values())
hints

{}

In [7]:
def arg2num(arg):
    try:
        num = int(arg)
        return num
    except ValueError:
        pass
    try:
        num = float(arg)
        return num
    except ValueError:
        pass
    return None

In [8]:
# constexprs = {i: num for i, arg in enumerate(sig_args) if (num := arg2num(arg)) is not None}
# sig_canon = {i: arg for i, arg in enumerate(sig_args) if i not in constexprs.keys()}
constexprs = dict()
sig_canon = dict()

for i, arg in enumerate(sig_args):
    if (num := arg2num(arg)) is None:
        sig_canon[i] = arg
    else:
        constexprs[i] = num
constexprs, sig_canon

({4: 1024}, {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: 'i32'})

In [9]:
config = triton.compiler.instance_descriptor(
    divisible_by_16=[i for i, h in hints.items() if h == 16],
    equal_to_1=[i for i, h in hints.items() if h == 1]
)
config

instance_descriptor(divisible_by_16=[], equal_to_1=[])

In [10]:
# ccinfo = triton.compile(kernel, signature=sig_canon, constants=constexprs, configs=[config], num_warps=num_warps)
# ccinfo

In [11]:
from triton._C.libtriton.triton import ir
import triton.compiler.compiler as ttc

In [12]:
arch = ttc.get_architecture_descriptor(None)
num_stages = 2 if arch < 75 else 3
context = ir.context()

In [13]:
ttir = ttc.ast_to_ttir(kernel, sig_canon, config, constexprs, debug=True, arch=arch)
ttir.dump()

module {
  tt.func public @add_kernel_0123(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32) attributes {noinline = false} {
    %0 = tt.get_program_id x : i32
    %c1024_i32 = arith.constant 1024 : i32
    %1 = arith.muli %0, %c1024_i32 : i32
    %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
    %3 = tt.splat %1 : (i32) -> tensor<1024xi32>
    %4 = arith.addi %3, %2 : tensor<1024xi32>
    %5 = tt.splat %arg3 : (i32) -> tensor<1024xi32>
    %6 = arith.cmpi slt, %4, %5 : tensor<1024xi32>
    %7 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>>
    %8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
    %9 = tt.load %8, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32>
    %10 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>>
    %11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
    %12 = tt.load %11, %6 {cache = 1 : i32, evict = 1 : i32, isVolatil

In [21]:
ttir_opt = ttc.optimize_ttir(ttir, arch)
ttir_opt.dump()

module {
  tt.func public @add_kernel_0123(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32) attributes {noinline = false} {
    %c1024_i32 = arith.constant 1024 : i32
    %0 = tt.get_program_id x : i32
    %1 = arith.muli %0, %c1024_i32 : i32
    %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
    %3 = tt.splat %1 : (i32) -> tensor<1024xi32>
    %4 = arith.addi %3, %2 : tensor<1024xi32>
    %5 = tt.splat %arg3 : (i32) -> tensor<1024xi32>
    %6 = arith.cmpi slt, %4, %5 : tensor<1024xi32>
    %7 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>>
    %8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
    %9 = tt.load %8, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32>
    %10 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>>
    %11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
    %12 = tt.load %11, %6 {cache = 1 : i32, evict = 1 : i32, isVolatil

In [26]:
ttgir = ttc.ttir_to_ttgir(ttir_opt, num_warps)
ttgir.dump()

module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
  tt.func public @add_kernel_0123(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32) attributes {noinline = false} {
    %c1024_i32 = arith.constant 1024 : i32
    %0 = tt.get_program_id x : i32
    %1 = arith.muli %0, %c1024_i32 : i32
    %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>>
    %3 = tt.splat %1 : (i32) -> tensor<1024xi32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>>
    %4 = arith.addi %3, %2 : tensor<1024xi32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>>
    %5 = tt.splat %arg3 : (i32) -> tensor<1024xi32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>>
    %6 = "triton_gpu.c

In [39]:
ttgir_opt = ttc.optimize_ttgir(ttgir, num_stages, arch)
ttgir_opt.dump()

module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
  tt.func public @add_kernel_0123(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32) attributes {noinline = false} {
    %c1024_i32 = arith.constant 1024 : i32
    %0 = tt.get_program_id x : i32
    %1 = arith.muli %0, %c1024_i32 : i32
    %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>>
    %3 = tt.splat %1 : (i32) -> tensor<1024xi32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>>
    %4 = arith.addi %3, %2 : tensor<1024xi32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>>
    %5 = tt.splat %arg3 : (i32) -> tensor<1024xi32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>>
    %6 = "triton_gpu.c

In [40]:
pm = ir.pass_manager(ttgir.context)
pm.enable_debug()
pm.add_tritongpu_coalesce_pass()
pm.run(ttgir)
ttgir.dump()

module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
  tt.func public @add_kernel_0123(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32) attributes {noinline = false} {
    %c1024_i32 = arith.constant 1024 : i32
    %0 = tt.get_program_id x : i32
    %1 = arith.muli %0, %c1024_i32 : i32
    %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>>
    %3 = tt.splat %1 : (i32) -> tensor<1024xi32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>>
    %4 = arith.addi %3, %2 : tensor<1024xi32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>>
    %5 = tt.splat %arg3 : (i32) -> tensor<1024xi32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>>
    %6 = "triton_gpu.c