In [None]:
# ===== 0) Drive mount + import path =====
from google.colab import drive
drive.mount('/content/drive')

import os, sys, json, math, tarfile, shutil
from typing import Dict, Any, List, Optional, Tuple

# hf_ax4_downloader.py 파일이 이 폴더 안에 있어야 합니다.
CODE_DIR = "/content/drive/MyDrive/TRAITHON/code"
if CODE_DIR not in sys.path:
    sys.path.append(CODE_DIR)

import torch

# (선택) TF32 허용 (A100에서 속도 이점 있을 수 있음)
if torch.cuda.is_available():
     torch.backends.cuda.matmul.allow_tf32 = True

# 너희 모듈
import clickbait_preprocess as cbp
import ax4_clickbait_scorer as ax4
import hf_ax4_downloader as ax4_dl


# ===== 1) (옵션) PEFT 설치/로드 =====
try:
    import peft
    from peft import PeftModel
except Exception:
    !pip -q install peft
    from peft import PeftModel


# ===== 2) 모델 로더 (어답터 ON/OFF 토글 가능 구조) =====
def safe_exp(x: float, cap: float = 80.0) -> Optional[float]:
    if x is None:
        return None
    if x > cap:
        return None
    return float(math.exp(x))


class PPLModelWrapper:
    def __init__(
        self,
        model_path: str,
        adapter_path: Optional[str] = None,
        device: Optional[str] = None,
        dtype: Optional[torch.dtype] = None
    ):
        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        if dtype is None:
            dtype = torch.bfloat16 if device == "cuda" else torch.float32

        self.device_str = device
        self.dtype = dtype

        from transformers import AutoModelForCausalLM, AutoTokenizer
        print(f"[ppl] Loading base model: {model_path}")

        try:
            self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=True)
        except TypeError:
            self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

        if self.tokenizer.pad_token is None and self.tokenizer.eos_token is not None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        self.base_model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype=dtype,
            device_map="auto" if device == "cuda" else None,
            trust_remote_code=True,
        )
        self.base_model.eval()

        self.model = self.base_model
        self.has_adapter = False

        if adapter_path:
            print(f"[ppl] Attaching adapter: {adapter_path}")
            self.model = PeftModel.from_pretrained(self.base_model, adapter_path)
            self.model.eval()
            self.has_adapter = True

        self.run_device = next(self.model.parameters()).device
        print(f"[ppl] Ready. run_device={self.run_device}")

    def _forward_loss(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels: torch.Tensor) -> float:
        with torch.inference_mode():
            if str(self.run_device).startswith("cuda"):
                with torch.autocast(device_type="cuda", dtype=self.dtype):
                    out = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            else:
                out = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            return float(out.loss.detach().float().cpu().item())

    def _disable_adapter_ctx(self):
        if hasattr(self.model, "disable_adapter"):
            return self.model.disable_adapter()
        raise RuntimeError("PEFT model does not support disable_adapter().")

    def compute_loss_pair(self, input_ids, attention_mask, labels) -> Tuple[float, float]:
        if self.has_adapter:
            with self._disable_adapter_ctx():
                loss_off = self._forward_loss(input_ids, attention_mask, labels)
            loss_on = self._forward_loss(input_ids, attention_mask, labels)
        else:
            loss_off = self._forward_loss(input_ids, attention_mask, labels)
            loss_on = loss_off
        return loss_off, loss_on

    def compute_ppl_pair(self, input_ids, attention_mask, labels, ppl_cap_exp: float = 80.0):
        loss_off, loss_on = self.compute_loss_pair(input_ids, attention_mask, labels)
        ppl_off = safe_exp(loss_off, cap=ppl_cap_exp)
        ppl_on  = safe_exp(loss_on,  cap=ppl_cap_exp)
        return ppl_off, ppl_on, loss_off, loss_on


# ===== 3) 마스킹 로직 =====
def _find_subsequence(hay: List[int], needle: List[int]) -> Optional[int]:
    if not needle:
        return None
    n, m = len(hay), len(needle)
    if m > n:
        return None
    first = needle[0]
    for i in range(0, n - m + 1):
        if hay[i] != first:
            continue
        if hay[i:i+m] == needle:
            return i
    return None


def build_masked_inputs_for_ppl(flat_article: Dict[str, Any], tokenizer, strict_span: bool = False):
    messages = ax4.build_messages_for_binary(flat_article)
    article_text = ax4.build_article_text(flat_article)
    rendered = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)

    enc_prompt = tokenizer(rendered, add_special_tokens=False, return_tensors="pt", return_attention_mask=True)
    input_ids = enc_prompt["input_ids"]
    attention_mask = enc_prompt.get("attention_mask", torch.ones_like(input_ids))

    # (1) fast tokenizer면 offset_mapping 우선
    if getattr(tokenizer, "is_fast", False):
        try:
            start_char = rendered.find(article_text)
            if start_char < 0:
                rendered_norm = rendered.replace("\r\n", "\n")
                article_norm = article_text.replace("\r\n", "\n")
                start_char = rendered_norm.find(article_norm)
                if start_char >= 0:
                    rendered = rendered_norm
                    article_text = article_norm
                    enc_prompt = tokenizer(
                        rendered,
                        add_special_tokens=False,
                        return_offsets_mapping=True,
                        return_tensors="pt",
                        return_attention_mask=True
                    )
                    input_ids = enc_prompt["input_ids"]
                    attention_mask = enc_prompt.get("attention_mask", torch.ones_like(input_ids))
                else:
                    raise RuntimeError("article_text not found")
            else:
                enc_prompt = tokenizer(
                    rendered,
                    add_special_tokens=False,
                    return_offsets_mapping=True,
                    return_tensors="pt",
                    return_attention_mask=True
                )
                input_ids = enc_prompt["input_ids"]
                attention_mask = enc_prompt.get("attention_mask", torch.ones_like(input_ids))

            end_char = start_char + len(article_text)
            offsets = enc_prompt["offset_mapping"][0].tolist()
            start_tok, end_tok = None, None

            for i, (s, e) in enumerate(offsets):
                if s == 0 and e == 0:
                    continue
                if start_tok is None and e > start_char:
                    start_tok = i
                if start_tok is not None and s >= end_char:
                    end_tok = i
                    break

            if start_tok is None:
                raise RuntimeError("Failed map")
            if end_tok is None:
                end_tok = len(offsets)

            labels = input_ids.clone()
            labels[:] = -100
            labels[:, start_tok:end_tok] = input_ids[:, start_tok:end_tok]
            return input_ids, attention_mask, labels, {"method": "offset_mapping", "span_tok": [start_tok, end_tok]}
        except Exception:
            pass

    # (2) subsequence fallback
    try:
        enc_article = tokenizer(article_text, add_special_tokens=False, return_tensors="pt")
        article_ids = enc_article["input_ids"][0].tolist()
        prompt_ids = input_ids[0].tolist()
        idx = _find_subsequence(prompt_ids, article_ids)

        if idx is None:
            rendered_norm = rendered.replace("\r\n", "\n")
            article_norm = article_text.replace("\r\n", "\n")
            if (rendered_norm != rendered) or (article_norm != article_text):
                enc_prompt2 = tokenizer(rendered_norm, add_special_tokens=False, return_tensors="pt", return_attention_mask=True)
                enc_article2 = tokenizer(article_norm, add_special_tokens=False, return_tensors="pt")
                idx2 = _find_subsequence(enc_prompt2["input_ids"][0].tolist(), enc_article2["input_ids"][0].tolist())
                if idx2 is not None:
                    input_ids = enc_prompt2["input_ids"]
                    attention_mask = enc_prompt2.get("attention_mask", torch.ones_like(input_ids))
                    idx = idx2
                    article_ids = enc_article2["input_ids"][0].tolist()

        if idx is None:
            raise RuntimeError("Subsequence match failed")

        start_tok, end_tok = idx, idx + len(article_ids)
        labels = input_ids.clone()
        labels[:] = -100
        labels[:, start_tok:end_tok] = input_ids[:, start_tok:end_tok]
        return input_ids, attention_mask, labels, {"method": "subsequence", "span_tok": [start_tok, end_tok]}

    except Exception as e:
        if strict_span:
            raise
        labels = input_ids.clone()
        return input_ids, attention_mask, labels, {"method": "fallback_full", "reason": repr(e)}


# ===== 4) 폴더 처리 + 결과 저장 =====
def list_json_files(root: str) -> List[str]:
    paths = []
    for cur, dirs, files in os.walk(root):
        dirs.sort()
        files.sort()
        for fn in files:
            if fn.lower().endswith(".json"):
                paths.append(os.path.join(cur, fn))
    paths.sort()  # ★ [수정] 재현성/검증 안정화
    return paths


def process_folder_to_jsonl(
    src_root,
    out_jsonl_path,
    model_path,
    adapter_path=None,
    shard_id=0,
    num_shards=1,
    max_files=None,
    ppl_cap_exp=80.0,
    strict_span=False
):
    os.makedirs(os.path.dirname(out_jsonl_path), exist_ok=True)

    all_paths = list_json_files(src_root)
    if num_shards > 1:
        all_paths = [p for i, p in enumerate(all_paths) if (i % num_shards) == shard_id]
    if max_files is not None:
        all_paths = all_paths[:max_files]

    print(f"[run] src_root={src_root}")
    print(f"[run] total_files={len(all_paths)} (shard {shard_id}/{num_shards})")

    wrapper = PPLModelWrapper(model_path=model_path, adapter_path=adapter_path)
    device = wrapper.run_device
    n_ok, n_err = 0, 0

    with open(out_jsonl_path, "w", encoding="utf-8") as w:
        for idx, path in enumerate(all_paths):
            rel_path = os.path.relpath(path, src_root).replace("\\", "/")  # ★ [수정] 공통키 미리 계산
            try:
                with open(path, "r", encoding="utf-8") as f:
                    raw = json.load(f)

                flat = cbp.preprocess_article(raw)
                input_ids, attention_mask, labels, dbg = build_masked_inputs_for_ppl(
                    flat, wrapper.tokenizer, strict_span
                )

                input_ids, attention_mask, labels = (
                    input_ids.to(device),
                    attention_mask.to(device),
                    labels.to(device),
                )

                ppl_off, ppl_on, loss_off, loss_on = wrapper.compute_ppl_pair(
                    input_ids, attention_mask, labels, ppl_cap_exp
                )

                rec = {
                    "path": path,
                    "rel_path": rel_path,
                    "newsID": flat.get("newsID"),
                    "clickbaitClass": flat.get("clickbaitClass"),
                    "mean_nll_anchor_off": loss_off,
                    "mean_nll_deployed_on": loss_on,
                    "ppl_anchor_off": ppl_off,
                    "ppl_deployed_on": ppl_on,
                    "span_debug": dbg,
                }
                w.write(json.dumps(rec, ensure_ascii=False) + "\n")
                n_ok += 1

            except Exception as e:
                n_err += 1
                # ★ [수정] 에러 레코드에도 rel_path 기록
                w.write(json.dumps({
                    "path": path,
                    "rel_path": rel_path,
                    "error": repr(e)
                }, ensure_ascii=False) + "\n")

            if (idx + 1) % 100 == 0:
                print(f"[run] done={idx+1}/{len(all_paths)}")

    print(f"[run] finished. ok={n_ok} err={n_err} -> {out_jsonl_path}")


# ===== 5) 헬퍼: 데이터/어답터 로컬 복사 =====
def extract_data_to_local(tar_path: str, dest_dir: str = "/content/data_extracted") -> str:
    """Drive에 있는 tar(.tar/.tar.gz 등) 파일을 로컬 dest_dir로 압축 해제"""
    if not os.path.exists(tar_path):
        raise FileNotFoundError(f"Tar file not found: {tar_path}")

    # ⚠️ 섞임 방지: 파티션마다 dest_dir을 다르게 주는 것을 권장
    if os.path.exists(dest_dir) and os.listdir(dest_dir):
        print(f"[setup] Data dir exists, skipping extract: {dest_dir}")
        return dest_dir

    os.makedirs(dest_dir, exist_ok=True)
    print(f"[setup] Extracting {tar_path} -> {dest_dir} ...")
    with tarfile.open(tar_path, "r:*") as tar:  # ★ [수정] tar.gz 등 호환
        tar.extractall(path=dest_dir)
    print("[setup] Data extraction done.")
    return dest_dir


def copy_adapter_to_local(src_path: str, dest_path: str = "/content/adapter_local") -> str:
    """Drive에 있는 어답터 폴더를 로컬 dest_path로 복사"""
    if not os.path.exists(src_path):
        raise FileNotFoundError(f"Adapter source not found: {src_path}")

    print(f"[setup] Copying adapter: {src_path} -> {dest_path} ...")
    shutil.copytree(src_path, dest_path, dirs_exist_ok=True)
    print("[setup] Adapter copy done.")
    return dest_path


# ===== 6) 실행 예시 =====
# [설정] 구글 드라이브 경로
DRIVE_TAR_PATH = "/content/drive/MyDrive/TRAITHON/data/golden_set_v1_TL.tar"

# 어답터 원본 경로
DRIVE_ADAPTER_SRC = "/content/drive/MyDrive/TRAITHON/models/adapter"

# ★ [수정] part 번호 일치시키기
OUT_JSONL = "/content/drive/MyDrive/TRAITHON/DRIFT/ppl_golden.jsonl"

# 1) Base 모델 준비 (다운로더 사용)
MODEL_PATH = ax4_dl.download_ax4_light_to_content(force=False)

# 2) 데이터 준비 (Tar -> Local)  ★ [수정] 파티션마다 다른 dest_dir 사용
LOCAL_DATA_DIR = extract_data_to_local(
    tar_path=DRIVE_TAR_PATH,
    dest_dir="/content/data_part0"
)

# 3) 어답터 준비 (Drive -> Local)
LOCAL_ADAPTER_PATH = copy_adapter_to_local(
    src_path=DRIVE_ADAPTER_SRC,
    dest_path="/content/adapter_temp"
)

# 4) 메인 로직 실행
process_folder_to_jsonl(
    src_root=LOCAL_DATA_DIR,
    out_jsonl_path=OUT_JSONL,
    model_path=MODEL_PATH,
    adapter_path=LOCAL_ADAPTER_PATH,

    shard_id=0,      # ★ [수정] num_shards=1이면 0으로 고정 권장
    num_shards=1,
    max_files=None,
    ppl_cap_exp=80.0,
    strict_span=False,
)


Mounted at /content/drive
[hf_dl] downloading repo: skt/A.X-4.0-Light
[hf_dl] -> /content/A.X-4.0-Light
[hf_dl] token: none (public repo라면 OK, gated면 필요)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
For more details, check out https://huggingface.co/docs/huggingface_hub/main/en/guides/download#download-files-to-local-folder.


Fetching 16 files:   0%|          | 0/16 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/130 [00:00<?, ?B/s]

A.X_logo.png:   0%|          | 0.00/72.8k [00:00<?, ?B/s]

.gitattributes: 0.00B [00:00, ?B/s]

LICENSE: 0.00B [00:00, ?B/s]

README.md: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/674 [00:00<?, ?B/s]

assets/A.X_logo_ko_4x3.png:   0%|          | 0.00/183k [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.93G [00:00<?, ?B/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/4.60G [00:00<?, ?B/s]

[hf_dl] done: /content/A.X-4.0-Light
[setup] Extracting /content/drive/MyDrive/TRAITHON/data/golden_set_v1_TL.tar -> /content/data_part0 ...


  tar.extractall(path=dest_dir)


[setup] Data extraction done.
[setup] Copying adapter: /content/drive/MyDrive/TRAITHON/models/adapter -> /content/adapter_temp ...
[setup] Adapter copy done.
[run] src_root=/content/data_part0
[run] total_files=5000 (shard 0/1)
[ppl] Loading base model: /content/A.X-4.0-Light


`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

[ppl] Attaching adapter: /content/adapter_temp
[ppl] Ready. run_device=cuda:0
[run] done=100/5000
[run] done=200/5000
[run] done=300/5000
[run] done=400/5000
[run] done=500/5000
[run] done=600/5000
[run] done=700/5000
[run] done=800/5000
[run] done=900/5000
[run] done=1000/5000
[run] done=1100/5000
[run] done=1200/5000
[run] done=1300/5000
[run] done=1400/5000
[run] done=1500/5000
[run] done=1600/5000
[run] done=1700/5000
[run] done=1800/5000
[run] done=1900/5000
[run] done=2000/5000
[run] done=2100/5000
[run] done=2200/5000
[run] done=2300/5000
[run] done=2400/5000
[run] done=2500/5000
[run] done=2600/5000
[run] done=2700/5000
[run] done=2800/5000
[run] done=2900/5000
[run] done=3000/5000
[run] done=3100/5000
[run] done=3200/5000
[run] done=3300/5000
[run] done=3400/5000
[run] done=3500/5000
[run] done=3600/5000
[run] done=3700/5000
[run] done=3800/5000
[run] done=3900/5000
[run] done=4000/5000
[run] done=4100/5000
[run] done=4200/5000
[run] done=4300/5000
[run] done=4400/5000
[run] d