# FEASTA

This notebook reproduces the salient characteristics of the [FEASTA](https://dl.acm.org/doi/10.1145/3620666.3651336) accelerator.

## Imports

Import the necessary modules.

In [67]:
# HiFiber boilerplate

from fibertree_bootstrap import *

fibertree_bootstrap(style="tree", animation='movie')

# Compilation boilerplate

import os
import sys
sys.path.insert(0, "..")

from src import utils

interactive(children=(Dropdown(description='style', options=('tree', 'uncompressed', 'tree+uncompressed'), valâ€¦

Button(description='Run all cells below', style=ButtonStyle())

## Initialization

Initialize the input tensors. Tensor shapes and densities can be modified below.

**Warning:** Large tensors will overwhelm the video generation. Either:
1. Use small tensors; as a rule of thumb, fewer than 60 computes (e.g., multiplications) should be required.
2. Do not generate a video; remove the `spacetime` specification from the `mapping` before compiling.

In [68]:
K = 4
M = 5
N = 6

KM1 = 4
KM0 = 2
M1 = 4
M0 = 2

density = [0.9, 0.5]
seed = 0

A_MK = Tensor.fromRandom(rank_ids=["M", "K"], shape=[M, K], seed=seed, density=density, name="A")
B_NK = Tensor.fromRandom(rank_ids=["N", "K"], shape=[N, K], seed=seed + 1, density=density, name="B")

## Compile and Run

Below is the TeAAL specification for FEASTA. To simulate the accelerator:
1. Compile it to HiFiber by running the cell, inserting a new cell
2. Run the new cell, which will
    - Execute the kernel; multiplying the above defined matrices
    - Generate visualizations of the actions of the kernel

#### Notes

- Small tensors are required for video generation. If you are using large tensors, remove the spacetime specification to generate a kernel that does not produce videos. Outputs can still be checked below.
- Partition shapes are decreased accordingly above for visualization purposes. The real OuterSPACE uses `KM1 = 256`, `KM0 = 16`, `M1 = 128`, and `M0 = 8`.

In [69]:
yaml = """
# Inner product mode
einsum:
  declaration:       
    A: [K, M]            
    B: [K, N]               
    Z: [M, N]               

  expressions:               
    - Z[m, n] = A[k, m] * B[k, n]

mapping:
  rank-order:
    A: [M, K]             
    B: [N, K]            
    Z: [M, N]

  partitioning:              
    Z:
      N: [uniform_shape(2)] # FP is fiber parallelism, which is 2 in this example

loop-order:    
  Z: [M, N1, N0, K]         

spacetime:
  Z:
    space: [N0]      
    time: [M, N1, K]
"""

utils.compile(yaml)

In [70]:
# Autogenerated HiFiber

Z_MN1N0 = Tensor(rank_ids=["M", "N1", "N0"], name="Z")
tmp0 = B_NK
tmp1 = tmp0.splitUniform(2, depth=0)
B_N1N0K = tmp1
B_N1N0K.setRankIds(rank_ids=["N1", "N0", "K"])
z_m = Z_MN1N0.getRoot()
a_m = A_MK.getRoot()
b_n1 = B_N1N0K.getRoot()
for m, (z_n1, a_k) in z_m << a_m:
    for n1, (z_n0, b_n0) in z_n1 << b_n1:
        for n0, (z_ref, b_k) in z_n0 << b_n0:
            for k, (a_val, b_val) in a_k & b_k:
                z_ref += a_val * b_val
tmp2 = Z_MN1N0
tmp3 = tmp2.mergeRanks(depth=1, levels=1, coord_style="absolute")
tmp3.setRankIds(rank_ids=["M", "N"])
Z_MN = tmp3

In [71]:
yaml = """
# Row-wise mode
einsum:
  declaration:
    A: [K, M]
    B: [K, N]
    Z: [M, N]

  expressions:
    - Z[m, n] = A[k, m] * B[k, n]

mapping:
  rank-order:
    A: [M, K]     
    B: [K, N]      
    Z: [M, N]

  partitioning:   
    Z:
      K: [uniform_shape(2)] # FP is fiber parallelism, which is 2 in this example
      N: [uniform_shape(4)] # DP is data parallelism, which is 4 in this example


loop-order:
  Z: [M, K1, K0, N1, N0]

spacetime:
  Z:
    space: [K0, N0]     
    time: [M, K1, N1]
"""

utils.compile(yaml)

In [72]:
# Autogenerated HiFiber

Z_MN1N0 = Tensor(rank_ids=["M", "N1", "N0"], name="Z")
tmp0 = A_MK
tmp1 = tmp0.splitUniform(2, depth=1)
A_MK1K0 = tmp1
A_MK1K0.setRankIds(rank_ids=["M", "K1", "K0"])
tmp2 = B_KN
tmp3 = tmp2.splitUniform(2, depth=0)
B_K1K0N = tmp3
B_K1K0N.setRankIds(rank_ids=["K1", "K0", "N"])
tmp4 = B_K1K0N
tmp5 = tmp4.splitUniform(4, depth=2)
B_K1K0N1N0 = tmp5
B_K1K0N1N0.setRankIds(rank_ids=["K1", "K0", "N1", "N0"])
z_m = Z_MN1N0.getRoot()
B_N1N0K1K0 = B_K1K0N1N0.swizzleRanks(rank_ids=["N1", "N0", "K1", "K0"])
a_m = A_MK1K0.getRoot()
b_n1 = B_N1N0K1K0.getRoot()
for m, (z_n1, a_k1) in z_m << a_m:
    for n1, (z_n0, b_n0) in z_n1 << b_n1:
        for n0, (z_ref, b_k1) in z_n0 << b_n0:
            for k1, (a_k0, b_k0) in a_k1 & b_k1:
                for k0, (a_val, b_val) in a_k0 & b_k0:
                    z_ref += a_val * b_val
tmp6 = Z_MN1N0
tmp7 = tmp6.mergeRanks(depth=1, levels=1, coord_style="absolute")
tmp7.setRankIds(rank_ids=["M", "N"])
Z_MN = tmp7

In [73]:
yaml = """
# Outer product mode
einsum:
  declaration:
    A: [K, M]
    B: [K, N]
    Z: [M, N]

  expressions:
    - Z[m, n] = A[k, m] * B[k, n]

mapping:
  rank-order:
    A: [K, M]      
    B: [K, N]    
    Z: [M, N]

  partitioning:    
    Z:
      M: [uniform_shape(2)] # FP = 2
      N: [uniform_shape(4)] # DP = 4

loop-order:
  Z: [K, M1, M0, N1, N0]

spacetime:
  Z:
    space: [M0, N0] 
    time: [K, M1, N1]
"""

utils.compile(yaml)

In [74]:
# Autogenerated HiFiber

Z_M1M0N1N0 = Tensor(rank_ids=["M1", "M0", "N1", "N0"], name="Z")
tmp0 = A_KM
tmp1 = tmp0.splitUniform(2, depth=1)
A_KM1M0 = tmp1
A_KM1M0.setRankIds(rank_ids=["K", "M1", "M0"])
tmp2 = B_KN
tmp3 = tmp2.splitUniform(4, depth=1)
B_KN1N0 = tmp3
B_KN1N0.setRankIds(rank_ids=["K", "N1", "N0"])
z_m1 = Z_M1M0N1N0.getRoot()
A_M1M0K = A_KM1M0.swizzleRanks(rank_ids=["M1", "M0", "K"])
B_N1N0K = B_KN1N0.swizzleRanks(rank_ids=["N1", "N0", "K"])
a_m1 = A_M1M0K.getRoot()
b_n1 = B_N1N0K.getRoot()
for m1, (z_m0, a_m0) in z_m1 << a_m1:
    for m0, (z_n1, a_k) in z_m0 << a_m0:
        for n1, (z_n0, b_n0) in z_n1 << b_n1:
            for n0, (z_ref, b_k) in z_n0 << b_n0:
                for k, (a_val, b_val) in a_k & b_k:
                    z_ref += a_val * b_val
tmp4 = Z_M1M0N1N0
tmp5 = tmp4.mergeRanks(depth=2, levels=1, coord_style="absolute")
tmp6 = tmp5.mergeRanks(depth=0, levels=1, coord_style="absolute")
tmp6.setRankIds(rank_ids=["M", "N"])
Z_MN = tmp6

## Check Results

Check that generated code computes the correct result.

**Note**: Should be used after compiling and running the kernel (above cell).

In [75]:
utils.check_matmul(A_KM, B_KN, Z_MN)

Result correct? True
