# [Schedule Primitives in TVM](https://tvm.apache.org/docs/how_to/work_with_schedules/schedule_primitives.html#sphx-glr-how-to-work-with-schedules-schedule-primitives-py)

## Create Schedule

In [None]:
from __future__ import absolute_import, print_function


import tvm
from tvm import te
from tvm import relay
import numpy as np

# declare some variables for use later
n = te.var("n")
m = te.var("m")

def test_elewise_mul():
  # declare a matrix element-wise multiply
  A = te.placeholder((m, n), name="A")
  B = te.placeholder((m, n), name="B")
  C = te.compute((m, n), lambda i, j: A[i, j] * B[i, j], name="C")

  s = te.create_schedule([C.op])
  # lower will transform the computation from definition to the real
  # callable function. With argument `simple_mode=True`, it will
  # return you a readable C like statement, we use it here to print the
  # schedule result.
  print(tvm.lower(s, [A, B, C], simple_mode=True))

test_elewise_mul()

## SchedulePrimitives::Split

SplitFactor：将指定维度按照指定长度进行切分


In [None]:
def test_split_factor(axis, factor=32):
  A = te.placeholder((m, n), name="A")
  B = te.compute((m, n), lambda i, j: A[i, j], name="B")
  s = te.create_schedule(B.op)
  xo, xi = s[B].split(B.op.axis[axis], factor=factor)
  print(tvm.lower(s, [A, B], simple_mode=True))

test_split_factor(axis=0)
test_split_factor(axis=1)


SplitParts：将指定维度按照指定份数进行切分


In [None]:
def test_split_nparts(nparts=8):
  A = te.placeholder((m, n), name="A")
  B = te.compute((m, n), lambda i, j: A[i, j], name="B")
  s = te.create_schedule(B.op)
  bx, tx = s[B].split(B.op.axis[0], nparts=nparts)
  print(tvm.lower(s, [A, B], simple_mode=True))

test_split_nparts()

## SchedulePrimitives::Tile

分块操作，注意Tile和Split是有区别的，我们无法用两个Split完成一个Tile的功能，但是可以通过Tile覆盖Split功能（Split中不切的axis-factor设为1）。


In [None]:
def test_tile(x, y, simple_mode=True):
  A = te.placeholder((m, n), name="A")
  B = te.compute((m, n), lambda i, j: A[i, j], name="B")
  s = te.create_schedule(B.op)
  xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], x_factor=x, y_factor=y)
  print(tvm.lower(s, [A, B], simple_mode=simple_mode))

test_tile(10, 5)

比较一下 ___Split___ 和 ___Tile___ 的行为： ___Tile___ 需要 ___Split___ 和 ___Reorder___


In [None]:
def test_split_x2(x_factor, y_factor):
  A = te.placeholder((m, n), name="A")
  B = te.compute((m, n), lambda i, j: A[i, j], name="B")
  s = te.create_schedule(B.op)
  xo, xi = s[B].split(B.op.axis[0], factor=x_factor)
  yo, yi = s[B].split(B.op.axis[1], factor=y_factor)
  print(tvm.lower(s, [A, B], simple_mode=True))

print('*'*64)
print('test_split_x2(10, 5)')
print('*'*64)
test_split_x2(10, 5)
print('*'*64)
print('test_tile(10, 5)')
print('*'*64)
test_tile(10, 5)

比较一下 ___Split___ 和 ___Tile___ 的行为：用 ___Tile___ 来实现 ___Split___：


In [None]:
test_split_factor(axis=1, factor=5)
test_tile(1, 5)

## SchedulePrimitives::Fuse

合并连续的 ___N___ 个维度


In [None]:
def test_fuse(axis0, axis1, simple_mode=True):
  A = te.placeholder((m, n), name="A")
  B = te.compute((m, n), lambda i, j: A[i, j], name="B")
  s = te.create_schedule(B.op)
  # tile to four axes first: (i.outer, j.outer, i.inner, j.inner)
  axes4 = s[B].tile(B.op.axis[0], B.op.axis[1], x_factor=10, y_factor=5)
  # then fuse (i.inner, j.inner) into one axis: (i.inner.j.inner.fused)
  fused = s[B].fuse(axes4[axis0], axes4[axis1])
  ### fused = s[B].fuse(xo, yo)
  print(tvm.lower(s, [A, B], simple_mode=simple_mode))

test_fuse(1, 2, True)

## SchedulePrimitives::Reorder

维度调换，相当于Transpose功能

In [None]:
def test_reorder(axis0, axis1, axis2, axis3):
  A = te.placeholder((m, n), name="A")
  B = te.compute((m, n), lambda i, j: A[i, j], name="B")
  s = te.create_schedule(B.op)
  # tile to four axes first: (i.outer, j.outer, i.inner, j.inner)
  axes4 = s[B].tile(B.op.axis[0], B.op.axis[1], x_factor=10, y_factor=5)
  s[B].reorder(axes4[axis0], axes4[axis1], axes4[axis2], axes4[axis3])
  print(tvm.lower(s, [A, B], simple_mode=True))

test_reorder(0, 1, 2, 3) # no change
test_reorder(2, 1, 0, 3)

## SchedulePrimitives::Gpu::Bind


In [None]:
def test_bind():
  A = te.placeholder((n,), name="A")
  B = te.compute(A.shape, lambda i: A[i], name="B")
  s = te.create_schedule(B.op)
  bx, tx = s[B].split(B.op.axis[0], factor=64)
  s[B].bind(bx, te.thread_axis("blockIdx.x"))
  s[B].bind(tx, te.thread_axis("threadIdx.x"))
  print(tvm.lower(s, [A, B], simple_mode=True))
test_bind()

## SchedulePrimitives::ComputeAt

移动一个Compute Stage 到指定的计算指定维度中，看起来可以用在 Fusion 中。

In [None]:
def test_compute_at(axis):
  A = te.placeholder((m, n), name="A")
  B = te.compute((m, n), lambda i, j: A[i, j] + 1, name="B")
  C = te.compute((m, n), lambda i, j: B[i, j] * 2, name="C")
  s = te.create_schedule(C.op)
  # move computation of B into the first axis of computation of C
  s[B].compute_at(s[C], C.op.axis[axis]) 
  print(tvm.lower(s, [A, B, C], simple_mode=True))

test_compute_at(-1)
test_compute_at(0)

## SchedulePrimitives::Compute_Inline

In [None]:
def test_compute_inline():
  A = te.placeholder((m,), name="A")
  B = te.compute((m,), lambda i: A[i] + 1, name="B")
  C = te.compute((m,), lambda i: B[i] * 2, name="C")
  s = te.create_schedule(C.op)
  # mark one stage as inline
  s[B].compute_inline()
  print(tvm.lower(s, [A, B, C], simple_mode=True))

test_compute_inline()

## SchedulePrimitives::Compute_Root

将一个计算移动到根上

In [None]:
def test_compute_root():
  A = te.placeholder((m,), name="A")
  B = te.compute((m,), lambda i: A[i] + 1, name="B")
  C = te.compute((m,), lambda i: B[i] * 2, name="C")
  s = te.create_schedule(C.op)
  s[B].compute_at(s[C], C.op.axis[0])
  # move computation of one stage to the root
  s[B].compute_root()
  print(tvm.lower(s, [A, B, C], simple_mode=True))

test_compute_root()

## SchedulePrimitives::Parallel

为了支持CPU上类似 _openmp_ 方式的并行计算， GPU上还是要用 ___bind___。 GCU上可以用来划分 4xCluster 并行，以及 6xSip 并行。

In [None]:
A = te.placeholder((n, m), name='A')
l = te.reduce_axis((0, m), name = 'l')
B = te.compute((n,), lambda i: te.sum(A[i, l], axis=l), name='B')
s = te.create_schedule(B.op)
s[B].parallel(B.op.reduce_axis[0])
print(tvm.lower(s, [A, B], simple_mode=True))


## SchedulePrimitives::Unroll

Unroll 的 axis 要求是 constant

In [None]:

def test_unroll():
  A = te.placeholder((m, n), name='A')
  B = te.placeholder((m, n), name='B')
  C = te.compute((m, n), lambda i, j: A[i, j] + B[i, j], name='C')
  s = te.create_schedule(C.op)
  xo, xi = s[C].split(s[C].op.axis[0], factor=4)
  print(tvm.lower(s, [A, B, C], simple_mode=True))
  print("---------cutting line---------")
  s[C].unroll(xi)
  print(tvm.lower(s, [A, B, C], simple_mode=True))

test_unroll()

## SchedulePrimitives::CacheRead

In [None]:
def test_cache_read():
  A = te.placeholder((m, n), name='A')
  k = te.reduce_axis((0, n), name='k')
  B = te.compute((m,), lambda i: te.sum(A[i, k], axis=k), name='B')
  s = te.create_schedule(B.op)
  print(tvm.lower(s, [A, B], simple_mode=True))
  print("---------cutting line---------")
  AA = s.cache_read(A, "shared", [B])
  print(tvm.lower(s, [A, B], simple_mode=True))

test_cache_read()

## SchedulePrimitives::CacheWrite

In [None]:
def test_cache_write():
  A = te.placeholder((m, n), name='A')
  k = te.reduce_axis((0, n), name='k')
  B = te.compute((m,), lambda i: te.sum(A[i, k], axis=k), name='B')
  s = te.create_schedule(B.op)
  print(tvm.lower(s, [A, B], simple_mode=True))
  print("---------cutting line---------")
  BW = s.cache_write(B, "local")
  print(tvm.lower(s, [A, B], simple_mode=True))

test_cache_write()

## SchedulePrimitives::StorageAlign

Set alignment requirement for specific axis

This ensures that stride[axis] == k * factor + offset for some k. This is useful to set memory layout to for more friendly memory access pattern. For example, we can set alignment to be factor=2, offset=1 to avoid bank conflict for thread access on higher dimension in GPU shared memory.

Parameters:
 * axis (IterVar) – The axis dimension to be aligned.
 * factor (int) – The factor in alignment specification.
 * offset (int) – The offset in the alignment specification.

计算公式：`stride=size + floormod(offset - floormod(size, factor), factor)`。StorageAlign看起来完全是为了优化 GPU 的 _shared memory_ 访问时的 _bank conflict_ 而引入的定制优化。现代 GPU 的 _shared memory_ 一般是 32bits interleaving， 有32个 bank，由此可以计算出如果我们让每个 thread 处理连续 128bytes 的数据，会导致所有的 thread 都会同时访问相同的 bank。这个时候就需要改变数据存储的格式，比如说申请一个buffer，它的 row 为128+4bytes，其中前128bytes写入有效数据，后面的4bytes为 padding 的 dummy 数据， 那么 t0 和 t1 ... t15 在第一个访问时间点上会分别访问 bank0, bank1 ... bank15，来解决访问冲突。 StorageAlign 里边的 factor 看起来对应与每个 thread 访问的数据量，这种方式可以解决不同的 row 之间的访问冲突，并不是最极限的优化。

In [None]:
def tset_storage_align():
  # m = 100
  # n = 128
  factor_val = 97
  offset = 16
  A = te.placeholder((m, n), name='A')
  k = te.reduce_axis((0, n), name='k')
  B = te.compute((m,), lambda i: te.sum(A[i, k], axis=k), name='B')
  s = te.create_schedule(B.op)
  ## cache read will create a buffer, buffer is only 1 axis
  AA = s.cache_read(A, "shared", [B])
  print(tvm.lower(s, [A, B], simple_mode=True))

  print("---------cutting line---------")
  s[AA].storage_align(AA.op.axis[0], factor_val, offset)
  print(tvm.lower(s, [A, B], simple_mode=True))

tset_storage_align()

## SchedulePrimitives::Pragma

Pragma 将生成 pragma_scope， 这回有助于一些实验中的功能以及外部的扩展。

Most pragmas are advanced/experimental features and may subject to change. List of supported pragmas:

* debug_skip_region

  Force skip the region marked by the axis and turn it into no-op. This is useful for debug purposes.

* parallel_launch_point

  Specify to launch parallel threads outside the specified iteration loop. By default the threads launch at the point of parallel construct. This pragma moves the launching point to even outer scope. The threads are launched once and reused across multiple parallel constructs as BSP style program.

* parallel_barrier_when_finish

  Insert a synchronization barrier between working threads after the specified loop iteration finishes.

* parallel_stride_pattern

  Hint parallel loop to execute in strided pattern. for (int i = task_id; i < end; i += num_task)

这个功能可以穿透式的带一些信息给到底层，虽然提供了便利性，但应该最小限度使用，以免在程序变得庞大后难以维护。

In [None]:
def test_pragma():
  A = te.placeholder((n, m), name='A')
  k = te.reduce_axis((0, n), name='k')
  l = te.reduce_axis((0, m), name = 'l')
  B = te.compute((n,), lambda i: te.sum(A[i, l], axis=l), name='B')
  s = te.create_schedule(B.op)
  ko, ki = s[B].split(B.op.reduce_axis[0], factor=4)
  print(tvm.lower(s, [A, B], simple_mode=True))
  print("---------cutting line---------")
  s[B].pragma(ki, "unroll")
  print(tvm.lower(s, [A, B], simple_mode=True))

test_pragma()

## SchedulePrimitives::CreateGroup

create_group 对从inputs到outputs的所有stage创建group，group本质上是一个虚拟stage，可以通过操作这个虚拟stage来一起操作这个group里的所有stage。

本例中，通过compute_at使这个group中的D和E，一起附加到F的reduce维度操作中。这样临时Buffer D 变成了一个Scalar。 


In [None]:
def test_create_group():
  k = te.reduce_axis((0, n), name='k')
  A = te.placeholder((m, n), name='A')
  B = te.placeholder((m, n), name='B')

  D = te.compute((m, n), lambda i, j: A[i, j] + B[i, j], name='D')
  E = te.compute((m, n), lambda i, j: D[i, j] + B[i, j], name='E')
  F = te.compute((m,), lambda i: te.sum(E[i, k], axis=k), name='F')

  s = te.create_schedule(F.op)

  print(tvm.lower(s, [A, B, E], simple_mode=True))
  print("---------cutting line---------")

  g = s.create_group(outputs = E, inputs = [A, B], include_inputs=True)
  g.compute_at(s[F], F.op.reduce_axis[0])

  print(tvm.lower(s, [A, B, E], simple_mode=True))

test_create_group()

## SchedulePrimitives::SetScope

指定当前 _stage_ 的计算结果保存的位置，在没有指定的情况下默认 `storage_scope=global`，可以通过 `set_stage` 指定成 _shared_, 通常用于用于 _thread_ 之间的数据共享。

`set_scope`比`cache_read`以及`cache_write`提供更灵活的操作，后两者实现中使用了这个功能。

In [None]:
def test_set_scope():
  A = te.placeholder((m, n), name='A')
  k = te.reduce_axis((0, n), name='k')
  B = te.compute((m,), lambda i: te.sum(A[i, k], axis=k), name='B')
  C = te.compute((m,), lambda i: B[i] + 10, name='C')
  s = te.create_schedule(C.op)
  print(tvm.lower(s, [A, C], simple_mode=True))
  print("---------cutting line---------")
  s[B].set_scope('shared')
  print(tvm.lower(s, [A, C], simple_mode=True))

test_set_scope()

## SchedulePrimitives::Vectorize

___注意:___ `m` 和 `n` 是 `var`的时候Vectorize无法使能。

In [None]:
def test_vectorize():
  m = 1024
  n = 1024
  A = te.placeholder((m, n), name='A')
  B = te.placeholder((m, n), name='B')
  C = te.compute(
            (m, n),
            lambda x, y: A[x, y] + B[x, y],
            name='C')
  s = te.create_schedule(C.op)
  xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], 32, 32)
  print(tvm.lower(s, [A, B, C], simple_mode=True))
  print("---------cutting line---------")
  
  s[C].vectorize(yi)
  print(tvm.lower(s, [A, B, C], simple_mode=True))

test_vectorize()

## SchedulePrimitives::Normalize

`normalize` 与 `create_group`，`rfactor`，`cache_read`，`cache_write` 一样，作用域是全部的 stages。

下面的例子看起来在TVM框架中会自动调用 `normalize` 无需用户手动调用，也不希望用户手动调用，同样的 `rebase` 也是 。
[tqChen](https://github.com/apache/tvm/issues/733#issuecomment-355420100)
I think a good way is always avoid calling normalize manually as it will be called right before we do lowering, it might be good to add this to the document.

In [None]:
def test_normalize():
  A = te.placeholder((n,), name='A')
  B = te.placeholder((n,), name='B')
  k = te.reduce_axis((10, n), 'k')
  C = te.compute((1,), lambda _: te.sum(A[k] * B[k], axis=k), name='C')
  s = te.create_schedule(C.op)
  print(tvm.lower(s, [A, B, C], simple_mode=True))
  print("---------cutting line---------")
  s = s.normalize()
  print(tvm.lower(s, [A, B, C], simple_mode=True))

test_normalize()

## SchedulePrimitives::Prefetch

Prefetch the specified variable

* Parameters
  * _tensor_ (Tensor) – The tensor to be prefetched
  * _var_ (IterVar) – The loop point at which the prefetching is applied
  * _offset_ (Expr) – The number of iterations to be prefetched before actual execution

___FIXME___：_tir.prefetch_ 的行为我没有理解。

In [None]:
def test_prefetch():
  k = te.reduce_axis((0, n), name='k')
  A = te.placeholder((m, n), name='A')
  B = te.compute((m,), lambda i: te.sum(A[i, k], axis=k), name='B')
  s = te.create_schedule(B.op)
  print(tvm.lower(s, [A, B], simple_mode=True))
  print("---------cutting line---------")
  s[B].prefetch(A, s[B].op.reduce_axis[0], 11)
  print(tvm.lower(s, [A, B], simple_mode=True))

test_prefetch()

## SchedulePrimitives::Tensorize

Note that intrin_func now returns a triplet: (body, reduce_reset, reduce_update). If tensorization includes all the reduce axes, function body() will be invoked, otherwise reduce_reset() and reduce_update() together will be used. In our example body() and reduce_update() share the same implementation, while in other cases, hardware may have different instructions for these two functions. Moreover, we can see now bb.strides[0] is different from l due to the tiling.

`intrin_func`返回 (body, reduce_reset, reduce_update)，在 tensorization 包含了完成的 reduce axes 的话那么只调用 body()；否则的话需要调用 reduce_reset() 和 reduce_update()来组合完成 partail sum 的初期化以及更新：
* test_tensorize(): 不切分 reduce axis 时的实现
* test_tensorize2(): 切分 reduce axis 时的实现

In [None]:
def test_tensorize():
  N, M, L = 1024, 512, 64
  A = te.placeholder((N, L), name='A')
  B = te.placeholder((M, L), name='B')
  k = te.reduce_axis((0, L), name='k')
  C = te.compute((N, M), lambda i, j: te.sum(A[i, k] * B[j, k], axis=k), name='C')
  s = te.create_schedule(C.op)

  def intrin_gemv(m, l):
      a = te.placeholder((l,), name='a')
      b = te.placeholder((m, l), name='b')
      k = te.reduce_axis((0, l), name='k')
      c =  te.compute((m,), lambda i: te.sum(a[k] * b[i, k], axis=k), name='c')
      Abuf = tvm.tir.decl_buffer(a.shape, a.dtype, name='A', offset_factor=1, strides=[1])
      Bbuf = tvm.tir.decl_buffer(b.shape, b.dtype, name='B', offset_factor=1, strides=[te.var("s1"), 1])
      Cbuf = tvm.tir.decl_buffer(c.shape, c.dtype, name='C', offset_factor=1, strides=[1])
      
      def intrin_func(ins, outs):
          ib = tvm.tir.ir_builder.create()
          aa, bb = ins
          cc = outs[0]
          ib.emit(tvm.tir.call_extern("int32", "gemv_update", cc.access_ptr("w"), aa.access_ptr("r"), bb.access_ptr("r"), m, l, bb.strides[0]))
          return ib.get()
      #with tvm.build_config(offset_factor=1):
      with relay.build_config(opt_level=0):
          return te.decl_tensor_intrin(c.op, intrin_func, binds={a: Abuf, b: Bbuf, c: Cbuf})

  factor = 16
  x, y = C.op.axis
  z, = C.op.reduce_axis
  yo, yi = s[C].split(y, factor=factor)
  s[C].reorder(x, yo, yi, z)

  gemv = intrin_gemv(factor, L)

  print(tvm.lower(s, [A, B, C], simple_mode=True))
  print("---------cutting line---------")

  s[C].tensorize(yi, gemv)

  print(tvm.lower(s, [A, B, C], simple_mode=True))

test_tensorize()

In [None]:
def gemv_impl():
    cc_code = """
      extern "C" int gemv_update(float *cc, float *aa, float *bb, int m, int l, int stride) {
        for (int i = 0; i < m; ++i) {
            for (int j = 0; j < l; ++j) {
                cc[i] += aa[j] * bb[i * stride + j];
            }
        }
        return 0;
      }
      extern "C" int gemv_reset(float *cc, int m) {
        for (int i = 0; i < m; ++i) {
            cc[i] = 0.0;
        }
        return 0;
      }
    """
    from tvm.contrib import utils, clang

    temp = utils.tempdir()
    ll_path = temp.relpath("temp.ll")
    # Create LLVM ir from c source code
    ll_code = clang.create_llvm(cc_code, output=ll_path)
    return ll_code


def intrin_gemv(m, l):
    a = te.placeholder((l,), name="a")
    b = te.placeholder((m, l), name="b")
    k = te.reduce_axis((0, l), name="k")
    c = te.compute((m,), lambda i: te.sum(a[k] * b[i, k], axis=k), name="c")
    Ab = tvm.tir.decl_buffer(a.shape, a.dtype, name="A", offset_factor=1, strides=[1])
    Bb = tvm.tir.decl_buffer(b.shape, b.dtype, name="B", offset_factor=1, strides=[te.var("s1"), 1])
    Cb = tvm.tir.decl_buffer(c.shape, c.dtype, name="C", offset_factor=1, strides=[1])

    def intrin_func(ins, outs):
        aa, bb = ins
        cc = outs[0]

        def _body():
            ib = tvm.tir.ir_builder.create()
            ib.emit(
                tvm.tir.call_extern(
                    "int32",
                    "gemv_update",
                    cc.access_ptr("w"),
                    aa.access_ptr("r"),
                    bb.access_ptr("r"),
                    m,
                    l,
                    bb.strides[0],
                )
            )
            return ib.get()

        def _reduce_reset():
            ib = tvm.tir.ir_builder.create()
            ib.emit(tvm.tir.call_extern("int32", "gemv_reset", cc.access_ptr("w"), m))
            return ib.get()

        def _reduce_update():
            return _body()

        return _body(), _reduce_reset(), _reduce_update()

    return te.decl_tensor_intrin(c.op, intrin_func, binds={a: Ab, b: Bb, c: Cb})

def test_tensorize2():
  N, M, L = 1024, 512, 64
  A = te.placeholder((N, L), name='A')
  B = te.placeholder((M, L), name='B')
  k = te.reduce_axis((0, L), name='k')
  C = te.compute((N, M), lambda i, j: te.sum(A[i, k] * B[j, k], axis=k), name='C')
  s = te.create_schedule(C.op)

  factor = 16
  x, y = C.op.axis
  z, = C.op.reduce_axis
  yo, yi = s[C].split(y, factor=factor)
  s[C].reorder(x, yo, yi, z)
  zo, zi = s[C].split(z, factor=factor)
  s[C].reorder(x, yo, zo, yi, zi)

  print(tvm.lower(s, [A, B, C], simple_mode=True))

  gemv = intrin_gemv(factor, factor)
  s[C].tensorize(yi, gemv)
  s[C].pragma(yo, "import_llvm", gemv_impl())
  
  print("---------cutting line---------")
  print(tvm.lower(s, [A, B, C], simple_mode=True))

test_tensorize2()