# MoST: high-level scheduling tools

In [1]:
import src.matmap.base as most
import src.matmap.qast_utils.loopReader as lr
import src.matmap.transforms.TilingTransform as ts
import src.matmap.transforms.ReorderingTransform as rs

We start by generating a new kernel, here for SGEMM. This is actually the code the user writes:

In [3]:
sg = lr.__debug_new_sgemm()
print(sg)

def sgemm_full(N: size, M: size, K: size, C: f32[N, M] @ DRAM,
               A: f32[N, K] @ DRAM, B: f32[K, M] @ DRAM):
    for i in par(0, N):
        for j in par(0, M):
            for k in par(0, K):
                C[i, j] += A[i, k] * B[k, j]



The first thing we do is to lock this to a specific problem size:

In [4]:
sg_const = sg.partial_eval(512, 512, 512)
print(sg_const)

def sgemm_full(C: f32[512, 512] @ DRAM, A: f32[512, 512] @ DRAM,
               B: f32[512, 512] @ DRAM):
    for i in par(0, 512):
        for j in par(0, 512):
            for k in par(0, 512):
                C[i, j] += A[i, k] * B[k, j]



We can generate specific transformations, for instance, tiling:

In [5]:
tile_8_16_8 = ts.TilingTransform({'i':8, 'j':16, 'k':8})
# can run for sg or sg_const
print(tile_8_16_8.apply(sg_const))

def sgemm_full(C: f32[512, 512] @ DRAM, A: f32[512, 512] @ DRAM,
               B: f32[512, 512] @ DRAM):
    for io in par(0, 64):
        for jo in par(0, 32):
            for ko in par(0, 64):
                for ii in par(0, 8):
                    for ji in par(0, 16):
                        for ki in par(0, 8):
                            C[8 * io + ii, 16 * jo +
                              ji] += A[8 * io + ii, 8 * ko +
                                       ki] * B[8 * ko + ki, 16 * jo + ji]



... or reordering...

In [6]:
reorder_kij = rs.ReorderingTransform(['k', 'i', 'j'])
print(reorder_kij.apply(sg_const))

def sgemm_full(C: f32[512, 512] @ DRAM, A: f32[512, 512] @ DRAM,
               B: f32[512, 512] @ DRAM):
    for k in par(0, 512):
        for i in par(0, 512):
            for j in par(0, 512):
                C[i, j] += A[i, k] * B[k, j]



These transforms can be combined to generate more complicated, higher-level transforms, e.g.:

In [9]:
cs = most.CompoundTransform([reorder_kij, tile_8_16_8])
print(cs.apply(sg_const))

AssertionError: Non-MoSTSchedule argument passed into CompoundSchedule

One could, also do multilevel tiling (CoSA style) as follows:

In [10]:
t1 = ts.TilingTransform({'i':128, 'j':128, 'k':128})
t2 = ts.TilingTransform({'i_in':8, 'j_in':16, 'k_in':8})
multilevel = most.CompoundTransform([t1,t2])
print(multilevel.apply(sg_const))

AssertionError: Non-MoSTSchedule argument passed into CompoundSchedule

These scheduling elements can be defined manually as above, or through static algorithms, such as this HBL thing here...

In [11]:
memsize = 768
bounds = lr.getFixedLoopBounds(sg_const)
accesses = lr.getProjectiveDataAccesses(sg_const)
opt_tile = ts.TilingTransform.generateHBLProjectiveTile(bounds, accesses, memsize, False)
print(opt_tile.apply(sg_const))

{'i': 16, 'j': 16, 'k': 16}
def sgemm_full(C: f32[512, 512] @ DRAM, A: f32[512, 512] @ DRAM,
               B: f32[512, 512] @ DRAM):
    for io in par(0, 32):
        for jo in par(0, 32):
            for ko in par(0, 32):
                for ii in par(0, 16):
                    for ji in par(0, 16):
                        for ki in par(0, 16):
                            C[16 * io + ii, 16 * jo +
                              ji] += A[16 * io + ii, 16 * ko +
                                       ki] * B[16 * ko + ki, 16 * jo + ji]



... or autotuning (in progress!)

The CoSASchedule element runs the cosa framework to determine optimal loop paramters and returns the object with the loop transformations applied

In [1]:
import src.matmap.transforms.CoSATransform as cs
import src.matmap.qast_utils.loopReader as lr
sg = lr.__debug_new_sgemm2()
sg = sg.partial_eval(7,7,112,112,3,64,1)

print(sg)

y = cs.CoSATransform("#input params",sg)
obj = y.apply(sg)

def sgemm_full(X: f32[7, 7] @ DRAM, A: f32[7, 112] @ DRAM,
               B: f32[112, 7] @ DRAM):
    for i in par(0, 7):
        for j in par(0, 7):
            for k in par(0, 112):
                for l in par(0, 112):
                    for m in par(0, 3):
                        for n in par(0, 64):
                            for o in par(0, 1):
                                X[i, j] += A[i, k] * B[k, j]

Set parameter Username


INFO:gurobipy.gurobipy:Set parameter Username


Academic license - for non-commercial use only - expires 2023-01-01


INFO:gurobipy.gurobipy:Academic license - for non-commercial use only - expires 2023-01-01


Gurobi Optimizer version 9.5.1 build v9.5.1rc2 (linux64)


INFO:gurobipy.gurobipy:Gurobi Optimizer version 9.5.1 build v9.5.1rc2 (linux64)


Thread count: 16 physical cores, 32 logical processors, using up to 32 threads


INFO:gurobipy.gurobipy:Thread count: 16 physical cores, 32 logical processors, using up to 32 threads


Optimize a model with 761 rows, 1068 columns and 6835 nonzeros


INFO:gurobipy.gurobipy:Optimize a model with 761 rows, 1068 columns and 6835 nonzeros


Model fingerprint: 0x41abf6dd


INFO:gurobipy.gurobipy:Model fingerprint: 0x41abf6dd


Model has 1236 quadratic objective terms


INFO:gurobipy.gurobipy:Model has 1236 quadratic objective terms


Model has 4 quadratic constraints


INFO:gurobipy.gurobipy:Model has 4 quadratic constraints


Model has 1 general constraint


INFO:gurobipy.gurobipy:Model has 1 general constraint


Variable types: 6 continuous, 1062 integer (1000 binary)


INFO:gurobipy.gurobipy:Variable types: 6 continuous, 1062 integer (1000 binary)


Coefficient statistics:


INFO:gurobipy.gurobipy:Coefficient statistics:


  Matrix range     [1e+00, 4e+00]


INFO:gurobipy.gurobipy:  Matrix range     [1e+00, 4e+00]


  QMatrix range    [1e+00, 3e+00]


INFO:gurobipy.gurobipy:  QMatrix range    [1e+00, 3e+00]


  QLMatrix range   [1e+00, 8e+00]


INFO:gurobipy.gurobipy:  QLMatrix range   [1e+00, 8e+00]


  Objective range  [1e+00, 4e+01]


INFO:gurobipy.gurobipy:  Objective range  [1e+00, 4e+01]


  QObjective range [4e-01, 6e+00]


INFO:gurobipy.gurobipy:  QObjective range [4e-01, 6e+00]


  Bounds range     [1e+00, 1e+00]


INFO:gurobipy.gurobipy:  Bounds range     [1e+00, 1e+00]


  RHS range        [1e+00, 2e+01]


INFO:gurobipy.gurobipy:  RHS range        [1e+00, 2e+01]


  QRHS range       [1e+01, 1e+01]


INFO:gurobipy.gurobipy:  QRHS range       [1e+01, 1e+01]


Presolve removed 585 rows and 139 columns


INFO:gurobipy.gurobipy:Presolve removed 585 rows and 139 columns


Presolve time: 0.02s


INFO:gurobipy.gurobipy:Presolve time: 0.02s


Presolved: 3868 rows, 2159 columns, 13132 nonzeros


INFO:gurobipy.gurobipy:Presolved: 3868 rows, 2159 columns, 13132 nonzeros


Variable types: 2 continuous, 2157 integer (2156 binary)


INFO:gurobipy.gurobipy:Variable types: 2 continuous, 2157 integer (2156 binary)


Found heuristic solution: objective 268.6295123


INFO:gurobipy.gurobipy:Found heuristic solution: objective 268.6295123





INFO:gurobipy.gurobipy:


Root relaxation: objective 2.199096e+02, 174 iterations, 0.00 seconds (0.00 work units)


INFO:gurobipy.gurobipy:Root relaxation: objective 2.199096e+02, 174 iterations, 0.00 seconds (0.00 work units)





INFO:gurobipy.gurobipy:


    Nodes    |    Current Node    |     Objective Bounds      |     Work


INFO:gurobipy.gurobipy:    Nodes    |    Current Node    |     Objective Bounds      |     Work


 Expl Unexpl |  Obj  Depth IntInf | Incumbent    BestBd   Gap | It/Node Time


INFO:gurobipy.gurobipy: Expl Unexpl |  Obj  Depth IntInf | Incumbent    BestBd   Gap | It/Node Time





INFO:gurobipy.gurobipy:


     0     0  219.90961    0    6  268.62951  219.90961  18.1%     -    0s


INFO:gurobipy.gurobipy:     0     0  219.90961    0    6  268.62951  219.90961  18.1%     -    0s


H    0     0                     219.9896431  219.90961  0.04%     -    0s


INFO:gurobipy.gurobipy:H    0     0                     219.9896431  219.90961  0.04%     -    0s


*    0     0               0     219.9703786  219.97038  0.00%     -    0s


INFO:gurobipy.gurobipy:*    0     0               0     219.9703786  219.97038  0.00%     -    0s





INFO:gurobipy.gurobipy:


Cutting planes:


INFO:gurobipy.gurobipy:Cutting planes:


  Gomory: 1


INFO:gurobipy.gurobipy:  Gomory: 1


  Cover: 1


INFO:gurobipy.gurobipy:  Cover: 1


  MIR: 2


INFO:gurobipy.gurobipy:  MIR: 2


  StrongCG: 1


INFO:gurobipy.gurobipy:  StrongCG: 1





INFO:gurobipy.gurobipy:


Explored 1 nodes (187 simplex iterations) in 0.11 seconds (0.06 work units)


INFO:gurobipy.gurobipy:Explored 1 nodes (187 simplex iterations) in 0.11 seconds (0.06 work units)


Thread count was 32 (of 32 available processors)


INFO:gurobipy.gurobipy:Thread count was 32 (of 32 available processors)





INFO:gurobipy.gurobipy:


Solution count 3: 219.97 219.99 268.63 


INFO:gurobipy.gurobipy:Solution count 3: 219.97 219.99 268.63 





INFO:gurobipy.gurobipy:


Optimal solution found (tolerance 1.00e-04)


INFO:gurobipy.gurobipy:Optimal solution found (tolerance 1.00e-04)


Best objective 2.199703785977e+02, best bound 2.199703785977e+02, gap 0.0000%


INFO:gurobipy.gurobipy:Best objective 2.199703785977e+02, best bound 2.199703785977e+02, gap 0.0000%


['i', 'j', 'ko', 'ki', 'l', 'm', 'n', 'o']
['i', 'j', 'ko', 'lo', 'li', 'm', 'n', 'o', 'ki']
['i', 'j', 'ko', 'lo', 'm', 'no', 'ni', 'o', 'ki', 'li']
['i', 'j', 'koo', 'koi', 'lo', 'm', 'no', 'o', 'ki', 'li', 'ni']
['i', 'j', 'koo', 'loo', 'loi', 'm', 'no', 'o', 'ki', 'li', 'ni', 'koi']
['i', 'j', 'koo', 'loo', 'm', 'noo', 'noi', 'o', 'ki', 'li', 'ni', 'koi', 'loi']
['i', 'j', 'kooo', 'kooi', 'loo', 'm', 'noo', 'o', 'ki', 'li', 'ni', 'koi', 'loi', 'noi']
['i', 'j', 'kooo', 'looo', 'looi', 'm', 'noo', 'o', 'ki', 'li', 'ni', 'koi', 'loi', 'noi', 'kooi']
['i', 'j', 'kooo', 'looo', 'm', 'nooo', 'nooi', 'o', 'ki', 'li', 'ni', 'koi', 'loi', 'noi', 'kooi', 'looi']
['i', 'j', 'koooo', 'koooi', 'looo', 'm', 'nooo', 'o', 'ki', 'li', 'ni', 'koi', 'loi', 'noi', 'kooi', 'looi', 'nooi']
['i', 'j', 'koooo', 'loooo', 'loooi', 'm', 'nooo', 'o', 'ki', 'li', 'ni', 'koi', 'loi', 'noi', 'kooi', 'looi', 'nooi', 'koooi']
['i', 'j', 'koooo', 'loooo', 'm', 'noooo', 'noooi', 'o', 'ki', 'li', 'ni', 'koi', 'loi',

In [2]:
print(obj)

def sgemm_full(X: f32[7, 7] @ DRAM, A: f32[7, 112] @ DRAM,
               B: f32[112, 7] @ DRAM):
    for i in par(0, 7):
        for j in par(0, 7):
            for koooo in par(0, 7):
                for loooo in par(0, 7):
                    for m in par(0, 3):
                        for nooooo in par(0, 2):
                            for o in par(0, 1):
                                for ki in par(0, 2):
                                    for li in par(0, 2):
                                        for ni in par(0, 2):
                                            for koi in par(0, 2):
                                                for loi in par(0, 2):
                                                    for noi in par(0, 2):
                                                        for kooi in par(0, 2):
                                                            for looi in par(
                                                                    0, 2):
                         

In [3]:
y.subschedules

[<matmap.transforms.ReorderingTransform.ReorderingTransform at 0x7f6eaa2660a0>,
 <matmap.transforms.TilingTransform.TilingTransform at 0x7f6eab344280>,
 <matmap.transforms.TilingTransform.TilingTransform at 0x7f6eab344ac0>,
 <matmap.transforms.TilingTransform.TilingTransform at 0x7f6eab344bb0>,
 <matmap.transforms.TilingTransform.TilingTransform at 0x7f6eab33fd60>,
 <matmap.transforms.TilingTransform.TilingTransform at 0x7f6eb40198e0>,
 <matmap.transforms.TilingTransform.TilingTransform at 0x7f6eaa266340>]