# How to optimize GEMM on CPU

In this tutorial, we will demonstrate how to use TVM to optimize square matrix multiplication on CPU.

There are two important optimizations on intense computation applications executed on CPU:
- Increase the cache hit rate of memory access. Both complex numerical computation and hot-spot memory access can be accelerated from high cache hit rate. This requires us to transform the origin memory access pattern to the pattern fits the cache policy.
- SIMD (Single instruction multi-data), or we call it vector processing unit. Every time, a small batch of data, rather than a single grid, will be processed. This requires us to transform the data access pattern in the loop body in uniform pattern so that the LLVM backend can lower it to SIMD.

## Preparation

We first import TVM and write a baseline implementation for matrix multiplication.

In [1]:
import numpy
import tvm
from tvm import te

# The size of the matrix
# (M, K) x (K, N)
# You are free to try out different shapes, sometimes TVM optimization outperforms numpy with MKL.
M = 1024
K = 1024
N = 1024

# The default tensor type in tvm
dtype = "float32"

# using Intel AVX2(Advanced Vector Extensions) ISA for SIMD
# To get the best performance, please change the following line
# to llvm -mcpu=core-avx2, or specific type of CPU you use
target = 'llvm -mcpu=core-avx2'
ctx = tvm.context(target, 0)

We need to define the compute, or the algorithm, of matrix multiplication first.

In [2]:
A = te.placeholder((M, K), name='A')
B = te.placeholder((K, N), name='B')
k = te.reduce_axis((0, K), 'k')
C = te.compute(
           (M, N),
           lambda y, x: te.sum(A[y, k] * B[k, x], axis=k),
           name='C')

In TVM, we can always inspect lower level IR to debug or optimize our schedule. Here is the generated IR using our baseline schedule.

In [3]:
# Default schedule
s = te.create_schedule(C.op)
print(tvm.lower(s, [A, B, C], simple_mode=True))

PrimFunc([A, B, C]) attrs={"tir.noalias": (bool)1, "global_symbol": "main"} {
  for (y, 0, 1024) {
    for (x, 0, 1024) {
      C[((y*1024) + x)] = 0f
      for (k, 0, 1024) {
        C[((y*1024) + x)] = (C[((y*1024) + x)] + (A[((y*1024) + k)]*B[((k*1024) + x)]))
      }
    }
  }
}




TVM also provides profiler to measure the kernel latency.

In [4]:
# Random generated tensor for testing
a = tvm.nd.array(numpy.random.rand(M, K).astype(dtype), ctx)
b = tvm.nd.array(numpy.random.rand(K, N).astype(dtype), ctx)
c = tvm.nd.array(numpy.zeros((M, N), dtype=dtype), ctx)

func = tvm.build(s, [A, B, C], target=target, name='mmult')
func(a, b, c)

evaluator = func.time_evaluator(func.entry_name, ctx, number=1)
print('Baseline: %f sec' % evaluator(a, b, c).mean)

Baseline: 4.860180 sec


## Blocking

A important trick to enhance the cache hit rate is blocking â€” data chunk will be computed block by block. The memory access inside the block is a small neighbourhood which is with high memory locality. In this tutorial, I picked up 32 as the blocking factor. So the block will fill 32 * 32 * sizeof(float) which is 4KB in the cache whose total size is 32KB (L1 data cache).

In [5]:
bn = 32
s = te.create_schedule(C.op)

# Get the iteration axis from the operator
y, x = s[C].op.axis
k, = s[C].op.reduce_axis

# Blocking by loop tiling
yo, yi = s[C].split(y, bn)
xo, xi = s[C].split(x, bn)
ko, ki = s[C].split(k, factor=4)

# Hoist reduction domain outside the blocking loop
s[C].reorder(yo, xo, ko, ki, yi, xi)

func = tvm.build(s, [A, B, C], target=target, name='mmult')
func(a, b, c)

# By simply tiling the loop 32x32, and hoisting ko, ki outside the blocking loops,
# we can see big speedup compared with the baseline.
evaluator = func.time_evaluator(func.entry_name, ctx, number=10)
print('Opt1: %f' % evaluator(a, b, c).mean)

Opt1: 0.224869


In [6]:
print(tvm.lower(s, [A, B, C], simple_mode=True))

PrimFunc([A, B, C]) attrs={"tir.noalias": (bool)1, "global_symbol": "main"} {
  for (y.outer, 0, 32) {
    for (x.outer, 0, 32) {
      for (y.inner.init, 0, 32) {
        for (x.inner.init, 0, 32) {
          C[((((y.outer*32768) + (y.inner.init*1024)) + (x.outer*32)) + x.inner.init)] = 0f
        }
      }
      for (k.outer, 0, 256) {
        for (k.inner, 0, 4) {
          for (y.inner, 0, 32) {
            for (x.inner, 0, 32) {
              C[((((y.outer*32768) + (y.inner*1024)) + (x.outer*32)) + x.inner)] = (C[((((y.outer*32768) + (y.inner*1024)) + (x.outer*32)) + x.inner)] + (A[((((y.outer*32768) + (y.inner*1024)) + (k.outer*4)) + k.inner)]*B[((((k.outer*4096) + (k.inner*1024)) + (x.outer*32)) + x.inner)]))
            }
          }
        }
      }
    }
  }
}




## Vectorization

Another important trick is vectorization. When the memory access pattern is uniform, the compiler can detect this pattern and pass the continuous memory to vector processor. In TVM, we can use vectorize interface to hint the compiler this pattern, so that we can accelerate it vastly.

In this tutorial, we chose to vectorize the inner loop row data since it is cache friendly.

In [7]:
# Vectorization
s[C].vectorize(xi)

func = tvm.build(s, [A, B, C], target=target, name='mmult')
func(a, b, c)

evaluator = func.time_evaluator(func.entry_name, ctx, number=10)
print('Opt2: %f' % evaluator(a, b, c).mean)

Opt2: 0.244345


In [8]:
print(tvm.lower(s, [A, B, C], simple_mode=True))

PrimFunc([A, B, C]) attrs={"tir.noalias": (bool)1, "global_symbol": "main"} {
  for (y.outer, 0, 32) {
    for (x.outer, 0, 32) {
      for (y.inner.init, 0, 32) {
        C[ramp((((y.outer*32768) + (y.inner.init*1024)) + (x.outer*32)), 1, 32)] = x32(0f)
      }
      for (k.outer, 0, 256) {
        for (k.inner, 0, 4) {
          for (y.inner, 0, 32) {
            C[ramp((((y.outer*32768) + (y.inner*1024)) + (x.outer*32)), 1, 32)] = (C[ramp((((y.outer*32768) + (y.inner*1024)) + (x.outer*32)), 1, 32)] + (x32(A[((((y.outer*32768) + (y.inner*1024)) + (k.outer*4)) + k.inner)])*B[ramp((((k.outer*4096) + (k.inner*1024)) + (x.outer*32)), 1, 32)]))
          }
        }
      }
    }
  }
}




## Loop Permutation

If we look at the above IR, we can see the inner loop row data is vectorized and B is transformed into PackedB. The traversal of PackedB is sequential now. So we will look at the access pattern of A. In current schedule, A is accessed column by column which is not cache friendly. If we change the nested loop order of ki and inner axes xi, the access pattern for A matrix is more cache friendly.

In [9]:
s = te.create_schedule(C.op)
y, x = s[C].op.axis
k, = s[C].op.reduce_axis
yo, xo, yi, xi = s[C].tile(y, x, bn, bn)
ko, ki = s[C].split(k, factor=4)

# re-ordering
s[C].reorder(yo, xo, ko, yi, ki, xi)
s[C].vectorize(xi)

func = tvm.build(s, [A, B, C], target=target, name='mmult')
func(a, b, c)

evaluator = func.time_evaluator(func.entry_name, ctx, number=10)
print('Opt3: %f' % evaluator(a, b, c).mean)

Opt3: 0.066508


In [10]:
print(tvm.lower(s, [A, B, C], simple_mode=True))

PrimFunc([A, B, C]) attrs={"tir.noalias": (bool)1, "global_symbol": "main"} {
  for (y.outer, 0, 32) {
    for (x.outer, 0, 32) {
      for (y.inner.init, 0, 32) {
        C[ramp((((y.outer*32768) + (y.inner.init*1024)) + (x.outer*32)), 1, 32)] = x32(0f)
      }
      for (k.outer, 0, 256) {
        for (y.inner, 0, 32) {
          for (k.inner, 0, 4) {
            C[ramp((((y.outer*32768) + (y.inner*1024)) + (x.outer*32)), 1, 32)] = (C[ramp((((y.outer*32768) + (y.inner*1024)) + (x.outer*32)), 1, 32)] + (x32(A[((((y.outer*32768) + (y.inner*1024)) + (k.outer*4)) + k.inner)])*B[ramp((((k.outer*4096) + (k.inner*1024)) + (x.outer*32)), 1, 32)]))
          }
        }
      }
    }
  }
}




## Parallel

Futhermore, we can also utilize multi-core processors to do the thread-level parallelization.

In [11]:
# parallel
s[C].parallel(yo)

func = tvm.build(s, [A, B, C], target=target, name='mmult')
func(a, b, c)

evaluator = func.time_evaluator(func.entry_name, ctx, number=10)
print('Opt4: %f' % evaluator(a, b, c).mean)

Opt4: 0.041917


In [12]:
print(tvm.lower(s, [A, B, C], simple_mode=True))

PrimFunc([A, B, C]) attrs={"tir.noalias": (bool)1, "global_symbol": "main"} {
  parallel (y.outer, 0, 32) {
    for (x.outer, 0, 32) {
      for (y.inner.init, 0, 32) {
        C[ramp((((y.outer*32768) + (y.inner.init*1024)) + (x.outer*32)), 1, 32)] = x32(0f)
      }
      for (k.outer, 0, 256) {
        for (y.inner, 0, 32) {
          for (k.inner, 0, 4) {
            C[ramp((((y.outer*32768) + (y.inner*1024)) + (x.outer*32)), 1, 32)] = (C[ramp((((y.outer*32768) + (y.inner*1024)) + (x.outer*32)), 1, 32)] + (x32(A[((((y.outer*32768) + (y.inner*1024)) + (k.outer*4)) + k.inner)])*B[ramp((((k.outer*4096) + (k.inner*1024)) + (x.outer*32)), 1, 32)]))
          }
        }
      }
    }
  }
}


