In [3]:
!python3 -V

Python 3.10.12


In [4]:
!pip install https://github.com/mlc-ai/package/releases/download/v0.9.dev0/mlc_ai_nightly_cu121-0.12.dev1813-cp310-cp310-manylinux_2_28_x86_64.whl

Collecting mlc-ai-nightly-cu121==0.12.dev1813
  Downloading https://github.com/mlc-ai/package/releases/download/v0.9.dev0/mlc_ai_nightly_cu121-0.12.dev1813-cp310-cp310-manylinux_2_28_x86_64.whl (536.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m536.5/536.5 MB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Collecting cloudpickle
  Downloading cloudpickle-3.0.0-py3-none-any.whl (20 kB)
Collecting ml-dtypes
  Downloading ml_dtypes-0.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (206 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m206.7/206.7 KB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Installing collected packages: ml-dtypes, cloudpickle, mlc-ai-nightly-cu121
Successfully installed cloudpickle-3.0.0 ml-dtypes-0.3.1 mlc-ai-nightly-cu121-0.12.dev1813
[0m

In [68]:
!pip install pandas

Collecting pandas
  Downloading pandas-2.1.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (12.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.3/12.3 MB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Collecting tzdata>=2022.1
  Downloading tzdata-2023.3-py2.py3-none-any.whl (341 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m341.8/341.8 KB[0m [31m14.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting pytz>=2020.1
  Downloading pytz-2023.3.post1-py2.py3-none-any.whl (502 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m502.5/502.5 KB[0m [31m12.4 MB/s[0m eta [36m0:00:00[0m00:01[0m
Installing collected packages: pytz, tzdata, pandas
Successfully installed pandas-2.1.2 pytz-2023.3.post1 tzdata-2023.3
[0m

In [50]:
import time

In [19]:
import IPython
import numpy as np
import tvm
from tvm import relax
from tvm.ir.module import IRModule
from tvm.script import relax as R
from tvm.script import tir as T

In [None]:
@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def relu0(x: T.handle, y: T.handle):
        n = T.int64()
        X = T.match_buffer(x, (1, n), "float32")
        Y = T.match_buffer(y, (1, n), "float32")
        for i, j in T.grid(1, n):
            with T.block("Y"):
                vi, vj = T.axis.remap("SS", [i, j])
                Y[vi, vj] = T.max(X[vi, vj], T.float32(0))

    @T.prim_func
    def linear0(x: T.handle,
                w: T.handle,
                b: T.handle,
                z: T.handle):
        m, n, k = T.int64(), T.int64(), T.int64()
        X = T.match_buffer(x, (1, m), "float32")
        W = T.match_buffer(w, (n, m), "float32")
        B = T.match_buffer(b, (n, ), "float32")
        Z = T.match_buffer(z, (1, n), "float32")
        Y = T.alloc_buffer((1, n), "float32")
        for i, j, k in T.grid(1, n, m):
            with T.block("Y"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    Y[vi, vj] = T.float32(0)
                Y[vi, vj] = Y[vi, vj] + X[vi, vk] * W[vj, vk]
        for i, j in T.grid(1, n):
            with T.block("Z"):
                vi, vj = T.axis.remap("SS", [i, j])
                Z[vi, vj] = Y[vi, vj] + B[vj]

    @R.function
    def main(x: R.Tensor((1, "m"), "float32"),
             w0: R.Tensor(("n", "m"), "float32"),
             b0: R.Tensor(("n", ), "float32"),
             w1: R.Tensor(("k", "n"), "float32"),
             b1: R.Tensor(("k", ), "float32")):
        m, n, k = T.int64(), T.int64(), T.int64()
        with R.dataflow():
            lv0 = R.call_dps_packed("linear0", (x, w0, b0), R.Tensor((1, n), "float32"))
            lv1 = R.call_dps_packed("relu0", (lv0, ), R.Tensor((1, n), "float32"))
            out = R.call_dps_packed("linear0", (lv1, w1, b1), R.Tensor((1, k), "float32"))
            R.output(out)
        return out

In [73]:
softening = 0.1
@tvm.script.ir_module
class NBodyModule:
    @T.prim_func
    def computeInvR3(dx : T.handle,
                     dy : T.handle,
                     dz : T.handle,
                     #softening : T.float32,
                     inv_R3 : T.handle):
        n = T.int64()
        X = T.match_buffer(dx, (n, n), "float32")
        Y = T.match_buffer(dy, (n, n), "float32")
        Z = T.match_buffer(dz, (n, n), "float32")
        INV_R3 = T.match_buffer(inv_R3, (n, n), "float32")
        for i, j in T.grid(n, n):
            with T.block("INV_R3"):
                vi, vj = T.axis.remap("SS", [i, j])
                INV_R3[vi,vj] = T.pow(X[vi,vj], 2) + T.pow(Y[vi,vj], 2.0) + T.pow(Z[vi,vj], 2.0) + softening**2.0
                if (INV_R3[vi,vj] > 0):
                    INV_R3[vi,vj] = T.pow(INV_R3[vi,vj], (-1.5))

    # def getAcc(pos, mass, G, softening):
    # # Positions for all particles
    # x = pos[:, 0:1]
    # y = pos[:, 1:2]
    # z = pos[:, 2:3]
    
    # # Particle separations
    # dx = particleSeps(x)
    # dy = particleSeps(y)
    # dz = particleSeps(z)
    
    # # 1/r^3 for all particle separations
    # inv_r3 = computeInvR3(dx, dy, dz, softening)
    
    # # Acceleration components per dimension
    # ax = computeAcc(dx, inv_r3, mass, G)
    # ay = computeAcc(dy, inv_r3, mass, G)
    # az = computeAcc(dz, inv_r3, mass, G)
    
    # # Packing acceleration components
    # a = np.hstack((ax, ay, az))
    
    # return a
    # @T.prim_func
    # def getAcc(pos: T.handle,
    #            mass: T.handle,
    #            acc: T.handle):
    #     n = T.int64()
    #     pos_tensor = T.match_buffer(pos, (n, 3), "float32")
    #     mass_tensor = T.match_buffer(mass, (n, 1), "float32")
    #     acc_tensor = T.match_buffer(acc, (n, 3), "float32")
    #     x = pos_tensor[:, 0]
        
    @R.function
    def main(dx: R.Tensor(("n", "n"), "float32"),
             dy: R.Tensor(("n", "n"), "float32"),
             dz: R.Tensor(("n", "n"), "float32"),
            ) -> R.Tensor(("n", "n"), "float32"):
        n = T.int64()
        with R.dataflow():
            lv0 = R.call_dps_packed("computeInvR3", (dz, dy, dz), R.Tensor((n, n), "float32"))
            R.output(lv0)
        return lv0

In [74]:
IPython.display.Code(NBodyModule.script(), language="python")

In [75]:
ex = relax.build(NBodyModule, target="llvm")
type(ex)

tvm.relax.vm_build.Executable

In [76]:
vm = relax.VirtualMachine(ex, tvm.cpu())

In [77]:
N = 2000
dx = tvm.nd.array(np.random.randn(N, N).astype("float32"))
dy = tvm.nd.array(np.random.randn(N, N).astype("float32"))
dz = tvm.nd.array(np.random.randn(N, N).astype("float32"))

In [78]:
t0 = time.time()
for i in range(100):
    inv_R3 = vm["main"](dz, dy, dz)
t1 = time.time()

print("Total time is:",t1 - t0)

Total time is: 2.36555552482605


In [79]:
inv_R3.numpy().shape

(2000, 2000)

In [83]:
mod_inv_R3 = tvm.IRModule.from_expr(NBodyModule["computeInvR3"].with_attr("global_symbol", "main"))
mod_inv_R3.script()

'# from tvm.script import ir as I\n# from tvm.script import tir as T\n\n@I.ir_module\nclass Module:\n    @T.prim_func\n    def main(dx: T.handle, dy: T.handle, dz: T.handle, inv_R3: T.handle):\n        n = T.int64()\n        X = T.match_buffer(dx, (n, n))\n        Y = T.match_buffer(dy, (n, n))\n        Z = T.match_buffer(dz, (n, n))\n        INV_R3 = T.match_buffer(inv_R3, (n, n))\n        # with T.block("root"):\n        for i, j in T.grid(n, n):\n            with T.block("INV_R3"):\n                vi, vj = T.axis.remap("SS", [i, j])\n                T.reads(X[vi, vj], Y[vi, vj], Z[vi, vj], INV_R3[vi, vj])\n                T.writes(INV_R3[vi, vj])\n                INV_R3[vi, vj] = T.pow(X[vi, vj], T.float32(2)) + T.pow(Y[vi, vj], T.float32(2)) + T.pow(Z[vi, vj], T.float32(2)) + T.float32(0.010000000000000002)\n                if INV_R3[vi, vj] > T.float32(0):\n                    INV_R3[vi, vj] = T.pow(INV_R3[vi, vj], T.float32(-1.5))'

In [95]:
from tvm import meta_schedule as ms
database = ms.tune_tir(
    mod=mod_inv_R3,
    target="nvidia/geforce-rtx-3070",
    max_trials_global=6400,
    num_trials_per_iter=6400,
    work_dir="./tune_tmp",
    # task_name="main",
)
sch = ms.tir_integration.compile_tir(database, mod_inv_R3, "llvm --num-cores=1")

2023-11-07 12:38:18 [INFO] Logging directory: ./tune_tmp/logs
2023-11-07 12:38:18 [INFO] LocalBuilder: max_workers = 8
2023-11-07 12:38:18 [INFO] LocalRunner: max_workers = 1
2023-11-07 12:38:19 [INFO] [task_scheduler.cc:159] Initializing Task #0: "main"


ScheduleError: Traceback (most recent call last):
  9: _ZN3tvm7runtime13Packed
  8: tvm::runtime::TypedPackedFunc<void (tvm::meta_schedule::TaskScheduler, tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>, void>(void (tvm::meta_schedule::TaskSchedulerNode::*)(tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>))::{lambda(tvm::meta_schedule::TaskScheduler, tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>)#1}>(tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>, void>(void (tvm::meta_schedule::TaskSchedulerNode::*)(tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>))::{lambda(tvm::meta_schedule::TaskScheduler, tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const [clone .isra.0]
  7: tvm::meta_schedule::GradientBasedNode::Tune(tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>)
  6: tvm::meta_schedule::TaskSchedulerNode::Tune(tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>)
  5: tvm::meta_schedule::PostOrderApplyNode::GenerateDesignSpace(tvm::IRModule const&)
  4: tvm::meta_schedule::AutoBindNode::Apply(tvm::tir::Schedule const&, tvm::tir::BlockRV const&)
  3: tvm::meta_schedule::BindBlockThreadIdx(tvm::tir::Schedule, tvm::tir::BlockRV, long, long, std::function<tvm::PrimExpr (long)>)
  2: tvm::meta_schedule::BindSpatialLoop(tvm::tir::Schedule, tvm::tir::LoopRV, long, long, std::function<tvm::PrimExpr (long)>)
  1: tvm::tir::TracedScheduleNode::Bind(tvm::tir::LoopRV const&, tvm::runtime::String const&)
  0: tvm::tir::ConcreteScheduleNode::Bind(tvm::tir::LoopRV const&, tvm::runtime::String const&) [clone .cold]
ScheduleError: An error occurred in the schedule primitive 'bind'.
The IR with diagnostic is:
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(dx: T.handle, dy: T.handle, dz: T.handle, inv_R3: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        n = T.int64()
        X = T.match_buffer(dx, (n, n))
        Y = T.match_buffer(dy, (n, n))
        Z = T.match_buffer(dz, (n, n))
        INV_R3 = T.match_buffer(inv_R3, (n, n))
        with T.block("root"):
            T.reads()
            T.writes()
            # tir.For#0
            for i_j_fused_1 in range(T.int64(256)):
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                for i_j_fused_2 in range(T.int64(1024)):
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                    for i_j_fused_0 in range((n * n + T.int64(262143)) // T.int64(262144)):
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                        # tir.Block#1
                        with T.block("INV_R3"):
                        ^^^^^^^^^^^^^^^^^^^^^^^
                            vi = T.axis.spatial(n, (i_j_fused_0 * T.int64(262144) + i_j_fused_1 * T.int64(1024) + i_j_fused_2) // n)
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                            vj = T.axis.spatial(n, (i_j_fused_0 * T.int64(262144) + i_j_fused_1 * T.int64(1024) + i_j_fused_2) % n)
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                            T.where((i_j_fused_0 * T.int64(256) + i_j_fused_1) * T.int64(1024) + i_j_fused_2 < n * n)
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                            T.reads(X[vi, vj], Y[vi, vj], Z[vi, vj], INV_R3[vi, vj])
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                            T.writes(INV_R3[vi, vj])
                            ^^^^^^^^^^^^^^^^^^^^^^^^
                            INV_R3[vi, vj] = T.pow(X[vi, vj], T.float32(2)) + T.pow(Y[vi, vj], T.float32(2)) + T.pow(Z[vi, vj], T.float32(2)) + T.float32(0.010000000000000002)
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                            if INV_R3[vi, vj] > T.float32(0):
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                INV_R3[vi, vj] = T.pow(INV_R3[vi, vj], T.float32(-1.5))
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Error message: The queried subtree root tir.For#0 in SRef tree does not have compact dataflow, because its child block tir.Block#1 on SRef tree is neither a local complete block nor a local reduction block.
It violates condition #3 as a local complete block.
Definition of a local complete block:
1) All block vars are data parallel
2) Local Dominant: the block is the only writer of its output, dominating the reader of its output buffers under a given subtree
3) No overlap between the buffers the block reads and writes
It violates condition #1 as a local reduction block.
Definition of a reduction block:
1) The block has the `init` statement
2) All the block bindings are quasi-affine expressions
3) All block vars are either data parallel block vars or reduction block vars
4) Local Dominant: the block is the only writer of its output, dominating the reader of its output buffers under a given subtree
5) The reduction block vars are not used to index the output buffers
