In [None]:
from pandas import DataFrame
from fastfusion import Specification
from fastfusion.frontend.mapping import Iteration, Mapping, Spatial, Storage
from fastfusion.frontend.workload.workload import RankVariableName
from fastfusion.mapper.FFM.exploration.mapper_multi_einsum import get_sims
from fastfusion.mapper.FFM.exploration.mapper_one_einsum import get_single_einsum_sims, iterate_mappings_constraints
from fastfusion.mapper.FFM.exploration.tile_shape_exploration import dummy_tile_shape_exploration
from fastfusion.mapper.FFM.joining.mappinginfo import Compatibility, Loop, Reservation
from fastfusion.mapper.FFM.joining.sim import SIM
from fastfusion.mapper.FFM.joining.simexplore import join_sims
from fastfusion.mapper.FFM.pareto import Pareto, MAPPING_COLUMN
from fastfusion.util.util import fzs

# # Example mapping node
# type: "temporal"
# rank: P
# # Choose one of the following cases
# # Case 1.a
# tile_shape: 3   # will make tile shapes with shape 3
# # Case 1.b
# tile_shape: null # will create a sympy symbol to represent tile shape and use that
# # Case 2.a
# factor: 3       # will make 3 as evenly shaped possible tiles
# # Case 2.b
# factor: null    # will create a sympy symbol to represent the factor, then same as 2.a
# # Case 3   (I'm only showing null from now on)
# tile_pattern:
#   stride: null
#   initial_shape: null  # This will create tile like this [0, 1, ..., initial_shape - 1], [initial_shape, ..., initial_shape + stride - 1], [initial_shape + stride, ..., initial_shape + 2*stride - 1], ...
# # Case 4
# tile_pattern:
#   stride: null
#   shape: null      # This will create tile like this [0, 1, ..., shape-1], [stride, stride+1, ..., stride + shape-1], [2*stride, 2*stride + 1, ..., 2*stride + shape - 1], ...
#         choices = list(integer_factorizations_to_n_parts(rank_size, len(loops)))

# Tile shape constraint: Applies to all tensor(s) in a storage node for which that tile shape is relevant
# Loop bound constraint: Only for spatial

spec = Specification.from_yaml(
    "architecture/four_level.arch.yaml",
    "workloads/mha_full_new.yaml",
    "workloads/mha_full_new.renames.yaml",
)

workload = spec.workload
renames = spec.renames

rank_variable_to_size = {r: 16 for r in "abcdefghijklmnopqrstuvwxyz"}

# pr = cProfile.Profile()
# pr.enable()

# sims = get_single_einsum_sims(spec, "Q", rank_variable_to_size)
flattened_architecture = spec.get_flattened_architecture()
sims = get_sims(spec, rank_variable_to_size, flattened_architecture, pkl_cache="sims.pkl")
join_sims(
    sims,
    spec,
    flattened_architecture,
)

# pr.disable()
# s = io.StringIO()
# ps = pstats.Stats(pr, stream=s).sort_stats('cumulative')
# ps.print_stats(30)  # Print top 30 time-consuming functions
# print(s.getvalue())

# TODO: Check for ranks not in the mapping and put them at the bottom
# TODO: What if there are no loops? 
# TODO: Set _must_exist for all backing storage nodes
# TODO: Constraint attacher
# TODO: Can't have tile size constraints on backing memory
# TODO: Einsum orders
# TODO: Copy Einsums
# I'm doing the tile shape exploration now and I'm trying to understand this note. I think I understand what you're saying.
# Can I ask one thing from the constraint code? If the constraint is an equality, then just set the tile_shape attribute of the node (or factor or whatever is needed) to the value.
# The tile shape exploration assumes a particular mapspace (in most cases, tile shapes are factors of the full rank shape), so an equality may never be satisfied. E.g., if the constraint sets the tile shape equal to a non-factor value because you want a particular imperfect factorization, but that's never in the mapspace, then you'll get nothing.
# It's also a bit more efficient to just set the value and the explorer doesn't have to figure out the equality by trial-and-error. For other more complicated constraints, trial-and-error is better.

