# [How to do reduction in TVM](https://tvm.apache.org/docs/how_to/work_with_schedules/reduction.html#sphx-glr-how-to-work-with-schedules-reduction-py)

## SchedulePrimitives::Reduce

In [17]:
from __future__ import absolute_import, print_function

import tvm
import tvm.testing
from tvm import te
import numpy as np

n = te.var("n")
m = te.var("m")

def test_reduce():
  A = te.placeholder((n, m), name="A")
  k = te.reduce_axis((0, m), "k")
  B = te.compute((n,), lambda i: te.sum(A[i, k], axis=k) , name="B")
  s = te.create_schedule(B.op)
  print(tvm.lower(s, [A, B], simple_mode=True))

test_reduce()

@main = primfn(A_1: handle, B_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {A: Buffer(A_2: Pointer(float32), float32, [(stride: int32*n: int32)], [], type="auto"),
             B: Buffer(B_2: Pointer(float32), float32, [(stride_1: int32*n)], [], type="auto")}
  buffer_map = {A_1: A, B_1: B}
  preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [n, m: int32], [stride, stride_2: int32], type="auto"), B_1: B_3: Buffer(B_2, float32, [n], [stride_1], type="auto")} {
  for (i: int32, 0, n) {
    B[(i*stride_1)] = 0f32
    for (k: int32, 0, m) {
      B[(i*stride_1)] = (B[(i*stride_1)] + A[((i*stride) + (k*stride_2))])
    }
  }
}




* Split reduce axis

In [18]:
def test_split_reduce():
  A = te.placeholder((n, m), name="A")
  k = te.reduce_axis((0, m), "k")
  B = te.compute((n,), lambda i: te.sum(A[i, k], axis=k) , name="B")
  s = te.create_schedule(B.op)
  ko, ki = s[B].split(B.op.reduce_axis[0], factor=16)
  xo, xi = s[B].split(B.op.axis[0], factor=32)
  print(tvm.lower(s, [A, B], simple_mode=True))

test_split_reduce()

@main = primfn(A_1: handle, B_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {A: Buffer(A_2: Pointer(float32), float32, [(stride: int32*n: int32)], [], type="auto"),
             B: Buffer(B_2: Pointer(float32), float32, [(stride_1: int32*n)], [], type="auto")}
  buffer_map = {A_1: A, B_1: B}
  preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [n, m: int32], [stride, stride_2: int32], type="auto"), B_1: B_3: Buffer(B_2, float32, [n], [stride_1], type="auto")} {
  for (i.outer: int32, 0, floordiv((n + 31), 32)) {
    for (i.inner: int32, 0, 32) {
      if @tir.likely((((i.outer*32) + i.inner) < n), dtype=bool) {
        B[(((i.outer*32) + i.inner)*stride_1)] = 0f32
      }
      if @tir.likely((((i.outer*32) + i.inner) < n), dtype=bool) {
        for (k.outer: int32, 0, floordiv((m + 15), 16)) {
          for (k.inner: int32, 0, 16) {
            if @tir.likely((((k.outer*16) + k.inner) < m), dtype=bool) {
  

* Bind row in GPU kernel

In [19]:
def test_reduce_bind():
  A = te.placeholder((n, m), name="A")
  k = te.reduce_axis((0, m), "k")
  B = te.compute((n,), lambda i: te.sum(A[i, k], axis=k) , name="B")
  s = te.create_schedule(B.op)
  ko, ki = s[B].split(B.op.reduce_axis[0], factor=16)
  xo, xi = s[B].split(B.op.axis[0], factor=32)
  s[B].bind(xo, te.thread_axis("blockIdx.x"))
  s[B].bind(xi, te.thread_axis("threadIdx.x"))
  print(tvm.lower(s, [A, B], simple_mode=True))

test_reduce_bind()

@main = primfn(A_1: handle, B_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {A: Buffer(A_2: Pointer(float32), float32, [(stride: int32*n: int32)], [], type="auto"),
             B: Buffer(B_2: Pointer(float32), float32, [(stride_1: int32*n)], [], type="auto")}
  buffer_map = {A_1: A, B_1: B}
  preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [n, m: int32], [stride, stride_2: int32], type="auto"), B_1: B_3: Buffer(B_2, float32, [n], [stride_1], type="auto")} {
  attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = floordiv((n + 31), 32);
  attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 32 {
    if @tir.likely((((blockIdx.x*32) + threadIdx.x) < n), dtype=bool) {
      B[(((blockIdx.x*32) + threadIdx.x)*stride_1)] = 0f32
    }
    for (k.outer: int32, 0, floordiv((m + 15), 16)) {
      for (k.inner: int32, 0, 16) {
        i

## SchedulePrimitives::Refactor

构建归约的一个问题是我们不能简单地在归约轴上并行化，需要划分归约的计算，将局部归约结果存储在临时数组中，然后再对临时数组进行归约。为了简化这个问题，引入 ___rfactor___ 原语对计算进行重写。下面这个这个列子对reduce维度进行了rfactor操作，目的是想让 _16_ 个 thread 能够同时进行规约计算，然后再使用 _1_ 个 thread 对 _16_ 个中间结果最后进行一次 _16_ 元素的规约，以便于GPU处理。

In [20]:
def test_reduce_rfactor():
  A = te.placeholder((n, m), name="A")
  k = te.reduce_axis((0, m), "k")
  B = te.compute((n,), lambda i: te.sum(A[i, k], axis=k) , name="B")
  s = te.create_schedule(B.op)
  ko, ki = s[B].split(B.op.reduce_axis[0], factor=16)
  BF = s.rfactor(B, ki)
  print(tvm.lower(s, [A, B], simple_mode=True))
  print(s[B].op.body)

test_reduce_rfactor()

@main = primfn(A_1: handle, B_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {A: Buffer(A_2: Pointer(float32), float32, [(stride: int32*n: int32)], [], type="auto"),
             B: Buffer(B_2: Pointer(float32), float32, [(stride_1: int32*n)], [], type="auto")}
  buffer_map = {A_1: A, B_1: B}
  preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [n, m: int32], [stride, stride_2: int32], type="auto"), B_1: B_3: Buffer(B_2, float32, [n], [stride_1], type="auto")} {
  allocate(B.rf: Pointer(global float32), float32, [(n*16)]), storage_scope = global {
    for (k.inner: int32, 0, 16) {
      for (i: int32, 0, n) {
        B.rf_1: Buffer(B.rf, float32, [(16*n)], [])[((k.inner*n) + i)] = 0f32
        for (k.outer: int32, 0, floordiv((m + 15), 16)) {
          if @tir.likely((((k.outer*16) + k.inner) < m), dtype=bool) {
            B.rf_1[((k.inner*n) + i)] = (B.rf_1[((k.inner*n) + i)] + A[((i*stride) + (((k.outer*16

我们用 C 语言描述上面的计算过程:

```C++
#include <array>

// Row
#define N 100
// Col
#define M 256
// Parallel 
#define K 16

int main(void) {
  float A[N][M];
  float B[N];
  float B_rf[K][N];

  for(int k_inner = 0; k_inner < K; ++k_inner) {
    for (int i = 0; i < N; ++i) {
      B_rf[k_inner][i] = 0.0F;
      for(int k_outer = 0; k_outer < (M + 15) >> 4; ++k_outer) {
        if (k_outer * K + k_inner < M) {
          B_rf[k_inner][i] += A[i][k_outer * K + k_inner];
        }
      }
    }
  }
  for (int i = 0; i < N; ++i) {
    B[i] = 0.0F;
    for (int k = 0 ; k < K; ++k) {
      B[i] += B_rf[k];
    }
  }

  return 0;
}

```


## SchedulePrimitives::set_store_predicate

Cross Thread Reduction

In [21]:
def cross_thread_reduction():
  A = te.placeholder((n, m), name="A")
  k = te.reduce_axis((0, m), "k")
  B = te.compute((n,), lambda i: te.sum(A[i, k], axis=k) , name="B")
  s = te.create_schedule(B.op)
  ko, ki = s[B].split(B.op.reduce_axis[0], factor=16)
  BF = s.rfactor(B, ki)
  
  xo, xi = s[B].split(s[B].op.axis[0], factor=32)
  s[B].bind(xo, te.thread_axis("blockIdx.x"))
  s[B].bind(xi, te.thread_axis("threadIdx.y"))
  tx = te.thread_axis("threadIdx.x")
  s[B].bind(s[B].op.reduce_axis[0], tx)
  s[BF].compute_at(s[B], s[B].op.reduce_axis[0])
  s[B].set_store_predicate(tx.var.equal(0))
  print(tvm.lower(s, [A, B], simple_mode=True))
  fcuda = tvm.build(s, [A, B], "cuda")
  print('*'*64)
  print('CUDA Source Code')
  print('*'*64)
  print(fcuda.imported_modules[0].get_source())

cross_thread_reduction()

@main = primfn(A_1: handle, B_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {A: Buffer(A_2: Pointer(float32), float32, [(stride: int32*n: int32)], [], type="auto"),
             B: Buffer(B_2: Pointer(float32), float32, [(stride_1: int32*n)], [], type="auto")}
  buffer_map = {A_1: A, B_1: B}
  preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [n, m: int32], [stride, stride_2: int32], type="auto"), B_1: B_3: Buffer(B_2, float32, [n], [stride_1], type="auto")} {
  attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = floordiv((n + 31), 32);
  allocate(B.rf: Pointer(local float32), float32, [1]), storage_scope = local;
  allocate(reduce_temp0: Pointer(local float32), float32, [1]), storage_scope = local;
  attr [IterVar(threadIdx.y: int32, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 32;
  attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadId