# Polyhedral Compilation with Caten


This tutorial demonstrates how to use `caten.polyhedral` to construct and transform schedules for high-performance kernels.


## 1. Setup


In [None]:
import caten.polyhedral as P

print("Caten initialized.")


## 2. Matrix Multiplication Optimization


We start with a simple MatMul: `C[i,j] += A[i,k] * B[k,j]`.


### 2.1 Domain and Initial Schedule


In [None]:
M, N, K = 1024, 1024, 1024
# Domain: { S[i,j,k] : 0 <= i < M and 0 <= j < N and 0 <= k < K }
dom_str = f"{{ S[i,j,k] : 0 <= i < {M} and 0 <= j < {N} and 0 <= k < {K} }}" 

with P.domain(dom_str) as dom:
    # Initial Schedule: i, j, k order
    with P.band("{ S[i,j,k] -> [i,j,k] }") as band:
        # Transformations
        
        # 1. Tiling: Tile i and j loops by 32
        # This splits the band into: [i_tile, j_tile, i_point, j_point, k]
        # We explicitly tile only the first two dimensions.
        band.tile([32, 32, 1]) 
        
        # 2. Splitting: Separate tile loops from point loops
        # Split after 2 dimensions (i_tile, j_tile)
        band.split(2)
        
        # Now the schedule tree is:
        # Band(i_tile, j_tile) -> Band(i_point, j_point, k)
        
sched = dom.finalize()
print(sched)


## 3. Convolution + Pooling Fusion


Fusion of a Producer (Conv) and a Consumer (Pool) with different loop structures.


### 3.1 Problem Definition


In [None]:
# Parameters
N, C, H, W = 1, 64, 112, 112
K_h, K_w = 3, 3 # Conv Kernel
Pool_S = 2 # Pooling Stride

# Conv Domain: Output of Conv is H x W
conv_dom = f"{{ Conv[n, c, h, w] : 0 <= n < {N} and 0 <= c < {C} and 0 <= h < {H} and 0 <= w < {W} }}" 

# Pool Domain: Output of Pool is H/2 x W/2
H_out, W_out = H // Pool_S, W // Pool_S
pool_dom = f"{{ Pool[n, c, h, w] : 0 <= n < {N} and 0 <= c < {C} and 0 <= h < {H_out} and 0 <= w < {W_out} }}" 

full_domain = f"{conv_dom}; {pool_dom}"
print(f"Full Domain: {full_domain}")


### 3.2 Fusion Schedule (Compute-at)


We schedule the Conv computation *inside* the Pooling loop nest to maximize locality.


In [None]:
# Relation between Pool indices (ph, pw) and Conv indices (ch, cw)
# Pool[ph, pw] depends on Conv[ph*S ... ph*S+Pool_K, ...]
# Here we fuse simply by embedding Conv[ph*S + kh, pw*S + kw] into Pool[ph, pw].

# We assume 'S' parameter is 2 in the schedule map
S = 2

with P.domain(full_domain) as dom:
    # Outer loops: Iterate over Pooling output (Consumer driven)
    # Schedule: n, c, ph, pw
    # We map Pool directly, and Conv to the Pool tile it belongs to (div by S)
    with P.band(f"{{ Pool[n,c,ph,pw] -> [n,c,ph,pw]; Conv[n,c,ch,cw] -> [n,c, floor(ch/{S}), floor(cw/{S})] }}") as outer:
        
        # Inside (n, c, ph, pw), we have both Pool body and required Conv slice
        with P.sequence(["{ Conv[n,c,h,w] }", "{ Pool[n,c,h,w] }"]) as seq:
            
            # Child 1: Producer (Conv)
            with seq.child(0):
                # Schedule Conv relative to the tile
                # Local loops: kh, kw
                # ch = ph*S + kh -> kh = ch % S (if strictly tiling)
                # Here we map to local offsets
                with P.band(f"{{ Conv[n,c,ch,cw] -> [ch % {S}, cw % {S}] }}"):
                    pass
                    
            # Child 2: Consumer (Pool)
            with seq.child(1):
                # Pool body has no further loops here (point operation)
                pass

sched_fusion = dom.finalize()
print(sched_fusion)


## 4. Conclusion


This demonstrates how `caten.polyhedral` allows explicit construction of complex schedules like Tiling and Fusion using a Pythonic context manager API.
