In [None]:
import hashlib
import os
import pickle
import time
from fastfusion import Specification
from fastfusion.mapper.FFM.exploration.mapper_multi_einsum import get_sims
from fastfusion.mapper.FFM.joining.sim import SIM
from fastfusion.mapper.FFM.joining.simexplore import join_sims
import fastfusion.mapper.FFM.exploration.mapper_one_einsum as mapper_one_einsum

from fastfusion.mapper.FFM.exploration.mapping_filter_tags.ffmt import get_ffmt_tag
from fastfusion.mapper.FFM.exploration.mapping_filter_tags.onesplit import get_one_split_tag
from fastfusion.mapper.FFM.pareto import PartialMappings
from fastfusion.mapper.simanneal.tracking import EvaluationsScoreTracker
from fastfusion.mapper.simanneal.wrappers import join_sims as join_sims_simanneal


def cache(filename):
    filename = filename if filename.endswith(".pkl") else f"{filename}.pkl"
    def decorator(func):
        def wrapper(*args, **kwargs):
            if os.path.exists(filename):
                return pickle.load(open(filename, "rb"))
            else:
                result = func(*args, **kwargs)
                pickle.dump(result, open(filename, "wb"))
                return result
        return wrapper
    return decorator

archname = "snowcat"
spec = Specification.from_yaml(
    f"architecture/{archname}.arch.yaml",
    # "workloads/mha_full.workload.yaml",
    # "workloads/mha_full.renames.yaml"
    "workloads/matmuls8_mixed.workload.yaml",
)

@cache(hashlib.md5(spec._yaml_source.encode()).hexdigest())
def get_sims_with_cache():
    spec.estimate_energy_area()
    flattened_architecture = spec.get_flattened_architecture()
    t0 = time.time()
    sims, decompress_data = get_sims(spec, flattened_architecture)
    pmapping_time = time.time() - t0
    total_pmappings = sum(p.mappings.n_pmappings for v in sims.values() for p in v)
    print(f'Took {pmapping_time:.2f} seconds to generate {total_pmappings} partial mappings ({total_pmappings / pmapping_time:.2f} per second)')

    t0 = time.time()
    mappings = join_sims(sims, spec, flattened_architecture, drop_valid_reservations=archname != "snowcat")
    join_time = time.time() - t0

    mappings.decompress(decompress_data)

    data = mappings.data
    data["EDP"] = data["metric_Energy"] * data["metric_Latency"]
    best_edp = data["EDP"].min()

    pmappings_per_second = total_pmappings / pmapping_time
    time_limit = (join_time + pmapping_time) * pmappings_per_second * 1000 

    tracker = EvaluationsScoreTracker(
        max_evaluations=mappings.n_pmappings,
        stop_at_score=best_edp,
    )
    
    return sims, tracker, flattened_architecture

sims, tracker, flattened_architecture = get_sims_with_cache()

print(f"Remember to scale by # threads")

joined2 = join_sims_simanneal(
    sims,
    tracker,
    "simulated_anneal",
    spec,
    flattened_architecture,
)

INFO        Loading yaml file architecture/snowcat.arch.yaml
INFO        Found top key variables in architecture/snowcat.arch.yaml
INFO        Found top key architecture in architecture/snowcat.arch.yaml
INFO        Loading yaml file workloads/matmuls8_mixed.workload.yaml
INFO        Found top key workload in workloads/matmuls8_mixed.workload.yaml


Remember to scale by # threads
Checking Matmul1 Matmul2
Evaluations: 10.0, Score: 90194313216.0


In [None]:
import importlib
import fastfusion.visualization.interactive
importlib.reload(fastfusion.visualization.interactive)
from fastfusion.visualization.interactive import plotly_show
# mappings2.make_pareto(columns=["RESOURCE_GlobalBuffer_LEVEL_0", "metric_Energy"])
# plotly_show(result, "RESOURCE_GlobalBuffer_LEVEL_0", "metric_Energy", category="Category", logscales=False, einsum_names=spec.workload.einsum_names)
plotly_show(result, "metric_Latency", "metric_Energy", logscales=False, einsum_names=spec.workload.einsum_names)
# from fastfusion.visualization.interactive import mapping2svg
# mapping2svg(mappings.data.iloc[0], spec.workload.einsum_names)

In [None]:
mappings.data.sort_values(by="metric_Energy", ascending=True).head()
from fastfusion.mapper.FFM.visualization import make_mapping
from IPython.display import SVG
newmapping = make_mapping(mappings.data.iloc[0], spec.workload.einsum_names)
display(SVG(newmapping.render()))

# {'n1'}-1 || [GlobalBuffer] T1 sz 0 above 1
# TODO: Re-add -1 to the mapper one eisnum freenig
# compatibility2sims['Matmul1']["{'n1'}-1 || [GlobalBuffer] T1 sz 0 above 1"]
# Above 1: 8192
# Above 2: 8321
# compatibility2sims['Matmul2']["{'n1'}-1 || [GlobalBuffer] T1 sz 0 above 1, [GlobalBuffer] T2 sz 0 above 0"]

In [None]:
compatibility2sims = {einsum_name: {s.compatibility_str(): s for s in sims2} for einsum_name, sims2 in sims.items()}
print(compatibility2sims)