# Introduction to Unity: TVMScript

## TVMScript parser

### Pipeline

    Python object/source code -> Python AST -> TVMScript IRBuilder -> TVM IR

### Parse by decorator

TVMScript provides several python decorators for different supported IR's:
 - `@I.ir_module` - `I` for base IR
 - `@T.prim_func` - `T` for TensorIR(TIR)
 - `@R.function` - `R` for Relax IR
    
And TVMScript supports cross-IR parsing and printing. So we are able to parse their own python modules and functions, with any IR's.

Let us start with a simple matrix multiplication based on TIR and Relax each. 

In [1]:
import tvm
from tvm.script import ir as I, tir as T, relax as R


@I.ir_module
class matmul_module:
    @R.function
    def matmul_relax(
        A: R.Tensor((128, 128), dtype="float32"),
        B: R.Tensor((128, 128), dtype="float32"),
    ) -> R.Tensor((128, 128), dtype="float32"):
        C: R.Tensor((128, 128), dtype="float32") = R.matmul(A, B)
        return C

    @T.prim_func
    def matmul_tir(
        A: T.Buffer((128, 128), dtype="float32"),
        B: T.Buffer((128, 128), dtype="float32"),
        C: T.Buffer((128, 128), dtype="float32"),
    ):
        for i, j, k in T.grid(128, 128, 128):
            with T.block("update"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    C[vi, vj] = 0
                C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]


Above is a `128 * 128` elements matrix multiplication with data type of `float32`. Let us print out what the decorators return.

Note: In `IRModule`, the member functions are called as `IRModule["func"]` instead of `IRModule.func`.

In [2]:
print(type(matmul_module))
print(type(matmul_module["matmul_relax"]))
print(type(matmul_module["matmul_tir"]))


<class 'tvm.ir.module.IRModule'>
<class 'tvm.relax.expr.Function'>
<class 'tvm.tir.function.PrimFunc'>


Then let us print the `IRModule` directly. We get an `IRModule` source code in `str`. In fact, the TVMScript printer is called implicitly.

In [3]:
print(matmul_module)


# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func
    def matmul_tir(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
        # with T.block("root"):
        for i, j, k in T.grid(128, 128, 128):
            with T.block("update"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                T.reads(A[vi, vk], B[vk, vj])
                T.writes(C[vi, vj])
                with T.init():
                    C[vi, vj] = T.float32(0)
                C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]

    @R.function
    def matmul_relax(A: R.Tensor((128, 128), dtype="float32"), B: R.Tensor((128, 128), dtype="float32")) -> R.Tensor((128, 128), dtype="float32"):
        C: R.Tensor((128, 128), dtype="float32") = R.matmul(A, B, out_dtype="void")
        return C


This is a tpyical example of calling TVMScript parser via decorators. However, if we recall our parsr pipeline:

    Python object/source code -> Python AST -> TVMScript IRBuilder -> TVM IR

we notice that there is a component called TVMScript IRBuilder. This component allows us to build IR's more flexibly.
So it is a better alternative for IR factory function. Let us build an any-rank matrix multiplication generator via IRBuilder.

In [4]:
from tvm.script.ir_builder import IRBuilder, relax as relax_builder, ir as ir_builder


def matmul_gen(a_shape, b_shape):
    with IRBuilder() as builder:
        with ir_builder.ir_module():
            with relax_builder.function():
                R.func_name("matmul_relax")
                A = R.arg("A", R.Tensor(a_shape, dtype="float32"))
                B = R.arg("B", R.Tensor(b_shape, dtype="float32"))
                R.func_ret_value(R.matmul(A, B))

    return builder.get()


print(matmul_gen(a_shape=(128, 128, 64), b_shape=(128, 64, 128)))

# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def matmul_relax(A: R.Tensor((128, 128, 64), dtype="float32"), B: R.Tensor((128, 64, 128), dtype="float32")) -> R.Tensor((128, 128, 128), dtype="float32"):
        gv: R.Tensor((128, 128, 128), dtype="float32") = R.matmul(A, B, out_dtype="void")
        return gv


## TVMScript printer

### print explicitly

Actually, we have already called the TVMScript printer in above codes implicitly, while the TVMScript printer provides several interface explicitly. The function `script()` works as the basic printing method.

In [5]:
assert matmul_module.script() == str(matmul_module)


In addition, if we would like to have a highlighted code, just call `show()` instead.

In [6]:
matmul_module.show()


### printing decoration

When printing explicitly, TVMScript printer offers various options to decorate the rendered codes. We may underlines some parts of the codes, or annotate some statements by comments.

In [7]:
matmul_module.show(
    black_format=False,
    obj_to_annotate={
        matmul_module["matmul_tir"].body.block.body.body.body.body: "Annotation 1\nAnnotation 2\nAnnotation 3"
    },
    obj_to_underline=[
        matmul_module["matmul_tir"].body.block.body.body.body.body.block.body,
    ],
)

### error report

Besides of users, TVM itself calls TVMScript printer for error reporting as well. The printing decoration is design to unify the TVM output for IR in some extent.
Let us takes `assert_structural_equal` as an example. `assert_structural_equal` is a helpful function for us to check the structural equality between two IR's.

In [8]:
from tvm.ir import assert_structural_equal


@T.prim_func
def matmul_tir_with_typo(
    A: T.Buffer((128, 128), dtype="float32"),
    B: T.Buffer((128, 128), dtype="float32"),
    C: T.Buffer((128, 128), dtype="float32"),
):
    for i, j, k in T.grid(128, 128, 256):
        with T.block("update"):
            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
            with T.init():
                C[vi, vj] = 0
            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]

try:
    assert_structural_equal(matmul_tir_with_typo, matmul_module["matmul_tir"], True)
except ValueError as ve:
    print(f"ValueError{str(ve).split('ValueError')[1]}")

ValueError: StructuralEqual check failed, caused by lhs at <root>.body.block.body.body.body.extent.value:
# from tvm.script import tir as T

@T.prim_func
def main(A_handle: T.handle, B_handle: T.handle, C_handle: T.handle):
    A = T.match_buffer(A_handle, (128, 128))
    B = T.match_buffer(B_handle, (128, 128))
    C = T.match_buffer(C_handle, (128, 128))
    with T.block("root"):
        T.reads()
        T.writes()
        for i in range(128):
            for j in range(128):
                for k in range(256):
                               ^^^
                    with T.block("update"):
                        vi = T.axis.spatial(128, i)
                        vj = T.axis.spatial(128, j)
                        vk = T.axis.reduce(256, k)
                        T.reads(A[vi, vk], B[vk, vj])
                        T.writes(C[vi, vj])
                        with T.init():
                            C[vi, vj] = T.float32(0)
                        C[vi, vj] = C[vi, vj] + A[vi, v