In [11]:
import pandas as pd
from IPython.display import display

### NCU Measurement Pre-Processing

In [12]:
files = ["results/toy_network_setup_flops_ncu.csv", "results/toy_network_forward_flops_ncu.csv", "results/toy_network_forward_backward_flops_ncu.csv"]
runs = ["setup", "setup_forward", "setup_forward_backward"]

for i, file in enumerate(files):
    df = pd.read_csv(file)
    df = df[df["Metric Value"] != 0].copy()

    df["Simple Kernel"] = (
        df["Kernel Name"]
        .str.extract(r'([A-Za-z0-9_]+?_(?:nn|kernel|nt|tn))')
        .iloc[:,0]
    )

    df["Simple Metric"] = (
        df["Metric Name"]
        .str.extract(r'_(fadd|fmul|ffma)_')
        .iloc[:,0]
    )
    print (f"Processing kernels for '{runs[i]}' profile run...")
    display(df[["Simple Kernel", "Simple Metric", "Metric Value"]])
    df[["Simple Kernel", "Simple Metric", "Metric Value"]].to_csv(file.replace(".csv", "_processed.csv"), index=False)

Processing kernels for 'setup' profile run...


Unnamed: 0,Simple Kernel,Simple Metric,Metric Value
3,distribution_elementwise_grid_stride_kernel,fadd,872448.0
4,distribution_elementwise_grid_stride_kernel,ffma,14135296.0
5,distribution_elementwise_grid_stride_kernel,fmul,6459392.0
12,distribution_elementwise_grid_stride_kernel,fadd,872448.0
13,distribution_elementwise_grid_stride_kernel,ffma,14135296.0
14,distribution_elementwise_grid_stride_kernel,fmul,6459392.0
21,distribution_elementwise_grid_stride_kernel,fadd,872448.0
22,distribution_elementwise_grid_stride_kernel,ffma,14135296.0
23,distribution_elementwise_grid_stride_kernel,fmul,6459392.0
30,distribution_elementwise_grid_stride_kernel,fadd,872448.0


Processing kernels for 'setup_forward' profile run...


Unnamed: 0,Simple Kernel,Simple Metric,Metric Value
3,distribution_elementwise_grid_stride_kernel,fadd,872448.0
4,distribution_elementwise_grid_stride_kernel,ffma,14135296.0
5,distribution_elementwise_grid_stride_kernel,fmul,6459392.0
12,distribution_elementwise_grid_stride_kernel,fadd,872448.0
13,distribution_elementwise_grid_stride_kernel,ffma,14135296.0
...,...,...,...
347,ampere_sgemm_64x32_sliced1x4_nn,fmul,1048576.0
354,splitKreduce_kernel,fadd,262144.0
356,splitKreduce_kernel,fmul,131072.0
363,vectorized_elementwise_kernel,fadd,262144.0


Processing kernels for 'setup_forward_backward' profile run...


Unnamed: 0,Simple Kernel,Simple Metric,Metric Value
3,distribution_elementwise_grid_stride_kernel,fadd,872448.0
4,distribution_elementwise_grid_stride_kernel,ffma,14135296.0
5,distribution_elementwise_grid_stride_kernel,fmul,6459392.0
12,distribution_elementwise_grid_stride_kernel,fadd,872448.0
13,distribution_elementwise_grid_stride_kernel,ffma,14135296.0
...,...,...,...
707,ampere_sgemm_32x128_tn,fmul,1048576.0
714,vectorized_elementwise_kernel,fadd,131072.0
716,vectorized_elementwise_kernel,fmul,262144.0
724,ampere_sgemm_32x128_tn,ffma,134217728.0


## Forward

### NCU Measurements Forward

In [13]:
df_fwd = pd.read_csv("results/toy_network_forward_flops_ncu_processed.csv")

# drop all rows from setup
df_setup = pd.read_csv("results/toy_network_setup_flops_ncu_processed.csv")
max_setup_index = df_setup.index.max()
df_fwd_filtered = df_fwd[df_fwd.index > max_setup_index].copy()

df_grouped = df_fwd_filtered.groupby(["Simple Kernel", "Simple Metric"], as_index=False)["Metric Value"].sum()
print(df_grouped)
print()

ffma = df_grouped[df_grouped["Simple Metric"] == "ffma"]["Metric Value"].sum()
fadd = df_grouped[df_grouped["Simple Metric"] == "fadd"]["Metric Value"].sum()
fmul = df_grouped[df_grouped["Simple Metric"] == "fmul"]["Metric Value"].sum()
ncu_total_flops = ffma * 2 + fadd + fmul

ffma_gemm = df_grouped[(df_grouped["Simple Kernel"].str.contains("sgemm")) & (df_grouped["Simple Metric"] == "ffma")]["Metric Value"].sum()
fadd_gemm = df_grouped[(df_grouped["Simple Kernel"].str.contains("sgemm")) & (df_grouped["Simple Metric"] == "fadd")]["Metric Value"].sum()
fmul_gemm = df_grouped[(df_grouped["Simple Kernel"].str.contains("sgemm")) & (df_grouped["Simple Metric"] == "fmul")]["Metric Value"].sum()
ncu_total_flops_gemm = ffma_gemm * 2 + fadd_gemm + fmul_gemm

ffma_ksplit = df_grouped[(df_grouped["Simple Kernel"].str.contains("splitK")) & (df_grouped["Simple Metric"] == "ffma")]["Metric Value"].sum()
fadd_ksplit = df_grouped[(df_grouped["Simple Kernel"].str.contains("splitK")) & (df_grouped["Simple Metric"] == "fadd")]["Metric Value"].sum()
fmul_ksplit = df_grouped[(df_grouped["Simple Kernel"].str.contains("splitK")) & (df_grouped["Simple Metric"] == "fmul")]["Metric Value"].sum()
ncu_total_flops_ksplit = ffma_ksplit * 2 + fadd_ksplit + fmul_ksplit

ffma_activation = df_grouped[(df_grouped["Simple Kernel"].str.contains("elementwise")) & (df_grouped["Simple Metric"] == "ffma")]["Metric Value"].sum()
fadd_activation = df_grouped[(df_grouped["Simple Kernel"].str.contains("elementwise")) & (df_grouped["Simple Metric"] == "fadd")]["Metric Value"].sum()
fmul_activation = df_grouped[(df_grouped["Simple Kernel"].str.contains("elementwise")) & (df_grouped["Simple Metric"] == "fmul")]["Metric Value"].sum()
ncu_total_flops_activation = ffma_activation * 2 + fadd_activation + fmul_activation

assert ncu_total_flops == ncu_total_flops_gemm + ncu_total_flops_ksplit + ncu_total_flops_activation, "Total FLOPs do not match!"

print(f"Total FLOPs GEMM NCU: {ncu_total_flops_gemm:,}")
print(f"Total FLOPs KSplit NCU: {ncu_total_flops_ksplit:,}")
print(f"Total FLOPs Activation NCU: {ncu_total_flops_activation:,}")
print(f"Total FLOPs NCU: {ncu_total_flops:,}")

                     Simple Kernel Simple Metric  Metric Value
0  ampere_sgemm_64x32_sliced1x4_nn          fadd  7.864320e+06
1  ampere_sgemm_64x32_sliced1x4_nn          ffma  1.342177e+09
2  ampere_sgemm_64x32_sliced1x4_nn          fmul  1.048576e+07
3              splitKreduce_kernel          fadd  2.621440e+06
4              splitKreduce_kernel          fmul  1.310720e+06
5    vectorized_elementwise_kernel          fadd  2.621070e+06
6    vectorized_elementwise_kernel          ffma  9.174358e+06

Total FLOPs GEMM NCU: 2,702,704,640.0
Total FLOPs KSplit NCU: 3,932,160.0
Total FLOPs Activation NCU: 20,969,786.0
Total FLOPs NCU: 2,727,606,586.0


### Theoretical Calculations Forward

In [14]:
N = 10      # Number of layers
D = 1024    # Input/Output dimension
M = 128     # Number of tokens (sample length)

theoretical_gemm_flops = N * 2 * D * D * M
theoretical_ksplit_flops = 0 # not considered in theoretical model
theoretical_activation_flops = N * D * M
theoretical_total_flops = theoretical_gemm_flops + theoretical_activation_flops

print("N:", N, "D:", D, "M:", M)
print(f"GEMM: {theoretical_gemm_flops:,}")
print(f"KSplit: {theoretical_ksplit_flops:,}")
print(f"Activation: {theoretical_activation_flops:,}")
print(f"Total FLOPs (GEMM + KSplit + Activation): {theoretical_total_flops:,}")

N: 10 D: 1024 M: 128
GEMM: 2,684,354,560
KSplit: 0
Activation: 1,310,720
Total FLOPs (GEMM + KSplit + Activation): 2,685,665,280


### Difference in Forward: NCU vs. Theoretical

In [15]:
diff_gemm_pct       = (ncu_total_flops_gemm - theoretical_gemm_flops) / theoretical_gemm_flops * 100
diff_gemm           = ncu_total_flops_gemm - theoretical_gemm_flops
diff_ksplit_pct     = float('inf') if theoretical_ksplit_flops == 0 else (ncu_total_flops_ksplit - theoretical_ksplit_flops) / theoretical_ksplit_flops * 100 if theoretical_ksplit_flops != 0 else 0
diff_ksplit         = ncu_total_flops_ksplit - theoretical_ksplit_flops
diff_activation_pct = (ncu_total_flops_activation - theoretical_activation_flops) / theoretical_activation_flops * 100
diff_activation     = ncu_total_flops_activation - theoretical_activation_flops

print(f"GEMM % diff:            {diff_gemm_pct:.2f}%")
print(f"GEMM abs diff:          {diff_gemm:,}\n")
print(f"KSplit % diff:          {diff_ksplit_pct:.2f}%")
print(f"KSplit abs diff:        {diff_ksplit:,}\n")
print(f"Activation % diff:      {diff_activation_pct:.2f}%")
print(f"Activation abs diff:    {diff_activation:,}\n")
print(f"Summed total diff:      {diff_gemm + diff_ksplit + diff_activation:,}")
print(f"Actual total diff:      {ncu_total_flops - theoretical_total_flops:,}")

GEMM % diff:            0.68%
GEMM abs diff:          18,350,080.0

KSplit % diff:          inf%
KSplit abs diff:        3,932,160.0

Activation % diff:      1499.87%
Activation abs diff:    19,659,066.0

Summed total diff:      41,941,306.0
Actual total diff:      41,941,306.0


### NCU Closed Form Reconstruction Forward

In [16]:
# gemm overhead
epsilon_gemm = ncu_total_flops_gemm / theoretical_gemm_flops - 1
# ksplit overhead - absolute FLOPs per D*M operation
rho = ncu_total_flops_ksplit / (N * D * M)
# activation overhead - calculated as overhead factor like epsilon
kappa_fwd = ncu_total_flops_activation / theoretical_activation_flops - 1

# Reconstruct with consistent usage:
# - epsilon_gemm and kappa_fwd are overhead factors (use 1 + factor)
# - rho is absolute FLOPs per operation (use directly)
ncu_total_flops_reconstruct = N * (2 * (1 + epsilon_gemm) * D * D * M + (1 + kappa_fwd + rho) * D * M)

print(f"epsilon_gemm: {epsilon_gemm:.4f}, rho: {rho:.4f}, kappa_fwd: {kappa_fwd:.4f}")
print(f"Total Theoretical FLOPs: {theoretical_total_flops:,}")
print(f"Total FLOPs NCU: {ncu_total_flops:,}")
print(f"Total Reconstructed NCU FLOPs: {ncu_total_flops_reconstruct:,}")
print(f"NCU Reconstruction vs NCU Total Diff: {(ncu_total_flops_reconstruct - ncu_total_flops)}")

epsilon_gemm: 0.0068, rho: 3.0000, kappa_fwd: 14.9987
Total Theoretical FLOPs: 2,685,665,280
Total FLOPs NCU: 2,727,606,586.0
Total Reconstructed NCU FLOPs: 2,727,606,586.0
NCU Reconstruction vs NCU Total Diff: 0.0


## Backward

### NCU Measurements Backward

In [17]:
df_bwd = pd.read_csv("results/toy_network_forward_backward_flops_ncu_processed.csv")

# drop all rows from setup and forward runs
max_fwd_index = df_fwd.index.max()
df_bwd_filtered = df_bwd[df_bwd.index > max_fwd_index].copy()

df_bwd_grouped = df_bwd_filtered.groupby(["Simple Kernel", "Simple Metric"], as_index=False)["Metric Value"].sum()
print(df_bwd_grouped)
print()

ffma_bwd = df_bwd_grouped[df_bwd_grouped["Simple Metric"] == "ffma"]["Metric Value"].sum()
fadd_bwd = df_bwd_grouped[df_bwd_grouped["Simple Metric"] == "fadd"]["Metric Value"].sum()
fmul_bwd = df_bwd_grouped[df_bwd_grouped["Simple Metric"] == "fmul"]["Metric Value"].sum()
ncu_total_flops_bwd = ffma_bwd * 2 + fadd_bwd + fmul_bwd

ffma_gemm_bwd = df_bwd_grouped[(df_bwd_grouped["Simple Kernel"].str.contains("sgemm")) & (df_bwd_grouped["Simple Metric"] == "ffma")]["Metric Value"].sum()
fadd_gemm_bwd = df_bwd_grouped[(df_bwd_grouped["Simple Kernel"].str.contains("sgemm")) & (df_bwd_grouped["Simple Metric"] == "fadd")]["Metric Value"].sum()
fmul_gemm_bwd = df_bwd_grouped[(df_bwd_grouped["Simple Kernel"].str.contains("sgemm")) & (df_bwd_grouped["Simple Metric"] == "fmul")]["Metric Value"].sum()
ncu_total_flops_gemm_bwd = ffma_gemm_bwd * 2 + fadd_gemm_bwd + fmul_gemm_bwd

ffma_ksplit_bwd = df_bwd_grouped[(df_bwd_grouped["Simple Kernel"].str.contains("splitK|reduce_kernel")) & (df_bwd_grouped["Simple Metric"] == "ffma")]["Metric Value"].sum()
fadd_ksplit_bwd = df_bwd_grouped[(df_bwd_grouped["Simple Kernel"].str.contains("splitK|reduce_kernel")) & (df_bwd_grouped["Simple Metric"] == "fadd")]["Metric Value"].sum()
fmul_ksplit_bwd = df_bwd_grouped[(df_bwd_grouped["Simple Kernel"].str.contains("splitK|reduce_kernel")) & (df_bwd_grouped["Simple Metric"] == "fmul")]["Metric Value"].sum()
ncu_total_flops_ksplit_bwd = ffma_ksplit_bwd * 2 + fadd_ksplit_bwd + fmul_ksplit_bwd

ffma_activation_bwd = df_bwd_grouped[(df_bwd_grouped["Simple Kernel"].str.contains("elementwise")) & (df_bwd_grouped["Simple Metric"] == "ffma")]["Metric Value"].sum()
fadd_activation_bwd = df_bwd_grouped[(df_bwd_grouped["Simple Kernel"].str.contains("elementwise")) & (df_bwd_grouped["Simple Metric"] == "fadd")]["Metric Value"].sum()
fmul_activation_bwd = df_bwd_grouped[(df_bwd_grouped["Simple Kernel"].str.contains("elementwise")) & (df_bwd_grouped["Simple Metric"] == "fmul")]["Metric Value"].sum()
ncu_total_flops_activation_bwd = ffma_activation_bwd * 2 + fadd_activation_bwd + fmul_activation_bwd

assert ncu_total_flops_bwd == ncu_total_flops_gemm_bwd + ncu_total_flops_ksplit_bwd + ncu_total_flops_activation_bwd, "Total FLOPs do not match for backward pass!"

print(f"Total FLOPs GEMM NCU (bwd): {ncu_total_flops_gemm_bwd:,}")
print(f"Total FLOPs KSplit NCU (bwd): {ncu_total_flops_ksplit_bwd:,}")
print(f"Total FLOPs Activation NCU (bwd): {ncu_total_flops_activation_bwd:,}")
print(f"Total FLOPs NCU (bwd): {ncu_total_flops_bwd:,}")

                         Simple Kernel Simple Metric  Metric Value
0               ampere_sgemm_32x128_tn          ffma  1.342177e+09
1               ampere_sgemm_32x128_tn          fmul  1.048576e+07
2   cutlass_80_simt_sgemm_64x64_8x5_nt          ffma  1.207960e+09
3   cutlass_80_simt_sgemm_64x64_8x5_nt          fmul  9.437184e+06
4                   elementwise_kernel          fadd  1.310720e+05
5                   elementwise_kernel          fmul  2.621440e+05
6                        reduce_kernel          fadd  2.073440e+05
7                  splitKreduce_kernel          fadd  9.437184e+06
8                  splitKreduce_kernel          fmul  1.179648e+06
9        vectorized_elementwise_kernel          fadd  1.179648e+06
10       vectorized_elementwise_kernel          fmul  2.359296e+06

Total FLOPs GEMM NCU (bwd): 5,120,196,608.0
Total FLOPs KSplit NCU (bwd): 10,824,176.0
Total FLOPs Activation NCU (bwd): 3,932,160.0
Total FLOPs NCU (bwd): 5,134,952,944.0


### Theoretical Calculations Backward

In [18]:
theoretical_gemm_flops_bwd = (N - 1) * 4 * D * D * M + 2 * D * D * M
theoretical_ksplit_flops_bwd = 0 # not considered in the theoretical model
theoretical_activation_flops_bwd = (N - 1) * 2 * D * M + D * M
theoretical_total_flops_bwd = theoretical_gemm_flops_bwd + theoretical_activation_flops_bwd

print("N:", N, "D:", D, "M:", M)
print(f"GEMM (bwd): {theoretical_gemm_flops_bwd:,}")
print(f"KSplit (bwd): {theoretical_ksplit_flops_bwd:,}")
print(f"Activation (bwd): {theoretical_activation_flops_bwd:,}")
print(f"Total FLOPs (GEMM + KSplit + Activation) (bwd): {theoretical_total_flops_bwd:,}")

N: 10 D: 1024 M: 128
GEMM (bwd): 5,100,273,664
KSplit (bwd): 0
Activation (bwd): 2,490,368
Total FLOPs (GEMM + KSplit + Activation) (bwd): 5,102,764,032


### Difference in Backward: NCU vs. Theoretical

In [19]:
diff_gemm_bwd_pct       = (ncu_total_flops_gemm_bwd - theoretical_gemm_flops_bwd) / theoretical_gemm_flops_bwd * 100
diff_gemm_bwd           = ncu_total_flops_gemm_bwd - theoretical_gemm_flops_bwd
diff_ksplit_bwd_pct     = float('inf') if theoretical_ksplit_flops_bwd == 0 else (ncu_total_flops_ksplit_bwd - theoretical_ksplit_flops_bwd) / theoretical_ksplit_flops_bwd * 100
diff_ksplit_bwd         = ncu_total_flops_ksplit_bwd - theoretical_ksplit_flops_bwd
diff_activation_bwd_pct = (ncu_total_flops_activation_bwd - theoretical_activation_flops_bwd) / theoretical_activation_flops_bwd * 100
diff_activation_bwd     = ncu_total_flops_activation_bwd - theoretical_activation_flops_bwd

print(f"GEMM (bwd) % diff:            {diff_gemm_bwd_pct:.2f}%")
print(f"GEMM (bwd) abs diff:          {diff_gemm_bwd:,}\n")
print(f"KSplit (bwd) % diff:          {diff_ksplit_bwd_pct:.2f}%")
print(f"KSplit (bwd) abs diff:        {diff_ksplit_bwd:,}\n")
print(f"Activation (bwd) % diff:      {diff_activation_bwd_pct:.2f}%")
print(f"Activation (bwd) abs diff:    {diff_activation_bwd:,}\n")
print(f"Summed total diff (bwd):      {diff_gemm_bwd + diff_ksplit_bwd + diff_activation_bwd:,}")
print(f"Actual total diff (bwd):      {ncu_total_flops_bwd - theoretical_total_flops_bwd:,}")

GEMM (bwd) % diff:            0.39%
GEMM (bwd) abs diff:          19,922,944.0

KSplit (bwd) % diff:          inf%
KSplit (bwd) abs diff:        10,824,176.0

Activation (bwd) % diff:      57.89%
Activation (bwd) abs diff:    1,441,792.0

Summed total diff (bwd):      32,188,912.0
Actual total diff (bwd):      32,188,912.0


### NCU Closed Form Reconstruction Backward

In [20]:
# gemm overhead for backward pass
epsilon_gemm_bwd = ncu_total_flops_gemm_bwd / theoretical_gemm_flops_bwd - 1

# ksplit overhead for backward pass - calculated as overhead factor like epsilon
# Since theoretical_ksplit_flops_bwd = 0, we need a different approach
# We'll calculate rho as FLOPs per D*M operation (absolute, not overhead)
rho_bwd = ncu_total_flops_ksplit_bwd / ((2 * N - 1) * D * M)  # Total ksplit operations in backward

# activation overhead for backward pass - calculated as overhead factor
kappa_bwd = ncu_total_flops_activation_bwd / theoretical_activation_flops_bwd - 1

# Reconstruct total FLOPs with consistent usage:
# - epsilon_gemm_bwd and kappa_bwd are overhead factors (use 1 + factor)
# - rho_bwd is absolute FLOPs per operation (use directly)
ncu_total_flops_reconstruct_bwd = (
    (N-1) * (4 * (1 + epsilon_gemm_bwd) * D * D * M + 2 * (1 + kappa_bwd + rho_bwd) * D * M)
    + 2 * (1 + epsilon_gemm_bwd) * D * D * M + (1 + kappa_bwd + rho_bwd) * D * M
)

print(f"epsilon_gemm: {epsilon_gemm_bwd:.4f}, rho: {rho_bwd:.4f}, kappa_bwd: {kappa_bwd:.4f}")
print(f"Total Theoretical FLOPs (bwd): {theoretical_total_flops_bwd:,}")
print(f"Total FLOPs NCU (bwd): {ncu_total_flops_bwd:,}")
print(f"Total Reconstructed NCU FLOPs (bwd): {ncu_total_flops_reconstruct_bwd:,}")
print(f"NCU Reconstruction vs NCU Total Diff (bwd): {ncu_total_flops_reconstruct_bwd - ncu_total_flops_bwd:,}")

epsilon_gemm: 0.0039, rho: 4.3464, kappa_bwd: 0.5789
Total Theoretical FLOPs (bwd): 5,102,764,032
Total FLOPs NCU (bwd): 5,134,952,944.0
Total Reconstructed NCU FLOPs (bwd): 5,134,952,944.0
NCU Reconstruction vs NCU Total Diff (bwd): 0.0
