![overview](TVMScript.svg)

In [1]:
import tvm
from tvm.script import ir as I, relax as R, tir as T


@I.ir_module
class Matmul:
    @R.function
    def main(
        a: R.Tensor((128, 128), "float32"), b: R.Tensor((128, 128), "float32")
    ) -> R.Tensor((128, 128), "float32"):
        out: R.Tensor((128, 128), "float32") = R.matmul(a, b)
        return out

In [2]:
print(type(Matmul))
print(type(Matmul["main"]))

<class 'tvm.ir.module.IRModule'>
<class 'tvm.relax.expr.Function'>


In [3]:
Matmul.show()

In [4]:
from tvm.relax.transform import LegalizeOps

mod = LegalizeOps()(Matmul)
mod.show()

In [5]:
from tvm.tir import schedule as sch

sch = sch.Schedule(mod)
block = sch.get_block("matmul", func_name="matmul")
sch.mod.show(black_format=False, obj_to_underline=[sch.get_sref(block).stmt.init])

In [6]:
from tvm.tir import TensorIntrin
import tvm.tir.tensor_intrin

intrinsic = TensorIntrin.get("wmma_fill_16x16x16_f32")
print("TensorIntrin.desc:")
intrinsic.desc.show()
print("TensorIntrin.impl:")
intrinsic.impl.show()


TensorIntrin.desc:


TensorIntrin.impl:


In [7]:
i, j, k = sch.get_loops(block)
i0, i1 = sch.split(i, [None, 16])
j0, j1 = sch.split(j, [None, 16])
sch.reorder(i0, j0, i1, j1, k)
sch.mod.show(
    black_format=False,
    obj_to_underline=[sch.get_sref(i1).stmt.extent, sch.get_sref(j1).stmt.extent],
)

In [8]:
frag = sch.cache_write(block, 0, "wmma.accumulator")
sch.mod.show(
    black_format=False, obj_to_underline=[sch.mod["matmul"].body.block.alloc_buffers[0]]
)

In [9]:
block_init = sch.decompose_reduction(block, i1)
sch.mod.show(black_format=False, obj_to_underline=[sch.get_sref(block_init).stmt])


In [10]:
sch.tensorize(sch.get_loops(block_init)[-2], "wmma_fill_16x16x16_f32")
sch.mod.show(
    black_format=False,
    obj_to_underline=[sch.get_sref(sch.get_block("matmul_init_o", "matmul")).stmt.body],
)