# CODE THAT CAOUNTS PARAMETERS OF A MODEL AND EXAMPLARY COMPRESSION RATIOS

In [5]:
def llama3_param_count_v2(
    n_layers: int,
    d_model: int,
    n_heads: int,
    n_kv_heads: int,
    vocab_size: int,
    ffn_dim: int,
    datt:int,
    gated:bool = True
    # untied_lm_head: bool = False,  # if True, count a separate LM head (vocab*d_model)
    # include_bias: bool = True,     # whether to count biases in linear layers
):
    """
    More accurate parameter counter for LLaMA-3 style models:
      - Uses explicit ffn_dim
      - Supports GQA (grouped KV heads)
      - Counts optional biases
      - Handles tied or untied embeddings
    """
    # === Embeddings ===
    token_embed = vocab_size * d_model
    # lm_head = vocab_size * d_model if untied_lm_head else 0
    lm_head = vocab_size * d_model

    # === Per-layer ===
    # Q projection: d_model -> d_model
    q_w = d_model * datt
    # q_b = d_model if include_bias else 0

    # K/V projections: d_model -> kv_out_dim
    kv_out_dim = int(round(datt * (n_kv_heads / n_heads)))
    k_w = d_model * kv_out_dim
    v_w = d_model * kv_out_dim
    # k_b = kv_out_dim if include_bias else 0
    # v_b = kv_out_dim if include_bias else 0

    # Output projection: d_model -> d_model
    out_w = d_model * d_model
    # out_b = d_model if include_bias else 0

    # SwiGLU FFN
    w1 = d_model * ffn_dim
    wgate = d_model * ffn_dim
    w2 = ffn_dim * d_model
    # b_w1 = ffn_dim if include_bias else 0
    # b_wgate = ffn_dim if include_bias else 0
    # b_w2 = d_model if include_bias else 0
    if gated:
        ffn_total = w1 + wgate + w2
    else:
        ffn_total = w1 + w2
    # ffn_bias = b_w1 + b_wgate + b_w2

    # Norms (2 per layer + final)
    rms_per_layer = 2 * d_model
    rms_final = d_model

    # per_layer = q_w + q_b + k_w + k_b + v_w + v_b + out_w + out_b + ffn_total + ffn_bias + rms_per_layer
    per_layer = q_w + k_w + v_w + out_w + ffn_total + rms_per_layer
    total = token_embed + n_layers * per_layer + rms_final + lm_head
    return total


def from_llama_8B(
                d_model=4096,
                dff=14336,
                n_layers=32,
                datt=4096,
                n_heads=32,
                n_kv_heads=8,
                vocab_size=128256):
    params = llama3_param_count_v2(n_layers=n_layers,d_model=d_model,n_heads=n_heads,n_kv_heads=n_kv_heads,vocab_size=vocab_size,ffn_dim=dff, datt=datt)
    print(f"LLaMA-3 param count: {params:,}")
    return params

def from_llama_1B(
                d_model=2048,
                dff=8192,
                n_layers=16,
                datt=2048,
                n_heads=32,
                n_kv_heads=8,
                vocab_size=128256):
    params = llama3_param_count_v2(n_layers=n_layers,d_model=d_model,n_heads=n_heads,n_kv_heads=n_kv_heads,vocab_size=vocab_size,ffn_dim=dff, datt=datt)
    print(f"LLaMA-3 param count: {params:,}")
    return params

def from_llmr_800(
                d_model=1536,
                dff=6144,
                n_layers=24,
                datt=1536,
                n_heads=24,
                n_kv_heads=24,
                vocab_size=50257):
    params = llama3_param_count_v2(n_layers=n_layers,d_model=d_model,n_heads=n_heads,n_kv_heads=n_kv_heads,vocab_size=vocab_size,ffn_dim=dff, datt=datt, gated=False)
    print(f"LLaMA-3 param count: {params:,}")
    return params

def from_llmr_300(
                d_model=1024,
                dff=4096,
                n_layers=16,
                datt=1024,
                n_heads=16,
                n_kv_heads=16,
                vocab_size=50257):
    params = llama3_param_count_v2(n_layers=n_layers,d_model=d_model,n_heads=n_heads,n_kv_heads=n_kv_heads,vocab_size=vocab_size,ffn_dim=dff, datt=datt, gated=False)
    print(f"LLaMA-3 param count: {params:,}")
    return params

def from_llmr_360(
                d_model=960,
                dff=2560,
                n_layers=32,
                datt=960,
                n_heads=15,
                n_kv_heads=5,
                vocab_size=49152):
    params = llama3_param_count_v2(n_layers=n_layers,d_model=d_model,n_heads=n_heads,n_kv_heads=n_kv_heads,vocab_size=vocab_size,ffn_dim=dff, datt=datt, gated=True)
    print(f"LLaMA-3 param count: {params:,}")
    return params

def from_sllm_1700(
                d_model=2048,
                dff=8192,
                n_layers=24,
                datt=2048,
                n_heads=32,
                n_kv_heads=32,
                vocab_size=49152):
    params = llama3_param_count_v2(n_layers=n_layers,d_model=d_model,n_heads=n_heads,n_kv_heads=n_kv_heads,vocab_size=vocab_size,ffn_dim=dff, datt=datt, gated=True)
    print(f"LLaMA-3 param count: {params:,}")
    return params

from_llama_8B(), from_llama_1B(), from_llmr_800(), from_llmr_300(), from_llmr_360(), from_sllm_1700()


LLaMA-3 param count: 8,030,261,248
LLaMA-3 param count: 1,498,482,688
LLaMA-3 param count: 833,942,016
LLaMA-3 param count: 304,286,720
LLaMA-3 param count: 409,007,040
LLaMA-3 param count: 1,812,039,680


(8030261248, 1498482688, 833942016, 304286720, 409007040, 1812039680)

# RATIOS FOR PAPER MODELS

### LLAMA3 1B

In [6]:
# In Configs 

# dmodel: 1344 # 50
# dff: 5376

# dmodel: 960 # 30
# dff: 3840

# dmodel: 448 # 10
# dff: 1728

bdmodel = 2048 # llama3 1B
bdff = 8192 # llama3 1B

# 50% 0.525
ratio = 1/64*42

print(f"llama3 1B 50% compression ratio {1 - from_llama_1B(2048*ratio, 8192*ratio)/from_llama_1B()}")
print(f"dmodel: {bdmodel*ratio}, dff: {bdff*ratio}\n")

# 30% 0.324
ratio = 1/64*30

print(f"llama3 1B 70% compression ratio {1 - from_llama_1B(2048*ratio, 8192*ratio)/from_llama_1B()}")
print(f"dmodel: {bdmodel*ratio}, dff: {bdff*ratio}\n")

# 10% 0.119
ratio = 1/64*14

print(f"llama3 1B 90% compression ratio {1 - from_llama_1B(2048*ratio, 8192*ratio)/from_llama_1B()}")
print(f"dmodel: {bdmodel*ratio}, dff: {bdff*ratio}\n")

LLaMA-3 param count: 786,574,656.0
LLaMA-3 param count: 1,498,482,688
llama3 1B 50% compression ratio 0.4750859237153896
dmodel: 1344.0, dff: 5376.0

LLaMA-3 param count: 485,161,920.0
LLaMA-3 param count: 1,498,482,688
llama3 1B 70% compression ratio 0.6762312144910145
dmodel: 960.0, dff: 3840.0

LLaMA-3 param count: 178,698,688.0
LLaMA-3 param count: 1,498,482,688
llama3 1B 90% compression ratio 0.8807469119055982
dmodel: 448.0, dff: 1792.0



### LLAMA3 8B

In [7]:
bdmodel = 4096 # llama3 8B
bdff = 14336 # llama3 8B

# 3072/4096, 9216/14336

m_dmodel_r = 3072/bdmodel
m_bdff_r = 9216/bdff
print(f"Minitrin compression ratio: {1-from_llama_8B(3072, 9216)/from_llama_8B()}")
print(f"Minitron compression ratios, dmodel ratio: {m_dmodel_r}, dff ratio: {m_bdff_r}")
print(f"dmodel: {bdmodel*m_dmodel_r}, dff: {bdff*m_bdff_r}\n")


# 50%
ratio = 1/64*64
dmodel_r = m_dmodel_r * ratio
dff_r = m_bdff_r * ratio
print(f"llama3 8B 50% compression ratio: {1 - from_llama_8B(bdmodel*dmodel_r, bdff*dff_r)/from_llama_8B()}")
print(f"dmodel: {bdmodel*dmodel_r}, dff: {bdff*dff_r}\n")

# 70% # (2240, 7840) # 0.3564
ratio = 1/64*45
dmodel_r = m_dmodel_r * ratio
dff_r = m_bdff_r * ratio
print(f"llama3 8B 70% compression ratio: {1 - from_llama_8B(bdmodel*dmodel_r, bdff*dff_r)/from_llama_8B()}")
print(f"dmodel: {bdmodel*dmodel_r}, dff: {bdff*dff_r}\n")


# 90% # (1024, 3584) # 0.1058
ratio = 1/64*22
dmodel_r = m_dmodel_r * ratio
dff_r = m_bdff_r * ratio
print(f"llama3 8B 90% compression ratio: {1 - from_llama_8B(bdmodel*dmodel_r, bdff*dff_r)/from_llama_8B()}")
print(f"dmodel: {bdmodel*dmodel_r}, dff: {bdff*dff_r}\n")

LLaMA-3 param count: 4,412,083,200
LLaMA-3 param count: 8,030,261,248
Minitrin compression ratio: 0.45056791258206397
Minitron compression ratios, dmodel ratio: 0.75, dff ratio: 0.6428571428571429
dmodel: 3072.0, dff: 9216.0

LLaMA-3 param count: 4,412,083,200.0
LLaMA-3 param count: 8,030,261,248
llama3 8B 50% compression ratio: 0.45056791258206397
dmodel: 3072.0, dff: 9216.0

LLaMA-3 param count: 2,471,871,600.0
LLaMA-3 param count: 8,030,261,248
llama3 8B 70% compression ratio: 0.6921804255601722
dmodel: 2160.0, dff: 6480.0

LLaMA-3 param count: 835,406,880.0
LLaMA-3 param count: 8,030,261,248
llama3 8B 90% compression ratio: 0.8959676585605401
dmodel: 1056.0, dff: 3168.0000000000005



# OTHER MODELS RATIO - NOT IN PAPER

In [8]:
# Conjoinded ratio cdmodel = cdff !!! - DONT USE IN PAPER EXPS - USE MINITRON RATIO ABOVE
bdmodel = 4096 # llama3 8B
bdff = 14336 # llama3 8B

# 50%
ratio = 1/64*43
print(f"llama3 1B 50% compression ratio {1 - from_llama_8B(bdmodel*ratio, bdff*ratio)/from_llama_8B()}")
print(f"dmodel: {bdmodel*ratio}, dff: {bdff*ratio}\n")

# 70% # (2240, 7840) # 0.3564
ratio = 1/64*32
print(f"llama3 1B 70% compression ratio {1 - from_llama_8B(bdmodel*ratio, bdff*ratio)/from_llama_8B()}")
print(f"dmodel: {bdmodel*ratio}, dff: {bdff*ratio}\n")


# 90% # (1024, 3584) # 0.1058
ratio = 1/64*16
print(f"llama3 1B 90% compression ratio {1 - from_llama_8B(bdmodel*ratio, bdff*ratio)/from_llama_8B()}")
print(f"dmodel: {bdmodel*ratio}, dff: {bdff*ratio}\n")

LLaMA-3 param count: 4,034,214,592.0
LLaMA-3 param count: 8,030,261,248
llama3 1B 50% compression ratio 0.49762349350654655
dmodel: 2752.0, dff: 9632.0

LLaMA-3 param count: 2,471,626,752.0
LLaMA-3 param count: 8,030,261,248
llama3 1B 70% compression ratio 0.6922109162244781
dmodel: 2048.0, dff: 7168.0

LLaMA-3 param count: 849,937,408.0
LLaMA-3 param count: 8,030,261,248
llama3 1B 90% compression ratio 0.8941581871683585
dmodel: 1024.0, dff: 3584.0



In [9]:
bdmodel = 960
bdff = 2560


# 50%
ratio = 1/32*21
# 35%
ratio = 1/32*16
# 20%
ratio = 1/32*11

print(from_llmr_360(bdmodel*ratio, bdff*ratio)/from_llmr_360())
bdmodel*ratio, bdff*ratio

LLaMA-3 param count: 80,720,970.0
LLaMA-3 param count: 409,007,040
0.1973583877676042


(330.0, 880.0)

In [10]:
bdmodel = 1536
bdff = 6144


# 10%
ratio = 1/32*11
print(from_llmr_800(bdmodel*ratio, bdff*ratio)/from_llmr_800())
bdmodel*ratio, bdff*ratio

LLaMA-3 param count: 171,707,184.0
LLaMA-3 param count: 833,942,016
0.20589822878045277


(528.0, 2112.0)

In [11]:
bdmodel = 1024
bdff = 4096

# 10%
ratio = 1/32*15
print(from_llmr_300(bdmodel*ratio, bdff*ratio)/from_llmr_300())
bdmodel*ratio, bdff*ratio

LLaMA-3 param count: 105,033,120.0
LLaMA-3 param count: 304,286,720
0.34517812673520554


(480.0, 1920.0)

In [12]:
bdmodel = 2048
bdff = 8192


# 50% ~0.524
ratio = 1/64*43

# 30% ~0.324
# ratio = 1/64*32

# 10% ~0.119
# ratio = 1/64*16

print(from_sllm_1700(bdmodel*ratio, bdff*ratio)/from_sllm_1700())
bdmodel*ratio, bdff*ratio

LLaMA-3 param count: 928,966,496.0
LLaMA-3 param count: 1,812,039,680
0.512663440129523


(1376.0, 5504.0)

In [13]:
# 50%
ratio = 1/32*8

ratio_ff = ratio #1/64*41
print(from_llama_8B(4096*ratio, 14336*ratio_ff)/from_llama_8B())

4096*ratio, 14336*ratio_ff


LLaMA-3 param count: 849,937,408.0
LLaMA-3 param count: 8,030,261,248
0.10584181283164151


(1024.0, 3584.0)