# Polyhedral Compilation with Caten


This tutorial demonstrates how to use `caten.polyhedral` to construct and transform schedules for high-performance kernels, including Matrix Multiplication optimization and Conv-Pool Fusion.


## 1. Setup


In [None]:
import caten.isl as I
import caten.polyhedral as P

print("Caten initialized.")


## 2. Matrix Multiplication Optimization


We optimize a simple Matrix Multiplication $C_{ij} += A_{ik} \cdot B_{kj}$.


### 2.1 Domain and Initial Schedule


In [None]:
M, N, K = 1024, 1024, 1024
# Domain: { S[i,j,k] : 0 <= i < M and 0 <= j < N and 0 <= k < K }
dom_str = f"{{ S[i,j,k] : 0 <= i < {M} and 0 <= j < {N} and 0 <= k < {K} }}" # noqa

with P.domain(dom_str) as dom:
    # Initial Schedule: i, j, k order
    with P.band("{ S[i,j,k] -> [i,j,k] }") as band:
        # 1. Tiling: Tile i and j loops by 32
        # This effectively splits the band into: [i_tile, j_tile, i_point, j_point, k]
        band.tile([32, 32, 1]) 
        
        # 2. Splitting: Separate outer tile loops from inner point loops
        band.split(2)
        
        # Resulting schedule tree:
        # Band(i_tile, j_tile) -> Band(i_point, j_point, k)
        
sched = dom.finalize()
print("MatMul Schedule:")
print(sched)


## 3. Convolution + Pooling Fusion


We demonstrate **Consumer-Driven Embedding**, fusing a Convolution kernel into a Pooling kernel using `compute_at`.


### 3.1 Problem Definition


Convolution (Producer): $Out[n, c, h, w] = \sum_{kh, kw} In[n, c, h+kh, w+kw] \cdot W[co, c, kh, kw]$


Pooling (Consumer): $Pool[n, c, ph, pw] = \max_{kh, kw} Out[n, c, ph \cdot S + kh, pw \cdot S + kw]$


In [None]:
# Parameters
N, C, H, W = 1, 64, 112, 112
K_h, K_w = 3, 3 # Conv Kernel
Pool_S = 2 # Pooling Stride

# Conv Domain: Output of Conv is H x W
conv_dom = f"{{ Conv[n, c, h, w] : 0 <= n < {N} and 0 <= c < {C} and 0 <= h < {H} and 0 <= w < {W} }}" # noqa

# Pool Domain: Output of Pool is H/2 x W/2
H_out, W_out = H // Pool_S, W // Pool_S
pool_dom = f"{{ Pool[n, c, h, w] : 0 <= n < {N} and 0 <= c < {C} and 0 <= h < {H_out} and 0 <= w < {W_out} }}" # noqa

full_domain = f"{{conv_dom}}; {{pool_dom}}" # noqa
print(f"Full Domain: {{full_domain}}")


### 3.2 Automated Fusion using compute_at


We can use the `compute_at` method to automatically embed the Producer (Conv) schedule into the Consumer (Pool) schedule, based on dependency analysis.


In [None]:
# Setup Independent Domains with Access Relations (Explicit constraints for robustness)
conv_writes = f"{{ Conv[n,c,h,w] -> Out[n,c,h,w] : 0 <= n < {N} and 0 <= c < {C} and 0 <= h < {H} and 0 <= w < {W} }}" # noqa
pool_reads = f"{{ Pool[n,c,ph,pw] -> Out[n,c,h,w] : 0 <= n < {N} and 0 <= c < {C} and 0 <= ph < {H_out} and 0 <= pw < {W_out} and 2*ph <= h < 2*ph + 2 and 2*pw <= w < 2*pw + 2 }}" # noqa

# Define independent schedules
# Conv (Producer)
with P.domain(conv_dom) as conv:
    conv.access(writes=conv_writes)
    # Default schedule: [n, c, h, w] (Identity)
    with P.band(f"{{ Conv[n,c,h,w] -> [n,c,h,w] }}"):
        pass

# Pool (Consumer)
with P.domain(pool_dom) as pool:
    pool.access(reads=pool_reads)
    # Default schedule: [n, c, ph, pw]
    with P.band(f"{{ Pool[n,c,ph,pw] -> [n,c,ph,pw] }}"):
        pass

# Perform Fusion (Compute At)
# This computes the required Conv slice for each Pool iteration and embeds it.
fused_domain = conv.compute_at(pool)

# Get the resulting schedule
sched_fusion = fused_domain.schedule
print("Fused Schedule Tree:")
print(sched_fusion)


## 4. Conclusion


This demonstrates how `caten.polyhedral` allows high-level transformations like `compute_at`, automating the complex task of dependency analysis and schedule reconstruction.
