# Example: LlaMa-3.2-1B on RTX4090 

In [None]:
import os, sys, pathlib
sys.path.append(str(pathlib.Path("resnet.ipynb").resolve().parents[1]))

DEVICE = "cuda:0"

In [None]:
import gc
from huggingface_hub import notebook_login
notebook_login()

## Build model with trainable gates

In [None]:
from core.profiler import measure_latency_text_ms 

from adapters.huggingface.llama import (
    LlamaAdapter,
    LlamaGatingConfig,
    LlamaExportPolicy,
    LatencyProxyLLM,
    infer_slim_meta,
    load_slim_for_finetune
)

# Script to build config from a recipe
from examples.run_llama_optimize import build_from_recipe, _ids_mask, _ProxyBridge

# Get needed metadata from recipe and download the base model
pack = build_from_recipe("../recipes/RTX4090/llama_3_2_1b.yaml")

In [None]:
gated_model = pack["student"] # LlaMa-3.2-1B
teacher     = pack["teacher"] # Another instance of LlaMa-3.2-1B

# HawAda adapter to train gates
adapter = LlamaAdapter(student)

# Gates config
gate_cfg = LlamaGatingConfig(
    tau=float(pack.get("gating").get("tau")), # Target latency (respect to base latency)
    init_logit=float(pack.get("gating").get("init_logit")),  # Initial logit for gate sigmoids
    head_gating=bool(pack.get("gating").get("head_gating")), # If to gate heads or not
    gate_kv=bool(pack.get("gating").get("gate_kv")),  # If to gate KV layers or not (default=False)
    ffn_group=int(pack.get("gating").get("ffn_group")), # Feed-forward network group size
    ffn_gating=bool(pack.get("gating", {}).get("ffn_gating")), # If to gate FFN layers or not
    hard_eval=bool(pack.get("gating", {}).get("hard_eval")),
)

# Attach gates to student
gated_model = adapter.attach_gates(gate_cfg).train().to(DEVICE)

### Download pre-trained gates from HawAda repo

In [None]:
ckpt_path  = hf_hub_download(gated_model_repo, "pytorch_model.bin")
state_dict = torch.load(ckpt_path, map_location=DEVICE, weights_only=False)

missing, unexpected = gated_model.load_state_dict(state_dict, strict=False)
print("missing:", len(missing), "unexpected:", len(unexpected))

### Prune and export

In [None]:
# Check configuration for pruning and export
print("Policy for the probes during training:", pack["probe_policy"])
print("\nPolicy for the final pruning:", pack["export_policy"])

slim_model = adapter.export_pruned(gated_model, 
                                   policy=pack["export_policy"], 
                                   step=9999)

## Measure latency

In [None]:
teacher      = pack["teacher"] # Another instance of LlaMa-3.2-1B
train_loader = pack["train_loader"]
val_loader   = pack["val_loader"]
llm_proxy    = pack["proxy"]
decode_T     = pack["decode_T"]


batch = next(iter(val_loader))
B, S = int(ids.size(0)), int(ids.size(1))
lat_cfg = cfg.get("latency", {})
decode_T = int(lat_cfg.get("decode_T_tokens", 128))

print(f"Starting benchmarking with batch size = {B}, S = {S}, decode_T = {decode_T}...")
mean_keep, p95_keep, _ = measure_latency_text_ms(teacher.to(device).eval(), B=B, S=S, T=decode_T, device=device)
mean_slim, p95_slim, _ = measure_latency_text_ms(slim.to(device).eval(), B=B, S=S, T=decode_T, device=device)
print(f"Base: mean={mean_keep:.3f}ms p95={p95_keep:.3f}ms")
print(f"Slim: mean={mean_slim:.3f}ms p95={p95_slim:.3f}ms\n")
if mean_keep > 0:
    print(f"Speedup={100.0*(mean_keep-mean_slim)/mean_keep:.2f}%")

Now you your own LlaMa version, optimized for RTX4090.

You will want to fine-tune it for your dowstream task; this can be done on any other device (e.g., H100).

HawAda framework allows you to optimize the model for other GPUs; To do it, follow the following steps:

* Create your own recipe
* Attach HawAda adapter to the model
* Run the gates training on your device (see ResNet notebook)
* Export pruned model after gates are trained
* Run grid search to choose the best shapes
* Run distillation (fine tuning) for your downstream task