In [None]:
import copy
import itertools
from typing import Iterator
from fastfusion.frontend._set_parsing import InvertibleSet
from fastfusion.frontend.arch import Leaf, Memory
from fastfusion.frontend.workload.workload_spec import Tensor, RankVariable
import fastfusion.frontend.mapping as mapping_import
from fastfusion import Specification
from fastfusion.mapper.FFM.exploration.single_einsum_mapper import iterate_mappings, make_storage_choices_all_levels, make_storage_choices_one_level, get_storage_choices


# # Example mapping node
# type: "temporal"
# rank: P
# # Choose one of the following cases
# # Case 1.a
# tile_shape: 3   # will make tile shapes with shape 3
# # Case 1.b
# tile_shape: null # will create a sympy symbol to represent tile shape and use that
# # Case 2.a
# factor: 3       # will make 3 as evenly shaped possible tiles
# # Case 2.b
# factor: null    # will create a sympy symbol to represent the factor, then same as 2.a
# # Case 3   (I'm only showing null from now on)
# tile_pattern:
#   stride: null
#   initial_shape: null  # This will create tile like this [0, 1, ..., initial_shape - 1], [initial_shape, ..., initial_shape + stride - 1], [initial_shape + stride, ..., initial_shape + 2*stride - 1], ...
# # Case 4
# tile_pattern:
#   stride: null
#   shape: null      # This will create tile like this [0, 1, ..., shape-1], [stride, stride+1, ..., stride + shape-1], [2*stride, 2*stride + 1, ..., 2*stride + shape - 1], ...
#         choices = list(integer_factorizations_to_n_parts(rank_size, len(loops)))

# Tile shape constraint: Applies to all tensor(s) in a storage node for which that tile shape is relevant
# Loop bound constraint: Only for spatial

spec = Specification.from_yaml(
    "architecture/four_level.arch.yaml",
    "workloads/mha_full_new.yaml",
    "workloads/mha_full_new.renames.yaml",
)

workload = spec.workload
renames = spec.renames

einsum_name = "K"
einsum = workload.einsums[einsum_name]
rank_variables = einsum.rank_variables
tensors = einsum.tensors
symbol_table = workload.get_constraint_symbol_table(einsum_name, renames)
first_value = next(iter(symbol_table.values()))
arch_nodes = spec.get_flattened_architecture()
tensor2rank_variables = einsum.tensor2rank_variables
storage_order = [n.name for n in arch_nodes if isinstance(n, Memory)]
rank_variable_to_size = {r: 16 for r in rank_variables}

# If there are two back-to-back storages for the same tensor & the outer is
# optional, then it is invalid.
uneven_storages = [n for n in arch_nodes if n.constraints.storage.uneven]
storage_choice_options = list(make_storage_choices_all_levels(arch_nodes, symbol_table))
import time

t0 = time.time()
mappings_count = 0
main_memory = arch_nodes[0]
n_mappings = 0

# TODO: Check for ranks not in the mapping and put them at the bottom

print(f"Total mappings: {n_mappings}")

# for i, (storage_choices, symbol_table) in enumerate(get_storage_choices(arch_nodes, symbol_table)):
#     print(f"{i}/{len(storage_choice_options)}: {storage_choices.compact_string()}")

for i, mapping in enumerate(iterate_mappings(spec, "QK")):
    print(f"{i}: {mapping.compact_string()}")

# TODO: What if there are no loops?
# TODO: Set _must_exist for all backing storage nodes


2025-05-20 13:16:00 INFO        Loading yaml file architecture/four_level.arch.yaml
2025-05-20 13:16:00 INFO        Found top key variables in architecture/four_level.arch.yaml
2025-05-20 13:16:00 INFO        Found top key architecture in architecture/four_level.arch.yaml
2025-05-20 13:16:00 INFO        Found top key compound_components in architecture/four_level.arch.yaml
2025-05-20 13:16:00 INFO        Loading yaml file workloads/mha_full_new.yaml
2025-05-20 13:16:00 INFO        Found top key workload in workloads/mha_full_new.yaml
2025-05-20 13:16:00 INFO        Loading yaml file workloads/mha_full_new.renames.yaml
2025-05-20 13:16:00 INFO        Found top key renames in workloads/mha_full_new.renames.yaml
2025-05-20 13:16:00 INFO        Loading yaml file /root/.config/fastfusion/config.yaml
2025-05-20 13:16:00 INFO        Found top key version in /root/.config/fastfusion/config.yaml
2025-05-20 13:16:00 INFO        Found top key environment_variables in /root/.config/fastfusion/conf

Total mappings: 0
[GlobalBuffer Q] {h, b}-None [GlobalBuffer K] {h, b, p}-None [GlobalBuffer QK] SX-{m, b, e, p, h}-None {m, h, b}-None [LocalBuffer Q] {h, b, e}-None [LocalBuffer K] {h, b, p}-None [LocalBuffer QK] SX-{m, b, e, p, h}-None SY-{m, b, e, p, h}-None {h, b, p}-None [Register K] {h, b, e, p}-None
[GlobalBuffer Q] {h, b}-None [GlobalBuffer K] {h, b, p}-None [GlobalBuffer QK] SX-{m, b, e, p, h}-None {m, h, b}-None [LocalBuffer Q] {h, b, e}-None [LocalBuffer K] {h, b, e, p}-None [Register K] {h, b, p}-None [LocalBuffer QK] SX-{m, b, e, p, h}-None SY-{m, b, e, p, h}-None {m, h, b, p}-None
[GlobalBuffer Q] {h, b}-None [GlobalBuffer K] {h, b, p}-None [GlobalBuffer QK] SX-{m, b, e, p, h}-None {m, h, b}-None [LocalBuffer Q] {m, h, b}-None [LocalBuffer QK] {h, b, p}-None [LocalBuffer K] SX-{m, b, e, p, h}-None SY-{m, b, e, p, h}-None {h, b, e, p}-None [Register K] {h, b, e, p}-None
[GlobalBuffer Q] {h, b}-None [GlobalBuffer K] {h, b, p}-None [GlobalBuffer QK] SX-{m, b, e, p, h}-None 