## Running the Fast & Fusiest Mapper (FFM)
This notebook shows how to run the Fast & Fuseiest Mapper (FFM) on a full workload and
architecture.

In [1]:
# Imports
import fastfusion as ff
from IPython.display import SVG, display
import os
ff.set_n_parallel_jobs(32)#os.cpu_count(), print_message=True)

In [None]:
# Initialize the specification and show the workload.
BATCH_SIZE = 4
N_TOKENS = 4096
FUSE = True

spec = ff.Spec.from_yaml(
    # "arches/tpu_v4i_like.yaml",
    "arches/tpu_v4i_like_constrained.yaml",
    # "arches/simple.arch.yaml",
    "workloads/gpt3_6.7B.yaml",
    # "workloads/matmuls8_mixed.workload.yaml",
    jinja_parse_data=dict(
        BATCH_SIZE=BATCH_SIZE,
        N_TOKENS=N_TOKENS,
    )
)

# If fusion is disabled, keep all tensors in main memory.
if not FUSE:
    spec.arch.nodes["MainMemory"].constraints.tensors.keep = "All()"

# display(SVG(spec.workload.render()))

In [3]:
# Generate the pmappings.

# Set optimization metrics
# spec.mapper.ffm.metrics = ff.mapper.FFM.Metrics.ENERGY
# spec.mapper.ffm.metrics = ff.mapper.FFM.Metrics.LATENCY
spec.mapper.ffm.metrics = ff.mapper.FFM.Metrics.LATENCY | ff.mapper.FFM.Metrics.ENERGY
spec.mapper.ffm.max_fused_loops = 2

pmappings = ff.mapper.FFM.make_pmappings(
    spec,
    can_combine_multiple_runs=False,
    cache_dir="cache",
)

# Simanneal before:
# -- number of pmappings per Einsum
# ++ per-Pmapping runtime

# Changes:
# + pmappings per Einsum
# - per-pmapping runtime

In [None]:
# Output some stats about the generated pmappings.
print(f"Total number of pmappings: {pmappings.total_pmappings()}")
print(f"Number of valid pmappings: {pmappings.valid_pmappings()}")
print(f"Number of Pareto-optimal pmappings: {pmappings.pareto_optimal_pmappings()}")
print(f"Number of evaluated pmappings: {pmappings.evaluated_pmappings()}")
print(f"Number of evaluated pmappings for simanneal baseline compare: {pmappings._evaluated_pmappings_for_simanneal_baseline_compare()}")

# Total number of pmappings: 1299665056477.98
# Number of valid pmappings: 346954061.15200764
# Number of Pareto-optimal pmappings: 334167
# Number of evaluated pmappings: 282649104 / 9105032


# Total number of pmappings: 934266468234.9056
# Number of valid pmappings: 96623713.66167709
# Number of Pareto-optimal pmappings: 20457134
# Number of evaluated pmappings: 177651824
# Number of evaluated pmappings for simanneal baseline compare: 21663787806.54472


In [None]:
import copy

s = "energy" if spec.mapper.ffm.metrics == ff.mapper.FFM.Metrics.ENERGY else "latency"

if spec.mapper.ffm.metrics == ff.mapper.FFM.Metrics.ENERGY:
    s = "energy"
    acc = lambda x: x[f"Total<SEP>{s}"]
elif spec.mapper.ffm.metrics == ff.mapper.FFM.Metrics.LATENCY:
    s = "latency"
    acc = lambda x: x[f"Total<SEP>{s}"]
elif spec.mapper.ffm.metrics == ff.mapper.FFM.Metrics.LATENCY | ff.mapper.FFM.Metrics.ENERGY:
    s = "EDP"
    acc = lambda x: x[f"Total<SEP>energy"] * x[f"Total<SEP>latency"]

class FilterLambda:
    def __init__(
        self,
        best_edp: float,
        min_latency: float,
        min_energy: float,
    ):
        self.best_edp = best_edp
        self.min_latency = min_latency
        self.min_energy = min_energy
        
    def __call__(self, x):
        a = acc(x) <= self.best_edp
        b = x["Total<SEP>energy"] * self.min_latency <= self.best_edp
        c = x["Total<SEP>latency"] * self.min_energy <= self.best_edp
        return a & b & c

SIMANNEAL = True

if not SIMANNEAL:
    # ff.set_n_parallel_jobs(1, print_message=True)
    
    def filter_lambda(pm):
        return all(len(x.loops) == 0 for x in pm.compatibility.tensors)

    mappings = ff.mapper.FFM.join_pmappings(spec, pmappings.filter(filter_lambda))
    
    min_latency = sum(min(s.mappings.data[f'Total<SEP>latency'].min() for s in sg) for sg in pmappings.einsum2pmappings.values() )
    min_energy = sum(min(s.mappings.data[f'Total<SEP>energy'].min() for s in sg) for sg in pmappings.einsum2pmappings.values() )
    
    print(f"Min latency: {min_latency}")
    print(f"Min energy: {min_energy}")

    f = FilterLambda(
        best_edp=acc(mappings[0]),
        min_latency=min_latency,
        min_energy=min_energy,
    )
    
    print(f"Filtering pmappings with {s} > {acc(mappings[0])}")

    # Join the pmappings to create a full mapping.
    mappings = ff.mapper.FFM.join_pmappings(
        spec,
        pmappings,
        pmapping_row_filter_function=f
    )
else:
    # Join the pmappings to create a full mapping.
    ff.set_n_parallel_jobs(32, print_message=True)
    tracker = ff.mapper.simanneal2.join_pmappings(
        spec,
        pmappings,
        score_target=0.04982001853,
        max_evaluations=20,
    )
    for a, b in tracker.history:
        print(f"{a} {b}")

# B=4 M=4096 GPT3 6.7B constrained EDP:
#   No fusion 0.0767265063452541
#   No tiled fusion 0.07461399646709636
#   One loop 0.05170420468561774
#   Two loops 0.04982001852997137 <- FAST
#   Tiled fusion 0.04982001853
# Totals:
# 	latency: 0.06004345904761903
# 	energy: 0.8297326523187198
# 	mapping: <fastfusion.mapper.FFM._interface.main.MappingFromRow object at 0x7f0726c4e8a0>

# Brief progress update on where I've been this week: Struggled with sim annealing because I implemented it using the new tile shape exploration & results changed quite a lot, but I think I fixed it.
# The way it works is that it'll randomly mutate compatibility, pick the relevant SIMs, and pick one of the chosen pmappings for that SIM. The pmappings have already been Pareto pruned, so I charge them (#Evaluated pmappings by tile shape exploration / # pareto-optimal pmappings) evaluations for this, being the expected value of the #pmappings to check before finding a Pareto-optimal one.
# The problem is that the new pmapper can find all the pareto-optimal pmappings with orders-of-magnitude fewer evaluated pmappings, so the sim annealing results had unusually low evaluations. Did some debugging and the low evaluations was somewhat due to information leakage. When the pmapper explores, it prunes partially-constructed mappings. It also waits until last to enumerate fused loops, so we get a lot of cases where one act of pruning would affect many different SIMs. This wouldn't be possible with sim annealing because in theory it should be making the fused loops before doing other intra-Einsum choices, so this cross-sim pruning wouldn't benefit it. I added a tracker for that in the pmapping explore and now simanneal is good and slow 

# Though, I wonder now if this is leaking contribution from fast pmapper into fast & fusiest... FFM gives the fast pmapper larger mapspaces to explore (full pmapping space) than simannealing would (pmapping space for a given fused loop choice), so the pmapper for FFM can leverage more pruning opportunities in this space. (edited) 

In [None]:
# The joined pmappings object contains a DataFrame of all Pareto-optimal pmappings for
# the given optimization metrics. Since we're only interested in one metric, this should
# have exaclty one row, but we'll grab index 0 to be sure.
mappings = mappings[0]

# Show the mapping.
display(SVG(mappings.render()))

# All units are SI units-- seconds, joules, meters, etc.
print(f"Totals:")
for k, v in mappings.access("Total").to_dict().items():
    print(f"\t{k}: {v}")

for k in "energy", "latency":
    try:
        per_compute = mappings.access("Total").per_compute().to_dict()[k]
        print(f'Per-compute {k}: {per_compute}')
    except:
        print(f'No per-compute {k}')

print(f'Contributors to {s}:')
for k, v in mappings.access(s).to_dict().items():
    print(f"\t{k}: {v}")

In [None]:
# Print the other stats
for k, v in mappings.to_dict().items():
    print(f"{k}: {v}")