In [1]:
import hashlib
import os
import pickle
from hwcomponents_cacti import SRAM as CactiSRAM
from hwcomponents_library import AladdinAdder, AladdinMultiplier

from fastfusion.frontend.architecture import Memory
from fastfusion.frontend.specification import Specification
from fastfusion.mapper.FFM.exploration.mapper_multi_einsum import get_sims
from fastfusion.mapper.simanneal.wrappers import join_sims

import copy
import time
from fastfusion import Specification
from fastfusion.mapper.metrics import Metrics
from fastfusion.mapper.FFM.exploration.mapper_multi_einsum import get_sims
from fastfusion.mapper.FFM.joining.sim import SIM
from fastfusion.mapper.FFM.joining.simexplore import join_sims
import fastfusion.mapper.FFM.exploration.mapper_one_einsum as mapper_one_einsum

from fastfusion.mapper.FFM.exploration.mapping_filter_tags.ffmt import get_ffmt_tag
from fastfusion.mapper.FFM.exploration.mapping_filter_tags.onesplit import get_one_split_tag
from fastfusion.mapper.FFM.pareto import PartialMappings
from fastfusion.mapper.FFM import make_pmappings, join_pmappings

# TODO: area is an alias for get_area
# TODO: Separate energy and area
# TODO: Move scaling into main hwcomponents repo
# TODO: Function that just returns the hwcomponents component
# TODO: Is all the initial right consolidation necessary?
# TODO: Datawidth calculation for energy

# TODO: Reference specific tensor names in constraints, even if those tensors are not in
# a particular Einsum. Also have the error mrssages for parsing errors list which Einsum
# failed. Einsums that aren't in the tensor should resolve to NotInThisEinsum(), which =
# nothing.

# TODO: Make a setting for the below two in the spec
# TODO: Generate pmappings one Einsum at a time. Once we've made compatibility, check it
# against the previously-generated compatibilities and stop if there's no match.
# TODO: Once the previous is done, also add a forward check. Once the compatibilities of
# a particular Einsum are generated, we can immediately check the previous Einsums.
# TODO: Make the mapping return an object that supports union operators and stuff
# TODO: The fix in mapping.py

# TODO: have inf a supported value in YAMLs
# TODO: programatically check if any storages are below all backing storages. If so,
# don't record reservations for it.
# TODO: If any memroies have size > sum of all tensor sizes, also don't record reservations

spec = Specification.from_yaml(
    f"architecture/four_level.arch.yaml",
    "workloads/mha_full.workload.yaml",
    "workloads/mha_full.renames.yaml",
)

adder = AladdinAdder(tech_node=7e-9, width=16)
multiplier = AladdinMultiplier(tech_node=7e-9, width=8)
mac_area = adder.area + multiplier.area

base_local_buffer_size = 4 * 1024 * 1024 * 8
base_local_buffer = CactiSRAM(tech_node=7e-9, width=128, depth=base_local_buffer_size // 128)
base_global_buffer_size = 128 * 1024 * 1024 * 8
base_global_buffer = CactiSRAM(tech_node=7e-9, width=1024, depth=base_global_buffer_size // 1024)
area_budget = (mac_area * 128 * 128 + base_local_buffer.area) * 4 + base_global_buffer.area

print(f"COMPUTE ENERGY / 8 ????????????")
print(f"COMPUTE ENERGY / 8 ????????????")
print(f"COMPUTE ENERGY / 8 ????????????")
print(f"COMPUTE ENERGY / 8 ????????????")
print(f"COMPUTE ENERGY / 8 ????????????")

objective = lambda df: df['metric_Latency']# * df['metric_Energy']

def get_fused_mappings(
        spec: Specification, 
        pe_x,
        pe_y,
        local_buffer_model,
        global_buffer_model,
        tagger=None, 
        # fuse=True,
        parameterization="",
        return_mappings=False,
        mac_energy: float = None,
        max_latency: float = None
    ) -> PartialMappings:
    cachekey = (pe_x, pe_y, local_buffer_model.width, local_buffer_model.depth, global_buffer_model.width, global_buffer_model.depth, parameterization)
    fname = parameterization + " " + hashlib.md5(str(cachekey).encode()).hexdigest()
    if os.path.exists(f"cache/{fname}.pkl"):
        print(f"Loading from cache: {fname}")
        mappings = pickle.load(open(f"cache/{fname}.pkl", "rb"))
        if return_mappings:
            return mappings
        return objective(mappings.data).min(), mappings
    spec = copy.deepcopy(spec)
    local_buffer: Memory = spec.architecture.nodes["LocalBuffer"]
    local_buffer.attributes.size = local_buffer_model.width * local_buffer_model.depth
    global_buffer: Memory = spec.architecture.nodes["GlobalBuffer"]
    global_buffer.attributes.size = global_buffer_model.width * global_buffer_model.depth
    if mac_energy is not None:
        mac = spec.architecture.nodes["MAC"]
        mac.actions["compute"].arguments.energy = mac_energy / 8
    for target, model in [(local_buffer, local_buffer_model), (global_buffer, global_buffer_model)]:
        target.actions["read"].arguments.energy = model.read() / model.width
        target.actions["write"].arguments.energy = model.write() / model.width
    main_memory: Memory = spec.architecture.nodes["MainMemory"]
    if parameterization == "Unfused":
        main_memory.constraints.tensors.keep = "All()"
    elif parameterization == "FlashAttention B":
        main_memory.constraints.tensors.keep = "All() - (I | Q | K | V | QK | QK_softmax)"# - QK_softmax"# - Q - K - V - I"
        main_memory.constraints.tensors.bypass = "I | Q | K | V | QK | QK_softmax"#Q | K | V | I"# | QK | FFA"
    elif parameterization == "FlashAttention A":
        main_memory.constraints.tensors.keep = "All() - (QK | QK_softmax)"# - QK_softmax"# - Q - K - V - I"
        main_memory.constraints.tensors.bypass = "QK | QK_softmax"#Q | K | V | I"# | QK | FFA"
    elif parameterization == "FFM":
        main_memory.constraints.tensors.keep = "~Intermediates()" #"# | AV | Z "
        main_memory.constraints.tensors.bypass = "I | Q | K | V | QK"#Q | K | V | I"# | QK | FFA"
        pass
    else:
        assert False, f"Parameterization {parameterization} not supported"
    register: Memory = spec.architecture.nodes["Register"]
    register.spatial.fanout["X"] = pe_x
    register.spatial.fanout["Y"] = pe_y
    
    spec.calculate_component_energy_area()
    # flattened_architecture = spec.get_flattened_architecture()
    # t0 = time.time()
    # sims, pmapping_objects = get_sims(spec, flattened_architecture, tagger=tagger, metrics=Metrics.LATENCY | Metrics.ENERGY)
    pmappings = make_pmappings(spec)
    # pmapping_time = time.time() - t0
    # total_pmappings = sum(p.mappings.n_pmappings for v in sims.values() for p in v)
    # n_pareto_optimal_mappings = sum(len(p.mappings.data) for v in sims.values() for p in v)
    # print(f'Took {pmapping_time:.2f} seconds to generate {total_pmappings} partial mappings ({total_pmappings / pmapping_time:.2f} per second). {n_pareto_optimal_mappings} pareto optimal mappings ({n_pareto_optimal_mappings / total_pmappings*100:.2f}% of total).')
    # t0 = time.time()
    mappings = join_pmappings(spec, pmappings)
    # join_time = time.time() - t0
    # print(f"Pmappings: {pmapping_time:.2f}. Joining: {join_time:.2f}. Total Pmappings: {total_pmappings}. Total mappings: {mappings.n_pmappings}. Time per pmapping: {pmapping_time / total_pmappings:.2e}")
    pickle.dump(mappings, open(f"cache/{fname}.pkl", "wb"))
    if return_mappings:
        return mappings
    return objective(mappings.data).min(), mappings

print(f'Overall area budget: {area_budget * 1e6} mm^2')

parameterization2edp = {}
parameterization2mappings = {}

parameterizations = ["Unfused", "FlashAttention A", "FlashAttention B", "FFM"]#, "Unfused"] # "FFM", "Unfused", "FlashAttention"]#, "FlashAttention", "Unfused"]

TARGET_TECH_NODE = 4e-9
adder = AladdinAdder(tech_node=TARGET_TECH_NODE, width=16)
multiplier = AladdinMultiplier(tech_node=TARGET_TECH_NODE, width=8)
mac_area = adder.area + multiplier.area

glb_size = 512 * 1024 * 1024 * 8
glb = CactiSRAM(tech_node=TARGET_TECH_NODE, width=1024, depth=glb_size // 1024)
llb_size = 1 * 1024 * 1024 * 8
llb = CactiSRAM(tech_node=TARGET_TECH_NODE, width=128, depth=llb_size // 128)

# for glb_MB in [8, 16, 32, 64, 128, 256, 512, 1024]:#, 64, 128]:#, 64, 256]:#,16]:#16, 32, 64, 128]: # [16, 32, 64, 128]: # 16, 64
# for glb_MB in [1024, 512, 256, 128, 64, 32, 16, 8]:
for mac_x, mac_y in [(128,128), (256,128), (256,256), (512, 256), (512,512), (1024, 512)]:
    total_mac_area = mac_area * mac_x * mac_y
    area_remaining = area_budget - 4 * (llb.area + total_mac_area)
    while glb.area > area_remaining:
        print(f"Global buffer area: {glb.area}. Area remaining: {area_remaining}")
        glb_size //= 2
        glb = CactiSRAM(tech_node=TARGET_TECH_NODE, width=1024, depth=glb_size // 1024)
        if area_remaining < 0:
            break
    max_latency = None
    
    glb_MB = glb_size // 1024 // 1024 // 8
    llb_MB = llb_size // 1024 // 1024 // 8
    
    print(f"\n\n")
    print(f"=" * 100)
    print(f"Global buffer: {glb_MB} MB, Local buffer: {llb_MB} MB, MAC dims: {mac_x}x{mac_y}")
    print(f"=" * 100)

    for parameterization in parameterizations: # "fuse"
        # while True:
        #     try:
        x, mappings = get_fused_mappings(
            spec,
            mac_x,
            mac_y,
            llb,
            glb,
            mac_energy=(adder.add() + multiplier.multiply()),
            parameterization=parameterization,
            max_latency=max_latency
        )
            # break
            # except Exception as e:
            #     max_latency *= 2
            #     print(f"Error: {e}")
        # max_latency = x
        if x != 0:
            parameterization2edp[f"{parameterization} {glb_MB}MB {llb_MB}MB {mac_x}x{mac_y}"] = x
            parameterization2mappings[f"{parameterization} {glb_MB}MB {llb_MB}MB {mac_x}x{mac_y}"] = mappings



COMPUTE ENERGY / 8 ????????????
COMPUTE ENERGY / 8 ????????????
COMPUTE ENERGY / 8 ????????????
COMPUTE ENERGY / 8 ????????????
COMPUTE ENERGY / 8 ????????????
Overall area budget: 8448.022488990722 mm^2
Global buffer area: 0.02256514974783053. Area remaining: 0.00825458610917197
Global buffer area: 0.011728085273632725. Area remaining: 0.00825458610917197



Global buffer: 128 MB, Local buffer: 1 MB, MAC dims: 128x128


Generating tensor order and loop choices for Einsum I: 1it [00:00, 14.17it/s]
Generating tensor order and loop choices for Einsum AV: 8it [00:00, 52.47it/s]t/s]
Generating tensor order and loop choices for Einsum QK: 8it [00:00, 51.17it/s]
Generating tensor order and loop choices for Einsum QK_softmax: 2it [00:00, 25.43it/s]
Generating tensor order and loop choices for Einsum FFB: 3it [00:00, 28.76it/s]

Generated 1 job for I


Generating tensor order and loop choices for Einsum K: 8it [00:00, 52.05it/s]
Generating tensor order and loop choices for Einsum FFB: 8it [00:00, 51.18it/s]
Generating tensor order and loop choices for Einsum Q: 8it [00:00, 55.77it/s]
Generating tensor order and loop choices for Einsum V: 8it [00:00, 52.22it/s]
Generating tensor order and loop choices for Einsum FFA: 8it [00:00, 54.32it/s]
Generating tensor order and loop choices for Einsum Z: 8it [00:00, 48.01it/s]
Generating jobs: 100%|██████████| 10/10 [00:01<00:00,  5.33it/s]


Generated 8 jobs for V
Generated 8 jobs for K
Generated 8 jobs for Q
Generated 8 jobs for QK
Generated 2 jobs for QK_softmax
Generated 8 jobs for AV
Generated 8 jobs for Z
Generated 8 jobs for FFA
Generated 8 jobs for FFB
WZ
V
I_in
Z
FFB
I
QK_softmax
AV
Q
WQ
WFFB
K
WFFA
WK
QK
WV
FFA


Generating pmappings: 100%|██████████| 67/67 [00:11<00:00,  5.66it/s]
Compressing pmappings: 100%|██████████| 10/10 [00:00<00:00, 51.04it/s]


SIM I tensors: {'I'}
SIM V tensors: {'I', 'V'}
SIM K tensors: {'I', 'K'}
SIM Q tensors: {'I', 'Q'}
SIM QK tensors: {'QK', 'Q', 'K'}
SIM QK_softmax tensors: {'QK_softmax', 'QK'}
SIM AV tensors: {'QK_softmax', 'AV', 'V'}
SIM Z tensors: {'AV', 'Z'}
SIM FFA tensors: {'Z', 'FFA'}
SIM FFB tensors: {'FFA'}


Grouping Partial Mappings: 100%|██████████| 1/1 [00:01<00:00,  1.65s/it]
Grouping Partial Mappings: 100%|██████████| 1/1 [00:00<00:00, 681.67it/s]
Grouping Partial Mappings: 100%|██████████| 1/1 [00:00<00:00, 854.59it/s]
Grouping Partial Mappings: 100%|██████████| 1/1 [00:00<00:00, 873.27it/s]
Grouping Partial Mappings: 100%|██████████| 1/1 [00:00<00:00, 363.68it/s]
Grouping Partial Mappings: 100%|██████████| 1/1 [00:00<00:00, 810.81it/s]
Grouping Partial Mappings: 100%|██████████| 1/1 [00:00<00:00, 714.90it/s]
Grouping Partial Mappings: 100%|██████████| 1/1 [00:00<00:00, 451.78it/s]
Grouping Partial Mappings: 100%|██████████| 1/1 [00:00<00:00, 605.15it/s]


Initial consolidate and group: 1.69 seconds

Einsum V (2/10)


Grouping Partial Mappings: 100%|██████████| 1/1 [00:00<00:00,  2.31it/s]
Merging mappings I <--> V: 100%|██████████| 1/1 [00:00<00:00, 270.79it/s]


Mapping merging: 0.44 seconds
	Combining 0(1) x 0(1) -> 1
	Number of groups for Einsum V: 1
	Number of mappings for Einsum V: 1
	Mappings per group for Einsum V: 1.0
	Largest left: 1
	Largest right: 1

Einsum K (3/10)


Merging mappings V <--> K: 100%|██████████| 1/1 [00:00<00:00, 446.49it/s]


Mapping merging: 0.00 seconds
	Combining 0(1) x 0(1) -> 1
	Number of groups for Einsum K: 1
	Number of mappings for Einsum K: 1
	Mappings per group for Einsum K: 1.0
	Largest left: 1
	Largest right: 1

Einsum Q (4/10)


Merging mappings K <--> Q: 100%|██████████| 1/1 [00:00<00:00, 528.32it/s]


Mapping merging: 0.00 seconds
	Combining 0(1) x 0(1) -> 1
	Number of groups for Einsum Q: 1
	Number of mappings for Einsum Q: 1
	Mappings per group for Einsum Q: 1.0
	Largest left: 1
	Largest right: 1

Einsum QK (5/10)


Merging mappings Q <--> QK: 100%|██████████| 1/1 [00:00<00:00, 432.31it/s]


Mapping merging: 0.01 seconds
	Combining 0(1) x 0(1) -> 1
	Number of groups for Einsum QK: 1
	Number of mappings for Einsum QK: 1
	Mappings per group for Einsum QK: 1.0
	Largest left: 1
	Largest right: 1

Einsum QK_softmax (6/10)


Merging mappings QK <--> QK_softmax: 100%|██████████| 1/1 [00:00<00:00, 469.00it/s]


Mapping merging: 0.00 seconds
	Combining 0(1) x 0(1) -> 1
	Number of groups for Einsum QK_softmax: 1
	Number of mappings for Einsum QK_softmax: 1
	Mappings per group for Einsum QK_softmax: 1.0
	Largest left: 1
	Largest right: 1

Einsum AV (7/10)


Merging mappings QK_softmax <--> AV: 100%|██████████| 1/1 [00:00<00:00, 358.61it/s]


Mapping merging: 0.00 seconds
	Combining 0(1) x 0(1) -> 1
	Number of groups for Einsum AV: 1
	Number of mappings for Einsum AV: 1
	Mappings per group for Einsum AV: 1.0
	Largest left: 1
	Largest right: 1

Einsum Z (8/10)


Merging mappings AV <--> Z: 100%|██████████| 1/1 [00:00<00:00, 452.46it/s]


Mapping merging: 0.00 seconds
	Combining 0(1) x 0(1) -> 1
	Number of groups for Einsum Z: 1
	Number of mappings for Einsum Z: 1
	Mappings per group for Einsum Z: 1.0
	Largest left: 1
	Largest right: 1

Einsum FFA (9/10)


Merging mappings Z <--> FFA: 100%|██████████| 1/1 [00:00<00:00, 448.06it/s]


Mapping merging: 0.00 seconds
	Combining 0(1) x 0(1) -> 1
	Number of groups for Einsum FFA: 1
	Number of mappings for Einsum FFA: 1
	Mappings per group for Einsum FFA: 1.0
	Largest left: 1
	Largest right: 1

Einsum FFB (10/10)


Merging mappings FFA <--> FFB: 100%|██████████| 1/1 [00:00<00:00, 451.19it/s]


Mapping merging: 0.00 seconds
	Combining 0(1) x 0(1) -> 1
	Number of groups for Einsum FFB: 1
	Number of mappings for Einsum FFB: 1
	Mappings per group for Einsum FFB: 1.0
	Largest left: 1
	Largest right: 1


Final consolidate: 100%|██████████| 1/1 [00:00<00:00, 18893.26it/s]



Initial consolidate and group: 1.69 seconds
Mapping merging: 0.47 seconds

Total: 2.17 seconds



AttributeError: Can't pickle local object 'join_pmappings.<locals>.<lambda>.<locals>.<lambda>'

In [None]:
import matplotlib.pyplot as plt
plt.style.use('default')
# plt.rcParams.update({'font.size': 28})

def plot_default_formatting(ax, grid_axis='both'):
    ax.tick_params(axis='both', which='major')#, labelsize=20)
    ax.tick_params(axis='both', which='minor')#, labelsize=20)
    legend = ax.legend()
    legend.get_frame().set_facecolor('white')
    legend.get_frame().set_edgecolor('black')
    for spine in ax.spines.values():
        spine.set_edgecolor('black')
    if ax.get_legend() is None:
        legend = ax.legend(fontsize=24, ncol=2)
    ax.minorticks_on()
    ax.grid(axis=grid_axis, which='major', linestyle='-', linewidth='0.3', color='gray')
    ax.grid(axis=grid_axis, which='minor', linestyle='--', linewidth='0.1', color='lightgray')
    

def make_bar_chart(
    data,
    title,
    xlabel,
    ylabel,
    y_scale,
    output_file=None,
    normalize: bool = False,
    ylim=(None, None),
    xlim=(None, None),
):
    """
    Create a bar chart from the given data and save it as a PDF.
    """
    plt.figure(figsize=(16, 8))
    
    if isinstance(data, dict) and isinstance(next(iter(data.values())), dict):
        bar_width = 0.8 / len(data)
        keys = list(next(iter(data.values())).keys())
        x = range(len(keys))
        first = next(iter(data.values()))
            
        for i, (label, values) in enumerate(data.items()):
            bar_positions = [pos + i * bar_width for pos in x]
            to_plot = values
            if normalize:
                to_plot = {k: v / first[k] for k, v in values.items()}
            bars = plt.bar(bar_positions, to_plot.values(), width=bar_width, label=label)
        plt.xticks([pos + (len(data) - 1) * bar_width / 2 for pos in x], keys)
        plt.legend(loc='upper right', fontsize=10)
    else:
        keys = list(data.keys())
        bars = plt.bar(keys, data.values())

    # Set logarithmic scale for Y-axis if specified
    if y_scale == 'log':
        plt.yscale('log')

    # Add labels and title
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.ylim(ylim)
    plt.xlim(xlim)

    # Rotate X-axis labels vertically
    plt.xticks(rotation=90)
    
    plot_default_formatting(plt.gca(), grid_axis='y')
    
    if output_file is not None:
        with open(output_file, 'wb') as f:
            plt.savefig(f, format='pdf', bbox_inches='tight')

    # Show the plot
    plt.show()
print()
make_bar_chart(
    parameterization2edp,
    title=None,
    xlabel=None,
    ylabel="EDP",
    y_scale='linear'
)

In [None]:
# # glb_MB = 128
# # sram_MB = 4
# # parameterization = ""

# # cur_area_budget = area_budget
# # glb_size = glb_MB * 1024 * 1024 * 8
# # glb = CactiSRAM(tech_node=7e-9, width=1024, depth=glb_size // 1024)
# # cur_area_budget -= glb.area
# # # for sram_MB in [0.25, 0.5, 1, 1.5, 2, 2.5, 3, 3.5, 4]:
# # sram_size = sram_MB * 1024 * 1024 * 8
# # llb = CactiSRAM(tech_node=7e-9, width=128, depth=sram_size // 128)
# # remaining_area = cur_area_budget / 4 - llb.area # Per-MXU
# # mac_dims = int((remaining_area / mac_area) ** 0.5)
# # print(f"Global buffer: {glb_MB} MB, Local buffer: {sram_MB} MB, MAC dims: {mac_dims}x{mac_dims}")
# # print(f'GLB read energy: {glb.read()}. LLB read energy: {llb.read()}')

from fastfusion.mapper.FFM.exploration.mapper_multi_einsum import get_per_tensor_size, get_num_computes
for tensor, size in sorted(get_per_tensor_size(spec).items(), key=lambda x: x[1], reverse=True):
    print(f"{tensor}: {size}")
print(f"Number of computes: {get_num_computes(spec)}")

In [None]:
from fastfusion.mapper.FFM.exploration.mapper_multi_einsum import get_rank_variable_bounds_for_all_einsums

parameterization2latencycols: list[dict[str, float]] = []
for p, mappings in parameterization2mappings.items():
    mappings._data = mappings.data.sort_values(by="metric_Latency", ascending=True)
    rank_variable_bounds = get_rank_variable_bounds_for_all_einsums(spec)

    row = {
        "Parameterization": p,
    }
    for col in mappings.data.columns:
        print(f'{col}: {mappings.data.iloc[0][col]}')
        # if "Latency" in col:
        # if "metric_Latency" in col:
        if "Latency" in col:
        # if "metric_Energy" in col:
            row[col] = mappings.data.iloc[0][col]
    parameterization2latencycols.append(row)

    # from fastfusion.mapper.FFM.visualization import make_mapping
    # from IPython.display import SVG
    # newmapping = make_mapping(mappings.data.iloc[0], spec.workload.einsum_names, get_rank_variable_bounds_for_all_einsums(spec))
    # for col in mappings.data.columns:
    #     print(f'{col}: {mappings.data.iloc[0][col]}')

    # display(SVG(newmapping.render()))
    
from fastfusion.accelerated_imports.pd import DataFrame
df = DataFrame(parameterization2latencycols)
from fastfusion.accelerated_imports import pd
pd.set_option('display.max_columns', None)
df
    
# {'n1'}-1 || [GlobalBuffer] T1 sz 0 above 1
# TODO: Re-add -1 to the mapper one eisnum freenig
# compatibility2sims['Matmul1']["{'n1'}-1 || [GlobalBuffer] T1 sz 0 above 1"]
# Above 1: 8192
# Above 2: 8321
# compatibility2sims['Matmul2']["{'n1'}-1 || [GlobalBuffer] T1 sz 0 above 1, [GlobalBuffer] T2 sz 0 above 0"]

In [None]:
from fastfusion.mapper.FFM.exploration.mapper_multi_einsum import get_rank_variable_bounds_for_all_einsums

mac_dims = int((((area_budget - glb.area) / 4 - llb.area) / mac_area)** 0.5)
mappings = list(parameterization2mappings.values())[0]
mappings._data = mappings.data.sort_values(by="metric_Latency", ascending=True).head()
rank_variable_bounds = get_rank_variable_bounds_for_all_einsums(spec)
from fastfusion.mapper.FFM.visualization import make_mapping
from IPython.display import SVG
newmapping = make_mapping(mappings.data.iloc[0], spec.workload.einsum_names, get_rank_variable_bounds_for_all_einsums(spec))
a = {}
for col in mappings.data.columns:
    print(f'{col}: {mappings.data.iloc[0][col]}')
    if "Latency" in col:
        a[col] = mappings.data.iloc[0][col]
display(SVG(newmapping.render()))

In [None]:
assert False

from fastfusion.mapper.FFM.exploration.mapper_multi_einsum import get_rank_variable_bounds_for_all_einsums

sram_size = 0.5 * 1024 * 1024 * 8
llb = CactiSRAM(tech_node=7e-9, width=128, depth=sram_size // 128)
mac_dims = int((((area_budget - glb.area) / 4 - llb.area) / mac_area)** 0.5)
mappings = get_fused_mappings(
    spec,
    mac_dims,
    llb,
    glb,
    return_mappings=True,
    parameterization="FFM"
)
mappings._data = mappings.data.sort_values(by="metric_Latency", ascending=True).head()
rank_variable_bounds = get_rank_variable_bounds_for_all_einsums(spec)
from fastfusion.mapper.FFM.visualization import make_mapping
from IPython.display import SVG
newmapping = make_mapping(mappings.data.iloc[0], spec.workload.einsum_names, get_rank_variable_bounds_for_all_einsums(spec))
b = {}
for col in mappings.data.columns:
    print(f'{col}: {mappings.data.iloc[0][col]}')
    if "Latency" in col:
        b[col] = mappings.data.iloc[0][col]
display(SVG(newmapping.render()))

In [None]:
df = pd.DataFrame([a, b])
df