In [17]:
import os
import tvm
from tvm.script import ir as I
from tvm.script import tir as T
from tvm import autotvm, auto_scheduler
from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
from tvm import meta_schedule as ms
from tvm.ir import IRModule
from tvm import relax
from tvm import rpc
from tvm.contrib import utils, ndk
x_shape = 4096
w_w_x = 512
w_s_x = 128
w_y = 11008*2
func_name = "main"
@I.ir_module
class ModuleSrc:
    @T.prim_func(private=False)
    # fused_fused_decode4_NT_matmul3
    def fused_fused_decode4_NT_matmul3(lv13: T.Buffer((T.int64(512), T.int64(22016)), "uint32"), lv14: T.Buffer((T.int64(128), T.int64(22016)), "float16"), p_lv45: T.handle, p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        n = T.int64()
        lv45 = T.match_buffer(p_lv45, (T.int64(1), n, T.int64(4096)), "float16")
        var_NT_matmul_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(22016)), "float16")
        # with T.block("root"):
        decode = T.alloc_buffer((T.int64(4096), T.int64(22016)), "float16")
        p_output0_intermediate = T.alloc_buffer((T.int64(22016), T.int64(4096)), "float16")
        for i, j in T.grid(T.int64(4096), T.int64(22016)):
            with T.block("decode"):
                v_i, v_j = T.axis.remap("SS", [i, j])
                T.reads(lv13[v_i // T.int64(8), v_j], lv14[v_i // T.int64(32), v_j])
                T.writes(decode[v_i, v_j])
                decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv13[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv14[v_i // T.int64(32), v_j]
        for ax0, ax1 in T.grid(T.int64(22016), T.int64(4096)):
            with T.block("T_transpose"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(decode[v_ax1, v_ax0])
                T.writes(p_output0_intermediate[v_ax0, v_ax1])
                p_output0_intermediate[v_ax0, v_ax1] = decode[v_ax1, v_ax0]
        for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(22016), T.int64(4096)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(lv45[v_i0, v_i1, v_k], p_output0_intermediate[v_i2, v_k])
                T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2])
                with T.init():
                    var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
                var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv45[v_i0, v_i1, v_k] * p_output0_intermediate[v_i2, v_k]

@I.ir_module
class ModuleToManual:
    # 优化kernel配置
    #vf pr  bx  tx  ty
    #4	4	172	32	8
    @T.prim_func(private=False)
    def fused_decode1_fused_NT_matmul2_silu_after(
        lv36: T.Buffer((512, w_y), "uint32"),
        lv37: T.Buffer((128, w_y), "float16"),
        p_lv45: T.handle,
        p_output0: T.handle,
    ):
        T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1})
        n = T.int32()
        lv45 = T.match_buffer(p_lv45, (1, n, 4096), "float16")
        p_output0_intermediate = T.match_buffer(p_output0, (1, n, w_y), "float16")
        # with T.block("root"):
        decode_local = T.alloc_buffer((4096, w_y), "float16", scope="local")
        lv36_local = T.alloc_buffer((512, w_y), "uint32", scope="local")
        lv37_local = T.alloc_buffer((128, w_y), "float16", scope="local")
        lv45_pad_local = T.alloc_buffer(
            (1, (n + 31) // 32 * 32, 4096), "float16", scope="local"
        )
        var_NT_matmul_intermediate_pad_local = T.alloc_buffer(
            (1, (n + 31) // 32 * 32, w_y), "float16", scope="local"
        )

        # 任务划分:
        ### 一个thread处理 `processed_rows_per_thread`行 `vectorize_factor` 列(输出角度)
        ### 完整处理 `processed_rows_per_thread` 行输入需要: blockIdx.x * threadIdx.x 配合
        ### 完整处理 `n` 行输入需要: blockIdx.y * threadIdx.y 配合
        #### 分析: 根据`n`变化的只有 blockIdx.y, 说明 blockIdx.x * threadIdx.x * threadIdx.y 可以完整处理32行输入
        #  4 16 24 128 2
        BlockIdx_x = 86*2
        # n = 32
        # BlockIdx_y = (n+31)//32 * 32 # 这里32是假设输入为32的倍数, //32的32 = thready * 
        ThreadIdx_x = 16
        ThreadIdx_y = 8
        vectorize_factor = 8
        processed_columns_per_thread = vectorize_factor# w_y / (BlockIdx_x * ThreadIdx_x) == vectorize_factor
        processed_rows_per_thread = 4# == 32 / threadIdx.y

        for i0_i1_fused_0_i0_i1_fused_1_0_fused in T.thread_binding(
            (n + 31) // 32, thread="blockIdx.y"
        ):
            for i2_0 in T.thread_binding(BlockIdx_x, thread="blockIdx.x"):
                for i0_i1_fused_1_1 in T.thread_binding(ThreadIdx_y, thread="threadIdx.y"):
                    for i2_1 in T.thread_binding(ThreadIdx_x, thread="threadIdx.x"):
                        for i0_i1_fused_1_2_init in range(processed_rows_per_thread):
                            for i2_2_init in T.vectorized(vectorize_factor):
                                with T.block("NT_matmul_init"):
                                    v_i0 = T.axis.spatial(1, 0)
                                    v_i1 = T.axis.spatial(
                                        (n + 31) // 32 * 32,
                                        i0_i1_fused_0_i0_i1_fused_1_0_fused * 32
                                        + i0_i1_fused_1_1 * processed_rows_per_thread
                                        + i0_i1_fused_1_2_init,
                                    )
                                    v_i2 = T.axis.spatial(
                                        w_y, i2_0 * (ThreadIdx_x * vectorize_factor) + i2_1 * vectorize_factor + i2_2_init
                                    )
                                    T.reads()
                                    T.writes(
                                        var_NT_matmul_intermediate_pad_local[
                                            v_i0, v_i1, v_i2
                                        ]
                                    )
                                    var_NT_matmul_intermediate_pad_local[
                                        v_i0, v_i1, v_i2
                                    ] = T.float16(0)
                        for k_0 in range(128):
                            for ax0 in range(1):
                                for ax1 in T.vectorized(vectorize_factor):
                                    with T.block("lv37_local"):
                                        v0 = T.axis.spatial(128, k_0 + ax0)
                                        v1 = T.axis.spatial(
                                            w_y, i2_0 * (ThreadIdx_x * vectorize_factor) + i2_1 * vectorize_factor + ax1
                                        )
                                        T.reads(lv37[v0, v1])
                                        T.writes(lv37_local[v0, v1])
                                        lv37_local[v0, v1] = lv37[v0, v1]
                            for k_1 in range(4):
                                for ax0 in range(1):
                                    for ax1 in T.vectorized(vectorize_factor):
                                        with T.block("lv36_local"):
                                            v0 = T.axis.spatial(512, k_0 * 4 + k_1 + ax0)
                                            v1 = T.axis.spatial(
                                                w_y, i2_0 * (ThreadIdx_x * vectorize_factor) + i2_1 * vectorize_factor + ax1
                                            )
                                            T.reads(lv36[v0, v1])
                                            T.writes(lv36_local[v0, v1])
                                            lv36_local[v0, v1] = lv36[v0, v1]
                                for k_2 in range(8):
                                    for ax0 in range(1):
                                        for ax1 in T.vectorized(vectorize_factor):
                                            with T.block("decode"):
                                                v_i = T.axis.spatial(
                                                    4096, k_0 * 32 + k_1 * 8 + k_2 + ax0
                                                )
                                                v_j = T.axis.spatial(
                                                    w_y, i2_0 * (ThreadIdx_x * vectorize_factor) + i2_1 * vectorize_factor + ax1
                                                )
                                                T.reads(
                                                    lv36_local[v_i // 8, v_j],
                                                    lv37_local[v_i // 32, v_j],
                                                )
                                                T.writes(decode_local[v_i, v_j])
                                                decode_local[v_i, v_j] = (
                                                    T.Cast(
                                                        "float16",
                                                        T.bitwise_and(
                                                            T.shift_right(
                                                                lv36_local[v_i // 8, v_j],
                                                                T.Cast("uint32", v_i % 8)
                                                                * T.uint32(4),
                                                            ),
                                                            T.uint32(15),
                                                        ),
                                                    )
                                                    - T.float16(7)
                                                ) * lv37_local[v_i // 32, v_j]
                                    for ax0, ax1 in T.grid(1, processed_rows_per_thread):
                                        for ax2 in T.vectorized(1):
                                            with T.block("lv45_pad_local"):
                                                v0 = T.axis.spatial(1, ax0)
                                                v1 = T.axis.spatial(
                                                    (n + 31) // 32 * 32,
                                                    i0_i1_fused_0_i0_i1_fused_1_0_fused * 32
                                                    + i0_i1_fused_1_1 * processed_rows_per_thread
                                                    + ax1,
                                                )
                                                v2 = T.axis.spatial(
                                                    4096, k_0 * 32 + k_1 * 8 + k_2 + ax2
                                                )
                                                T.reads(lv45[v0, v1, v2])
                                                T.writes(lv45_pad_local[v0, v1, v2])
                                                lv45_pad_local[v0, v1, v2] = T.if_then_else(
                                                    v1 < n, lv45[v0, v1, v2], T.float16(0)
                                                )
                                    for i0_i1_fused_1_2 in range(processed_rows_per_thread):
                                        for i2_2 in T.vectorized(vectorize_factor):
                                            with T.block("NT_matmul_update"):
                                                v_i0 = T.axis.spatial(1, 0)
                                                v_i1 = T.axis.spatial(
                                                    (n + 31) // 32 * 32,
                                                    i0_i1_fused_0_i0_i1_fused_1_0_fused * 32
                                                    + i0_i1_fused_1_1 * processed_rows_per_thread
                                                    + i0_i1_fused_1_2,
                                                )
                                                v_i2 = T.axis.spatial(
                                                    w_y, i2_0 * (ThreadIdx_x * vectorize_factor) + i2_1 * vectorize_factor + i2_2
                                                )
                                                v_k = T.axis.reduce(
                                                    4096, k_0 * 32 + k_1 * 8 + k_2
                                                )
                                                T.reads(
                                                    var_NT_matmul_intermediate_pad_local[
                                                        v_i0, v_i1, v_i2
                                                    ],
                                                    lv45_pad_local[v_i0, v_i1, v_k],
                                                    decode_local[v_k, v_i2],
                                                )
                                                T.writes(
                                                    var_NT_matmul_intermediate_pad_local[
                                                        v_i0, v_i1, v_i2
                                                    ]
                                                )
                                                var_NT_matmul_intermediate_pad_local[
                                                    v_i0, v_i1, v_i2
                                                ] = (
                                                    var_NT_matmul_intermediate_pad_local[
                                                        v_i0, v_i1, v_i2
                                                    ]
                                                    + lv45_pad_local[v_i0, v_i1, v_k]
                                                    * decode_local[v_k, v_i2]
                                                )
                        for ax0, ax1 in T.grid(1, processed_rows_per_thread):
                            for ax2 in T.vectorized(vectorize_factor):
                                with T.block("var_NT_matmul_intermediate_pad_local"):
                                    v0 = T.axis.spatial(1, ax0)
                                    v1 = T.axis.spatial(
                                        (n + 31) // 32 * 32,
                                        i0_i1_fused_0_i0_i1_fused_1_0_fused * 32
                                        + i0_i1_fused_1_1 * processed_rows_per_thread
                                        + ax1,
                                    )
                                    v2 = T.axis.spatial(w_y, i2_0 * (ThreadIdx_x * vectorize_factor) + i2_1 * vectorize_factor + ax2)
                                    T.reads(
                                        var_NT_matmul_intermediate_pad_local[v0, v1, v2]
                                    )
                                    T.writes(p_output0_intermediate[v0, v1, v2])
                                    if v1 < n:
                                        p_output0_intermediate[
                                            v0, v1, v2
                                        ] = var_NT_matmul_intermediate_pad_local[
                                            v0, v1, v2
                                        ]
sch_manual = tvm.tir.Schedule(ModuleToManual)
print(sch_manual.mod.script())
print("================================================")
rt_mod = tvm.build(sch_manual.mod, target="opencl")
print(rt_mod.imported_modules[0].get_source())

# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def fused_decode1_fused_NT_matmul2_silu_after(lv36: T.Buffer((512, 22016), "uint32"), lv37: T.Buffer((128, 22016), "float16"), p_lv45: T.handle, p_output0: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        n = T.int32()
        lv45 = T.match_buffer(p_lv45, (1, n, 4096), "float16")
        p_output0_intermediate = T.match_buffer(p_output0, (1, n, 22016), "float16")
        # with T.block("root"):
        decode_local = T.alloc_buffer((4096, 22016), "float16", scope="local")
        lv36_local = T.alloc_buffer((512, 22016), "uint32", scope="local")
        lv37_local = T.alloc_buffer((128, 22016), "float16", scope="local")
        lv45_pad_local = T.alloc_buffer((1, (n + 31) // 32 * 32, 4096), "float16", scope="local")
        var_NT_matmul_intermediate_pad_local = T.alloc_buffer((1, (n + 31) // 32 * 32, 22016), "float16", scope="

In [20]:
# run and compare with cuda
import numpy as np
def _detect_local_cuda():
    dev = tvm.cuda()
    if not dev.exist:
        return None
    return tvm.target.Target(
        {
            "kind": "cuda",
            "max_shared_memory_per_block": dev.max_shared_memory_per_block,
            "max_threads_per_block": dev.max_threads_per_block,
            "thread_warp_size": dev.warp_size,
            "registers_per_block": 65536,
            "arch": "sm_" + tvm.cuda().compute_version.replace(".", ""),
        }
    )
# target = tvm.target.Target("cuda", host="llvm")
target = _detect_local_cuda()

print(target)
# 定义计算任务
dev = tvm.cuda(0)

num_flop = 1228406784
seq_len = 32
W_w_np = np.random.uniform(size=(w_w_x, w_y)).astype("uint32")
W_s_np = np.random.uniform(size=(w_s_x, w_y)).astype("float16")
Input_np = np.random.uniform(size=(1, seq_len, x_shape)).astype("float16")
# W_w_np = np.ones((w_w_x, w_y), np.uint32) * 1#.astype("uint32")
# W_s_np = np.ones((w_s_x, w_y), np.float16) * 1#.astype("float16") * 2
# Input_np = np.ones((1, 1, x_shape), np.float16)#.astype("float16")
Output_nd = tvm.nd.array(np.zeros((1, seq_len, w_y), dtype="float16"), dev)
def numpy_caculate():
    test_rows = 2
    test_cols = 10
    output = np.zeros((1, test_rows, test_cols), dtype = np.float16)
    W_w_inv_np = np.transpose(W_w_np)
    W_s_inv_np = np.transpose(W_s_np)
    for row in range(test_rows):
        for i in range(test_cols):
            for r in range(x_shape):
                temp = Input_np[0][row][r] * np.float16((W_w_inv_np[i][r // 8] >> ((r % 8) * 4) & (15)) - np.float16(7.0)) * W_s_inv_np[i][r // 32]
                output[0][row][i] = output[0][row][i] + temp
    print(output)
    output = np.zeros((1, test_rows, test_cols), dtype = np.float16)
    for row in range(test_rows):
        for i in range(test_cols):
            for r in range(x_shape):
                temp = Input_np[0][row][r] * np.float16((W_w_np[r // 8][i] >> ((r % 8) * 4) & (15)) - np.float16(7.0)) * W_s_np[r // 32][i]
                temp_output = output[0][row][i]
                output[0][row][i] = temp_output + temp
                # print(f"{temp_output} + {temp} = {output[0][0][i]}")
    print(output)
numpy_caculate()
def print_npdata(np_data: np.ndarray) :
    print(np_data)
    print_num = 20
    d = np_data.flatten()
    p_size = print_num if d.size > print_num else d.size
    print(d[:p_size])

cuda -keys=cuda,gpu -arch=sm_61 -max_num_threads=1024 -max_shared_memory_per_block=49152 -max_threads_per_block=1024 -registers_per_block=65536 -thread_warp_size=32
[[[-6520. -6740. -6072. -6788. -6452. -6532. -6972. -6236. -5940. -6476.]
  [-6644. -6584. -6044. -6836. -6420. -6560. -7028. -6240. -5872. -6444.]]]
[[[-6520. -6740. -6072. -6788. -6452. -6532. -6972. -6236. -5940. -6476.]
  [-6644. -6584. -6044. -6836. -6420. -6560. -7028. -6240. -5872. -6444.]]]


In [25]:
# cuda未优化版本测试
sch = tvm.tir.Schedule(ModuleSrc)
with target:
    src_gpu_mod = tvm.tir.transform.DefaultGPUSchedule()(sch.mod) ##
rt_mod = tvm.build(src_gpu_mod, target="cuda")
W_w_nd = tvm.nd.array(W_w_np, dev)
W_s_nd = tvm.nd.array(W_s_np, dev)
Input_nd = tvm.nd.array(Input_np, dev)
Output_nd = tvm.nd.array(np.zeros((1, seq_len, w_y), dtype="float16"), dev)
evaluator = rt_mod.time_evaluator("main", dev, number=100)
print("manual_evaluator GEMV-Blocking: %f GFLOPS" % (num_flop / evaluator(W_w_nd, W_s_nd, Input_nd, Output_nd).mean / 1e9))
# print(Output_nd.numpy())
print_npdata(Output_nd.numpy())

manual_evaluator GEMV-Blocking: 65.352494 GFLOPS
[-5976. -6344. -7012. -5560. -6352. -5812. -7388. -6224. -6852. -6048.]


In [21]:
os.environ["TVM_NDK_CC"]="/home/sensetime/Android/Sdk/ndk/25.2.9519653/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android33-clang++"
target = tvm.target.Target("opencl -device=adreno", host="llvm -mtriple=aarch64-linux-gnu")
device_key="android"
rpc_host = "10.4.236.32"
rpc_port = 9190
comp_target = tvm.target.Target("opencl", host="llvm -mtriple=aarch64-linux-android")  # TODO: Only support arm64 for now

def test_opencl(mod: tvm.IRModule, name_hint: str):
    # mod = tvm.lower(sch_manual.mod)
    print("Build ...")
    android_rt_mod = tvm.build(mod, target="opencl", target_host="llvm -mtriple=aarch64-linux-android")
    # print(android_rt_mod.imported_modules[0].get_source())
    temp = utils.tempdir()
    path_dso_cl = temp.relpath("dev_lib_cl.so")
    android_rt_mod.export_library(path_dso_cl, ndk.create_shared)

    print("Run GPU(OpenCL Flavor) test ...")
    # Establish remote connection with target hardware

    tracker = rpc.connect_tracker(rpc_host, rpc_port)
    remote = tracker.request(device_key, priority=0, session_timeout=60)
    print("Connect to device done.")
    dev = remote.cl(0)
    remote.upload(path_dso_cl)
    f1 = remote.load_module("dev_lib_cl.so")

    W_w_nd = tvm.nd.array(W_w_np, dev)
    W_s_nd = tvm.nd.array(W_s_np, dev)
    Input_nd = tvm.nd.array(Input_np, dev)
    Output_nd = tvm.nd.array(np.zeros((1, seq_len, w_y), dtype="float16"), dev)
    test_number=32
    time_f = f1.time_evaluator(f1.entry_name, dev, number=test_number)
    cost = time_f(W_w_nd, W_s_nd, Input_nd, Output_nd).mean
    print("evaluator[%s] GEMV-Blocking: %fms with loop %d" % (name_hint, cost * 1000, test_number))
    print("evaluator[%s] GEMV-Blocking: %fGFLOPS" % (name_hint, num_flop / cost / 1e9))

    print_npdata(Output_nd.numpy())
    # return Output_nd.numpy()
    return cost*1000 # unit: ms

In [18]:
# 未优化版本opencl测试
from tvm import dlight as dl
sch = tvm.tir.Schedule(ModuleSrc)
with target:
    # src_gpu_mod = tvm.tir.transform.DefaultGPUSchedule()(sch.mod) ##
    mod_deploy = dl.ApplyDefaultSchedule(  # pylint: disable=not-callable
        dl.gpu.Matmul(),
        dl.gpu.GEMV(),
        dl.gpu.Reduction(),
        dl.gpu.GeneralReduction(),
        dl.gpu.Fallback(),
    )(sch.mod)
src_output = test_opencl(mod_deploy, "source")
# print_npdata(src_output)


Build ...
Run GPU(OpenCL Flavor) test ...
Connect to device done.
evaluator[source] GEMV-Blocking: 600.584168ms with loop 32
evaluator[source] GEMV-Blocking: 2.045353GFLOPS
[[[-6700. -6796. -7064. ... -6468. -5900. -6428.]
  [-6952. -6956. -7324. ... -6400. -6240. -6600.]
  [-6624. -6784. -7156. ... -6344. -6028. -6528.]
  ...
  [-6664. -6732. -7168. ... -6280. -6040. -6420.]
  [-6740. -6864. -7136. ... -6444. -6228. -6640.]
  [-6728. -6792. -7188. ... -6316. -5984. -6464.]]]
0: -6700.0
1: -6796.0
2: -7064.0
3: -6724.0
4: -6352.0
5: -6300.0
6: -6412.0
7: -5916.0
8: -6324.0
9: -6544.0
10: -6220.0
11: -5920.0
12: -6012.0
13: -6660.0
14: -6704.0
15: -6652.0
16: -6240.0
17: -6644.0
18: -6556.0
19: -5712.0
20: -6240.0
21: -6096.0
22: -6540.0
23: -6768.0
24: -6604.0
25: -6840.0
26: -6512.0
27: -6760.0
28: -6588.0
29: -6532.0
30: -5776.0
31: -5924.0
32: -6728.0
33: -6132.0
34: -6356.0
35: -6532.0
36: -6740.0
37: -6232.0
38: -6508.0
39: -5868.0
40: -5988.0
41: -6440.0
42: -7312.0
43: -6748.0
4

In [19]:
#优化版本opencl测试
# print(sch_manual.mod)
opt_output = test_opencl(sch_manual.mod, "opted")
# print_npdata(opt_output)
np.testing.assert_equal(opt_output, src_output)

Build ...
Run GPU(OpenCL Flavor) test ...
Connect to device done.
evaluator[opted] GEMV-Blocking: 16.574752ms with loop 32
evaluator[opted] GEMV-Blocking: 74.113132GFLOPS
[[[-6700. -6796. -7064. ... -6468. -5900. -6428.]
  [-6952. -6956. -7324. ... -6400. -6240. -6600.]
  [-6624. -6784. -7156. ... -6344. -6028. -6528.]
  ...
  [-6664. -6732. -7168. ... -6280. -6040. -6420.]
  [-6740. -6864. -7136. ... -6444. -6228. -6640.]
  [-6728. -6792. -7188. ... -6316. -5984. -6464.]]]
0: -6700.0
1: -6796.0
2: -7064.0
3: -6724.0
4: -6352.0
5: -6300.0
6: -6412.0
7: -5916.0
8: -6324.0
9: -6544.0
10: -6220.0
11: -5920.0
12: -6012.0
13: -6660.0
14: -6704.0
15: -6652.0
16: -6240.0
17: -6644.0
18: -6556.0
19: -5712.0
20: -6240.0
21: -6096.0
22: -6540.0
23: -6768.0
24: -6604.0
25: -6840.0
26: -6512.0
27: -6760.0
28: -6588.0
29: -6532.0
30: -5776.0
31: -5924.0
32: -6728.0
33: -6132.0
34: -6356.0
35: -6532.0
36: -6740.0
37: -6232.0
38: -6508.0
39: -5868.0
40: -5988.0
41: -6440.0
42: -7312.0
43: -6748.0
44:

In [22]:
# 自动搜索
# 以32为倍数先搜一波
# @TODO: 探索更低的倍数，以降低padding的额外性能损耗
def auto_tune(record_file: str):
    from typing import Union
    def search(vf: int, pr: int, bx: int, tx: int, ty: int):
        """search by workgroup

        Args:
            blockIdxX (_type_): blockIdx.x
            threadIdxX (_type_): threadIdx.x
            vectorize_output (_type_): 输出的vectorize参数, 决定单线程输出多少个结果
            vectorize_input (list, optional): 输入X拷贝到shared_memory时的vectorize参数, 一般为4或8

        Returns:
            _type_: _description_
        """
        @I.ir_module
        class ModuleToManual:
            @T.prim_func(private=False)
            # fused_decode1_fused_NT_matmul2_silu_after
            def main(
                lv36: T.Buffer((512, w_y), "uint32"),
                lv37: T.Buffer((128, w_y), "float16"),
                p_lv45: T.handle,
                p_output0: T.handle,
            ):
                T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1})
                n = T.int32()
                lv45 = T.match_buffer(p_lv45, (1, n, 4096), "float16")
                p_output0_intermediate = T.match_buffer(p_output0, (1, n, w_y), "float16")
                # with T.block("root"):
                decode_local = T.alloc_buffer((4096, w_y), "float16", scope="local")
                lv36_local = T.alloc_buffer((512, w_y), "uint32", scope="local")
                lv37_local = T.alloc_buffer((128, w_y), "float16", scope="local")
                lv45_pad_local = T.alloc_buffer(
                    (1, (n + 31) // 32 * 32, 4096), "float16", scope="local"
                )
                var_NT_matmul_intermediate_pad_local = T.alloc_buffer(
                    (1, (n + 31) // 32 * 32, w_y), "float16", scope="local"
                )

                # 任务划分:
                ### 一个thread处理 `processed_rows_per_thread`行 `vectorize_factor` 列(输出角度)
                ### 完整处理 `processed_rows_per_thread` 行输入需要: blockIdx.x * threadIdx.x 配合
                ### 完整处理 `n` 行输入需要: blockIdx.y * threadIdx.y 配合
                #### 分析: 根据`n`变化的只有 blockIdx.y, 说明 blockIdx.x * threadIdx.x * threadIdx.y 可以完整处理32行输入
                BlockIdx_x = bx#32
                # n = 32
                # BlockIdx_y = (n+31)//32 * 32 # 这里32是假设输入为32的倍数, //32的32 = thready * 
                ThreadIdx_x = tx#16 * 3
                ThreadIdx_y = ty#8
                vectorize_factor = vf#8
                # processed_columns_per_thread = vectorize_factor# w_y / (BlockIdx_x * ThreadIdx_x) == vectorize_factor
                processed_rows_per_thread = pr#4

                ## BlockIdx.y == [BlockIdx.x, ThreadIdx.x, ThraedIdx.y] 解决 seq_length为32的处理
                for i0_i1_fused_0_i0_i1_fused_1_0_fused in T.thread_binding(
                    (n + 31) // 32, thread="blockIdx.y"
                ):
                    for i2_0 in T.thread_binding(BlockIdx_x, thread="blockIdx.x"):
                        for i0_i1_fused_1_1 in T.thread_binding(ThreadIdx_y, thread="threadIdx.y"):
                            for i2_1 in T.thread_binding(ThreadIdx_x, thread="threadIdx.x"):
                                for i0_i1_fused_1_2_init in range(processed_rows_per_thread):
                                    for i2_2_init in T.vectorized(vectorize_factor):
                                        with T.block("NT_matmul_init"):
                                            v_i0 = T.axis.spatial(1, 0)
                                            v_i1 = T.axis.spatial(
                                                (n + 31) // 32 * 32,
                                                i0_i1_fused_0_i0_i1_fused_1_0_fused * 32
                                                + i0_i1_fused_1_1 * processed_rows_per_thread
                                                + i0_i1_fused_1_2_init,
                                            )
                                            v_i2 = T.axis.spatial(
                                                w_y, i2_0 * (ThreadIdx_x * vectorize_factor) + i2_1 * vectorize_factor + i2_2_init
                                            )
                                            T.reads()
                                            T.writes(
                                                var_NT_matmul_intermediate_pad_local[
                                                    v_i0, v_i1, v_i2
                                                ]
                                            )
                                            var_NT_matmul_intermediate_pad_local[
                                                v_i0, v_i1, v_i2
                                            ] = T.float16(0)
                                for k_0 in range(128):
                                    for ax0 in range(1):
                                        for ax1 in T.vectorized(vectorize_factor):
                                            with T.block("lv37_local"):
                                                v0 = T.axis.spatial(128, k_0 + ax0)
                                                v1 = T.axis.spatial(
                                                    w_y, i2_0 * (ThreadIdx_x * vectorize_factor) + i2_1 * vectorize_factor + ax1
                                                )
                                                T.reads(lv37[v0, v1])
                                                T.writes(lv37_local[v0, v1])
                                                lv37_local[v0, v1] = lv37[v0, v1]
                                    for k_1 in range(4):
                                        for ax0 in range(1):
                                            for ax1 in T.vectorized(vectorize_factor):
                                                with T.block("lv36_local"):
                                                    v0 = T.axis.spatial(512, k_0 * 4 + k_1 + ax0)
                                                    v1 = T.axis.spatial(
                                                        w_y, i2_0 * (ThreadIdx_x * vectorize_factor) + i2_1 * vectorize_factor + ax1
                                                    )
                                                    T.reads(lv36[v0, v1])
                                                    T.writes(lv36_local[v0, v1])
                                                    lv36_local[v0, v1] = lv36[v0, v1]
                                        for k_2 in range(8):
                                            for ax0 in range(1):
                                                for ax1 in T.vectorized(vectorize_factor):
                                                    with T.block("decode"):
                                                        v_i = T.axis.spatial(
                                                            4096, k_0 * 32 + k_1 * 8 + k_2 + ax0
                                                        )
                                                        v_j = T.axis.spatial(
                                                            w_y, i2_0 * (ThreadIdx_x * vectorize_factor) + i2_1 * vectorize_factor + ax1
                                                        )
                                                        T.reads(
                                                            lv36_local[v_i // 8, v_j],
                                                            lv37_local[v_i // 32, v_j],
                                                        )
                                                        T.writes(decode_local[v_i, v_j])
                                                        decode_local[v_i, v_j] = (
                                                            T.Cast(
                                                                "float16",
                                                                T.bitwise_and(
                                                                    T.shift_right(
                                                                        lv36_local[v_i // 8, v_j],
                                                                        T.Cast("uint32", v_i % 8)
                                                                        * T.uint32(4),
                                                                    ),
                                                                    T.uint32(15),
                                                                ),
                                                            )
                                                            - T.float16(7)
                                                        ) * lv37_local[v_i // 32, v_j]
                                            for ax0, ax1 in T.grid(1, processed_rows_per_thread):
                                                for ax2 in T.vectorized(1):
                                                    with T.block("lv45_pad_local"):
                                                        v0 = T.axis.spatial(1, ax0)
                                                        v1 = T.axis.spatial(
                                                            (n + 31) // 32 * 32,
                                                            i0_i1_fused_0_i0_i1_fused_1_0_fused * 32
                                                            + i0_i1_fused_1_1 * processed_rows_per_thread
                                                            + ax1,
                                                        )
                                                        v2 = T.axis.spatial(
                                                            4096, k_0 * 32 + k_1 * 8 + k_2 + ax2
                                                        )
                                                        T.reads(lv45[v0, v1, v2])
                                                        T.writes(lv45_pad_local[v0, v1, v2])
                                                        lv45_pad_local[v0, v1, v2] = T.if_then_else(
                                                            v1 < n, lv45[v0, v1, v2], T.float16(0)
                                                        )
                                            for i0_i1_fused_1_2 in range(processed_rows_per_thread):
                                                for i2_2 in T.vectorized(vectorize_factor):
                                                    with T.block("NT_matmul_update"):
                                                        v_i0 = T.axis.spatial(1, 0)
                                                        v_i1 = T.axis.spatial(
                                                            (n + 31) // 32 * 32,
                                                            i0_i1_fused_0_i0_i1_fused_1_0_fused * 32
                                                            + i0_i1_fused_1_1 * processed_rows_per_thread
                                                            + i0_i1_fused_1_2,
                                                        )
                                                        v_i2 = T.axis.spatial(
                                                            w_y, i2_0 * (ThreadIdx_x * vectorize_factor) + i2_1 * vectorize_factor + i2_2
                                                        )
                                                        v_k = T.axis.reduce(
                                                            4096, k_0 * 32 + k_1 * 8 + k_2
                                                        )
                                                        T.reads(
                                                            var_NT_matmul_intermediate_pad_local[
                                                                v_i0, v_i1, v_i2
                                                            ],
                                                            lv45_pad_local[v_i0, v_i1, v_k],
                                                            decode_local[v_k, v_i2],
                                                        )
                                                        T.writes(
                                                            var_NT_matmul_intermediate_pad_local[
                                                                v_i0, v_i1, v_i2
                                                            ]
                                                        )
                                                        var_NT_matmul_intermediate_pad_local[
                                                            v_i0, v_i1, v_i2
                                                        ] = (
                                                            var_NT_matmul_intermediate_pad_local[
                                                                v_i0, v_i1, v_i2
                                                            ]
                                                            + lv45_pad_local[v_i0, v_i1, v_k]
                                                            * decode_local[v_k, v_i2]
                                                        )
                                for ax0, ax1 in T.grid(1, processed_rows_per_thread):
                                    for ax2 in T.vectorized(vectorize_factor):
                                        with T.block("var_NT_matmul_intermediate_pad_local"):
                                            v0 = T.axis.spatial(1, ax0)
                                            v1 = T.axis.spatial(
                                                (n + 31) // 32 * 32,
                                                i0_i1_fused_0_i0_i1_fused_1_0_fused * 32
                                                + i0_i1_fused_1_1 * processed_rows_per_thread
                                                + ax1,
                                            )
                                            v2 = T.axis.spatial(w_y, i2_0 * (ThreadIdx_x * vectorize_factor) + i2_1 * vectorize_factor + ax2)
                                            T.reads(
                                                var_NT_matmul_intermediate_pad_local[v0, v1, v2]
                                            )
                                            T.writes(p_output0_intermediate[v0, v1, v2])
                                            if v1 < n:
                                                p_output0_intermediate[
                                                    v0, v1, v2
                                                ] = var_NT_matmul_intermediate_pad_local[
                                                    v0, v1, v2
                                                ]
        return tvm.tir.Schedule(ModuleToManual).mod

    BlockIdx_x = [43, 86, 172, 344] # 32
    # n = 32
    # BlockIdx_y = (n+31)//32 * 32 # 这里32是假设输入为32的倍数, //32的32 = thready * 
    ThreadIdx_x = [None]#16 * 3
    # ThreadIdx_y = 8 # = 32 / processed_rows_per_thread
    vectorize_factor = [2, 4, 8]# 8
    # processed_columns_per_thread = vectorize_factor# w_y / (BlockIdx_x * ThreadIdx_x) == vectorize_factor
    processed_rows_per_thread = [1, 2, 4, 8, 16]#4
    task_index = 0
    total_task_num = len(vectorize_factor)*len(BlockIdx_x)*len(ThreadIdx_x)*len(processed_rows_per_thread)
    records = {}
    print(f"Total tasks: {total_task_num}")
    # try:
    vectorize_factor_r = vectorize_factor[::-1]
    print(vectorize_factor_r)
    processed_rows_per_thread_r = processed_rows_per_thread[::-1]
    BlockIdx_x_r = BlockIdx_x[::-1]
    # table
    write_interval = 5
    from prettytable import PrettyTable
    table = PrettyTable()
    table.field_names = ["vectorize_factor", "processed_rows_per_thread", "blockIdx.x", "threadIdx.x", "threadIdx.y", "cost(ms)"]
    for vf in vectorize_factor_r:
        for pr in processed_rows_per_thread_r:
            for bx in BlockIdx_x:
                task_index = task_index + 1
                import math
                tx = math.ceil(w_y /(vf * bx))
                ty = math.ceil(32//pr)
                if tx * vf >= w_y or tx*ty > 1024 or 32 % pr != 0: # w_y为输出列数, 工作组和vectorize相乘不能大于该数字
                    print(f"search record [{task_index}/{total_task_num}]: skip {vf} {pr} {bx} {tx} {ty}")
                    continue
                if w_y % (vf * bx) != 0:
                    print(f"search record [{task_index}/{total_task_num}]: skip because tx isn't divisible {vf} {pr} {bx} {tx} {ty}")
                    continue
                print(f"search record [{task_index}/{total_task_num}]: start run {vf} {pr} {bx} {tx} {ty}")
                # vf: int, pr: int, bx: int, tx: int, ty: int):
                mod_deploy = search(vf, pr, bx, tx, ty)
                cost = test_opencl(mod_deploy, "search")
                print("=====")
                records[(vf, pr, bx, tx, ty)] = cost
                table.add_row([vf, pr, bx, tx, ty, cost])
                if task_index % write_interval == 0:
                    with open(record_file, 'wt') as f:
                        f.write(table.get_csv_string())
    # except Exception as e:
    #     print(f"error occured: {e}")
    ### write file
    # from prettytable import PrettyTable
    # table = PrettyTable()
    # table.field_names = ["vectorize_factor", "processed_rows_per_thread", "blockIdx.x", "threadIdx.x", "threadIdx.y", "cost(ms)"]
    # for config, cost in records.items():
    #     table.add_row([config[0], config[1], config[2], config[3], config[4], cost])
    #     print(f"{config}: {cost}ms")
    print("================================")
    print(table)
    with open(record_file, 'wt') as f:
        f.write(table.get_csv_string())
    
    # record_sorted = sorted(record.items(), key=lambda x: x[1][0], reverse=True)
auto_tune("./manual_tune/gate_up_fused_n_tune_record_1.csv")

Total tasks: 60
[8, 4, 2]
search record [1/60]: start run 8 16 43 64 2
Build ...
Run GPU(OpenCL Flavor) test ...




Connect to device done.
evaluator[search] GEMV-Blocking: 92.023272ms with loop 32
evaluator[search] GEMV-Blocking: 13.348871GFLOPS
[[[-6520. -6740. -6072. ... -6836. -6408. -6456.]
  [-6644. -6584. -6044. ... -6760. -6464. -6568.]
  [-6708. -6764. -6168. ... -6824. -6480. -6560.]
  ...
  [-6340. -6456. -5812. ... -6612. -6204. -6356.]
  [-6544. -6664. -5908. ... -6632. -6352. -6392.]
  [-6704. -6628. -6060. ... -6820. -6528. -6556.]]]
[-6520. -6740. -6072. -6780. -6456. -6528. -6972. -6236. -5940. -6476.
 -6632. -6572. -6384. -6692. -6756. -6824. -7268. -6576. -6184. -6952.]
=====
search record [2/60]: start run 8 16 86 32 2
Build ...
Run GPU(OpenCL Flavor) test ...
Connect to device done.
evaluator[search] GEMV-Blocking: 181.405136ms with loop 32
evaluator[search] GEMV-Blocking: 6.771621GFLOPS
[[[-6520. -6740. -6072. ... -6836. -6408. -6456.]
  [-6644. -6584. -6044. ... -6760. -6464. -6568.]
  [-6708. -6764. -6168. ... -6824. -6480. -6560.]
  ...
  [-6340. -6456. -5812. ... -6612. -62

In [None]:
import numpy as np
target = tvm.target.Target("opencl -device=adreno", host="llvm -mtriple=aarch64-linux-gnu")
device_key="android"
rpc_host = "10.158.176.30"
rpc_port = 5001
# remote = autotvm.measure.request_remote(device_key, "10.158.176.30", 5001, timeout=10000)
# dev = remote.device(str(target), 0)

# num_flop = 1228406784
# W_np = np.random.uniform(size=(512, vocab_size)).astype("uint32")
# S_np = np.random.uniform(size=(128, vocab_size)).astype("float16")
# Input_np = np.random.uniform(size=(1, 1, 4096)).astype("float16")
# # Output_np = np.random.uniform(size=(1, 1, 4096)).astype("float16")
# W_nd = tvm.nd.array(W_np, dev)
# S_nd = tvm.nd.array(S_np, dev)
# Input_nd = tvm.nd.array(Input_np, dev)
# Output_nd = tvm.nd.array(np.zeros((1, 1, vocab_size), dtype="float32"), dev)

In [None]:
rpc_config = ms.runner.RPCConfig(tracker_host=rpc_host, tracker_port=rpc_port, tracker_key = device_key)
runner= ms.runner.RPCRunner(rpc_config)
# ms.builder.LocalBuilder()
sch = tvm.tir.Schedule(ModuleSrc)
database = ms.tune_tir(
    mod=ModuleSrc,
    target=target,
    max_trials_global=64,
    num_trials_per_iter=64,
    work_dir="./tune_first",
    cost_model="xgb",
    runner = runner
)
print(len(database))
sch1 = ms.tir_integration.compile_tir(database, sch.mod, target)
print(type(sch1))

In [None]:
from tvm.script import relax as R
@I.ir_module
class Module:
    @R.function
    def main(A: R.Tensor((3, 4), dtype="float32"), B: R.Tensor((4, 5), dtype="float32")):
        with R.dataflow():
            lv: R.Tensor((3, 5), dtype="float32") = R.matmul(A, B)
            gv: R.Tensor((3, 5), dtype="float32") = lv
            R.output(gv)
        return gv

In [None]:
## auto_scheduler test
from tvm import auto_scheduler
import numpy as np
a_np = np.random.rand(3, 4).astype("float32")
b_np = np.random.rand(4, 5).astype("float32")
a_nd = tvm.runtime.NDArray(a_np)
b_nd = tvm.runtime.NDArray(b_np)
sch = tvm.tir.Schedule(Module)

params = {"A": a_np, "B": b_np}
## 报错，这里只支持relay
# tasks = auto_scheduler.extract_tasks(sch.mod, params, target=target)
tasks = ms.relax_integration.extract_tasks(sch.mod, target=target, params=params)
print(len(tasks))

In [None]:

from mod_deploy import Module as ModuleAll
params_all = {}
tasks_all = auto_scheduler.extract_tasks(ModuleAll, params_all, target=target)
print(len(tasks_all))

In [None]:
import numpy as np
log_file = "tune.json"
def _detect_local_cuda():
    dev = tvm.cuda()
    if not dev.exist:
        return None
    return tvm.target.Target(
        {
            "kind": "cuda",
            "max_shared_memory_per_block": dev.max_shared_memory_per_block,
            "max_threads_per_block": dev.max_threads_per_block,
            "thread_warp_size": dev.warp_size,
            "registers_per_block": 65536,
            "arch": "sm_" + tvm.cuda().compute_version.replace(".", ""),
        }
    )
# target = tvm.target.Target("cuda", host="llvm")
target = _detect_local_cuda()

print(target)
# 定义计算任务
dev = tvm.cuda(0)

num_flop = 1228406784
W_np = np.random.uniform(size=(512, vocab_size)).astype("uint32")
S_np = np.random.uniform(size=(128, vocab_size)).astype("float16")
Input_np = np.random.uniform(size=(1, 1, 4096)).astype("float16")
# Output_np = np.random.uniform(size=(1, 1, 4096)).astype("float16")
W_nd = tvm.nd.array(W_np, dev)
S_nd = tvm.nd.array(S_np, dev)
Input_nd = tvm.nd.array(Input_np, dev)
Output_nd = tvm.nd.array(np.zeros((1, 1, vocab_size), dtype="float32"), dev)
sch = tvm.tir.Schedule(ModuleSrc)
new_mod = sch.mod


In [None]:
# task = auto_scheduler.SearchTask(func=sch.mod['fused_fused_decode11_fused_matmul5_cast2'], args=sch.mod['fused_fused_decode11_fused_matmul5_cast2'].params, target=target)

# tune_option = auto_scheduler.TuningOptions(
#     num_measure_trials=10,
#     measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
#     verbose=2,
# )


database = ms.tune_tir(
    mod=new_mod,
    target=target,
    max_trials_global=64,
    num_trials_per_iter=64,
    work_dir="./tune_45593_1",
    cost_model="xgb"
)
print(len(database))
sch1 = ms.tir_integration.compile_tir(database, new_mod, target)
print(type(sch1))

In [None]:
# print(sch1.trace)
# print(sch1.mod.script())
rt_mod = tvm.build(sch1.mod, target="cuda")

evaluator = rt_mod.time_evaluator("main", dev, number=100)

print("evaluator GEMV-Blocking: %f GFLOPS" % (1228406784 / evaluator(W_nd, S_nd, Input_nd, Output_nd).mean / 1e9))




In [None]:

record_database = ms.Database.create(kind='json', work_dir='./tune_45593_1')


In [None]:
record_sch = ms.tir_integration.compile_tir(record_database, new_mod, target)

record_rt_mod = tvm.build(record_sch.mod, target="cuda")

record_evaluator = record_rt_mod.time_evaluator("main", dev, number=20)

print("evaluator GEMV-Blocking: %f GFLOPS" % (num_flop / record_evaluator(W_nd, S_nd, Input_nd, Output_nd).mean / 1e9))
print(record_sch.trace)
print(record_sch.mod.script())

In [None]:
from typing import TYPE_CHECKING, Dict, List, Optional, Union, Callable
from tvm import runtime
if TYPE_CHECKING:
    import numpy as np  # type: ignore
    from tvm.ir import IRModule
    from tvm.meta_schedule.runner import EvaluatorConfig, RPCConfig
    from tvm.runtime import Device, Module, NDArray
    from tvm.target import Target
    from tvm.runtime.vm import Executable


def f_measurement(
    rt_mod: runtime.Module, device: runtime.ndarray.Device, input_data: Dict[str, runtime.NDArray]
):
    vm = relax.VirtualMachine(rt_mod, device=device)
    vm.save_function("main", "measure_func", **input_data, include_return=False)
    evaluator = vm.time_evaluator(
        func_name="measure_func",
        dev=device,
        repeat=100,
        number=1,
        min_repeat_ms=500,
    )
    return evaluator()

def run_module_via_rpc(
    rpc_config: "RPCConfig",
    lib: Union["Module", "Executable"],
    dev_type: str,
    args: Union[Dict[int, "np.ndarray"], Dict[str, "np.ndarray"]],
    continuation: Callable,
    backend: Optional[str] = "graph",
):
    """Execute a tvm.runtime.Module on RPC remote"""
    # pylint: disable=import-outside-toplevel
    import os
    import tempfile

    from tvm.contrib.tar import tar
    from tvm.runtime import ndarray

    # pylint: enable=import-outside-toplevel

    with tempfile.TemporaryDirectory() as tmp_dir:
        # filename = os.path.join(tmp_dir, "tvm_tmp_mod." + tar.output_format)
        filename = os.path.join(tmp_dir, "tvm_tmp_mod." + "so")
        if backend == "vm":
            code, lib = lib.save(filename, fmt="so")
        from tvm.contrib import ndk
        lib.export_library(filename, ndk.create_shared)
        session = rpc_config.connect_server()
        print(type(session._sess))
        session.upload(filename)
        _, filename = os.path.split(filename)
        rt_mod = session.load_module(filename)
        
        if backend == "vm":
            rt_mod = session.get_function("runtime.Load_Executable")(code, rt_mod)
            # rt_mod = session.get_function("runtime.module.loadfile_relax.Executable")(filename)
        dev = session.device(dev_type=dev_type, dev_id=0)
        # print(dev)
        # create the remote runtime module
        print(rt_mod)
        print(rt_mod['main'])
        from tvm.contrib import graph_executor as runtime
        module = runtime.GraphModule(rt_mod["main"](dev))
        print(module)
        for k, v in args.items():
            module.set_input(k, tvm.nd.array(v))
        return module.run()
        # nd_args = {k: ndarray.array(v, dev) for k, v in args.items()}
        nd_args = {k: ndarray.empty(v.shape, v.dtype, dev) for k, v in args.items()}
        return continuation(rt_mod, dev, nd_args)