In [1]:
import sys
import math
import json
import argparse
from pathlib import Path
from typing import Any, Dict, Optional, Tuple

import torch
import yaml

In [2]:
from adapters.huggingface.llama import (
    LlamaAdapter,
    LlamaGatingConfig,
    LlamaExportPolicy,
    LatencyProxyLLM
)
from data.llms import build_llm_dataloaders_from_cfg

In [3]:
from core.profiler import measure_latency_text_ms
from core.train import LagrangeTrainer, TrainerConfig, DualConfig
from core.distill import KDConfig
from core.gates import PenaltyWeights, Constraints
from core.export import Rounding as CoreRounding


In [4]:
from transformers import AutoModelForCausalLM, AutoTokenizer


In [5]:
device = "cuda"
base_id = "meta-llama/Llama-3.2-1B"



In [6]:
# export_policy_final = LlamaExportPolicy(
#     warmup_steps=0,
#     head_rounding=CoreRounding(
#         floor_groups=4,
#         multiple_groups=8,
#         min_keep_ratio=0.5,
#     ),
#     q_rounding=CoreRounding(
#         floor_groups=4,
#         multiple_groups=8,
#         min_keep_ratio=0.5,
#     ),        
#     ffn_rounding=CoreRounding(
#         floor_groups=4,
#         multiple_groups=8,
#         min_keep_ratio=0.5,
#     ),
# )


export_policy_final = LlamaExportPolicy(
    warmup_steps=0,
    head_rounding=CoreRounding(
        floor_groups=1,
        multiple_groups=1,
        min_keep_ratio=0.0,
    ),
    q_rounding=CoreRounding(
        floor_groups=1,
        multiple_groups=1,
        min_keep_ratio=0.0,
    ),        
    ffn_rounding=CoreRounding(
        floor_groups=1,
        multiple_groups=1,
        min_keep_ratio=0.0,
    ),
)

In [7]:
from adapters.huggingface.llama import load_slim_llama, infer_slim_meta

# layer_meta = infer_slim_meta("runs/llama3p2_1b/slim.pt", output_json="runs/llama3p2_1b/slim_meta.json")

In [8]:
slim = load_slim_llama("runs/llama3p2_1b", base_id, device=device)


`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the `config` argument. All other arguments will be removed in v4.45


missing: []
unexpected: []


In [9]:
base = AutoModelForCausalLM.from_pretrained(base_id).to(device).eval()


In [10]:
layer = 0  # pick a few layers

print("=== BASE ===")
print("q_proj.weight:", base.model.layers[layer].self_attn.q_proj.weight.shape)
print("o_proj.weight:", base.model.layers[layer].self_attn.o_proj.weight.shape)
print("up_proj.weight:", base.model.layers[layer].mlp.up_proj.weight.shape)
print("down_proj.weight:", base.model.layers[layer].mlp.down_proj.weight.shape)
print("num_heads:", base.model.layers[layer].self_attn.num_heads)

print("\n=== SLIM ===")
print("q_proj.weight:", slim.model.layers[layer].self_attn.q_proj.weight.shape)
print("o_proj.weight:", slim.model.layers[layer].self_attn.o_proj.weight.shape)
print("up_proj.weight:", slim.model.layers[layer].mlp.up_proj.weight.shape)
print("down_proj.weight:", slim.model.layers[layer].mlp.down_proj.weight.shape)
print("num_heads:", slim.model.layers[layer].self_attn.num_heads)

=== BASE ===
q_proj.weight: torch.Size([2048, 2048])
o_proj.weight: torch.Size([2048, 2048])
up_proj.weight: torch.Size([8192, 2048])
down_proj.weight: torch.Size([2048, 8192])
num_heads: 32

=== SLIM ===
q_proj.weight: torch.Size([1536, 2048])
o_proj.weight: torch.Size([2048, 1536])
up_proj.weight: torch.Size([4096, 2048])
down_proj.weight: torch.Size([2048, 4096])
num_heads: 24


In [11]:
B = 1
S = 512
decode_T = 128


# Warm-up (helps get rid of startup noise)
_ = measure_latency_text_ms(base, B=B, S=S, T=8, device=device)
_ = measure_latency_text_ms(slim, B=B, S=S, T=8, device=device)

base_ms, base_p95, base_std = measure_latency_text_ms(
    base, B=B, S=S, T=decode_T, device=device
)
print(f"[BASE] mean={base_ms:.3f} ms, p95={base_p95:.3f} ms, std={base_std:.3f}")

slim_ms, slim_p95, slim_std = measure_latency_text_ms(
    slim, B=B, S=S, T=decode_T, device=device
)
print(f"[SLIM] mean={slim_ms:.3f} ms, p95={slim_p95:.3f} ms, std={slim_std:.3f}")
print(f"Speedup (mean): {100*(base_ms-slim_ms) / base_ms:.3f}%")
print(f"Speedup (p95):  {100*(base_p95-slim_p95) / base_p95:.3f}%")

We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)


[BASE] mean=1253.050 ms, p95=1271.636 ms, std=21.959
[SLIM] mean=1149.150 ms, p95=1191.154 ms, std=28.820
Speedup (mean): 8.292%
Speedup (p95):  6.329%


In [13]:
import torch
from time import perf_counter

device = "cuda"
B, T, H = 4, 512, 2048

x = torch.randn(B, T, H, device=device, dtype=torch.float16)

dense_layer = base.model.layers[0].to(device).eval()
slim_layer  = slim.model.layers[0].to(device).eval()

# HF normally does this inside LlamaModel
position_ids = torch.arange(T, device=device).unsqueeze(0).expand(B, -1)

@torch.inference_mode()
def bench(layer, iters=50):
    # warmup
    for _ in range(10):
        _ = layer(
            hidden_states=x,
            attention_mask=None,         # fine for microbench; SDPA uses causal
            position_ids=position_ids,   # this is what was missing
        )[0]  # first element = hidden_states
    torch.cuda.synchronize()

    start = perf_counter()
    for _ in range(iters):
        _ = layer(
            hidden_states=x,
            attention_mask=None,
            position_ids=position_ids,
        )[0]
    torch.cuda.synchronize()
    return (perf_counter() - start) * 1000 / iters

t_dense = bench(dense_layer)
t_slim  = bench(slim_layer)

print(f"dense layer: {t_dense:.3f} ms")
print(f" slim layer: {t_slim:.3f} ms")
print(f" per-layer speedup: {(t_dense - t_slim)/t_dense*100:.1f}%")


The attention layers in this model are transitioning from computing the RoPE embeddings internally through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed `position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be removed and `position_embeddings` will be mandatory.


dense layer: 7.051 ms
 slim layer: 4.004 ms
 per-layer speedup: 43.2%


In [15]:
for decode_T in [8, 32, 128, 256]:
    base_mean, base_p95, base_std = measure_latency_text_ms(base, B=B, S=S, T=decode_T, device=device)
    slim_mean, slim_p95, slim_std = measure_latency_text_ms(slim, B=B, S=S, T=decode_T, device=device)
    print(f"T={decode_T}: speedup_mean={ (base_mean - slim_mean)/base_mean*100:.1f}%")


T=8: speedup_mean=31.4%
T=32: speedup_mean=17.9%
T=128: speedup_mean=11.5%
T=256: speedup_mean=9.7%
