In [None]:
import hashlib
import os
import pickle
from hwcomponents_cacti import SRAM as CactiSRAM
from hwcomponents_library import AladdinAdder, AladdinMultiplier

from fastfusion.frontend.architecture import Memory
from fastfusion.frontend.specification import Specification
from fastfusion.mapper.FFM.exploration.mapper_multi_einsum import get_sims
from fastfusion.mapper.simanneal.wrappers import join_sims

import copy
import time
from fastfusion import Specification
from fastfusion.mapper.metrics import Metrics
from fastfusion.mapper.FFM.exploration.mapper_multi_einsum import get_sims
from fastfusion.mapper.FFM.joining.sim import SIM
from fastfusion.mapper.FFM.joining.simexplore import join_sims
import fastfusion.mapper.FFM.exploration.mapper_one_einsum as mapper_one_einsum

from fastfusion.mapper.FFM.exploration.mapping_filter_tags.ffmt import get_ffmt_tag
from fastfusion.mapper.FFM.exploration.mapping_filter_tags.onesplit import get_one_split_tag
from fastfusion.mapper.FFM.pareto import PartialMappings
from fastfusion.mapper.FFM import make_pmappings, join_pmappings
from fastfusion.mapper.FFM.exploration.mapper_multi_einsum import get_per_tensor_size, get_num_computes


objective = lambda df: df['Total_Latency']# * df['Total_Energy']

def get_pmappings(spec: Specification, parameterization: str, local_buffer_fanout: int, batch_size=1, n_tokens=16384):
    spec.mapper_ffm.metrics = Metrics.LATENCY# | Metrics.ENERGY
    spec.calculate_component_energy_area(area=False)
    llb: Memory = spec.architecture.nodes["LocalBuffer"]
    llb.spatial.fanout["Z"] = local_buffer_fanout
    cachekey = (parameterization, local_buffer_fanout, batch_size, n_tokens)
    fname = "_".join(str(x) for x in cachekey) + ".pkl"
    
    if parameterization == "Ideal":
        llb: Memory = spec.architecture.nodes["LocalBuffer"]
        llb.spatial.fanout.clear()
        reg: Memory = spec.architecture.nodes["Register"]
        reg.spatial.fanout.clear()
        reg.attributes.size *= 128 * 128 * local_buffer_fanout
        dram: Memory = spec.architecture.nodes["MainMemory"]
        dram.attributes.shared_read_write_bandwidth += f" / (128 * 128 * {local_buffer_fanout})"
    
    if os.path.exists(f"cache/pmappings_{fname}.pkl"):
        print(f"Loading from cache: pmappings_{fname}")
        return pickle.load(open(f"cache/pmappings_{fname}.pkl", "rb"))
    pmappings = make_pmappings(spec)
    pickle.dump(pmappings, open(f"cache/pmappings_{fname}.pkl", "wb"))
    return pmappings

def get_fused_mappings(local_buffer_fanout: int, parameterization="", return_mappings=False, batch_size=1, n_tokens=16384):
    spec = Specification.from_yaml(
        f"architecture/four_level.arch.yaml",
        "workloads/mha_full.workload.yaml",
        "workloads/mha_full.renames.yaml",
        jinja_parse_data={
            "BATCH_SIZE": batch_size,
            "N_TOKENS": n_tokens,
        }
    )
    
    # if parameterization == "Ideal":
    #     return get_num_computes(spec) / local_buffer_fanout / 128 / 128, None
    
    cachekey = (parameterization, local_buffer_fanout, batch_size, n_tokens)
    fname = "_".join(str(x) for x in cachekey) + ".pkl"
    if os.path.exists(f"cache/{fname}.pkl"):
        print(f"Loading from cache: {fname}")
        mappings = pickle.load(open(f"cache/{fname}.pkl", "rb"))
        if return_mappings:
            return mappings
        return objective(mappings.data).min(), mappings

    spec = copy.deepcopy(spec)
    if parameterization in ["Unfused", "LoopTree", "LoopForest"]:
        pmappings = get_pmappings(spec, "LoopForest", local_buffer_fanout, batch_size, n_tokens)
    elif parameterization == "TileFlow":
        spec.mapper_ffm.timeloop_style_even = True
        pmappings = get_pmappings(spec, parameterization, local_buffer_fanout, batch_size, n_tokens)
    elif parameterization == "Ideal":
        pmappings = get_pmappings(spec, parameterization, local_buffer_fanout, batch_size, n_tokens)
    else:
        raise ValueError(f"Unknown parameterization: {parameterization}")
    
    if parameterization == "Unfused":
        filter_lambda = lambda pm: set(x.resource_name for x in pm.compatibility.tensors) == {"MainMemory"}
        pmappings.filter(filter_lambda)
    elif parameterization == "LoopTree" or parameterization == "TileFlow":
        filter_lambda = lambda pm: len(set(len(x.loops) for x in pm.compatibility.tensors if x.resource_name != "MainMemory")) <= 1
        pmappings.filter(filter_lambda)

    mappings = join_pmappings(spec, pmappings)
    
    if parameterization == "Ideal":
        mappings.data["Total_Latency"] /= 128 * 128 * local_buffer_fanout
    
    pickle.dump(mappings, open(f"cache/{fname}.pkl", "wb"))
    if return_mappings:
        return mappings
    
    return objective(mappings.data).min(), mappings

parameterization2edp = {}
parameterization2mappings = {}

parameterizations = ["Unfused", "TileFlow", "LoopTree", "LoopForest", "Ideal"]

for batch_size in [1]:
    # for batch_size, n_tokens in [(64, 2048), (1, 2048), (64, 16384), (1, 16384), (1, 1024), (1, 4096), (1, 8192), (1, 32768), (1, 65536)]:
    for n_tokens in [1024, 2048, 4096, 8192, 16384, 32768]:#, 65536]:
        for n_pes in [256]:
            for parameterization in parameterizations:
                x, mappings = get_fused_mappings(
                    n_pes,
                    parameterization=parameterization,
                    batch_size=batch_size,
                    n_tokens=n_tokens,
                )
                if x != 0:
                    parameterization2edp.setdefault((batch_size, n_tokens, n_pes), {})[parameterization] = x
                    parameterization2mappings.setdefault((batch_size, n_tokens, n_pes), {})[parameterization] = mappings
                    print(f"{batch_size} {n_tokens} {n_pes} {parameterization}: {x}")

In [None]:
import matplotlib.pyplot as plt
plt.style.use('default')
# plt.rcParams.update({'font.size': 28})

def plot_default_formatting(ax, grid_axis='both'):
    ax.tick_params(axis='both', which='major')#, labelsize=20)
    ax.tick_params(axis='both', which='minor')#, labelsize=20)
    legend = ax.legend()
    legend.get_frame().set_facecolor('white')
    legend.get_frame().set_edgecolor('black')
    for spine in ax.spines.values():
        spine.set_edgecolor('black')
    if ax.get_legend() is None:
        legend = ax.legend(fontsize=24, ncol=2)
    ax.minorticks_on()
    ax.grid(axis=grid_axis, which='major', linestyle='-', linewidth='0.3', color='gray')
    ax.grid(axis=grid_axis, which='minor', linestyle='--', linewidth='0.1', color='lightgray')
    

def make_bar_chart(
    data,
    title,
    xlabel,
    ylabel,
    y_scale,
    output_file=None,
    normalize: bool = False,
    ylim=(None, None),
    xlim=(None, None),
):
    """
    Create a bar chart from the given data and save it as a PDF.
    """
    plt.figure(figsize=(16, 8))
    
    if isinstance(data, dict) and isinstance(next(iter(data.values())), dict):
        bar_width = 0.8 / len(data)
        keys = list(next(iter(data.values())).keys())
        x = range(len(keys))
        first = next(iter(data.values()))
            
        for i, (label, values) in enumerate(data.items()):
            bar_positions = [pos + i * bar_width for pos in x]
            to_plot = values
            if normalize:
                to_plot = {k: v / first[k] for k, v in values.items()}
            bars = plt.bar(bar_positions, to_plot.values(), width=bar_width, label=label)
        plt.xticks([pos + (len(data) - 1) * bar_width / 2 for pos in x], keys)
        plt.legend(loc='upper right', fontsize=10)
    else:
        keys = list(data.keys())
        bars = plt.bar(keys, data.values())

    # Set logarithmic scale for Y-axis if specified
    if y_scale == 'log':
        plt.yscale('log')

    # Add labels and title
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.ylim(ylim)
    plt.xlim(xlim)

    # Rotate X-axis labels vertically
    plt.xticks(rotation=90)
    
    plot_default_formatting(plt.gca(), grid_axis='y')
    
    if output_file is not None:
        with open(output_file, 'wb') as f:
            plt.savefig(f, format='pdf', bbox_inches='tight')

    # Show the plot
    plt.show()

parameterization2edp_normalized = {}
for (batch_size, n_tokens, n_pes), edp_dict in parameterization2edp.items():
    min_edp = min(edp_dict.values())
    ideal = edp_dict["Ideal"]
    min_non_ideal = min(v for k, v in edp_dict.items() if k != "Ideal")
    for k, v in edp_dict.items():
        parameterization2edp_normalized.setdefault(f"{batch_size} {n_tokens} {n_pes} {k}", v / min_edp)
        print(f"{batch_size} {n_tokens} {n_pes} {k}: {v / min_non_ideal:.4f} ({v / ideal:.4f})")

print()
make_bar_chart(
    parameterization2edp_normalized,
    title=None,
    xlabel=None,
    ylabel="EDP",
    y_scale='linear'
)

In [None]:
from IPython.display import SVG
from fastfusion.mapper.FFM.pareto.df_convention import MAPPING_COLUMN
mapping = parameterization2mappings[(1, 32768, 256)]["LoopForest"]
display(SVG(mapping.data.iloc[0][MAPPING_COLUMN].render()))
print(mapping.data.iloc[0]["Total_Latency"])
for col in mapping.data.columns:
    if "latency" not in col.lower():
        continue
    print(f'{col}: {mapping.data.iloc[0][col]} ({mapping.data.iloc[0][col] / mapping.data.iloc[0]["Total_Latency"] * 100:.2f}%)')