In [1]:
import time
from fastfusion import Specification
from fastfusion.mapper.FFM.exploration.mapper_multi_einsum import get_sims
from fastfusion.mapper.FFM.joining.sim import SIM
from fastfusion.mapper.FFM.joining.simexplore import join_sims
import fastfusion.mapper.FFM.exploration.mapper_one_einsum as mapper_one_einsum

from fastfusion.mapper.FFM.exploration.mapping_filter_tags.ffmt import get_ffmt_tag
from fastfusion.mapper.FFM.exploration.mapping_filter_tags.onesplit import get_one_split_tag
from fastfusion.mapper.FFM.pareto import PartialMappings


def get_fused_mappings(spec: Specification, tagger=None) -> PartialMappings:
    spec.estimate_energy_area()
    flattened_architecture = spec.get_flattened_architecture()
    t0 = time.time()
    sims, decompress_data = get_sims(spec, flattened_architecture, tagger=tagger)
    pmapping_time = time.time() - t0
    total_pmappings = sum(p.mappings.n_pmappings for v in sims.values() for p in v)
    print(f'Took {pmapping_time:.2f} seconds to generate {total_pmappings} partial mappings ({total_pmappings / pmapping_time:.2f} per second)')
    t0 = time.time()
    mappings = join_sims(sims, spec, flattened_architecture, drop_valid_reservations=archname != "snowcat")
    join_time = time.time() - t0
    mappings.decompress(decompress_data)
    print(f"Pmappings: {pmapping_time:.2f}. Joining: {join_time:.2f}. Total Pmappings: {total_pmappings}. Total mappings: {mappings.n_pmappings}. Time per pmapping: {pmapping_time / total_pmappings:.2e}")
    return mappings

# archname = "snowcat"
archname = "four_level"
spec = Specification.from_yaml(
    f"architecture/{archname}.arch.yaml",
    "workloads/mha_full.workload.yaml",
    "workloads/mha_full.renames.yaml"
)

result = {
    # "LoopTree": get_fused_mappings(spec, get_one_split_tag).data,
    "LoopForest": get_fused_mappings(spec).data,
    # "FFMT": get_fused_mappings(spec, get_ffmt_tag).data,
}
# Pmappings: 65.17. Joining: 44.25. Total Pmappings: 2186997.5. Total mappings: 1.5343799278736949e+29. Time per pmapping: 2.98e-05

INFO        Loading yaml file architecture/four_level.arch.yaml
INFO        Found top key variables in architecture/four_level.arch.yaml
INFO        Found top key architecture in architecture/four_level.arch.yaml
INFO        Found top key component_classes in architecture/four_level.arch.yaml
INFO        Loading yaml file workloads/mha_full.workload.yaml
INFO        Found top key workload in workloads/mha_full.workload.yaml
INFO        Loading yaml file workloads/mha_full.renames.yaml
INFO        Found top key renames in workloads/mha_full.renames.yaml
INFO        Calculated "1024*1024*128*8" = 1073741824.
INFO        Calculated "1024*1024*32*8" = 268435456.
INFO        Calculated "0.5" = 0.5.
Generating storage and loop choices for Einsum I: 6it [00:00, 306.92it/s]
Generating storage and loop choices for Einsum V: 51it [00:00, 404.85it/s]
Generating storage and loop choices for Einsum K: 51it [00:00, 473.31it/s]
Generating storage and loop choices for Einsum Q: 51it [00:00, 457.25it/s

Took 40.07 seconds to generate 2186997.5 partial mappings (54581.88 per second)
SIM I tensors: {'I'}
SIM V tensors: {'V', 'I'}
SIM K tensors: {'K', 'I'}
SIM Q tensors: {'Q', 'I'}
SIM QK tensors: {'QK', 'K', 'Q'}
SIM AV tensors: {'V', 'QK', 'AV'}
SIM Z tensors: {'Z', 'AV'}
SIM FFA tensors: {'FFA', 'Z'}
SIM FFB tensors: {'FFA'}


Inital consolidate I: 100%|██████████| 191/191 [00:00<00:00, 4531.10it/s]
Inital consolidate V: 100%|██████████| 2162/2162 [00:00<00:00, 4421.94it/s]
Grouping Partial Mappings: 100%|██████████| 244/244 [00:00<00:00, 1273.92it/s]
Inital consolidate K: 100%|██████████| 2162/2162 [00:00<00:00, 4572.73it/s]
Grouping Partial Mappings: 100%|██████████| 244/244 [00:00<00:00, 1295.22it/s]
Inital consolidate Q: 100%|██████████| 2162/2162 [00:00<00:00, 4737.92it/s]
Grouping Partial Mappings: 100%|██████████| 244/244 [00:00<00:00, 1387.97it/s]
Inital consolidate QK: 100%|██████████| 18714/18714 [00:03<00:00, 5340.74it/s]
Grouping Partial Mappings: 100%|██████████| 3643/3643 [00:03<00:00, 937.02it/s] 
Inital consolidate AV: 100%|██████████| 53234/53234 [00:10<00:00, 5080.05it/s]
Grouping Partial Mappings: 100%|██████████| 11410/11410 [00:13<00:00, 818.51it/s]
Inital consolidate Z: 100%|██████████| 2762/2762 [00:00<00:00, 5208.35it/s]
Grouping Partial Mappings: 100%|██████████| 244/244 [00:00<00:00

Initial consolidate and group: 42.55 seconds

Einsum V (2/9)
Consolidating: 0.00 seconds


Grouping Partial Mappings: 100%|██████████| 38/38 [00:00<00:00, 27432.63it/s]


Combining: 0.08 seconds
Grouping: 0.00 seconds
Bucket merging: 0.00 seconds
Removed 0/145 (100.00% remaining)
Removing mappings that can't be combined later: 0.00 seconds


Merging mappings I <--> V: 100%|██████████| 145/145 [00:00<00:00, 1501.90it/s]


Mapping merging: 0.21 seconds
Scaled runtime by 1.0. Runtime: 1.09
	Combining 61(39) x 307(169) -> 145
	Number of groups for Einsum V: 145
	Number of mappings for Einsum V: 1478
	Mappings per group for Einsum V: 10.193103448275862

Einsum K (3/9)
Consolidating: 0.00 seconds
Combining: 0.00 seconds
Grouping: 0.00 seconds
Bucket merging: 0.01 seconds
Removed 0/496 (100.00% remaining)
Removing mappings that can't be combined later: 0.00 seconds


Merging mappings V <--> K: 100%|██████████| 496/496 [00:00<00:00, 1233.33it/s]


Mapping merging: 0.71 seconds
Scaled runtime by 1.0. Runtime: 1.09
	Combining 214(121) x 307(169) -> 496
	Number of groups for Einsum K: 496
	Number of mappings for Einsum K: 17928
	Mappings per group for Einsum K: 36.145161290322584

Einsum Q (4/9)
Consolidating: 0.00 seconds
Combining: 0.00 seconds
Grouping: 0.00 seconds
Bucket merging: 0.02 seconds
Removed 1115/1329 (16.10% remaining)
Removing mappings that can't be combined later: 0.01 seconds


Merging mappings K <--> Q: 100%|██████████| 214/214 [00:00<00:00, 1415.20it/s]


Mapping merging: 0.33 seconds
Scaled runtime by 1.0. Runtime: 1.09
	Combining 214(121) x 307(169) -> 214
	Number of groups for Einsum Q: 214
	Number of mappings for Einsum Q: 928
	Mappings per group for Einsum Q: 4.336448598130841

Einsum QK (5/9)
Consolidating: 0.01 seconds


Grouping Partial Mappings: 100%|██████████| 43/43 [00:00<00:00, 26429.52it/s]


Combining: 0.10 seconds
Grouping: 0.00 seconds
Bucket merging: 0.09 seconds
Removed 5511/5670 (2.80% remaining)
Removing mappings that can't be combined later: 0.05 seconds


Merging mappings Q <--> QK: 100%|██████████| 159/159 [00:00<00:00, 1857.61it/s]


Mapping merging: 0.24 seconds
Scaled runtime by 1.0. Runtime: 1.09
	Combining 15(19) x 12371(4691) -> 159
	Number of groups for Einsum QK: 159
	Number of mappings for Einsum QK: 710
	Mappings per group for Einsum QK: 4.465408805031447

Einsum AV (6/9)
Consolidating: 0.02 seconds


Grouping Partial Mappings: 100%|██████████| 14/14 [00:00<00:00, 30970.60it/s]


Combining: 0.06 seconds
Grouping: 0.00 seconds
Bucket merging: 0.11 seconds
Removed 5467/5825 (6.15% remaining)
Removing mappings that can't be combined later: 0.05 seconds


Merging mappings QK <--> AV: 100%|██████████| 358/358 [00:00<00:00, 1505.12it/s]


Mapping merging: 0.52 seconds
Scaled runtime by 1.0. Runtime: 1.09
	Combining 25(29) x 30838(10946) -> 358
	Number of groups for Einsum AV: 358
	Number of mappings for Einsum AV: 8600
	Mappings per group for Einsum AV: 24.022346368715084

Einsum Z (7/9)
Consolidating: 0.07 seconds


Grouping Partial Mappings: 100%|██████████| 51/51 [00:00<00:00, 13576.38it/s]


Combining: 0.13 seconds
Grouping: 0.00 seconds
Bucket merging: 0.01 seconds
Removed 0/244 (100.00% remaining)
Removing mappings that can't be combined later: 0.00 seconds


Merging mappings AV <--> Z: 100%|██████████| 244/244 [00:00<00:00, 1550.76it/s]


Mapping merging: 0.45 seconds
Scaled runtime by 1.0. Runtime: 1.09
	Combining 89(51) x 307(169) -> 244
	Number of groups for Einsum Z: 244
	Number of mappings for Einsum Z: 14269
	Mappings per group for Einsum Z: 58.47950819672131

Einsum FFA (8/9)
Consolidating: 0.00 seconds


Grouping Partial Mappings: 100%|██████████| 51/51 [00:00<00:00, 33241.57it/s]


Combining: 0.12 seconds
Grouping: 0.00 seconds
Bucket merging: 0.01 seconds
Removed 0/288 (100.00% remaining)
Removing mappings that can't be combined later: 0.00 seconds


Merging mappings Z <--> FFA: 100%|██████████| 288/288 [00:00<00:00, 1527.65it/s]


Mapping merging: 0.48 seconds
Scaled runtime by 1.0. Runtime: 1.09
	Combining 89(51) x 391(213) -> 288
	Number of groups for Einsum FFA: 288
	Number of mappings for Einsum FFA: 16524
	Mappings per group for Einsum FFA: 57.375

Einsum FFB (9/9)
Consolidating: 0.00 seconds


Grouping Partial Mappings: 100%|██████████| 69/69 [00:00<00:00, 1655.37it/s]


Combining: 0.16 seconds
Grouping: 0.00 seconds
Bucket merging: 0.01 seconds


Merging mappings FFA <--> FFB: 100%|██████████| 69/69 [00:00<00:00, 1643.54it/s]


Mapping merging: 0.14 seconds
Scaled runtime by 1.0. Runtime: 1.40
	Combining 123(69) x 123(69) -> 69
	Number of groups for Einsum FFB: 69
	Number of mappings for Einsum FFB: 138
	Mappings per group for Einsum FFB: 2.0


Final consolidate: 100%|██████████| 69/69 [00:00<00:00, 11924.47it/s]
Grouping Partial Mappings: 100%|██████████| 1/1 [00:00<00:00, 1473.75it/s]



Initial consolidate and group: 42.55 seconds
Consolidating: 0.11 seconds
Combining: 0.65 seconds
Grouping: 0.00 seconds
Bucket merging: 0.26 seconds
Removing mappings that can't be combined later: 0.10 seconds
Mapping merging: 3.10 seconds

Total: 46.77 seconds

Pmappings: 40.07. Joining: 47.13. Total Pmappings: 2186997.5. Total mappings: 1.5343799278736949e+29. Time per pmapping: 1.83e-05


In [2]:
import importlib
import fastfusion.visualization.interactive
importlib.reload(fastfusion.visualization.interactive)
from fastfusion.visualization.interactive import plotly_show
# mappings2.make_pareto(columns=["RESOURCE_GlobalBuffer_LEVEL_0", "metric_Energy"])
# plotly_show(result, "RESOURCE_GlobalBuffer_LEVEL_0", "metric_Energy", category="Category", logscales=False, einsum_names=spec.workload.einsum_names)
plotly_show(result, "metric_Latency", "metric_Energy", logscales=False, einsum_names=spec.workload.einsum_names)
# from fastfusion.visualization.interactive import mapping2svg
# mapping2svg(mappings.data.iloc[0], spec.workload.einsum_names)

VBox(children=(FigureWidget({
    'data': [{'line': {'shape': 'hv'},
              'marker': {'symbol': 'circl…

In [3]:
mappings.data.sort_values(by="metric_Energy", ascending=True).head()
from fastfusion.mapper.FFM.visualization import make_mapping
from IPython.display import SVG
newmapping = make_mapping(mappings.data.iloc[0], spec.workload.einsum_names)
display(SVG(newmapping.render()))

# {'n1'}-1 || [GlobalBuffer] T1 sz 0 above 1
# TODO: Re-add -1 to the mapper one eisnum freenig
# compatibility2sims['Matmul1']["{'n1'}-1 || [GlobalBuffer] T1 sz 0 above 1"]
# Above 1: 8192
# Above 2: 8321
# compatibility2sims['Matmul2']["{'n1'}-1 || [GlobalBuffer] T1 sz 0 above 1, [GlobalBuffer] T2 sz 0 above 0"]

NameError: name 'mappings' is not defined

In [None]:
compatibility2sims = {einsum_name: {s.compatibility_str(): s for s in sims2} for einsum_name, sims2 in sims.items()}
print(compatibility2sims)