In [None]:

import copy
import time
from fastfusion import Specification
from fastfusion.mapper.FFM.exploration.metrics import Metrics
from fastfusion.mapper.FFM.exploration.mapper_multi_einsum import get_sims
from fastfusion.mapper.FFM.joining.sim import SIM
from fastfusion.mapper.FFM.joining.simexplore import join_sims
import fastfusion.mapper.FFM.exploration.mapper_one_einsum as mapper_one_einsum

from fastfusion.mapper.FFM.exploration.mapping_filter_tags.ffmt import get_ffmt_tag
from fastfusion.mapper.FFM.exploration.mapping_filter_tags.onesplit import get_one_split_tag
from fastfusion.mapper.FFM.pareto import PartialMappings

# TODO: Drop the index column after decompressing. The code in the decompress logic tries to do that, but something is bugged.
# TODO: Separate estimating energy and area.
# TODO: Once tile shapes are masked out, immediately drop those shapes.
def get_fused_mappings(spec: Specification, n_pes, tagger=None, even=False, fuse=True) -> PartialMappings:
    spec = copy.deepcopy(spec)
    if not fuse:
        spec.architecture.nodes["MainMemory"].constraints.storage.keep = "All()"
    pe_fanout = spec.architecture.nodes["Register"].spatial.fanout
    pe_fanout["X"] = n_pes
    pe_fanout["Y"] = n_pes
    
    spec.mapper_ffm.timeloop_style_even = even
    spec.estimate_energy_area()
    flattened_architecture = spec.get_flattened_architecture()
    t0 = time.time()
    sims, decompress_data = get_sims(spec, flattened_architecture, tagger=tagger, metrics=Metrics.PER_COMPONENT_ENERGY | Metrics.RESERVATIONS | Metrics.LATENCY)
    pmapping_time = time.time() - t0
    total_pmappings = sum(p.mappings.n_pmappings for v in sims.values() for p in v)
    n_pareto_optimal_mappings = sum(len(p.mappings.data) for v in sims.values() for p in v)
    print(f'Took {pmapping_time:.2f} seconds to generate {total_pmappings} partial mappings ({total_pmappings / pmapping_time:.2f} per second). {n_pareto_optimal_mappings} pareto optimal mappings ({n_pareto_optimal_mappings / total_pmappings*100:.2f}% of total).')
    t0 = time.time()
    mappings = join_sims(sims, spec, flattened_architecture, drop_valid_reservations=archname != "snowcat")
    join_time = time.time() - t0
    mappings.decompress(decompress_data)
    print(f"Pmappings: {pmapping_time:.2f}. Joining: {join_time:.2f}. Total Pmappings: {total_pmappings}. Total mappings: {mappings.n_pmappings}. Time per pmapping: {pmapping_time / total_pmappings:.2e}")
    return mappings

# TODO: Don't lower backing storage nodes

# archname = "snowcat"
archname = "four_level"
spec = Specification.from_yaml(
    f"architecture/{archname}.arch.yaml",
    "workloads/mha_full.workload.yaml",
    "workloads/mha_full.renames.yaml",
    # "workloads/matmuls8_mixed.workload.yaml",
    # "workloads/matmuls8_mixed.renames.yaml",
    # jinja_parse_data={
    #     "FFMT": True
    # }
)

result = {}
# 1024 ~= B200 POPs
for n in [2048]:#, 512]:
    # result[f"Unfused {n}"] = get_fused_mappings(spec, n, tagger=None, fuse=False)
    # result[f"TileFlow {n}"] = get_fused_mappings(spec, n, tagger=get_one_split_tag, even=True) # This trips an assertion error
    # result[f"LoopTree {n}"] = get_fused_mappings(spec, n, tagger=get_one_split_tag)
    result[f"LoopForest {n}"] = get_fused_mappings(spec, n)

# result = {
#     # "LoopForest 64": get_fused_mappings(spec, 64),
#     # "LoopForest 128": get_fused_mappings(spec, 128),
#     # "LoopForest 256": get_fused_mappings(spec, 256),
#     # "LoopForest 512": get_fused_mappings(spec, 512),
#     # "Unfused 64": get_fused_mappings(spec, 64, tagger=None),
#     # "Unfused 128": get_fused_mappings(spec, 128, tagger=None),
#     # "Unfused 256": get_fused_mappings(spec, 256, tagger=None, fuse=False),
#     # "Unfused 512": get_fused_mappings(spec, 512, tagger=None, fuse=False),
#     # "TileFlow 64": get_fused_mappings(spec, 64, tagger=get_one_split_tag, even=True),
#     # "TileFlow 128": get_fused_mappings(spec, 128, tagger=get_one_split_tag, even=True),
#     # "TileFlow 256": get_fused_mappings(spec, 256, tagger=get_one_split_tag, even=True),
#     # "TileFlow 512": get_fused_mappings(spec, 512, tagger=get_one_split_tag, even=True),
#     # "LoopTree": get_fused_mappings(spec, get_one_split_tag),
#     # "TileFlow": get_fused_mappings(spec, get_one_split_tag, even=True),
#     # "FFMT": get_fused_mappings(spec, get_ffmt_tag),
# }
# Pmappings: 65.17. Joining: 44.25. Total Pmappings: 2186997.5. Total mappings: 1.5343799278736949e+29. Time per pmapping: 2.98e-05

INFO        Loading yaml file architecture/four_level.arch.yaml
INFO        Found top key variables in architecture/four_level.arch.yaml
INFO        Found top key architecture in architecture/four_level.arch.yaml
INFO        Found top key component_classes in architecture/four_level.arch.yaml
INFO        Loading yaml file workloads/mha_full.workload.yaml
INFO        Found top key workload in workloads/mha_full.workload.yaml
INFO        Loading yaml file workloads/mha_full.renames.yaml
INFO        Found top key renames in workloads/mha_full.renames.yaml
INFO        Calculated "1024*1024*128*8" = 1073741824.
INFO        Calculated "1024*1024*4*8" = 33554432.
INFO        Calculated "0.5" = 0.5.


By default metrics optimizes for energy and latency.We should change to just energy or just latency at some point.


Generating storage and loop choices for Einsum I: 15it [00:00, 492.97it/s]
Generating storage and loop choices for Einsum V: 674it [00:01, 592.10it/s]
Generating storage and loop choices for Einsum K: 674it [00:01, 575.96it/s]
Generating storage and loop choices for Einsum Q: 674it [00:01, 518.13it/s]
Generating storage and loop choices for Einsum QK: 3156it [00:04, 652.87it/s]
Generating storage and loop choices for Einsum AV: 3106it [00:04, 628.43it/s]
Generating storage and loop choices for Einsum Z: 914it [00:02, 449.04it/s]
Generating storage and loop choices for Einsum FFA: 194it [00:00, 522.81it/s]
Generating storage and loop choices for Einsum FFB: 117it [00:00, 431.59it/s]
Generating Partial Mappings:   0%|          | 0/9524 [00:00<?, ?it/s]

[MainMemory I_in], d None, m None, [GlobalBuffer I], m None, d None, [GlobalBuffer I_in], SX-d-None, SX-m-None, [LocalBuffer I], SY-d-None, SY-m-None, SX-d-None, SX-m-None, Einsum I
[MainMemory I_in], [GlobalBuffer I_in], d None, m None, [GlobalBuffer I], m None, d None, SX-d-None, SX-m-None, [LocalBuffer I], SY-d-None, SY-m-None, SX-d-None, SX-m-None, Einsum I
[MainMemory I_in], [GlobalBuffer I_in], m None, d None, [GlobalBuffer I], m None, d None, SX-d-None, SX-m-None, [LocalBuffer I], SY-d-None, SY-m-None, SX-d-None, SX-m-None, Einsum I
[MainMemory I_in], m None, d None, [GlobalBuffer I], m None, d None, [GlobalBuffer I_in], SX-d-None, SX-m-None, [LocalBuffer I], SY-d-None, SY-m-None, SX-d-None, SX-m-None, Einsum I
[MainMemory I_in], d None, m None, [GlobalBuffer I], m None, d None, [GlobalBuffer I_in], SX-d-None, SX-m-None, [LocalBuffer I], [LocalBuffer I_in], SY-d-None, SY-m-None, SX-d-None, SX-m-None, Einsum I
[MainMemory I_in], m None, d None, [GlobalBuffer I], m None, d None, [

Generating Partial Mappings:   1%|          | 64/9524 [00:01<04:40, 33.77it/s]

[MainMemory WV], m None, [GlobalBuffer I], e None, h None, [GlobalBuffer V], SX-m-None, d None, [LocalBuffer I], e None, h None, [LocalBuffer V], SY-d-None, SX-e-None, d None, [Register WV], b None, d None, e None, h None, m None, Einsum V
[MainMemory WV], [GlobalBuffer I], h None, e None, m None, [GlobalBuffer V], SX-m-None, d None, [LocalBuffer I], e None, h None, [LocalBuffer V], SY-d-None, SX-e-None, d None, [Register WV], b None, d None, e None, h None, m None, Einsum V
[MainMemory WV], m None, [GlobalBuffer I], h None, e None, [GlobalBuffer V], SX-m-None, d None, [LocalBuffer I], e None, h None, [LocalBuffer V], SY-d-None, SX-e-None, d None, [Register WV], b None, d None, e None, h None, m None, Einsum V
[MainMemory WV], m None, [GlobalBuffer I], e None, h None, [GlobalBuffer V], SX-m-None, d None, [LocalBuffer I], e None, h None, [LocalBuffer V], SY-d-None, SX-e-None, d None, [Register WV], b None, d None, e None, h None, m None, Einsum V
[MainMemory WV], [GlobalBuffer I], h Non

Generating Partial Mappings:   1%|          | 96/9524 [00:04<08:07, 19.36it/s]

[MainMemory WV], m None, [GlobalBuffer I], e None, h None, [GlobalBuffer V], d None, [GlobalBuffer WV], SX-m-None, m None, [LocalBuffer V], d None, [LocalBuffer I], SY-d-None, SX-e-None, e None, h None, [Register WV], b None, d None, e None, h None, m None, Einsum V
[MainMemory WV], m None, [GlobalBuffer I], e None, h None, [GlobalBuffer V], d None, [GlobalBuffer WV], SX-m-None, m None, [LocalBuffer V], d None, [LocalBuffer I], SY-d-None, SX-e-None, e None, h None, [Register WV], b None, d None, e None, h None, m None, Einsum V
[MainMemory WV], [GlobalBuffer I], e None, h None, m None, [GlobalBuffer V], d None, [GlobalBuffer WV], SX-m-None, m None, [LocalBuffer V], d None, [LocalBuffer I], SY-d-None, SX-e-None, e None, h None, [Register WV], b None, d None, e None, h None, m None, Einsum V
[MainMemory WV], [GlobalBuffer I], e None, m None, h None, [GlobalBuffer V], d None, [GlobalBuffer WV], SX-m-None, m None, [LocalBuffer V], d None, [LocalBuffer I], SY-d-None, SX-e-None, e None, h No

Generating Partial Mappings:   1%|▏         | 128/9524 [00:08<12:41, 12.34it/s]

[MainMemory WV], m None, [GlobalBuffer I], e None, [GlobalBuffer WV], h None, [GlobalBuffer V], SX-m-None, d None, [LocalBuffer I], e None, h None, [LocalBuffer V], SY-d-None, SX-e-None, d None, [Register WV], b None, d None, e None, h None, m None, Einsum V
[MainMemory WV], m None, [GlobalBuffer I], h None, [GlobalBuffer WV], e None, [GlobalBuffer V], SX-m-None, d None, [LocalBuffer I], e None, h None, [LocalBuffer V], SY-d-None, SX-e-None, d None, [Register WV], b None, d None, e None, h None, m None, Einsum V
[MainMemory WV], m None, [GlobalBuffer I], [GlobalBuffer WV], e None, h None, [GlobalBuffer V], SX-m-None, d None, [LocalBuffer I], e None, h None, [LocalBuffer V], SY-d-None, SX-e-None, d None, [Register WV], b None, d None, e None, h None, m None, Einsum V
[MainMemory WV], [GlobalBuffer I], h None, e None, [GlobalBuffer WV], m None, [GlobalBuffer V], SX-m-None, d None, [LocalBuffer I], e None, h None, [LocalBuffer V], SY-d-None, SX-e-None, d None, [Register WV], b None, d Non

Generating Partial Mappings:   2%|▏         | 160/9524 [00:12<15:16, 10.22it/s]

[MainMemory WV], m None, [GlobalBuffer I], e None, h None, [GlobalBuffer WV], [GlobalBuffer V], SX-m-None, d None, [LocalBuffer I], e None, h None, [LocalBuffer V], SY-d-None, SX-e-None, d None, [Register WV], b None, d None, e None, h None, m None, Einsum V
[MainMemory WV], m None, [GlobalBuffer I], e None, [GlobalBuffer WV], h None, [GlobalBuffer V], SX-m-None, d None, [LocalBuffer I], e None, h None, [LocalBuffer V], SY-d-None, SX-e-None, d None, [Register WV], b None, d None, e None, h None, m None, Einsum V
[MainMemory WV], m None, [GlobalBuffer I], h None, [GlobalBuffer WV], e None, [GlobalBuffer V], SX-m-None, d None, [LocalBuffer I], e None, h None, [LocalBuffer V], SY-d-None, SX-e-None, d None, [Register WV], b None, d None, e None, h None, m None, Einsum V
[MainMemory WV], m None, [GlobalBuffer I], [GlobalBuffer WV], h None, e None, [GlobalBuffer V], SX-m-None, d None, [LocalBuffer I], e None, h None, [LocalBuffer V], SY-d-None, SX-e-None, d None, [Register WV], b None, d Non

In [None]:
import importlib
import fastfusion.mapper.FFM.exploration.mapper_multi_einsum
importlib.reload(fastfusion.mapper.FFM.exploration.mapper_multi_einsum)
from fastfusion.mapper.FFM.exploration.mapper_multi_einsum import get_rank_variable_bounds_for_all_einsums, get_num_computes
import fastfusion.visualization.interactive
importlib.reload(fastfusion.visualization.interactive)
from fastfusion.visualization.interactive import plotly_show

rank_variable_bounds = get_rank_variable_bounds_for_all_einsums(spec)
print(f'Number of computes: {get_num_computes(spec)}')


if archname == "snowcat":
    for mappings in result.values():
        mappings: PartialMappings
        mappings.make_pareto(columns=["RESOURCE_GlobalBuffer_LEVEL_0", "metric_Energy"])
        
result2 = {k: v.data for k, v in result.items()}

rank_variable_bounds = get_rank_variable_bounds_for_all_einsums(spec)
print(f'Number of computes: {get_num_computes(spec)}')

        
# if archname == "snowcat":
#     plotly_show(result2, "RESOURCE_GlobalBuffer_LEVEL_0", "metric_Energy", category="Category", logscales=False, einsum_names=spec.workload.einsum_names)
# else:
#     plotly_show(result2, "metric_Energy", "metric_Energy", logscales=False, einsum_names=spec.workload.einsum_names)
    
# plotly_show(result2, "metric_Energy", "metric_Energy", category="Category", logscales=False, einsum_names=spec.workload.einsum_names, rank_variable_bounds=rank_variable_bounds)
plotly_show(result2, "metric_Latency", "metric_Latency", category="Category", logscales=False, einsum_names=spec.workload.einsum_names, rank_variable_bounds=rank_variable_bounds)
# from fastfusion.visualization.interactive import mapping2svg
# mapping2svg(mappings.data.iloc[0], spec.workload.einsum_names)
# 783.577 -> 768.846
# 10.5M -> 8.6M

In [None]:
r = result2["Unfused 2048"]
for c in r.columns:
    if "Energy" in c or "Reservations" in c:
        print(f"{c}: {r[c].iloc[0]}")
        


In [None]:
import importlib
import fastfusion.visualization.interactive
importlib.reload(fastfusion.visualization.interactive)
from fastfusion.visualization.interactive import plotly_show
if archname == "snowcat":
    for mappings in result.values():
        mappings: PartialMappings
        mappings.make_pareto(columns=["RESOURCE_GlobalBuffer_LEVEL_0", "metric_Energy"])
        
result2 = {k: v.data for k, v in result.items()}
        
if archname == "snowcat":
    plotly_show(result2, "RESOURCE_GlobalBuffer_LEVEL_0", "metric_Energy", category="Category", logscales=False, einsum_names=spec.workload.einsum_names)
else:
    plotly_show(result2, "metric_Energy", "metric_Energy", logscales=False, einsum_names=spec.workload.einsum_names)
    
plotly_show(result2, "metric_Energy", "metric_Energy", category="Category", logscales=False, einsum_names=spec.workload.einsum_names)
# from fastfusion.visualization.interactive import mapping2svg
# mapping2svg(mappings.data.iloc[0], spec.workload.einsum_names)

In [None]:
from fastfusion.mapper.FFM.visualization import make_mapping
from IPython.display import SVG
newmapping = make_mapping(result2["Unfused 2048"].iloc[0], spec.workload.einsum_names, rank_variable_bounds=rank_variable_bounds)
display(SVG(newmapping.render()))


In [None]:
compatibility2sims = {einsum_name: {s.compatibility_str(): s for s in sims2} for einsum_name, sims2 in sims.items()}
print(compatibility2sims)