<a href="https://colab.research.google.com/github/guanrenyang/TW_TVM/blob/main/tile_matmul.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%shell
# Installs the latest dev build of TVM from PyPI, with CUDA enabled. To use this,
# you must request a Google Colab instance with a GPU by going to Runtime ->
# Change runtime type -> Hardware accelerator -> GPU. If you wish to build from
# source, see see https://tvm.apache.org/docs/install/from_source.html
pip install tlcpack-nightly-cu113 --pre -f https://tlcpack.ai/wheels

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://tlcpack.ai/wheels




In [None]:
import tvm
import tvm.testing
from tvm import te
import numpy as np
import random

In [None]:
tgt_gpu = tvm.target.Target(target="cuda", host="llvm")
gpu_0 = tvm.device(tgt_gpu.kind.name, 0)

tgt_cpu = tvm.target.Target(target="llvm", host="llvm")
cpu = tvm.device(tgt_cpu.kind.name, 0)

In [None]:
M = 1024
K = 2048
N = 1024
N_dim = 32
K_pruned = 128
N_pruned = 128

sparsity = 
block_num = (N_pruned + N_dim - 1) // N_dim
N_ori = N//block_num

In [None]:
def mask_gen(N: int, N_pruned: int, base = 0):
  '''A tool to generate a mask (in the type of python list) given the number of 
  origial elements `N` and that of remaining elements `N_pruned`'''
  mask_keep = list(range(N))
  random.shuffle(mask_keep) # shuffle is an in-place operation
  mask_keep = mask_keep[ : N_pruned]
  mask_keep.sort()
  for i, _ in enumerate(mask_keep):
      mask_keep[i] = mask_keep[i]  + base
  return mask_keep
'''
In the cuda version of TW, `mask[i]+i` is the real index relative to the beginning of its block of the corresponding element

In this tvm version, we set `mask[i]` to be the real index instead of `mask[i]+i`.
'''
print(mask_gen(16, 8))

[0, 3, 4, 9, 10, 11, 12, 13]


In [None]:
def get_B_Stream(B, mask_k_list, mask_n_list, block_num, N_pruned_perBlock, K_pruned, N_original):
  ''' To generate a python list of B tiles with type of `list(tvm.nd.NDArray)`
  Parameters:
  * B(tvm.nd.NDArray): the undivided matrix B
  * mask_k_list: mask of K dimension of each block with type of `list(list(int))`
  * mask_n_list: mask of N dimension of each block with type of `list(list(int))`
  * block_num: the number of blocks in B
  * N_pruned_perBlock: the number of remaining elements in the N dimension of each block
  * K_pruned: the number of remaining elements in the K dimension of each block
  * N_original: the number of elements in the N dimension of the undivided matrix
                used to compute the offset of each block in the N dimension 
  '''
  B_transposed_tiled_list = []
  for bn in range(block_num):
      mask_k = mask_k_list[bn]
      mask_n = mask_n_list[bn]
      
      dst = np.zeros((N_pruned_perBlock, K_pruned))
      for i in range(K_pruned):
          for j in range(N_pruned_perBlock):
              idx_col = mask_k[i] 
              idx_row = mask_n[j] + N_original * bn
              dst[j, i] = B[idx_col, idx_row]
      B_transposed_tiled_list.append(dst)
  return B_transposed_tiled_list

'''`_verify` means the variable is used for unit test'''
B_verify = np.random.random((8, 16))
K_verify = 8
N_verify = 16

K_pruned_verify = 4 # the numbers of remaining elements in the K dimension of each block are the same
N_pruned_global_verify = 8 # but those in the N dimension are different

tilesize_verify = 2

block_num_verify = (N_pruned_global_verify+tilesize_verify-1)//tilesize_verify
N_original_perBlock_verify = N_verify // block_num_verify
print("block_num_verify:",block_num_verify)
print("N_original_perBlock_verify:", N_original_perBlock_verify)

mask_k_list_verify = [mask_gen(K_verify, K_pruned_verify) for _ in range(block_num_verify)]
mask_n_list_verify = [mask_gen(N_original_perBlock_verify, tilesize_verify) for _ in range(block_num_verify)]
print("mask_k_list_verify:", mask_k_list_verify)
print("mask_n_list_verify:", mask_n_list_verify)

B_transposed_tiled_list = get_B_Stream(B_verify, mask_k_list_verify,\
                                                mask_n_list_verify, block_num_verify, \
                                                tilesize_verify, K_pruned_verify, N_original_perBlock_verify)
block_to_check = 0# number of block to check, from 0 to 3
print("\nBlock of dense matrix:\n", B_verify[0:8, 8*block_to_check:8*block_to_check+4])
print("\nBlock of sparse matrix\n", B_transposed_tiled_list[block_to_check].T)

block_num_verify: 4
N_original_perBlock_verify: 4
mask_k_list_verify: [[0, 1, 2, 4], [3, 4, 6, 7], [0, 1, 2, 4], [0, 1, 2, 7]]
mask_n_list_verify: [[1, 3], [1, 3], [0, 3], [1, 2]]

Block of dense matrix:
 [[0.82392874 0.2529673  0.53781596 0.3805554 ]
 [0.55990472 0.2779939  0.36911807 0.63154789]
 [0.34482586 0.02511602 0.22861381 0.3542492 ]
 [0.48808761 0.97802291 0.14443363 0.83602811]
 [0.26096124 0.26661773 0.0823008  0.2562803 ]
 [0.74169749 0.71631852 0.64701425 0.95544594]
 [0.27538455 0.34235265 0.35923464 0.26231561]
 [0.02141543 0.3525766  0.7727293  0.66557609]]

Block of sparse matrix
 [[0.2529673  0.3805554 ]
 [0.2779939  0.63154789]
 [0.02511602 0.3542492 ]
 [0.26661773 0.2562803 ]]


In [None]:
def get_tiled_matmul_kernel():
  '''TW Tiled-Gemm kernel
  Input of the kernel: 
  * A_transposed
  * B_transposed_tiled
  * mask_k
  Output of the kernel:
  * '''
  A_transposed = te.placeholder((K, M), name='A_transposed')
  mask_k = te.placeholder((K_pruned,), name='mask_k', dtype='int')
  B_transposed_tiled = te.placeholder((N_dim, K_pruned), name='B_tiled')
  # mask_k = [0 for i in range(K_pruned)]
  # mask_n = te.placeholder((N_dim,), name="mask_n")
  
  A_transposed_skipped = te.compute((K_pruned, M), lambda i,j: A_transposed[mask_k[i], j], name='A_skipped')

  k = te.reduce_axis((0, K_pruned), name='k')
  C_transposed_skipped = te.compute((N_dim, M),lambda j,i: te.sum(A_transposed_skipped[k, i]*B_transposed_tiled[j, k], axis=k),name='C_transposed_skipped')
  
  s = te.create_schedule(C_transposed_skipped.op)

  return s, [C_transposed_skipped, A_transposed, B_transposed_tiled, mask_k]

'''Testing'''
schedule, placeholders = get_tiled_matmul_kernel()
for ph in placeholders:
  print(ph.op.name, ph.shape)

print(tvm.lower(schedule, placeholders, simple_mode=True))

tiled_matmul_kernel = tvm.build(schedule, placeholders, target=tgt_cpu, name="tiled_matmul")  
A_transposed_data = tvm.nd.array(np.random.uniform(size=(K, M)).astype(placeholders[1].dtype), cpu)
B_transposed_tiled_data = tvm.nd.array(np.random.uniform(size=(N_dim, K_pruned)).astype(placeholders[2].dtype), cpu)
mask_k_data = tvm.nd.array(np.array(mask_gen(K, K_pruned)).astype(placeholders[3].dtype), cpu)

C_transposed_skipped_data = tvm.nd.array(np.random.uniform(size=(N_dim, M)).astype(placeholders[0].dtype), cpu)

tiled_matmul_kernel(C_transposed_skipped_data, A_transposed_data, B_transposed_tiled_data, mask_k_data)

def tiled_matmul_test(A_transposed, B_transposed_tiled, mask_k):
    A_transposed_skipped = np.zeros((K_pruned, M))
    for i in range(K_pruned):
        for j in range(M):
            A_transposed_skipped[i, j] = A_transposed[mask_k[i], j]
    C_transposed_skipped = (A_transposed_skipped.T @ B_transposed_tiled.T).T
    return C_transposed_skipped

tvm.testing.assert_allclose(C_transposed_skipped_data.numpy(), tiled_matmul_test(A_transposed_data.numpy(), B_transposed_tiled_data.numpy(), mask_k_data.numpy()), 1e-6)


C_transposed_skipped [32, 1024]
A_transposed [2048, 1024]
B_tiled [32, 128]
mask_k [128]
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(C_transposed_skipped: T.Buffer((32, 1024), "float32"), A_transposed: T.Buffer((2048, 1024), "float32"), B_tiled: T.Buffer((32, 128), "float32"), mask_k: T.Buffer((128,), "int32")):
        T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
        A_skipped = T.allocate([131072], "float32", "global")
        A_skipped_1 = T.Buffer((131072,), data=A_skipped)
        for i, j in T.grid(128, 1024):
            A_transposed_1 = T.Buffer((2097152,), data=A_transposed.data)
            mask_k_1 = T.Buffer((128,), "int32", data=mask_k.data)
            A_skipped_1[i * 1024 + j] = A_transposed_1[mask_k_1[i] * 1024 + j]
        for j, i in T.grid(32, 1024):
            C_transposed_skipped_1 = T.Buffer((32768,), data=C_transposed_skipped.data)

In [None]:
def get_tiled_matmul_kernel():
  '''TW Tiled-Gemm kernel
  Input of the kernel: 
  * A_transposed
  * B_transposed_tiled
  * mask_k
  Output of the kernel:
  * '''
  A_transposed = te.placeholder((K, M), name='A_transposed')
  mask_k = te.placeholder((K_pruned,), name='mask_k', dtype='int')
  B_transposed_tiled = te.placeholder((N_dim, K_pruned), name='B_tiled')
  # mask_k = [0 for i in range(K_pruned)]
  # mask_n = te.placeholder((N_dim,), name="mask_n")
  
  A_transposed_skipped = te.compute((K_pruned, M), lambda i,j: A_transposed[mask_k[i], j], name='A_skipped')

  k = te.reduce_axis((0, K_pruned), name='k')
  C_transposed_skipped = te.compute((N_dim, M),lambda j,i: te.sum(A_transposed_skipped[k, i]*B_transposed_tiled[j, k], axis=k),name='C_transposed_skipped')
  
  s = te.create_schedule(C_transposed_skipped.op)

  print(C_transposed_skipped.op.axis)
  # print(s[C_transposed_skipped].reduce_axis)

get_tiled_matmul_kernel()

[T.iter_var(j, T.Range(0, 32), "DataPar", ""), T.iter_var(i, T.Range(0, 1024), "DataPar", "")]
