# MoST: high-level scheduling tools

In [1]:
import MoST.MoST_base as most
import MoST.qast_utils.loopReader as lr
import MoST.transforms.TilingSchedule as ts
import MoST.transforms.ReorderingSchedule as rs

We start by generating a new kernel, here for SGEMM. This is actually the code the user writes:

In [2]:
sg = lr.__debug_new_sgemm()
print(sg)

def sgemm_full(N: size, M: size, K: size, C: f32[N, M] @ DRAM,
               A: f32[N, K] @ DRAM, B: f32[K, M] @ DRAM):
    for i in par(0, N):
        for j in par(0, M):
            for k in par(0, K):
                C[i, j] += A[i, k] * B[k, j]



The first thing we do is to lock this to a specific problem size:

In [3]:
sg_const = sg.partial_eval(512, 512, 512)
print(sg_const)

def sgemm_full(C: f32[512, 512] @ DRAM, A: f32[512, 512] @ DRAM,
               B: f32[512, 512] @ DRAM):
    for i in par(0, 512):
        for j in par(0, 512):
            for k in par(0, 512):
                C[i, j] += A[i, k] * B[k, j]



We can generate specific transformations, for instance, tiling:

In [4]:
tile_8_16_8 = ts.TilingSchedule({'i':8, 'j':16, 'k':8})
# can run for sg or sg_const
print(tile_8_16_8.apply(sg_const))

def sgemm_full(C: f32[512, 512] @ DRAM, A: f32[512, 512] @ DRAM,
               B: f32[512, 512] @ DRAM):
    for i_out in par(0, 64):
        for j_out in par(0, 32):
            for k_out in par(0, 64):
                for i_in in par(0, 8):
                    for j_in in par(0, 16):
                        for k_in in par(0, 8):
                            C[8 * i_out + i_in, 16 * j_out +
                              j_in] += A[8 * i_out + i_in, 8 * k_out +
                                         k_in] * B[8 * k_out + k_in,
                                                   16 * j_out + j_in]



... or reordering...

In [5]:
reorder_kij = rs.ReorderingSchedule(['k', 'i', 'j'])
print(reorder_kij.apply(sg_const))

def sgemm_full(C: f32[512, 512] @ DRAM, A: f32[512, 512] @ DRAM,
               B: f32[512, 512] @ DRAM):
    for k in par(0, 512):
        for i in par(0, 512):
            for j in par(0, 512):
                C[i, j] += A[i, k] * B[k, j]



These transforms can be combined to generate more complicated, higher-level transforms, e.g.:

In [6]:
cs = most.CompoundSchedule([reorder_kij, tile_8_16_8])
print(cs.apply(sg_const))

def sgemm_full(C: f32[512, 512] @ DRAM, A: f32[512, 512] @ DRAM,
               B: f32[512, 512] @ DRAM):
    for k_out in par(0, 64):
        for i_out in par(0, 64):
            for j_out in par(0, 32):
                for k_in in par(0, 8):
                    for i_in in par(0, 8):
                        for j_in in par(0, 16):
                            C[8 * i_out + i_in, 16 * j_out +
                              j_in] += A[8 * i_out + i_in, 8 * k_out +
                                         k_in] * B[8 * k_out + k_in,
                                                   16 * j_out + j_in]



One could, also do multilevel tiling (CoSA style) as follows:

In [7]:
t1 = ts.TilingSchedule({'i':128, 'j':128, 'k':128})
t2 = ts.TilingSchedule({'i_in':8, 'j_in':16, 'k_in':8})
multilevel = most.CompoundSchedule([t1,t2])
print(multilevel.apply(sg_const))

def sgemm_full(C: f32[512, 512] @ DRAM, A: f32[512, 512] @ DRAM,
               B: f32[512, 512] @ DRAM):
    for i_out in par(0, 4):
        for j_out in par(0, 4):
            for k_out in par(0, 4):
                for i_in_out in par(0, 16):
                    for j_in_out in par(0, 8):
                        for k_in_out in par(0, 16):
                            for i_in_in in par(0, 8):
                                for j_in_in in par(0, 16):
                                    for k_in_in in par(0, 8):
                                        C[128 * i_out +
                                          (8 * i_in_out + i_in_in),
                                          128 * j_out +
                                          (16 * j_in_out + j_in_in)] += A[
                                              128 * i_out +
                                              (8 * i_in_out + i_in_in),
                                              128 * k_out +
                                 

These scheduling elements can be defined manually as above, or through static algorithms, such as this HBL thing here...

In [8]:
memsize = 768
bounds = lr.getFixedLoopBounds(sg_const)
accesses = lr.getProjectiveDataAccesses(sg_const)
opt_tile = ts.TilingSchedule.generateHBLProjectiveTile(bounds, accesses, memsize, False)
print(opt_tile.apply(sg_const))

{'i': 16, 'j': 16, 'k': 16}
def sgemm_full(C: f32[512, 512] @ DRAM, A: f32[512, 512] @ DRAM,
               B: f32[512, 512] @ DRAM):
    for i_out in par(0, 32):
        for j_out in par(0, 32):
            for k_out in par(0, 32):
                for i_in in par(0, 16):
                    for j_in in par(0, 16):
                        for k_in in par(0, 16):
                            C[16 * i_out + i_in, 16 * j_out +
                              j_in] += A[16 * i_out + i_in, 16 * k_out +
                                         k_in] * B[16 * k_out + k_in,
                                                   16 * j_out + j_in]



... or autotuning (in progress!)