In [1]:
import hashlib
import os
import pickle
from hwcomponents_cacti import CactiSRAM
from hwcomponents_library import AladdinAdder, AladdinMultiplier

from fastfusion.frontend.architecture import Memory
from fastfusion.frontend.specification import Specification
from fastfusion.mapper.FFM.exploration.mapper_multi_einsum import get_sims
from fastfusion.mapper.simanneal.wrappers import join_sims

import copy
import time
from fastfusion import Specification
from fastfusion.mapper.FFM.exploration.metrics import Metrics
from fastfusion.mapper.FFM.exploration.mapper_multi_einsum import get_sims
from fastfusion.mapper.FFM.joining.sim import SIM
from fastfusion.mapper.FFM.joining.simexplore import join_sims
import fastfusion.mapper.FFM.exploration.mapper_one_einsum as mapper_one_einsum

from fastfusion.mapper.FFM.exploration.mapping_filter_tags.ffmt import get_ffmt_tag
from fastfusion.mapper.FFM.exploration.mapping_filter_tags.onesplit import get_one_split_tag
from fastfusion.mapper.FFM.pareto import PartialMappings

# TODO: area is an alias for get_area
# TODO: Separate energy and area
# TODO: Move scaling into main hwcomponents repo
# TODO: Function that just returns the hwcomponents component
# TODO: Is all the initial right consolidation necessary?
# TODO: Datawidth calculation for energy

# TODO: Reference specific tensor names in constraints, even if those tensors are not in
# a particular Einsum. Also have the error mrssages for parsing errors list which Einsum
# failed. Einsums that aren't in the tensor should resolve to NotInThisEinsum(), which =
# nothing.

# TODO: Make a setting for the below two in the spec
# TODO: Generate pmappings one Einsum at a time. Once we've made compatibility, check it
# against the previously-generated compatibilities and stop if there's no match.
# TODO: Once the previous is done, also add a forward check. Once the compatibilities of
# a particular Einsum are generated, we can immediately check the previous Einsums.
# TODO: Make the mapping return an object that supports union operators and stuff

spec = Specification.from_yaml(
    f"architecture/four_level.arch.yaml",
    "workloads/mha_full.workload.yaml",
    "workloads/mha_full.renames.yaml",
)

adder = AladdinAdder(technology="7nm", width=16)
multiplier = AladdinMultiplier(technology="7nm", width=8)
mac_area = adder.get_area() + multiplier.get_area()

base_local_buffer_size = 4 * 1024 * 1024 * 8
base_local_buffer = CactiSRAM(technology="7nm", width=128, depth=base_local_buffer_size // 128)
base_global_buffer_size = 128 * 1024 * 1024 * 8
base_global_buffer = CactiSRAM(technology="7nm", width=1024, depth=base_global_buffer_size // 1024)

area_budget = (mac_area * 128 * 128 + base_local_buffer.get_area()) * 4 + base_global_buffer.get_area()

print(f"COMPUTE ENERGY / 8 ????????????")
print(f"COMPUTE ENERGY / 8 ????????????")
print(f"COMPUTE ENERGY / 8 ????????????")
print(f"COMPUTE ENERGY / 8 ????????????")
print(f"COMPUTE ENERGY / 8 ????????????")

def get_fused_mappings(
        spec: Specification, 
        n_pes,
        local_buffer_model,
        global_buffer_model,
        tagger=None, 
        # fuse=True,
        parameterization="",
        return_mappings=False,
    ) -> PartialMappings:
    cachekey = (n_pes, local_buffer_model.width, local_buffer_model.depth, global_buffer_model.width, global_buffer_model.depth, parameterization)
    fname = parameterization + " " + hashlib.md5(str(cachekey).encode()).hexdigest()
    if os.path.exists(f"cache/{fname}.pkl"):
        print(f"Loading from cache: {fname}")
        mappings = pickle.load(open(f"cache/{fname}.pkl", "rb"))
        if return_mappings:
            return mappings
        return (mappings.data["metric_Energy"] * mappings.data["metric_Latency"]).min()
    spec = copy.deepcopy(spec)
    local_buffer: Memory = spec.architecture.nodes["LocalBuffer"]
    local_buffer.attributes.size = local_buffer_model.width * local_buffer_model.depth
    global_buffer: Memory = spec.architecture.nodes["GlobalBuffer"]
    global_buffer.attributes.size = global_buffer_model.width * global_buffer_model.depth
    for target, model in [(local_buffer, local_buffer_model), (global_buffer, global_buffer_model)]:
        target.actions["read"].arguments.energy = model.read() / model.width
        target.actions["write"].arguments.energy = model.write() / model.width
    main_memory: Memory = spec.architecture.nodes["MainMemory"]
    if parameterization == "Unfused":
        main_memory.constraints.storage.keep = "All()"
    elif parameterization == "FlashAttention":
        main_memory.constraints.storage.keep = "All() - QK - Q - K - V - I"
        main_memory.constraints.storage.bypass = "QK | Q | K | V | I"
    elif parameterization == "Fuse I":
        main_memory.constraints.storage.keep = "All() - I"
        main_memory.constraints.storage.bypass = "I"
    elif parameterization == "FFM":
        main_memory.constraints.storage.keep = "~Intermediates()"# - I - Q - K - V"# | AV | Z "
        main_memory.constraints.storage.bypass = "Q | K | V | I"#Q | K | V | I"# | QK | FFA"
        pass
    else:
        assert False, f"Parameterization {parameterization} not supported"
    register: Memory = spec.architecture.nodes["Register"]
    register.spatial.fanout["X"] = n_pes
    register.spatial.fanout["Y"] = n_pes
    
    spec.estimate_energy_area()
    flattened_architecture = spec.get_flattened_architecture()
    t0 = time.time()
    sims, decompress_data = get_sims(spec, flattened_architecture, tagger=tagger, metrics=Metrics.LATENCY | Metrics.ENERGY | Metrics.PER_COMPONENT_ENERGY) # metrics=Metrics.ENERGY | # | Metrics.PER_COMPONENT_ENERGY)
    pmapping_time = time.time() - t0
    total_pmappings = sum(p.mappings.n_pmappings for v in sims.values() for p in v)
    n_pareto_optimal_mappings = sum(len(p.mappings.data) for v in sims.values() for p in v)
    print(f'Took {pmapping_time:.2f} seconds to generate {total_pmappings} partial mappings ({total_pmappings / pmapping_time:.2f} per second). {n_pareto_optimal_mappings} pareto optimal mappings ({n_pareto_optimal_mappings / total_pmappings*100:.2f}% of total).')
    t0 = time.time()
    mappings = join_sims(sims, spec, flattened_architecture)
    join_time = time.time() - t0
    mappings.decompress(decompress_data)
    print(f"Pmappings: {pmapping_time:.2f}. Joining: {join_time:.2f}. Total Pmappings: {total_pmappings}. Total mappings: {mappings.n_pmappings}. Time per pmapping: {pmapping_time / total_pmappings:.2e}")
    pickle.dump(mappings, open(f"cache/{fname}.pkl", "wb"))
    if return_mappings:
        return mappings
    return (mappings.data["metric_Energy"] * mappings.data["metric_Latency"]).min()

print(f'Overall area budget: {area_budget * 1e6} mm^2')

parameterization2edp = {}
parameterizations = ["FFM", "Unfused", "FlashAttention"]#, "FlashAttention", "Unfused"]

for glb_MB in [128]:#, 64, 32, 16]:#,16]:#16, 32, 64, 128]: # [16, 32, 64, 128]: # 16, 64
    cur_area_budget = area_budget
    glb_size = glb_MB * 1024 * 1024 * 8
    glb = CactiSRAM(technology="7nm", width=1024, depth=glb_size // 1024)
    cur_area_budget -= glb.get_area()
    # for sram_MB in [0.25, 0.5, 1, 1.5, 2, 2.5, 3, 3.5, 4]:
    for sram_MB in [4]:#, 1, 0.25]:#,0.5]:#[0.5,1,2,4]: # [0.25, 1, 4]: # 0.125, 0.5, 2
        sram_size = sram_MB * 1024 * 1024 * 8
        llb = CactiSRAM(technology="7nm", width=128, depth=sram_size // 128)
        remaining_area = cur_area_budget / 4 - llb.get_area() # Per-MXU
        if remaining_area < 0:
            break
        mac_dims = int((remaining_area / mac_area) ** 0.5)
        print(f"Global buffer: {glb_MB} MB, Local buffer: {sram_MB} MB, MAC dims: {mac_dims}x{mac_dims}")
        print(f'GLB read energy: {glb.read()}. LLB read energy: {llb.read()}')
        
        for parameterization in parameterizations: # "fuse"
            x = get_fused_mappings(
                spec,
                mac_dims,
                llb,
                glb,
                parameterization=parameterization,
            )
            if x != 0:
                parameterization2edp[f"{parameterization} {glb_MB}MB {sram_MB}MB {mac_dims}x{mac_dims}"] = x

INFO        Loading yaml file architecture/four_level.arch.yaml
INFO        Found top key variables in architecture/four_level.arch.yaml
INFO        Found top key architecture in architecture/four_level.arch.yaml
INFO        Found top key component_classes in architecture/four_level.arch.yaml
INFO        Loading yaml file workloads/mha_full.workload.yaml
INFO        Found top key workload in workloads/mha_full.workload.yaml
INFO        Loading yaml file workloads/mha_full.renames.yaml
INFO        Found top key renames in workloads/mha_full.renames.yaml


COMPUTE ENERGY / 8 ????????????
COMPUTE ENERGY / 8 ????????????
COMPUTE ENERGY / 8 ????????????
COMPUTE ENERGY / 8 ????????????
COMPUTE ENERGY / 8 ????????????
Overall area budget: 39.657475727216784 mm^2


INFO        Calculated "614e9 * 8 / 1.05e9" = 4678.0952380952385.
INFO        Calculated "0.84e-12 / 8" = 1.05e-13.
INFO        Calculated "0.5" = 0.5.


Global buffer: 128 MB, Local buffer: 4 MB, MAC dims: 128x128
GLB read energy: 1.6504301873093021e-09. LLB read energy: 4.2315915388286673e-11
By default metrics optimizes for energy and latency.We should change to just energy or just latency at some point.


Generating storage and loop choices for Einsum I: 6it [00:00, 83.13it/s]
Generating Pmappings for I:  17%|█▋        | 1/6 [00:00<00:02,  1.78it/s]

0 / 65 skipped (0.00%)


Generating Pmappings for I:  33%|███▎      | 2/6 [00:00<00:01,  2.16it/s]

0 / 54 skipped (0.00%)


Generating Pmappings for I:  50%|█████     | 3/6 [00:01<00:01,  2.10it/s]

0 / 65 skipped (0.00%)


Generating Pmappings for I:  67%|██████▋   | 4/6 [00:01<00:00,  2.15it/s]

0 / 65 skipped (0.00%)


Generating Pmappings for I:  83%|████████▎ | 5/6 [00:02<00:00,  2.24it/s]

0 / 54 skipped (0.00%)


Generating Pmappings for I: 100%|██████████| 6/6 [00:02<00:00,  2.22it/s]

0 / 54 skipped (0.00%)



Generating storage and loop choices for Einsum V: 8it [00:00, 97.97it/s]
Generating Pmappings for V:   0%|          | 0/8 [00:00<?, ?it/s]

9 / 200 skipped (4.50%)


Generating Pmappings for V:  25%|██▌       | 2/8 [00:08<00:23,  3.92s/it]

Skipping Compatibility(loops=(Loop({'b'}, 0, False),), storage={Reservation('I', 1, 'GlobalBuffer', 0)}), tags=Tags(({})) because it is not in any tensor2boundless_compatibilities
48 / 432 skipped (11.11%)


Generating Pmappings for V:  50%|█████     | 4/8 [00:11<00:09,  2.34s/it]

Skipping Compatibility(loops=(Loop({'b'}, 0, False),), storage={Reservation('I', 1, 'GlobalBuffer', 0)}), tags=Tags(({})) because it is not in any tensor2boundless_compatibilities
117 / 117 skipped (100.00%)


Generating Pmappings for V:  50%|█████     | 4/8 [00:50<00:50, 12.74s/it]


AssertionError: Skipped everyone! Compatibility: Compatibility(loops=(Loop({'b'}, 0, False), Loop({'m'}, 0, False), Loop({'b'}, 0, False), Loop({'d'}, 0, False)), storage={Reservation('I', 4, 'GlobalBuffer', 0), Reservation('V', 2, 'GlobalBuffer', 0)}), tags=Tags(({}))

In [None]:
import matplotlib.pyplot as plt
plt.style.use('default')
# plt.rcParams.update({'font.size': 28})

def plot_default_formatting(ax, grid_axis='both'):
    ax.tick_params(axis='both', which='major')#, labelsize=20)
    ax.tick_params(axis='both', which='minor')#, labelsize=20)
    legend = ax.legend()
    legend.get_frame().set_facecolor('white')
    legend.get_frame().set_edgecolor('black')
    for spine in ax.spines.values():
        spine.set_edgecolor('black')
    if ax.get_legend() is None:
        legend = ax.legend(fontsize=24, ncol=2)
    ax.minorticks_on()
    ax.grid(axis=grid_axis, which='major', linestyle='-', linewidth='0.3', color='gray')
    ax.grid(axis=grid_axis, which='minor', linestyle='--', linewidth='0.1', color='lightgray')
    

def make_bar_chart(
    data,
    title,
    xlabel,
    ylabel,
    y_scale,
    output_file=None,
    normalize: bool = False,
    ylim=(None, None),
    xlim=(None, None),
):
    """
    Create a bar chart from the given data and save it as a PDF.
    """
    plt.figure(figsize=(16, 8))
    
    if isinstance(data, dict) and isinstance(next(iter(data.values())), dict):
        bar_width = 0.8 / len(data)
        keys = list(next(iter(data.values())).keys())
        x = range(len(keys))
        first = next(iter(data.values()))
            
        for i, (label, values) in enumerate(data.items()):
            bar_positions = [pos + i * bar_width for pos in x]
            to_plot = values
            if normalize:
                to_plot = {k: v / first[k] for k, v in values.items()}
            bars = plt.bar(bar_positions, to_plot.values(), width=bar_width, label=label)
        plt.xticks([pos + (len(data) - 1) * bar_width / 2 for pos in x], keys)
        plt.legend(loc='upper right', fontsize=10)
    else:
        keys = list(data.keys())
        bars = plt.bar(keys, data.values())

    # Set logarithmic scale for Y-axis if specified
    if y_scale == 'log':
        plt.yscale('log')

    # Add labels and title
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.ylim(ylim)
    plt.xlim(xlim)

    # Rotate X-axis labels vertically
    plt.xticks(rotation=90)
    
    plot_default_formatting(plt.gca(), grid_axis='y')
    
    if output_file is not None:
        with open(output_file, 'wb') as f:
            plt.savefig(f, format='pdf', bbox_inches='tight')

    # Show the plot
    plt.show()
print()
make_bar_chart(
    parameterization2edp,
    title=None,
    xlabel=None,
    ylabel="EDP",
    y_scale='linear'
)

In [None]:
# # glb_MB = 128
# # sram_MB = 4
# # parameterization = ""

# # cur_area_budget = area_budget
# # glb_size = glb_MB * 1024 * 1024 * 8
# # glb = CactiSRAM(technology="7nm", width=1024, depth=glb_size // 1024)
# # cur_area_budget -= glb.get_area()
# # # for sram_MB in [0.25, 0.5, 1, 1.5, 2, 2.5, 3, 3.5, 4]:
# # sram_size = sram_MB * 1024 * 1024 * 8
# # llb = CactiSRAM(technology="7nm", width=128, depth=sram_size // 128)
# # remaining_area = cur_area_budget / 4 - llb.get_area() # Per-MXU
# # mac_dims = int((remaining_area / mac_area) ** 0.5)
# # print(f"Global buffer: {glb_MB} MB, Local buffer: {sram_MB} MB, MAC dims: {mac_dims}x{mac_dims}")
# # print(f'GLB read energy: {glb.read()}. LLB read energy: {llb.read()}')

from fastfusion.mapper.FFM.exploration.mapper_multi_einsum import get_per_tensor_size, get_num_computes
for tensor, size in sorted(get_per_tensor_size(spec).items(), key=lambda x: x[1], reverse=True):
    print(f"{tensor}: {size}")
print(f"Number of computes: {get_num_computes(spec)}")

In [None]:
from fastfusion.mapper.FFM.exploration.mapper_multi_einsum import get_rank_variable_bounds_for_all_einsums

parameterization2latencycols: list[dict[str, float]] = []
for p in parameterizations:
    mappings = get_fused_mappings(
        spec,
        mac_dims,
        llb,
        glb,
        return_mappings=True,
        parameterization=p #
    )
    mappings._data = mappings.data.sort_values(by="metric_Energy", ascending=True)
    rank_variable_bounds = get_rank_variable_bounds_for_all_einsums(spec)

    row = {
        "Parameterization": p,
    }
    for col in mappings.data.columns:
        print(f'{col}: {mappings.data.iloc[0][col]}')
        # if "Latency" in col:
        # if "metric_Latency" in col:
        # if "Energy" in col:
        # if "metric_Energy" in col:
            # row[col] = mappings.data.iloc[0][col]
    parameterization2latencycols.append(row)

    # from fastfusion.mapper.FFM.visualization import make_mapping
    # from IPython.display import SVG
    # newmapping = make_mapping(mappings.data.iloc[0], spec.workload.einsum_names, get_rank_variable_bounds_for_all_einsums(spec))
    # for col in mappings.data.columns:
    #     print(f'{col}: {mappings.data.iloc[0][col]}')

    # display(SVG(newmapping.render()))
    
from pandas import DataFrame
df = DataFrame(parameterization2latencycols)
import pandas as pd
pd.set_option('display.max_columns', None)
df
    
# {'n1'}-1 || [GlobalBuffer] T1 sz 0 above 1
# TODO: Re-add -1 to the mapper one eisnum freenig
# compatibility2sims['Matmul1']["{'n1'}-1 || [GlobalBuffer] T1 sz 0 above 1"]
# Above 1: 8192
# Above 2: 8321
# compatibility2sims['Matmul2']["{'n1'}-1 || [GlobalBuffer] T1 sz 0 above 1, [GlobalBuffer] T2 sz 0 above 0"]

In [None]:
from fastfusion.mapper.FFM.exploration.mapper_multi_einsum import get_rank_variable_bounds_for_all_einsums

sram_size = 4 * 1024 * 1024 * 8
llb = CactiSRAM(technology="7nm", width=128, depth=sram_size // 128)
mac_dims = int((((area_budget - glb.get_area()) / 4 - llb.get_area()) / mac_area)** 0.5)
mappings = get_fused_mappings(
    spec,
    mac_dims,
    llb,
    glb,
    return_mappings=True,
    parameterization="FFM"
)
mappings._data = mappings.data.sort_values(by="metric_Latency", ascending=True).head()
rank_variable_bounds = get_rank_variable_bounds_for_all_einsums(spec)
from fastfusion.mapper.FFM.visualization import make_mapping
from IPython.display import SVG
newmapping = make_mapping(mappings.data.iloc[0], spec.workload.einsum_names, get_rank_variable_bounds_for_all_einsums(spec))
a = {}
for col in mappings.data.columns:
    print(f'{col}: {mappings.data.iloc[0][col]}')
    if "Latency" in col:
        a[col] = mappings.data.iloc[0][col]
display(SVG(newmapping.render()))

In [None]:
from fastfusion.mapper.FFM.exploration.mapper_multi_einsum import get_rank_variable_bounds_for_all_einsums

sram_size = 0.5 * 1024 * 1024 * 8
llb = CactiSRAM(technology="7nm", width=128, depth=sram_size // 128)
mac_dims = int((((area_budget - glb.get_area()) / 4 - llb.get_area()) / mac_area)** 0.5)
mappings = get_fused_mappings(
    spec,
    mac_dims,
    llb,
    glb,
    return_mappings=True,
    parameterization="FFM"
)
mappings._data = mappings.data.sort_values(by="metric_Latency", ascending=True).head()
rank_variable_bounds = get_rank_variable_bounds_for_all_einsums(spec)
from fastfusion.mapper.FFM.visualization import make_mapping
from IPython.display import SVG
newmapping = make_mapping(mappings.data.iloc[0], spec.workload.einsum_names, get_rank_variable_bounds_for_all_einsums(spec))
b = {}
for col in mappings.data.columns:
    print(f'{col}: {mappings.data.iloc[0][col]}')
    if "Latency" in col:
        b[col] = mappings.data.iloc[0][col]
display(SVG(newmapping.render()))

In [None]:
df = pd.DataFrame([a, b])
df