# Polyhedral Compilation with Caten


This tutorial demonstrates how to use `caten.polyhedral` to construct and transform schedules for high-performance kernels, including Matrix Multiplication optimization and Conv-Pool Fusion.


## 1. Setup


In [None]:
import caten.isl as I
import caten.polyhedral as P

print("Caten initialized.")


## 2. Matrix Multiplication Optimization


We optimize a simple Matrix Multiplication $C_{ij} += A_{ik} \cdot B_{kj}$.


### 2.1 Domain and Initial Schedule


In [None]:
M, N, K = 1024, 1024, 1024
# Domain: { S[i,j,k] : 0 <= i < M and 0 <= j < N and 0 <= k < K }
dom_str = f"{{ S[i,j,k] : 0 <= i < {M} and 0 <= j < {N} and 0 <= k < {K} }}" 

with P.domain(dom_str) as dom:
    # Initial Schedule: i, j, k order
    with P.band("{ S[i,j,k] -> [i,j,k] }") as band:
        # 1. Tiling: Tile i and j loops by 32
        # This effectively splits the band into: [i_tile, j_tile, i_point, j_point, k]
        band.tile([32, 32, 1]) 
        
        # 2. Splitting: Separate outer tile loops from inner point loops
        band.split(2)
        
        # Resulting schedule tree:
        # Band(i_tile, j_tile) -> Band(i_point, j_point, k)
        
sched = dom.finalize()
print("MatMul Schedule:")
print(sched)


## 3. Convolution + Pooling Fusion


We demonstrate **Consumer-Driven Embedding**, fusing a Convolution kernel into a Pooling kernel.


### 3.1 Problem Definition


Convolution (Producer): $Out[n, c, h, w] = \sum_{kh, kw} In[n, c, h+kh, w+kw] \cdot W[co, c, kh, kw]$


Pooling (Consumer): $Pool[n, c, ph, pw] = \max_{kh, kw} Out[n, c, ph \cdot S + kh, pw \cdot S + kw]$


In [None]:
# Parameters
N, C, H, W = 1, 64, 112, 112
K_h, K_w = 3, 3 # Conv Kernel
Pool_S = 2 # Pooling Stride

# Conv Domain: Output of Conv is H x W
conv_dom = f"{{ Conv[n, c, h, w] : 0 <= n < {N} and 0 <= c < {C} and 0 <= h < {H} and 0 <= w < {W} }}" 

# Pool Domain: Output of Pool is H/2 x W/2
H_out, W_out = H // Pool_S, W // Pool_S
pool_dom = f"{{ Pool[n, c, h, w] : 0 <= n < {N} and 0 <= c < {C} and 0 <= h < {H_out} and 0 <= w < {W_out} }}" 

full_domain = f"{conv_dom}; {pool_dom}" 
print(f"Full Domain: {full_domain}")


### 3.2 Dependency Analysis


We verify the Read-After-Write dependency between Conv and Pool.


In [None]:
# Access Relations
# Conv writes to Out[n, c, h, w]
writes = I.UnionMap("{ Conv[n,c,h,w] -> Out[n,c,h,w] }")

# Pool reads from Out. 
# For simplicity, assume Pool reads a window of 2x2 (Pool_K=2).
# Read indices: h = ph*S + kh, w = pw*S + kw
reads = I.UnionMap("{ Pool[n,c,ph,pw] -> Out[n,c,h,w] : 2*ph <= h < 2*ph + 2 and 2*pw <= w < 2*pw + 2 }")

# Compute Flow Dependence
dep = P.compute_flow(sink=reads, must_source=writes)
print("Dependency Map (Conv -> Pool):")
print(dep)


### 3.3 Constructing the Fusion Schedule


We embed the Producer (Conv) loop *inside* the Consumer (Pool) loop nest. This allows computing Conv outputs on-the-fly in registers/L1 cache before Pooling consumes them.


In [None]:
# Fusion Schedule Construction
S = 2

with P.domain(full_domain) as dom:
    # Outer Band: Iterate over the Consumer's domain (Pooling output)
    # We map both domains to a common iteration space [n, c, ph, pw].
    # For Conv, we map (h, w) to the Pool tile they belong to: (h//S, w//S).
    outer_schedule = f"{{ Pool[n,c,ph,pw] -> [n,c,ph,pw]; Conv[n,c,h,w] -> [n,c, floor(h/{S}), floor(w/{S})] }}" 
    
    with P.band(outer_schedule) as outer:
        
        # Inside the tile [n, c, ph, pw], we execute Conv parts then Pool parts.
        with P.sequence(["{ Conv[n,c,h,w] }", "{ Pool[n,c,h,w] }"]) as seq:
            
            # Child 1: Producer (Conv) Slice
            with seq.child(0):
                # Schedule Conv relative to the tile
                # Local loops: kh, kw
                # ch = ph*S + kh -> kh = ch % S (if strictly tiling)
                # Here we map to local offsets
                with P.band(f"{{ Conv[n,c,h,w] -> [h % {S}, w % {S}] }}"):
                    pass
                    
            # Child 2: Consumer (Pool) Body
            with seq.child(1):
                # Pool body has no further loops here (point operation)
                pass

sched_fusion = dom.finalize()
print("Fused Schedule Tree:")
print(sched_fusion)


## 4. Conclusion


By explicitly constructing the schedule tree using `caten.polyhedral` DSL, we achieved a fused schedule that minimizes memory traffic between Convolution and Pooling layers.
