In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import glob
import re

%matplotlib inline

# Show all df rows
pd.set_option('display.max_rows', None)

In [None]:
data_root = "insert your data root here"
module_data = pd.read_csv(f'{data_root}/calls_merged_per_module.csv')
op_data_files = glob.glob(f"{data_root}/calls_* _*.csv")


In [None]:
module_data['Throughput'] = module_data['Count']/(module_data['Duration (us)']/10**3)

In [None]:
module_data

In [None]:
all_dfs = []
for file in op_data_files:
    module_regex = r'calls_(\w+) _(\d+).csv'
    module, iteration = re.search(module_regex, file).groups()
    op_data = pd.read_csv(file)
    op_data['iteration'] = int(iteration)
    all_dfs.append(op_data)

op_dfs = pd.concat(all_dfs)
display(op_dfs)

In [None]:
per_device = op_dfs.groupby("Device")[["Duration (us)"]].sum()
per_device["prct_spent"] = per_device["Duration (us)"]/per_device["Duration (us)"].sum()
per_device

In [None]:
per_op_grouped = op_dfs.groupby(["Module", "Name"])[["Duration (us)", "Count", "iteration"]].mean()

# normalise duration per module
per_op_df = per_op_grouped.reset_index()
per_op_df["Duration norm"] = per_op_df.groupby("Module")['Duration (us)'].transform(lambda x: x / x.sum())
print(per_op_df.iteration.unique())
display(per_op_df)

def mapping(name):
    if name.startswith("fused_fused"):
        name = name.replace("fused_fused_", "fused_")
    if "_NT_matmul" in name:
        return name.replace("_NT_matmul", "\n_NT_matmul")
    if name == "vm.builtin.paged_attention_kv_cache_attention_with_fused_qkv":
        return "vm.builtin.paged_attention\n_kv_cache_attention\n_with_fused_qkv"
    return name

per_op_df["Grouped Name"] = per_op_df["Name"].apply(mapping)

per_op_df = per_op_df.groupby(["Module", "Grouped Name"]).sum()
display(per_op_df)


cmap = plt.cm.twilight
for group, df in per_op_df.reset_index().groupby("Module"):
    df = df[df["Duration norm"] > 0.01]
    colors = cmap(np.linspace(0, 1, len(df)+1))
    ax = df.set_index("Grouped Name").plot.pie(y="Duration norm", figsize=(4,3), legend=True, startangle=160,
                                               ylabel='', labeldistance=None, autopct='%1.1f%%', pctdistance=0.77,
                                               colors=colors, explode=[0.01] * len(df), textprops={'color':"w"},
                                               title=group)
    for text in ax.texts:
        if text.get_text() in ["12.6%", "28.0%", "22.6%"]:
            text.set_color("k")
        print(text)

    ax.legend(loc='center left', bbox_to_anchor=(1, 0.6))

    plt.savefig(f'./figures/per_op_mlc_llama_{group.strip()}.pdf', bbox_inches='tight')