In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel  # pip install peft
from LatentCOMP_cleaned.infer_lopa_pure import lopa_generate, ensure_mistral_special_token, _get_inner_model

repo_id = "jeongseokoh/Llama-3.1-8B-Instruct-LOPA-partial4-0specials"
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16 if (device=="cuda" and torch.cuda.is_bf16_supported()) else (torch.float16 if device=="cuda" else torch.float32)

# Tokenizer (saved at repo root)
tokenizer = AutoTokenizer.from_pretrained(repo_id, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# Base (under subfolder=base)
model = AutoModelForCausalLM.from_pretrained(repo_id, subfolder="base", trust_remote_code=False, torch_dtype=dtype)
ensure_mistral_special_token(tokenizer, model)

# Merge LoRA if present (under subfolder=lora)
try:
    model = PeftModel.from_pretrained(model, repo_id, subfolder="lora").merge_and_unload()
except Exception as e:
    print(f"[warn] LoRA merge failed or missing, using base only: {e}")

model = model.to(device).eval()

# Force eager attention (stability)
for k in ("attn_implementation", "_attn_implementation"):
    try:
        setattr(model.config, k, "eager")
        setattr(_get_inner_model(model).config, k, "eager")
    except Exception:
        pass

# Fill these
system = "You are a helpful assistant that answers questions based on the given document. "
document = "Replace with your full document text here."
question = "Replace with your question here."
K = 4  # same as training (partial4)

# Generate with LoPA
text = lopa_generate(
    model, tokenizer,
    system=system, document=document, question=question,
    K=K, device=device,
    max_new_tokens=256, min_length=16,
    temperature=0.7, top_p=0.9, top_k=None,
    do_sample=True, debug=True,
)
print(text)


In [1]:
import torch
import transformers

print("Torch version:", torch.__version__)
print("Transformers version:", transformers.__version__)

Torch version: 2.5.1
Transformers version: 4.56.1


In [2]:
import torch
from transformers import AutoModelForCausalLM

# 1) 경로와 원본 모델 ID 지정
best_dir = "LatentCOMP_cleaned/LatentCOMP_cleaned/outputs/Llama-3.1-8B-Instruct-LOPA-partial8-0specials/best"
base_dir = f"{best_dir}/base"
orig_model_id = "meta-llama/Llama-3.1-8B-Instruct"  # 학습에 쓴 원본 모델 이름

# 2) 원본 베이스 모델 로드(캐시 사용), 저장
m = AutoModelForCausalLM.from_pretrained(
    orig_model_id,
    trust_remote_code=True,          # Llama3 계열이면 True 괜찮습니다
    cache_dir="/data2/jeongseokoh/hub/model"  # 학습 때 쓰던 캐시 경로가 있으면 지정
)
# base/config.json에 remote code 흔적이 섞이지 않도록 auto_map 클리어(있다면)
try:
    setattr(m.config, "auto_map", None)
except Exception:
    pass

m.save_pretrained(base_dir, safe_serialization=True)
del m

print("Rewrote clean base to:", base_dir)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

[2025-09-13 20:39:19,291] [INFO] [real_accelerator.py:260:get_accelerator] Setting ds_accelerator to cuda (auto detect)


/data2/jeongseokoh/miniconda3/envs/vllm/compiler_compat/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status
/data2/jeongseokoh/miniconda3/envs/vllm/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::runtime_error::~runtime_error()@GLIBCXX_3.4'
/data2/jeongseokoh/miniconda3/envs/vllm/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__gxx_personality_v0@CXXABI_1.3'
/data2/jeongseokoh/miniconda3/envs/vllm/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream::tellp()@GLIBCXX_3.4'
/data2/jeongseokoh/miniconda3/envs/vllm/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::substr(unsigned long, unsigned long) const@GLIBCXX_3.4'
/data2/jeongseokoh/miniconda3/envs/vllm/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::_M_replace_aux(unsigned long, unsigned long, unsigne

[2025-09-13 20:39:20,573] [INFO] [logging.py:107:log_dist] [Rank -1] [TorchCheckpointEngine] Initialized with serialization = False
Rewrote clean base to: LatentCOMP_cleaned/LatentCOMP_cleaned/outputs/Llama-3.1-8B-Instruct-LOPA-partial8-0specials/best/base


In [1]:
# LoPA (low-only prefill) scratch pipeline without trust_remote_code
# Torch 2.5.1, Transformers 4.56.1, GPU+SDPA
# - Phase1: lower-K only prefill (no assistant header)
# - Phase2: generation (first feed assistant header tokens step-by-step)
# - Positions unchanged; no remapping
# - LoRA: best/lora attach → merge
# - MinLength 수동 적용(4.56.1 디바이스 불일치 회피)

import os
from pathlib import Path
from typing import List, Optional, Tuple

import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.cache_utils import DynamicCache

# -----------------------------
# Config
# -----------------------------
best_dir = Path("/data2/jeongseokoh/jeongseokoh/LatentCOMP_cleaned/LatentCOMP_cleaned/outputs/Llama-3.1-8B-Instruct-LOPA-partial4-0specials/best")
lora_dir = best_dir / "lora"
backbone_id = "meta-llama/Llama-3.1-8B-Instruct"
K = 4  # prefill layers

device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = (
    torch.bfloat16
    if (device == "cuda" and torch.cuda.is_bf16_supported())
    else (torch.float16 if device == "cuda" else torch.float32)
)

cache_dir_model = "/data2/jeongseokoh/hub/model"
cache_dir_tok   = "/data2/jeongseokoh/hub/tokenizer"

# -----------------------------
# Tokenizer (best_dir 우선)
# -----------------------------
tok_src = str(best_dir) if (best_dir / "tokenizer.json").is_file() else backbone_id
tokenizer = AutoTokenizer.from_pretrained(tok_src, cache_dir=cache_dir_tok, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# -----------------------------
# Backbone + LoRA (merge)
# -----------------------------
base_model = AutoModelForCausalLM.from_pretrained(
    backbone_id,
    trust_remote_code=False,
    torch_dtype=dtype,
    cache_dir=cache_dir_model,
)
base_model.to(device).eval()

from peft import PeftModel
assert lora_dir.is_dir(), f"LoRA folder not found: {lora_dir}"
peft_model = PeftModel.from_pretrained(base_model, str(lora_dir))
model = peft_model.merge_and_unload().to(device).eval()
del peft_model

# SDPA 고정 (필요시 eager로 비교)
for k in ("attn_implementation", "_attn_implementation"):
    try: setattr(model.config, k, "sdpa")
    except: pass
try:
    for k in ("attn_implementation", "_attn_implementation"):
        setattr(model.model.config, k, "sdpa")
except: pass

# -----------------------------
# Cache helpers
# -----------------------------
def pkv_len(pkv) -> int:
    if hasattr(pkv, "key_cache"): return len(pkv.key_cache)
    if hasattr(pkv, "layers"):    return len(pkv.layers)
    try: return len(pkv)
    except: return 0

def pkv_get(pkv, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
    if hasattr(pkv, "key_cache") and hasattr(pkv, "value_cache"):
        return pkv.key_cache[idx], pkv.value_cache[idx]
    if hasattr(pkv, "layers"):
        layer = pkv.layers[idx]
        return layer.keys, layer.values
    return pkv[idx]

def dc_from_subset(pkv_src, layer_indices: List[int]) -> DynamicCache:
    dc = DynamicCache()
    for li in layer_indices:
        k, v = pkv_get(pkv_src, li)
        dc.update(k, v, li)  # layer index 유지
    return dc

def cache_slice_doc_only(k: torch.Tensor, v: torch.Tensor, doc_len: int):
    return k[:, :, -doc_len:, :], v[:, :, -doc_len:, :]

def build_mask(length: int, batch: int = 1):
    return torch.ones(batch, length, device=device, dtype=torch.long)

def build_pos_ids(start: int, length: int, batch: int = 1):
    return torch.arange(start, start + length, device=device, dtype=torch.long).unsqueeze(0).expand(batch, -1)

# -----------------------------
# Prompt builders (Gen3)
# -----------------------------
def build_messages(system: str, document: str, question: str, include_query: bool = True):
    user = f"Document:\n{document}\n\nQuestion: {question}" if include_query else f"Document:\n{document}\n\n"
    return [{"role":"system","content":system},{"role":"user","content":user}]

def apply_chat_template(messages, add_generation_prompt: bool):
    try:
        return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=add_generation_prompt)
    except TypeError:
        s = tokenizer.apply_chat_template(messages, tokenize=False)
        if add_generation_prompt:
            s += "<|start_header_id|>assistant<|end_header_id|>\n\n"
        return s

def tokens_from_messages(messages, add_generation_prompt=False):
    s = apply_chat_template(messages, add_generation_prompt)
    return tokenizer(s, add_special_tokens=False, return_tensors="pt").input_ids.to(device)

# -----------------------------
# LoPA generate (low-only prefill; no upper continue)
# -----------------------------
@torch.inference_mode()
def lopa_generate(
    system: str,
    document: str,
    question: str,
    *,
    K: int = 4,
    max_new_tokens: int = 256,
    min_length: int = 16,
    temperature: float = 0.7,
    top_p: float = 0.9,
    top_k: Optional[int] = None,
    do_sample: bool = True,
    repetition_penalty: Optional[float] = None,
    start_with_assistant_header: bool = True, # Phase2 첫 토큰: header를 단계적으로 투입
    force_attn_impl: Optional[str] = None,   # "sdpa" | "eager"
):
    if force_attn_impl:
        for k in ("attn_implementation", "_attn_implementation"):
            try:
                setattr(model.config, k, force_attn_impl)
                setattr(model.model.config, k, force_attn_impl)
            except Exception:
                pass

    # Phase-1 ids
    msgs = build_messages(system, document, question, include_query=True)
    ids_phase1 = tokens_from_messages(msgs, add_generation_prompt=False)  # [1, L_sys+doc]
    ids_hdr    = tokens_from_messages(msgs, add_generation_prompt=True)   # [1, L_sys+doc + header]
    sys_only   = tokens_from_messages([{"role":"system","content":system}], add_generation_prompt=False)

    L_sys = sys_only.size(1)
    L_all = ids_phase1.size(1)
    L_doc = L_all - L_sys
    assert L_doc > 0, "Phase-1 must include document tokens."

    # 1) System-only prefill (base model)
    out_sys = model.model(
        input_ids=sys_only,
        attention_mask=build_mask(L_sys),
        use_cache=True,
        return_dict=True,
    )
    pkv_sys = out_sys.past_key_values
    n_layers = pkv_len(pkv_sys)

    # 2) Lower-K doc pass (base model) — no upper continue
    K_eff = max(0, min(K, n_layers))
    full_layers: nn.ModuleList = model.model.layers
    lower_layers = nn.ModuleList([full_layers[i] for i in range(0, K_eff)])

    model.model.layers = lower_layers
    dc_low_in = dc_from_subset(pkv_sys, list(range(K_eff))) if K_eff > 0 else DynamicCache()
    attn_doc_full = torch.cat([build_mask(L_sys), build_mask(L_doc)], dim=1)

    out_low = model.model(
        input_ids=ids_phase1[:, L_sys:],  # doc only
        past_key_values=dc_low_in,
        attention_mask=attn_doc_full,     # sys+doc 길이
        use_cache=True,
        return_dict=True,
    )
    pkv_low = out_low.past_key_values

    # 복원
    model.model.layers = full_layers

    # 3) Combine caches
    # lower(0..K-1): sys + doc
    # upper(K..):    sys only  (doc 미포함; generation 시작 시점부터 쌓임)
    combined = DynamicCache()
    # 헤드/차원 shape 확보
    k0_sys, v0_sys = pkv_get(pkv_sys, 0)
    for li in range(n_layers):
        k_sys, v_sys = pkv_get(pkv_sys, li)
        k_sys = k_sys[:, :, :L_sys, :]
        v_sys = v_sys[:, :, :L_sys, :]
        if li < K_eff:
            k_low, v_low = pkv_get(pkv_low, li)
            k_doc, v_doc = cache_slice_doc_only(k_low, v_low, L_doc)
            k_cat = torch.cat([k_sys, k_doc], dim=2).contiguous()
            v_cat = torch.cat([v_sys, v_doc], dim=2).contiguous()
            combined.update(k_cat, v_cat, li)
        else:
            # upper: sys only
            combined.update(k_sys.contiguous(), v_sys.contiguous(), li)

    # 4) Phase-2: Generation
    # seed: assistant header를 한 토큰씩 밀어 넣어 upper past가 L_sys → L_sys+H로 자라도록 함
    if start_with_assistant_header:
        hdr_tail = ids_hdr[:, L_all:]  # header-only tokens (len H)
        H = int(hdr_tail.size(1))
        for j in range(H):
            past_len = pkv_get(combined, 0)[0].shape[2]  # lower 기준: L_sys+L_doc+grown
            attn_mask = torch.cat([build_mask(past_len), build_mask(1)], dim=1)
            step_tok = hdr_tail[:, j:j+1]
            out_seed = model(
                input_ids=step_tok,
                past_key_values=combined,
                attention_mask=attn_mask,
                use_cache=True,
                return_dict=True,
            )
            combined = out_seed.past_key_values

    # 5) Decoding (CausalLM) with safe processors
    from transformers.generation import LogitsProcessorList
    from transformers.generation.logits_process import (
        TemperatureLogitsWarper, TopPLogitsWarper, TopKLogitsWarper,
        RepetitionPenaltyLogitsProcessor,
    )

    eos_id = tokenizer.eos_token_id
    generated_ids = torch.empty((1, 0), dtype=torch.long, device=device)
    # 첫 생성 스텝의 입력은 "직전 토큰" (header 마지막 또는 마지막 doc)
    last = (ids_hdr[:, -1:] if start_with_assistant_header else ids_phase1[:, -1:])

    processors = LogitsProcessorList()
    if repetition_penalty and repetition_penalty != 1.0:
        processors.append(RepetitionPenaltyLogitsProcessor(penalty=float(repetition_penalty)))
    if temperature and temperature != 1.0:
        processors.append(TemperatureLogitsWarper(temperature=float(temperature)))
    if top_k is not None and top_k > 0:
        processors.append(TopKLogitsWarper(top_k=int(top_k), filter_value=-float("inf")))
    if top_p is not None and top_p < 1.0:
        processors.append(TopPLogitsWarper(top_p=float(top_p), min_tokens_to_keep=1))

    cur = 0
    while cur < max_new_tokens:
        # lower past_len은 L_sys+L_doc + grown; upper past_len은 L_sys + grown
        past_len = pkv_get(combined, 0)[0].shape[2]
        attn_mask = torch.cat([build_mask(past_len), build_mask(1)], dim=1)

        out = model(
            input_ids=last,
            past_key_values=combined,
            attention_mask=attn_mask,
            use_cache=True,
            return_dict=True,
        )
        combined = out.past_key_values
        logits = out.logits[:, -1, :].to(torch.float32)

        # 즉시 EOS 방지 (min_length 수동 적용)
        if eos_id is not None and cur < min_length:
            logits[:, eos_id] = -float("inf")

        if not torch.isfinite(logits).all():
            logits = torch.nan_to_num(logits, nan=-1e9, posinf=1e9, neginf=-1e9)

        inp = generated_ids if generated_ids.numel() > 0 else last.new_zeros((1, 0), dtype=torch.long, device=device)
        inp = inp.to(logits.device)
        logits = processors(inp, logits)

        # Fallback-safe sampling
        if not torch.isfinite(logits).any():
            next_tok = torch.argmax(logits, dim=-1, keepdim=True)
        else:
            if do_sample:
                probs = torch.softmax(logits, dim=-1)
                if (not torch.isfinite(probs).all()) or (probs.sum(dim=-1) <= 0).any():
                    next_tok = torch.argmax(logits, dim=-1, keepdim=True)
                else:
                    next_tok = torch.multinomial(probs, num_samples=1)
            else:
                next_tok = torch.argmax(logits, dim=-1, keepdim=True)

        if eos_id is not None and int(next_tok.item()) == eos_id:
            break

        generated_ids = torch.cat([generated_ids, next_tok], dim=1)
        last = next_tok
        cur += 1

    return tokenizer.decode(generated_ids[0], skip_special_tokens=True)

# -----------------------------
# Example
# -----------------------------
system = "You are a helpful assistant that answers questions based on the given document."
document = "The Nile is the longest river in Africa and flows northward through several countries..."
question = "Which continent is the Nile the longest river in?"

answer = lopa_generate(
    system, document, question,
    K=K,
    max_new_tokens=256,
    min_length=16,
    temperature=0.7, top_p=0.9,
    start_with_assistant_header=True,  # Phase2 처음에 header를 단계적으로 투입
    # force_attn_impl="eager",         # 필요시 비교
)
print(answer)


`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

I'm ready to help. What would you like to know? Please go ahead and ask your question.


In [None]:
# LoPA (low-only prefill) scratch pipeline without trust_remote_code
# Torch 2.5.1, Transformers 4.56.1, GPU+SDPA
# - Phase1: lower-K only prefill (no assistant header)
# - Phase2: generation (first feed assistant header tokens step-by-step)
# - Positions unchanged; no remapping
# - LoRA: best/lora attach → merge
# - MinLength 수동 적용(4.56.1 디바이스 불일치 회피)

import os
from pathlib import Path
from typing import List, Optional, Tuple

import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.cache_utils import DynamicCache

# -----------------------------
# Config
# -----------------------------
best_dir = Path("/data2/jeongseokoh/jeongseokoh/LatentCOMP_cleaned/LatentCOMP_cleaned/outputs/Llama-3.1-8B-Instruct-LOPA-partial4-0specials/best")
lora_dir = best_dir / "lora"
backbone_id = "meta-llama/Llama-3.1-8B-Instruct"
K = 4  # prefill layers

device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = (
    torch.bfloat16
    if (device == "cuda" and torch.cuda.is_bf16_supported())
    else (torch.float16 if device == "cuda" else torch.float32)
)

cache_dir_model = "/data2/jeongseokoh/hub/model"
cache_dir_tok   = "/data2/jeongseokoh/hub/tokenizer"

# -----------------------------
# Tokenizer (best_dir 우선)
# -----------------------------
tok_src = str(best_dir) if (best_dir / "tokenizer.json").is_file() else backbone_id
tokenizer = AutoTokenizer.from_pretrained(tok_src, cache_dir=cache_dir_tok, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# -----------------------------
# Backbone + LoRA (merge)
# -----------------------------
base_model = AutoModelForCausalLM.from_pretrained(
    backbone_id,
    trust_remote_code=False,
    torch_dtype=dtype,
    cache_dir=cache_dir_model,
)
base_model.to(device).eval()

from peft import PeftModel
assert lora_dir.is_dir(), f"LoRA folder not found: {lora_dir}"
peft_model = PeftModel.from_pretrained(base_model, str(lora_dir))
model = peft_model.merge_and_unload().to(device).eval()
del peft_model

# SDPA 고정 (필요시 eager로 비교)
for k in ("attn_implementation", "_attn_implementation"):
    try: setattr(model.config, k, "sdpa")
    except: pass
try:
    for k in ("attn_implementation", "_attn_implementation"):
        setattr(model.model.config, k, "sdpa")
except: pass

# -----------------------------
# Cache helpers
# -----------------------------
def pkv_len(pkv) -> int:
    if hasattr(pkv, "key_cache"): return len(pkv.key_cache)
    if hasattr(pkv, "layers"):    return len(pkv.layers)
    try: return len(pkv)
    except: return 0

def pkv_get(pkv, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
    if hasattr(pkv, "key_cache") and hasattr(pkv, "value_cache"):
        return pkv.key_cache[idx], pkv.value_cache[idx]
    if hasattr(pkv, "layers"):
        layer = pkv.layers[idx]
        return layer.keys, layer.values
    return pkv[idx]

def dc_from_subset(pkv_src, layer_indices: List[int]) -> DynamicCache:
    dc = DynamicCache()
    for li in layer_indices:
        k, v = pkv_get(pkv_src, li)
        dc.update(k, v, li)  # layer index 유지
    return dc

def cache_slice_doc_only(k: torch.Tensor, v: torch.Tensor, doc_len: int):
    return k[:, :, -doc_len:, :], v[:, :, -doc_len:, :]

def build_mask(length: int, batch: int = 1):
    return torch.ones(batch, length, device=device, dtype=torch.long)

def build_pos_ids(start: int, length: int, batch: int = 1):
    return torch.arange(start, start + length, device=device, dtype=torch.long).unsqueeze(0).expand(batch, -1)

# -----------------------------
# Prompt builders (Gen3)
# -----------------------------
def build_messages(system: str, document: str, question: str, include_query: bool = True):
    user = f"Document:\n{document}\n\nQuestion: {question}" if include_query else f"Document:\n{document}\n\n"
    return [{"role":"system","content":system},{"role":"user","content":user}]

def apply_chat_template(messages, add_generation_prompt: bool):
    try:
        return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=add_generation_prompt)
    except TypeError:
        s = tokenizer.apply_chat_template(messages, tokenize=False)
        if add_generation_prompt:
            s += "<|start_header_id|>assistant<|end_header_id|>\n\n"
        return s

def tokens_from_messages(messages, add_generation_prompt=False):
    s = apply_chat_template(messages, add_generation_prompt)
    return tokenizer(s, add_special_tokens=False, return_tensors="pt").input_ids.to(device)

# -----------------------------
# LoPA generate (low-only prefill; no upper continue)
# -----------------------------
@torch.inference_mode()
def lopa_generate(
    system: str,
    document: str,
    question: str,
    *,
    K: int = 4,
    max_new_tokens: int = 256,
    min_length: int = 16,
    temperature: float = 0.7,
    top_p: float = 0.9,
    top_k: Optional[int] = None,
    do_sample: bool = True,
    repetition_penalty: Optional[float] = None,
    start_with_assistant_header: bool = True, # Phase2 첫 토큰: header를 단계적으로 투입
    force_attn_impl: Optional[str] = None,   # "sdpa" | "eager"
):
    if force_attn_impl:
        for k in ("attn_implementation", "_attn_implementation"):
            try:
                setattr(model.config, k, force_attn_impl)
                setattr(model.model.config, k, force_attn_impl)
            except Exception:
                pass

    # Phase-1 ids
    msgs = build_messages(system, document, question, include_query=True)
    ids_phase1 = tokens_from_messages(msgs, add_generation_prompt=False)  # [1, L_sys+doc]
    ids_hdr    = tokens_from_messages(msgs, add_generation_prompt=True)   # [1, L_sys+doc + header]
    sys_only   = tokens_from_messages([{"role":"system","content":system}], add_generation_prompt=False)

    L_sys = sys_only.size(1)
    L_all = ids_phase1.size(1)
    L_doc = L_all - L_sys
    assert L_doc > 0, "Phase-1 must include document tokens."

    # 1) System-only prefill (base model)
    out_sys = model.model(
        input_ids=sys_only,
        attention_mask=build_mask(L_sys),
        use_cache=True,
        return_dict=True,
    )
    pkv_sys = out_sys.past_key_values
    n_layers = pkv_len(pkv_sys)

    # 2) Lower-K doc pass (base model) — no upper continue
    K_eff = max(0, min(K, n_layers))
    full_layers: nn.ModuleList = model.model.layers
    lower_layers = nn.ModuleList([full_layers[i] for i in range(0, K_eff)])

    model.model.layers = lower_layers
    dc_low_in = dc_from_subset(pkv_sys, list(range(K_eff))) if K_eff > 0 else DynamicCache()
    attn_doc_full = torch.cat([build_mask(L_sys), build_mask(L_doc)], dim=1)

    out_low = model.model(
        input_ids=ids_phase1[:, L_sys:],  # doc only
        past_key_values=dc_low_in,
        attention_mask=attn_doc_full,     # sys+doc 길이
        use_cache=True,
        return_dict=True,
    )
    pkv_low = out_low.past_key_values

    # 복원
    model.model.layers = full_layers

    # 3) Combine caches
    # lower(0..K-1): sys + doc
    # upper(K..):    sys only  (doc 미포함; generation 시작 시점부터 쌓임)
    combined = DynamicCache()
    # 헤드/차원 shape 확보
    k0_sys, v0_sys = pkv_get(pkv_sys, 0)
    for li in range(n_layers):
        k_sys, v_sys = pkv_get(pkv_sys, li)
        k_sys = k_sys[:, :, :L_sys, :]
        v_sys = v_sys[:, :, :L_sys, :]
        if li < K_eff:
            k_low, v_low = pkv_get(pkv_low, li)
            k_doc, v_doc = cache_slice_doc_only(k_low, v_low, L_doc)
            k_cat = torch.cat([k_sys, k_doc], dim=2).contiguous()
            v_cat = torch.cat([v_sys, v_doc], dim=2).contiguous()
            combined.update(k_cat, v_cat, li)
        else:
            # upper: sys only
            combined.update(k_sys.contiguous(), v_sys.contiguous(), li)

    # 4) Phase-2: Generation
    # seed: assistant header를 한 토큰씩 밀어 넣어 upper past가 L_sys → L_sys+H로 자라도록 함
    if start_with_assistant_header:
        hdr_tail = ids_hdr[:, L_all:]  # header-only tokens (len H)
        H = int(hdr_tail.size(1))
        for j in range(H):
            past_len = pkv_get(combined, 0)[0].shape[2]  # lower 기준: L_sys+L_doc+grown
            attn_mask = torch.cat([build_mask(past_len), build_mask(1)], dim=1)
            step_tok = hdr_tail[:, j:j+1]
            out_seed = model(
                input_ids=step_tok,
                past_key_values=combined,
                attention_mask=attn_mask,
                use_cache=True,
                return_dict=True,
            )
            combined = out_seed.past_key_values

    # 5) Decoding (CausalLM) with safe processors
    from transformers.generation import LogitsProcessorList
    from transformers.generation.logits_process import (
        TemperatureLogitsWarper, TopPLogitsWarper, TopKLogitsWarper,
        RepetitionPenaltyLogitsProcessor,
    )

    eos_id = tokenizer.eos_token_id
    generated_ids = torch.empty((1, 0), dtype=torch.long, device=device)
    # 첫 생성 스텝의 입력은 "직전 토큰" (header 마지막 또는 마지막 doc)
    last = (ids_hdr[:, -1:] if start_with_assistant_header else ids_phase1[:, -1:])

    processors = LogitsProcessorList()
    if repetition_penalty and repetition_penalty != 1.0:
        processors.append(RepetitionPenaltyLogitsProcessor(penalty=float(repetition_penalty)))
    if temperature and temperature != 1.0:
        processors.append(TemperatureLogitsWarper(temperature=float(temperature)))
    if top_k is not None and top_k > 0:
        processors.append(TopKLogitsWarper(top_k=int(top_k), filter_value=-float("inf")))
    if top_p is not None and top_p < 1.0:
        processors.append(TopPLogitsWarper(top_p=float(top_p), min_tokens_to_keep=1))

    cur = 0
    while cur < max_new_tokens:
        # lower past_len은 L_sys+L_doc + grown; upper past_len은 L_sys + grown
        past_len = pkv_get(combined, 0)[0].shape[2]
        attn_mask = torch.cat([build_mask(past_len), build_mask(1)], dim=1)

        out = model(
            input_ids=last,
            past_key_values=combined,
            attention_mask=attn_mask,
            use_cache=True,
            return_dict=True,
        )
        combined = out.past_key_values
        logits = out.logits[:, -1, :].to(torch.float32)

        # 즉시 EOS 방지 (min_length 수동 적용)
        if eos_id is not None and cur < min_length:
            logits[:, eos_id] = -float("inf")

        if not torch.isfinite(logits).all():
            logits = torch.nan_to_num(logits, nan=-1e9, posinf=1e9, neginf=-1e9)

        inp = generated_ids if generated_ids.numel() > 0 else last.new_zeros((1, 0), dtype=torch.long, device=device)
        inp = inp.to(logits.device)
        logits = processors(inp, logits)

        # Fallback-safe sampling
        if not torch.isfinite(logits).any():
            next_tok = torch.argmax(logits, dim=-1, keepdim=True)
        else:
            if do_sample:
                probs = torch.softmax(logits, dim=-1)
                if (not torch.isfinite(probs).all()) or (probs.sum(dim=-1) <= 0).any():
                    next_tok = torch.argmax(logits, dim=-1, keepdim=True)
                else:
                    next_tok = torch.multinomial(probs, num_samples=1)
            else:
                next_tok = torch.argmax(logits, dim=-1, keepdim=True)

        if eos_id is not None and int(next_tok.item()) == eos_id:
            break

        generated_ids = torch.cat([generated_ids, next_tok], dim=1)
        last = next_tok
        cur += 1

    return tokenizer.decode(generated_ids[0], skip_special_tokens=True)

# -----------------------------
# Example
# -----------------------------
system = "You are a helpful assistant that answers questions based on the given document."
document = "The Nile is the longest river in Africa and flows northward through several countries..."
question = "Which continent is the Nile the longest river in?"

answer = lopa_generate(
    system, document, question,
    K=K,
    max_new_tokens=256,
    min_length=16,
    temperature=0.7, top_p=0.9,
    start_with_assistant_header=True,  # Phase2 처음에 header를 단계적으로 투입
    # force_attn_impl="eager",         # 필요시 비교
)
print(answer)


`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

[Lens] L_sys=40, L_doc=35, L_all=75, header_len=4
[Layers] n_layers=32, K_eff=4
[sys] n_layers=32 | L0:K(1, 8, 40, 128) | L1:K(1, 8, 40, 128) | L31:K(1, 8, 40, 128)
[OK]   system_prefill: all present layers seq_len == 40
[low] h_low=(1, 35, 4096)
[low] n_layers=4 | L0:K(1, 8, 75, 128) | L1:K(1, 8, 75, 128) | L3:K(1, 8, 75, 128)
[up] n_layers=32 | L4:K(1, 8, 75, 128) | L5:K(1, 8, 75, 128) | L31:K(1, 8, 75, 128)
[combined(before_hdr)] n_layers=32 | L0:K(1, 8, 75, 128) | L1:K(1, 8, 75, 128) | L31:K(1, 8, 75, 128)
[OK]   combined_prefill: all present layers seq_len == 75
[hdr] past_len before=75, after=79, delta=4
[combined(after_hdr)] n_layers=32 | L0:K(1, 8, 79, 128) | L1:K(1, 8, 79, 128) | L31:K(1, 8, 79, 128)
[decode step 0] past_len 79 -> 80 | next_id=40
[decode step 1] past_len 80 -> 81 | next_id=2846
[decode step 2] past_len 81 -> 82 | next_id=539
[debug answer] I'm not able to answer your question as I don't have any information about a document. Can you please provide the document

In [1]:
import sys, torch
from transformers import AutoTokenizer

model_dir = "/data2/jeongseokoh/jeongseokoh/LatentCOMP_cleaned/LatentCOMP_cleaned/outputs/Llama-3.1-8B-Instruct-LOPA-partial4-0specials/best"
prefill_layers = 32
device = "cuda" if torch.cuda.is_available() else "cpu"

sys.path.insert(0, model_dir)  # best/ 를 import 경로에 추가
from modeling_partial_layer import LlamaForCausalLM

tok = AutoTokenizer.from_pretrained(model_dir)
model = LlamaForCausalLM.from_pretrained(model_dir, device_map="cuda:0", dtype=torch.bfloat16, trust_remote_code=True).to(device).eval()
for k in ("attn_implementation", "_attn_implementation"):
    setattr(model.config, k, "sdpa")
system = "You are a helpful assistant that answers questions based on the given document."
document = "The Nile is the longest river in Africa..."
question = "Which continent is the Nile the longest river in?"

out = model.generate(
    system=system, document=document, query=question,
    compress=False, tokenizer=tok, prefill_layers=prefill_layers,
    max_new_tokens=256, do_sample=False, temperature=0.7, top_p=0.9
)
print(tok.batch_decode(out, skip_special_tokens=True))  # 문자열 또는 문자열 리스트


Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


['system\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYou are a helpful assistant that answers questions based on the given document.user\n\nDocument:\nThe Nile is the longest river in Africa...\n\nQuestion: Which continent is the Nile the longest river in?assistant\n\nThe Nile is the longest river in Africa.']


In [2]:
import os
os.environ["LATENTRAG_DEBUG"]="1"
print(model.lora_debug_status())

peft_available=True | attached=True | merged=False | is_sharded=True | source=/data2/jeongseokoh/jeongseokoh/LatentCOMP_cleaned/LatentCOMP_cleaned/outputs/Llama-3.1-8B-Instruct-LOPA-partial4-0specials/best/lora | error=None | last_used_lora=True


In [None]:
import sys, torch
from pathlib import Path
from transformers import AutoTokenizer

# 경로 설정 (혼동 방지를 위해 절대경로 권장)
best_dir = Path("/data2/jeongseokoh/jeongseokoh/LatentCOMP_cleaned/LatentCOMP_cleaned/outputs/Llama-3.1-8B-Instruct-LOPA-partial4-0specials/best")
lora_dir = best_dir / "lora"
prefill_layers = 4  # 학습 시 사용한 값과 동일하게

device = "cuda" if torch.cuda.is_available() else "cpu"

# best/의 remote-code(LOPA 구현)를 import할 수 있게 경로 추가
sys.path.insert(0, str(best_dir))
from modeling_partial_layer import LlamaForCausalLM  # Llama 3.1 계열용

# 토크나이저: chat template 포함
tok = AutoTokenizer.from_pretrained(str(best_dir))
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

# 1) 백본을 허깅페이스에서 ‘깨끗하게’ 로드 (백본=llama3.1 8B instruct)
#    주의: trust_remote_code=False → 표준 HF 클래스 경로를 통해 가중치만 로드
backbone_id = "meta-llama/Llama-3.1-8B-Instruct"
base_model = LlamaForCausalLM.from_pretrained(
    backbone_id,
    device_map="cuda:0",
    dtype=torch.bfloat16,
    trust_remote_code=False,  # 가중치만 가져오고 동작은 our remote-code 클래스가 담당
    cache_dir="/data2/jeongseokoh/hub/model",
).to(device).eval()

# 2) LoRA 어댑터 attach 후 merge (단일 모델로)
from peft import PeftModel
peft_model = PeftModel.from_pretrained(base_model, str(lora_dir))
merged = peft_model.merge_and_unload().to(device).eval()

# 3) LOPA 추론 (compress=True → partial prefill)
#    이미 merge된 모델이므로 추가 attach를 막기 위해 use_lora=False로 고정
system = "You are a helpful assistant that answers questions based on the given document."
document = "The Nile is the longest river in Africa and flows northward through several countries..."
question = "Which continent is the Nile the longest river in?"

out = merged.generate(
    system=system,
    document=document,
    query=question,
    compress=True,                 # LOPA 경로
    tokenizer=tok,                 # 필수
    prefill_layers=prefill_layers, # 학습과 동일
    use_lora=False,                # 이미 merge됐으므로 재-부착 방지
    max_new_tokens=256,
    do_sample=True,
    temperature=0.7,
    top_p=0.9,
)
print(out)  # 문자열 또는 문자열 리스트


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Fetching 0 files: 0it [00:00, ?it/s]

assistant<|end_header_id|>assistant<|end_header_id|>

Africa.<|eot_id|>
