In [1]:
import os
import tvm
from tvm.script import ir as I
from tvm.script import tir as T
from tvm import autotvm, auto_scheduler
from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
from tvm import meta_schedule as ms
from tvm.ir import IRModule
from tvm import relax
from tvm import rpc
from tvm.contrib import utils, ndk
x_shape = 4096
w_w_x = 512
w_s_x = 128
w_y = 11008*2
func_name = "main"
@I.ir_module
class ModuleSrc:
    @T.prim_func
    def main(lv571: T.Buffer((T.int64(512), T.int64(w_y)), "uint32"), lv572: T.Buffer((T.int64(128), T.int64(w_y)), "float16"), lv1654: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(w_y)), "float16")):
        T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
        # with T.block("root"):
        p_output0_intermediate = T.alloc_buffer((T.int64(4096), T.int64(w_y)), "float16")
        for i, j in T.grid(T.int64(4096), T.int64(w_y)):
            with T.block("decode"):
                v_i, v_j = T.axis.remap("SS", [i, j])
                T.reads(lv571[v_i // T.int64(8), v_j], lv572[v_i // T.int64(32), v_j])
                T.writes(p_output0_intermediate[v_i, v_j])
                p_output0_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv571[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv572[v_i // T.int64(32), v_j]
        for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(w_y), T.int64(4096)):
            with T.block("matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(lv1654[v_i0, v_i1, v_k], p_output0_intermediate[v_k, v_i2])
                T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
                with T.init():
                    var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
                var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1654[v_i0, v_i1, v_k] * p_output0_intermediate[v_k, v_i2]

@I.ir_module
class ModuleToManual:
    @T.prim_func
    def main(lv571: T.Buffer((T.int64(512), T.int64(w_y)), "uint32"), lv572: T.Buffer((T.int64(128), T.int64(w_y)), "float16"), lv1654: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(w_y)), "float16")):
        T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
        # with T.block("root"):
        p_output0_intermediate = T.alloc_buffer((T.int64(4096), T.int64(w_y)), "float16")
        for i, j in T.grid(T.int64(4096), T.int64(w_y)):
            with T.block("decode"):
                v_i, v_j = T.axis.remap("SS", [i, j])
                T.reads(lv571[v_i // T.int64(8), v_j], lv572[v_i // T.int64(32), v_j])
                T.writes(p_output0_intermediate[v_i, v_j])
                p_output0_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv571[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv572[v_i // T.int64(32), v_j]
        for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(w_y), T.int64(4096)):
            with T.block("matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(lv1654[v_i0, v_i1, v_k], p_output0_intermediate[v_k, v_i2])
                T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
                with T.init():
                    var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
                var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1654[v_i0, v_i1, v_k] * p_output0_intermediate[v_k, v_i2]

In [2]:
## ref to mlc-llm/dispatch/dispatch_tir_operator_adreno.py
# 最优kernel: 0.962456ms
# vf vi	tx	bx
# 4	 4	128	43

def sch_fused_decode5_fused_matmul6_silu1(func):
    sch = tvm.tir.Schedule(func)
    b0 = sch.get_block(name="decode", func_name="main")
    b1 = sch.get_block(name="matmul", func_name="main")
    l2, l3, l4, l5 = sch.get_loops(block=b1)
    l6 = sch.fuse(l2, l3, l4, preserve_unit_iters=True)
    l10, l11, l12 = sch.split(loop=l6, factors=[43, 128, 4], preserve_unit_iters=True)
    v13, v14, v15 = sch.sample_perfect_tile(
        loop=l5, n=3, max_innermost_factor=8, decision=[128, 4, 8]
    )
    l16, l17, l18 = sch.split(
        loop=l5, factors=[v13, v14, v15], preserve_unit_iters=True
    )
    sch.reorder(l10, l11, l16, l17, l18, l12)
    sch.bind(loop=l10, thread_axis="blockIdx.x")
    sch.bind(loop=l11, thread_axis="threadIdx.x")
    sch.compute_inline(block=b0)
    b19 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope="local")
    sch.reverse_compute_at(block=b19, loop=l11, preserve_unit_loops=True, index=-1)
    b20 = sch.cache_read(block=b1, read_buffer_index=1, storage_scope="local")
    b21 = sch.cache_read(block=b1, read_buffer_index=2, storage_scope="local")
    b22 = sch.cache_read(block=b1, read_buffer_index=0, storage_scope="shared")
    sch.compute_at(block=b22, loop=l11, preserve_unit_loops=True, index=-1)
    v23 = sch.sample_categorical(
        candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=3
    )
    sch.annotate(
        block_or_loop=b22, ann_key="meta_schedule.cooperative_fetch", ann_val=v23
    )
    sch.compute_at(block=b20, loop=l17, preserve_unit_loops=True, index=-1)
    sch.compute_at(block=b21, loop=l16, preserve_unit_loops=True, index=-1)
    l24, l25, l26, l27, l28, l29 = sch.get_loops(block=b20)
    sch.vectorize(loop=l29)
    l30, l31, l32, l33, l34 = sch.get_loops(block=b21)
    sch.vectorize(loop=l34)
    l35, l36, l37, l38, l39 = sch.get_loops(block=b19)
    sch.vectorize(loop=l39)
    sch.vectorize(loop=l12)
    b40 = sch.decompose_reduction(block=b1, loop=l16)
    sch.enter_postproc()
    sch.unannotate(block_or_loop=b22, ann_key="meta_schedule.cooperative_fetch")
    l43, l44, l45, l46, l47 = sch.get_loops(block=b22)
    l48, l49, l50 = sch.split(loop=l47, factors=[None, 128, 4], preserve_unit_iters=True)
    sch.vectorize(loop=l50)
    sch.bind(loop=l49, thread_axis="threadIdx.x")
    return sch.mod["main"].with_attr("tir.is_scheduled", 1)


sch_manual = tvm.tir.Schedule(ModuleToManual)
# sch_fused_decode5_fused_matmul6_silu1(sch_manual.mod[func_name])
sch_manual.mod['main'] = sch_fused_decode5_fused_matmul6_silu1(sch_manual.mod[func_name])
# print(sch_manual.mod.script())
print("================================================")
rt_mod = tvm.build(sch_manual.mod, target="opencl")
print(rt_mod.imported_modules[0].get_source())

// Function: main_kernel
#ifdef cl_khr_fp16
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#elif defined(cl_amd_fp16)
#pragma OPENCL EXTENSION cl_amd_fp16 : enable
#else
#error "Half precision floating point not supported by OpenCL implementation on your device." 
#endif

__kernel void main_kernel(__global half* restrict lv1654, __global uint* restrict lv571, __global half* restrict lv572, __global half* restrict var_matmul_intermediate) {
  __local half lv1654_shared[4096];
  half4 var_matmul_intermediate_local[1];
  half4 lv572_local[1];
  uint4 lv571_local[1];
  for (int ax2_0 = 0; ax2_0 < 4; ++ax2_0) {
    vstore8(vload8(0, lv1654 + ((ax2_0 * 1024) + ((convert_int(get_local_id(0))) * 8))), 0, lv1654_shared + ((ax2_0 * 1024) + ((convert_int(get_local_id(0))) * 8)));
  }
  var_matmul_intermediate_local[0] = ((half4)((half)0.000000e+00f, (half)0.000000e+00f, (half)0.000000e+00f, (half)0.000000e+00f));
  barrier(CLK_LOCAL_MEM_FENCE);
  for (int k_0 = 0; k_0 < 128; ++k_0) {
    lv572_loc

In [3]:
# run and compare with cuda
import numpy as np
def _detect_local_cuda():
    dev = tvm.cuda()
    if not dev.exist:
        return None
    return tvm.target.Target(
        {
            "kind": "cuda",
            "max_shared_memory_per_block": dev.max_shared_memory_per_block,
            "max_threads_per_block": dev.max_threads_per_block,
            "thread_warp_size": dev.warp_size,
            "registers_per_block": 65536,
            "arch": "sm_" + tvm.cuda().compute_version.replace(".", ""),
        }
    )
# target = tvm.target.Target("cuda", host="llvm")
target = _detect_local_cuda()

print(target)
# 定义计算任务
dev = tvm.cuda(0)

num_flop = 1228406784
W_w_np = np.random.uniform(size=(w_w_x, w_y)).astype("uint32")
W_s_np = np.random.uniform(size=(w_s_x, w_y)).astype("float16")
Input_np = np.random.uniform(size=(1, 1, x_shape)).astype("float16")
# W_w_np = np.ones((w_w_x, w_y), np.uint32) * 1#.astype("uint32")
# W_s_np = np.ones((w_s_x, w_y), np.float16) * 1#.astype("float16") * 2
# Input_np = np.ones((1, 1, x_shape), np.float16)#.astype("float16")
Output_nd = tvm.nd.array(np.zeros((1, 1, w_y), dtype="float16"), dev)
def numpy_caculate():
    test_cols = 10
    output = np.zeros((1, 1, test_cols), dtype = np.float16)
    W_w_inv_np = np.transpose(W_w_np)
    W_s_inv_np = np.transpose(W_s_np)
    for i in range(test_cols):
        for r in range(x_shape):
            temp = Input_np[0][0][r] * np.float16((W_w_inv_np[i][r // 8] >> ((r % 8) * 4) & (15)) - np.float16(7.0)) * W_s_inv_np[i][r // 32]
            output[0][0][i] = output[0][0][i] + temp
    print(output)
    output = np.zeros((1, 1, test_cols), dtype = np.float16)
    for i in range(test_cols):
        for r in range(x_shape):
            temp = Input_np[0][0][r] * np.float16((W_w_np[r // 8][i] >> ((r % 8) * 4) & (15)) - np.float16(7.0)) * W_s_np[r // 32][i]
            temp_output = output[0][0][i]
            output[0][0][i] = temp_output + temp
            # print(f"{temp_output} + {temp} = {output[0][0][i]}")
    print(output)
numpy_caculate()
def print_npdata(np_data: np.ndarray) :
    print(np_data)
    print_num = 20
    d = np_data.flatten()
    p_size = print_num if d.size > print_num else d.size
    print(d[:p_size])

cuda -keys=cuda,gpu -arch=sm_61 -max_num_threads=1024 -max_shared_memory_per_block=49152 -max_threads_per_block=1024 -registers_per_block=65536 -thread_warp_size=32
[[[-7156. -6716. -6644. -6736. -6696. -6532. -6776. -5808. -6616. -6256.]]]
[[[-7156. -6716. -6644. -6736. -6696. -6532. -6776. -5808. -6616. -6256.]]]


In [25]:
# cuda未优化版本测试
sch = tvm.tir.Schedule(ModuleSrc)
with target:
    src_gpu_mod = tvm.tir.transform.DefaultGPUSchedule()(sch.mod) ##
rt_mod = tvm.build(src_gpu_mod, target="cuda")
W_w_nd = tvm.nd.array(W_w_np, dev)
W_s_nd = tvm.nd.array(W_s_np, dev)
Input_nd = tvm.nd.array(Input_np, dev)
Output_nd = tvm.nd.array(np.zeros((1, 1, w_y), dtype="float16"), dev)
evaluator = rt_mod.time_evaluator("main", dev, number=100)
print("manual_evaluator GEMV-Blocking: %f GFLOPS" % (num_flop / evaluator(W_w_nd, W_s_nd, Input_nd, Output_nd).mean / 1e9))
# print(Output_nd.numpy())
print_npdata(Output_nd.numpy())

manual_evaluator GEMV-Blocking: 65.352494 GFLOPS
[-5976. -6344. -7012. -5560. -6352. -5812. -7388. -6224. -6852. -6048.]


In [5]:
os.environ["TVM_NDK_CC"]="/home/sensetime/Android/Sdk/ndk/25.2.9519653/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android33-clang++"
target = tvm.target.Target("opencl -device=adreno", host="llvm -mtriple=aarch64-linux-gnu")
device_key="android"
rpc_host = "10.4.236.32"
rpc_port = 9190
comp_target = tvm.target.Target("opencl", host="llvm -mtriple=aarch64-linux-android")  # TODO: Only support arm64 for now

def test_opencl(mod: tvm.IRModule, name_hint: str):
    # mod = tvm.lower(sch_manual.mod)
    print("Build ...")
    android_rt_mod = tvm.build(mod, target="opencl", target_host="llvm -mtriple=aarch64-linux-android")
    # print(android_rt_mod.imported_modules[0].get_source())
    temp = utils.tempdir()
    path_dso_cl = temp.relpath("dev_lib_cl.so")
    android_rt_mod.export_library(path_dso_cl, ndk.create_shared)

    print("Run GPU(OpenCL Flavor) test ...")
    # Establish remote connection with target hardware

    tracker = rpc.connect_tracker(rpc_host, rpc_port)
    remote = tracker.request(device_key, priority=0, session_timeout=60)
    print("Connect to device done.")
    dev = remote.cl(0)
    remote.upload(path_dso_cl)
    f1 = remote.load_module("dev_lib_cl.so")

    W_w_nd = tvm.nd.array(W_w_np, dev)
    W_s_nd = tvm.nd.array(W_s_np, dev)
    Input_nd = tvm.nd.array(Input_np, dev)
    Output_nd = tvm.nd.array(np.zeros((1, 1, w_y), dtype="float16"), dev)
    test_number=32
    time_f = f1.time_evaluator(f1.entry_name, dev, number=test_number)
    cost = time_f(W_w_nd, W_s_nd, Input_nd, Output_nd).mean
    print("evaluator[%s] GEMV-Blocking: %fms with loop %d" % (name_hint, cost * 1000, test_number))
    print("evaluator[%s] GEMV-Blocking: %fGFLOPS" % (name_hint, num_flop / cost / 1e9))
    print(Output_nd.numpy())
    return cost * 1000

In [28]:
# 未优化版本opencl测试
from tvm import dlight as dl
sch = tvm.tir.Schedule(ModuleSrc)
with target:
    # src_gpu_mod = tvm.tir.transform.DefaultGPUSchedule()(sch.mod) ##
    mod_deploy = dl.ApplyDefaultSchedule(  # pylint: disable=not-callable
        dl.gpu.Matmul(),
        dl.gpu.GEMV(),
        dl.gpu.Reduction(),
        dl.gpu.GeneralReduction(),
        dl.gpu.Fallback(),
    )(sch.mod)
src_output = test_opencl(mod_deploy, "source")
print_npdata(src_output)


Build ...
Run GPU(OpenCL Flavor) test ...




Connect to device done.
evaluator[source] GEMV-Blocking: 4.269120 ms with loop 32
evaluator[source] GEMV-Blocking: 287.742388 GFLOPS
[-6600. -6816. -7668. -6056. -6908. -6464. -7976. -6960. -7600. -6672.]


In [6]:
#优化版本opencl测试
# print(sch_manual.mod)
opt_output = test_opencl(sch_manual.mod, "opted")
print_npdata(opt_output)
# np.testing.assert_equal(opt_output, src_output)

Build ...




Run GPU(OpenCL Flavor) test ...
Connect to device done.
evaluator[opted] GEMV-Blocking: 1.050112ms with loop 32
evaluator[opted] GEMV-Blocking: 1169.786446GFLOPS
[[[-7156. -6716. -6644. ... -6588. -6236. -6244.]]]
1.050112
[1.050112]


In [8]:
# 自动搜索
def auto_tune(record_file: str):
    from typing import Union
    def search(vf: int, vi: int, tx: int, bx: Union[int, None] = None):
        """search by workgroup

        Args:
            blockIdxX (_type_): blockIdx.x
            threadIdxX (_type_): threadIdx.x
            vectorize_output (_type_): 输出的vectorize参数, 决定单线程输出多少个结果
            vectorize_input (list, optional): 输入X拷贝到shared_memory时的vectorize参数, 一般为4或8

        Returns:
            _type_: _description_
        """
        @I.ir_module
        class ModuleToManual:
            @T.prim_func(private=False)
            def main(lv571: T.Buffer((T.int64(512), T.int64(w_y)), "uint32"), lv572: T.Buffer((T.int64(128), T.int64(w_y)), "float16"), lv1654: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(w_y)), "float16")):
                T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
                # with T.block("root"):
                p_output0_intermediate = T.alloc_buffer((T.int64(4096), T.int64(w_y)), "float16")
                for i, j in T.grid(T.int64(4096), T.int64(w_y)):
                    with T.block("decode"):
                        v_i, v_j = T.axis.remap("SS", [i, j])
                        T.reads(lv571[v_i // T.int64(8), v_j], lv572[v_i // T.int64(32), v_j])
                        T.writes(p_output0_intermediate[v_i, v_j])
                        p_output0_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv571[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv572[v_i // T.int64(32), v_j]
                for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(w_y), T.int64(4096)):
                    with T.block("matmul"):
                        v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                        T.reads(lv1654[v_i0, v_i1, v_k], p_output0_intermediate[v_k, v_i2])
                        T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
                        with T.init():
                            var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
                        var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1654[v_i0, v_i1, v_k] * p_output0_intermediate[v_k, v_i2]

        sch = tvm.tir.Schedule(ModuleToManual)
        b0 = sch.get_block(name="decode", func_name="main")
        b1 = sch.get_block(name="matmul", func_name="main")
        l2, l3, l4, l5 = sch.get_loops(block=b1)
        l6 = sch.fuse(l2, l3, l4, preserve_unit_iters=True)
        l10, l11, l12 = sch.split(loop=l6, factors=[bx, tx, vf], preserve_unit_iters=True)
        v13, v14, v15 = sch.sample_perfect_tile(
            loop=l5, n=3, max_innermost_factor=8, decision=[128, 4, 8]
        )
        l16, l17, l18 = sch.split(
            loop=l5, factors=[v13, v14, v15], preserve_unit_iters=True
        )
        sch.reorder(l10, l11, l16, l17, l18, l12)
        sch.bind(loop=l10, thread_axis="blockIdx.x")
        sch.bind(loop=l11, thread_axis="threadIdx.x")
        sch.compute_inline(block=b0)
        b19 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope="local")
        sch.reverse_compute_at(block=b19, loop=l11, preserve_unit_loops=True, index=-1)
        b20 = sch.cache_read(block=b1, read_buffer_index=1, storage_scope="local")
        b21 = sch.cache_read(block=b1, read_buffer_index=2, storage_scope="local")
        b22 = sch.cache_read(block=b1, read_buffer_index=0, storage_scope="shared")
        sch.compute_at(block=b22, loop=l11, preserve_unit_loops=True, index=-1)
        v23 = sch.sample_categorical(
            candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=[1, 2, 4, 8].index(vi)
        )
        sch.annotate(
            block_or_loop=b22, ann_key="meta_schedule.cooperative_fetch", ann_val=v23
        )
        sch.compute_at(block=b20, loop=l17, preserve_unit_loops=True, index=-1)
        sch.compute_at(block=b21, loop=l16, preserve_unit_loops=True, index=-1)
        l24, l25, l26, l27, l28, l29 = sch.get_loops(block=b20)
        sch.vectorize(loop=l29)
        l30, l31, l32, l33, l34 = sch.get_loops(block=b21)
        sch.vectorize(loop=l34)
        l35, l36, l37, l38, l39 = sch.get_loops(block=b19)
        sch.vectorize(loop=l39)
        sch.vectorize(loop=l12)
        b40 = sch.decompose_reduction(block=b1, loop=l16)
        sch.enter_postproc()
        sch.unannotate(block_or_loop=b22, ann_key="meta_schedule.cooperative_fetch")
        l43, l44, l45, l46, l47 = sch.get_loops(block=b22)
        l48, l49, l50 = sch.split(loop=l47, factors=[None, tx, vi], preserve_unit_iters=True)
        sch.vectorize(loop=l50)
        sch.bind(loop=l49, thread_axis="threadIdx.x")
        return sch.mod

    vec_factor = [2, 4, 8]
    vec_input = [2, 4, 8]
    blockx = [None]
    threadx = [8, 16, 32, 43, 64, 86, 128, 172, 256, 344, 512, 688]
    task_index = 0
    total_task_num = len(vec_factor)*len(vec_input)*len(blockx)*len(threadx)
    records = {}
    print(f"Total tasks: {total_task_num}")
    write_interval = 5
    from prettytable import PrettyTable
    table = PrettyTable()
    table.field_names = ["vf", "vi", "tx", "bx", "cost(ms)"]
    for vf in vec_factor:
        for vi in vec_input:
            for tx in threadx:
                    task_index = task_index + 1
                    import math
                    bx = w_y /(vf * tx)
                    if tx * vf > w_y or tx > 1024:
                        print(f"search record [{task_index}/{total_task_num}]: skip {vf} {vi} {tx} {bx}")
                        continue
                    print(f"search record [{task_index}/{total_task_num}]: start run {vf} {vi} {tx} {bx}")
                    bx_real = int(bx)
                    if w_y % (vf * tx) != 0:
                        bx_real = None
                    mod_deploy = search(vf, vi, tx, bx_real)
                    cost = test_opencl(mod_deploy, "search")
                    print("=====")
                    records[(vf, vi, tx, bx)] = cost

                    table.add_row([vf, vi, tx, bx, cost])
                    if task_index % write_interval == 0:
                        with open(record_file, 'wt') as f:
                            f.write(table.get_csv_string())
    print("================================")
    print(table)
    
    # record_sorted = sorted(record.items(), key=lambda x: x[1][0], reverse=True)
auto_tune("./manual_tune/fused_gate_up_tune_record_1.csv")

Total tasks: 108
search record [1/108]: start run 2 2 8 1376.0
Build ...




Run GPU(OpenCL Flavor) test ...
Connect to device done.
evaluator[search] GEMV-Blocking: 11.036448ms with loop 32
evaluator[search] GEMV-Blocking: 111.304541GFLOPS
[[[-7156. -6716. -6644. ... -6588. -6236. -6244.]]]
=====
search record [2/108]: start run 2 2 16 688.0
Build ...
Run GPU(OpenCL Flavor) test ...
Connect to device done.
evaluator[search] GEMV-Blocking: 5.313416ms with loop 32
evaluator[search] GEMV-Blocking: 231.189650GFLOPS
[[[-7156. -6716. -6644. ... -6588. -6236. -6244.]]]
=====
search record [3/108]: start run 2 2 32 344.0
Build ...
Run GPU(OpenCL Flavor) test ...
Connect to device done.
evaluator[search] GEMV-Blocking: 2.823272ms with loop 32
evaluator[search] GEMV-Blocking: 435.100403GFLOPS
[[[-7156. -6716. -6644. ... -6588. -6236. -6244.]]]
=====
search record [4/108]: start run 2 2 43 256.0
Build ...
Run GPU(OpenCL Flavor) test ...
Connect to device done.
evaluator[search] GEMV-Blocking: 2.310936ms with loop 32
evaluator[search] GEMV-Blocking: 531.562442GFLOPS
[[[-7

In [None]:
import numpy as np
target = tvm.target.Target("opencl -device=adreno", host="llvm -mtriple=aarch64-linux-gnu")
device_key="android"
rpc_host = "10.158.176.30"
rpc_port = 5001
# remote = autotvm.measure.request_remote(device_key, "10.158.176.30", 5001, timeout=10000)
# dev = remote.device(str(target), 0)

# num_flop = 1228406784
# W_np = np.random.uniform(size=(512, vocab_size)).astype("uint32")
# S_np = np.random.uniform(size=(128, vocab_size)).astype("float16")
# Input_np = np.random.uniform(size=(1, 1, 4096)).astype("float16")
# # Output_np = np.random.uniform(size=(1, 1, 4096)).astype("float16")
# W_nd = tvm.nd.array(W_np, dev)
# S_nd = tvm.nd.array(S_np, dev)
# Input_nd = tvm.nd.array(Input_np, dev)
# Output_nd = tvm.nd.array(np.zeros((1, 1, vocab_size), dtype="float32"), dev)

In [None]:
rpc_config = ms.runner.RPCConfig(tracker_host=rpc_host, tracker_port=rpc_port, tracker_key = device_key)
runner= ms.runner.RPCRunner(rpc_config)
# ms.builder.LocalBuilder()
sch = tvm.tir.Schedule(ModuleSrc)
database = ms.tune_tir(
    mod=ModuleSrc,
    target=target,
    max_trials_global=64,
    num_trials_per_iter=64,
    work_dir="./tune_first",
    cost_model="xgb",
    runner = runner
)
print(len(database))
sch1 = ms.tir_integration.compile_tir(database, sch.mod, target)
print(type(sch1))

In [None]:
from tvm.script import relax as R
@I.ir_module
class Module:
    @R.function
    def main(A: R.Tensor((3, 4), dtype="float32"), B: R.Tensor((4, 5), dtype="float32")):
        with R.dataflow():
            lv: R.Tensor((3, 5), dtype="float32") = R.matmul(A, B)
            gv: R.Tensor((3, 5), dtype="float32") = lv
            R.output(gv)
        return gv

In [None]:
## auto_scheduler test
from tvm import auto_scheduler
import numpy as np
a_np = np.random.rand(3, 4).astype("float32")
b_np = np.random.rand(4, 5).astype("float32")
a_nd = tvm.runtime.NDArray(a_np)
b_nd = tvm.runtime.NDArray(b_np)
sch = tvm.tir.Schedule(Module)

params = {"A": a_np, "B": b_np}
## 报错，这里只支持relay
# tasks = auto_scheduler.extract_tasks(sch.mod, params, target=target)
tasks = ms.relax_integration.extract_tasks(sch.mod, target=target, params=params)
print(len(tasks))

In [None]:

from mod_deploy import Module as ModuleAll
params_all = {}
tasks_all = auto_scheduler.extract_tasks(ModuleAll, params_all, target=target)
print(len(tasks_all))

In [None]:
import numpy as np
log_file = "tune.json"
def _detect_local_cuda():
    dev = tvm.cuda()
    if not dev.exist:
        return None
    return tvm.target.Target(
        {
            "kind": "cuda",
            "max_shared_memory_per_block": dev.max_shared_memory_per_block,
            "max_threads_per_block": dev.max_threads_per_block,
            "thread_warp_size": dev.warp_size,
            "registers_per_block": 65536,
            "arch": "sm_" + tvm.cuda().compute_version.replace(".", ""),
        }
    )
# target = tvm.target.Target("cuda", host="llvm")
target = _detect_local_cuda()

print(target)
# 定义计算任务
dev = tvm.cuda(0)

num_flop = 1228406784
W_np = np.random.uniform(size=(512, vocab_size)).astype("uint32")
S_np = np.random.uniform(size=(128, vocab_size)).astype("float16")
Input_np = np.random.uniform(size=(1, 1, 4096)).astype("float16")
# Output_np = np.random.uniform(size=(1, 1, 4096)).astype("float16")
W_nd = tvm.nd.array(W_np, dev)
S_nd = tvm.nd.array(S_np, dev)
Input_nd = tvm.nd.array(Input_np, dev)
Output_nd = tvm.nd.array(np.zeros((1, 1, vocab_size), dtype="float32"), dev)
sch = tvm.tir.Schedule(ModuleSrc)
new_mod = sch.mod


In [None]:
# task = auto_scheduler.SearchTask(func=sch.mod['fused_fused_decode11_fused_matmul5_cast2'], args=sch.mod['fused_fused_decode11_fused_matmul5_cast2'].params, target=target)

# tune_option = auto_scheduler.TuningOptions(
#     num_measure_trials=10,
#     measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
#     verbose=2,
# )


database = ms.tune_tir(
    mod=new_mod,
    target=target,
    max_trials_global=64,
    num_trials_per_iter=64,
    work_dir="./tune_45593_1",
    cost_model="xgb"
)
print(len(database))
sch1 = ms.tir_integration.compile_tir(database, new_mod, target)
print(type(sch1))

In [None]:
# print(sch1.trace)
# print(sch1.mod.script())
rt_mod = tvm.build(sch1.mod, target="cuda")

evaluator = rt_mod.time_evaluator("main", dev, number=100)

print("evaluator GEMV-Blocking: %f GFLOPS" % (1228406784 / evaluator(W_nd, S_nd, Input_nd, Output_nd).mean / 1e9))




In [None]:

record_database = ms.Database.create(kind='json', work_dir='./tune_45593_1')


In [None]:
record_sch = ms.tir_integration.compile_tir(record_database, new_mod, target)

record_rt_mod = tvm.build(record_sch.mod, target="cuda")

record_evaluator = record_rt_mod.time_evaluator("main", dev, number=20)

print("evaluator GEMV-Blocking: %f GFLOPS" % (num_flop / record_evaluator(W_nd, S_nd, Input_nd, Output_nd).mean / 1e9))
print(record_sch.trace)
print(record_sch.mod.script())

In [None]:
from typing import TYPE_CHECKING, Dict, List, Optional, Union, Callable
from tvm import runtime
if TYPE_CHECKING:
    import numpy as np  # type: ignore
    from tvm.ir import IRModule
    from tvm.meta_schedule.runner import EvaluatorConfig, RPCConfig
    from tvm.runtime import Device, Module, NDArray
    from tvm.target import Target
    from tvm.runtime.vm import Executable


def f_measurement(
    rt_mod: runtime.Module, device: runtime.ndarray.Device, input_data: Dict[str, runtime.NDArray]
):
    vm = relax.VirtualMachine(rt_mod, device=device)
    vm.save_function("main", "measure_func", **input_data, include_return=False)
    evaluator = vm.time_evaluator(
        func_name="measure_func",
        dev=device,
        repeat=100,
        number=1,
        min_repeat_ms=500,
    )
    return evaluator()

def run_module_via_rpc(
    rpc_config: "RPCConfig",
    lib: Union["Module", "Executable"],
    dev_type: str,
    args: Union[Dict[int, "np.ndarray"], Dict[str, "np.ndarray"]],
    continuation: Callable,
    backend: Optional[str] = "graph",
):
    """Execute a tvm.runtime.Module on RPC remote"""
    # pylint: disable=import-outside-toplevel
    import os
    import tempfile

    from tvm.contrib.tar import tar
    from tvm.runtime import ndarray

    # pylint: enable=import-outside-toplevel

    with tempfile.TemporaryDirectory() as tmp_dir:
        # filename = os.path.join(tmp_dir, "tvm_tmp_mod." + tar.output_format)
        filename = os.path.join(tmp_dir, "tvm_tmp_mod." + "so")
        if backend == "vm":
            code, lib = lib.save(filename, fmt="so")
        from tvm.contrib import ndk
        lib.export_library(filename, ndk.create_shared)
        session = rpc_config.connect_server()
        print(type(session._sess))
        session.upload(filename)
        _, filename = os.path.split(filename)
        rt_mod = session.load_module(filename)
        
        if backend == "vm":
            rt_mod = session.get_function("runtime.Load_Executable")(code, rt_mod)
            # rt_mod = session.get_function("runtime.module.loadfile_relax.Executable")(filename)
        dev = session.device(dev_type=dev_type, dev_id=0)
        # print(dev)
        # create the remote runtime module
        print(rt_mod)
        print(rt_mod['main'])
        from tvm.contrib import graph_executor as runtime
        module = runtime.GraphModule(rt_mod["main"](dev))
        print(module)
        for k, v in args.items():
            module.set_input(k, tvm.nd.array(v))
        return module.run()
        # nd_args = {k: ndarray.array(v, dev) for k, v in args.items()}
        nd_args = {k: ndarray.empty(v.shape, v.dtype, dev) for k, v in args.items()}
        return continuation(rt_mod, dev, nd_args)