# [Schedule Primitives in TVM](https://tvm.apache.org/docs/how_to/work_with_schedules/schedule_primitives.html#sphx-glr-how-to-work-with-schedules-schedule-primitives-py)

## Create Schedule

In [1]:
from __future__ import absolute_import, print_function


import tvm
from tvm import te
import numpy as np

# declare some variables for use later
n = te.var("n")
m = te.var("m")

def test_elewise_mul():
  # declare a matrix element-wise multiply
  A = te.placeholder((m, n), name="A")
  B = te.placeholder((m, n), name="B")
  C = te.compute((m, n), lambda i, j: A[i, j] * B[i, j], name="C")

  s = te.create_schedule([C.op])
  # lower will transform the computation from definition to the real
  # callable function. With argument `simple_mode=True`, it will
  # return you a readable C like statement, we use it here to print the
  # schedule result.
  print(tvm.lower(s, [A, B, C], simple_mode=True))

test_elewise_mul()

@main = primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {A: Buffer(A_2: Pointer(float32), float32, [(stride: int32*m: int32)], [], type="auto"),
             B: Buffer(B_2: Pointer(float32), float32, [(stride_1: int32*m)], [], type="auto"),
             C: Buffer(C_2: Pointer(float32), float32, [(stride_2: int32*m)], [], type="auto")}
  buffer_map = {A_1: A, B_1: B, C_1: C}
  preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [m, n: int32], [stride, stride_3: int32], type="auto"), B_1: B_3: Buffer(B_2, float32, [m, n], [stride_1, stride_4: int32], type="auto"), C_1: C_3: Buffer(C_2, float32, [m, n], [stride_2, stride_5: int32], type="auto")} {
  for (i: int32, 0, m) {
    for (j: int32, 0, n) {
      C[((i*stride_2) + (j*stride_5))] = (A[((i*stride) + (j*stride_3))]*B[((i*stride_1) + (j*stride_4))])
    }
  }
}




## SchedulePrimitives::Split

SplitFactor：将指定维度按照指定长度进行切分


In [2]:
def test_split_factor(axis, factor=32):
  A = te.placeholder((m, n), name="A")
  B = te.compute((m, n), lambda i, j: A[i, j], name="B")
  s = te.create_schedule(B.op)
  xo, xi = s[B].split(B.op.axis[axis], factor=factor)
  print(tvm.lower(s, [A, B], simple_mode=True))

test_split_factor(axis=0)
test_split_factor(axis=1)


@main = primfn(A_1: handle, B_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {A: Buffer(A_2: Pointer(float32), float32, [(stride: int32*m: int32)], [], type="auto"),
             B: Buffer(B_2: Pointer(float32), float32, [(stride_1: int32*m)], [], type="auto")}
  buffer_map = {A_1: A, B_1: B}
  preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [m, n: int32], [stride, stride_2: int32], type="auto"), B_1: B_3: Buffer(B_2, float32, [m, n], [stride_1, stride_3: int32], type="auto")} {
  for (i.outer: int32, 0, floordiv((m + 31), 32)) {
    for (i.inner: int32, 0, 32) {
      if @tir.likely((((i.outer*32) + i.inner) < m), dtype=bool) {
        for (j: int32, 0, n) {
          let cse_var_1: int32 = ((i.outer*32) + i.inner)
          B[((cse_var_1*stride_1) + (j*stride_3))] = A[((cse_var_1*stride) + (j*stride_2))]
        }
      }
    }
  }
}


@main = primfn(A_1: handle, B_1: handle) -> ()
  attr = {"from_legacy

SplitParts：将指定维度按照指定份数进行切分


In [3]:
def test_split_nparts(nparts=8):
  A = te.placeholder((m,), name="A")
  B = te.compute((m,), lambda i: A[i], name="B")
  s = te.create_schedule(B.op)
  bx, tx = s[B].split(B.op.axis[0], nparts=nparts)
  print(tvm.lower(s, [A, B], simple_mode=True))

test_split_nparts()

@main = primfn(A_1: handle, B_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {A: Buffer(A_2: Pointer(float32), float32, [(stride: int32*m: int32)], [], type="auto"),
             B: Buffer(B_2: Pointer(float32), float32, [(stride_1: int32*m)], [], type="auto")}
  buffer_map = {A_1: A, B_1: B}
  preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [m], [stride], type="auto"), B_1: B_3: Buffer(B_2, float32, [m], [stride_1], type="auto")} {
  for (i.outer: int32, 0, 8) {
    for (i.inner: int32, 0, floordiv((m + 7), 8)) {
      if @tir.likely(((i.inner + (i.outer*floordiv((m + 7), 8))) < m), dtype=bool) {
        B[((i.inner + (i.outer*floordiv((m + 7), 8)))*stride_1)] = A[((i.inner + (i.outer*floordiv((m + 7), 8)))*stride)]
      }
    }
  }
}




## SchedulePrimitives::Tile

分块操作，注意Tile和Split是有区别的，我们无法用两个Split完成一个Tile的功能，但是可以通过Tile覆盖Split功能（Split中不切的axis-factor设为1）。


In [4]:
def test_tile(x, y, simple_mode=True):
  A = te.placeholder((m, n), name="A")
  B = te.compute((m, n), lambda i, j: A[i, j], name="B")
  s = te.create_schedule(B.op)
  xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], x_factor=x, y_factor=y)
  print(tvm.lower(s, [A, B], simple_mode=simple_mode))

test_tile(10, 5)

@main = primfn(A_1: handle, B_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {A: Buffer(A_2: Pointer(float32), float32, [(stride: int32*m: int32)], [], type="auto"),
             B: Buffer(B_2: Pointer(float32), float32, [(stride_1: int32*m)], [], type="auto")}
  buffer_map = {A_1: A, B_1: B}
  preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [m, n: int32], [stride, stride_2: int32], type="auto"), B_1: B_3: Buffer(B_2, float32, [m, n], [stride_1, stride_3: int32], type="auto")} {
  for (i.outer: int32, 0, floordiv((m + 9), 10)) {
    for (j.outer: int32, 0, floordiv((n + 4), 5)) {
      for (i.inner: int32, 0, 10) {
        if @tir.likely((((i.outer*10) + i.inner) < m), dtype=bool) {
          for (j.inner: int32, 0, 5) {
            if @tir.likely((((j.outer*5) + j.inner) < n), dtype=bool) {
              let cse_var_2: int32 = ((j.outer*5) + j.inner)
              let cse_var_1: int32 = ((i.outer*10) + i.

比较一下 ___Split___ 和 ___Tile___ 的行为： ___Tile___ 需要 ___Split___ 和 ___Reorder___


In [5]:
def test_split_x2(x_factor, y_factor):
  A = te.placeholder((m, n), name="A")
  B = te.compute((m, n), lambda i, j: A[i, j], name="B")
  s = te.create_schedule(B.op)
  xo, xi = s[B].split(B.op.axis[0], factor=x_factor)
  yo, yi = s[B].split(B.op.axis[1], factor=y_factor)
  print(tvm.lower(s, [A, B], simple_mode=True))

print('*'*64)
print('test_split_x2(10, 5)')
print('*'*64)
test_split_x2(10, 5)
print('*'*64)
print('test_tile(10, 5)')
print('*'*64)
test_tile(10, 5)

****************************************************************
test_split_x2(10, 5)
****************************************************************
@main = primfn(A_1: handle, B_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {A: Buffer(A_2: Pointer(float32), float32, [(stride: int32*m: int32)], [], type="auto"),
             B: Buffer(B_2: Pointer(float32), float32, [(stride_1: int32*m)], [], type="auto")}
  buffer_map = {A_1: A, B_1: B}
  preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [m, n: int32], [stride, stride_2: int32], type="auto"), B_1: B_3: Buffer(B_2, float32, [m, n], [stride_1, stride_3: int32], type="auto")} {
  for (i.outer: int32, 0, floordiv((m + 9), 10)) {
    for (i.inner: int32, 0, 10) {
      if @tir.likely((((i.outer*10) + i.inner) < m), dtype=bool) {
        for (j.outer: int32, 0, floordiv((n + 4), 5)) {
          for (j.inner: int32, 0, 5) {
            if @tir.likely((((j.outer

比较一下 ___Split___ 和 ___Tile___ 的行为：用 ___Tile___ 来实现 ___Split___：


In [6]:
test_split_factor(axis=1, factor=5)
test_tile(1, 5)

@main = primfn(A_1: handle, B_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {A: Buffer(A_2: Pointer(float32), float32, [(stride: int32*m: int32)], [], type="auto"),
             B: Buffer(B_2: Pointer(float32), float32, [(stride_1: int32*m)], [], type="auto")}
  buffer_map = {A_1: A, B_1: B}
  preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [m, n: int32], [stride, stride_2: int32], type="auto"), B_1: B_3: Buffer(B_2, float32, [m, n], [stride_1, stride_3: int32], type="auto")} {
  for (i: int32, 0, m) {
    for (j.outer: int32, 0, floordiv((n + 4), 5)) {
      for (j.inner: int32, 0, 5) {
        if @tir.likely((((j.outer*5) + j.inner) < n), dtype=bool) {
          let cse_var_1: int32 = ((j.outer*5) + j.inner)
          B[((i*stride_1) + (cse_var_1*stride_3))] = A[((i*stride) + (cse_var_1*stride_2))]
        }
      }
    }
  }
}


@main = primfn(A_1: handle, B_1: handle) -> ()
  attr = {"from_legacy_te_s

## SchedulePrimitives::Fuse

合并连续的 ___N___ 个维度


In [7]:
def test_fuse(axis0, axis1, simple_mode=True):
  A = te.placeholder((m, n), name="A")
  B = te.compute((m, n), lambda i, j: A[i, j], name="B")
  s = te.create_schedule(B.op)
  # tile to four axes first: (i.outer, j.outer, i.inner, j.inner)
  axes4 = s[B].tile(B.op.axis[0], B.op.axis[1], x_factor=10, y_factor=5)
  # then fuse (i.inner, j.inner) into one axis: (i.inner.j.inner.fused)
  fused = s[B].fuse(axes4[axis0], axes4[axis1])
  ### fused = s[B].fuse(xo, yo)
  print(tvm.lower(s, [A, B], simple_mode=simple_mode))

test_fuse(1, 2, True)

@main = primfn(A_1: handle, B_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {A: Buffer(A_2: Pointer(float32), float32, [(stride: int32*m: int32)], [], type="auto"),
             B: Buffer(B_2: Pointer(float32), float32, [(stride_1: int32*m)], [], type="auto")}
  buffer_map = {A_1: A, B_1: B}
  preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [m, n: int32], [stride, stride_2: int32], type="auto"), B_1: B_3: Buffer(B_2, float32, [m, n], [stride_1, stride_3: int32], type="auto")} {
  for (i.outer: int32, 0, floordiv((m + 9), 10)) {
    for (j.outer.i.inner.fused: int32, 0, (floordiv((n + 4), 5)*10)) {
      if @tir.likely((((i.outer*10) + floormod(j.outer.i.inner.fused, 10)) < m), dtype=bool) {
        for (j.inner: int32, 0, 5) {
          if @tir.likely((((floordiv(j.outer.i.inner.fused, 10)*5) + j.inner) < n), dtype=bool) {
            let cse_var_2: int32 = ((floordiv(j.outer.i.inner.fused, 10)*5) + j.inn

## SchedulePrimitives::Reorder

维度调换，相当于Transpose功能

In [8]:
def test_reorder(axis0, axis1, axis2, axis3):
  A = te.placeholder((m, n), name="A")
  B = te.compute((m, n), lambda i, j: A[i, j], name="B")
  s = te.create_schedule(B.op)
  # tile to four axes first: (i.outer, j.outer, i.inner, j.inner)
  axes4 = s[B].tile(B.op.axis[0], B.op.axis[1], x_factor=10, y_factor=5)
  s[B].reorder(axes4[axis0], axes4[axis1], axes4[axis2], axes4[axis3])
  print(tvm.lower(s, [A, B], simple_mode=True))

test_reorder(0, 1, 2, 3) # no change
test_reorder(2, 1, 0, 3)

@main = primfn(A_1: handle, B_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {A: Buffer(A_2: Pointer(float32), float32, [(stride: int32*m: int32)], [], type="auto"),
             B: Buffer(B_2: Pointer(float32), float32, [(stride_1: int32*m)], [], type="auto")}
  buffer_map = {A_1: A, B_1: B}
  preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [m, n: int32], [stride, stride_2: int32], type="auto"), B_1: B_3: Buffer(B_2, float32, [m, n], [stride_1, stride_3: int32], type="auto")} {
  for (i.outer: int32, 0, floordiv((m + 9), 10)) {
    for (j.outer: int32, 0, floordiv((n + 4), 5)) {
      for (i.inner: int32, 0, 10) {
        if @tir.likely((((i.outer*10) + i.inner) < m), dtype=bool) {
          for (j.inner: int32, 0, 5) {
            if @tir.likely((((j.outer*5) + j.inner) < n), dtype=bool) {
              let cse_var_2: int32 = ((j.outer*5) + j.inner)
              let cse_var_1: int32 = ((i.outer*10) + i.

## SchedulePrimitives::Gpu::Bind


In [9]:
def test_bind():
  A = te.placeholder((n,), name="A")
  B = te.compute(A.shape, lambda i: A[i], name="B")
  s = te.create_schedule(B.op)
  bx, tx = s[B].split(B.op.axis[0], factor=64)
  s[B].bind(bx, te.thread_axis("blockIdx.x"))
  s[B].bind(tx, te.thread_axis("threadIdx.x"))
  print(tvm.lower(s, [A, B], simple_mode=True))
test_bind()

@main = primfn(A_1: handle, B_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {A: Buffer(A_2: Pointer(float32), float32, [(stride: int32*n: int32)], [], type="auto"),
             B: Buffer(B_2: Pointer(float32), float32, [(stride_1: int32*n)], [], type="auto")}
  buffer_map = {A_1: A, B_1: B}
  preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [n], [stride], type="auto"), B_1: B_3: Buffer(B_2, float32, [n], [stride_1], type="auto")} {
  attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = floordiv((n + 63), 64);
  attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 64;
  if @tir.likely((((blockIdx.x*64) + threadIdx.x) < n), dtype=bool) {
    B[(((blockIdx.x*64) + threadIdx.x)*stride_1)] = A[(((blockIdx.x*64) + threadIdx.x)*stride)]
  }
}




## SchedulePrimitives::Compute_At

移动一个Compute Stage 到指定的计算指定维度中，看起来可以用在 Fusion 中。

In [10]:
def test_compute_at(axis):
  A = te.placeholder((m, n), name="A")
  B = te.compute((m, n), lambda i, j: A[i, j] + 1, name="B")
  C = te.compute((m, n), lambda i, j: B[i, j] * 2, name="C")
  s = te.create_schedule(C.op)
  # move computation of B into the first axis of computation of C
  s[B].compute_at(s[C], C.op.axis[axis]) 
  print(tvm.lower(s, [A, B, C], simple_mode=True))

test_compute_at(-1)
test_compute_at(0)

@main = primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {A: Buffer(A_2: Pointer(float32), float32, [(stride: int32*m: int32)], [], type="auto"),
             B: Buffer(B_2: Pointer(float32), float32, [(stride_1: int32*m)], [], type="auto"),
             C: Buffer(C_2: Pointer(float32), float32, [(stride_2: int32*m)], [], type="auto")}
  buffer_map = {A_1: A, B_1: B, C_1: C}
  preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [m, n: int32], [stride, stride_3: int32], type="auto"), B_1: B_3: Buffer(B_2, float32, [m, n], [stride_1, stride_4: int32], type="auto"), C_1: C_3: Buffer(C_2, float32, [m, n], [stride_2, stride_5: int32], type="auto")} {
  for (i: int32, 0, m) {
    for (j: int32, 0, n) {
      B[((i*stride_1) + (j*stride_4))] = (A[((i*stride) + (j*stride_3))] + 1f32)
      C[((i*stride_2) + (j*stride_5))] = (B[((i*stride_1) + (j*stride_4))]*2f32)
    }
  }
}


@main = 

## SchedulePrimitives::Compute_Inline

In [11]:
def test_compute_inline():
  A = te.placeholder((m,), name="A")
  B = te.compute((m,), lambda i: A[i] + 1, name="B")
  C = te.compute((m,), lambda i: B[i] * 2, name="C")
  s = te.create_schedule(C.op)
  # mark one stage as inline
  s[B].compute_inline()
  print(tvm.lower(s, [A, B, C], simple_mode=True))

test_compute_inline()

@main = primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {A: Buffer(A_2: Pointer(float32), float32, [(stride: int32*m: int32)], [], type="auto"),
             B: Buffer(B_2: Pointer(float32), float32, [(stride_1: int32*m)], [], type="auto"),
             C: Buffer(C_2: Pointer(float32), float32, [(stride_2: int32*m)], [], type="auto")}
  buffer_map = {A_1: A, B_1: B, C_1: C}
  preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [m], [stride], type="auto"), B_1: B_3: Buffer(B_2, float32, [m], [stride_1], type="auto"), C_1: C_3: Buffer(C_2, float32, [m], [stride_2], type="auto")} {
  for (i: int32, 0, m) {
    C[(i*stride_2)] = ((A[(i*stride)] + 1f32)*2f32)
  }
}




## SchedulePrimitives::Compute_Root

将一个计算移动到根上

In [12]:
def test_compute_root():
  A = te.placeholder((m,), name="A")
  B = te.compute((m,), lambda i: A[i] + 1, name="B")
  C = te.compute((m,), lambda i: B[i] * 2, name="C")
  s = te.create_schedule(C.op)
  s[B].compute_at(s[C], C.op.axis[0])
  # move computation of one stage to the root
  s[B].compute_root()
  print(tvm.lower(s, [A, B, C], simple_mode=True))

test_compute_root()

@main = primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {A: Buffer(A_2: Pointer(float32), float32, [(stride: int32*m: int32)], [], type="auto"),
             B: Buffer(B_2: Pointer(float32), float32, [(stride_1: int32*m)], [], type="auto"),
             C: Buffer(C_2: Pointer(float32), float32, [(stride_2: int32*m)], [], type="auto")}
  buffer_map = {A_1: A, B_1: B, C_1: C}
  preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [m], [stride], type="auto"), B_1: B_3: Buffer(B_2, float32, [m], [stride_1], type="auto"), C_1: C_3: Buffer(C_2, float32, [m], [stride_2], type="auto")} {
  for (i: int32, 0, m) {
    B[(i*stride_1)] = (A[(i*stride)] + 1f32)
  }
  for (i_1: int32, 0, m) {
    C[(i_1*stride_2)] = (B[(i_1*stride_1)]*2f32)
  }
}


