In [None]:
import os
os.environ["TRI_DEBUG_DEV"] = "1"
import re
import torch
import sys
from pathlib import Path
# ==== HF repo 설정 ====
REPO_ID         = "jeongseokoh/LoPA_Llama3.1_8B_16_Lowers"  # ← repo 루트만!
BASE_SUBFOLDER  = "base"
LORA_SUBFOLDER  = "lora"    # 없으면 자동으로 건너뜀
HF_TOKEN        = os.environ.get("HF_TOKEN", None)  # private면 토큰 필요
LOPA_MODELING_PATH  = "./lopa_llama_modeling.py"  # (완성본) TRI 모델링 파일
ATTN_IMPL           = "flash_attention_2"         # "flash_attention_2" | "eager" | "sdpa"
LOWER_K             = 16                          # 하위 레이어 수(K). tri_info.txt에서 자동 로드 시 덮어씀
MAX_NEW_TOKENS      = 256
TEMPERATURE         = 0.7
TOP_P               = 0.9
REPETITION_PENALTY  = 1.0
SEED_TEXT           = ""                   # 헤더 뒤에 1~2 토큰 seed 권장

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16 if (torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
        else (torch.float16 if torch.cuda.is_available() else torch.float32)
print("device:", device, "| dtype:", dtype)

# ==== TRI 모델링 주입 유지 ====
import importlib.util, transformers, transformers.models.llama as llama_pkg, sys
target_name = "transformers.models.llama.modeling_llama"
spec = importlib.util.spec_from_file_location(target_name, LOPA_MODELING_PATH)
mod = importlib.util.module_from_spec(spec)
sys.modules.pop(target_name, None)
sys.modules[target_name] = mod
spec.loader.exec_module(mod)
setattr(llama_pkg, "modeling_llama", mod)
from transformers.models.llama.modeling_llama import LlamaForCausalLM
print("[DEBUG] TRI patch loaded:", LOPA_MODELING_PATH)

# ==== 토크나이저 ====
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained(REPO_ID, use_fast=True, token=HF_TOKEN)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

# ==== 모델(베이스) ====
model = LlamaForCausalLM.from_pretrained(
    REPO_ID,
    subfolder=BASE_SUBFOLDER,      # ★ subfolder로 지정
    torch_dtype=dtype,
    token=HF_TOKEN,
    cache_dir="/data2/jeongseokoh/hub",
    device_map="auto"
)

# ==== LoRA 어댑터(있으면 로딩) ====
try:
    from peft import PeftModel
    # 허브에서 서브폴더로 바로 로딩
    model = PeftModel.from_pretrained(model, REPO_ID, subfolder=LORA_SUBFOLDER, token=HF_TOKEN)
    print(f"[info] LoRA adapters loaded from: {REPO_ID}/{LORA_SUBFOLDER}")
except Exception as e:
    print("[info] LoRA subfolder not found or load skipped:", str(e).split("\n")[0])

# ==== 어텐션 백엔드 ====
try:
    model.config._attn_implementation = ATTN_IMPL
    print("[info] attn_impl =", ATTN_IMPL)
except Exception:
    pass

# ==== tri_info.txt에서 lower_k 자동 로드 (repo 또는 subfolder 둘 다 시도) ====
from huggingface_hub import hf_hub_download
import re
loaded_lower_k = False
for sub in (None, BASE_SUBFOLDER):
    try:
        tri_path = hf_hub_download(REPO_ID, filename="tri_info.txt", subfolder=sub, token=HF_TOKEN)
        txt = Path(tri_path).read_text(encoding="utf-8")
        m = re.search(r"lower_k\s*=\s*(\d+)", txt)
        if m:
            LOWER_K = int(m.group(1))
            print(f"[info] lower_k loaded from tri_info.txt ({'root' if sub is None else sub}):", LOWER_K)
            loaded_lower_k = True
            break
    except Exception:
        continue
if not loaded_lower_k:
    print("[warn] tri_info.txt not found on hub; using LOWER_K =", LOWER_K)

# ==== 추론 모드 / TF32 ====
torch.set_grad_enabled(False)
try:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
except Exception:
    pass

# TRI API 확인
for need in ("tri_build_caches", "tri_forward_assistant", "tri_step_logits"):
    assert hasattr(model, need), f"Missing TRI API: {need}"
model.eval()


device: cuda | dtype: torch.bfloat16
[DEBUG] TRI patch loaded: ./lopa_llama_modeling.py


`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

[info] LoRA adapters loaded from: jeongseokoh/LoPA_Llama3.1_8B_16_Lowers/lora
[info] attn_impl = flash_attention_2
[info] lower_k loaded from tri_info.txt (root): 16


PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): LlamaForCausalLM(
      (model): LlamaModel(
        (embed_tokens): Embedding(128256, 4096)
        (layers): ModuleList(
          (0-31): 32 x LlamaDecoderLayer(
            (self_attn): LlamaAttention(
              (q_proj): lora.Linear(
                (base_layer): Linear(in_features=4096, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=4, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=4, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (k_proj): lora.Linear(


In [2]:
from copy import copy
from typing import List, Optional

class TRIInfer:
    def __init__(self, model, tok, device, lower_k: int, attn_impl: str = "flash_attention_2"):
        self.model = model.eval()
        self.tok = tok
        self.device = device
        self.lower_k = int(lower_k)
        try:
            self.model.config._attn_implementation = attn_impl
        except Exception:
            pass

    @staticmethod
    def build_messages(system: str, document: str, question: str, include_query: bool = True):
        user = f"Document:\n{document}\n\nQuestion: {question}" if include_query else f"Document:\n{document}\n\n"
        return [{"role":"system","content":system},{"role":"user","content":user}]

    def tokens_from_messages(self, messages, add_generation_prompt: bool):
        try:
            s = self.tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=add_generation_prompt)
        except TypeError:
            s = self.tok.apply_chat_template(messages, tokenize=False)
            tmpl = getattr(self.tok, "chat_template", "") or ""
            if add_generation_prompt and "<|start_header_id|>" in tmpl:
                s += "<|start_header_id|>assistant<|end_header_id|>\n\n"
        return self.tok(s, add_special_tokens=False, return_tensors="pt").input_ids.to(self.device)

    @staticmethod
    def lcp_len(a: torch.Tensor, b: torch.Tensor) -> int:
        L = min(a.size(1), b.size(1))
        eq = (a[0, :L] == b[0, :L])
        nz = (~eq).nonzero(as_tuple=False)
        return int(nz[0, 0]) if nz.numel() else L

    @staticmethod
    def sample_top_p(logits: torch.Tensor, temperature=1.0, top_p=0.9,
                     repetition_penalty: float = 1.0, prev_ids: Optional[torch.Tensor] = None):
        if repetition_penalty != 1.0 and prev_ids is not None and prev_ids.numel() > 0:
            logits = logits.clone()
            logits[:, prev_ids.unique()] /= repetition_penalty
        logits = logits.float() / max(1e-6, float(temperature))
        probs  = torch.softmax(logits, dim=-1)
        if top_p < 1.0:
            sorted_probs, sorted_idx = torch.sort(probs, descending=True)
            cumsum = torch.cumsum(sorted_probs, dim=-1)
            mask = cumsum > top_p
            mask[..., 0] = False
            keep = ~mask
            filtered = torch.zeros_like(sorted_probs).masked_scatter_(keep, sorted_probs[keep])
            probs = torch.zeros_like(probs).scatter(dim=-1, index=sorted_idx, src=filtered)
            probs = probs / probs.sum(dim=-1, keepdim=True)
        return torch.multinomial(probs, num_samples=1).squeeze(-1)

    def generate(self, system: str, document: str, question: str,
                 max_new_tokens: int = 256, temperature: float = 0.7, top_p: float = 0.9,
                 repetition_penalty: float = 1.0, seed: str = "") -> str:

        msgs  = self.build_messages(system, document, question, include_query=True)
        S_ids = self.tokens_from_messages(msgs[:1], add_generation_prompt=False)
        SU_ids= self.tokens_from_messages(msgs, add_generation_prompt=False)
        SU_gen= self.tokens_from_messages(msgs, add_generation_prompt=True)

        lcp = self.lcp_len(S_ids, SU_ids)
        user_delta   = SU_ids[:, lcp:SU_ids.size(1)]
        header_delta = SU_gen[:, SU_ids.size(1):]

        # 1) S/U 프리필
        pkv, S_len, U_len = self.model.tri_build_caches(S_ids, user_delta, lower_k=self.lower_k)

        # 2) 헤더(+선택 seed) 1회 기록 & 다음 토큰 분포
        head = header_delta
        if seed:
            seed_ids = self.tok(seed, add_special_tokens=False, return_tensors="pt").input_ids.to(self.device)
            head = torch.cat([head, seed_ids], dim=1)

        out = self.model.tri_step_logits(head, self.lower_k, pkv, S_len, U_len,
                                         logits_to_keep=1, labels=None, write_cache=True)
        logits = out.logits[:, -1, :]

        # 3) 루프
        cur = self.sample_top_p(logits, temperature, top_p, 1.0, None).unsqueeze(0).to(self.device)
        generated: List[int] = []
        for _ in range(max_new_tokens):
            out = self.model.tri_step_logits(cur, self.lower_k, pkv, S_len, U_len,
                                             logits_to_keep=1, labels=None, write_cache=True)
            tid = int(cur.item()); generated.append(tid)
            if self.tok.eos_token_id is not None and tid == int(self.tok.eos_token_id):
                break
            logits = out.logits[:, -1, :]
            prev = torch.tensor(generated, device=logits.device, dtype=torch.long).unsqueeze(0)
            cur = self.sample_top_p(logits, temperature, top_p, repetition_penalty, prev).unsqueeze(0).to(self.device)

        return self.tok.decode(generated, skip_special_tokens=True)


In [3]:
runner = TRIInfer(model, tok, device, lower_k=LOWER_K, attn_impl=ATTN_IMPL)

system_prompt = "You are a helpful assistant that answers questions based on the given document."

document = """The Nile is the longest river in Africa and flows northward through northeastern Africa to drain into the Mediterranean Sea.
It has two major tributaries: the White Nile and the Blue Nile."""

question = "Which continent is the Nile the longest river in?\nAt the end of your explanation, wrap the answer in '\\boxed{answer}'."

out = runner.generate(system_prompt, document, question,
                      max_new_tokens=MAX_NEW_TOKENS, temperature=TEMPERATURE,
                      top_p=TOP_P, repetition_penalty=REPETITION_PENALTY,
                      seed=SEED_TEXT)
print(out)


The Nile is the longest river in Africa. 

\boxed{Africa}


In [1]:
import re
import torch
import os
import sys
from pathlib import Path
# ==== HF repo 설정 ====
REPO_ID         = "jeongseokoh/LoPA_Llama3.1_8B_8_Lowers"  # ← repo 루트만!
BASE_SUBFOLDER  = "base"
LORA_SUBFOLDER  = "lora"    # 없으면 자동으로 건너뜀
HF_TOKEN        = os.environ.get("HF_TOKEN", None)  # private면 토큰 필요
LOPA_MODELING_PATH  = "./lopa_llama_modeling.py"  # (완성본) TRI 모델링 파일
ATTN_IMPL           = "flash_attention_2"         # "flash_attention_2" | "eager" | "sdpa"
LOWER_K             = 8                           # 하위 레이어 수(K). tri_info.txt에서 자동 로드 시 덮어씀
MAX_NEW_TOKENS      = 256
TEMPERATURE         = 0.7
TOP_P               = 0.9
REPETITION_PENALTY  = 1.0
SEED_TEXT           = ""                   # 헤더 뒤에 1~2 토큰 seed 권장

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16 if (torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
        else (torch.float16 if torch.cuda.is_available() else torch.float32)
print("device:", device, "| dtype:", dtype)

# ==== TRI 모델링 주입 유지 ====
import importlib.util, transformers, transformers.models.llama as llama_pkg, sys
target_name = "transformers.models.llama.modeling_llama"
spec = importlib.util.spec_from_file_location(target_name, LOPA_MODELING_PATH)
mod = importlib.util.module_from_spec(spec)
sys.modules.pop(target_name, None)
sys.modules[target_name] = mod
spec.loader.exec_module(mod)
setattr(llama_pkg, "modeling_llama", mod)
from transformers.models.llama.modeling_llama import LlamaForCausalLM
print("[DEBUG] TRI patch loaded:", LOPA_MODELING_PATH)

# ==== 토크나이저 ====
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained(REPO_ID, use_fast=True, token=HF_TOKEN)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

# ==== 모델(베이스) ====
model = LlamaForCausalLM.from_pretrained(
    REPO_ID,
    subfolder=BASE_SUBFOLDER,      # ★ subfolder로 지정
    torch_dtype=dtype,
    token=HF_TOKEN,
).to(device)

# ==== LoRA 어댑터(있으면 로딩) ====
try:
    from peft import PeftModel
    # 허브에서 서브폴더로 바로 로딩
    model = PeftModel.from_pretrained(model, REPO_ID, subfolder=LORA_SUBFOLDER, token=HF_TOKEN)
    model = model.to(device)
    print(f"[info] LoRA adapters loaded from: {REPO_ID}/{LORA_SUBFOLDER}")
except Exception as e:
    print("[info] LoRA subfolder not found or load skipped:", str(e).split("\n")[0])

# ==== 어텐션 백엔드 ====
try:
    model.config._attn_implementation = ATTN_IMPL
    print("[info] attn_impl =", ATTN_IMPL)
except Exception:
    pass

# ==== tri_info.txt에서 lower_k 자동 로드 (repo 또는 subfolder 둘 다 시도) ====
from huggingface_hub import hf_hub_download
import re
loaded_lower_k = False
for sub in (None, BASE_SUBFOLDER):
    try:
        tri_path = hf_hub_download(REPO_ID, filename="tri_info.txt", subfolder=sub, token=HF_TOKEN)
        txt = Path(tri_path).read_text(encoding="utf-8")
        m = re.search(r"lower_k\s*=\s*(\d+)", txt)
        if m:
            LOWER_K = int(m.group(1))
            print(f"[info] lower_k loaded from tri_info.txt ({'root' if sub is None else sub}):", LOWER_K)
            loaded_lower_k = True
            break
    except Exception:
        continue
if not loaded_lower_k:
    print("[warn] tri_info.txt not found on hub; using LOWER_K =", LOWER_K)

# ==== 추론 모드 / TF32 ====
torch.set_grad_enabled(False)
try:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
except Exception:
    pass

# TRI API 확인
for need in ("tri_build_caches", "tri_forward_assistant", "tri_step_logits"):
    assert hasattr(model, need), f"Missing TRI API: {need}"
model.eval()


device: cuda | dtype: torch.bfloat16
[DEBUG] TRI patch loaded: ./lopa_llama_modeling.py


`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

[info] LoRA adapters loaded from: jeongseokoh/LoPA_Llama3.1_8B_8_Lowers/lora
[info] attn_impl = flash_attention_2
[info] lower_k loaded from tri_info.txt (root): 8


PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): LlamaForCausalLM(
      (model): LlamaModel(
        (embed_tokens): Embedding(128256, 4096)
        (layers): ModuleList(
          (0-31): 32 x LlamaDecoderLayer(
            (self_attn): LlamaAttention(
              (q_proj): lora.Linear(
                (base_layer): Linear(in_features=4096, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=4, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=4, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (k_proj): lora.Linear(


In [2]:
# ==== 설정 (환경에 맞게 수정) ====
MODEL_BASE_DIR      = "/workspace/LatentCOMP_cleaned/LatentCOMP_cleaned/outputs/Llama-3.1-8B-Instruct-LOPA-partial8-0specials/best/base"         # 학습된 base 가중치 저장 폴더
LORA_DIR            = "/workspace/LatentCOMP_cleaned/LatentCOMP_cleaned/outputs/Llama-3.1-8B-Instruct-LOPA-partial8-0specials/best/lora"         # LoRA 어댑터 폴더(없으면 None/미존재)
LOPA_MODELING_PATH  = "./lopa_llama_modeling.py"  # (완성본) TRI 모델링 파일
TOKENIZER_PATH      = "/workspace/LatentCOMP_cleaned/LatentCOMP_cleaned/outputs/Llama-3.1-8B-Instruct-LOPA-partial8-0specials/best"
ATTN_IMPL           = "flash_attention_2"         # "flash_attention_2" | "eager" | "sdpa"
LOWER_K             = 8                           # 하위 레이어 수(K). tri_info.txt에서 자동 로드 시 덮어씀
MAX_NEW_TOKENS      = 256
TEMPERATURE         = 0.7
TOP_P               = 0.9
REPETITION_PENALTY  = 1.0
SEED_TEXT           = ""                   # 헤더 뒤에 1~2 토큰 seed 권장

# tri_info.txt에서 lower_k 자동 로드
import re, pathlib
try:
    info_p = pathlib.Path(MODEL_BASE_DIR).parent / "tri_info.txt"
    if info_p.exists():
        m = re.search(r"lower_k\s*=\s*(\d+)", info_p.read_text())
        if m:
            LOWER_K = int(m.group(1))
            print("[info] lower_k loaded from tri_info.txt:", LOWER_K)
except Exception as e:
    print("[warn] tri_info.txt read failed:", e)

# 권장 환경변수(추론 효율↑): 노트북에서 바로 설정해도 됨
import os
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("CUDA_DEVICE_MAX_CONNECTIONS", "1")
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "max_split_size_mb:128,expandable_segments:True")

import torch, sys
from pathlib import Path
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16 if (torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
        else (torch.float16 if torch.cuda.is_available() else torch.float32)
print("device:", device, "| dtype:", dtype)


[info] lower_k loaded from tri_info.txt: 8
device: cuda | dtype: torch.bfloat16


In [3]:
# ==== TRI 모델링을 transformers 내부에 주입 ====
import importlib.util, transformers, transformers.models.llama as llama_pkg

target_name = "transformers.models.llama.modeling_llama"
spec = importlib.util.spec_from_file_location(target_name, LOPA_MODELING_PATH)
mod = importlib.util.module_from_spec(spec)
sys.modules.pop(target_name, None)
sys.modules[target_name] = mod
spec.loader.exec_module(mod)
setattr(llama_pkg, "modeling_llama", mod)

from transformers.models.llama.modeling_llama import LlamaForCausalLM  # (패치된) 클래스
print("[DEBUG] TRI patch loaded:", LOPA_MODELING_PATH)

# ==== 토크나이저/모델 ====
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained(TOKENIZER_PATH, use_fast=True)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

model = LlamaForCausalLM.from_pretrained(MODEL_BASE_DIR, torch_dtype=dtype).to(device)

# LoRA 어댑터(있으면 자동 로드)
if LORA_DIR and Path(LORA_DIR).exists():
    try:
        from peft import PeftModel
        model = PeftModel.from_pretrained(model, LORA_DIR)
        model = model.to(device)
        print("[info] LoRA adapters loaded from:", LORA_DIR)
    except Exception as e:
        print("[warn] LoRA load failed:", e)

# 어텐션 백엔드 지정
try:
    model.config._attn_implementation = ATTN_IMPL
    print("[info] attn_impl =", ATTN_IMPL)
except Exception:
    pass

# 추론 모드/TF32
torch.set_grad_enabled(False)
try:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
except Exception:
    pass

# TRI API 존재 확인
for need in ("tri_build_caches", "tri_forward_assistant", "tri_step_logits"):
    assert hasattr(model, need), f"Missing TRI API: {need}"

model.eval()


[DEBUG] TRI patch loaded: ./lopa_llama_modeling.py


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

[info] LoRA adapters loaded from: /workspace/LatentCOMP_cleaned/LatentCOMP_cleaned/outputs/Llama-3.1-8B-Instruct-LOPA-partial8-0specials/best/lora
[info] attn_impl = flash_attention_2


PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): LlamaForCausalLM(
      (model): LlamaModel(
        (embed_tokens): Embedding(128256, 4096)
        (layers): ModuleList(
          (0-31): 32 x LlamaDecoderLayer(
            (self_attn): LlamaAttention(
              (q_proj): lora.Linear(
                (base_layer): Linear(in_features=4096, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=4, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=4, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (k_proj): lora.Linear(


In [4]:
from copy import copy
from typing import List, Optional

class TRIInfer:
    def __init__(self, model, tok, device, lower_k: int, attn_impl: str = "flash_attention_2"):
        self.model = model.eval()
        self.tok = tok
        self.device = device
        self.lower_k = int(lower_k)
        try:
            self.model.config._attn_implementation = attn_impl
        except Exception:
            pass

    @staticmethod
    def build_messages(system: str, document: str, question: str, include_query: bool = True):
        user = f"Document:\n{document}\n\nQuestion: {question}" if include_query else f"Document:\n{document}\n\n"
        return [{"role":"system","content":system},{"role":"user","content":user}]

    def tokens_from_messages(self, messages, add_generation_prompt: bool):
        try:
            s = self.tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=add_generation_prompt)
        except TypeError:
            s = self.tok.apply_chat_template(messages, tokenize=False)
            tmpl = getattr(self.tok, "chat_template", "") or ""
            if add_generation_prompt and "<|start_header_id|>" in tmpl:
                s += "<|start_header_id|>assistant<|end_header_id|>\n\n"
        return self.tok(s, add_special_tokens=False, return_tensors="pt").input_ids.to(self.device)

    @staticmethod
    def lcp_len(a: torch.Tensor, b: torch.Tensor) -> int:
        L = min(a.size(1), b.size(1))
        eq = (a[0, :L] == b[0, :L])
        nz = (~eq).nonzero(as_tuple=False)
        return int(nz[0, 0]) if nz.numel() else L

    @staticmethod
    def sample_top_p(logits: torch.Tensor, temperature=1.0, top_p=0.9,
                     repetition_penalty: float = 1.0, prev_ids: Optional[torch.Tensor] = None):
        if repetition_penalty != 1.0 and prev_ids is not None and prev_ids.numel() > 0:
            logits = logits.clone()
            logits[:, prev_ids.unique()] /= repetition_penalty
        logits = logits.float() / max(1e-6, float(temperature))
        probs  = torch.softmax(logits, dim=-1)
        if top_p < 1.0:
            sorted_probs, sorted_idx = torch.sort(probs, descending=True)
            cumsum = torch.cumsum(sorted_probs, dim=-1)
            mask = cumsum > top_p
            mask[..., 0] = False
            keep = ~mask
            filtered = torch.zeros_like(sorted_probs).masked_scatter_(keep, sorted_probs[keep])
            probs = torch.zeros_like(probs).scatter(dim=-1, index=sorted_idx, src=filtered)
            probs = probs / probs.sum(dim=-1, keepdim=True)
        return torch.multinomial(probs, num_samples=1).squeeze(-1)

    def generate(self, system: str, document: str, question: str,
                 max_new_tokens: int = 256, temperature: float = 0.7, top_p: float = 0.9,
                 repetition_penalty: float = 1.0, seed: str = "") -> str:

        msgs  = self.build_messages(system, document, question, include_query=True)
        S_ids = self.tokens_from_messages(msgs[:1], add_generation_prompt=False)
        SU_ids= self.tokens_from_messages(msgs, add_generation_prompt=False)
        SU_gen= self.tokens_from_messages(msgs, add_generation_prompt=True)

        lcp = self.lcp_len(S_ids, SU_ids)
        user_delta   = SU_ids[:, lcp:SU_ids.size(1)]
        header_delta = SU_gen[:, SU_ids.size(1):]

        # 1) S/U 프리필
        pkv, S_len, U_len = self.model.tri_build_caches(S_ids, user_delta, lower_k=self.lower_k)

        # 2) 헤더(+선택 seed) 1회 기록 & 다음 토큰 분포
        head = header_delta
        if seed:
            seed_ids = self.tok(seed, add_special_tokens=False, return_tensors="pt").input_ids.to(self.device)
            head = torch.cat([head, seed_ids], dim=1)

        out = self.model.tri_step_logits(head, self.lower_k, pkv, S_len, U_len,
                                         logits_to_keep=1, labels=None, write_cache=True)
        logits = out.logits[:, -1, :]

        # 3) 루프
        cur = self.sample_top_p(logits, temperature, top_p, 1.0, None).unsqueeze(0).to(self.device)
        generated: List[int] = []
        for _ in range(max_new_tokens):
            out = self.model.tri_step_logits(cur, self.lower_k, pkv, S_len, U_len,
                                             logits_to_keep=1, labels=None, write_cache=True)
            tid = int(cur.item()); generated.append(tid)
            if self.tok.eos_token_id is not None and tid == int(self.tok.eos_token_id):
                break
            logits = out.logits[:, -1, :]
            prev = torch.tensor(generated, device=logits.device, dtype=torch.long).unsqueeze(0)
            cur = self.sample_top_p(logits, temperature, top_p, repetition_penalty, prev).unsqueeze(0).to(self.device)

        return self.tok.decode(generated, skip_special_tokens=True)


In [5]:
runner = TRIInfer(model, tok, device, lower_k=LOWER_K, attn_impl=ATTN_IMPL)

system_prompt = "You are a helpful assistant that answers questions based on the given document."

document = """The Nile is the longest river in Africa and flows northward through northeastern Africa to drain into the Mediterranean Sea.
It has two major tributaries: the White Nile and the Blue Nile."""

question = "Which continent is the Nile the longest river in?\nAt the end of your explanation, wrap the answer in '\\boxed{answer}'."

out = runner.generate(system_prompt, document, question,
                      max_new_tokens=MAX_NEW_TOKENS, temperature=TEMPERATURE,
                      top_p=TOP_P, repetition_penalty=REPETITION_PENALTY,
                      seed=SEED_TEXT)
print(out)


The Nile is the longest river in Africa. 
\boxed{Africa}


In [6]:
# 같은 컨텍스트에서 N개 샘플을 뽑고 싶다면 간단 루프:
outs = [runner.generate(system_prompt, document, question,
                        max_new_tokens=128, temperature=0.9, top_p=0.95,
                        repetition_penalty=1.0, seed=SEED_TEXT)
        for _ in range(3)]
for i, o in enumerate(outs, 1):
    print(f"\n[{i}]\n{o}")



[1]
The Nile is the longest river in Africa.

[2]
The Nile is the longest river in Africa.

[3]
The Nile is the longest river in Africa.


In [2]:
import json

file1 = "gsm8k_train_5resp_seed42_samples4000_boxed_numeric_exact.jsonl"
file2 = "triviaqa_hotpotqa_6000_merged2.jsonl"
output = "gsm8k_triviaqa_hotpotqa_6000_merged2.jsonl"

seen = set()
with open(output, "w", encoding="utf-8") as fout:
    for fname in [file1, file2]:
        with open(fname, "r", encoding="utf-8") as fin:
            for line in fin:
                data = json.loads(line)
                # 중복 제거 기준 (예: "id" 필드가 있다고 가정)
                key = data.get("id", None)
                if key is not None:
                    if key in seen:
                        continue
                    seen.add(key)
                fout.write(json.dumps(data, ensure_ascii=False) + "\n")


In [3]:
# === Config ===
MODEL_DIR  = "/workspace/LatentCOMP_cleaned/LatentCOMP_cleaned/outputs/Llama-3.1-8B-Instruct-LOPA-partial8-0specials/best"  
MODEL_BASE_DIR = "/workspace/LatentCOMP_cleaned/LatentCOMP_cleaned/outputs/Llama-3.1-8B-Instruct-LOPA-partial8-0specials/best/base"      # train 스크립트가 저장한 base 가중치 폴더
LORA_DIR       = "/workspace/LatentCOMP_cleaned/LatentCOMP_cleaned/outputs/Llama-3.1-8B-Instruct-LOPA-partial8-0specials/best/lora"      # LoRA 어댑터 폴더(없으면 None로 두거나 경로를 존재하지 않게)
LOPA_MODELING_PATH = "./lopa_llama_modeling.py"  # TRI 모델링 패치 파일 경로
LOWER_K = 8                               # 학습에 썼던 prefill_layers 값과 동일하게

# (선택) tri_info.txt에서 LOWER_K 자동 로드
try:
    import re, pathlib
    p = pathlib.Path(MODEL_BASE_DIR).parent / "tri_info.txt"
    if p.exists():
        txt = p.read_text()
        m = re.search(r"lower_k\s*=\s*(\d+)", txt)
        if m:
            LOWER_K = int(m.group(1))
            print("[info] lower_k loaded from tri_info.txt:", LOWER_K)
except Exception as e:
    print("[warn] fail to read tri_info.txt:", e)

# === Imports ===
import os, sys, torch
from pathlib import Path
from typing import List, Tuple
from transformers import AutoTokenizer
torch.set_grad_enabled(False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16 if (torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
        else (torch.float16 if torch.cuda.is_available() else torch.float32)
print("device:", device, "| dtype:", dtype)


[info] lower_k loaded from tri_info.txt: 8
device: cuda | dtype: torch.bfloat16


In [4]:
# TRI 모델링 패치를 transformers에 주입
import importlib.util
import transformers, transformers.models.llama as llama_pkg

target_name = "transformers.models.llama.modeling_llama"
spec = importlib.util.spec_from_file_location(target_name, LOPA_MODELING_PATH)
mod = importlib.util.module_from_spec(spec)
sys.modules.pop(target_name, None)
sys.modules[target_name] = mod
spec.loader.exec_module(mod)
setattr(llama_pkg, "modeling_llama", mod)

from transformers.models.llama.modeling_llama import LlamaForCausalLM  # 패치된 클래스
print("[DEBUG] TRI patch loaded from:", LOPA_MODELING_PATH)

# 토크나이저/모델 로드 (LoRA가 있으면 어댑터도 적용)
tok = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

# base 모델
model = LlamaForCausalLM.from_pretrained(MODEL_BASE_DIR, torch_dtype=dtype).to(device)

# LoRA 적용(폴더가 실제 있으면)
if LORA_DIR and Path(LORA_DIR).exists():
    try:
        from peft import PeftModel
        model = PeftModel.from_pretrained(model, LORA_DIR)
        model = model.to(device)
        print("[info] LoRA adapters loaded from:", LORA_DIR)
    except Exception as e:
        print("[warn] LoRA load failed:", e)

# 안전하게 eager로
try:
    model.config._attn_implementation = "eager"
except Exception:
    pass

model.eval()


[DEBUG] TRI patch loaded from: ./lopa_llama_modeling.py


`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

[info] LoRA adapters loaded from: /workspace/LatentCOMP_cleaned/LatentCOMP_cleaned/outputs/Llama-3.1-8B-Instruct-LOPA-partial8-0specials/best/lora


PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): LlamaForCausalLM(
      (model): LlamaModel(
        (embed_tokens): Embedding(128256, 4096)
        (layers): ModuleList(
          (0-31): 32 x LlamaDecoderLayer(
            (self_attn): LlamaAttention(
              (q_proj): lora.Linear(
                (base_layer): Linear(in_features=4096, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=4, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=4, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (k_proj): lora.Linear(


In [5]:
def build_messages(system: str, document: str, question: str, include_query: bool = True):
    user = f"Document:\n{document}\n\nQuestion: {question}" if include_query else f"Document:\n{document}\n\n"
    return [{"role":"system","content":system},{"role":"user","content":user}]

def apply_chat_template(tokenizer, messages, add_generation_prompt: bool):
    try:
        return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=add_generation_prompt)
    except TypeError:
        s = tokenizer.apply_chat_template(messages, tokenize=False)
        tmpl = getattr(tokenizer, "chat_template", "") or ""
        if add_generation_prompt and "<|start_header_id|>" in tmpl:
            s += "<|start_header_id|>assistant<|end_header_id|>\n\n"
        return s

def tokens_from_messages(tokenizer, messages, device, add_generation_prompt=False):
    s = apply_chat_template(tokenizer, messages, add_generation_prompt)
    return tokenizer(s, add_special_tokens=False, return_tensors="pt").input_ids.to(device)

def lcp_len(a: torch.Tensor, b: torch.Tensor) -> int:
    L = min(a.size(1), b.size(1))
    eq = (a[0, :L] == b[0, :L])
    nz = (~eq).nonzero(as_tuple=False)
    return int(nz[0, 0]) if nz.numel() else L

# 편의: 메시지를 토큰 세그먼트로 쪼개기
def build_segments(system_prompt: str, document: str, question: str):
    msgs  = build_messages(system_prompt, document, question, include_query=True)
    S_ids = tokens_from_messages(tok, msgs[:1], device, add_generation_prompt=False)
    SU_ids= tokens_from_messages(tok, msgs, device, add_generation_prompt=False)
    SU_gen= tokens_from_messages(tok, msgs, device, add_generation_prompt=True)

    l_su = lcp_len(S_ids, SU_ids)
    user_delta     = SU_ids[:, l_su:SU_ids.size(1)]
    header_delta   = SU_gen[:, SU_ids.size(1):]   # assistant header-only
    return msgs, S_ids, SU_ids, SU_gen, user_delta, header_delta


In [13]:
system_prompt = "You are a helpful assistant that answers questions based on the given document."

document = """Josh decides to try flipping a house. He buys a house for $80,000 and then puts in $50,000 in repairs.
This increased the value of the house by 150%."""

question = "How much profit did he make?"

msgs, S_ids, SU_ids, SU_gen, user_delta, header_delta = build_segments(system_prompt, document, question)
print("S len:", S_ids.size(1), "SU len:", SU_ids.size(1), "header len:", header_delta.size(1))


S len: 40 SU len: 95 header len: 4


In [7]:
# 단순 확인용: TRI 제약 없이 전체 콘텍스트로 generate
prompt_text = apply_chat_template(tok, build_messages(system_prompt, document, question), add_generation_prompt=True)
input_ids = tok(prompt_text, add_special_tokens=False, return_tensors="pt").input_ids.to(device)

out_ids = model.generate(
    input_ids=input_ids,
    max_new_tokens=128,
    temperature=0.7,
    top_p=0.9,
    do_sample=True,
    eos_token_id=tok.eos_token_id,
)
print(tok.decode(out_ids[0], skip_special_tokens=True))


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


system

Cutting Knowledge Date: December 2023
Today Date: 26 Jul 2024

You are a helpful assistant that answers questions based on the given document.user

Document:
Josh decides to try flipping a house. He buys a house for $80,000 and then puts in $50,000 in repairs.
This increased the value of the house by 150%.

Question: How much profit did he make?assistant

To  fetisch Animalia href="? ( (  	'gc aalborg-Compatible waited throat aalborg Party sat weiber echangaimassage odense fetisch Awards throat aalborg mouth rumpe fetisch Animalia için/Object Insecta favorite throat aalborg question and overposting href=" ( addCriterion ( sourceMapping Animalia father wife century was salopes aalborg Academy href=" (aa father brother için Insecta entre Telegraph Province | aalborg href="? eoq League000 ن? overposting Province |
 salopes	 +#+olumn Academy href=" (  geschichten fetisch Hospital anniversary Rica eyes fetisch Initiative entre Animalia and echangernetes Cathedral dengan-Cola000 hund

In [None]:
# 샘플러: temperature + top-p
def sample_top_p(logits: torch.Tensor, temperature=1.0, top_p=0.9):
    logits = logits.float() / max(1e-6, float(temperature))
    probs  = torch.softmax(logits, dim=-1)
    if top_p < 1.0:
        sorted_probs, sorted_idx = torch.sort(probs, descending=True)
        cumsum = torch.cumsum(sorted_probs, dim=-1)
        mask = cumsum > top_p
        # 첫 토큰은 허용
        mask[..., 0] = False
        probs.scatter_(dim=-1, index=sorted_idx, src=torch.where(mask, torch.zeros_like(sorted_probs), sorted_probs))
        probs = probs / probs.sum(dim=-1, keepdim=True)
    next_id = torch.multinomial(probs, num_samples=1)
    return next_id.squeeze(-1)

# 1) S/U 프리필 (grad 불필요)
with torch.no_grad():
    pkv, S_len, U_len = model.tri_build_caches(system_ids=S_ids, user_ids=user_delta, lower_k=LOWER_K)

# 2) 헤더를 캐시에 기록(write_cache=True)
with torch.no_grad():
    _ = model.tri_step_logits(
        assistant_ids=header_delta, lower_k=LOWER_K, pkv=pkv, S=S_len, U=U_len,
        logits_to_keep=0, labels=None, write_cache=True
    )

# 3) 첫 토큰 분포: "헤더의 마지막 토큰"을 질의로 넣되 캐시는 건드리지 않음(write_cache=False)
with torch.no_grad():
    first_query = header_delta[:, -1:]   # 마지막 헤더 토큰 1개
    out = model.tri_step_logits(
        assistant_ids=first_query, lower_k=LOWER_K, pkv=pkv, S=S_len, U=U_len,
        logits_to_keep=1, labels=None, write_cache=True
    )
    logits = out.logits[:, -1, :]  # next 토큰 분포
    next_id = sample_top_p(logits, temperature=0.7, top_p=0.9).unsqueeze(0).to(device)

# 4) 루프: 직전 생성 토큰을 "기록(write_cache=True)하면서" 다음 분포를 받는다
max_new_tokens = 128
generated = []

with torch.no_grad():
    cur = next_id  # 직전 생성 토큰(이번 스텝에 캐시에 기록할 토큰)
    for step in range(max_new_tokens):
        # 기록 + 다음 분포 획득
        out = model.tri_step_logits(
            assistant_ids=cur, lower_k=LOWER_K, pkv=pkv, S=S_len, U=U_len,
            logits_to_keep=1, labels=None, write_cache=True
        )
        generated.append(int(cur.item()))
        # EOS 체크
        if tok.eos_token_id is not None and int(cur.item()) == int(tok.eos_token_id):
            break
        # 다음 토큰 샘플
        logits = out.logits[:, -1, :]
        cur = sample_top_p(logits, temperature=0.7, top_p=0.9).unsqueeze(0).to(device)

# 디코딩(헤더 이후 생성된 assistant 콘텐츠만 디코드)
print(tok.decode(generated, skip_special_tokens=True))


I'm ready to help. What's on your mind?


: 

In [9]:
def tri_generate(system_prompt: str, document: str, question: str,
                 lower_k: int = LOWER_K, max_new_tokens: int = 128,
                 temperature: float = 0.7, top_p: float = 0.9) -> str:
    msgs, S_ids, SU_ids, SU_gen, user_delta, header_delta = build_segments(system_prompt, document, question)

    with torch.no_grad():
        pkv, S_len, U_len = model.tri_build_caches(system_ids=S_ids, user_ids=user_delta, lower_k=lower_k)
        _ = model.tri_step_logits(header_delta, lower_k, pkv, S_len, U_len,
                                  logits_to_keep=0, labels=None, write_cache=True)
        # 첫 분포
        first_query = header_delta[:, -1:]
        out = model.tri_step_logits(first_query, lower_k, pkv, S_len, U_len,
                                    logits_to_keep=1, labels=None, write_cache=False)
        logits = out.logits[:, -1, :]
        cur = sample_top_p(logits, temperature=temperature, top_p=top_p).unsqueeze(0).to(device)

        generated = []
        for _ in range(max_new_tokens):
            out = model.tri_step_logits(cur, lower_k, pkv, S_len, U_len,
                                        logits_to_keep=1, labels=None, write_cache=True)
            tok_id = int(cur.item())
            generated.append(tok_id)
            if tok.eos_token_id is not None and tok_id == int(tok.eos_token_id):
                break
            logits = out.logits[:, -1, :]
            cur = sample_top_p(logits, temperature=temperature, top_p=top_p).unsqueeze(0).to(device)

    return tok.decode(generated, skip_special_tokens=True)

print(tri_generate(system_prompt, document, question, lower_k=LOWER_K))


 I'm ready to assist you. What would you like to know?


In [1]:
import torch
import transformers

print("Torch version:", torch.__version__)
print("Transformers version:", transformers.__version__)

Torch version: 2.5.1
Transformers version: 4.56.1


In [2]:
import torch
from transformers import AutoModelForCausalLM

# 1) 경로와 원본 모델 ID 지정
best_dir = "LatentCOMP_cleaned/LatentCOMP_cleaned/outputs/Llama-3.1-8B-Instruct-LOPA-partial8-0specials/best"
base_dir = f"{best_dir}/base"
orig_model_id = "meta-llama/Llama-3.1-8B-Instruct"  # 학습에 쓴 원본 모델 이름

# 2) 원본 베이스 모델 로드(캐시 사용), 저장
m = AutoModelForCausalLM.from_pretrained(
    orig_model_id,
    trust_remote_code=True,          # Llama3 계열이면 True 괜찮습니다
    cache_dir="/data2/jeongseokoh/hub/model"  # 학습 때 쓰던 캐시 경로가 있으면 지정
)
# base/config.json에 remote code 흔적이 섞이지 않도록 auto_map 클리어(있다면)
try:
    setattr(m.config, "auto_map", None)
except Exception:
    pass

m.save_pretrained(base_dir, safe_serialization=True)
del m

print("Rewrote clean base to:", base_dir)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

[2025-09-13 20:39:19,291] [INFO] [real_accelerator.py:260:get_accelerator] Setting ds_accelerator to cuda (auto detect)


/data2/jeongseokoh/miniconda3/envs/vllm/compiler_compat/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status
/data2/jeongseokoh/miniconda3/envs/vllm/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::runtime_error::~runtime_error()@GLIBCXX_3.4'
/data2/jeongseokoh/miniconda3/envs/vllm/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__gxx_personality_v0@CXXABI_1.3'
/data2/jeongseokoh/miniconda3/envs/vllm/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream::tellp()@GLIBCXX_3.4'
/data2/jeongseokoh/miniconda3/envs/vllm/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::substr(unsigned long, unsigned long) const@GLIBCXX_3.4'
/data2/jeongseokoh/miniconda3/envs/vllm/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::_M_replace_aux(unsigned long, unsigned long, unsigne

[2025-09-13 20:39:20,573] [INFO] [logging.py:107:log_dist] [Rank -1] [TorchCheckpointEngine] Initialized with serialization = False
Rewrote clean base to: LatentCOMP_cleaned/LatentCOMP_cleaned/outputs/Llama-3.1-8B-Instruct-LOPA-partial8-0specials/best/base


In [1]:
# LoPA (low-only prefill) scratch pipeline without trust_remote_code
# Torch 2.5.1, Transformers 4.56.1, GPU+SDPA
# - Phase1: lower-K only prefill (no assistant header)
# - Phase2: generation (first feed assistant header tokens step-by-step)
# - Positions unchanged; no remapping
# - LoRA: best/lora attach → merge
# - MinLength 수동 적용(4.56.1 디바이스 불일치 회피)

import os
from pathlib import Path
from typing import List, Optional, Tuple

import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.cache_utils import DynamicCache

# -----------------------------
# Config
# -----------------------------
best_dir = Path("/data2/jeongseokoh/jeongseokoh/LatentCOMP_cleaned/LatentCOMP_cleaned/outputs/Llama-3.1-8B-Instruct-LOPA-partial4-0specials/best")
lora_dir = best_dir / "lora"
backbone_id = "meta-llama/Llama-3.1-8B-Instruct"
K = 4  # prefill layers

device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = (
    torch.bfloat16
    if (device == "cuda" and torch.cuda.is_bf16_supported())
    else (torch.float16 if device == "cuda" else torch.float32)
)

cache_dir_model = "/data2/jeongseokoh/hub/model"
cache_dir_tok   = "/data2/jeongseokoh/hub/tokenizer"

# -----------------------------
# Tokenizer (best_dir 우선)
# -----------------------------
tok_src = str(best_dir) if (best_dir / "tokenizer.json").is_file() else backbone_id
tokenizer = AutoTokenizer.from_pretrained(tok_src, cache_dir=cache_dir_tok, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# -----------------------------
# Backbone + LoRA (merge)
# -----------------------------
base_model = AutoModelForCausalLM.from_pretrained(
    backbone_id,
    trust_remote_code=False,
    torch_dtype=dtype,
    cache_dir=cache_dir_model,
)
base_model.to(device).eval()

from peft import PeftModel
assert lora_dir.is_dir(), f"LoRA folder not found: {lora_dir}"
peft_model = PeftModel.from_pretrained(base_model, str(lora_dir))
model = peft_model.merge_and_unload().to(device).eval()
del peft_model

# SDPA 고정 (필요시 eager로 비교)
for k in ("attn_implementation", "_attn_implementation"):
    try: setattr(model.config, k, "sdpa")
    except: pass
try:
    for k in ("attn_implementation", "_attn_implementation"):
        setattr(model.model.config, k, "sdpa")
except: pass

# -----------------------------
# Cache helpers
# -----------------------------
def pkv_len(pkv) -> int:
    if hasattr(pkv, "key_cache"): return len(pkv.key_cache)
    if hasattr(pkv, "layers"):    return len(pkv.layers)
    try: return len(pkv)
    except: return 0

def pkv_get(pkv, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
    if hasattr(pkv, "key_cache") and hasattr(pkv, "value_cache"):
        return pkv.key_cache[idx], pkv.value_cache[idx]
    if hasattr(pkv, "layers"):
        layer = pkv.layers[idx]
        return layer.keys, layer.values
    return pkv[idx]

def dc_from_subset(pkv_src, layer_indices: List[int]) -> DynamicCache:
    dc = DynamicCache()
    for li in layer_indices:
        k, v = pkv_get(pkv_src, li)
        dc.update(k, v, li)  # layer index 유지
    return dc

def cache_slice_doc_only(k: torch.Tensor, v: torch.Tensor, doc_len: int):
    return k[:, :, -doc_len:, :], v[:, :, -doc_len:, :]

def build_mask(length: int, batch: int = 1):
    return torch.ones(batch, length, device=device, dtype=torch.long)

def build_pos_ids(start: int, length: int, batch: int = 1):
    return torch.arange(start, start + length, device=device, dtype=torch.long).unsqueeze(0).expand(batch, -1)

# -----------------------------
# Prompt builders (Gen3)
# -----------------------------
def build_messages(system: str, document: str, question: str, include_query: bool = True):
    user = f"Document:\n{document}\n\nQuestion: {question}" if include_query else f"Document:\n{document}\n\n"
    return [{"role":"system","content":system},{"role":"user","content":user}]

def apply_chat_template(messages, add_generation_prompt: bool):
    try:
        return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=add_generation_prompt)
    except TypeError:
        s = tokenizer.apply_chat_template(messages, tokenize=False)
        if add_generation_prompt:
            s += "<|start_header_id|>assistant<|end_header_id|>\n\n"
        return s

def tokens_from_messages(messages, add_generation_prompt=False):
    s = apply_chat_template(messages, add_generation_prompt)
    return tokenizer(s, add_special_tokens=False, return_tensors="pt").input_ids.to(device)

# -----------------------------
# LoPA generate (low-only prefill; no upper continue)
# -----------------------------
@torch.inference_mode()
def lopa_generate(
    system: str,
    document: str,
    question: str,
    *,
    K: int = 4,
    max_new_tokens: int = 256,
    min_length: int = 16,
    temperature: float = 0.7,
    top_p: float = 0.9,
    top_k: Optional[int] = None,
    do_sample: bool = True,
    repetition_penalty: Optional[float] = None,
    start_with_assistant_header: bool = True, # Phase2 첫 토큰: header를 단계적으로 투입
    force_attn_impl: Optional[str] = None,   # "sdpa" | "eager"
):
    if force_attn_impl:
        for k in ("attn_implementation", "_attn_implementation"):
            try:
                setattr(model.config, k, force_attn_impl)
                setattr(model.model.config, k, force_attn_impl)
            except Exception:
                pass

    # Phase-1 ids
    msgs = build_messages(system, document, question, include_query=True)
    ids_phase1 = tokens_from_messages(msgs, add_generation_prompt=False)  # [1, L_sys+doc]
    ids_hdr    = tokens_from_messages(msgs, add_generation_prompt=True)   # [1, L_sys+doc + header]
    sys_only   = tokens_from_messages([{"role":"system","content":system}], add_generation_prompt=False)

    L_sys = sys_only.size(1)
    L_all = ids_phase1.size(1)
    L_doc = L_all - L_sys
    assert L_doc > 0, "Phase-1 must include document tokens."

    # 1) System-only prefill (base model)
    out_sys = model.model(
        input_ids=sys_only,
        attention_mask=build_mask(L_sys),
        use_cache=True,
        return_dict=True,
    )
    pkv_sys = out_sys.past_key_values
    n_layers = pkv_len(pkv_sys)

    # 2) Lower-K doc pass (base model) — no upper continue
    K_eff = max(0, min(K, n_layers))
    full_layers: nn.ModuleList = model.model.layers
    lower_layers = nn.ModuleList([full_layers[i] for i in range(0, K_eff)])

    model.model.layers = lower_layers
    dc_low_in = dc_from_subset(pkv_sys, list(range(K_eff))) if K_eff > 0 else DynamicCache()
    attn_doc_full = torch.cat([build_mask(L_sys), build_mask(L_doc)], dim=1)

    out_low = model.model(
        input_ids=ids_phase1[:, L_sys:],  # doc only
        past_key_values=dc_low_in,
        attention_mask=attn_doc_full,     # sys+doc 길이
        use_cache=True,
        return_dict=True,
    )
    pkv_low = out_low.past_key_values

    # 복원
    model.model.layers = full_layers

    # 3) Combine caches
    # lower(0..K-1): sys + doc
    # upper(K..):    sys only  (doc 미포함; generation 시작 시점부터 쌓임)
    combined = DynamicCache()
    # 헤드/차원 shape 확보
    k0_sys, v0_sys = pkv_get(pkv_sys, 0)
    for li in range(n_layers):
        k_sys, v_sys = pkv_get(pkv_sys, li)
        k_sys = k_sys[:, :, :L_sys, :]
        v_sys = v_sys[:, :, :L_sys, :]
        if li < K_eff:
            k_low, v_low = pkv_get(pkv_low, li)
            k_doc, v_doc = cache_slice_doc_only(k_low, v_low, L_doc)
            k_cat = torch.cat([k_sys, k_doc], dim=2).contiguous()
            v_cat = torch.cat([v_sys, v_doc], dim=2).contiguous()
            combined.update(k_cat, v_cat, li)
        else:
            # upper: sys only
            combined.update(k_sys.contiguous(), v_sys.contiguous(), li)

    # 4) Phase-2: Generation
    # seed: assistant header를 한 토큰씩 밀어 넣어 upper past가 L_sys → L_sys+H로 자라도록 함
    if start_with_assistant_header:
        hdr_tail = ids_hdr[:, L_all:]  # header-only tokens (len H)
        H = int(hdr_tail.size(1))
        for j in range(H):
            past_len = pkv_get(combined, 0)[0].shape[2]  # lower 기준: L_sys+L_doc+grown
            attn_mask = torch.cat([build_mask(past_len), build_mask(1)], dim=1)
            step_tok = hdr_tail[:, j:j+1]
            out_seed = model(
                input_ids=step_tok,
                past_key_values=combined,
                attention_mask=attn_mask,
                use_cache=True,
                return_dict=True,
            )
            combined = out_seed.past_key_values

    # 5) Decoding (CausalLM) with safe processors
    from transformers.generation import LogitsProcessorList
    from transformers.generation.logits_process import (
        TemperatureLogitsWarper, TopPLogitsWarper, TopKLogitsWarper,
        RepetitionPenaltyLogitsProcessor,
    )

    eos_id = tokenizer.eos_token_id
    generated_ids = torch.empty((1, 0), dtype=torch.long, device=device)
    # 첫 생성 스텝의 입력은 "직전 토큰" (header 마지막 또는 마지막 doc)
    last = (ids_hdr[:, -1:] if start_with_assistant_header else ids_phase1[:, -1:])

    processors = LogitsProcessorList()
    if repetition_penalty and repetition_penalty != 1.0:
        processors.append(RepetitionPenaltyLogitsProcessor(penalty=float(repetition_penalty)))
    if temperature and temperature != 1.0:
        processors.append(TemperatureLogitsWarper(temperature=float(temperature)))
    if top_k is not None and top_k > 0:
        processors.append(TopKLogitsWarper(top_k=int(top_k), filter_value=-float("inf")))
    if top_p is not None and top_p < 1.0:
        processors.append(TopPLogitsWarper(top_p=float(top_p), min_tokens_to_keep=1))

    cur = 0
    while cur < max_new_tokens:
        # lower past_len은 L_sys+L_doc + grown; upper past_len은 L_sys + grown
        past_len = pkv_get(combined, 0)[0].shape[2]
        attn_mask = torch.cat([build_mask(past_len), build_mask(1)], dim=1)

        out = model(
            input_ids=last,
            past_key_values=combined,
            attention_mask=attn_mask,
            use_cache=True,
            return_dict=True,
        )
        combined = out.past_key_values
        logits = out.logits[:, -1, :].to(torch.float32)

        # 즉시 EOS 방지 (min_length 수동 적용)
        if eos_id is not None and cur < min_length:
            logits[:, eos_id] = -float("inf")

        if not torch.isfinite(logits).all():
            logits = torch.nan_to_num(logits, nan=-1e9, posinf=1e9, neginf=-1e9)

        inp = generated_ids if generated_ids.numel() > 0 else last.new_zeros((1, 0), dtype=torch.long, device=device)
        inp = inp.to(logits.device)
        logits = processors(inp, logits)

        # Fallback-safe sampling
        if not torch.isfinite(logits).any():
            next_tok = torch.argmax(logits, dim=-1, keepdim=True)
        else:
            if do_sample:
                probs = torch.softmax(logits, dim=-1)
                if (not torch.isfinite(probs).all()) or (probs.sum(dim=-1) <= 0).any():
                    next_tok = torch.argmax(logits, dim=-1, keepdim=True)
                else:
                    next_tok = torch.multinomial(probs, num_samples=1)
            else:
                next_tok = torch.argmax(logits, dim=-1, keepdim=True)

        if eos_id is not None and int(next_tok.item()) == eos_id:
            break

        generated_ids = torch.cat([generated_ids, next_tok], dim=1)
        last = next_tok
        cur += 1

    return tokenizer.decode(generated_ids[0], skip_special_tokens=True)

# -----------------------------
# Example
# -----------------------------
system = "You are a helpful assistant that answers questions based on the given document."
document = "The Nile is the longest river in Africa and flows northward through several countries..."
question = "Which continent is the Nile the longest river in?"

answer = lopa_generate(
    system, document, question,
    K=K,
    max_new_tokens=256,
    min_length=16,
    temperature=0.7, top_p=0.9,
    start_with_assistant_header=True,  # Phase2 처음에 header를 단계적으로 투입
    # force_attn_impl="eager",         # 필요시 비교
)
print(answer)


`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

I'm ready to help. What would you like to know? Please go ahead and ask your question.


In [None]:
# LoPA (low-only prefill) scratch pipeline without trust_remote_code
# Torch 2.5.1, Transformers 4.56.1, GPU+SDPA
# - Phase1: lower-K only prefill (no assistant header)
# - Phase2: generation (first feed assistant header tokens step-by-step)
# - Positions unchanged; no remapping
# - LoRA: best/lora attach → merge
# - MinLength 수동 적용(4.56.1 디바이스 불일치 회피)

import os
from pathlib import Path
from typing import List, Optional, Tuple

import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.cache_utils import DynamicCache

# -----------------------------
# Config
# -----------------------------
best_dir = Path("/data2/jeongseokoh/jeongseokoh/LatentCOMP_cleaned/LatentCOMP_cleaned/outputs/Llama-3.1-8B-Instruct-LOPA-partial4-0specials/best")
lora_dir = best_dir / "lora"
backbone_id = "meta-llama/Llama-3.1-8B-Instruct"
K = 4  # prefill layers

device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = (
    torch.bfloat16
    if (device == "cuda" and torch.cuda.is_bf16_supported())
    else (torch.float16 if device == "cuda" else torch.float32)
)

cache_dir_model = "/data2/jeongseokoh/hub/model"
cache_dir_tok   = "/data2/jeongseokoh/hub/tokenizer"

# -----------------------------
# Tokenizer (best_dir 우선)
# -----------------------------
tok_src = str(best_dir) if (best_dir / "tokenizer.json").is_file() else backbone_id
tokenizer = AutoTokenizer.from_pretrained(tok_src, cache_dir=cache_dir_tok, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# -----------------------------
# Backbone + LoRA (merge)
# -----------------------------
base_model = AutoModelForCausalLM.from_pretrained(
    backbone_id,
    trust_remote_code=False,
    torch_dtype=dtype,
    cache_dir=cache_dir_model,
)
base_model.to(device).eval()

from peft import PeftModel
assert lora_dir.is_dir(), f"LoRA folder not found: {lora_dir}"
peft_model = PeftModel.from_pretrained(base_model, str(lora_dir))
model = peft_model.merge_and_unload().to(device).eval()
del peft_model

# SDPA 고정 (필요시 eager로 비교)
for k in ("attn_implementation", "_attn_implementation"):
    try: setattr(model.config, k, "sdpa")
    except: pass
try:
    for k in ("attn_implementation", "_attn_implementation"):
        setattr(model.model.config, k, "sdpa")
except: pass

# -----------------------------
# Cache helpers
# -----------------------------
def pkv_len(pkv) -> int:
    if hasattr(pkv, "key_cache"): return len(pkv.key_cache)
    if hasattr(pkv, "layers"):    return len(pkv.layers)
    try: return len(pkv)
    except: return 0

def pkv_get(pkv, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
    if hasattr(pkv, "key_cache") and hasattr(pkv, "value_cache"):
        return pkv.key_cache[idx], pkv.value_cache[idx]
    if hasattr(pkv, "layers"):
        layer = pkv.layers[idx]
        return layer.keys, layer.values
    return pkv[idx]

def dc_from_subset(pkv_src, layer_indices: List[int]) -> DynamicCache:
    dc = DynamicCache()
    for li in layer_indices:
        k, v = pkv_get(pkv_src, li)
        dc.update(k, v, li)  # layer index 유지
    return dc

def cache_slice_doc_only(k: torch.Tensor, v: torch.Tensor, doc_len: int):
    return k[:, :, -doc_len:, :], v[:, :, -doc_len:, :]

def build_mask(length: int, batch: int = 1):
    return torch.ones(batch, length, device=device, dtype=torch.long)

def build_pos_ids(start: int, length: int, batch: int = 1):
    return torch.arange(start, start + length, device=device, dtype=torch.long).unsqueeze(0).expand(batch, -1)

# -----------------------------
# Prompt builders (Gen3)
# -----------------------------
def build_messages(system: str, document: str, question: str, include_query: bool = True):
    user = f"Document:\n{document}\n\nQuestion: {question}" if include_query else f"Document:\n{document}\n\n"
    return [{"role":"system","content":system},{"role":"user","content":user}]

def apply_chat_template(messages, add_generation_prompt: bool):
    try:
        return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=add_generation_prompt)
    except TypeError:
        s = tokenizer.apply_chat_template(messages, tokenize=False)
        if add_generation_prompt:
            s += "<|start_header_id|>assistant<|end_header_id|>\n\n"
        return s

def tokens_from_messages(messages, add_generation_prompt=False):
    s = apply_chat_template(messages, add_generation_prompt)
    return tokenizer(s, add_special_tokens=False, return_tensors="pt").input_ids.to(device)

# -----------------------------
# LoPA generate (low-only prefill; no upper continue)
# -----------------------------
@torch.inference_mode()
def lopa_generate(
    system: str,
    document: str,
    question: str,
    *,
    K: int = 4,
    max_new_tokens: int = 256,
    min_length: int = 16,
    temperature: float = 0.7,
    top_p: float = 0.9,
    top_k: Optional[int] = None,
    do_sample: bool = True,
    repetition_penalty: Optional[float] = None,
    start_with_assistant_header: bool = True, # Phase2 첫 토큰: header를 단계적으로 투입
    force_attn_impl: Optional[str] = None,   # "sdpa" | "eager"
):
    if force_attn_impl:
        for k in ("attn_implementation", "_attn_implementation"):
            try:
                setattr(model.config, k, force_attn_impl)
                setattr(model.model.config, k, force_attn_impl)
            except Exception:
                pass

    # Phase-1 ids
    msgs = build_messages(system, document, question, include_query=True)
    ids_phase1 = tokens_from_messages(msgs, add_generation_prompt=False)  # [1, L_sys+doc]
    ids_hdr    = tokens_from_messages(msgs, add_generation_prompt=True)   # [1, L_sys+doc + header]
    sys_only   = tokens_from_messages([{"role":"system","content":system}], add_generation_prompt=False)

    L_sys = sys_only.size(1)
    L_all = ids_phase1.size(1)
    L_doc = L_all - L_sys
    assert L_doc > 0, "Phase-1 must include document tokens."

    # 1) System-only prefill (base model)
    out_sys = model.model(
        input_ids=sys_only,
        attention_mask=build_mask(L_sys),
        use_cache=True,
        return_dict=True,
    )
    pkv_sys = out_sys.past_key_values
    n_layers = pkv_len(pkv_sys)

    # 2) Lower-K doc pass (base model) — no upper continue
    K_eff = max(0, min(K, n_layers))
    full_layers: nn.ModuleList = model.model.layers
    lower_layers = nn.ModuleList([full_layers[i] for i in range(0, K_eff)])

    model.model.layers = lower_layers
    dc_low_in = dc_from_subset(pkv_sys, list(range(K_eff))) if K_eff > 0 else DynamicCache()
    attn_doc_full = torch.cat([build_mask(L_sys), build_mask(L_doc)], dim=1)

    out_low = model.model(
        input_ids=ids_phase1[:, L_sys:],  # doc only
        past_key_values=dc_low_in,
        attention_mask=attn_doc_full,     # sys+doc 길이
        use_cache=True,
        return_dict=True,
    )
    pkv_low = out_low.past_key_values

    # 복원
    model.model.layers = full_layers

    # 3) Combine caches
    # lower(0..K-1): sys + doc
    # upper(K..):    sys only  (doc 미포함; generation 시작 시점부터 쌓임)
    combined = DynamicCache()
    # 헤드/차원 shape 확보
    k0_sys, v0_sys = pkv_get(pkv_sys, 0)
    for li in range(n_layers):
        k_sys, v_sys = pkv_get(pkv_sys, li)
        k_sys = k_sys[:, :, :L_sys, :]
        v_sys = v_sys[:, :, :L_sys, :]
        if li < K_eff:
            k_low, v_low = pkv_get(pkv_low, li)
            k_doc, v_doc = cache_slice_doc_only(k_low, v_low, L_doc)
            k_cat = torch.cat([k_sys, k_doc], dim=2).contiguous()
            v_cat = torch.cat([v_sys, v_doc], dim=2).contiguous()
            combined.update(k_cat, v_cat, li)
        else:
            # upper: sys only
            combined.update(k_sys.contiguous(), v_sys.contiguous(), li)

    # 4) Phase-2: Generation
    # seed: assistant header를 한 토큰씩 밀어 넣어 upper past가 L_sys → L_sys+H로 자라도록 함
    if start_with_assistant_header:
        hdr_tail = ids_hdr[:, L_all:]  # header-only tokens (len H)
        H = int(hdr_tail.size(1))
        for j in range(H):
            past_len = pkv_get(combined, 0)[0].shape[2]  # lower 기준: L_sys+L_doc+grown
            attn_mask = torch.cat([build_mask(past_len), build_mask(1)], dim=1)
            step_tok = hdr_tail[:, j:j+1]
            out_seed = model(
                input_ids=step_tok,
                past_key_values=combined,
                attention_mask=attn_mask,
                use_cache=True,
                return_dict=True,
            )
            combined = out_seed.past_key_values

    # 5) Decoding (CausalLM) with safe processors
    from transformers.generation import LogitsProcessorList
    from transformers.generation.logits_process import (
        TemperatureLogitsWarper, TopPLogitsWarper, TopKLogitsWarper,
        RepetitionPenaltyLogitsProcessor,
    )

    eos_id = tokenizer.eos_token_id
    generated_ids = torch.empty((1, 0), dtype=torch.long, device=device)
    # 첫 생성 스텝의 입력은 "직전 토큰" (header 마지막 또는 마지막 doc)
    last = (ids_hdr[:, -1:] if start_with_assistant_header else ids_phase1[:, -1:])

    processors = LogitsProcessorList()
    if repetition_penalty and repetition_penalty != 1.0:
        processors.append(RepetitionPenaltyLogitsProcessor(penalty=float(repetition_penalty)))
    if temperature and temperature != 1.0:
        processors.append(TemperatureLogitsWarper(temperature=float(temperature)))
    if top_k is not None and top_k > 0:
        processors.append(TopKLogitsWarper(top_k=int(top_k), filter_value=-float("inf")))
    if top_p is not None and top_p < 1.0:
        processors.append(TopPLogitsWarper(top_p=float(top_p), min_tokens_to_keep=1))

    cur = 0
    while cur < max_new_tokens:
        # lower past_len은 L_sys+L_doc + grown; upper past_len은 L_sys + grown
        past_len = pkv_get(combined, 0)[0].shape[2]
        attn_mask = torch.cat([build_mask(past_len), build_mask(1)], dim=1)

        out = model(
            input_ids=last,
            past_key_values=combined,
            attention_mask=attn_mask,
            use_cache=True,
            return_dict=True,
        )
        combined = out.past_key_values
        logits = out.logits[:, -1, :].to(torch.float32)

        # 즉시 EOS 방지 (min_length 수동 적용)
        if eos_id is not None and cur < min_length:
            logits[:, eos_id] = -float("inf")

        if not torch.isfinite(logits).all():
            logits = torch.nan_to_num(logits, nan=-1e9, posinf=1e9, neginf=-1e9)

        inp = generated_ids if generated_ids.numel() > 0 else last.new_zeros((1, 0), dtype=torch.long, device=device)
        inp = inp.to(logits.device)
        logits = processors(inp, logits)

        # Fallback-safe sampling
        if not torch.isfinite(logits).any():
            next_tok = torch.argmax(logits, dim=-1, keepdim=True)
        else:
            if do_sample:
                probs = torch.softmax(logits, dim=-1)
                if (not torch.isfinite(probs).all()) or (probs.sum(dim=-1) <= 0).any():
                    next_tok = torch.argmax(logits, dim=-1, keepdim=True)
                else:
                    next_tok = torch.multinomial(probs, num_samples=1)
            else:
                next_tok = torch.argmax(logits, dim=-1, keepdim=True)

        if eos_id is not None and int(next_tok.item()) == eos_id:
            break

        generated_ids = torch.cat([generated_ids, next_tok], dim=1)
        last = next_tok
        cur += 1

    return tokenizer.decode(generated_ids[0], skip_special_tokens=True)

# -----------------------------
# Example
# -----------------------------
system = "You are a helpful assistant that answers questions based on the given document."
document = "The Nile is the longest river in Africa and flows northward through several countries..."
question = "Which continent is the Nile the longest river in?"

answer = lopa_generate(
    system, document, question,
    K=K,
    max_new_tokens=256,
    min_length=16,
    temperature=0.7, top_p=0.9,
    start_with_assistant_header=True,  # Phase2 처음에 header를 단계적으로 투입
    # force_attn_impl="eager",         # 필요시 비교
)
print(answer)


`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

[Lens] L_sys=40, L_doc=35, L_all=75, header_len=4
[Layers] n_layers=32, K_eff=4
[sys] n_layers=32 | L0:K(1, 8, 40, 128) | L1:K(1, 8, 40, 128) | L31:K(1, 8, 40, 128)
[OK]   system_prefill: all present layers seq_len == 40
[low] h_low=(1, 35, 4096)
[low] n_layers=4 | L0:K(1, 8, 75, 128) | L1:K(1, 8, 75, 128) | L3:K(1, 8, 75, 128)
[up] n_layers=32 | L4:K(1, 8, 75, 128) | L5:K(1, 8, 75, 128) | L31:K(1, 8, 75, 128)
[combined(before_hdr)] n_layers=32 | L0:K(1, 8, 75, 128) | L1:K(1, 8, 75, 128) | L31:K(1, 8, 75, 128)
[OK]   combined_prefill: all present layers seq_len == 75
[hdr] past_len before=75, after=79, delta=4
[combined(after_hdr)] n_layers=32 | L0:K(1, 8, 79, 128) | L1:K(1, 8, 79, 128) | L31:K(1, 8, 79, 128)
[decode step 0] past_len 79 -> 80 | next_id=40
[decode step 1] past_len 80 -> 81 | next_id=2846
[decode step 2] past_len 81 -> 82 | next_id=539
[debug answer] I'm not able to answer your question as I don't have any information about a document. Can you please provide the document

In [1]:
import sys, torch
from transformers import AutoTokenizer

model_dir = "/data2/jeongseokoh/jeongseokoh/LatentCOMP_cleaned/LatentCOMP_cleaned/outputs/Llama-3.1-8B-Instruct-LOPA-partial4-0specials/best"
prefill_layers = 32
device = "cuda" if torch.cuda.is_available() else "cpu"

sys.path.insert(0, model_dir)  # best/ 를 import 경로에 추가
from modeling_partial_layer import LlamaForCausalLM

tok = AutoTokenizer.from_pretrained(model_dir)
model = LlamaForCausalLM.from_pretrained(model_dir, device_map="cuda:0", dtype=torch.bfloat16, trust_remote_code=True).to(device).eval()
for k in ("attn_implementation", "_attn_implementation"):
    setattr(model.config, k, "sdpa")
system = "You are a helpful assistant that answers questions based on the given document."
document = "The Nile is the longest river in Africa..."
question = "Which continent is the Nile the longest river in?"

out = model.generate(
    system=system, document=document, query=question,
    compress=False, tokenizer=tok, prefill_layers=prefill_layers,
    max_new_tokens=256, do_sample=False, temperature=0.7, top_p=0.9
)
print(tok.batch_decode(out, skip_special_tokens=True))  # 문자열 또는 문자열 리스트


Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


['system\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYou are a helpful assistant that answers questions based on the given document.user\n\nDocument:\nThe Nile is the longest river in Africa...\n\nQuestion: Which continent is the Nile the longest river in?assistant\n\nThe Nile is the longest river in Africa.']


In [2]:
import os
os.environ["LATENTRAG_DEBUG"]="1"
print(model.lora_debug_status())

peft_available=True | attached=True | merged=False | is_sharded=True | source=/data2/jeongseokoh/jeongseokoh/LatentCOMP_cleaned/LatentCOMP_cleaned/outputs/Llama-3.1-8B-Instruct-LOPA-partial4-0specials/best/lora | error=None | last_used_lora=True


In [None]:
import sys, torch
from pathlib import Path
from transformers import AutoTokenizer

# 경로 설정 (혼동 방지를 위해 절대경로 권장)
best_dir = Path("/data2/jeongseokoh/jeongseokoh/LatentCOMP_cleaned/LatentCOMP_cleaned/outputs/Llama-3.1-8B-Instruct-LOPA-partial4-0specials/best")
lora_dir = best_dir / "lora"
prefill_layers = 4  # 학습 시 사용한 값과 동일하게

device = "cuda" if torch.cuda.is_available() else "cpu"

# best/의 remote-code(LOPA 구현)를 import할 수 있게 경로 추가
sys.path.insert(0, str(best_dir))
from modeling_partial_layer import LlamaForCausalLM  # Llama 3.1 계열용

# 토크나이저: chat template 포함
tok = AutoTokenizer.from_pretrained(str(best_dir))
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

# 1) 백본을 허깅페이스에서 ‘깨끗하게’ 로드 (백본=llama3.1 8B instruct)
#    주의: trust_remote_code=False → 표준 HF 클래스 경로를 통해 가중치만 로드
backbone_id = "meta-llama/Llama-3.1-8B-Instruct"
base_model = LlamaForCausalLM.from_pretrained(
    backbone_id,
    device_map="cuda:0",
    dtype=torch.bfloat16,
    trust_remote_code=False,  # 가중치만 가져오고 동작은 our remote-code 클래스가 담당
    cache_dir="/data2/jeongseokoh/hub/model",
).to(device).eval()

# 2) LoRA 어댑터 attach 후 merge (단일 모델로)
from peft import PeftModel
peft_model = PeftModel.from_pretrained(base_model, str(lora_dir))
merged = peft_model.merge_and_unload().to(device).eval()

# 3) LOPA 추론 (compress=True → partial prefill)
#    이미 merge된 모델이므로 추가 attach를 막기 위해 use_lora=False로 고정
system = "You are a helpful assistant that answers questions based on the given document."
document = "The Nile is the longest river in Africa and flows northward through several countries..."
question = "Which continent is the Nile the longest river in?"

out = merged.generate(
    system=system,
    document=document,
    query=question,
    compress=True,                 # LOPA 경로
    tokenizer=tok,                 # 필수
    prefill_layers=prefill_layers, # 학습과 동일
    use_lora=False,                # 이미 merge됐으므로 재-부착 방지
    max_new_tokens=256,
    do_sample=True,
    temperature=0.7,
    top_p=0.9,
)
print(out)  # 문자열 또는 문자열 리스트


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Fetching 0 files: 0it [00:00, ?it/s]

assistant<|end_header_id|>assistant<|end_header_id|>

Africa.<|eot_id|>
