# ASA × LFM2.5 — Liquid Prompt τ/α Tuning

**독립 실행 가능** — 이전 노트북 에셋 불필요. 전체 파이프라인을 자체 실행 후 τ/α sweep 수행.

**목표**: Liquid 공식 포맷에서 ASA 억제가 너무 강한 문제(Recall 0.63→0.21)를 해결.
- 기존: τ=0.50, α=1 → FPR **0.116** (최저) but Recall **0.213** (폭락)
- 논문: τ=0.60, α=4 → 더 보수적 게이트 + 더 강한 스티어링

## 1. Setup

In [None]:
!pip install -q transformers accelerate scikit-learn datasets tqdm matplotlib seaborn

import json, re, pickle, os, sys, warnings, gc, ast
import numpy as np
import torch
from pathlib import Path
from typing import Dict, List, Tuple, Optional
from tqdm.auto import tqdm
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (precision_score, recall_score, f1_score,
                             accuracy_score, confusion_matrix, roc_auc_score)
import matplotlib.pyplot as plt
import seaborn as sns

warnings.filterwarnings("ignore")
print(f"PyTorch {torch.__version__}, CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

os.makedirs("outputs_liquid", exist_ok=True)

## 2. Data Pipeline

In [None]:
from datasets import load_dataset
ds = load_dataset("tatsu-lab/alpaca", split="train")
print(f"Alpaca: {len(ds)} samples")

TOOLS = [
    {"name": "calculator", "description": "Evaluate a mathematical expression and return the numeric result.",
     "parameters": {"type": "object", "properties": {"expression": {"type": "string", "description": "Math expression"}}, "required": ["expression"]}},
    {"name": "python_interpreter", "description": "Execute Python code and return the output.",
     "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": "Python source code"}}, "required": ["code"]}},
    {"name": "web_search", "description": "Search the web for up-to-date information.",
     "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "Search query"}}, "required": ["query"]}},
    {"name": "translator", "description": "Translate text from one language to another.",
     "parameters": {"type": "object", "properties": {"text": {"type": "string", "description": "Text to translate"}, "target_language": {"type": "string", "description": "Target language"}}, "required": ["text", "target_language"]}}
]

SYS_PROMPT = f"List of tools: {json.dumps(TOOLS)}"

KW = {
    "math": {"calculate", "compute", "solve", "equation", "sum", "average",
             "percentage", "convert", "ratio", "divide", "multiply"},
    "code": {"write a program", "code", "function", "algorithm", "implement",
             "script", "debug", "compile", "class", "method"},
    "search": {"search", "find", "look up", "latest", "current", "news",
               "recent", "who is", "what happened", "when did"},
    "translation": {"translate", "translation", "say in", "convert to",
                    "how do you say", "in french", "in spanish", "in german",
                    "in japanese", "in chinese"},
}

def classify(text):
    t = text.lower()
    for domain, keywords in KW.items():
        if any(k in t for k in keywords):
            return domain
    return None

np.random.seed(42)
buckets = {d: {"tool": [], "non_tool": []} for d in KW}

for i, row in enumerate(ds):
    text = row.get("instruction", "") + " " + row.get("input", "")
    dom = classify(text)
    if dom is None:
        continue
    output = row.get("output", "")
    label = "tool" if classify(output) == dom or any(k in output.lower() for k in ["result", "output", "answer", "return"]) else "non_tool"
    buckets[dom][label].append(i)

for dom in buckets:
    for lbl in buckets[dom]:
        np.random.shuffle(buckets[dom][lbl])

PER_SPLIT_PER_CLASS = 40
SPLITS = {"cal": 0, "train": 1, "valid": 2, "test": 3}
TEST_MULT = 2

samples = {s: [] for s in SPLITS}
used = set()
for dom in buckets:
    for lbl in ["tool", "non_tool"]:
        pool = [i for i in buckets[dom][lbl] if i not in used]
        for sname, sidx in SPLITS.items():
            n = PER_SPLIT_PER_CLASS * (TEST_MULT if sname == "test" else 1)
            start = sidx * PER_SPLIT_PER_CLASS
            end = start + n
            chosen = pool[start:end]
            used.update(chosen)
            is_tool = 1 if lbl == "tool" else 0
            for idx in chosen:
                row = ds[int(idx)]
                text = row.get("instruction", "") + " " + row.get("input", "")
                samples[sname].append({"text": text, "label": is_tool, "domain": dom, "idx": int(idx)})

for s in samples:
    np.random.shuffle(samples[s])

print("\nSplit sizes:")
for s in samples:
    t = sum(1 for x in samples[s] if x["label"] == 1)
    nt = len(samples[s]) - t
    print(f"  {s:5s}: {len(samples[s])} ({t} tool / {nt} non-tool)")

## 3. Load Model

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM

MODEL_ID = "LiquidAI/LFM2.5-1.2B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True
)
model.eval()

TOOL_S = "<|tool_call_start|>"
TOOL_E = "<|tool_call_end|>"
NUM_LAYERS = model.config.num_hidden_layers
print(f"{MODEL_ID} ({sum(p.numel() for p in model.parameters()) // 1_000_000}M params, {NUM_LAYERS} layers)")

## 4. Extract Hidden States

In [None]:
def extract_hidden_states(split_data, split_name):
    '''Extract last-token hidden states from all layers.'''
    all_hs = {l: [] for l in range(NUM_LAYERS)}
    labels, domains = [], []
    hooks, captured = [], {}

    def make_hook(layer_idx):
        def hook_fn(module, inp, out):
            h = out[0] if isinstance(out, tuple) else out
            captured[layer_idx] = h[:, -1, :].detach().cpu().float().numpy()
        return hook_fn

    for l in range(NUM_LAYERS):
        hooks.append(model.model.layers[l].register_forward_hook(make_hook(l)))

    for sample in tqdm(split_data, desc=split_name):
        msgs = [{"role": "system", "content": SYS_PROMPT},
                {"role": "user", "content": sample["text"]}]
        text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
        ids = tokenizer(text, return_tensors="pt").input_ids.to(model.device)
        captured.clear()
        with torch.no_grad():
            model(ids)
        for l in range(NUM_LAYERS):
            all_hs[l].append(captured[l].squeeze(0))
        labels.append(sample["label"])
        domains.append(sample["domain"])

    for h in hooks:
        h.remove()

    return {l: np.stack(all_hs[l]) for l in range(NUM_LAYERS)}, np.array(labels), domains

hs = {}
labels_dict = {}
domains_dict = {}
for split_name in ["cal", "train", "valid", "test"]:
    hs[split_name], labels_dict[split_name], domains_dict[split_name] = \
        extract_hidden_states(samples[split_name], split_name)

gc.collect(); torch.cuda.empty_cache()
print("Hidden states extracted for all splits.")

## 5. Probe Sweep → Find L*

In [None]:
print("Probe sweep across all layers:\n")
best_auc, best_layer, plateau = 0, 0, 0

for l in range(NUM_LAYERS):
    scaler_tmp = StandardScaler().fit(hs["cal"][l])
    X = scaler_tmp.transform(hs["cal"][l])
    y = labels_dict["cal"]
    probe = LogisticRegression(max_iter=2000, C=1.0).fit(X, y)
    X_v = scaler_tmp.transform(hs["valid"][l])
    y_v = labels_dict["valid"]
    probs = probe.predict_proba(X_v)[:, 1]
    auc = roc_auc_score(y_v, probs)
    acc = probe.score(X_v, y_v)
    print(f"  Layer {l:2d} | AUC: {auc:.4f} | Acc: {acc:.4f}")
    if auc > best_auc:
        best_auc = auc
        best_layer = l
        plateau = 0
    else:
        plateau += 1

L_STAR = best_layer
print(f"\nL* = {L_STAR} (AUC={best_auc:.4f}, plateau={plateau} layers)")

## 6. Build Steering Vectors

In [None]:
DOMAINS = sorted(set(domains_dict["cal"]))
h_cal = hs["cal"][L_STAR]
y_cal = labels_dict["cal"]
d_cal = domains_dict["cal"]

steering_vecs = {}
# Global vector
tool_mask = y_cal == 1
nontool_mask = y_cal == 0
v_global = h_cal[tool_mask].mean(axis=0) - h_cal[nontool_mask].mean(axis=0)
steering_vecs["global"] = v_global / (np.linalg.norm(v_global) + 1e-8)

# Domain vectors
for dom in DOMAINS:
    dom_mask = np.array([d == dom for d in d_cal])
    tool_dom = dom_mask & tool_mask
    nontool_dom = dom_mask & nontool_mask
    if tool_dom.sum() > 0 and nontool_dom.sum() > 0:
        v_d = h_cal[tool_dom].mean(axis=0) - h_cal[nontool_dom].mean(axis=0)
        steering_vecs[dom] = v_d / (np.linalg.norm(v_d) + 1e-8)
        cos = np.dot(steering_vecs[dom], steering_vecs["global"])
        print(f"  {dom:12s} cos(v_d, v_g) = {cos:.4f}")

print("Vectors built from CAL.")

## 7. Train Router & Probes

In [None]:
scaler = StandardScaler().fit(hs["cal"][L_STAR])

# Router (4-class domain classifier)
X_train_r = scaler.transform(hs["train"][L_STAR])
y_train_r = np.array([DOMAINS.index(d) for d in domains_dict["train"]])
router = LogisticRegression(max_iter=2000, C=1.0, multi_class="multinomial").fit(X_train_r, y_train_r)

# Per-domain probes
probes = {}
for dom in DOMAINS:
    mask = np.array([d == dom for d in domains_dict["train"]])
    if mask.sum() < 4:
        continue
    X_d = scaler.transform(hs["train"][L_STAR][mask])
    y_d = labels_dict["train"][mask]
    probes[dom] = LogisticRegression(max_iter=2000, C=1.0).fit(X_d, y_d)
    print(f"  Probe '{dom}' train acc: {probes[dom].score(X_d, y_d):.4f}")

# Validation
X_val = scaler.transform(hs["valid"][L_STAR])
y_val_r = np.array([DOMAINS.index(d) for d in domains_dict["valid"]])
print(f"  Router valid acc: {router.score(X_val, y_val_r):.4f}")

for dom in DOMAINS:
    mask = np.array([d == dom for d in domains_dict["valid"]])
    if dom in probes and mask.sum() > 0:
        X_d = scaler.transform(hs["valid"][L_STAR][mask])
        y_d = labels_dict["valid"][mask]
        print(f"  Probe '{dom}' valid acc: {probes[dom].score(X_d, y_d):.4f}")

print("Router & probes trained.")

## 8. Generation & Evaluation Functions

In [None]:
def generate(messages, hook_fn=None, layer=None, max_new=128):
    '''Generate text, optionally with ASA hook.'''
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    ids = tokenizer(text, return_tensors="pt").input_ids.to(model.device)
    hook = None
    if hook_fn and layer is not None:
        hook = model.model.layers[layer].register_forward_hook(hook_fn)
    with torch.no_grad():
        out = model.generate(ids, max_new_tokens=max_new, do_sample=False,
                             pad_token_id=tokenizer.eos_token_id)
    if hook:
        hook.remove()
    return tokenizer.decode(out[0][ids.shape[1]:], skip_special_tokens=False)

def evaluate_split(split_data, alpha, tau, beta, use_asa=True):
    '''Evaluate on TEST with given hyperparameters.'''
    global _injected

    def get_steer_vec(domain):
        v_d = steering_vecs.get(domain)
        v_g = steering_vecs.get("global")
        if v_d is None:
            return v_g
        if v_g is not None and beta > 0:
            v = beta * v_g + (1 - beta) * v_d
        else:
            v = v_d
        return v / (np.linalg.norm(v) + 1e-8)

    def asa_hook(module, inp, out):
        global _injected
        if _injected:
            return out
        _injected = True
        h = out[0] if isinstance(out, tuple) else out
        hl = h[:, -1, :].detach().cpu().float().numpy()
        hs_scaled = scaler.transform(hl)
        dom = DOMAINS[router.predict(hs_scaled)[0]]
        pt = probes[dom].predict_proba(hs_scaled)[0, 1] if dom in probes else 0.5
        gate = 1 if pt >= tau else (-1 if pt <= 1 - tau else 0)
        if gate == 0:
            return out
        v = get_steer_vec(dom)
        vt = torch.tensor(v, dtype=torch.float16).to(h.device)
        hn = h.clone()
        hn[:, -1, :] += gate * alpha * vt
        rest = out[1:] if isinstance(out, tuple) else None
        return (hn,) + rest if rest else hn

    y_true, y_pred = [], []
    domain_results = {}

    for sample in tqdm(split_data, desc=f"a={alpha} t={tau}", leave=False):
        msgs = [{"role": "system", "content": SYS_PROMPT},
                {"role": "user", "content": sample["text"]}]
        if use_asa:
            _injected = False
            out = generate(msgs, hook_fn=asa_hook, layer=L_STAR)
        else:
            out = generate(msgs)
        triggered = 1 if TOOL_S in out else 0
        y_true.append(sample["label"])
        y_pred.append(triggered)
        domain = sample["domain"]
        if domain not in domain_results:
            domain_results[domain] = {"y_true": [], "y_pred": []}
        domain_results[domain]["y_true"].append(sample["label"])
        domain_results[domain]["y_pred"].append(triggered)

    tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel()
    results = {
        "precision": precision_score(y_true, y_pred, zero_division=0),
        "recall": recall_score(y_true, y_pred, zero_division=0),
        "f1": f1_score(y_true, y_pred, zero_division=0),
        "fpr": fp / (fp + tn) if (fp + tn) > 0 else 0,
        "accuracy": accuracy_score(y_true, y_pred),
    }
    for dom, dr in domain_results.items():
        tn_d, fp_d, fn_d, tp_d = confusion_matrix(dr["y_true"], dr["y_pred"], labels=[0, 1]).ravel()
        results[f"{dom}_f1"] = f1_score(dr["y_true"], dr["y_pred"], zero_division=0)
        results[f"{dom}_precision"] = precision_score(dr["y_true"], dr["y_pred"], zero_division=0)
        results[f"{dom}_recall"] = recall_score(dr["y_true"], dr["y_pred"], zero_division=0)
        results[f"{dom}_fpr"] = fp_d / (fp_d + tn_d) if (fp_d + tn_d) > 0 else 0
    return results

_injected = False
print("Functions defined.")

## 9. Baseline (No ASA)

In [None]:
print("Evaluating Baseline (no ASA) on TEST set...")
baseline_results = evaluate_split(samples["test"], alpha=0, tau=0, beta=0, use_asa=False)
print(f"\n  Baseline: F1={baseline_results['f1']:.4f}  FPR={baseline_results['fpr']:.4f}  "
      f"Rec={baseline_results['recall']:.4f}  Prec={baseline_results['precision']:.4f}")

## 10. τ/α Grid Sweep (핵심 실험)

> **논문**: Qwen2.5-1.5B에서 α=4, τ=0.60 사용.
> 기존 Liquid 결과: α=1, τ=0.50 → FPR 0.116 but Recall 0.213 (과잉 억제).

In [None]:
EXPERIMENTS = [
    ("current",      0.50, 1, 0.0),
    ("paper",        0.60, 4, 0.0),
    ("gentle",       0.70, 2, 0.0),
    ("balanced",     0.65, 3, 0.0),
    ("conservative", 0.70, 4, 0.0),
    ("selective",    0.75, 5, 0.0),
    ("paper_high",   0.60, 8, 0.0),
    ("minimal",      0.80, 3, 0.0),
]

all_results = {}
for name, tau, alpha, beta in EXPERIMENTS:
    print(f"\n{'='*60}")
    print(f"  {name} (tau={tau}, alpha={alpha})")
    print(f"{'='*60}")
    results = evaluate_split(samples["test"], alpha=alpha, tau=tau, beta=beta, use_asa=True)
    all_results[name] = {"tau": tau, "alpha": alpha, "beta": beta, **results}
    print(f"  F1={results['f1']:.4f}  Prec={results['precision']:.4f}  "
          f"Rec={results['recall']:.4f}  FPR={results['fpr']:.4f}")

print("\n\nAll experiments complete!")

## 11. Results Comparison

In [None]:
QWEN_BEST = {"f1": 0.6185, "precision": 0.7591, "recall": 0.5219,
             "fpr": 0.1656, "accuracy": 0.6781}

sorted_r = sorted(all_results.items(), key=lambda x: x[1]['f1'], reverse=True)

print("=" * 90)
print(f"{'Name':>14s} | {'tau':>4s} | {'a':>3s} | {'F1':>6s} | {'Prec':>6s} | {'Rec':>6s} | {'FPR':>6s} | vs Qwen F1")
print("-" * 90)
print(f"{'baseline':>14s} | {'--':>4s} | {'--':>3s} | {baseline_results['f1']:>6.4f} | "
      f"{baseline_results['precision']:>6.4f} | {baseline_results['recall']:>6.4f} | "
      f"{baseline_results['fpr']:>6.4f} | --")
print(f"{'Qwen+ASA ref':>14s} | {'--':>4s} | {'--':>3s} | {QWEN_BEST['f1']:>6.4f} | "
      f"{QWEN_BEST['precision']:>6.4f} | {QWEN_BEST['recall']:>6.4f} | "
      f"{QWEN_BEST['fpr']:>6.4f} | reference")
print("-" * 90)

for name, r in sorted_r:
    delta = r['f1'] - QWEN_BEST['f1']
    mark = "BEATS!" if delta > 0 else f"{delta:+.4f}"
    star = " ***" if r['f1'] == max(v['f1'] for v in all_results.values()) else ""
    print(f"{name:>14s} | {r['tau']:>4.2f} | {r['alpha']:>3d} | {r['f1']:>6.4f} | "
          f"{r['precision']:>6.4f} | {r['recall']:>6.4f} | {r['fpr']:>6.4f} | {mark}{star}")

print("=" * 90)

best_f1_name = max(all_results, key=lambda k: all_results[k]['f1'])
best_fpr_name = min(all_results, key=lambda k: all_results[k]['fpr'])
print(f"\nBest F1:  {best_f1_name} -> F1={all_results[best_f1_name]['f1']:.4f}")
print(f"Best FPR: {best_fpr_name} -> FPR={all_results[best_fpr_name]['fpr']:.4f}")

## 12. Visualization

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
names = [n for n, _ in sorted_r]
f1s = [r['f1'] for _, r in sorted_r]
fprs = [r['fpr'] for _, r in sorted_r]
recs = [r['recall'] for _, r in sorted_r]

for ax, vals, xlabel, title, ref_val in zip(
    axes, [f1s, fprs, recs],
    ['F1', 'FPR (lower=better)', 'Recall'],
    ['F1 Score', 'False Positive Rate', 'Recall'],
    [QWEN_BEST['f1'], QWEN_BEST['fpr'], QWEN_BEST['recall']]
):
    colors = ['#2ecc71' if (v > ref_val if xlabel != 'FPR (lower=better)' else v < ref_val)
              else '#e74c3c' for v in vals]
    ax.barh(names, vals, color=colors)
    ax.axvline(ref_val, color='orange', linestyle='--', label=f'Qwen+ASA ({ref_val:.3f})')
    ax.axvline(baseline_results.get(xlabel.split()[0].lower(), 0), color='gray',
               linestyle=':', alpha=0.5)
    ax.set_xlabel(xlabel)
    ax.set_title(title)
    ax.legend(fontsize=8)

plt.suptitle('ASA Liquid Prompt - tau/alpha Tuning Results', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig("outputs_liquid/tuning_comparison.png", dpi=150, bbox_inches='tight')
plt.show()

## 13. FPR vs Recall Pareto Front

In [None]:
fig, ax = plt.subplots(figsize=(10, 8))
for name, r in all_results.items():
    ax.scatter(r['fpr'], r['recall'], s=100, zorder=5)
    ax.annotate(f"{name}\nt={r['tau']},a={r['alpha']}\nF1={r['f1']:.3f}",
                (r['fpr'], r['recall']), textcoords="offset points",
                xytext=(10, 5), fontsize=8)

ax.scatter(baseline_results['fpr'], baseline_results['recall'], s=200, marker='X',
           color='gray', zorder=6, label='Baseline (no ASA)')
ax.scatter(QWEN_BEST['fpr'], QWEN_BEST['recall'], s=200, marker='D',
           color='orange', zorder=6, label='Qwen+ASA (best)')

ax.set_xlabel('FPR (lower = better)', fontsize=12)
ax.set_ylabel('Recall (higher = better)', fontsize=12)
ax.set_title('FPR vs Recall Tradeoff', fontsize=14)
ax.legend(loc='upper right')
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig("outputs_liquid/pareto_front.png", dpi=150, bbox_inches='tight')
plt.show()

## 14. Domain Analysis (Best Config)

In [None]:
best_name = max(all_results, key=lambda k: all_results[k]['f1'])
best = all_results[best_name]
print(f"Best: {best_name} (tau={best['tau']}, alpha={best['alpha']})\n")
for dom in sorted(set(domains_dict["test"])):
    f1 = best.get(f"{dom}_f1", 0)
    p = best.get(f"{dom}_precision", 0)
    r = best.get(f"{dom}_recall", 0)
    fpr = best.get(f"{dom}_fpr", 0)
    print(f"  {dom:>12s} | F1={f1:.4f} | P={p:.4f} | R={r:.4f} | FPR={fpr:.4f}")

## 15. Save Results

In [None]:
best_name = max(all_results, key=lambda k: all_results[k]['f1'])
best_fpr_name = min(all_results, key=lambda k: all_results[k]['fpr'])

save_data = {
    "baseline": baseline_results,
    "qwen_style_reference": QWEN_BEST,
    "experiments": all_results,
    "best_f1_config": best_name,
    "best_fpr_config": best_fpr_name,
    "L_star": L_STAR,
    "domains": DOMAINS,
}
with open("outputs_liquid/tuning_results.json", "w") as f:
    json.dump(save_data, f, indent=2, default=str)

# Also save assets for future use
asset_dir = Path("outputs_liquid/asa_assets")
asset_dir.mkdir(parents=True, exist_ok=True)
np.savez(asset_dir / "steering_vectors.npz", **steering_vecs)
pickle.dump(router, open(asset_dir / "router.pkl", "wb"))
pickle.dump(probes, open(asset_dir / "probes.pkl", "wb"))
pickle.dump(scaler, open(asset_dir / "scaler.pkl", "wb"))

best_r = all_results[best_name]
config_save = {"L_star": L_STAR, "alpha": best_r["alpha"], "tau": best_r["tau"],
               "beta": best_r["beta"], "domains": DOMAINS}
json.dump(config_save, open(asset_dir / "config.json", "w"), indent=2)

total_kb = sum(f.stat().st_size for f in asset_dir.iterdir()) / 1024
print(f"Assets saved: {asset_dir} ({total_kb:.0f} KB)")
print(f"Results saved: outputs_liquid/tuning_results.json")
print(f"\nBest F1: {best_name} (tau={best_r['tau']}, alpha={best_r['alpha']}) -> F1={best_r['f1']:.4f}")
print(f"Best FPR: {best_fpr_name} -> FPR={all_results[best_fpr_name]['fpr']:.4f}")