In [1]:
import copy
import hashlib
import os
import pickle
import time
from fastfusion import Specification, util
from fastfusion.mapper.FFM.exploration.mapper_multi_einsum import get_sims
from fastfusion.mapper.FFM.joining.sim import SIM
from fastfusion.mapper.FFM.joining.simexplore import join_sims
import fastfusion.mapper.FFM.exploration.mapper_one_einsum as mapper_one_einsum

from fastfusion.mapper.FFM.exploration.mapping_filter_tags.ffmt import get_ffmt_tag
from fastfusion.mapper.FFM.exploration.mapping_filter_tags.onesplit import get_one_split_tag
from fastfusion.mapper.FFM.pareto import PartialMappings
from fastfusion.mapper.simanneal.tracking import EvaluationsScoreTracker
from fastfusion.mapper.simanneal.wrappers import join_sims as join_sims_simanneal


def cache(filename):
    filename = filename if filename.endswith(".pkl") else f"{filename}.pkl"
    def decorator(func):
        def wrapper(*args, **kwargs):
            if os.path.exists(filename):
                return pickle.load(open(filename, "rb"))
            else:
                result = func(*args, **kwargs)
                pickle.dump(result, open(filename, "wb"))
                return result
        return wrapper
    return decorator

archname = "four_level"
spec = Specification.from_yaml(
    f"architecture/{archname}.arch.yaml",
    "workloads/mha_full.workload.yaml",
    "workloads/mha_full.renames.yaml"
    # "workloads/matmuls8_mixed.workload.yaml",
    # "workloads/matmuls8_mixed.renames.yaml",
)

NUM_THREADS = 24

@cache(hashlib.md5(spec._yaml_source.encode()).hexdigest())
def get_sims_with_cache():
    spec.estimate_energy_area()
    flattened_architecture = spec.get_flattened_architecture()
    t0 = time.time()
    sims, decompress_data = get_sims(spec, flattened_architecture)
    pmapping_time = time.time() - t0
    total_pmappings = sum(p.mappings.n_pmappings for v in sims.values() for p in v)
    print(f'Took {pmapping_time:.2f} seconds to generate {total_pmappings} partial mappings ({total_pmappings / pmapping_time:.2f} per second)')

    t0 = time.time()
    mappings = join_sims(sims, spec, flattened_architecture, drop_valid_reservations=archname != "snowcat")
    join_time = time.time() - t0

    mappings.decompress(decompress_data)

    data = mappings.data
    data["EDP"] = data["metric_Energy"] * data["metric_Latency"]
    best_edp = data["EDP"].min()

    pmappings_per_second = total_pmappings / pmapping_time
    time_limit = (join_time + pmapping_time) * 100

    baseline_tracker = EvaluationsScoreTracker(0, None)
    baseline_tracker.add_evaluation(join_time + pmapping_time, best_edp)

    tracker = EvaluationsScoreTracker(
        max_evaluations=time_limit,
        stop_at_score=best_edp * 1.05,
    )
    tracker.multiply_scale_by(1 / pmappings_per_second / NUM_THREADS)
    
    return sims, flattened_architecture, baseline_tracker, tracker

sims, flattened_architecture, baseline_tracker, tracker = get_sims_with_cache()
n_optimal_pmappings = sum(
    len(v2.mappings.data)
    for v in sims.values()
    for v2 in v
)
print(f'Number of optimal pmappings: {n_optimal_pmappings}')
total_pmappings = sum(p.mappings.n_pmappings for v in sims.values() for p in v)
print(f'Number of explored pmappings: {total_pmappings}')

assert False
print(f'Stop threshold: {tracker.stop_at_score}')

simanneal_tracker = copy.deepcopy(tracker)
simanneal_mappings = join_sims_simanneal(
    sims,
    simanneal_tracker,
    "simulated_anneal",
    spec,
    flattened_architecture,
)


INFO        Loading yaml file architecture/four_level.arch.yaml
INFO        Found top key variables in architecture/four_level.arch.yaml
INFO        Found top key architecture in architecture/four_level.arch.yaml
INFO        Found top key component_classes in architecture/four_level.arch.yaml
INFO        Loading yaml file workloads/mha_full.workload.yaml
INFO        Found top key workload in workloads/mha_full.workload.yaml
INFO        Loading yaml file workloads/mha_full.renames.yaml
INFO        Found top key renames in workloads/mha_full.renames.yaml
INFO        Calculated "1024*1024*128*8" = 1073741824.
INFO        Calculated "1024*1024*32*8" = 268435456.
INFO        Calculated "0.5" = 0.5.
Generating storage and loop choices for Einsum I: 12it [00:00, 423.50it/s]
Generating storage and loop choices for Einsum V: 120it [00:00, 315.77it/s]
Generating storage and loop choices for Einsum K: 120it [00:00, 381.37it/s]
Generating storage and loop choices for Einsum Q: 120it [00:00, 381.03

Took 118.47 seconds to generate 56389676.52495983 partial mappings (475987.62 per second)
SIM I tensors: {'I'}
SIM V tensors: {'I', 'V'}
SIM K tensors: {'I', 'K'}
SIM Q tensors: {'I', 'Q'}
SIM QK tensors: {'Q', 'K', 'QK'}
SIM AV tensors: {'V', 'AV', 'QK'}
SIM Z tensors: {'Z', 'AV'}
SIM FFA tensors: {'FFA', 'Z'}
SIM FFB tensors: {'FFA'}


Inital consolidate I: 100%|██████████| 908/908 [00:00<00:00, 6987.62it/s]
Inital consolidate V: 100%|██████████| 12153/12153 [00:00<00:00, 15204.51it/s]
Grouping Partial Mappings: 100%|██████████| 629/629 [00:01<00:00, 613.48it/s]
Inital consolidate K: 100%|██████████| 12153/12153 [00:01<00:00, 11771.84it/s]
Grouping Partial Mappings: 100%|██████████| 629/629 [00:01<00:00, 609.24it/s]
Inital consolidate Q: 100%|██████████| 12153/12153 [00:01<00:00, 11998.97it/s]
Grouping Partial Mappings: 100%|██████████| 629/629 [00:00<00:00, 670.74it/s]
Inital consolidate QK: 100%|██████████| 111277/111277 [00:10<00:00, 10736.23it/s]
Grouping Partial Mappings: 100%|██████████| 14138/14138 [00:06<00:00, 2083.58it/s]
Inital consolidate AV: 100%|██████████| 233017/233017 [00:23<00:00, 9963.27it/s] 
Grouping Partial Mappings: 100%|██████████| 29510/29510 [00:23<00:00, 1264.67it/s]
Inital consolidate Z: 100%|██████████| 15201/15201 [00:01<00:00, 15149.05it/s]
Grouping Partial Mappings: 100%|██████████| 62

Initial consolidate and group: 91.47 seconds

Einsum V (2/9)
Consolidating: 0.00 seconds


Grouping Partial Mappings: 100%|██████████| 114/114 [00:00<00:00, 9639.16it/s]


Combining: 0.15 seconds
Grouping: 0.00 seconds
Bucket merging: 0.01 seconds
Removed 0/398 (100.00% remaining)
Removing mappings that can't be combined later: 0.00 seconds


Merging mappings I <--> V: 100%|██████████| 398/398 [00:00<00:00, 3200.60it/s]


Mapping merging: 0.24 seconds
Scaled runtime by 1.0. Runtime: 1.97
	Combining 208(114) x 818(435) -> 398
	Number of groups for Einsum V: 398
	Number of mappings for Einsum V: 12902
	Mappings per group for Einsum V: 32.41708542713568

Einsum K (3/9)
Consolidating: 0.01 seconds
Combining: 0.00 seconds
Grouping: 0.00 seconds
Bucket merging: 0.03 seconds
Removed 0/1380 (100.00% remaining)
Removing mappings that can't be combined later: 0.01 seconds


Merging mappings V <--> K: 100%|██████████| 1380/1380 [00:01<00:00, 782.73it/s]


Mapping merging: 1.96 seconds
Scaled runtime by 1.0. Runtime: 1.97
	Combining 601(323) x 818(435) -> 1380
	Number of groups for Einsum K: 1380
	Number of mappings for Einsum K: 266994
	Mappings per group for Einsum K: 193.47391304347826

Einsum Q (4/9)
Consolidating: 0.01 seconds
Combining: 0.00 seconds
Grouping: 0.00 seconds
Bucket merging: 0.06 seconds
Removed 3290/3768 (12.69% remaining)
Removing mappings that can't be combined later: 0.02 seconds


Merging mappings K <--> Q: 100%|██████████| 478/478 [00:00<00:00, 1774.97it/s]


Mapping merging: 0.48 seconds
Scaled runtime by 1.0. Runtime: 1.97
	Combining 601(323) x 818(435) -> 478
	Number of groups for Einsum Q: 478
	Number of mappings for Einsum Q: 9876
	Mappings per group for Einsum Q: 20.661087866108787

Einsum QK (5/9)
Consolidating: 0.02 seconds


Grouping Partial Mappings: 100%|██████████| 43/43 [00:00<00:00, 35502.97it/s]

Combining: 0.09 seconds
Grouping: 0.00 seconds





Bucket merging: 0.28 seconds
Removed 16190/16309 (0.73% remaining)
Removing mappings that can't be combined later: 0.19 seconds


Merging mappings Q <--> QK: 100%|██████████| 119/119 [00:00<00:00, 10055.45it/s]


Mapping merging: 0.16 seconds
Scaled runtime by 1.0. Runtime: 1.97
	Combining 15(19) x 38511(13607) -> 119
	Number of groups for Einsum QK: 119
	Number of mappings for Einsum QK: 9648
	Mappings per group for Einsum QK: 81.07563025210084

Einsum AV (6/9)
Consolidating: 0.06 seconds


Grouping Partial Mappings: 100%|██████████| 12/12 [00:00<00:00, 32388.45it/s]

Combining: 0.05 seconds
Grouping: 0.00 seconds





Bucket merging: 0.16 seconds
Removed 8302/8839 (6.08% remaining)
Removing mappings that can't be combined later: 0.09 seconds


Merging mappings QK <--> AV: 100%|██████████| 537/537 [00:00<00:00, 768.71it/s]


Mapping merging: 1.39 seconds
Scaled runtime by 1.0. Runtime: 1.97
	Combining 20(23) x 60957(21172) -> 537
	Number of groups for Einsum AV: 537
	Number of mappings for Einsum AV: 191642
	Mappings per group for Einsum AV: 356.8752327746741

Einsum Z (7/9)
Consolidating: 0.20 seconds


Grouping Partial Mappings: 100%|██████████| 128/128 [00:01<00:00, 104.20it/s]


Combining: 1.98 seconds
Grouping: 0.00 seconds
Bucket merging: 0.02 seconds
Removed 0/629 (100.00% remaining)
Removing mappings that can't be combined later: 0.00 seconds


Merging mappings AV <--> Z: 100%|██████████| 629/629 [00:06<00:00, 101.80it/s]


Mapping merging: 8.59 seconds
Scaled runtime by 1.0. Runtime: 1.97
	Combining 236(128) x 818(435) -> 629
	Number of groups for Einsum Z: 629
	Number of mappings for Einsum Z: 267113
	Mappings per group for Einsum Z: 424.66295707472176

Einsum FFA (8/9)
Consolidating: 0.02 seconds


Grouping Partial Mappings: 100%|██████████| 128/128 [00:00<00:00, 622.34it/s]


Combining: 0.46 seconds
Grouping: 0.00 seconds
Bucket merging: 0.03 seconds
Removed 0/743 (100.00% remaining)
Removing mappings that can't be combined later: 0.00 seconds


Merging mappings Z <--> FFA: 100%|██████████| 743/743 [00:35<00:00, 20.95it/s] 


Mapping merging: 44.55 seconds
Scaled runtime by 1.0. Runtime: 1.97
	Combining 236(128) x 1042(549) -> 743
	Number of groups for Einsum FFA: 743
	Number of mappings for Einsum FFA: 1183347
	Mappings per group for Einsum FFA: 1592.6608344549124

Einsum FFB (9/9)
Consolidating: 0.01 seconds


Grouping Partial Mappings: 100%|██████████| 174/174 [00:14<00:00, 11.80it/s]


Combining: 21.31 seconds
Grouping: 0.00 seconds
Bucket merging: 0.04 seconds


Merging mappings FFA <--> FFB: 100%|██████████| 174/174 [00:00<00:00, 180.81it/s]


Mapping merging: 2.92 seconds
Scaled runtime by 1.0. Runtime: 26.25
	Combining 326(174) x 326(174) -> 174
	Number of groups for Einsum FFB: 174
	Number of mappings for Einsum FFB: 27295
	Mappings per group for Einsum FFB: 156.867816091954


Final consolidate: 100%|██████████| 174/174 [00:00<00:00, 4076.46it/s]
Grouping Partial Mappings: 100%|██████████| 1/1 [00:00<00:00, 11.34it/s]


Initial consolidate and group: 91.47 seconds
Consolidating: 0.31 seconds
Combining: 24.05 seconds
Grouping: 0.01 seconds
Bucket merging: 0.62 seconds
Removing mappings that can't be combined later: 0.32 seconds
Mapping merging: 60.29 seconds

Total: 177.08 seconds (2.95 minutes)






Evaluations: 295.76546478271484, Score: 1.5964215756583526e+18
Stopping due to evaluations 295.76546478271484 > 0
Number of optimal pmappings: 29917692
Number of explored pmappings: 56389676.52495983


AssertionError: 

In [None]:
from matplotlib.ticker import FuncFormatter

import matplotlib.pyplot as plt
plt.style.use('default')
plt.rcParams.update({'font.size': 28})
def plot_default_formatting(ax, grid_axis='both'):
    ax.tick_params(axis='both', which='major')#, labelsize=20)
    ax.tick_params(axis='both', which='minor')#, labelsize=20)
    legend = ax.legend()
    legend.get_frame().set_facecolor('white')
    legend.get_frame().set_edgecolor('black')
    for spine in ax.spines.values():
        spine.set_edgecolor('black')
    if ax.get_legend() is None:
        legend = ax.legend(fontsize=24, ncol=2)
    ax.minorticks_on()
    ax.grid(axis=grid_axis, which='major', linestyle='-', linewidth='0.3', color='gray')
    ax.grid(axis=grid_axis, which='minor', linestyle='--', linewidth='0.1', color='lightgray')
    if ax.get_xscale() == 'log':
        ax.xaxis.set_major_formatter(logscale_formatter())
    if ax.get_yscale() == 'log':
        ax.yaxis.set_major_formatter(logscale_formatter())
    


# Fastfusion No Combine Reservations
# Fastfusion No Skip Invalid
# Fastfusion No Skip Invalid Or Combine Reservations
results = get_results(workload_generator=get_matmuls_function(16), targets=["Fastfusion No Skip Invalid"], name_only_key=True)
fastfusion_result = next(iter(results.values()))

def translate_names(x):
    if isinstance(x, dict):
        return {translate_names(k): translate_names(v) for k, v in x.items()}
    if isinstance(x, list):
        return [translate_names(i) for i in x]
    if isinstance(x, tuple):
        return tuple(translate_names(i) for i in x)
    if isinstance(x, str):
        x = x.replace("Fastfusion", "FFM")
        x = x.replace("Simulated\\\\Annealing", "Simulated Annealing")
        x = x.replace("Simulated\\\\Annealing", "Simulated Annealing")
        x = x.replace("Simulated Annealing", "Sim. Anneal")
        x = x.replace("Genetic", "Genetic Algo.")
    return x

def logscale_formatter():
    """
    Custom formatter for logscale axis to replace 10^0 with 1, 10^1 with 10, and so on.
    """
    return FuncFormatter(lambda x, _: f"{int(x):d}" if x >= 1 else f"{x:.1g}")

def make_bar_chart(
    data,
    title,
    xlabel,
    ylabel,
    y_scale,
    output_file=None,
    normalize: bool = False,
    font_size_timeout=18,
    timeout_xoffs_scale=1.15,
    ylim=(None, None),
    xlim=(None, None),
    return_axes=False,
) -> plt.Figure:
    """
    Create a bar chart from the given data and save it as a PDF.
    """
    data = translate_names(data)
    
    plt.figure(figsize=(16, 8))
    def key_process(k):
        if "Intra-Layer" in k:
            k = "Make Pmappings"
        if "Matmul" in k:
            k = k.replace("Matmul", "MM")
        return k.strip()
    
    if isinstance(data, dict) and isinstance(next(iter(data.values())), dict):
        y_min = min([min(v for v in v.values() if v > 0) for v in data.values()])
        bar_width = 0.8 / len(data)
        keys = list(next(iter(data.values())).keys())
        x = range(len(keys))
        first = next(iter(data.values()))
            
        for i, (label, values) in enumerate(data.items()):
            bar_positions = [pos + i * bar_width for pos in x]
            to_plot = values
            if normalize:
                to_plot = {k: v / first[k] for k, v in values.items()}
            bars = plt.bar(bar_positions, to_plot.values(), width=bar_width, label=label)
            for bar, value in zip(bars, to_plot.values()):
                if value == 0:
                    plt.text(
                        bar.get_x() + bar.get_width() / 2 * timeout_xoffs_scale,
                        y_min + 0.1,
                        "      Timeout",
                        ha='center',
                        va='center',
                        rotation=90,
                        fontsize=font_size_timeout
                    )
        plt.xticks([pos + (len(data) - 1) * bar_width / 2 for pos in x], [key_process(k) for k in keys])
        plt.legend(loc='upper right', fontsize=10)
    else:
        y_min = min(d for d in data.values() if d > 0)
        keys = [key_process(k) for k in data.keys()]
        bars = plt.bar(keys, data.values())
        for bar, value in zip(bars, data.values()):
            if value == 0:
                plt.text(
                    bar.get_x() + bar.get_width() / 2 * timeout_xoffs_scale,
                    y_min + 0.1,
                    "    Timeout",
                    ha='center',
                    va='center',
                    rotation=90,
                    fontsize=font_size_timeout
                )

    # Set logarithmic scale for Y-axis if specified
    if y_scale == 'log':
        plt.yscale('log')
        plt.gca().yaxis.set_major_formatter(logscale_formatter())

    # Add labels and title
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.ylim(ylim)
    plt.xlim(xlim)

    # Rotate X-axis labels vertically
    plt.xticks(rotation=90)
    
    plot_default_formatting(plt.gca(), grid_axis='y')
    
    if return_axes:
        assert output_file is None, "Cannot return axes and save to file at the same time."
        return plt.gca()
    
    if output_file is not None:
        with open(output_file, 'wb') as f:
            plt.savefig(f, format='pdf', bbox_inches='tight')

    # Show the plot
    plt.show()
    
def save_plot(output_file: str):
    plt.savefig(output_file, format='pdf', bbox_inches='tight')
    
runtime = fastfusion_result.n_mappings_inter.runtime
total_runtime = sum(runtime.values())
normalized_runtime = {k: v / total_runtime for k, v in runtime.items()}
make_bar_chart(
    normalized_runtime,
    title=None,#"Normalized Runtime of Each Operation",
    xlabel=None,#"Operation",
    ylabel="Normalized Runtime",
    y_scale='log'
)

results = get_results(workload_generator=get_matmuls_function(8), targets=[
    "Fastfusion",
    "Fastfusion No Skip Invalid",
    "Fastfusion No Combine Reservations",
    # "Fastfusion No Skip Invalid Or Combine Reservations",
], name_only_key=True)
runtimes = {k: v.n_mappings_inter.runtime for k, v in results.items()}
totals = {k: sum(v.values()) for k, v in runtimes.items()}
make_bar_chart(
    totals,
    title=None,#"Total Runtime of Each Operation",
    xlabel=None,#"Operation",
    ylabel="Runtime (Seconds)",
    y_scale='linear'
)

KeyError: 'metric_Latency'

In [None]:
mappings.data.sort_values(by="metric_Energy", ascending=True).head()
from fastfusion.mapper.FFM.visualization import make_mapping
from IPython.display import SVG
newmapping = make_mapping(mappings.data.iloc[0], spec.workload.einsum_names)
display(SVG(newmapping.render()))

# {'n1'}-1 || [GlobalBuffer] T1 sz 0 above 1
# TODO: Re-add -1 to the mapper one eisnum freenig
# compatibility2sims['Matmul1']["{'n1'}-1 || [GlobalBuffer] T1 sz 0 above 1"]
# Above 1: 8192
# Above 2: 8321
# compatibility2sims['Matmul2']["{'n1'}-1 || [GlobalBuffer] T1 sz 0 above 1, [GlobalBuffer] T2 sz 0 above 0"]

In [None]:
compatibility2sims = {einsum_name: {s.compatibility_str(): s for s in sims2} for einsum_name, sims2 in sims.items()}
print(compatibility2sims)