In [1]:

import copy
import time
from fastfusion import Specification
from fastfusion.mapper.metrics import Metrics
from fastfusion.mapper.FFM._make_pmappings.mapper_multi_einsum import get_sims
from fastfusion.mapper.FFM._join_pmappings.sim import SIM
from fastfusion.mapper.FFM._join_pmappings.simexplore import join_sims
import fastfusion.mapper.FFM._make_pmappings.mapper_one_einsum as mapper_one_einsum

from fastfusion.mapper.FFM._make_pmappings.mapping_filter_tags.ffmt import get_ffmt_tag
from fastfusion.mapper.FFM._make_pmappings.mapping_filter_tags.onesplit import get_one_split_tag
from fastfusion.mapper.FFM._pmapping_group import PmappingGroup

# TODO: Drop the index column after decompressing. The code in the decompress logic tries to do that, but something is bugged.
# TODO: Separate estimating energy and area.
# TODO: Once tile shapes are masked out, immediately drop those shapes.
def get_fused_mappings(spec: Specification, n_pes, tagger=None, even=False, fuse=True) -> PmappingGroup:
    spec = copy.deepcopy(spec)
    if not fuse:
        spec.arch.nodes["MainMemory"].constraints.storage.keep = "All()"
    pe_fanout = spec.arch.nodes["Register"].spatial.fanout
    pe_fanout["X"] = n_pes
    pe_fanout["Y"] = n_pes
    
    spec.mapper.ffm.timeloop_style_even = even
    spec.calculate_component_energy_area()
    flattened_architecture = spec.get_flattened_architecture()
    t0 = time.time()
    sims, decompress_data = get_sims(spec, flattened_architecture, tagger=tagger, metrics=Metrics.PER_COMPONENT_ENERGY | Metrics.RESERVATIONS | Metrics.LATENCY)
    pmapping_time = time.time() - t0
    total_pmappings = sum(p.mappings.n_pmappings for v in sims.values() for p in v)
    n_pareto_optimal_mappings = sum(len(p.mappings.data) for v in sims.values() for p in v)
    print(f'Took {pmapping_time:.2f} seconds to generate {total_pmappings} partial mappings ({total_pmappings / pmapping_time:.2f} per second). {n_pareto_optimal_mappings} pareto optimal mappings ({n_pareto_optimal_mappings / total_pmappings*100:.2f}% of total).')
    t0 = time.time()
    mappings = join_sims(sims, spec, flattened_architecture, drop_valid_reservations=archname != "snowcat")
    join_time = time.time() - t0
    mappings.decompress(decompress_data)
    print(f"Pmappings: {pmapping_time:.2f}. Joining: {join_time:.2f}. Total Pmappings: {total_pmappings}. Total mappings: {mappings.n_pmappings}. Time per pmapping: {pmapping_time / total_pmappings:.2e}")
    return mappings

# TODO: Don't lower backing storage nodes

# archname = "snowcat"
archname = "four_level"
spec = Specification.from_yaml(
    f"architecture/{archname}.arch.yaml",
    "workloads/mha_full.workload.yaml",
    "workloads/mha_full.renames.yaml",
    # "workloads/matmuls8_mixed.workload.yaml",
    # "workloads/matmuls8_mixed.renames.yaml",
    # jinja_parse_data={
    #     "FFMT": True
    # }
)

result = {}
# 1024 ~= B200 POPs
for n in [2048]:#, 512]:
    # result[f"Unfused {n}"] = get_fused_mappings(spec, n, tagger=None, fuse=False)
    # result[f"TileFlow {n}"] = get_fused_mappings(spec, n, tagger=get_one_split_tag, even=True) # This trips an assertion error
    # result[f"LoopTree {n}"] = get_fused_mappings(spec, n, tagger=get_one_split_tag)
    result[f"LoopForest {n}"] = get_fused_mappings(spec, n)


INFO        Loading yaml file architecture/snowcat.arch.yaml
INFO        Found top key variables in architecture/snowcat.arch.yaml
INFO        Found top key architecture in architecture/snowcat.arch.yaml
INFO        Loading yaml file workloads/mha_full.workload.yaml
INFO        Found top key workload in workloads/mha_full.workload.yaml
INFO        Loading yaml file workloads/mha_full.renames.yaml
INFO        Found top key renames in workloads/mha_full.renames.yaml


KeyError: 'No element with name "Register" found.'

In [None]:
import importlib
import fastfusion.mapper.FFM._make_pmappings.mapper_multi_einsum
importlib.reload(fastfusion.mapper.FFM._make_pmappings.mapper_multi_einsum)
from fastfusion.mapper.FFM._make_pmappings.mapper_multi_einsum import get_rank_variable_bounds_for_all_einsums, get_num_computes
import fastfusion.visualization.interactive
importlib.reload(fastfusion.visualization.interactive)
from fastfusion.visualization.interactive import plotly_show

rank_variable_bounds = get_rank_variable_bounds_for_all_einsums(spec)
print(f'Number of computes: {get_num_computes(spec)}')


if archname == "snowcat":
    for mappings in result.values():
        mappings: PmappingGroup
        mappings.make_pareto(columns=["RESOURCE_GlobalBuffer_LEVEL_0", "Total_Energy"])
        
result2 = {k: v.data for k, v in result.items()}

rank_variable_bounds = get_rank_variable_bounds_for_all_einsums(spec)
print(f'Number of computes: {get_num_computes(spec)}')

        
# if archname == "snowcat":
#     plotly_show(result2, "RESOURCE_GlobalBuffer_LEVEL_0", "Total_Energy", category="Category", logscales=False, einsum_names=spec.workload.einsum_names)
# else:
#     plotly_show(result2, "Total_Energy", "Total_Energy", logscales=False, einsum_names=spec.workload.einsum_names)
    
# plotly_show(result2, "Total_Energy", "Total_Energy", category="Category", logscales=False, einsum_names=spec.workload.einsum_names, rank_variable_bounds=rank_variable_bounds)
plotly_show(result2, "Total_Latency", "Total_Latency", category="Category", logscales=False, einsum_names=spec.workload.einsum_names, rank_variable_bounds=rank_variable_bounds)
# from fastfusion.visualization.interactive import mapping2svg
# mapping2svg(mappings.data.iloc[0], spec.workload.einsum_names)
# 783.577 -> 768.846
# 10.5M -> 8.6M

In [None]:
r = result2["Unfused 2048"]
for c in r.columns:
    if "Energy" in c or "Reservations" in c:
        print(f"{c}: {r[c].iloc[0]}")
        


In [None]:
import importlib
import fastfusion.visualization.interactive
importlib.reload(fastfusion.visualization.interactive)
from fastfusion.visualization.interactive import plotly_show
if archname == "snowcat":
    for mappings in result.values():
        mappings: PmappingGroup
        mappings.make_pareto(columns=["RESOURCE_GlobalBuffer_LEVEL_0", "Total_Energy"])
        
result2 = {k: v.data for k, v in result.items()}
        
if archname == "snowcat":
    plotly_show(result2, "RESOURCE_GlobalBuffer_LEVEL_0", "Total_Energy", category="Category", logscales=False, einsum_names=spec.workload.einsum_names)
else:
    plotly_show(result2, "Total_Energy", "Total_Energy", logscales=False, einsum_names=spec.workload.einsum_names)
    
plotly_show(result2, "Total_Energy", "Total_Energy", category="Category", logscales=False, einsum_names=spec.workload.einsum_names)
# from fastfusion.visualization.interactive import mapping2svg
# mapping2svg(mappings.data.iloc[0], spec.workload.einsum_names)

In [None]:
from fastfusion.mapper.FFM.deprecate_maybe.visualization import make_mapping
from IPython.display import SVG
newmapping = make_mapping(result2["Unfused 2048"].iloc[0], spec.workload.einsum_names, rank_variable_bounds=rank_variable_bounds)
display(SVG(newmapping.render()))


In [None]:
compatibility2sims = {einsum_name: {s.compatibility_str(): s for s in sims2} for einsum_name, sims2 in sims.items()}
print(compatibility2sims)