In [1]:
# ==== 설정 (환경에 맞게 수정) ====
VANILLA_REPO_ID     = "meta-llama/Meta-Llama-3.1-8B-Instruct"       # 바닐라 모델
TRI_REPO_ID         = "jeongseokoh/LoPA_Llama3.1_8B_8_Lowers"       # 너의 LoPA TRI 레포
BASE_SUBFOLDER      = "base"                                        # TRI 레포의 base 가중치 폴더
LORA_SUBFOLDER      = "lora"                                        # LoRA 폴더(없으면 자동 건너뜀)
LOPA_MODELING_PATH  = "./lopa_llama_modeling.py"                    # (완성본) TRI 모델링 파일 경로
TOKENIZER_PATH      = TRI_REPO_ID                                   # 동일 토크나이저 권장
ATTN_IMPL           = "flash_attention_2"                           # "flash_attention_2" | "eager" | "sdpa"
HF_TOKEN            = None  # private면 토큰 넣어줘
LOWER_K = 8
# 길이 (요청 조건)
LEN_S   = 256   # system
LEN_U   = 256   # user(질의)
LEN_D   = 10240  # document
LEN_H   = 4     # assistant header 길이(고정 더미)
LEN_GEN = 512   # 생성 토큰 수

# 실행 환경 권장 변수
import os, torch, sys
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("CUDA_DEVICE_MAX_CONNECTIONS", "1")
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "max_split_size_mb:128,expandable_segments:True")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype  = (torch.bfloat16 if (torch.cuda.is_available() and torch.cuda.is_bf16_supported())
          else (torch.float16 if torch.cuda.is_available() else torch.float32))
print("device:", device, "| dtype:", dtype)


device: cuda | dtype: torch.bfloat16


In [2]:
# ==== 토크나이저 ====
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained(TOKENIZER_PATH, use_fast=True, token=HF_TOKEN)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

# ==== 바닐라 모델 ====
from transformers import AutoModelForCausalLM
van_model = AutoModelForCausalLM.from_pretrained(
    VANILLA_REPO_ID, torch_dtype=dtype, token=HF_TOKEN
).to(device)
van_model.eval()

# ==== TRI 모델링을 transformers 내부에 주입 ====
import importlib.util, transformers, transformers.models.llama as llama_pkg
target_name = "transformers.models.llama.modeling_llama"
spec = importlib.util.spec_from_file_location(target_name, LOPA_MODELING_PATH)
mod = importlib.util.module_from_spec(spec)
sys.modules.pop(target_name, None)
sys.modules[target_name] = mod
spec.loader.exec_module(mod)
setattr(llama_pkg, "modeling_llama", mod)
from transformers.models.llama.modeling_llama import LlamaForCausalLM  # (패치된) 클래스
print("[DEBUG] TRI patch loaded:", LOPA_MODELING_PATH)

# ==== TRI 모델 ====
tri_model = LlamaForCausalLM.from_pretrained(
    TRI_REPO_ID, subfolder=BASE_SUBFOLDER, torch_dtype=dtype, token=HF_TOKEN
).to(device)

# LoRA 어댑터(있으면 자동 로드)
try:
    from peft import PeftModel
    tri_model = PeftModel.from_pretrained(tri_model, TRI_REPO_ID, subfolder=LORA_SUBFOLDER, token=HF_TOKEN)
    tri_model = tri_model.to(device)
    print(f"[info] LoRA adapters loaded: {TRI_REPO_ID}/{LORA_SUBFOLDER}")
except Exception as e:
    print("[info] LoRA not found or skipped:", str(e).split("\n")[0])

# 어텐션 백엔드 지정(둘 다)
for m in (van_model, tri_model):
    try:
        m.config._attn_implementation = ATTN_IMPL
    except Exception:
        pass

# 추론모드 & TF32
torch.set_grad_enabled(False)
try:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
except Exception:
    pass

# TRI API 존재 확인
for need in ("tri_build_caches", "tri_forward_assistant", "tri_step_logits"):
    assert hasattr(tri_model, need), f"Missing TRI API: {need}"

van_model.eval(); tri_model.eval()
print("[OK] models ready")


`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

`torch_dtype` is deprecated! Use `dtype` instead!


[DEBUG] TRI patch loaded: ./lopa_llama_modeling.py


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

[info] LoRA adapters loaded: jeongseokoh/LoPA_Llama3.1_8B_8_Lowers/lora
[OK] models ready


In [3]:
import torch

# "한 글자 → 1토큰" 되는 후보를 찾아서 filler 토큰으로 사용
def pick_single_token_id(candidates=("A", "B", "C", "x", "y", "z")):
    for s in candidates:
        ids = tok.encode(s, add_special_tokens=False)
        if len(ids) == 1 and ids[0] != tok.eos_token_id:
            return ids[0]
    # fallback: vocab 중앙값
    return min(len(tok), 32_000) - 10

FILL_ID = pick_single_token_id()
HEAD_ID = pick_single_token_id(("?", ":", ".", "!"))

def make_ids(length: int, fill_id: int = FILL_ID):
    return torch.full((1, length), fill_id, dtype=torch.long, device=device)

def build_segments_lengths():
    S_ids  = make_ids(LEN_S,  FILL_ID)
    U_ids  = make_ids(LEN_U,  FILL_ID)
    D_ids  = make_ids(LEN_D,  FILL_ID)
    H_ids  = make_ids(LEN_H,  HEAD_ID)
    # LoPA user 프리필 = Document + User
    U_total = torch.cat([D_ids, U_ids], dim=1)  # 1 x (1024+256)
    return S_ids, U_total, H_ids

S_ids, Utotal_ids, H_ids = build_segments_lengths()
print("lens:", S_ids.size(1), Utotal_ids.size(1), H_ids.size(1))


lens: 256 10496 4


In [4]:
# ==== FLOPs 계산 (Attention-only) ====
# 가정:
#  - 어텐션 FLOPs ≈ 4 * (#heads) * head_dim * (Tq * Tk)     (QKᵀ + AV 두 matmul)
#  - Prefill에서 한 번에 블록 길이 N을 넣으면 per-layer FLOPs ≈ 4 * H * Dh * N^2
#  - Generation(단일 토큰 루프) per-layer FLOPs ≈ 4 * H * Dh * sum_i Tk_i
#  - TRI: Prefill은 Upper: (S+H)^2, Lower: (S+U+H)^2 (Header 포함)
#         Generation에서 Upper: Tk_i = S+H+(i-1), Lower: Tk_i = S+U+H+(i-1)

def _attn_cfg(model):
    cfg = model.config
    H = int(cfg.num_attention_heads)
    Dh = getattr(cfg, "head_dim", cfg.hidden_size // H)
    L = int(cfg.num_hidden_layers)
    return L, H, Dh

def flops_vanilla_attention(model, S:int, U:int, Htok:int, A:int):
    L, Hh, Dh = _attn_cfg(model)
    N = S + U + Htok
    # Prefill: per-layer 4*H*Dh*N^2
    pre_per_layer = 4.0 * Hh * Dh * (N**2)
    # Gen: sum Tk = A*N + A*(A-1)/2
    sum_Tk = A * N + (A * (A - 1)) / 2.0
    gen_per_layer = 4.0 * Hh * Dh * sum_Tk
    pre = pre_per_layer * L
    gen = gen_per_layer * L
    tot = pre + gen
    return pre, gen, tot

def flops_tri_attention(model, S:int, U:int, Htok:int, A:int, K:int):
    L, Hh, Dh = _attn_cfg(model)
    K = int(K)
    # Prefill:
    #  Lower K layers: (S+U+H)^2, Upper L-K layers: (S+H)^2
    N_lower = S + U + Htok
    N_upper = S + Htok
    pre_lower = 4.0 * Hh * Dh * (N_lower**2) * K
    pre_upper = 4.0 * Hh * Dh * (N_upper**2) * (L - K)
    pre = pre_lower + pre_upper
    # Generation:
    #  sum Tk lower = A*(S+U+H) + A(A-1)/2
    #  sum Tk upper = A*(S+H)   + A(A-1)/2
    sumTk_lower = A * (S + U + Htok) + (A * (A - 1)) / 2.0
    sumTk_upper = A * (S + Htok)     + (A * (A - 1)) / 2.0
    gen_lower = 4.0 * Hh * Dh * sumTk_lower * K
    gen_upper = 4.0 * Hh * Dh * sumTk_upper * (L - K)
    gen = gen_lower + gen_upper
    tot = pre + gen
    return pre, gen, tot

def to_gflops(x):   # FLOPs → GFLOPs
    return x / 1e9

def throughput_gflops_per_s(gflops, seconds):
    return (gflops / max(1e-9, seconds))


In [5]:
import time

def sync():
    if device.type == "cuda":
        torch.cuda.synchronize()

def reset_peak():
    if device.type == "cuda":
        torch.cuda.reset_peak_memory_stats()

def get_mem_mib():
    if device.type != "cuda": 
        return {"alloc_MiB": 0.0, "reserved_MiB": 0.0}
    alloc = torch.cuda.max_memory_allocated() / (1024**2)
    reserv= torch.cuda.max_memory_reserved() / (1024**2)
    return {"alloc_MiB": round(alloc,2), "reserved_MiB": round(reserv,2)}


In [6]:
import torch.nn.functional as F
import time

@torch.inference_mode()
def profile_vanilla(model, S_ids, Utotal_ids, H_ids, gen_len, greedily=True):
    S = int(S_ids.size(1)); U = int(Utotal_ids.size(1)); Htok = int(H_ids.size(1)); A = int(gen_len)
    prompt = torch.cat([S_ids, Utotal_ids, H_ids], dim=1)
    L_prompt = prompt.size(1)

    # FLOPs 이론값(어텐션 기준)
    FLOP_pre, FLOP_gen, FLOP_tot = flops_vanilla_attention(model, S, U, Htok, A)

    reset_peak(); sync(); t0 = time.perf_counter()
    out = model(input_ids=prompt, use_cache=True)
    sync(); t_prefill = time.perf_counter() - t0

    logits = out.logits[:, -1, :]
    if greedily:
        next_id = torch.argmax(logits, dim=-1, keepdim=True)
    else:
        probs = F.softmax(logits, dim=-1); next_id = torch.multinomial(probs, num_samples=1)
    pkv = out.past_key_values

    sync(); t1 = time.perf_counter()
    out = model(input_ids=next_id, past_key_values=pkv, use_cache=True)
    sync(); t_first = time.perf_counter() - t1

    cur = next_id; pkv = out.past_key_values
    sync(); tg0 = time.perf_counter()
    for _ in range(gen_len - 1):
        out = model(input_ids=cur, past_key_values=pkv, use_cache=True)
        logits = out.logits[:, -1, :]
        if greedily:
            cur = torch.argmax(logits, dim=-1, keepdim=True)
        else:
            probs = F.softmax(logits, dim=-1); cur = torch.multinomial(probs, num_samples=1)
        pkv = out.past_key_values
    sync(); t_gen = time.perf_counter() - tg0

    mem = get_mem_mib()
    ttft = t_prefill + t_first
    ttot = ttft + t_gen

    # Throughput (GFLOPs/s)
    pre_gflops   = to_gflops(FLOP_pre)
    gen_gflops   = to_gflops(FLOP_gen)
    tot_gflops   = to_gflops(FLOP_tot)
    pre_gflops_s = throughput_gflops_per_s(pre_gflops, t_prefill)
    gen_gflops_s = throughput_gflops_per_s(gen_gflops, t_gen)
    tot_gflops_s = throughput_gflops_per_s(tot_gflops, ttot)

    metrics = {
        "model": "vanilla",
        "lens": {"S": S, "U_total": U, "H": Htok, "gen": A, "prompt": L_prompt},
        "prefill_ms": round(t_prefill*1000, 3),
        "first_token_ms": round(t_first*1000, 3),
        "ttft_ms": round(ttft*1000, 3),
        "gen_ms": round(t_gen*1000, 3),
        "prefill_ms_per_tok": round(1000 * t_prefill / L_prompt, 5),
        "gen_ms_per_tok": round(1000 * t_gen / A, 5),
        "total_ms_per_tok": round(1000 * ttot / (L_prompt + A), 5),
        "total_ms": round(ttot*1000, 3),
        "peak_mem_alloc_MiB": mem["alloc_MiB"],
        "peak_mem_reserved_MiB": mem["reserved_MiB"],
        # FLOPs/GFLOPs
        "FLOPs_prefill": int(FLOP_pre), "FLOPs_gen": int(FLOP_gen), "FLOPs_total": int(FLOP_tot),
        "GFLOPs_prefill": round(pre_gflops, 3), "GFLOPs_gen": round(gen_gflops, 3), "GFLOPs_total": round(tot_gflops, 3),
        "GFLOPs/s_prefill": round(pre_gflops_s, 2), "GFLOPs/s_gen": round(gen_gflops_s, 2), "GFLOPs/s_total": round(tot_gflops_s, 2),
    }
    return metrics

@torch.inference_mode()
def profile_tri(model, S_ids, Utotal_ids, H_ids, lower_k, gen_len, greedily=True):
    S = int(S_ids.size(1)); U = int(Utotal_ids.size(1)); Htok = int(H_ids.size(1)); A = int(gen_len)
    L_prompt = S + U + Htok

    # FLOPs 이론값(어텐션 기준)
    FLOP_pre, FLOP_gen, FLOP_tot = flops_tri_attention(model, S, U, Htok, A, K=lower_k)

    reset_peak(); sync(); t0 = time.perf_counter()
    pkv, S_len, U_len = model.tri_build_caches(system_ids=S_ids, user_ids=Utotal_ids, lower_k=lower_k)
    out = model.tri_step_logits(H_ids, lower_k, pkv, S_len, U_len, logits_to_keep=1, labels=None, write_cache=True)
    sync(); t_prefill = time.perf_counter() - t0

    logits = out.logits[:, -1, :]
    if greedily:
        cur = torch.argmax(logits, dim=-1, keepdim=True)
    else:
        probs = F.softmax(logits, dim=-1); cur = torch.multinomial(probs, num_samples=1)

    sync(); t1 = time.perf_counter()
    out = model.tri_step_logits(cur, lower_k, pkv, S_len, U_len, logits_to_keep=1, labels=None, write_cache=True)
    sync(); t_first = time.perf_counter() - t1

    sync(); tg0 = time.perf_counter()
    for _ in range(gen_len - 1):
        out = model.tri_step_logits(cur, lower_k, pkv, S_len, U_len, logits_to_keep=1, labels=None, write_cache=True)
        logits = out.logits[:, -1, :]
        if greedily:
            cur = torch.argmax(logits, dim=-1, keepdim=True)
        else:
            probs = F.softmax(logits, dim=-1); cur = torch.multinomial(probs, num_samples=1)
    sync(); t_gen = time.perf_counter() - tg0

    mem = get_mem_mib()
    ttft = t_prefill + t_first
    ttot = ttft + t_gen

    pre_gflops   = to_gflops(FLOP_pre)
    gen_gflops   = to_gflops(FLOP_gen)
    tot_gflops   = to_gflops(FLOP_tot)
    pre_gflops_s = throughput_gflops_per_s(pre_gflops, t_prefill)
    gen_gflops_s = throughput_gflops_per_s(gen_gflops, t_gen)
    tot_gflops_s = throughput_gflops_per_s(tot_gflops, ttot)

    metrics = {
        "model": f"LoPA-TRI(K={lower_k})",
        "lens": {"S": S, "U_total": U, "H": Htok, "gen": A, "prompt": L_prompt},
        "prefill_ms": round(t_prefill*1000, 3),
        "first_token_ms": round(t_first*1000, 3),
        "ttft_ms": round(ttft*1000, 3),
        "gen_ms": round(t_gen*1000, 3),
        "prefill_ms_per_tok": round(1000 * t_prefill / L_prompt, 5),
        "gen_ms_per_tok": round(1000 * t_gen / A, 5),
        "total_ms_per_tok": round(1000 * ttot / (L_prompt + A), 5),
        "total_ms": round(ttot*1000, 3),
        "peak_mem_alloc_MiB": mem["alloc_MiB"],
        "peak_mem_reserved_MiB": mem["reserved_MiB"],
        # FLOPs/GFLOPs
        "FLOPs_prefill": int(FLOP_pre), "FLOPs_gen": int(FLOP_gen), "FLOPs_total": int(FLOP_tot),
        "GFLOPs_prefill": round(pre_gflops, 3), "GFLOPs_gen": round(gen_gflops, 3), "GFLOPs_total": round(tot_gflops, 3),
        "GFLOPs/s_prefill": round(pre_gflops_s, 2), "GFLOPs/s_gen": round(gen_gflops_s, 2), "GFLOPs/s_total": round(tot_gflops_s, 2),
    }
    return metrics


In [7]:
# 워밍업
_ = profile_vanilla(van_model, S_ids[:, :64], Utotal_ids[:, :128], H_ids[:, :2], gen_len=32, greedily=True)
_ = profile_tri(tri_model, S_ids[:, :64], Utotal_ids[:, :128], H_ids[:, :2], lower_k=LOWER_K, gen_len=32, greedily=True)

# 본측정
m_van = profile_vanilla(van_model, S_ids, Utotal_ids, H_ids, gen_len=LEN_GEN, greedily=True)
m_tri = profile_tri(tri_model, S_ids, Utotal_ids, H_ids, lower_k=LOWER_K, gen_len=LEN_GEN, greedily=True)

import pandas as pd
def _row(m):
    return {
        "Model": m["model"],
        "Prompt(S/U/H)": f'{m["lens"]["S"]}/{m["lens"]["U_total"]}/{m["lens"]["H"]}',
        "GenToks": m["lens"]["gen"],
        "Prefill (ms)": m["prefill_ms"], "FirstTok (ms)": m["first_token_ms"], "TTFT (ms)": m["ttft_ms"], "Gen (ms)": m["gen_ms"],
        "Prefill (ms/tok)": m["prefill_ms_per_tok"], "Gen (ms/tok)": m["gen_ms_per_tok"], "Total (ms/tok)": m["total_ms_per_tok"], "Total (ms)": m["total_ms"],
        "Peak alloc (MiB)": m["peak_mem_alloc_MiB"], "Peak reserved (MiB)": m["peak_mem_reserved_MiB"],
        # FLOPs
        "GFLOPs Prefill (theory)": m["GFLOPs_prefill"], "GFLOPs/s Prefill": m["GFLOPs/s_prefill"],
        "GFLOPs Gen (theory)": m["GFLOPs_gen"], "GFLOPs/s Gen": m["GFLOPs/s_gen"],
        "GFLOPs Total (theory)": m["GFLOPs_total"], "GFLOPs/s Total": m["GFLOPs/s_total"],
    }

df = pd.DataFrame([_row(m_van), _row(m_tri)])
display(df)


Unnamed: 0,Model,Prompt(S/U/H),GenToks,Prefill (ms),FirstTok (ms),TTFT (ms),Gen (ms),Prefill (ms/tok),Gen (ms/tok),Total (ms/tok),Total (ms),Peak alloc (MiB),Peak reserved (MiB),GFLOPs Prefill (theory),GFLOPs/s Prefill,GFLOPs Gen (theory),GFLOPs/s Gen,GFLOPs Total (theory),GFLOPs/s Total
0,vanilla,256/10496/4,512,830.868,18.556,849.424,9030.42,0.07725,17.63754,0.87681,9879.844,34741.07,96564.0,60655.684,73002.8,2955.877,327.32,63611.561,6438.52
1,LoPA-TRI(K=8),256/10496/4,512,278.757,22.794,301.551,11745.232,0.02592,22.93991,1.06911,12046.782,33262.51,96562.0,15190.502,54493.66,842.753,71.75,16033.256,1330.92


In [6]:
@torch.inference_mode()
def profile_tri(model, S_ids, Utotal_ids, H_ids, lower_k=8, gen_len=LEN_GEN, greedily=True):
    # Prefill: S/U_total
    reset_peak(); sync(); t0 = time.perf_counter()
    pkv, S_len, U_len = model.tri_build_caches(system_ids=S_ids, user_ids=Utotal_ids, lower_k=lower_k)
    # Header 기록(+ start logits 1개만)
    out = model.tri_step_logits(H_ids, lower_k, pkv, S_len, U_len, logits_to_keep=1, labels=None, write_cache=True)
    sync(); t_prefill = time.perf_counter() - t0

    # 첫 토큰
    logits = out.logits[:, -1, :]
    if greedily:
        cur = torch.argmax(logits, dim=-1, keepdim=True)
    else:
        probs = F.softmax(logits, dim=-1)
        cur = torch.multinomial(probs, num_samples=1)

    sync(); t1 = time.perf_counter()
    out = model.tri_step_logits(cur, lower_k, pkv, S_len, U_len, logits_to_keep=1, labels=None, write_cache=True)
    sync(); t_first = time.perf_counter() - t1

    # Generation loop
    sync(); tg0 = time.perf_counter()
    for _ in range(gen_len - 1):
        out = model.tri_step_logits(cur, lower_k, pkv, S_len, U_len, logits_to_keep=1, labels=None, write_cache=True)
        logits = out.logits[:, -1, :]
        if greedily:
            cur = torch.argmax(logits, dim=-1, keepdim=True)
        else:
            probs = F.softmax(logits, dim=-1)
            cur = torch.multinomial(probs, num_samples=1)
    sync(); t_gen = time.perf_counter() - tg0

    mem = get_mem_mib()
    L_prompt = S_ids.size(1) + Utotal_ids.size(1) + H_ids.size(1)

    metrics = {
        "model": f"LoPA-TRI(K={lower_k})",
        "lens": {"S": S_ids.size(1), "U_total": Utotal_ids.size(1), "H": H_ids.size(1), "gen": gen_len, "prompt": L_prompt},
        "prefill_ms": round(t_prefill*1000, 3),
        "first_token_ms": round(t_first*1000, 3),
        "ttft_ms": round((t_prefill + t_first)*1000, 3),
        "gen_ms": round(t_gen*1000, 3),
        "prefill_ms_per_tok": round(1000 * t_prefill / L_prompt, 5),
        "gen_ms_per_tok": round(1000 * t_gen / gen_len, 5),
        "total_ms_per_tok": round(1000 * (t_prefill + t_first + t_gen) / (L_prompt + gen_len), 5),
        "total_ms": round((t_prefill + t_first + t_gen)*1000, 3),
        "peak_mem_alloc_MiB": mem["alloc_MiB"],
        "peak_mem_reserved_MiB": mem["reserved_MiB"],
    }
    return metrics


In [7]:
# 워밍업(커널/캐시 안정화): 각 1회 짧게
_ = profile_vanilla(van_model, S_ids[:, :64], Utotal_ids[:, :128], H_ids[:, :2], gen_len=32)
_ = profile_tri(tri_model, S_ids[:, :64], Utotal_ids[:, :128], H_ids[:, :2], lower_k=LOWER_K, gen_len=32)

# 본측정 (정확히 512토큰 생성)
m_van = profile_vanilla(van_model, S_ids, Utotal_ids, H_ids, gen_len=LEN_GEN, greedily=True)
m_tri = profile_tri(tri_model, S_ids, Utotal_ids, H_ids, lower_k=LOWER_K, gen_len=LEN_GEN, greedily=True)

print(m_van)
print(m_tri)


{'model': 'vanilla', 'lens': {'S': 256, 'U_total': 10496, 'H': 4, 'gen': 512, 'prompt': 10756}, 'prefill_ms': 830.267, 'first_token_ms': 18.328, 'ttft_ms': 848.595, 'gen_ms': 9041.353, 'prefill_ms_per_tok': 0.07719, 'gen_ms_per_tok': 17.65889, 'total_ms_per_tok': 0.8777, 'total_ms': 9889.948, 'peak_mem_alloc_MiB': 34741.07, 'peak_mem_reserved_MiB': 96564.0}
{'model': 'LoPA-TRI(K=8)', 'lens': {'S': 256, 'U_total': 10496, 'H': 4, 'gen': 512, 'prompt': 10756}, 'prefill_ms': 279.043, 'first_token_ms': 22.784, 'ttft_ms': 301.827, 'gen_ms': 11811.404, 'prefill_ms_per_tok': 0.02594, 'gen_ms_per_tok': 23.06915, 'total_ms_per_tok': 1.07501, 'total_ms': 12113.231, 'peak_mem_alloc_MiB': 33262.51, 'peak_mem_reserved_MiB': 96562.0}


In [None]:
import pandas as pd
# LoRA가 붙어 있으면 병합해서 언로드 (추론 속도 개선)
try:
    from peft import PeftModel
    if isinstance(tri_model, PeftModel):
        tri_model = tri_model.merge_and_unload()
        tri_model = tri_model.to(device).eval()
        print("[info] merged LoRA into base (unloaded PEFT)")
except Exception as e:
    print("[info] LoRA merge skipped:", e)

# 동일 백엔드 재지정
try:
    tri_model.config._attn_implementation = ATTN_IMPL
except Exception:
    pass

# 재측정
m_tri2 = profile_tri(tri_model, S_ids, Utotal_ids, H_ids, lower_k=LOWER_K, gen_len=LEN_GEN, greedily=True)
display(pd.DataFrame([m_van, m_tri, m_tri2]).assign(note=["vanilla","LoPA(before)","LoPA(merged)"]))


[info] merged LoRA into base (unloaded PEFT)


Unnamed: 0,model,lens,prefill_ms,first_token_ms,ttft_ms,gen_ms,prefill_ms_per_tok,gen_ms_per_tok,total_ms_per_tok,total_ms,peak_mem_alloc_MiB,peak_mem_reserved_MiB,note
0,vanilla,"{'S': 256, 'U_total': 10496, 'H': 4, 'gen': 51...",830.267,18.328,848.595,9041.353,0.07719,17.65889,0.8777,9889.948,34741.07,96564.0,vanilla
1,LoPA-TRI(K=8),"{'S': 256, 'U_total': 10496, 'H': 4, 'gen': 51...",279.043,22.784,301.827,11811.404,0.02594,23.06915,1.07501,12113.231,33262.51,96562.0,LoPA(before)
2,LoPA-TRI(K=8),"{'S': 256, 'U_total': 10496, 'H': 4, 'gen': 51...",182.177,15.669,197.846,8192.882,0.01694,16.00172,0.74465,8390.728,32197.51,96552.0,LoPA(merged)


In [8]:
import pandas as pd

def to_row(m):
    return {
        "Model": m["model"],
        "Prompt(S/U/H)": f'{m["lens"]["S"]}/{m["lens"]["U_total"]}/{m["lens"]["H"]}',
        "GenToks": m["lens"]["gen"],
        "Prefill (ms)": m["prefill_ms"],
        "FirstTok (ms)": m["first_token_ms"],
        "TTFT (ms)": m["ttft_ms"],
        "Gen (ms)": m["gen_ms"],
        "Prefill (ms/tok)": m["prefill_ms_per_tok"],
        "Gen (ms/tok)": m["gen_ms_per_tok"],
        "Total (ms/tok)": m["total_ms_per_tok"],
        "Total (ms)": m["total_ms"],
        "Peak alloc (MiB)": m["peak_mem_alloc_MiB"],
        "Peak reserved (MiB)": m["peak_mem_reserved_MiB"],
    }

df = pd.DataFrame([to_row(m_van), to_row(m_tri)])
display(df)


Unnamed: 0,Model,Prompt(S/U/H),GenToks,Prefill (ms),FirstTok (ms),TTFT (ms),Gen (ms),Prefill (ms/tok),Gen (ms/tok),Total (ms/tok),Total (ms),Peak alloc (MiB),Peak reserved (MiB)
0,vanilla,256/10496/4,512,830.267,18.328,848.595,9041.353,0.07719,17.65889,0.8777,9889.948,34741.07,96564.0
1,LoPA-TRI(K=8),256/10496/4,512,279.043,22.784,301.827,11811.404,0.02594,23.06915,1.07501,12113.231,33262.51,96562.0


In [9]:
# TRI 어텐션 fwd를 훅킹해서 어떤 경로/마스크가 쓰이는지 카운트
from transformers.models.llama.modeling_llama import LlamaAttention

_attn_calls = {"fa2_calls":0, "mask_none":0, "mask_additive":0}
_orig_fwd = LlamaAttention.forward

def _hook_fwd(self, hidden_states, position_embeddings, attention_mask,
              past_key_values=None, cache_position=None, **kwargs):
    impl = getattr(self.config, "_attn_implementation", "eager")
    if impl == "flash_attention_2":
        _attn_calls["fa2_calls"] += 1
    if attention_mask is None:
        _attn_calls["mask_none"] += 1
    else:
        _attn_calls["mask_additive"] += 1
    return _orig_fwd(self, hidden_states, position_embeddings, attention_mask,
                     past_key_values=past_key_values, cache_position=cache_position, **kwargs)

# 훅 설치
LlamaAttention.forward = _hook_fwd

# LoPA 한 번 짧게 돌려서 카운트
_ = profile_tri(tri_model, S_ids[:, :64], Utotal_ids[:, :128], H_ids[:, :2], lower_k=LOWER_K, gen_len=8, greedily=True)
print("[ATTN HOOK] calls:", _attn_calls)

# 사용 후 원복
LlamaAttention.forward = _orig_fwd


[ATTN HOOK] calls: {'fa2_calls': 328, 'mask_none': 328, 'mask_additive': 0}
