In [1]:
import fastfusion as ff
from math import log10

ff.set_n_parallel_jobs(32)

def get_runs(
    arch: str,
    workload: str,
    jinja_parse_data: dict,
):
    spec = ff.Specification.from_yaml(arch, workload, jinja_parse_data=jinja_parse_data)
    spec.arch["ArrayDummy"].constraints.spatial["reuse_input"].min_utilization = 0
    spec.arch["ArrayDummy"].constraints.spatial["reuse_output"].min_utilization = 0
    spec.arch["MainMemory"].constraints.tensors.keep = "All"
    spec.arch["GlobalBuffer"].constraints.tensors.keep = "output | input | ~MainMemory"
    spec.arch["GlobalBuffer"].constraints.tensors.may_keep = "weight | ~MainMemory"
    if "tpu_v4i" in arch:
        # spec.arch["LocalBuffer"].constraints.spatial.append(ff.constraints.Spatial(name="Z", min_utilization=1))
        spec.arch["LocalBuffer"].constraints.tensors.keep = "input | output"

    spec.mapper.ffm.metrics = ff.Metrics.ENERGY | ff.Metrics.LATENCY

    print(spec.workload.shape)

    def run_mapper(spec: ff.Specification, count_option: str):
        spec.mapper.ffm._count_option_for_mapsapce_size_evaluation = count_option
        return ff.mapper.FFM.make_pmappings(spec, cache_dir="/tmp/ff_cache")

    normal = run_mapper(spec, ())
    total = run_mapper(spec, ("redundant_dataplacements", "non_helpful_loops_for_loop_orders", "non_helpful_tile_shapes", "redundant_loop_orders"))
    no_redundant_dataplacements = run_mapper(spec, ("non_helpful_loops_for_loop_orders", "non_helpful_tile_shapes", "redundant_loop_orders"))
    no_non_helpful_tile_shapes = run_mapper(spec, ("non_helpful_loops_for_loop_orders", "redundant_loop_orders"))
    no_non_helpful_loops_for_loop_orders = run_mapper(spec, ("redundant_loop_orders"))
    return (
        spec,
        normal,
        total,
        no_redundant_dataplacements,
        no_non_helpful_tile_shapes,
        no_non_helpful_loops_for_loop_orders,
    )

def get_reduction_per_piece(
    spec: ff.Specification,
    normal: ff.mapper.FFM.MultiEinsumPmappings,
    total: ff.mapper.FFM.MultiEinsumPmappings,
    no_redundant_dataplacements: ff.mapper.FFM.MultiEinsumPmappings,
    no_non_helpful_tile_shapes: ff.mapper.FFM.MultiEinsumPmappings,
    no_non_helpful_loops_for_loop_orders: ff.mapper.FFM.MultiEinsumPmappings,
):
    def _count(f):
        r = f(per_einsum=True)
        # r['Total'] = f()
        return r

    n_total = _count(total.total_pmappings)
    n_no_redundant_dataplacements = _count(no_redundant_dataplacements.total_pmappings)
    n_no_non_helpful_tile_shapes = _count(no_non_helpful_tile_shapes.total_pmappings)
    n_no_non_helpful_permutations = _count(normal.total_pmappings)
    n_ffm_evaluated = _count(normal.evaluated_pmappings)

    prev = [n_total]
    def get_reduction(to_mapspace, from_mapspace=None):
        if from_mapspace is None:
            from_mapspace = prev.pop(0)
            prev.append(to_mapspace)
        return {k: log10(max(v / to_mapspace[k], 1)) for k, v in from_mapspace.items()}

    redundant_dataplacements_reduction = get_reduction(n_no_redundant_dataplacements)
    non_helpful_tile_shapes_reduction = get_reduction(n_no_non_helpful_tile_shapes)
    redundant_loop_orders_reduction = get_reduction(n_no_non_helpful_permutations)
    pruned_tile_shapes_reduction = get_reduction(n_ffm_evaluated)
    remaining = get_reduction({k: 1 for k in n_total.keys()})

    print(redundant_dataplacements_reduction)
    print(non_helpful_tile_shapes_reduction)
    print(redundant_loop_orders_reduction)
    print(pruned_tile_shapes_reduction)

    non_helpful_tile_shapes_reduction = {
        k: v + redundant_dataplacements_reduction[k] for k, v in non_helpful_tile_shapes_reduction.items()
    }
    redundant_dataplacements_reduction = {
        k: 0 for k in n_total.keys()
    }

    einsums = [e.name for e in spec.workload.einsums if len(e.tensor_accesses) > 2]# + ["Total"]
    einsums = [e for e in einsums if e in remaining]

    return {
        e: {
            # "Dominated Dataplacements": redundant_dataplacements_reduction[e],
            "Dataplacement → Tile Shape Pruning": non_helpful_tile_shapes_reduction[e],
            "Dataplacement → Dataflow Pruning": redundant_loop_orders_reduction[e],
            "Partial Tile Shape Pruning": pruned_tile_shapes_reduction[e],
        } for e in einsums
    }


In [None]:
import matplotlib.pyplot as plt
from format_plot import format_plot

def make_stacked_bar_chart(
    results: dict[str, dict[str, float]],
    ax: plt.Axes
):
    categories = list(next(iter(results.values())).keys())
    labels = list(results.keys())
    data = list(zip(*results.values()))
    bottom = [0] * len(labels)
    for i, cat in enumerate(categories):
        ax.bar(labels, [results[l][cat] for l in labels], label=cat, bottom=bottom)
        bottom = [b + v for b, v in zip(bottom, [results[l][cat] for l in labels])]
    format_plot(ax)
    # Set the Y axis label to be 10^x
    ax.set_yticklabels([f"$10^{{{int(y)}}}$" for y in ax.get_yticks()])
    ax.set_xticklabels(labels, rotation=45, ha="right")


reductions_gpt = get_reduction_per_piece(*get_runs(
    "../../examples/arches/tpu_v4i_like.arch.yaml",
    "../../examples/workloads/gpt3_6.7B.workload.yaml",
    dict(BATCH_SIZE=64, N_TOKENS=65536),
))

# reductions_gpt_b1 = get_reduction_per_piece(*get_runs(
#     "../../examples/arches/tpu_v4i_like.arch.yaml",
#     "../../examples/workloads/gpt3_6.7B.workload.yaml",
#     dict(BATCH_SIZE=1, N_TOKENS=65536),
# ))

reductions_mobilenet = get_reduction_per_piece(*get_runs(
    "../../examples/arches/nvdla_like.arch.yaml",
    "../../examples/workloads/mobilenet_28.workload.yaml",
    dict(BATCH_SIZE=64),
))

# reductions_mobilenet_b1 = get_reduction_per_piece(*get_runs(
#     "../../examples/arches/nvdla_like.arch.yaml",
#     "../../examples/workloads/mobilenet_28.workload.yaml",
#     dict(BATCH_SIZE=1),
# ))

# fig, ax = plt.subplots(figsize=(20, 10))
# make_stacked_bar_chart(reductions_gpt, ax)
# ax.set_title(f"TPU-like, GPT-3 6.7B")
# ax.set_ylabel(f"Reduction in #Mappings")
# plt.show()

# fig, ax = plt.subplots(figsize=(20, 10))
# make_stacked_bar_chart(reductions_mobilenet, ax)
# ax.set_title(f"NVDLA-like, MobileNet-v3")
# ax.set_ylabel(f"Reduction in #Mappings")
# plt.show()


In [None]:
renames = {
    "Piecewise": "P-",
    "Depthwise": "D-"
}
def rename(x):
    if isinstance(x, str):
        for k, v in renames.items():
            x = x.replace(k, v)
        return x
    elif isinstance(x, dict):
        return {rename(k): rename(v) for k, v in x.items()}
    elif isinstance(x, list):
        return [rename(i) for i in x]
    else:
        return x

fig, axs = plt.subplots(figsize=(20, 10), ncols=2)
ax = axs[0]
make_stacked_bar_chart(rename(reductions_gpt), ax)
# ax.set_title(f"TPU-like, GPT-3 6.7B")
ax.set_xlabel("GPT-3 6.7B Einsums")
ax.set_ylabel(f"Reduction in #Mappings")
ax.get_legend().remove()
ax.set_xticklabels(ax.get_xticklabels(), rotation=0, ha="center")
ax.set_ylim(0, 31)

ax = axs[1]
make_stacked_bar_chart(rename(reductions_mobilenet), ax)
# ax.set_title(f"NVDLA-like, MobileNet-v3")
ax.set_xlabel("MobileNet-v3 Einsums")
ax.set_ylabel(f"Reduction in #Mappings")
ax.get_legend().remove()
ax.set_ylabel("")
ax.set_xticklabels(ax.get_xticklabels(), rotation=0, ha="center")
ax.set_ylim(0, 31)

handles, labels = ax.get_legend_handles_labels()
plt.figlegend(handles, labels, loc = 'upper center', ncol=3, labelspacing=0.0, fontsize=20)

plt.savefig("plots/mapspace_size_reduction.pdf", bbox_inches="tight")
plt.show()

In [None]:
renames = {
    "Piecewise": "P-",
    "Depthwise": "D-"
}
def rename(x):
    if isinstance(x, str):
        for k, v in renames.items():
            x = x.replace(k, v)
        return x
    elif isinstance(x, dict):
        return {rename(k): rename(v) for k, v in x.items()}
    elif isinstance(x, list):
        return [rename(i) for i in x]
    else:
        return x

fig, axs = plt.subplots(figsize=(40, 10), ncols=4)
ax = axs[0]
make_stacked_bar_chart(rename(reductions_gpt), ax)
# ax.set_title(f"TPU-like, GPT-3 6.7B")
ax.set_xlabel("GPT-3 6.7B Einsums (Batch 64)")
ax.set_ylabel(f"Reduction in #Mappings")
ax.get_legend().remove()
ax.set_xticklabels(ax.get_xticklabels(), rotation=0, ha="center")

ax = axs[1]
make_stacked_bar_chart(rename(reductions_gpt), ax)
# ax.set_title(f"TPU-like, GPT-3 6.7B")
ax.set_xlabel("GPT-3 6.7B Einsums (Batch 1)")
ax.set_ylabel(f"Reduction in #Mappings")
ax.get_legend().remove()
ax.set_xticklabels(ax.get_xticklabels(), rotation=0, ha="center")
ax.set_ylabel("")

ax = axs[2]
make_stacked_bar_chart(rename(reductions_mobilenet), ax)
# ax.set_title(f"NVDLA-like, MobileNet-v3")
ax.set_xlabel("MobileNet-v3 Einsums (Batch 64)")
ax.set_ylabel(f"Reduction in #Mappings")
ax.get_legend().remove()
ax.set_ylabel("")
ax.set_xticklabels(ax.get_xticklabels(), rotation=0, ha="center")

ax = axs[3]
make_stacked_bar_chart(rename(reductions_mobilenet), ax)
# ax.set_title(f"NVDLA-like, MobileNet-v3")
ax.set_xlabel("MobileNet-v3 Einsums (Batch 1)")
ax.set_ylabel(f"Reduction in #Mappings")
ax.get_legend().remove()
ax.set_ylabel("")
ax.set_xticklabels(ax.get_xticklabels(), rotation=0, ha="center")

handles, labels = ax.get_legend_handles_labels()
plt.figlegend(handles, labels, loc = 'upper center', ncol=3, labelspacing=0.0, fontsize=20)

plt.savefig("plots/mapspace_size_reduction.pdf", bbox_inches="tight")
plt.show()
