# Polyhedral Compilation with Caten


This tutorial demonstrates how to use `caten.polyhedral` to construct and transform schedules for high-performance kernels, including Matrix Multiplication optimization and Conv-Pool Fusion.


## 1. Setup


In [1]:
import sys
sys.path.append("../")

import caten.isl as I
import caten.polyhedral as P

print("Caten initialized.")

Caten initialized.


## 2. Schedule Constraction

Conv2D

Pool2D

In [2]:
def create_conv_schedule():
    with P.parameter("N, K_out, H_out, W_out, Cin, KH, KW"):
        out, x, y = I.expr("OUT"), I.expr("X"), I.expr("Y")
        with P.domain("{ S_conv[n, k, h, w, c, kh, kw] : 0<=n<N and 0<=k<K_out and 0<=h<H_out and 0<=w<W_out and 0<=c<Cin and 0<=kh<KH and 0<=kw<KW }") as conv:
            with P.band("{ S_conv[n, k, h, w, c, kh, kw] -> [n, k, h, w, c, kh, kw] }"):
                P.stmt("Out[n, k, h, w] = Out[n, k, h, w], In[n, c, h, w], W[k, c, kh, kw]")[
                    lambda c0, c1, c2, c3, c4, c5, c6: out[c0, c1].assign(x[c0 * c1] + y[c0 * c1])
                ]
    return conv

def create_pool_schedule():
    dom_str = "{ S_pool[n, k, h, w, rh, rw] : 0<=n<N and 0<=k<K_out and 0<=h<H_pool and 0<=w<W_pool and 0<=rh<KH_pool and 0<=rw<KW_pool }"
    with P.parameter("N, K_out, H_out, W_out, Cin, KH, KW"):
        with P.domain(dom_str) as pool:
            with P.band("{ S_pool[n, k, h, w, rh, rw] -> [n, k, h, w, rh, rw] }"):
                P.stmt(f"PoolBuf[n, k, h, w] = PoolBuf[n, k, h, w], Out[n, k, h*{S_pool} + rh, w*{S_pool} + rw]")
    return pool#.finalize()

In [3]:
print(create_conv_schedule())

ScheduleNodeDomain(
┗ domain([N, K_out, H_out, W_out, Cin, KH, KW] -> S_conv[n, k, h, w, c, kh, kw] : 0 <= n < N and 0 <= k < K_out and 0 <= h < H_out and 0 <= w < W_out and 0 <= c < Cin and 0 <= kh < KH and 0 <= kw < KW)
  ┗ band([S_conv[n, k, h, w, c, kh, kw] -> [(n)], S_conv[n, k, h, w, c, kh, kw] -> [(k)], S_conv[n, k, h, w, c, kh, kw] -> [(h)], S_conv[n, k, h, w, c, kh, kw] -> [(w)], S_conv[n, k, h, w, c, kh, kw] -> [(c)], S_conv[n, k, h, w, c, kh, kw] -> [(kh)], S_conv[n, k, h, w, c, kh, kw] -> [(kw)]])
    ┗ leaf()
)


In [6]:
I.expr("A").call(I.expr(1))

ASTExpr({ op: call, args: [ { id: A }, { val: 1 } ] })

In [7]:
with P.sequence(conv, pool) as seq:
    # todo: get reshape/permute parameter from conv/pool access rel
    with seq[0].permute(0, 1, 2, 3) as conv_band:
        pass
    with (seq[1] @ [1, 1, 4, 1]).permute(0, 1, 2, 3) as pool_band:
        pass
    seq = seq.fuse()
    # coalesce
    with (seq << P.Parallel()) as inner:
        tensor_core = (inner @ P.TensorCore(4, 4))

# conv+pool fusion and optimization
with (conv + pool) as sequence:
    # any idea?

# matmul
gemm = gemm_kernel()
with (gemm.permute("ijk -> ikj")).maximize_band_depth() as gemm:
    gemm_128x128x128 = gemm @ P.Parallel(128, 128, 128)
    gemm_64x64x128 = gemm << P.Prefetch(64, 64) # this will sunk, e.g.: mapped w/ pack_block_A, pack_block_B
    with (gemm @ P.TensorCore(8, 8, 8)) as gemm_tc, gemm_reminder:
        gemm_tc
        with (gemm_reminder @ P.Vectorize(8, 8)) as gemm_vec, gemm_rem:
            gemm_vec
        
print(gemm) # baseline
print(gemm_8x8x8) # optimized ver

# softmax
softmax = softmax_kernel()
with softmax @ P.Parallel(Local1, Local2) as ...:
    # any idea?
    pass

IndentationError: expected an indented block after 'with' statement on line 13 (1333366227.py, line 17)