In [None]:
import os
from pathlib import Path
import pickle
from fastfusion import Specification
from fastfusion.mapper.FFM._make_pmappings.mapper_multi_einsum import get_sims
from fastfusion.mapper.FFM._join_pmappings.simexplore import join_sims
from fastfusion.mapper.FFM._make_pmappings.mapping_filter_tags import get_one_split_tag, get_ffmt_tag


ARCH_DIR          = Path('architecture/')
WORKLOAD_DIR      = Path('workloads/')
MAPPINGS_SIMS_DIR = Path('results/sims/')
MAPPINGS_DATA_DIR = Path('results/data/')


def one_split_tagger(compatibility):
    return get_one_split_tag(compatibility, "MainMemory")


NAME_TO_TAGGER = {
    'one_split': one_split_tagger
}


def get_experiment_name(tagger_name, arch_name: list[str], workload_name):
    combined_arch_name = '+'.join(arch_name)
    return f'matmul8_mixed.{tagger_name}.{combined_arch_name}.{workload_name}'


def get_sims_with_cache(tagger_name=None,
                        refresh_cache=False,
                        arch_name: list[str]=['snowcat'],
                        workload_name='matmuls8_mixed'):
    data_name  = get_experiment_name(tagger_name, arch_name, workload_name)
    result_pickle_name = MAPPINGS_DATA_DIR / f'{data_name}.pkl'
    if result_pickle_name.is_file() and not refresh_cache:
        with open(result_pickle_name, 'rb') as f:
            mappings = pickle.load(f)
            print(f'Loaded final results from cache {result_pickle_name}')
            return mappings

    if tagger_name is None:
        tagger = None
    else:
        tagger = NAME_TO_TAGGER[tagger_name]

    all_sims = []
    for a in arch_name:
        sims_name = get_experiment_name(tagger_name, [a], workload_name)
        sims_pickle_name = MAPPINGS_SIMS_DIR / f'{sims_name}.pkl'
        if sims_pickle_name.is_file() and not refresh_cache:
            with open(sims_pickle_name, 'rb') as f:
                sims = pickle.load(f)
                print(f'Loaded SIMs from {sims_pickle_name}')
                all_sims.append(sims)
                continue

        spec = Specification.from_yaml(
            ARCH_DIR / f'{a}.arch.yaml',
            WORKLOAD_DIR / f'{workload_name}.workload.yaml'
        )
        spec.calculate_component_energy_area()
        workload = spec.workload
        renames = spec.renames
        flattened_architecture = spec.get_flattened_architecture()
        sims, decompress_data = get_sims(spec, flattened_architecture, tagger=tagger)

        with open(sims_pickle_name, 'wb') as f:
            pickle.dump(sims, f)

        all_sims.append(sims)

    combined_sims = {}
    for sims in all_sims:
        for einsum, sims_for_einsum in sims.items():
            if einsum not in combined_sims:
                combined_sims[einsum] = []
            combined_sims[einsum].extend(sims_for_einsum)

    spec = Specification.from_yaml(
        ARCH_DIR / f'{a}.arch.yaml',
        WORKLOAD_DIR / f'{workload_name}.workload.yaml'
    )
    flattened_architecture = spec.get_flattened_architecture()
    mappings = join_sims(combined_sims, spec, flattened_architecture, drop_valid_reservations=False)
    mappings.decompress(decompress_data)

    with open(result_pickle_name, 'wb') as f:
        pickle.dump(mappings, f)
        print(f'Saved results to cache {result_pickle_name}')

    return mappings

mappings_looptree = get_sims_with_cache('one_split', workload_name='mha_full', refresh_cache=True)
mappings_full = get_sims_with_cache(workload_name='mha_full', refresh_cache=True)
mappings_tileflow = get_sims_with_cache(tagger_name='one_split', arch_name=['snowcat_even'], workload_name='mha_full', refresh_cache=True)
mappings_ffmt = get_sims_with_cache(tagger_name='one_split', arch_name=['snowcat_ffmt'], workload_name='mha_full', refresh_cache=True)

INFO        Loading yaml file architecture/snowcat.arch.yaml
INFO        Found top key variables in architecture/snowcat.arch.yaml
INFO        Found top key architecture in architecture/snowcat.arch.yaml
INFO        Loading yaml file workloads/mha_full.workload.yaml
INFO        Found top key workload in workloads/mha_full.workload.yaml
INFO        Calculated "0.5" = 0.5.
Generating storage and loop choices for Einsum I: 3it [00:00, 1397.48it/s]
Generating storage and loop choices for Einsum V: 46it [00:00, 2971.25it/s]
Generating storage and loop choices for Einsum K: 46it [00:00, 2942.16it/s]
Generating storage and loop choices for Einsum Q: 46it [00:00, 3145.28it/s]
Generating storage and loop choices for Einsum QK: 136it [00:00, 4016.09it/s]
Generating storage and loop choices for Einsum AV: 136it [00:00, 4087.25it/s]
Generating storage and loop choices for Einsum Z: 56it [00:00, 3193.70it/s]
Generating storage and loop choices for Einsum FFA: 42it [00:00, 2630.09it/s]
Generating st

SIM I tensors: {'I'}
SIM V tensors: {'V', 'I'}
SIM K tensors: {'I', 'K'}
SIM Q tensors: {'Q', 'I'}
SIM QK tensors: {'Q', 'K', 'QK'}
SIM AV tensors: {'V', 'AV', 'QK'}
SIM Z tensors: {'AV', 'Z'}
SIM FFA tensors: {'FFA', 'Z'}
SIM FFB tensors: {'FFA'}


Inital consolidate I: 100%|█████████████████| 577/577 [00:00<00:00, 9118.02it/s]
Inital consolidate V: 100%|████████████| 39790/39790 [00:00<00:00, 43449.46it/s]
Grouping Partial Mappings: 100%|███████████| 2969/2969 [00:03<00:00, 969.17it/s]
Inital consolidate K: 100%|█████████████| 39790/39790 [00:04<00:00, 8941.62it/s]
Grouping Partial Mappings: 100%|██████████| 2969/2969 [00:02<00:00, 1002.37it/s]
Inital consolidate Q: 100%|████████████| 39790/39790 [00:00<00:00, 42288.05it/s]
Grouping Partial Mappings: 100%|██████████| 2969/2969 [00:02<00:00, 1007.96it/s]
Inital consolidate QK: 100%|███████████| 78088/78088 [00:05<00:00, 14539.50it/s]
Grouping Partial Mappings: 100%|███████████| 9800/9800 [00:14<00:00, 692.05it/s]
Inital consolidate AV: 100%|███████████| 78088/78088 [00:02<00:00, 34229.18it/s]
Grouping Partial Mappings: 100%|███████████| 9800/9800 [00:13<00:00, 712.97it/s]
Inital consolidate Z: 100%|████████████| 66470/66470 [00:02<00:00, 32628.04it/s]
Grouping Partial Mappings: 1

Initial consolidate and group: 72.10 seconds

Einsum V (2/9)
Consolidating: 0.00 seconds


Grouping Partial Mappings: 100%|████████████| 288/288 [00:00<00:00, 1921.45it/s]


Combining: 0.30 seconds
Grouping: 0.00 seconds
Bucket merging: 0.04 seconds
Removed 0/2969 (100.00% remaining)
Removing mappings that can't be combined later: 0.03 seconds


Merging mappings I <--> V: 100%|███████████| 2969/2969 [00:03<00:00, 870.25it/s]


Mapping merging: 3.68 seconds
Scaled runtime by 1.0. Runtime: 0.94
	Combining 551(289) x 8172(2956) -> 2969
	Number of groups for Einsum V: 2969
	Number of mappings for Einsum V: 13166
	Mappings per group for Einsum V: 4.434489727180869

Einsum K (3/9)
Consolidating: 0.02 seconds
Combining: 0.01 seconds
Grouping: 0.01 seconds
Bucket merging: 0.11 seconds
Removed 0/5662 (100.00% remaining)
Removing mappings that can't be combined later: 0.04 seconds


Merging mappings V <--> K: 100%|███████████| 5662/5662 [00:07<00:00, 726.69it/s]


Mapping merging: 8.15 seconds
Scaled runtime by 1.0. Runtime: 0.94
	Combining 8172(2956) x 8172(2956) -> 5662
	Number of groups for Einsum K: 5662
	Number of mappings for Einsum K: 368132
	Mappings per group for Einsum K: 65.01801483574708

Einsum Q (4/9)
Consolidating: 0.04 seconds
Combining: 0.02 seconds
Grouping: 0.01 seconds
Bucket merging: 0.20 seconds
Removed 10417/11048 (5.71% remaining)
Removing mappings that can't be combined later: 0.09 seconds


Merging mappings K <--> Q: 100%|████████████| 631/631 [00:00<00:00, 1100.89it/s]


Mapping merging: 0.90 seconds
Scaled runtime by 1.0. Runtime: 0.94
	Combining 8172(2956) x 8172(2956) -> 631
	Number of groups for Einsum Q: 631
	Number of mappings for Einsum Q: 20035
	Mappings per group for Einsum Q: 31.751188589540412

Einsum QK (5/9)
Consolidating: 0.06 seconds


Grouping Partial Mappings: 100%|███████████████| 6/6 [00:00<00:00, 15024.37it/s]

Combining: 0.07 seconds
Grouping: 0.00 seconds





Bucket merging: 0.13 seconds
Removed 8722/9101 (4.16% remaining)
Removing mappings that can't be combined later: 0.09 seconds


Merging mappings Q <--> QK: 100%|███████████| 379/379 [00:00<00:00, 1374.72it/s]


Mapping merging: 0.52 seconds
Scaled runtime by 1.0. Runtime: 0.94
	Combining 304(168) x 27235(9539) -> 379
	Number of groups for Einsum QK: 379
	Number of mappings for Einsum QK: 4980
	Mappings per group for Einsum QK: 13.139841688654354

Einsum AV (6/9)
Consolidating: 0.05 seconds


Grouping Partial Mappings: 100%|█████████████| 15/15 [00:00<00:00, 45197.24it/s]

Combining: 0.11 seconds
Grouping: 0.00 seconds
Bucket merging: 0.05 seconds





Removed 0/3222 (100.00% remaining)
Removing mappings that can't be combined later: 0.03 seconds


Merging mappings QK <--> AV: 100%|█████████| 3222/3222 [00:04<00:00, 735.74it/s]


Mapping merging: 4.76 seconds
Scaled runtime by 1.0. Runtime: 0.94
	Combining 162(96) x 27319(9583) -> 3222
	Number of groups for Einsum AV: 3222
	Number of mappings for Einsum AV: 449478
	Mappings per group for Einsum AV: 139.50279329608938

Einsum Z (7/9)
Consolidating: 0.05 seconds


Grouping Partial Mappings: 100%|████████████| 362/362 [00:00<00:00, 1186.45it/s]


Combining: 0.54 seconds
Grouping: 0.01 seconds
Bucket merging: 0.09 seconds
Removed 0/3337 (100.00% remaining)
Removing mappings that can't be combined later: 0.01 seconds


Merging mappings AV <--> Z: 100%|██████████| 3337/3337 [00:05<00:00, 568.00it/s]


Mapping merging: 6.28 seconds
Scaled runtime by 1.0. Runtime: 0.94
	Combining 7768(2746) x 8172(2956) -> 3337
	Number of groups for Einsum Z: 3337
	Number of mappings for Einsum Z: 253249
	Mappings per group for Einsum Z: 75.89121965837579

Einsum FFA (8/9)
Consolidating: 0.07 seconds


Grouping Partial Mappings: 100%|█████████████| 292/292 [00:00<00:00, 461.49it/s]


Combining: 1.47 seconds
Grouping: 0.00 seconds
Bucket merging: 0.05 seconds
Removed 0/973 (100.00% remaining)
Removing mappings that can't be combined later: 0.00 seconds


Merging mappings Z <--> FFA: 100%|███████████| 973/973 [00:01<00:00, 541.72it/s]


Mapping merging: 2.67 seconds
Scaled runtime by 1.0. Runtime: 0.94
	Combining 551(289) x 1192(622) -> 973
	Number of groups for Einsum FFA: 973
	Number of mappings for Einsum FFA: 328849
	Mappings per group for Einsum FFA: 337.9743062692703

Einsum FFB (9/9)
Consolidating: 0.01 seconds


Grouping Partial Mappings: 100%|█████████████| 337/337 [00:01<00:00, 218.01it/s]


Combining: 1.79 seconds
Grouping: 0.00 seconds
Bucket merging: 0.03 seconds


Merging mappings FFA <--> FFB: 100%|█████████| 338/338 [00:00<00:00, 970.28it/s]


Mapping merging: 0.79 seconds
Scaled runtime by 1.0. Runtime: 3.56
	Combining 641(335) x 641(335) -> 338
	Number of groups for Einsum FFB: 338
	Number of mappings for Einsum FFB: 42287
	Mappings per group for Einsum FFB: 125.1094674556213


Final consolidate: 100%|████████████████████| 338/338 [00:00<00:00, 6307.67it/s]
Grouping Partial Mappings: 100%|████████████████| 1/1 [00:00<00:00, 3724.96it/s]



Initial consolidate and group: 72.10 seconds
Consolidating: 0.30 seconds
Combining: 4.32 seconds
Grouping: 0.03 seconds
Bucket merging: 0.70 seconds
Removing mappings that can't be combined later: 0.30 seconds
Mapping merging: 27.75 seconds

Total: 105.51 seconds (1.76 minutes)

Saved results to cache results/data/matmul8_mixed.one_split.snowcat.mha_full.pkl


INFO        Loading yaml file architecture/snowcat.arch.yaml
INFO        Found top key variables in architecture/snowcat.arch.yaml
INFO        Found top key architecture in architecture/snowcat.arch.yaml
INFO        Loading yaml file workloads/mha_full.workload.yaml
INFO        Found top key workload in workloads/mha_full.workload.yaml
Generating storage and loop choices for Einsum I: 3it [00:00, 1563.48it/s]
Generating storage and loop choices for Einsum V: 46it [00:00, 3110.75it/s]
Generating storage and loop choices for Einsum K: 46it [00:00, 2774.77it/s]
Generating storage and loop choices for Einsum Q: 46it [00:00, 3140.22it/s]
Generating storage and loop choices for Einsum QK: 136it [00:00, 4340.54it/s]
Generating storage and loop choices for Einsum AV: 136it [00:00, 4384.18it/s]
Generating storage and loop choices for Einsum Z: 56it [00:00, 537.24it/s]
Generating storage and loop choices for Einsum FFA: 42it [00:00, 2994.30it/s]
Generating storage and loop choices for Einsum FFB

SIM I tensors: {'I'}
SIM V tensors: {'V', 'I'}
SIM K tensors: {'I', 'K'}
SIM Q tensors: {'Q', 'I'}
SIM QK tensors: {'Q', 'K', 'QK'}
SIM AV tensors: {'V', 'AV', 'QK'}
SIM Z tensors: {'AV', 'Z'}
SIM FFA tensors: {'FFA', 'Z'}
SIM FFB tensors: {'FFA'}


Inital consolidate I: 100%|█████████████████| 577/577 [00:00<00:00, 9472.92it/s]
Inital consolidate V: 100%|████████████| 63420/63420 [00:02<00:00, 27925.01it/s]
Grouping Partial Mappings: 100%|███████████| 6957/6957 [00:06<00:00, 994.19it/s]
Inital consolidate K: 100%|████████████| 63420/63420 [00:02<00:00, 27202.32it/s]
Grouping Partial Mappings: 100%|██████████| 6957/6957 [00:06<00:00, 1039.79it/s]
Inital consolidate Q: 100%|████████████| 63420/63420 [00:02<00:00, 27984.09it/s]
Grouping Partial Mappings: 100%|███████████| 6957/6957 [00:07<00:00, 949.80it/s]
Inital consolidate QK: 100%|█████████| 248568/248568 [00:20<00:00, 11872.68it/s]
Grouping Partial Mappings: 100%|█████████| 60438/60438 [01:06<00:00, 914.06it/s]
Inital consolidate AV: 100%|█████████| 248568/248568 [00:23<00:00, 10720.66it/s]
Grouping Partial Mappings: 100%|█████████| 60438/60438 [01:09<00:00, 875.80it/s]
Inital consolidate Z: 100%|████████████| 90100/90100 [00:02<00:00, 30549.52it/s]
Grouping Partial Mappings: 1

Initial consolidate and group: 242.67 seconds

Einsum V (2/9)
Consolidating: 0.00 seconds


Grouping Partial Mappings: 100%|████████████| 288/288 [00:00<00:00, 1963.61it/s]


Combining: 0.29 seconds
Grouping: 0.00 seconds
Bucket merging: 0.09 seconds
Removed 0/6957 (100.00% remaining)
Removing mappings that can't be combined later: 0.07 seconds


Merging mappings I <--> V: 100%|███████████| 6957/6957 [00:08<00:00, 848.53it/s]


Mapping merging: 8.73 seconds
Scaled runtime by 1.0. Runtime: 1.01
	Combining 551(289) x 18367(6525) -> 6957
	Number of groups for Einsum V: 6957
	Number of mappings for Einsum V: 21048
	Mappings per group for Einsum V: 3.0254420008624408

Einsum K (3/9)
Consolidating: 0.04 seconds
Combining: 0.03 seconds
Grouping: 0.02 seconds
Bucket merging: 0.49 seconds
Removed 0/32159 (100.00% remaining)
Removing mappings that can't be combined later: 0.14 seconds


Merging mappings V <--> K: 100%|█████████| 32159/32159 [00:57<00:00, 558.28it/s]


Mapping merging: 57.98 seconds
Scaled runtime by 1.0. Runtime: 1.01
	Combining 18367(6525) x 18367(6525) -> 32159
	Number of groups for Einsum K: 32159
	Number of mappings for Einsum K: 470946
	Mappings per group for Einsum K: 14.6442986411269

Einsum Q (4/9)
Consolidating: 0.11 seconds
Combining: 0.13 seconds
Grouping: 0.09 seconds
Bucket merging: 1.45 seconds
Removed 98175/102195 (3.93% remaining)
Removing mappings that can't be combined later: 0.55 seconds


Merging mappings K <--> Q: 100%|███████████| 4020/4020 [00:04<00:00, 947.73it/s]


Mapping merging: 4.64 seconds
Scaled runtime by 1.0. Runtime: 1.01
	Combining 18367(6525) x 18367(6525) -> 4020
	Number of groups for Einsum Q: 4020
	Number of mappings for Einsum Q: 40945
	Mappings per group for Einsum Q: 10.185323383084578

Einsum QK (5/9)
Consolidating: 0.21 seconds


Grouping Partial Mappings: 100%|████████████| 862/862 [00:00<00:00, 1356.38it/s]


Combining: 1.11 seconds
Grouping: 0.00 seconds
Bucket merging: 0.67 seconds
Removed 42152/44336 (4.93% remaining)
Removing mappings that can't be combined later: 0.27 seconds


Merging mappings Q <--> QK:  69%|██████▉   | 1504/2184 [00:02<00:00, 684.25it/s]

In [3]:
import copy
import re
from fastfusion.frontend.mapping import Iteration, Mapping, Nested, Split, Storage
from fastfusion.visualization.interactive import plotly_show
from fastfusion.mapper.FFM.deprecate_maybe.visualization import make_mapping
from fastfusion.frontend.workload import Workload
workload = Workload.from_yaml('workloads/matmuls8_mixed.workload.yaml')

plotly_show(mappings_full.data, "RESOURCE_GlobalBuffer_LEVEL_0", "Total_Energy", logscales=True, einsum_names=workload.einsum_names)

INFO        Loading yaml file workloads/matmuls8_mixed.workload.yaml
INFO        Found top key workload in workloads/matmuls8_mixed.workload.yaml


VBox(children=(FigureWidget({
    'data': [{'line': {'shape': 'hv'},
              'marker': {'symbol': 'circl…

In [4]:
plotly_show(mappings_looptree.data, "RESOURCE_GlobalBuffer_LEVEL_0", "Total_Energy", logscales=True, einsum_names=workload.einsum_names)

VBox(children=(FigureWidget({
    'data': [{'line': {'shape': 'hv'},
              'marker': {'symbol': 'circl…

In [5]:

plotly_show(mappings_tileflow.data, "RESOURCE_GlobalBuffer_LEVEL_0", "Total_Energy", logscales=True, einsum_names=workload.einsum_names)

VBox(children=(FigureWidget({
    'data': [{'line': {'shape': 'hv'},
              'marker': {'symbol': 'circl…

In [4]:
plotly_show(mappings_ffmt.data, "RESOURCE_GlobalBuffer_LEVEL_0", "Total_Energy", logscales=True, einsum_names=workload.einsum_names)

VBox(children=(FigureWidget({
    'data': [{'line': {'shape': 'hv'},
              'marker': {'symbol': 'circl…