In [1]:
# Notebook 1 — Env & Paths
from pathlib import Path
import os, torch

# 안전: tokenizer 멀티스레드 이슈 방지
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")

# 수정된 모델링 파일 경로 (이미 앞서 저장/교체했다고 가정)
LOPA_MODELING_PATH = Path("lopa_llama_modeling.py").resolve()
assert LOPA_MODELING_PATH.exists(), f"❌ modeling file not found: {LOPA_MODELING_PATH}"

# 학습 산출물 경로 (예시)
BEST_DIR  = Path("/workspace/LatentCOMP_cleaned/LatentCOMP_cleaned/outputs/Llama-3.1-8B-Instruct-LOPA-partial8-0specials/best")
BASE_DIR  = BEST_DIR / "base"   # base 가중치
LORA_DIR  = BEST_DIR / "lora"   # LoRA 어댑터(있으면 사용)
assert BASE_DIR.exists(), f"❌ Base dir not found: {BASE_DIR}"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
    major_cc = torch.cuda.get_device_capability(0)[0]
    DTYPE = torch.bfloat16 if major_cc >= 8 else torch.float16
else:
    DTYPE = torch.float32

print(f"Device={DEVICE}, DTYPE={DTYPE}")


Device=cuda, DTYPE=torch.bfloat16


In [2]:
# Notebook 2 — Load custom modeling, tokenizer, model
import importlib.util, sys, transformers, transformers.models.llama

def load_custom_llama_modeling(modeling_path: Path):
    target_name = "transformers.models.llama.modeling_llama"
    if target_name in sys.modules:
        del sys.modules[target_name]
    spec = importlib.util.spec_from_file_location(target_name, str(modeling_path))
    if spec is None or spec.loader is None:
        raise RuntimeError(f"Failed to load spec for {modeling_path}")
    module = importlib.util.module_from_spec(spec)
    sys.modules[target_name] = module
    spec.loader.exec_module(module)
    # 간단 검증
    for klass in ("LlamaModel", "LlamaForCausalLM"):
        assert hasattr(module, klass), f"{klass} missing in {modeling_path}"
    return module

_ = load_custom_llama_modeling(LOPA_MODELING_PATH)
print("✅ Loaded custom LoPA modeling into transformers.models.llama.modeling_llama")

from transformers import AutoTokenizer
from transformers.models.llama.modeling_llama import LlamaForCausalLM

# tokenizer: best 루트(없으면 base)
tok_src = BEST_DIR if (BEST_DIR / "tokenizer_config.json").exists() else BASE_DIR
tokenizer = AutoTokenizer.from_pretrained(tok_src, use_fast=True)
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
    tokenizer.pad_token = tokenizer.eos_token

# base 모델 로드
model = LlamaForCausalLM.from_pretrained(
    BASE_DIR,
    torch_dtype=DTYPE,
    device_map="auto",
)
print("Loaded base weights.")

# LoRA 병합(있으면)
if LORA_DIR.exists():
    try:
        from peft import PeftModel
        model = PeftModel.from_pretrained(model, LORA_DIR)
        model = model.merge_and_unload()
        print("✅ LoRA merged for inference")
    except Exception as e:
        print(f"⚠️ LoRA merge failed, using wrapped model: {e}")

model.eval()
DEVICE = next(model.parameters()).device
print("Model dtype:", next(model.parameters()).dtype, "| device:", DEVICE)

# 참조용: 전체 레이어 수
TOTAL_LAYERS = model.config.num_hidden_layers
print("TOTAL_LAYERS:", TOTAL_LAYERS)


✅ Loaded custom LoPA modeling into transformers.models.llama.modeling_llama


`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Loaded base weights.
✅ LoRA merged for inference
Model dtype: torch.bfloat16 | device: cuda:0
TOTAL_LAYERS: 32


In [3]:
# Notebook 3 — Helpers
import torch

def build_messages(system: str, document: str, question: str, include_query: bool = True):
    user = f"Question: {question}" if include_query else f"Document:\n{document}\n\n"
    return [{"role": "system", "content": system}, {"role": "user", "content": user}]

def apply_chat_template(tokenizer, messages, add_generation_prompt: bool):
    try:
        return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=add_generation_prompt)
    except TypeError:
        s = tokenizer.apply_chat_template(messages, tokenize=False)
        tmpl = getattr(tokenizer, "chat_template", "") or ""
        if add_generation_prompt and "<|start_header_id|>" in tmpl:
            s += "<|start_header_id|>assistant<|end_header_id|>\n\n"
        return s

def tokens_from_messages(tokenizer, messages, device, add_generation_prompt=False):
    s = apply_chat_template(tokenizer, messages, add_generation_prompt)
    return tokenizer(s, add_special_tokens=False, return_tensors="pt").input_ids.to(device)

def sample_from_logits(logits, *, do_sample=True, temperature=0.7, top_p=0.9,
                       repetition_penalty=1.15, generated_ids=None):
    """ nucleus + repetition penalty + temperature """
    if repetition_penalty != 1.0 and generated_ids is not None and generated_ids.numel() > 0:
        uniq = torch.unique(generated_ids)
        # (logit>0)/penalty, (logit<0)*penalty
        gathered = logits.index_select(dim=-1, index=uniq)
        gathered = torch.where(gathered > 0, gathered / repetition_penalty, gathered * repetition_penalty)
        logits.scatter_(dim=-1, index=uniq.unsqueeze(0), src=gathered)

    if not do_sample:
        return torch.argmax(logits, dim=-1, keepdim=True)

    logits = logits / max(1e-6, temperature)
    probs = torch.softmax(logits, dim=-1)

    if top_p < 1.0:
        sorted_probs, sorted_idx = torch.sort(probs, descending=True)
        cumsum = torch.cumsum(sorted_probs, dim=-1)
        mask = cumsum > top_p
        mask[..., 1:] = mask[..., :-1].clone()
        mask[..., 0] = False
        sorted_probs[mask] = 0.0
        sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)
        next_local = torch.multinomial(sorted_probs, num_samples=1)
        next_id = sorted_idx.gather(-1, next_local)
    else:
        next_id = torch.multinomial(probs, num_samples=1)
    return next_id


In [4]:
# ── helper: rope_mode를 임시로 바꿨다가 복원
import contextlib, torch

@contextlib.contextmanager
def _tmp_rope_mode(model, mode: str):
    inner = getattr(model, "model", model)
    prev = getattr(inner, "lopa_rope_mode", "local")
    try:
        inner.lopa_rope_mode = str(mode)
        yield
    finally:
        inner.lopa_rope_mode = prev


# ── patched lopa_generate: "fast_global" 지원
@torch.inference_mode()
def lopa_generate(
    model,
    tokenizer,
    document: str,
    question: str,
    system_prompt: str = "You are a helpful assistant that answers questions based on the given document. ",
    K: int = 8,
    max_new_tokens: int = 1024,
    temperature: float = 0.7,
    top_p: float = 0.9,
    do_sample: bool = True,
    stop_on_eos: bool = True,
    repetition_penalty: float = 1.15,
    seed_debug: bool = False,
    step_debug: bool = False,
    explicit_empty_upper: bool = False,
    rope_mode: str = "local",       # ← "local" | "global" | "fast_global"
    zero_pad_prefix: bool = False,  # ← global에서만 의미 있음; fast_global에서는 무시
):
    if rope_mode not in ("local", "global", "fast_global"):
        raise ValueError('rope_mode must be "local", "global", or "fast_global"')

    device = next(model.parameters()).device
    inner = getattr(model, "model", model)
    cfg = getattr(model, "config", None)

    # 1) 메시지 토크나이즈
    msgs = build_messages(system_prompt, document, question, include_query=True)
    ids_phase1 = tokens_from_messages(tokenizer, msgs, device, add_generation_prompt=False)
    ids_hdr    = tokens_from_messages(tokenizer, msgs, device, add_generation_prompt=True)

    # 2) 하층 K 레이어만 prefill
    pref = inner.lopa_prefill_lower_k(input_ids=ids_phase1, lower_k=K, use_cache=True)
    lower_cache = pref.past_key_values
    L_all = lower_cache.get_seq_length()

    # 3) upper 캐시 합성
    if rope_mode == "global":
        if zero_pad_prefix:
            # 정확한 zero-pad 등가: upper에 L_all 길이 0-KV
            combined = inner.lopa_build_zero_padded_cache(
                lower_cache, lower_k=K, batch_size=ids_phase1.size(0),
                device=ids_phase1.device, zero_len=L_all
            )
        elif explicit_empty_upper:
            combined = inner.lopa_build_combined_cache(
                lower_cache, lower_k=K, batch_size=ids_phase1.size(0), device=ids_phase1.device
            )
        else:
            combined = lower_cache
    elif rope_mode == "fast_global":
        if zero_pad_prefix:
            print("[warn] zero_pad_prefix=True 는 fast_global에서 무시됩니다 (virtual zero로 동치 보정).")
        # fast_global은 zero-pad를 만들 필요 없음
        if explicit_empty_upper:
            # 디버깅 가독성용으로 원하면 빈 KV를 명시할 수 있음(필수 아님)
            combined = inner.lopa_build_combined_cache(
                lower_cache, lower_k=K, batch_size=ids_phase1.size(0), device=ids_phase1.device
            )
        else:
            combined = lower_cache
    else:  # local
        if explicit_empty_upper:
            combined = inner.lopa_build_combined_cache(
                lower_cache, lower_k=K, batch_size=ids_phase1.size(0), device=ids_phase1.device
            )
        else:
            combined = lower_cache

    # 4) seed → 첫 로짓
    seed_ids = ids_hdr[:, L_all:] if ids_hdr.size(1) > L_all else ids_phase1[:, -1:]

    if seed_debug:
        print(f"[seed] L_all={L_all}, seed_len={seed_ids.size(1)}")

    with _tmp_rope_mode(model, rope_mode):
        seed_out = model.lopa_step_logits(
            input_ids=seed_ids,
            prefix_len=L_all,
            past_key_values=combined,
            attention_mask_total_len=L_all + seed_ids.size(1),
            logits_to_keep=1,
            labels=None,
        )
    logits = seed_out.logits[:, -1, :]
    pkv = seed_out.past_key_values
    total_len = L_all + seed_ids.size(1)

    # 5) 첫 토큰
    generated = []
    next_id = sample_from_logits(
        logits, do_sample=do_sample, temperature=temperature, top_p=top_p,
        repetition_penalty=repetition_penalty, generated_ids=None
    )
    generated.append(next_id)

    eos_ids = set(t for t in [tokenizer.eos_token_id, getattr(tokenizer, "eot_token_id", None)] if t is not None)
    if stop_on_eos and int(next_id[0, 0]) in eos_ids:
        gen_ids = torch.cat(generated, dim=1)
        return tokenizer.decode(gen_ids[0], skip_special_tokens=True), gen_ids

    last = next_id

    # 6) 디코딩 루프
    for _ in range(max_new_tokens - 1):
        if step_debug:
            print("total_len(before):", total_len)

        with _tmp_rope_mode(model, rope_mode):
            step_out = model.lopa_step_logits(
                input_ids=last,
                prefix_len=total_len,                      # 이번 토큰의 절대 시작 위치
                past_key_values=pkv,
                attention_mask_total_len=total_len + 1,    # 한 토큰 추가 후 총 길이
                logits_to_keep=1,
                labels=None,
            )
        logits = step_out.logits[:, -1, :]
        pkv = step_out.past_key_values
        total_len += 1

        next_id = sample_from_logits(
            logits, do_sample=do_sample, temperature=temperature, top_p=top_p,
            repetition_penalty=repetition_penalty,
            generated_ids=torch.cat(generated, dim=1)
        )
        generated.append(next_id)
        last = next_id
        if stop_on_eos and int(next_id[0, 0]) in eos_ids:
            break

    gen_ids = torch.cat(generated, dim=1) if generated else torch.zeros((1, 0), dtype=torch.long, device=device)
    text = tokenizer.decode(gen_ids[0], skip_special_tokens=True)
    return text, gen_ids


In [13]:
# Notebook 5 — Run simple example
doc = """Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting."""
q   = "How much did she earn? Let's think step by step. "
txt, _ = lopa_generate(model, tokenizer, document=doc, question=q,
                       K=8, do_sample=False, rope_mode="fast_global", zero_pad_prefix=True)
print(txt)

[warn] zero_pad_prefix=True 는 fast_global에서 무시됩니다 (virtual zero로 동치 보정).


To find out how much she earned, we need to look at the information provided in the document. However, since you didn't provide any specific details about her earnings, I'll have to make an assumption.

Let's assume that the question is asking for a general idea of what she might earn based on the context given. In this case, the document doesn't mention anything about her salary or income directly. 

However, if we were to consider a more hypothetical scenario where she was working as a writer or editor (given the context), here are some rough estimates:

- According to the Bureau of Labor Statistics, the median annual wage for writers and authors in 2020 was around $67,120.
- For editors, it was around $61,370 per year.

Please note these figures are just examples and may not be relevant to your specific situation. If you could provide more context or clarify which "she" refers to, I'd be happy to try and help further!


In [14]:
# Notebook 5 — Run simple example
doc = """The Nile is a major north-flowing river in northeastern Africa, widely regarded as the longest river in the world..."""
q   = "Which continent is the Nile river located in? Let's think step by step. "

# 재현성(탐욕은 의미 없음, 샘플링시만 영향)
torch.manual_seed(0)
model.model.lopa_rope_mode = "global"
# Greedy (반복 붕괴 진단에 좋음)
answer_greedy, _ = lopa_generate(
    model, tokenizer, document=doc, question=q,
    K=8, max_new_tokens=64,
    do_sample=False, stop_on_eos=True,
    seed_debug=True, step_debug=False,   # 필요시 True로
    explicit_empty_upper=False
)
print("Greedy Answer:\n", answer_greedy)

# Sampling
answer_sample, _ = lopa_generate(
    model, tokenizer, document=doc, question=q,
    K=8, max_new_tokens=64,
    temperature=0.7, top_p=0.9,
    do_sample=True, stop_on_eos=True,
    repetition_penalty=1.15,
    seed_debug=False, step_debug=False
)
print("\nSampled Answer:\n", answer_sample)


[seed] L_all=88, seed_len=4
li | past | global_start | off | local_start
[00]   88 |   92 |   +4 |   88
[01]   88 |   92 |   +4 |   88
[02]   88 |   92 |   +4 |   88
[03]   88 |   92 |   +4 |   88
[04]   88 |   92 |   +4 |   88
[05]   88 |   92 |   +4 |   88
[06]   88 |   92 |   +4 |   88
[07]   88 |   92 |   +4 |   88
[28]    0 |   92 |  +92 |    0
[29]    0 |   92 |  +92 |    0
[30]    0 |   92 |  +92 |    0
[31]    0 |   92 |  +92 |    0
Greedy Answer:
 The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The

Sampled Answer:
 To answer this question, To find the answer to the given information provided provided provided provided provided provided provided provided provided provided provided provided provided provided provided provided provided provided provided provided provided provide

In [6]:
# Notebook 6 — Compare different K
def run_compare_K(doc, q, Ks=(8, None)):
    outs = {}
    for K in Ks:
        k_val = TOTAL_LAYERS if (K is None or K == "full") else int(K)
        txt, _ = lopa_generate(
            model, tokenizer, document=doc, question=q,
            K=k_val, max_new_tokens=64,
            do_sample=False, stop_on_eos=True,
            repetition_penalty=1.0,
            seed_debug=False, step_debug=False
        )
        outs[k_val] = txt
    return outs

outs = run_compare_K(doc, q, Ks=(8, "full"))
for k, txt in outs.items():
    print(f"\n=== K={k} ===\n{txt}")



=== K=8 ===
The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The

=== K=32 ===
To answer the question, we need to look at the given information. 

Step 1: The Nile river is described as being in northeastern Africa.

Step 2: The question asks for the continent where the Nile river is located.

Step 3: Since Africa is a continent and the Nile river is in northeastern Africa


In [None]:
# Cell 1 — 기본 설정
import os, sys, math, contextlib
from pathlib import Path
from typing import List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.cache_utils import DynamicCache

os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
MISTRAL_ASSIST_START = "<Mistral_start>"

def set_seed(seed: int = 42):
    import random
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)

# 🔧 체크포인트 루트 (trainer가 저장한 _best_ckpt)
CKPT_ROOT = "/workspace/LatentCOMP_cleaned/LatentCOMP_cleaned/outputs/Llama-3.1-8B-Instruct-LOPA-partial8-0specials/best"   # <-- 필요시 절대경로로 바꾸세요. 구조: base/, (옵션) lora/, tokenizer 파일
PREFILL_LAYERS = 8                         # 하위 K 레이어만 prefill
DTYPE = "auto"               # "auto" | "bf16" | "fp16" | "fp32"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


In [2]:
# Cell 2 — 로드 유틸

def _get_inner_model(m):
    if hasattr(m, "module"): m = m.module
    try:
        from peft import PeftModel
        if isinstance(m, PeftModel):
            try: m = m.get_base_model()
            except Exception: m = getattr(m, "base_model", m)
    except Exception:
        pass

    for attr in ("model","transformer","backbone","base_model","language_model"):
        if hasattr(m, attr):
            cand = getattr(m, attr)
            if hasattr(cand, "layers") and isinstance(getattr(cand, "layers", None), nn.ModuleList):
                return cand
            if hasattr(cand, "decoder") and hasattr(cand.decoder, "layers") and isinstance(cand.decoder.layers, nn.ModuleList):
                return cand.decoder
    if hasattr(m, "layers") and isinstance(getattr(m, "layers", None), nn.ModuleList):
        return m
    for child in m.modules():
        if child is m: continue
        if hasattr(child, "layers") and isinstance(getattr(child, "layers", None), nn.ModuleList):
            return child
    raise AttributeError("Could not locate inner base model with a .layers attribute")

def _is_mistral_template(tokenizer) -> bool:
    tmpl = getattr(tokenizer, "chat_template", "") or ""
    name = getattr(getattr(tokenizer, "init_kwargs", {}), "get", lambda k, d=None: d)("name_or_path", "")
    return ("[INST]" in tmpl) or ("mistral" in str(name).lower()) or ("mistral" in tmpl.lower())

def ensure_mistral_special_token(tokenizer, model=None):
    if not _is_mistral_template(tokenizer):
        return False
    add_tok = []
    cur = set(tokenizer.get_vocab().keys())
    if MISTRAL_ASSIST_START not in cur:
        add_tok.append(MISTRAL_ASSIST_START)
    if add_tok:
        tokenizer.add_special_tokens({"additional_special_tokens": tokenizer.special_tokens_map_extended.get("additional_special_tokens", []) + add_tok})
        if model is not None:
            try: model.resize_token_embeddings(len(tokenizer))
            except Exception: pass
        return True
    return False

def build_messages(system: str, document: str, question: str, include_query: bool = True):
    user = f"Document:\n{document}\n\nQuestion: {question}" if include_query else f"Document:\n{document}\n\n"
    return [{"role": "system", "content": system}, {"role": "user", "content": user}]

def apply_chat_template(tokenizer, messages, add_generation_prompt: bool):
    try:
        return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=add_generation_prompt)
    except TypeError:
        tmpl = getattr(tokenizer, "chat_template", "") or ""
        s = tokenizer.apply_chat_template(messages, tokenize=False)
        if add_generation_prompt and "<|start_header_id|>" in tmpl:
            s += "<|start_header_id|>assistant<|end_header_id|>\n\n"
        return s

def tokens_from_messages(tokenizer, messages, device, add_generation_prompt=False):
    s = apply_chat_template(tokenizer, messages, add_generation_prompt)
    print(s)
    return tokenizer(s, add_special_tokens=False, return_tensors="pt").input_ids.to(device)

def pkv_len(pkv) -> int:
    if hasattr(pkv, "key_cache"): return len(pkv.key_cache)
    if hasattr(pkv, "layers"):    return len(pkv.layers)
    return len(pkv)

def pkv_get(pkv, idx: int):
    if hasattr(pkv, "key_cache") and hasattr(pkv, "value_cache"):
        return pkv.key_cache[idx], pkv.value_cache[idx]
    if hasattr(pkv, "layers"):
        layer = pkv.layers[idx]
        return layer.keys, layer.values
    return pkv[idx]

def dc_from_subset(pkv_src, layer_indices: List[int]) -> DynamicCache:
    dc = DynamicCache()
    for li in layer_indices:
        k, v = pkv_get(pkv_src, li)
        dc.update(k, v, li)
    return dc

def _kv_meta_from_model(model_like):
    try: cfg = getattr(model_like, "config", None) or getattr(_get_inner_model(model_like), "config", None)
    except Exception: cfg = getattr(_get_inner_model(model_like), "config", None)
    num_heads = getattr(cfg, "num_attention_heads", None)
    num_kv    = getattr(cfg, "num_key_value_heads", None) or num_heads
    hidden    = getattr(cfg, "hidden_size", None)
    head_dim  = (hidden // num_heads) if (hidden and num_heads) else None
    try: dtype = next(_get_inner_model(model_like).parameters()).dtype
    except Exception: dtype = torch.float16 if torch.cuda.is_available() else torch.float32
    return int(num_kv), int(head_dim), dtype

def _make_empty_kv(batch: int, num_kv: int, head_dim: int, device, dtype):
    shape = (batch, num_kv, 0, head_dim)
    k = torch.empty(shape, device=device, dtype=dtype)
    v = torch.empty(shape, device=device, dtype=dtype)
    return k.contiguous(), v.contiguous()

# dtype 선택
if DTYPE == "fp32": TORCH_DTYPE = torch.float32
elif DTYPE == "bf16": TORCH_DTYPE = torch.bfloat16
elif DTYPE == "fp16": TORCH_DTYPE = torch.float16
else:
    TORCH_DTYPE = (torch.bfloat16 if (DEVICE=="cuda" and torch.cuda.is_bf16_supported())
                   else (torch.float16 if DEVICE=="cuda" else torch.float32))

# 토크나이저 & 모델 로드
tok = AutoTokenizer.from_pretrained(CKPT_ROOT, use_fast=True)
if tok.pad_token is None: tok.pad_token = tok.eos_token
tok.padding_side = "right"

base_dir = Path(CKPT_ROOT) / "base"
assert base_dir.is_dir(), f"Base backbone not found at {base_dir}"

model = AutoModelForCausalLM.from_pretrained(str(base_dir), trust_remote_code=False, torch_dtype=TORCH_DTYPE).to(DEVICE)

lora_dir = Path(CKPT_ROOT) / "lora"
if lora_dir.is_dir():
    try:
        from peft import PeftModel
        model = PeftModel.from_pretrained(model, str(lora_dir))
        print("[Info] LoRA adapters loaded.")
    except Exception as e:
        print(f"[Warn] LoRA load failed: {e}")

# Mistral 템플릿이면 start special 보장
_ = ensure_mistral_special_token(tok, model)

# 안정성: eager 강제 (훈련과 동일)
for k in ("attn_implementation","_attn_implementation"):
    try:
        setattr(model.config, k, "eager")
        setattr(_get_inner_model(model).config, k, "eager")
    except Exception:
        pass

model.eval()
print("Loaded model dtype:", next(model.parameters()).dtype, "| device:", DEVICE)


`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

[Info] LoRA adapters loaded.
Loaded model dtype: torch.bfloat16 | device: cuda


In [3]:
@contextlib.contextmanager
def lopa_cache_position_patch_global(model, past_key_values, doc_len: int):
    inner = _get_inner_model(model)

    def _pkv_past_len(li: int) -> int:
        if hasattr(past_key_values, "key_cache"):
            return int(past_key_values.key_cache[li].shape[2])
        if hasattr(past_key_values, "layers"):
            return int(past_key_values.layers[li].keys.shape[2])
        return int(past_key_values[li][0].shape[2])

    nL = len(inner.layers)
    past_lens = [ _pkv_past_len(li) for li in range(nL) ]

    handles = []
    for li, layer in enumerate(inner.layers):
        layer._lopa_past = past_lens[li]
        layer._lopa_li = li

        def _pre_hook(module, args, kwargs):
            li_local = getattr(module, "_lopa_li", 0)
            past_len = getattr(module, "_lopa_past", 0)

            cp = kwargs.get("cache_position", None)
            pi = kwargs.get("position_ids", None)
            start_val = None
            if isinstance(cp, torch.Tensor) and cp.numel() > 0:
                start_val = int(cp.view(-1)[0].item())
            elif isinstance(pi, torch.Tensor) and pi.numel() > 0:
                start_val = int(pi.view(-1)[0].item())
            if start_val is None:
                return args, kwargs

            # 전 레이어가 '전역' 절대 위치(=past_len + doc_len 기준)로 보이도록 보정
            # 하층(K): past_len 약 L_sys+L_doc → doc_len 더해도 큰 변화 없음
            # 상층  : past_len=seed_len → doc_len을 더해 절대 프레임을 하층과 맞춤
            desired_start = past_len + doc_len
            off = start_val - desired_start
            if off != 0:
                if isinstance(cp, torch.Tensor): kwargs["cache_position"] = cp - off
                if isinstance(pi, torch.Tensor): kwargs["position_ids"] = pi - off
            return args, kwargs

        handles.append(layer.register_forward_pre_hook(_pre_hook, with_kwargs=True))
    try:
        yield
    finally:
        for h in handles: h.remove()
        for layer in inner.layers:
            for a in ("_lopa_past","_lopa_li"):
                if hasattr(layer,a): delattr(layer,a)


In [4]:
# Cell 4 — prefill 캐시 구성 (하위 K만 sys+doc, 상위는 빈 KV)

def build_combined_prefill_cache(model, tokenizer, system_prompt: str, document: str, question: str, K: int, device):
    msgs = build_messages(system_prompt, document, question, include_query=True)
    print("=== Phase1 ===")
    ids_phase1 = tokens_from_messages(tokenizer, msgs, device, add_generation_prompt=False)
    print("=== Header ===\n")
    ids_hdr    = tokens_from_messages(tokenizer, msgs, device, add_generation_prompt=True)
    sys_only   = tokens_from_messages(tokenizer, [{"role":"system","content":system_prompt}], device, add_generation_prompt=False)

    L_sys = sys_only.size(1)
    L_all = ids_phase1.size(1)
    L_doc = L_all - L_sys
    assert L_doc > 0, "Document tokens must be > 0"

    inner = _get_inner_model(model)
    full_layers: nn.ModuleList = inner.layers
    n_layers = len(full_layers)
    K_eff = max(0, min(K, n_layers))

    # lower K: [system] prefill
    lower_layers = nn.ModuleList([full_layers[i] for i in range(0, K_eff)])
    inner.layers = lower_layers
    out_sys_low = inner(input_ids=sys_only, attention_mask=torch.ones_like(sys_only), use_cache=True, return_dict=True)
    pkv_sys_low = out_sys_low.past_key_values

    # then [document] prefill (still lower K)
    dc_low_in = dc_from_subset(pkv_sys_low, list(range(K_eff))) if K_eff > 0 else DynamicCache()
    out_low = inner(input_ids=ids_phase1[:, L_sys:], past_key_values=dc_low_in,
                    attention_mask=None, use_cache=True, return_dict=True)
    pkv_low = out_low.past_key_values

    # restore all layers
    inner.layers = full_layers

    # combined cache: lower K = sys+doc, upper = empty
    combined = DynamicCache()
    num_kv, head_dim, kv_dtype = _kv_meta_from_model(model)
    for li in range(n_layers):
        if li < K_eff:
            k_sys, v_sys = pkv_get(pkv_sys_low, li)
            k_low, v_low = pkv_get(pkv_low, li)
            k_sys_slice = k_sys[:, :, :L_sys, :]
            v_sys_slice = v_sys[:, :, :L_sys, :]
            k_doc = k_low[:, :, -L_doc:, :]
            v_doc = v_low[:, :, -L_doc:, :]
            combined.update(torch.cat([k_sys_slice, k_doc], dim=2).contiguous(),
                            torch.cat([v_sys_slice, v_doc], dim=2).contiguous(), li)
        else:
            k_empty, v_empty = _make_empty_kv(1, num_kv, head_dim, DEVICE, kv_dtype)
            combined.update(k_empty, v_empty, li)

    # header tail as seed
    hdr_tail = ids_hdr[:, L_all:]
    seed_default = hdr_tail if hdr_tail.numel() > 0 else ids_phase1[:, -1:]

    meta = dict(L_sys=int(L_sys), L_doc=int(L_doc), L_all=int(L_all),
                n_layers=int(n_layers), K_eff=int(K_eff))
    return combined, seed_default, ids_phase1, ids_hdr, meta

def print_cache_lengths(cache, tag: str):
    print(f"\n[{tag}] per-layer past lengths")
    n = pkv_len(cache)
    for li in range(n):
        k, _ = pkv_get(cache, li)
        print(f"  layer {li:02d}: past_seq = {int(k.shape[2])}")


In [5]:
# Cell 5 — 입력과 prefill 실행/검사

SYSTEM_PROMPT = "You are a helpful assistant that answers questions based on the given document. "
QUESTION = "What is the capital of France?\nFinally, provide your answer in '\\boxed{answer}' at the end of your explanation."
DOCUMENT = "France is a country in Western Europe. Its capital city is Paris. It is known for art, fashion, and cuisine."

combined, seed_default, ids_phase1, ids_hdr, meta = build_combined_prefill_cache(
    model, tok, SYSTEM_PROMPT, DOCUMENT, QUESTION, PREFILL_LAYERS, DEVICE
)

L_sys, L_doc, L_all, n_layers, K_eff = meta["L_sys"], meta["L_doc"], meta["L_all"], meta["n_layers"], meta["K_eff"]

print("Layers total=", n_layers, "| K_eff=", K_eff, "| L_sys=", L_sys, "| L_doc=", L_doc, "| L_all=", L_all)
print_cache_lengths(combined, tag="prefill")

# 검증: 하위 K는 L_sys+L_doc, 상위는 0
ok = True
for li in range(n_layers):
    k, _ = pkv_get(combined, li)
    expect = (L_sys+L_doc) if li < K_eff else 0
    if int(k.shape[2]) != expect:
        ok = False
print("Prefill KV check:", "OK" if ok else "MISMATCH")

# Seed 확정 (header tail 우선, 그 다음 Mistral start, 그 다음 fallback)
seed = seed_default
if seed.numel() == 0:
    if _is_mistral_template(tok):
        tid = tok.convert_tokens_to_ids(MISTRAL_ASSIST_START)
        if tid is not None and tid >= 0:
            seed = torch.tensor([[int(tid)]], device=DEVICE, dtype=ids_hdr.dtype)
    if seed.numel() == 0:
        seed = ids_phase1[:, -1:]
# 시드 뒤에 'Answer: ' 같은 오프너를 1회만 덧붙인다.
SEED_PREFIX = "To find the capital of France,"  # 또는 "답변: ", "Final answer: " 등 태스크에 맞게
prefix_ids = tok(SEED_PREFIX, add_special_tokens=False, return_tensors="pt").input_ids.to(DEVICE)
seed = torch.cat([seed, prefix_ids], dim=1)

# 길이 기대치 검증 위해 meta에 반영(선택)
meta["seed_len"] = int(seed.size(1))
def decode_piece(ids: torch.Tensor, limit=120):
    return tok.decode(ids[0].tolist(), skip_special_tokens=False)[:limit].replace("\n","⏎")

print("\n[Seed]")
print("  seed length =", int(seed.size(1)), "| preview:", decode_piece(seed))


=== Phase1 ===
<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 26 Jul 2024

You are a helpful assistant that answers questions based on the given document.<|eot_id|><|start_header_id|>user<|end_header_id|>

Document:
France is a country in Western Europe. Its capital city is Paris. It is known for art, fashion, and cuisine.

Question: What is the capital of France?
Finally, provide your answer in '\boxed{answer}' at the end of your explanation.<|eot_id|>
=== Header ===

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 26 Jul 2024

You are a helpful assistant that answers questions based on the given document.<|eot_id|><|start_header_id|>user<|end_header_id|>

Document:
France is a country in Western Europe. Its capital city is Paris. It is known for art, fashion, and cuisine.

Question: What is the capital of France?
Finally, provide your answer in '\boxed{answer

In [6]:
# Cell 6 — seed만 넣어 logits/Top-K 보기

@torch.no_grad()
def step_once(model, cache, input_ids, patch=True):
    if patch:
        doc_len = L_sys + L_doc
        with lopa_cache_position_patch_global(model, cache, doc_len):
            out = model(input_ids=input_ids, past_key_values=cache, use_cache=True, return_dict=True)
    else:
        out = model(input_ids=input_ids, past_key_values=cache, use_cache=True, return_dict=True)
    return out.logits, out.past_key_values

def topk_from_logits(logits, tokenizer, k=10, temperature=0.0):
    last = logits[:, -1, :]
    if temperature and temperature > 0.0:
        last = last / float(temperature)
    probs = last.softmax(dim=-1)
    top_p, top_i = torch.topk(probs, k, dim=-1)
    toks = [tokenizer.decode([int(t.item())], skip_special_tokens=False) for t in top_i[0]]
    return [(int(i.item()), float(p.item()), t) for i, p, t in zip(top_i[0], top_p[0], toks)]

# seed → 한 번 forward
logits, combined = step_once(model, combined, input_ids=seed, patch=True)
topk = topk_from_logits(logits, tok, k=10, temperature=0.0)

print("[Next-token distribution | after SEED]")
for r,(tid,prob,txt) in enumerate(topk, 1):
    print(f"  {r:2d}. id={tid:6d}  p={prob:8.5f}  tok={repr(txt)}")

print_cache_lengths(combined, tag="after-seed")

# 기대 길이: 하위K = L_sys+L_doc+len(seed), 상위 = len(seed)
seed_len = int(seed.size(1))
ok2 = True
for li in range(n_layers):
    k, _ = pkv_get(combined, li)
    expect = (L_sys+L_doc+seed_len) if li < K_eff else seed_len
    if int(k.shape[2]) != expect:
        ok2 = False
print("After-seed KV check:", "OK" if ok2 else "MISMATCH")


[Next-token distribution | after SEED]
   1. id=  2057  p= 0.98828  tok=' To'
   2. id=   311  p= 0.01099  tok=' to'
   3. id=    11  p= 0.00004  tok=','
   4. id=  1271  p= 0.00000  tok='To'
   5. id=  5257  p= 0.00000  tok=' TO'
   6. id=   315  p= 0.00000  tok=' of'
   7. id=   320  p= 0.00000  tok=' ('
   8. id=  4194  p= 0.00000  tok='\xa0'
   9. id=  9822  p= 0.00000  tok=' France'
  10. id=   264  p= 0.00000  tok=' a'

[after-seed] per-layer past lengths
  layer 00: past_seq = 110
  layer 01: past_seq = 110
  layer 02: past_seq = 110
  layer 03: past_seq = 110
  layer 04: past_seq = 110
  layer 05: past_seq = 110
  layer 06: past_seq = 110
  layer 07: past_seq = 110
  layer 08: past_seq = 11
  layer 09: past_seq = 11
  layer 10: past_seq = 11
  layer 11: past_seq = 11
  layer 12: past_seq = 11
  layer 13: past_seq = 11
  layer 14: past_seq = 11
  layer 15: past_seq = 11
  layer 16: past_seq = 11
  layer 17: past_seq = 11
  layer 18: past_seq = 11
  layer 19: past_seq = 11
  laye

In [7]:
# Cell 7b — 시드를 이미 한 번 적용한 상태(after-seed)에서 바로 스텝 시작
@torch.no_grad()
def generate_stepwise_from_after_seed(model, tokenizer, cache, first_logits, meta, max_new_tokens=20, temperature=0.0, topk=10):
    logits = first_logits  # Cell 6에서 얻은 logits을 그대로 사용 (seed를 다시 넣지 않음)
    gen = []
    for step in range(max_new_tokens):
        # 1) 직전 logits에서 다음 토큰 결정
        if temperature and temperature > 0.0:
            dist = torch.distributions.Categorical(logits=logits[:, -1, :]/float(temperature))
            next_id = dist.sample().unsqueeze(0)
        else:
            next_id = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
        # if step == 0:
        #     next_id = [[11439]]
        print(next_id)
        # 2) 그 토큰을 입력으로 한 스텝 전진
        logits, cache = step_once(model, cache, input_ids=next_id, patch=True)

        # 3) 리포트
        report = topk_from_logits(logits, tokenizer, k=topk, temperature=temperature)
        gen.append(int(next_id[0,0].item()))
        print(f"\n[Step {step:02d}] input id={int(next_id[0,0])} '{tokenizer.decode([int(next_id[0,0])], skip_special_tokens=False)}'")
        for r,(tid,prob,txt) in enumerate(report, 1):
            print(f"  {r:2d}. id={tid:6d}  p={prob:8.5f}  tok={repr(txt)}'")

        # 4) 길이 검증 (expect: 하위K = L_sys+L_doc+seed_len+(step+1), 상위 = seed_len+(step+1))
        seed_len =  meta["L_all"] - meta["L_sys"]  # = L_doc?  ← 이미 after-seed라면 seed_len은 Cell 5에서 별도로 보관해둔 값을 쓰는 게 깔끔
        # 권장: Cell 5에서 seed_len을 meta에 넣어두세요. 아래는 예시로 seed_len을 별도로 넘겨 받는 쪽이 안전.
        # 여기서는 설명 단순화를 위해 meta에 seed_len이 저장돼 있다고 가정:
        seed_len = meta.get("seed_len", 4)
        ok_step = True
        for li in range(meta["n_layers"]):
            k, _ = pkv_get(cache, li)
            expect = (meta["L_sys"]+meta["L_doc"]+seed_len+(step+1)) if li < meta["K_eff"] else (seed_len+(step+1))
            if int(k.shape[2]) != expect:
                ok_step = False
        print("  KV check:", "OK" if ok_step else "MISMATCH")
        print("  partial:", tokenizer.decode(gen, skip_special_tokens=False).replace("\n","⏎")[:200])

    return tokenizer.decode(gen, skip_special_tokens=False), cache


In [8]:
# Cell 6을 실행한 직후에:
first_logits = logits  # Cell 6에서 나온 logits
# seed_len을 meta에 기록해두면 위 함수의 기대 길이 계산이 정확해집니다.
meta["seed_len"] = int(seed.size(1))  # 여기서는 4

FINAL, combined = generate_stepwise_from_after_seed(
    model, tok, combined, first_logits, meta,
    max_new_tokens=20, temperature=0.0, topk=10
)
print("\n=== [Final output (raw)] ===")
print(FINAL)


tensor([[2057]], device='cuda:0')

[Step 00] input id=2057 ' To'
   1. id=  1505  p= 1.00000  tok=' find''
   2. id=  1766  p= 0.00000  tok=' found''
   3. id=  2733  p= 0.00000  tok=' feel''
   4. id=  2725  p= 0.00000  tok='.find''
   5. id= 14035  p= 0.00000  tok=' finds''
   6. id=  3990  p= 0.00000  tok='find''
   7. id=  1833  p= 0.00000  tok=' follow''
   8. id=  9455  p= 0.00000  tok=' finding''
   9. id=  7531  p= 0.00000  tok=' Find''
  10. id=  5266  p= 0.00000  tok=' fill''
  KV check: OK
  partial:  To
tensor([[1505]], device='cuda:0')

[Step 01] input id=1505 ' find'
   1. id=   279  p= 1.00000  tok=' the''
   2. id=  1820  p= 0.00000  tok='the''
   3. id=   264  p= 0.00000  tok=' a''
   4. id=   578  p= 0.00000  tok=' The''
   5. id=    11  p= 0.00000  tok=',''
   6. id=   315  p= 0.00000  tok=' of''
   7. id=  1505  p= 0.00000  tok=' find''
   8. id=  1766  p= 0.00000  tok=' found''
   9. id=  2057  p= 0.00000  tok=' To''
  10. id=  3247  p= 0.00000  tok=' THE''
  KV ch


[Step 06] input id=11 ','
   1. id=  2057  p= 1.00000  tok=' To''
   2. id=  5257  p= 0.00000  tok=' TO''
   3. id=   350  p= 0.00000  tok=' T''
   4. id=  1271  p= 0.00000  tok='To''
   5. id=   311  p= 0.00000  tok=' to''
   6. id=  3354  p= 0.00000  tok='.To''
   7. id=   578  p= 0.00000  tok=' The''
   8. id=  1183  p= 0.00000  tok=' Tr''
   9. id=  7054  p= 0.00000  tok=' Top''
  10. id=   666  p= 0.00000  tok=' Th''
  KV check: OK
  partial:  To find the capital of France,
tensor([[2057]], device='cuda:0')

[Step 07] input id=2057 ' To'
   1. id=  1505  p= 1.00000  tok=' find''
   2. id=  1766  p= 0.00000  tok=' found''
   3. id=  2733  p= 0.00000  tok=' feel''
   4. id=  2725  p= 0.00000  tok='.find''
   5. id= 14035  p= 0.00000  tok=' finds''
   6. id=  5266  p= 0.00000  tok=' fill''
   7. id=  9455  p= 0.00000  tok=' finding''
   8. id=  1833  p= 0.00000  tok=' follow''
   9. id=  3887  p= 0.00000  tok=' fund''
  10. id=  3990  p= 0.00000  tok='find''
  KV check: OK
  partial

In [10]:

tok.encode("According")

[128000, 11439]

In [None]:
import json, random, torch
from statistics import mean

# 학습 스크립트에서 쓰던 유틸 alias 가져오기 (이미 전 셀에서 wired 했다면 생략)
# from train_lopa_pure import build_messages, tokens_from_messages, _get_inner_model, pkv_get, _kv_meta_from_model, _make_empty_kv, dc_from_subset, lopa_cache_position_patch

DEVICE = next(model.parameters()).device
SYSTEM = "You are a helpful assistant that answers questions based on the given document. "

def make_combined(tokenizer, model, system_prompt, d, q, K):
    msgs = build_messages(system_prompt, d, q, include_query=True)
    ids_phase1 = tokens_from_messages(tokenizer, msgs, DEVICE, add_generation_prompt=False)
    ids_hdr    = tokens_from_messages(tokenizer, msgs, DEVICE, add_generation_prompt=True)
    sys_only   = tokens_from_messages(tokenizer, [{"role":"system","content":system_prompt}], DEVICE, add_generation_prompt=False)

    L_sys, L_all = sys_only.size(1), ids_phase1.size(1)
    L_doc = L_all - L_sys
    assert L_doc > 0
    inner = _get_inner_model(model)
    full_layers = inner.layers
    n_layers = len(full_layers)
    K_eff = max(0, min(int(K), n_layers))

    # lower-K prefill (model(...) 경유로도 가능하지만 여기선 빠른 진단 목적: inner 사용, no_grad)
    lower = torch.nn.ModuleList([full_layers[i] for i in range(K_eff)])
    inner.layers = lower
    with torch.no_grad():
        out_sys = inner(input_ids=sys_only, attention_mask=torch.ones_like(sys_only),
                        use_cache=True, return_dict=True)
    pkv_sys = out_sys.past_key_values
    dc_in = dc_from_subset(pkv_sys, list(range(K_eff))) if K_eff>0 else DynamicCache()
    with torch.no_grad():
        out_low = inner(input_ids=ids_phase1[:, L_sys:], past_key_values=dc_in,
                        attention_mask=None, use_cache=True, return_dict=True)
    pkv_low = out_low.past_key_values
    inner.layers = full_layers

    combined = DynamicCache()
    num_kv, head_dim, kv_dtype = _kv_meta_from_model(model)
    for li in range(n_layers):
        if li < K_eff:
            k_sys, v_sys = pkv_get(pkv_sys, li)
            k_low, v_low = pkv_get(pkv_low, li)
            k_cat = torch.cat([k_sys[:, :, :L_sys, :], k_low[:, :, -L_doc:, :]], dim=2)
            v_cat = torch.cat([v_sys[:, :, :L_sys, :], v_low[:, :, -L_doc:, :]], dim=2)
        else:
            k_cat, v_cat = _make_empty_kv(1, num_kv, head_dim, DEVICE, kv_dtype)
        combined.update(k_cat.contiguous(), v_cat.contiguous(), li)

    hdr_tail = tokens_from_messages(tokenizer, msgs, DEVICE, add_generation_prompt=True)[:, L_all:]
    seed = hdr_tail if hdr_tail.numel() > 0 else ids_phase1[:, -1:]
    return combined, seed, ids_hdr

@torch.inference_mode()
def ce_for_pair(tokenizer, model, q, d, resp, K):
    combined, seed, ids_hdr = make_combined(tokenizer, model, SYSTEM, d, q, K)
    msgs = build_messages(SYSTEM, d, q, include_query=True)
    msgs_ass = msgs + [{"role":"assistant","content":resp}]
    full_ids = tokens_from_messages(tokenizer, msgs_ass, DEVICE, add_generation_prompt=False)
    a = full_ids[:, ids_hdr.size(1):]
    if a.numel() == 0:
        return None
    inp = torch.cat([seed, a], dim=1)
    lab = inp.clone(); lab[:, :seed.size(1)] = -100
    doc_len = L_sys + L_doc
    with lopa_cache_position_patch_global(model, combined, doc_len):
        out = model(input_ids=inp, past_key_values=combined, labels=lab, use_cache=True, return_dict=True)
    return float(out.loss.item()) if out.loss is not None and torch.isfinite(out.loss) else None

# 샘플링
DATA_PATH = "triviaqa_hotpotqa_6000_merged2.jsonl"
pairs = []
with open(DATA_PATH, "r", encoding="utf-8") as f:
    for line in f:
        rec = json.loads(line)
        q = rec.get("question","").strip()
        d = rec.get("document","").strip()
        rs = rec.get("responses",[])
        if q and d and rs:
            pairs.append((q,d,rs[0]))
        if len(pairs) >= 24: break

# Shuffled 문서 만들기
docs = [d for _, d, _ in pairs]
random.shuffle(docs)
shuffled = [(q, d_shuf, r) for (q,_,r), d_shuf in zip(pairs, docs)]

def run_block(title, items, K):
    vals = []
    for (q,d,r) in items:
        v = ce_for_pair(tokenizer, model, q, d, r, K)
        if v is not None: vals.append(v)
    print(f"{title} | K={K} | N={len(vals)} | mean CE={mean(vals):.6f}")
    return vals

print("=== CE sanity ===")
_ = run_block("LoPA", pairs, K=4)
_ = run_block("Full ", pairs, K=9999)     # effectively all layers
_ = run_block("Shuf ", shuffled, K=4)


NameError: name 'model' is not defined

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel  # pip install peft
from LatentCOMP_cleaned.infer_lopa_pure import lopa_generate, ensure_mistral_special_token, _get_inner_model

repo_id = "jeongseokoh/Llama-3.1-8B-Instruct-LOPA-partial4-0specials"
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16 if (device=="cuda" and torch.cuda.is_bf16_supported()) else (torch.float16 if device=="cuda" else torch.float32)

# Tokenizer (saved at repo root)
tokenizer = AutoTokenizer.from_pretrained(repo_id, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# Base (under subfolder=base)
model = AutoModelForCausalLM.from_pretrained(repo_id, subfolder="base", trust_remote_code=False, torch_dtype=dtype)
ensure_mistral_special_token(tokenizer, model)

# Merge LoRA if present (under subfolder=lora)
try:
    model = PeftModel.from_pretrained(model, repo_id, subfolder="lora").merge_and_unload()
except Exception as e:
    print(f"[warn] LoRA merge failed or missing, using base only: {e}")

model = model.to(device).eval()

# Force eager attention (stability)
for k in ("attn_implementation", "_attn_implementation"):
    try:
        setattr(model.config, k, "eager")
        setattr(_get_inner_model(model).config, k, "eager")
    except Exception:
        pass

# Fill these
system = "You are a helpful assistant that answers questions based on the given document. "
document = "Replace with your full document text here."
question = "Replace with your question here."
K = 4  # same as training (partial4)

# Generate with LoPA
text = lopa_generate(
    model, tokenizer,
    system=system, document=document, question=question,
    K=K, device=device,
    max_new_tokens=256, min_length=16,
    temperature=0.7, top_p=0.9, top_k=None,
    do_sample=True, debug=True,
)
print(text)


In [1]:
import torch
import transformers

print("Torch version:", torch.__version__)
print("Transformers version:", transformers.__version__)

Torch version: 2.5.1
Transformers version: 4.56.1


In [2]:
import torch
from transformers import AutoModelForCausalLM

# 1) 경로와 원본 모델 ID 지정
best_dir = "LatentCOMP_cleaned/LatentCOMP_cleaned/outputs/Llama-3.1-8B-Instruct-LOPA-partial8-0specials/best"
base_dir = f"{best_dir}/base"
orig_model_id = "meta-llama/Llama-3.1-8B-Instruct"  # 학습에 쓴 원본 모델 이름

# 2) 원본 베이스 모델 로드(캐시 사용), 저장
m = AutoModelForCausalLM.from_pretrained(
    orig_model_id,
    trust_remote_code=True,          # Llama3 계열이면 True 괜찮습니다
    cache_dir="/data2/jeongseokoh/hub/model"  # 학습 때 쓰던 캐시 경로가 있으면 지정
)
# base/config.json에 remote code 흔적이 섞이지 않도록 auto_map 클리어(있다면)
try:
    setattr(m.config, "auto_map", None)
except Exception:
    pass

m.save_pretrained(base_dir, safe_serialization=True)
del m

print("Rewrote clean base to:", base_dir)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

[2025-09-13 20:39:19,291] [INFO] [real_accelerator.py:260:get_accelerator] Setting ds_accelerator to cuda (auto detect)


/data2/jeongseokoh/miniconda3/envs/vllm/compiler_compat/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status
/data2/jeongseokoh/miniconda3/envs/vllm/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::runtime_error::~runtime_error()@GLIBCXX_3.4'
/data2/jeongseokoh/miniconda3/envs/vllm/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__gxx_personality_v0@CXXABI_1.3'
/data2/jeongseokoh/miniconda3/envs/vllm/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream::tellp()@GLIBCXX_3.4'
/data2/jeongseokoh/miniconda3/envs/vllm/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::substr(unsigned long, unsigned long) const@GLIBCXX_3.4'
/data2/jeongseokoh/miniconda3/envs/vllm/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::_M_replace_aux(unsigned long, unsigned long, unsigne

[2025-09-13 20:39:20,573] [INFO] [logging.py:107:log_dist] [Rank -1] [TorchCheckpointEngine] Initialized with serialization = False
Rewrote clean base to: LatentCOMP_cleaned/LatentCOMP_cleaned/outputs/Llama-3.1-8B-Instruct-LOPA-partial8-0specials/best/base


In [1]:
# LoPA (low-only prefill) scratch pipeline without trust_remote_code
# Torch 2.5.1, Transformers 4.56.1, GPU+SDPA
# - Phase1: lower-K only prefill (no assistant header)
# - Phase2: generation (first feed assistant header tokens step-by-step)
# - Positions unchanged; no remapping
# - LoRA: best/lora attach → merge
# - MinLength 수동 적용(4.56.1 디바이스 불일치 회피)

import os
from pathlib import Path
from typing import List, Optional, Tuple

import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.cache_utils import DynamicCache

# -----------------------------
# Config
# -----------------------------
best_dir = Path("/data2/jeongseokoh/jeongseokoh/LatentCOMP_cleaned/LatentCOMP_cleaned/outputs/Llama-3.1-8B-Instruct-LOPA-partial4-0specials/best")
lora_dir = best_dir / "lora"
backbone_id = "meta-llama/Llama-3.1-8B-Instruct"
K = 4  # prefill layers

device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = (
    torch.bfloat16
    if (device == "cuda" and torch.cuda.is_bf16_supported())
    else (torch.float16 if device == "cuda" else torch.float32)
)

cache_dir_model = "/data2/jeongseokoh/hub/model"
cache_dir_tok   = "/data2/jeongseokoh/hub/tokenizer"

# -----------------------------
# Tokenizer (best_dir 우선)
# -----------------------------
tok_src = str(best_dir) if (best_dir / "tokenizer.json").is_file() else backbone_id
tokenizer = AutoTokenizer.from_pretrained(tok_src, cache_dir=cache_dir_tok, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# -----------------------------
# Backbone + LoRA (merge)
# -----------------------------
base_model = AutoModelForCausalLM.from_pretrained(
    backbone_id,
    trust_remote_code=False,
    torch_dtype=dtype,
    cache_dir=cache_dir_model,
)
base_model.to(device).eval()

from peft import PeftModel
assert lora_dir.is_dir(), f"LoRA folder not found: {lora_dir}"
peft_model = PeftModel.from_pretrained(base_model, str(lora_dir))
model = peft_model.merge_and_unload().to(device).eval()
del peft_model

# SDPA 고정 (필요시 eager로 비교)
for k in ("attn_implementation", "_attn_implementation"):
    try: setattr(model.config, k, "sdpa")
    except: pass
try:
    for k in ("attn_implementation", "_attn_implementation"):
        setattr(model.model.config, k, "sdpa")
except: pass

# -----------------------------
# Cache helpers
# -----------------------------
def pkv_len(pkv) -> int:
    if hasattr(pkv, "key_cache"): return len(pkv.key_cache)
    if hasattr(pkv, "layers"):    return len(pkv.layers)
    try: return len(pkv)
    except: return 0

def pkv_get(pkv, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
    if hasattr(pkv, "key_cache") and hasattr(pkv, "value_cache"):
        return pkv.key_cache[idx], pkv.value_cache[idx]
    if hasattr(pkv, "layers"):
        layer = pkv.layers[idx]
        return layer.keys, layer.values
    return pkv[idx]

def dc_from_subset(pkv_src, layer_indices: List[int]) -> DynamicCache:
    dc = DynamicCache()
    for li in layer_indices:
        k, v = pkv_get(pkv_src, li)
        dc.update(k, v, li)  # layer index 유지
    return dc

def cache_slice_doc_only(k: torch.Tensor, v: torch.Tensor, doc_len: int):
    return k[:, :, -doc_len:, :], v[:, :, -doc_len:, :]

def build_mask(length: int, batch: int = 1):
    return torch.ones(batch, length, device=device, dtype=torch.long)

def build_pos_ids(start: int, length: int, batch: int = 1):
    return torch.arange(start, start + length, device=device, dtype=torch.long).unsqueeze(0).expand(batch, -1)

# -----------------------------
# Prompt builders (Gen3)
# -----------------------------
def build_messages(system: str, document: str, question: str, include_query: bool = True):
    user = f"Document:\n{document}\n\nQuestion: {question}" if include_query else f"Document:\n{document}\n\n"
    return [{"role":"system","content":system},{"role":"user","content":user}]

def apply_chat_template(messages, add_generation_prompt: bool):
    try:
        return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=add_generation_prompt)
    except TypeError:
        s = tokenizer.apply_chat_template(messages, tokenize=False)
        if add_generation_prompt:
            s += "<|start_header_id|>assistant<|end_header_id|>\n\n"
        return s

def tokens_from_messages(messages, add_generation_prompt=False):
    s = apply_chat_template(messages, add_generation_prompt)
    return tokenizer(s, add_special_tokens=False, return_tensors="pt").input_ids.to(device)

# -----------------------------
# LoPA generate (low-only prefill; no upper continue)
# -----------------------------
@torch.inference_mode()
def lopa_generate(
    system: str,
    document: str,
    question: str,
    *,
    K: int = 4,
    max_new_tokens: int = 256,
    min_length: int = 16,
    temperature: float = 0.7,
    top_p: float = 0.9,
    top_k: Optional[int] = None,
    do_sample: bool = True,
    repetition_penalty: Optional[float] = None,
    start_with_assistant_header: bool = True, # Phase2 첫 토큰: header를 단계적으로 투입
    force_attn_impl: Optional[str] = None,   # "sdpa" | "eager"
):
    if force_attn_impl:
        for k in ("attn_implementation", "_attn_implementation"):
            try:
                setattr(model.config, k, force_attn_impl)
                setattr(model.model.config, k, force_attn_impl)
            except Exception:
                pass

    # Phase-1 ids
    msgs = build_messages(system, document, question, include_query=True)
    ids_phase1 = tokens_from_messages(msgs, add_generation_prompt=False)  # [1, L_sys+doc]
    ids_hdr    = tokens_from_messages(msgs, add_generation_prompt=True)   # [1, L_sys+doc + header]
    sys_only   = tokens_from_messages([{"role":"system","content":system}], add_generation_prompt=False)

    L_sys = sys_only.size(1)
    L_all = ids_phase1.size(1)
    L_doc = L_all - L_sys
    assert L_doc > 0, "Phase-1 must include document tokens."

    # 1) System-only prefill (base model)
    out_sys = model.model(
        input_ids=sys_only,
        attention_mask=build_mask(L_sys),
        use_cache=True,
        return_dict=True,
    )
    pkv_sys = out_sys.past_key_values
    n_layers = pkv_len(pkv_sys)

    # 2) Lower-K doc pass (base model) — no upper continue
    K_eff = max(0, min(K, n_layers))
    full_layers: nn.ModuleList = model.model.layers
    lower_layers = nn.ModuleList([full_layers[i] for i in range(0, K_eff)])

    model.model.layers = lower_layers
    dc_low_in = dc_from_subset(pkv_sys, list(range(K_eff))) if K_eff > 0 else DynamicCache()
    attn_doc_full = torch.cat([build_mask(L_sys), build_mask(L_doc)], dim=1)

    out_low = model.model(
        input_ids=ids_phase1[:, L_sys:],  # doc only
        past_key_values=dc_low_in,
        attention_mask=attn_doc_full,     # sys+doc 길이
        use_cache=True,
        return_dict=True,
    )
    pkv_low = out_low.past_key_values

    # 복원
    model.model.layers = full_layers

    # 3) Combine caches
    # lower(0..K-1): sys + doc
    # upper(K..):    sys only  (doc 미포함; generation 시작 시점부터 쌓임)
    combined = DynamicCache()
    # 헤드/차원 shape 확보
    k0_sys, v0_sys = pkv_get(pkv_sys, 0)
    for li in range(n_layers):
        k_sys, v_sys = pkv_get(pkv_sys, li)
        k_sys = k_sys[:, :, :L_sys, :]
        v_sys = v_sys[:, :, :L_sys, :]
        if li < K_eff:
            k_low, v_low = pkv_get(pkv_low, li)
            k_doc, v_doc = cache_slice_doc_only(k_low, v_low, L_doc)
            k_cat = torch.cat([k_sys, k_doc], dim=2).contiguous()
            v_cat = torch.cat([v_sys, v_doc], dim=2).contiguous()
            combined.update(k_cat, v_cat, li)
        else:
            # upper: sys only
            combined.update(k_sys.contiguous(), v_sys.contiguous(), li)

    # 4) Phase-2: Generation
    # seed: assistant header를 한 토큰씩 밀어 넣어 upper past가 L_sys → L_sys+H로 자라도록 함
    if start_with_assistant_header:
        hdr_tail = ids_hdr[:, L_all:]  # header-only tokens (len H)
        H = int(hdr_tail.size(1))
        for j in range(H):
            past_len = pkv_get(combined, 0)[0].shape[2]  # lower 기준: L_sys+L_doc+grown
            attn_mask = torch.cat([build_mask(past_len), build_mask(1)], dim=1)
            step_tok = hdr_tail[:, j:j+1]
            out_seed = model(
                input_ids=step_tok,
                past_key_values=combined,
                attention_mask=attn_mask,
                use_cache=True,
                return_dict=True,
            )
            combined = out_seed.past_key_values

    # 5) Decoding (CausalLM) with safe processors
    from transformers.generation import LogitsProcessorList
    from transformers.generation.logits_process import (
        TemperatureLogitsWarper, TopPLogitsWarper, TopKLogitsWarper,
        RepetitionPenaltyLogitsProcessor,
    )

    eos_id = tokenizer.eos_token_id
    generated_ids = torch.empty((1, 0), dtype=torch.long, device=device)
    # 첫 생성 스텝의 입력은 "직전 토큰" (header 마지막 또는 마지막 doc)
    last = (ids_hdr[:, -1:] if start_with_assistant_header else ids_phase1[:, -1:])

    processors = LogitsProcessorList()
    if repetition_penalty and repetition_penalty != 1.0:
        processors.append(RepetitionPenaltyLogitsProcessor(penalty=float(repetition_penalty)))
    if temperature and temperature != 1.0:
        processors.append(TemperatureLogitsWarper(temperature=float(temperature)))
    if top_k is not None and top_k > 0:
        processors.append(TopKLogitsWarper(top_k=int(top_k), filter_value=-float("inf")))
    if top_p is not None and top_p < 1.0:
        processors.append(TopPLogitsWarper(top_p=float(top_p), min_tokens_to_keep=1))

    cur = 0
    while cur < max_new_tokens:
        # lower past_len은 L_sys+L_doc + grown; upper past_len은 L_sys + grown
        past_len = pkv_get(combined, 0)[0].shape[2]
        attn_mask = torch.cat([build_mask(past_len), build_mask(1)], dim=1)

        out = model(
            input_ids=last,
            past_key_values=combined,
            attention_mask=attn_mask,
            use_cache=True,
            return_dict=True,
        )
        combined = out.past_key_values
        logits = out.logits[:, -1, :].to(torch.float32)

        # 즉시 EOS 방지 (min_length 수동 적용)
        if eos_id is not None and cur < min_length:
            logits[:, eos_id] = -float("inf")

        if not torch.isfinite(logits).all():
            logits = torch.nan_to_num(logits, nan=-1e9, posinf=1e9, neginf=-1e9)

        inp = generated_ids if generated_ids.numel() > 0 else last.new_zeros((1, 0), dtype=torch.long, device=device)
        inp = inp.to(logits.device)
        logits = processors(inp, logits)

        # Fallback-safe sampling
        if not torch.isfinite(logits).any():
            next_tok = torch.argmax(logits, dim=-1, keepdim=True)
        else:
            if do_sample:
                probs = torch.softmax(logits, dim=-1)
                if (not torch.isfinite(probs).all()) or (probs.sum(dim=-1) <= 0).any():
                    next_tok = torch.argmax(logits, dim=-1, keepdim=True)
                else:
                    next_tok = torch.multinomial(probs, num_samples=1)
            else:
                next_tok = torch.argmax(logits, dim=-1, keepdim=True)

        if eos_id is not None and int(next_tok.item()) == eos_id:
            break

        generated_ids = torch.cat([generated_ids, next_tok], dim=1)
        last = next_tok
        cur += 1

    return tokenizer.decode(generated_ids[0], skip_special_tokens=True)

# -----------------------------
# Example
# -----------------------------
system = "You are a helpful assistant that answers questions based on the given document."
document = "The Nile is the longest river in Africa and flows northward through several countries..."
question = "Which continent is the Nile the longest river in?"

answer = lopa_generate(
    system, document, question,
    K=K,
    max_new_tokens=256,
    min_length=16,
    temperature=0.7, top_p=0.9,
    start_with_assistant_header=True,  # Phase2 처음에 header를 단계적으로 투입
    # force_attn_impl="eager",         # 필요시 비교
)
print(answer)


`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

I'm ready to help. What would you like to know? Please go ahead and ask your question.


In [None]:
# LoPA (low-only prefill) scratch pipeline without trust_remote_code
# Torch 2.5.1, Transformers 4.56.1, GPU+SDPA
# - Phase1: lower-K only prefill (no assistant header)
# - Phase2: generation (first feed assistant header tokens step-by-step)
# - Positions unchanged; no remapping
# - LoRA: best/lora attach → merge
# - MinLength 수동 적용(4.56.1 디바이스 불일치 회피)

import os
from pathlib import Path
from typing import List, Optional, Tuple

import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.cache_utils import DynamicCache

# -----------------------------
# Config
# -----------------------------
best_dir = Path("/data2/jeongseokoh/jeongseokoh/LatentCOMP_cleaned/LatentCOMP_cleaned/outputs/Llama-3.1-8B-Instruct-LOPA-partial4-0specials/best")
lora_dir = best_dir / "lora"
backbone_id = "meta-llama/Llama-3.1-8B-Instruct"
K = 4  # prefill layers

device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = (
    torch.bfloat16
    if (device == "cuda" and torch.cuda.is_bf16_supported())
    else (torch.float16 if device == "cuda" else torch.float32)
)

cache_dir_model = "/data2/jeongseokoh/hub/model"
cache_dir_tok   = "/data2/jeongseokoh/hub/tokenizer"

# -----------------------------
# Tokenizer (best_dir 우선)
# -----------------------------
tok_src = str(best_dir) if (best_dir / "tokenizer.json").is_file() else backbone_id
tokenizer = AutoTokenizer.from_pretrained(tok_src, cache_dir=cache_dir_tok, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# -----------------------------
# Backbone + LoRA (merge)
# -----------------------------
base_model = AutoModelForCausalLM.from_pretrained(
    backbone_id,
    trust_remote_code=False,
    torch_dtype=dtype,
    cache_dir=cache_dir_model,
)
base_model.to(device).eval()

from peft import PeftModel
assert lora_dir.is_dir(), f"LoRA folder not found: {lora_dir}"
peft_model = PeftModel.from_pretrained(base_model, str(lora_dir))
model = peft_model.merge_and_unload().to(device).eval()
del peft_model

# SDPA 고정 (필요시 eager로 비교)
for k in ("attn_implementation", "_attn_implementation"):
    try: setattr(model.config, k, "sdpa")
    except: pass
try:
    for k in ("attn_implementation", "_attn_implementation"):
        setattr(model.model.config, k, "sdpa")
except: pass

# -----------------------------
# Cache helpers
# -----------------------------
def pkv_len(pkv) -> int:
    if hasattr(pkv, "key_cache"): return len(pkv.key_cache)
    if hasattr(pkv, "layers"):    return len(pkv.layers)
    try: return len(pkv)
    except: return 0

def pkv_get(pkv, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
    if hasattr(pkv, "key_cache") and hasattr(pkv, "value_cache"):
        return pkv.key_cache[idx], pkv.value_cache[idx]
    if hasattr(pkv, "layers"):
        layer = pkv.layers[idx]
        return layer.keys, layer.values
    return pkv[idx]

def dc_from_subset(pkv_src, layer_indices: List[int]) -> DynamicCache:
    dc = DynamicCache()
    for li in layer_indices:
        k, v = pkv_get(pkv_src, li)
        dc.update(k, v, li)  # layer index 유지
    return dc

def cache_slice_doc_only(k: torch.Tensor, v: torch.Tensor, doc_len: int):
    return k[:, :, -doc_len:, :], v[:, :, -doc_len:, :]

def build_mask(length: int, batch: int = 1):
    return torch.ones(batch, length, device=device, dtype=torch.long)

def build_pos_ids(start: int, length: int, batch: int = 1):
    return torch.arange(start, start + length, device=device, dtype=torch.long).unsqueeze(0).expand(batch, -1)

# -----------------------------
# Prompt builders (Gen3)
# -----------------------------
def build_messages(system: str, document: str, question: str, include_query: bool = True):
    user = f"Document:\n{document}\n\nQuestion: {question}" if include_query else f"Document:\n{document}\n\n"
    return [{"role":"system","content":system},{"role":"user","content":user}]

def apply_chat_template(messages, add_generation_prompt: bool):
    try:
        return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=add_generation_prompt)
    except TypeError:
        s = tokenizer.apply_chat_template(messages, tokenize=False)
        if add_generation_prompt:
            s += "<|start_header_id|>assistant<|end_header_id|>\n\n"
        return s

def tokens_from_messages(messages, add_generation_prompt=False):
    s = apply_chat_template(messages, add_generation_prompt)
    return tokenizer(s, add_special_tokens=False, return_tensors="pt").input_ids.to(device)

# -----------------------------
# LoPA generate (low-only prefill; no upper continue)
# -----------------------------
@torch.inference_mode()
def lopa_generate(
    system: str,
    document: str,
    question: str,
    *,
    K: int = 4,
    max_new_tokens: int = 256,
    min_length: int = 16,
    temperature: float = 0.7,
    top_p: float = 0.9,
    top_k: Optional[int] = None,
    do_sample: bool = True,
    repetition_penalty: Optional[float] = None,
    start_with_assistant_header: bool = True, # Phase2 첫 토큰: header를 단계적으로 투입
    force_attn_impl: Optional[str] = None,   # "sdpa" | "eager"
):
    if force_attn_impl:
        for k in ("attn_implementation", "_attn_implementation"):
            try:
                setattr(model.config, k, force_attn_impl)
                setattr(model.model.config, k, force_attn_impl)
            except Exception:
                pass

    # Phase-1 ids
    msgs = build_messages(system, document, question, include_query=True)
    ids_phase1 = tokens_from_messages(msgs, add_generation_prompt=False)  # [1, L_sys+doc]
    ids_hdr    = tokens_from_messages(msgs, add_generation_prompt=True)   # [1, L_sys+doc + header]
    sys_only   = tokens_from_messages([{"role":"system","content":system}], add_generation_prompt=False)

    L_sys = sys_only.size(1)
    L_all = ids_phase1.size(1)
    L_doc = L_all - L_sys
    assert L_doc > 0, "Phase-1 must include document tokens."

    # 1) System-only prefill (base model)
    out_sys = model.model(
        input_ids=sys_only,
        attention_mask=build_mask(L_sys),
        use_cache=True,
        return_dict=True,
    )
    pkv_sys = out_sys.past_key_values
    n_layers = pkv_len(pkv_sys)

    # 2) Lower-K doc pass (base model) — no upper continue
    K_eff = max(0, min(K, n_layers))
    full_layers: nn.ModuleList = model.model.layers
    lower_layers = nn.ModuleList([full_layers[i] for i in range(0, K_eff)])

    model.model.layers = lower_layers
    dc_low_in = dc_from_subset(pkv_sys, list(range(K_eff))) if K_eff > 0 else DynamicCache()
    attn_doc_full = torch.cat([build_mask(L_sys), build_mask(L_doc)], dim=1)

    out_low = model.model(
        input_ids=ids_phase1[:, L_sys:],  # doc only
        past_key_values=dc_low_in,
        attention_mask=attn_doc_full,     # sys+doc 길이
        use_cache=True,
        return_dict=True,
    )
    pkv_low = out_low.past_key_values

    # 복원
    model.model.layers = full_layers

    # 3) Combine caches
    # lower(0..K-1): sys + doc
    # upper(K..):    sys only  (doc 미포함; generation 시작 시점부터 쌓임)
    combined = DynamicCache()
    # 헤드/차원 shape 확보
    k0_sys, v0_sys = pkv_get(pkv_sys, 0)
    for li in range(n_layers):
        k_sys, v_sys = pkv_get(pkv_sys, li)
        k_sys = k_sys[:, :, :L_sys, :]
        v_sys = v_sys[:, :, :L_sys, :]
        if li < K_eff:
            k_low, v_low = pkv_get(pkv_low, li)
            k_doc, v_doc = cache_slice_doc_only(k_low, v_low, L_doc)
            k_cat = torch.cat([k_sys, k_doc], dim=2).contiguous()
            v_cat = torch.cat([v_sys, v_doc], dim=2).contiguous()
            combined.update(k_cat, v_cat, li)
        else:
            # upper: sys only
            combined.update(k_sys.contiguous(), v_sys.contiguous(), li)

    # 4) Phase-2: Generation
    # seed: assistant header를 한 토큰씩 밀어 넣어 upper past가 L_sys → L_sys+H로 자라도록 함
    if start_with_assistant_header:
        hdr_tail = ids_hdr[:, L_all:]  # header-only tokens (len H)
        H = int(hdr_tail.size(1))
        for j in range(H):
            past_len = pkv_get(combined, 0)[0].shape[2]  # lower 기준: L_sys+L_doc+grown
            attn_mask = torch.cat([build_mask(past_len), build_mask(1)], dim=1)
            step_tok = hdr_tail[:, j:j+1]
            out_seed = model(
                input_ids=step_tok,
                past_key_values=combined,
                attention_mask=attn_mask,
                use_cache=True,
                return_dict=True,
            )
            combined = out_seed.past_key_values

    # 5) Decoding (CausalLM) with safe processors
    from transformers.generation import LogitsProcessorList
    from transformers.generation.logits_process import (
        TemperatureLogitsWarper, TopPLogitsWarper, TopKLogitsWarper,
        RepetitionPenaltyLogitsProcessor,
    )

    eos_id = tokenizer.eos_token_id
    generated_ids = torch.empty((1, 0), dtype=torch.long, device=device)
    # 첫 생성 스텝의 입력은 "직전 토큰" (header 마지막 또는 마지막 doc)
    last = (ids_hdr[:, -1:] if start_with_assistant_header else ids_phase1[:, -1:])

    processors = LogitsProcessorList()
    if repetition_penalty and repetition_penalty != 1.0:
        processors.append(RepetitionPenaltyLogitsProcessor(penalty=float(repetition_penalty)))
    if temperature and temperature != 1.0:
        processors.append(TemperatureLogitsWarper(temperature=float(temperature)))
    if top_k is not None and top_k > 0:
        processors.append(TopKLogitsWarper(top_k=int(top_k), filter_value=-float("inf")))
    if top_p is not None and top_p < 1.0:
        processors.append(TopPLogitsWarper(top_p=float(top_p), min_tokens_to_keep=1))

    cur = 0
    while cur < max_new_tokens:
        # lower past_len은 L_sys+L_doc + grown; upper past_len은 L_sys + grown
        past_len = pkv_get(combined, 0)[0].shape[2]
        attn_mask = torch.cat([build_mask(past_len), build_mask(1)], dim=1)

        out = model(
            input_ids=last,
            past_key_values=combined,
            attention_mask=attn_mask,
            use_cache=True,
            return_dict=True,
        )
        combined = out.past_key_values
        logits = out.logits[:, -1, :].to(torch.float32)

        # 즉시 EOS 방지 (min_length 수동 적용)
        if eos_id is not None and cur < min_length:
            logits[:, eos_id] = -float("inf")

        if not torch.isfinite(logits).all():
            logits = torch.nan_to_num(logits, nan=-1e9, posinf=1e9, neginf=-1e9)

        inp = generated_ids if generated_ids.numel() > 0 else last.new_zeros((1, 0), dtype=torch.long, device=device)
        inp = inp.to(logits.device)
        logits = processors(inp, logits)

        # Fallback-safe sampling
        if not torch.isfinite(logits).any():
            next_tok = torch.argmax(logits, dim=-1, keepdim=True)
        else:
            if do_sample:
                probs = torch.softmax(logits, dim=-1)
                if (not torch.isfinite(probs).all()) or (probs.sum(dim=-1) <= 0).any():
                    next_tok = torch.argmax(logits, dim=-1, keepdim=True)
                else:
                    next_tok = torch.multinomial(probs, num_samples=1)
            else:
                next_tok = torch.argmax(logits, dim=-1, keepdim=True)

        if eos_id is not None and int(next_tok.item()) == eos_id:
            break

        generated_ids = torch.cat([generated_ids, next_tok], dim=1)
        last = next_tok
        cur += 1

    return tokenizer.decode(generated_ids[0], skip_special_tokens=True)

# -----------------------------
# Example
# -----------------------------
system = "You are a helpful assistant that answers questions based on the given document."
document = "The Nile is the longest river in Africa and flows northward through several countries..."
question = "Which continent is the Nile the longest river in?"

answer = lopa_generate(
    system, document, question,
    K=K,
    max_new_tokens=256,
    min_length=16,
    temperature=0.7, top_p=0.9,
    start_with_assistant_header=True,  # Phase2 처음에 header를 단계적으로 투입
    # force_attn_impl="eager",         # 필요시 비교
)
print(answer)


`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

[Lens] L_sys=40, L_doc=35, L_all=75, header_len=4
[Layers] n_layers=32, K_eff=4
[sys] n_layers=32 | L0:K(1, 8, 40, 128) | L1:K(1, 8, 40, 128) | L31:K(1, 8, 40, 128)
[OK]   system_prefill: all present layers seq_len == 40
[low] h_low=(1, 35, 4096)
[low] n_layers=4 | L0:K(1, 8, 75, 128) | L1:K(1, 8, 75, 128) | L3:K(1, 8, 75, 128)
[up] n_layers=32 | L4:K(1, 8, 75, 128) | L5:K(1, 8, 75, 128) | L31:K(1, 8, 75, 128)
[combined(before_hdr)] n_layers=32 | L0:K(1, 8, 75, 128) | L1:K(1, 8, 75, 128) | L31:K(1, 8, 75, 128)
[OK]   combined_prefill: all present layers seq_len == 75
[hdr] past_len before=75, after=79, delta=4
[combined(after_hdr)] n_layers=32 | L0:K(1, 8, 79, 128) | L1:K(1, 8, 79, 128) | L31:K(1, 8, 79, 128)
[decode step 0] past_len 79 -> 80 | next_id=40
[decode step 1] past_len 80 -> 81 | next_id=2846
[decode step 2] past_len 81 -> 82 | next_id=539
[debug answer] I'm not able to answer your question as I don't have any information about a document. Can you please provide the document

In [1]:
import sys, torch
from transformers import AutoTokenizer

model_dir = "/data2/jeongseokoh/jeongseokoh/LatentCOMP_cleaned/LatentCOMP_cleaned/outputs/Llama-3.1-8B-Instruct-LOPA-partial4-0specials/best"
prefill_layers = 32
device = "cuda" if torch.cuda.is_available() else "cpu"

sys.path.insert(0, model_dir)  # best/ 를 import 경로에 추가
from modeling_partial_layer import LlamaForCausalLM

tok = AutoTokenizer.from_pretrained(model_dir)
model = LlamaForCausalLM.from_pretrained(model_dir, device_map="cuda:0", dtype=torch.bfloat16, trust_remote_code=True).to(device).eval()
for k in ("attn_implementation", "_attn_implementation"):
    setattr(model.config, k, "sdpa")
system = "You are a helpful assistant that answers questions based on the given document."
document = "The Nile is the longest river in Africa..."
question = "Which continent is the Nile the longest river in?"

out = model.generate(
    system=system, document=document, query=question,
    compress=False, tokenizer=tok, prefill_layers=prefill_layers,
    max_new_tokens=256, do_sample=False, temperature=0.7, top_p=0.9
)
print(tok.batch_decode(out, skip_special_tokens=True))  # 문자열 또는 문자열 리스트


Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


['system\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYou are a helpful assistant that answers questions based on the given document.user\n\nDocument:\nThe Nile is the longest river in Africa...\n\nQuestion: Which continent is the Nile the longest river in?assistant\n\nThe Nile is the longest river in Africa.']


In [2]:
import os
os.environ["LATENTRAG_DEBUG"]="1"
print(model.lora_debug_status())

peft_available=True | attached=True | merged=False | is_sharded=True | source=/data2/jeongseokoh/jeongseokoh/LatentCOMP_cleaned/LatentCOMP_cleaned/outputs/Llama-3.1-8B-Instruct-LOPA-partial4-0specials/best/lora | error=None | last_used_lora=True


In [None]:
import sys, torch
from pathlib import Path
from transformers import AutoTokenizer

# 경로 설정 (혼동 방지를 위해 절대경로 권장)
best_dir = Path("/data2/jeongseokoh/jeongseokoh/LatentCOMP_cleaned/LatentCOMP_cleaned/outputs/Llama-3.1-8B-Instruct-LOPA-partial4-0specials/best")
lora_dir = best_dir / "lora"
prefill_layers = 4  # 학습 시 사용한 값과 동일하게

device = "cuda" if torch.cuda.is_available() else "cpu"

# best/의 remote-code(LOPA 구현)를 import할 수 있게 경로 추가
sys.path.insert(0, str(best_dir))
from modeling_partial_layer import LlamaForCausalLM  # Llama 3.1 계열용

# 토크나이저: chat template 포함
tok = AutoTokenizer.from_pretrained(str(best_dir))
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

# 1) 백본을 허깅페이스에서 ‘깨끗하게’ 로드 (백본=llama3.1 8B instruct)
#    주의: trust_remote_code=False → 표준 HF 클래스 경로를 통해 가중치만 로드
backbone_id = "meta-llama/Llama-3.1-8B-Instruct"
base_model = LlamaForCausalLM.from_pretrained(
    backbone_id,
    device_map="cuda:0",
    dtype=torch.bfloat16,
    trust_remote_code=False,  # 가중치만 가져오고 동작은 our remote-code 클래스가 담당
    cache_dir="/data2/jeongseokoh/hub/model",
).to(device).eval()

# 2) LoRA 어댑터 attach 후 merge (단일 모델로)
from peft import PeftModel
peft_model = PeftModel.from_pretrained(base_model, str(lora_dir))
merged = peft_model.merge_and_unload().to(device).eval()

# 3) LOPA 추론 (compress=True → partial prefill)
#    이미 merge된 모델이므로 추가 attach를 막기 위해 use_lora=False로 고정
system = "You are a helpful assistant that answers questions based on the given document."
document = "The Nile is the longest river in Africa and flows northward through several countries..."
question = "Which continent is the Nile the longest river in?"

out = merged.generate(
    system=system,
    document=document,
    query=question,
    compress=True,                 # LOPA 경로
    tokenizer=tok,                 # 필수
    prefill_layers=prefill_layers, # 학습과 동일
    use_lora=False,                # 이미 merge됐으므로 재-부착 방지
    max_new_tokens=256,
    do_sample=True,
    temperature=0.7,
    top_p=0.9,
)
print(out)  # 문자열 또는 문자열 리스트


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Fetching 0 files: 0it [00:00, ?it/s]

assistant<|end_header_id|>assistant<|end_header_id|>

Africa.<|eot_id|>
