<a href="https://colab.research.google.com/github/guanrenyang/TW_TVM/blob/main/tile_matmul.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%env LD_LIBRARY_PATH=/usr/local/cuda/lib
%env PATH=/usr/local/cuda/bin:/usr/bin

env: LD_LIBRARY_PATH=/usr/local/cuda/lib
env: PATH=/usr/local/cuda/bin:/usr/bin


In [2]:
import tvm
import tvm.testing
from tvm import te
import numpy as np
import random

In [3]:
tgt_gpu = tvm.target.Target(target="cuda", host="llvm")
gpu_0 = tvm.device(tgt_gpu.kind.name, 0)

tgt_cpu = tvm.target.Target(target="llvm", host="llvm")
cpu = tvm.device(tgt_cpu.kind.name, 0)

In [7]:
M = 32
N = 32
K = 16
tile_size = 
#int, N: int, K:int, K_pruned_max: int, N_pruned:int, tile_size:int,

In [8]:
def mask_gen(N: int, N_pruned: int, base = 0):
    '''A tool to generate a mask (in the type of python list) given the number of 
    origial elements `N` and that of remaining elements `N_pruned`'''
    mask_keep = list(range(N))
    random.shuffle(mask_keep) # shuffle is an in-place operation
    mask_keep = mask_keep[ : N_pruned]
    mask_keep.sort()
    for i, _ in enumerate(mask_keep):
        mask_keep[i] = mask_keep[i]  + base
    return mask_keep
'''
In the cuda version of TW, `mask[i]+i` is the real index relative to the beginning of its block of the corresponding element

In this tvm version, we set `mask[i]` to be the real index instead of `mask[i]+i`.
'''
print(mask_gen(16, 8))

[2, 4, 6, 7, 8, 12, 13, 14]


In [9]:
def get_B_Stream(B, mask_k_list, mask_n_list, block_num, N_pruned_perBlock, K_pruned, N_original):
    ''' To generate a python list of B tiles with type of `list(tvm.nd.NDArray)`
    Parameters:
    * B(tvm.nd.NDArray): the undivided matrix B
    * mask_k_list: mask of K dimension of each block with type of `list(list(int))`
    * mask_n_list: mask of N dimension of each block with type of `list(list(int))`
    * block_num: the number of blocks in B
    * N_pruned_perBlock: the number of remaining elements in the N dimension of each block
    * K_pruned: the number of remaining elements in the K dimension of each block
    * N_original: the number of elements in the N dimension of the undivided matrix
                    used to compute the offset of each block in the N dimension 
    '''
    B_transposed_tiled_list = []
    for bn in range(block_num):
        mask_k = mask_k_list[bn]
        mask_n = mask_n_list[bn]
        
        dst = np.zeros((N_pruned_perBlock, K_pruned))
        for i in range(K_pruned):
            for j in range(N_pruned_perBlock):
                idx_col = mask_k[i] 
                idx_row = mask_n[j] + N_original * bn
                dst[j, i] = B[idx_col, idx_row]
        B_transposed_tiled_list.append(dst)
    return B_transposed_tiled_list

'''`_verify` means the variable is used for unit test'''
B_verify = np.random.random((8, 16))
K_verify = 8
N_verify = 16

K_pruned_verify = 4 # the numbers of remaining elements in the K dimension of each block are the same
N_pruned_global_verify = 8 # but those in the N dimension are different

tilesize_verify = 2

block_num_verify = (N_pruned_global_verify+tilesize_verify-1)//tilesize_verify
N_original_perBlock_verify = N_verify // block_num_verify
print("block_num_verify:",block_num_verify)
print("N_original_perBlock_verify:", N_original_perBlock_verify)

mask_k_list_verify = [mask_gen(K_verify, K_pruned_verify) for _ in range(block_num_verify)]
mask_n_list_verify = [mask_gen(N_original_perBlock_verify, tilesize_verify) for _ in range(block_num_verify)]
print("mask_k_list_verify:", mask_k_list_verify)
print("mask_n_list_verify:", mask_n_list_verify)

B_transposed_tiled_list = get_B_Stream(B_verify, mask_k_list_verify,\
                                                mask_n_list_verify, block_num_verify, \
                                                tilesize_verify, K_pruned_verify, N_original_perBlock_verify)
block_to_check = 0# number of block to check, from 0 to 3
print("\nBlock of dense matrix:\n", B_verify[0:8, 8*block_to_check:8*block_to_check+4])
print("\nBlock of sparse matrix\n", B_transposed_tiled_list[block_to_check].T)

block_num_verify: 4
N_original_perBlock_verify: 4
mask_k_list_verify: [[1, 4, 6, 7], [0, 1, 2, 5], [2, 4, 5, 7], [2, 3, 5, 6]]
mask_n_list_verify: [[1, 2], [0, 3], [1, 3], [0, 1]]

Block of dense matrix:
 [[0.7338991  0.9039647  0.64008491 0.93956132]
 [0.30641148 0.19852933 0.52610446 0.96497124]
 [0.83806996 0.41329319 0.86986179 0.64806269]
 [0.01819067 0.86111985 0.45010897 0.7578844 ]
 [0.4083653  0.53162131 0.32541112 0.35413239]
 [0.54418759 0.51079587 0.32296146 0.19155029]
 [0.47227041 0.64321622 0.25844482 0.73020564]
 [0.82890208 0.53941138 0.32041523 0.42371766]]

Block of sparse matrix
 [[0.19852933 0.52610446]
 [0.53162131 0.32541112]
 [0.64321622 0.25844482]
 [0.53941138 0.32041523]]


In [16]:
M = 1024
K = 1024
N = 1024
tile_size = 32
K_pruned = 128
N_pruned = 128


block_num = (N_pruned + tile_size - 1) // tile_size
N_ori = N//block_num
def get_tiled_matmul_kernel(cuda=False):
    '''TW Tiled-Gemm kernel
    Input of the kernel: 
    * A_transposed
    * B_transposed_tiled
    * mask_k
    Output of the kernel:
    * '''
    A_transposed = te.placeholder((K, M), name='A_transposed')
    mask_k = te.placeholder((K_pruned,), name='mask_k', dtype='int')
    B_transposed_tiled = te.placeholder((tile_size, K_pruned), name='B_tiled')
    # mask_k = [0 for i in range(K_pruned)]
    # mask_n = te.placeholder((N_dim,), name="mask_n")
    
    A_transposed_skipped = te.compute((K_pruned, M), lambda i,j: A_transposed[mask_k[i], j], name='A_skipped')

    k = te.reduce_axis((0, K_pruned), name='k')
    C_transposed_skipped = te.compute((tile_size, M),lambda j,i: te.sum(A_transposed_skipped[k, i]*B_transposed_tiled[j, k], axis=k),name='C_transposed_skipped')
    
    s = te.create_schedule(C_transposed_skipped.op)

    '''schedule'''
    print('\nDefault schedule')
    print(tvm.lower(s, [C_transposed_skipped, A_transposed, B_transposed_tiled, mask_k], simple_mode=True))

    yo, xo, yi, xi = s[C_transposed_skipped].tile(C_transposed_skipped.op.axis[0], C_transposed_skipped.op.axis[1], x_factor = 32, y_factor=32)
    print('\nAfter split')
    print(tvm.lower(s, [C_transposed_skipped, A_transposed, B_transposed_tiled, mask_k], simple_mode=True))

    if cuda:
        s[C_transposed_skipped].bind(yo, te.thread_axis("blockIdx.y"))
        s[C_transposed_skipped].bind(xo, te.thread_axis("blockIdx.x"))
        s[C_transposed_skipped].bind(yi, te.thread_axis("threadIdx.y"))
        s[C_transposed_skipped].bind(xi, te.thread_axis("threadIdx.x"))
        print('\nLaunch threads')
        print(tvm.lower(s, [C_transposed_skipped, A_transposed, B_transposed_tiled, mask_k], simple_mode=True))
    
    return s, [C_transposed_skipped, A_transposed, B_transposed_tiled, mask_k]
get_tiled_matmul_kernel()



Default schedule
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(C_transposed_skipped: T.Buffer((32, 1024), "float32"), A_transposed: T.Buffer((2048, 1024), "float32"), B_tiled: T.Buffer((32, 128), "float32"), mask_k: T.Buffer((128,), "int32")):
        T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
        A_skipped = T.allocate([131072], "float32", "global")
        A_skipped_1 = T.Buffer((131072,), data=A_skipped)
        for i, j in T.grid(128, 1024):
            A_transposed_1 = T.Buffer((2097152,), data=A_transposed.data)
            mask_k_1 = T.Buffer((128,), "int32", data=mask_k.data)
            A_skipped_1[i * 1024 + j] = A_transposed_1[mask_k_1[i] * 1024 + j]
        for j, i in T.grid(32, 1024):
            C_transposed_skipped_1 = T.Buffer((32768,), data=C_transposed_skipped.data)
            C_transposed_skipped_1[j * 1024 + i] = T.float32(0)
      

(schedule(0x2e68f70),
 [Tensor(shape=[32, 1024], op.name=C_transposed_skipped),
  Tensor(shape=[2048, 1024], op.name=A_transposed),
  Tensor(shape=[32, 128], op.name=B_tiled),
  Tensor(shape=[128], op.name=mask_k)])

In [17]:
'''Testing cpu'''
schedule, placeholders = get_tiled_matmul_kernel(cuda=False)
for ph in placeholders:
  print(ph.op.name, ph.shape)

tiled_matmul_kernel = tvm.build(schedule, placeholders, target=tgt_cpu, name="tiled_matmul")  
A_transposed_data = tvm.nd.array(np.random.uniform(size=(K, M)).astype(placeholders[1].dtype), cpu)
B_transposed_tiled_data = tvm.nd.array(np.random.uniform(size=(tile_size, K_pruned)).astype(placeholders[2].dtype), cpu)
mask_k_data = tvm.nd.array(np.array(mask_gen(K, K_pruned)).astype(placeholders[3].dtype), cpu)

C_transposed_skipped_data = tvm.nd.array(np.random.uniform(size=(tile_size, M)).astype(placeholders[0].dtype), cpu)

tiled_matmul_kernel(C_transposed_skipped_data, A_transposed_data, B_transposed_tiled_data, mask_k_data)

def tiled_matmul_test(A_transposed, B_transposed_tiled, mask_k):
    A_transposed_skipped = np.zeros((K_pruned, M))
    for i in range(K_pruned):
        for j in range(M):
            A_transposed_skipped[i, j] = A_transposed[mask_k[i], j]
    C_transposed_skipped = (A_transposed_skipped.T @ B_transposed_tiled.T).T
    return C_transposed_skipped

tvm.testing.assert_allclose(C_transposed_skipped_data.numpy(), tiled_matmul_test(A_transposed_data.numpy(), B_transposed_tiled_data.numpy(), mask_k_data.numpy()), 1e-6)


Default schedule
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(C_transposed_skipped: T.Buffer((32, 1024), "float32"), A_transposed: T.Buffer((2048, 1024), "float32"), B_tiled: T.Buffer((32, 128), "float32"), mask_k: T.Buffer((128,), "int32")):
        T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
        A_skipped = T.allocate([131072], "float32", "global")
        A_skipped_1 = T.Buffer((131072,), data=A_skipped)
        for i, j in T.grid(128, 1024):
            A_transposed_1 = T.Buffer((2097152,), data=A_transposed.data)
            mask_k_1 = T.Buffer((128,), "int32", data=mask_k.data)
            A_skipped_1[i * 1024 + j] = A_transposed_1[mask_k_1[i] * 1024 + j]
        for j, i in T.grid(32, 1024):
            C_transposed_skipped_1 = T.Buffer((32768,), data=C_transposed_skipped.data)
            C_transposed_skipped_1[j * 1024 + i] = T.float32(0)
      

In [6]:
# N_pruned is the number of remaining entries in N dimension
# TODO: change K_pruned to a layer-wise configuration
def get_tw_kernel( M: int, N: int, K:int, K_pruned_max: int, N_pruned_global:int, tile_size:int,cuda:bool=False):
    '''TW Tiled-Gemm kernel
    Input of the kernel: 
    * A: K*M
    * B: (block_num, tile_size, K_pruned_max)
    * C: N*M
    * mask_k: (block_num, K_pruned_max)
    * mask_n: (block_num, tile_size)
    * block_num
    Output of the kernel:
    * '''
    dtype = 'float16'
    block_num = (N_pruned_global + tile_size - 1)//tile_size
    N_ori_per_block = N // block_num

    A_transposed = te.placeholder((K, M), name='A_transposed')
    B_transposed_packed = te.placeholder((block_num, tile_size, K_pruned_max), name='B_transposed_packed')
    

    mask_k = te.placeholder((block_num, K_pruned_max), name='mask_k', dtype='int')
    mask_n = te.placeholder((block_num, tile_size), name='mask_n', dtype='int') # 

    A_transposed_skipped = te.compute((block_num, K_pruned_max, M), lambda bn, i, j: A_transposed[mask_k[bn, i], j].astype(dtype), name='A_transposed_skipped')
    
    k = te.reduce_axis((0, K_pruned_max), name='k')
    C_transposed_skipped = te.compute((block_num, tile_size, M), lambda bn, j, i: te.sum(A_transposed_skipped[bn, k, i] * B_transposed_packed[bn, j, k].astype(dtype), axis=k) , name='C_transposed_skipped')

    def write_C_to_sparse(data, mask_n, out):
        '''
        data: shape of (block_num, tile_size, M)
        mask_n: shape of (block_num, tile_size)
        '''
        irb = tvm.tir.ir_builder.create()
        data_ptr = irb.buffer_ptr(data)
        mask_n_ptr = irb.buffer_ptr(mask_n)
        out_ptr = irb.buffer_ptr(out)

        assert data.shape[0]==mask_n.shape[0], 'block_num mismatches'
        block_num = data.shape[0]
        assert data.shape[1]==mask_n.shape[1], 'tile_size mismatches'
        tile_size = data.shape[1]
        
        N = out.shape[0]
        M = out.shape[1]

        with irb.for_range(0, N, kind='serial', name='n') as n:
            with irb.for_range(0, M, kind='serial', name='m') as m:
                out_ptr[n * M + m] = tvm.tir.generic.cast(0, data.dtype)

        with irb.for_range(0, block_num, kind='serial', name='bn') as bn:
            with irb.for_range(0, tile_size, kind='serial', name='ts') as ts:
                with irb.for_range(0, M, kind='serial', name='col') as col:
                    out_ptr[(tile_size * bn + mask_n_ptr[ts]) * M + col] += data_ptr[bn * tile_size * M + ts * M + col]
        return irb.get()
        
    C_transposed = te.extern((N, M),
                             [C_transposed_skipped, mask_n],
                             lambda ins, outs: write_C_to_sparse(ins[0], ins[1], outs[0]),
                             tag='write_C_to_sparse',
                             dtype=C_transposed_skipped.dtype,
                             name='C_transposed',
                             )
    
    s = te.create_schedule(C_transposed.op)

    '''testing cpu'''
    func = tvm.build(s, [C_transposed, A_transposed, B_transposed_packed, mask_k, mask_n], tgt_cpu, name='tiled_matmul')
    A_transposed_test = tvm.nd.array(np.random.random((K, M)).astype(A_transposed.dtype), cpu)
    B_transposed_packed_test = tvm.nd.array(np.random.random((block_num, K_pruned_max, M)).astype(B_transposed_packed.dtype), cpu)
    C_transposed_test = tvm.nd.array(np.random.random((N, M)).astype(C_transposed.dtype), cpu)

    
    print(tvm.lower(s, [C_transposed, A_transposed, B_transposed_packed, mask_k, mask_n], simple_mode=True))
    
    
get_tw_kernel(1024, 1024, 1024, 128, 512, 32)


# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(C_transposed: T.Buffer((1024, 1024), "float16"), A_transposed: T.Buffer((1024, 1024), "float32"), B_transposed_packed: T.Buffer((16, 32, 128), "float32"), mask_k: T.Buffer((16, 128), "int32"), mask_n: T.Buffer((16, 32), "int32")):
        T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
        A_transposed_skipped = T.allocate([2097152], "float16", "global")
        C_transposed_skipped = T.allocate([524288], "float16", "global")
        A_transposed_skipped_1 = T.Buffer((2097152,), "float16", data=A_transposed_skipped)
        for bn, i, j in T.grid(16, 128, 1024):
            A_transposed_1 = T.Buffer((1048576,), data=A_transposed.data)
            mask_k_1 = T.Buffer((2048,), "int32", data=mask_k.data)
            A_transposed_skipped_1[bn * 131072 + i * 1024 + j] = T.Cast("float16", A_transposed_1[mask_k_1[bn * 128