In [1]:
# Imports

import tvm
from tvm import meta_schedule as ms
from tvm import relay
from tvm.script import tir as T
from tvm.tir import TensorIntrin

import numpy as np

In [None]:
@T.prim_func
def dot_product_relu_16x4_u8i8i32_desc(
    A: T.Buffer((4,), "uint8", offset_factor=1),
    B: T.Buffer((16, 4), "int8", offset_factor=1),
    C: T.Buffer((16,), "int32", offset_factor=1),
) -> None:
    with T.block("root"):
        T.reads(C[0:16], A[0:4], B[0:16, 0:4])
        T.writes(C[0:16])
        for i in T.serial(0, 16):
            for k in T.serial(0, 4):
                with T.block("update"):
                    vi, vk = T.axis.remap("SR", [i, k])
                    C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32")

In [None]:
@T.prim_func
def dot_product_relu_16x4_u8i8i32_avx512(
    A: T.Buffer((4,), "uint8", offset_factor=1),
    B: T.Buffer((16, 4), "int8", offset_factor=1),
    C: T.Buffer((16,), "int32", offset_factor=1),
) -> None:
    with T.block("root"):
        T.reads(C[0:16], A[0:4], B[0:16, 0:4])
        T.writes(C[0:16])

        A_u8x4 = A.vload([0], "uint8x4")
        A_i32 = T.reinterpret(A_u8x4, dtype="int32")
        A_brdcst = T.broadcast(A_i32, 16)
        A_u8x64 = T.reinterpret(A_brdcst, dtype="uint8x64")

        B_i8x64 = B.vload([0, 0], dtype="int8x64")

        Red = T.call_llvm_pure_intrin(
            T.llvm_lookup_intrinsic_id("llvm.x86.avx512.pmaddubs.w.512"),
            T.uint32(2),
            A_u8x64,
            B_i8x64,
            dtype="int16x32",
        )

        C[T.ramp(T.int32(0), 1, 16)] += T.call_llvm_pure_intrin(
            T.llvm_lookup_intrinsic_id("llvm.x86.avx512.pmaddw.d.512"),
            T.uint32(2),
            Red,
            T.int16x32(1),
            dtype="int32x16",
        )

In [None]:
SKYLAKE_AVX512_TARGET = "llvm -mcpu=skylake-avx512 -num-cores 4"

AVX512_DOT_16x4_INTRIN = "dot_16x4_avx512"

TensorIntrin.register(
    AVX512_DOT_16x4_INTRIN, dot_product_relu_16x4_u8i8i32_desc, dot_product_relu_16x4_u8i8i32_avx512
)

def _get_schedule_rules_for_x86(intrin):
    return [
        ms.schedule_rule.ApplyCustomRule(),
        ms.schedule_rule.AutoInline(
            into_producer=False,
            into_consumer=True,
            inline_const_tensor=True,
            disallow_if_then_else=True,
            require_injective=True,
            require_ordered=True,
            disallow_op=["tir.exp"],
        ),
        ms.schedule_rule.AddRFactor(max_jobs_per_core=16, max_innermost_factor=64),
        ms.schedule_rule.MultiLevelTilingWithIntrin(
            intrin,
            structure="SSRSRS",
            tile_binds=None,
            max_innermost_factor=64,
            vector_load_lens=None,
            reuse_read=None,
            reuse_write=ms.schedule_rule.ReuseType(
                req="may",
                levels=[1, 2],
                scope="global",
            ),
        ),
        ms.schedule_rule.MultiLevelTiling(
            structure="SSRSRS",
            tile_binds=None,
            max_innermost_factor=64,
            vector_load_lens=None,
            reuse_read=None,
            reuse_write=ms.schedule_rule.ReuseType(
                req="may",
                levels=[1, 2],
                scope="global",
            ),
        ),
        ms.schedule_rule.ParallelizeVectorizeUnroll(
            max_jobs_per_core=16,
            max_vectorize_extent=64,
            unroll_max_steps=[0, 16, 64, 512],
            unroll_explicit=True,
        ),
        ms.schedule_rule.RandomComputeLocation(),
    ]

SCH_RULES_FOR_AVX512 = _get_schedule_rules_for_x86(AVX512_DOT_16x4_INTRIN)

POSTPROCS_FOR_VNNI = [
    ms.postproc.DisallowDynamicLoop(),
    ms.postproc.RewriteParallelVectorizeUnroll(),
    ms.postproc.RewriteReductionBlock(),
    ms.postproc.RewriteTensorize(vectorize_init_loop=True),
]

In [None]:
#_test_dense("uint8", SCH_RULES_FOR_AVX512, POSTPROCS_FOR_VNNI, SKYLAKE_AVX512_TARGET)

"""dim_m, dim_n, dim_k = 1024, 1024, 1024
data_shape = (dim_m, dim_k)
weight_shape = (dim_n, dim_k)

weight_dtype = "int8"
out_dtype = "int32"
data_dtype = "uint8"

data = relay.var("data", shape=data_shape, dtype=data_dtype)
weight = relay.var("weight", shape=weight_shape, dtype=weight_dtype)
dense = relay.nn.dense(data, weight, out_dtype=out_dtype)

relay_mod = tvm.IRModule.from_expr(dense)

weight_np = np.random.uniform(1, 10, size=weight_shape).astype(weight_dtype)

params = {"weight": weight_np}
tune_tasks = list(
    filter(
        lambda task: "dense" in task.task_name,
        ms.relay_integration.extract_tasks(relay_mod, SKYLAKE_AVX512_TARGET, params),
    )
)"""

from tvm import te

N, M, L = 1024, 512, 64
# A = te.placeholder((N, L), name="A")
# B = te.placeholder((M, L), name="B")
# k = te.reduce_axis((0, L), name="k")
# C = te.compute((N, M), lambda i, j: te.sum(te.max(0, A[i, k] * B[j, k]), axis=k), name="C")

A = te.placeholder((L,), name='A', dtype='int8')
B = te.placeholder((M,L), name='B', dtype='uint8')
j = te.reduce_axis((0, L), 'j')

# Define the computation for matrix-vector multiplication
C = te.compute((M,), lambda i: te.sum(A[j] * B[i, j], axis=j), name='C')


func = te.create_prim_func([A,B,C])
print(func)
sch = tvm.tir.Schedule(func)

In [None]:
SCH_RULES_FOR_AVX512[3].apply(sch, sch.get_block("root"))

In [2]:
from tvm.tir.tensor_intrin.x86 import dot_product_16x4_u8i8i32_vnni

A_buf = tvm.tir.decl_buffer((4,), "uint8", "A")
B_buf = tvm.tir.decl_buffer((16, 4), "int8", "B")
C_buf = tvm.tir.decl_buffer((16,), "int32", "C")

target = "llvm -mcpu=skylake-avx512"
f = tvm.build(dot_product_16x4_u8i8i32_vnni, [A_buf, B_buf, C_buf], name="my_thing", target=target)

print(f.get_source())

error: InternalError: Check failed: (value_dtype.element_of() == n->dtype.element_of() && value_dtype.lanes() % n->dtype.lanes() == 0) is false: Cannot load int8x4 from buffer of int32
 --> /home/julien/tvm/python/tvm/tir/tensor_intrin/x86.py:57:20
    |  
 57 |          C_i32x16 = C.vload([0], dtype="int8x4")
    |                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^


DiagnosticError: Traceback (most recent call last):
  1: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void (tvm::DiagnosticContext)>::AssignTypedLambda<tvm::{lambda(tvm::DiagnosticContext)#9}>(tvm::{lambda(tvm::DiagnosticContext)#9}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  0: tvm::DiagnosticContext::Render()
  File "/home/julien/tvm/src/ir/diagnostic.cc", line 131
DiagnosticError: one or more error diagnostics were emitted, please check diagnostic render for output.

In [None]:
import tvm
from tvm import te, tir

# Define your PrimFunc (myfun)
A = te.placeholder((10,), name='A')
B = te.placeholder((10,), name='B')
C = te.compute((10,), lambda i: A[i] + B[i], name='C')
myfun = tir.PrimFunc([A, B], [C])

# Create a target (specify the target architecture)
target = "llvm"

# Build the function
f = tvm.build(myfun, target=target)

# Now, you can run the compiled function with appropriate input data
a_np = np.random.rand(10).astype('float32')
b_np = np.random.rand(10).astype('float32')
c_np = np.zeros(10, dtype='float32')
f(a_np, b_np, c_np)