In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import quad

def _calc_keep_ratio_for_layer(f, layer_idx: int, K: int=0, min_keep_ratio: int = 0):
    #Note that K is the delay after how many layers pruning should start
    total_layers = 32
    layer_idx = layer_idx - K
    if layer_idx < 0:
        return 1, 1
    cumulative_keep_ratio = 1
    # layer_prune_ratio = (f(layer_idx + 1) - f(layer_idx)) / f(total_layers+1)# denominator normalizes values to (0,1)
    layer_prune_ratio = f(layer_idx) / f(total_layers)# denominator normalizes values to (0,1)
    layer_keep_ratio = 1 - layer_prune_ratio
    
    # calculate the overall percentage of tokens pruned
    for i in range(0,layer_idx+1):
        cumulative_keep_ratio *= layer_keep_ratio

    # Keep all remaining tokens if min_keep_ratio
    if cumulative_keep_ratio < min_keep_ratio:
        return min_keep_ratio, min_keep_ratio
    return layer_keep_ratio, cumulative_keep_ratio

# Define normalized functions (all start at 0 and end at 32)
def linear(x): return x
def quadratic(x): return ((x) ** 2)  
def logarithmic(x): return np.log(x+1) if x >= 0 else 0  # Natural log

# Calculate the integral of each function over the range of layers
layers = np.arange(0,32)
total_layers = 32

normalized_functions = {
    # "Linear (f(x)=x)": lambda x: linear(x),
    # "Quadratic (f(x)=x²)": lambda x: quadratic(x),
    # "Logarithmic (f(x)=log(x+1))": lambda x: logarithmic(x),
    "Keep 0.5 at layer 2": lambda x: 0.5 if x == 2 else 0
}

plt.figure(figsize=(10, 6))

for name, func in normalized_functions.items():
    keep_ratios = []
    current_keep_ratio = 1.0  # Start with 100% kept

    for layer in layers:
        layer_keep_ratio, cumulative_keep_ratio = _calc_keep_ratio_for_layer(func, layer, K=2, min_keep_ratio=0.5)
        keep_ratios.append(cumulative_keep_ratio * 100)  # Convert to percentage

    plt.plot(layers, keep_ratios, label=name, linewidth=2)

plt.xlabel('Layer Index')
plt.ylabel('Percentage of Tokens Kept (%)')
plt.title('Cumulative Keep Ratio per Layer (Equal Area Under Curve)')
plt.legend()
plt.grid(True)
plt.ylim(0, 100)
plt.show()


In [13]:
# calculate flops

# K ... layer were pruning starts
# R ... prune ratio
# T ... num total layers --> 32
# n ... num tokens
# d ... hidden state embedding size --> 1152
# m ... intermediate size of FFN --> 18944
# n_hat ... (1-R)


def calc_flops_for_layer(n: int = 1, d: int = 1152, m: int = 18944):
    return (4 * n * d**2) + (2 * n**2 * d) + (2 * n * d * m)

layers = np.arange(0,32)


flops = {
    "Linear (f(x)=x)": 0,
    "Quadratic (f(x)=x²)": 0,
    "Logarithmic (f(x)=log(x+1))": 0,
}

for name, func in normalized_functions.items():
    n = 100
    print("====")
    print(name)
    print("====")
    for layer in layers:
        layer_keep_ratio, _ = _calc_keep_ratio_for_layer(func, layer, K=2, min_keep_ratio=0)
        n = n * layer_keep_ratio
        print(layer_keep_ratio)
        print(calc_flops_for_layer(n))
        flops[name] += calc_flops_for_layer(n)

for name, flops in flops.items():
    print(f"{name}: {"{:.2e}".format(flops)}")


====
Linear (f(x)=x)
====
1
4918579200
1
4918579200
0.9696969696969697
4768854320.661158
0.9696969696969697
4623706968.989672
0.9696969696969697
4482996020.035791
0.9696969696969697
4346584769.329749
0.9696969696969697
4214340790.3347845
0.9696969696969697
4086135796.732682
0.9696969696969697
3961845509.364752
0.9696969696969697
3841349527.6584034
0.9696969696969697
3724531205.376314
0.9696969696969697
3611277530.5317545
0.9696969696969697
3501479009.3198805
0.9696969696969697
3395029553.9208074
0.9696969696969697
3291826374.035963
0.9696969696969697
3191769872.024696
0.9696969696969697
3094763541.5133333
0.9696969696969697
3000713869.353869
0.9696969696969697
2909530240.814221
0.9696969696969697
2821124847.8865943
0.9696969696969697
2735412600.604821
0.9696969696969697
2652311041.265742
0.9696969696969697
2571740261.453705
0.9696969696969697
2493622821.7710557
0.9696969696969697
2417883674.181214
0.9696969696969697
2344450086.874379
0.9696969696969697
2273251571.5693398
0.969696969696

0.6920398436307741