In [2]:
# Example: optimize LlaMa-3.2-1B for RTX4090

In [3]:
import os, sys, pathlib
sys.path.append(str(pathlib.Path("resnet.ipynb").resolve().parents[1]))

DEVICE = "cuda:0"

In [28]:
import gc
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [12]:
slim_model_repo  = "hawada/Llama-3.2-1B-rtx4090-slim"
gated_model_repo = "hawada/Llama-3.2-1B-rtx4090-gated"

# slim_model_repo  = "hawada/Llama-3.2-1B-h100-slim"
# gated_model_repo = "hawada/Llama-3.2-1B-h100-gated"

In [None]:
from core.profiler import measure_latency_text_ms 

from adapters.huggingface.llama import (
    LlamaAdapter,
    LlamaGatingConfig,
    LlamaExportPolicy,
    LatencyProxyLLM,
    infer_slim_meta,
    load_slim_for_finetune
)

# Script to build config from a recipe
from examples.run_llama_optimize import build_from_recipe, _ids_mask, _ProxyBridge

# Get needed metadata from recipe and download the base model
pack = build_from_recipe("../recipes/RTX4090/llama_3_2_1b.yaml")

# print("Student:", pack["recipe"]["model"]["name"])
# print("Teacher:", pack["recipe"]["base_model"])

In [30]:
gated_model = pack["student"] # LlaMa-3.2-1B
teacher     = pack["teacher"] # Another instance of LlaMa-3.2-1B

# HawAda adapter to train gates
adapter = LlamaAdapter(student)

# Gates config
gate_cfg = LlamaGatingConfig(
    tau=float(pack.get("gating").get("tau")), # Target latency (respect to base latency)
    init_logit=float(pack.get("gating").get("init_logit")),  # Initial logit for gate sigmoids
    head_gating=bool(pack.get("gating").get("head_gating")), # If to gate heads or not
    gate_kv=bool(pack.get("gating").get("gate_kv")),  # If to gate KV layers or not (default=False)
    ffn_group=int(pack.get("gating").get("ffn_group")), # Feed-forward network group size
    ffn_gating=bool(pack.get("gating", {}).get("ffn_gating")), # If to gate FFN layers or not
    hard_eval=bool(pack.get("gating", {}).get("hard_eval")),
)

# Attach gates to student
gated_model = adapter.attach_gates(gate_cfg).train().to(DEVICE)

In [31]:
ckpt_path  = hf_hub_download(gated_model_repo, "pytorch_model.bin")
state_dict = torch.load(ckpt_path, map_location=DEVICE, weights_only=False)

missing, unexpected = gated_model.load_state_dict(state_dict, strict=False)
print("missing:", len(missing), "unexpected:", len(unexpected))

missing: 0 unexpected: 0


In [32]:
# Check configuration for pruning and export
print("Policy for the probes during training:", pack["probe_policy"])
print("\nPolicy for the final pruning:", pack["export_policy"])

slim_model = adapter.export_pruned(gated_model, 
                                   policy=pack["export_policy"], 
                                   step=9999)

Policy for the probes during training: LlamaExportPolicy(warmup_steps=0, head_rounding=Rounding(floor_groups=1, multiple_groups=1, min_keep_ratio=0.5), ffn_rounding=Rounding(floor_groups=1, multiple_groups=1, min_keep_ratio=0.5), q_rounding=Rounding(floor_groups=1, multiple_groups=1, min_keep_ratio=0.5))

Policy for the final pruning: LlamaExportPolicy(warmup_steps=200, head_rounding=Rounding(floor_groups=16, multiple_groups=4, min_keep_ratio=0.5), ffn_rounding=Rounding(floor_groups=1, multiple_groups=16, min_keep_ratio=0.5), q_rounding=Rounding(floor_groups=4, multiple_groups=8, min_keep_ratio=0.5))


AttributeError: 'LlamaAttention' object has no attribute 'num_heads'

In [None]:
teacher      = pack["teacher"] # Another instance of LlaMa-3.2-1B
train_loader = pack["train_loader"]
val_loader   = pack["val_loader"]
llm_proxy    = pack["proxy"]
decode_T     = pack["decode_T"]


batch = next(iter(val_loader))
B, S = int(ids.size(0)), int(ids.size(1))
lat_cfg = cfg.get("latency", {})
decode_T = int(lat_cfg.get("decode_T_tokens", 128))

print(f"Starting benchmarking with batch size = {B}, S = {S}, decode_T = {decode_T}...")
mean_keep, p95_keep, _ = measure_latency_text_ms(teacher.to(device).eval(), B=B, S=S, T=decode_T, device=device)
mean_slim, p95_slim, _ = measure_latency_text_ms(slim.to(device).eval(), B=B, S=S, T=decode_T, device=device)
print(f"Base: mean={mean_keep:.3f}ms p95={p95_keep:.3f}ms")
print(f"Slim: mean={mean_slim:.3f}ms p95={p95_slim:.3f}ms\n")
if mean_keep > 0:
    print(f"Speedup={100.0*(mean_keep-mean_slim)/mean_keep:.2f}%")

In [None]:
## Calibrate the proxy function for latency measurement on current GPU

In [18]:
train_loader = pack["train_loader"]
llm_proxy    = pack["proxy"]
decode_T     = pack["decode_T"]

# Proxy calibration
batch = next(iter(train_loader))
B, S = int(ids.size(0)), int(ids.size(1))

# measure real keep-all
slim_keepall = adapter.export_keepall(student).to(DEVICE).eval()
real_ms, _, _ = measure_latency_text_ms(slim_keepall, B=B, S=S, T=decode_T, device=DEVICE)
del slim_keepall; gc.collect()

# proxy's raw keep-all prediction (pre-scale)
raw_pred = llm_proxy.predict(student, batch)
raw_pred = float(raw_pred.detach().item() if hasattr(raw_pred, "detach") else raw_pred)
raw_pred = max(raw_pred, 1e-9)

llm_proxy.scale_ms = max(1e-9, float(real_ms) / raw_pred)
print(f"[calibration] keep-all measured ≈ {real_ms:.3f} ms; proxy.scale_ms set to {llm_proxy.scale_ms:.6e}")

[calib] keep-all measured ≈ 1654.990 ms; proxy.scale_ms set to 2.195300e-09
