# Attention-MIL + LoRA Inference
학습된 모델을 불러와 테스트 데이터에 대한 추론을 수행하고 submission 파일을 생성합니다.


In [None]:
# Google Colab 환경 설정
from google.colab import drive
drive.mount('/content/drive')

!pip install -q "transformers>=4.40,<5" "peft>=0.11.0" accelerate scikit-learn tqdm


In [None]:
# 라이브러리 import 및 설정
import os, re, gc, random, math, numpy as np, pandas as pd, torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm
from contextlib import nullcontext

# 메모리 파편화 완화
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# 시드/디바이스
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device, " | Torch:", torch.__version__)

# BF16 설정
USE_BF16 = True
if torch.cuda.is_available():
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    torch.set_float32_matmul_precision("high")

def autocast_ctx():
    if not torch.cuda.is_available():
        return nullcontext()
    try:
        dtype = torch.bfloat16 if USE_BF16 else torch.float16
        return torch.autocast(device_type="cuda", dtype=dtype)
    except TypeError:
        return torch.autocast(device_type="cuda", dtype=torch.float16)

# 하이퍼파라미터
BASE_MODEL   = "monologg/koelectra-base-v3-discriminator"
MAX_LEN      = 192
POS_WEIGHT   = 14.23  # 학습시 사용된 값 (참고용)

# 경로 설정
DIR = ''
OUTPUT_DIR = ''


In [None]:
# 모델 정의
from transformers import AutoTokenizer
from transformers.models.electra import ElectraModel
from peft import LoraConfig, get_peft_model

class MILPooler(nn.Module):
    """Pooling 선택지: 'attn' | 'lse' | 'max' | 'mean'"""
    def __init__(self, d, mode="lse", r=128, lse_tau=10.0):
        super().__init__()
        self.mode = mode; self.tau = lse_tau
        if mode == "attn":
            self.W = nn.Linear(d, r)
            self.u = nn.Linear(r, 1, bias=False)
    
    def forward(self, H):  # H: [n_i, d]
        if self.mode == "mean":
            z = H.mean(dim=0)
            alpha = torch.full((H.size(0),), 1.0/H.size(0), device=H.device)
            return z, alpha
        if self.mode == "max":
            idx = torch.argmax(H.norm(dim=1))
            z = H[idx]
            alpha = torch.zeros(H.size(0), device=H.device); alpha[idx] = 1.0
            return z, alpha
        if self.mode == "lse":
            e = H.norm(dim=1) * self.tau
            alpha = torch.softmax(e, dim=0)
            z = torch.sum(alpha.unsqueeze(-1) * H, dim=0)
            return z, alpha
        # attention
        A = torch.tanh(self.W(H))
        e = self.u(A).squeeze(-1)
        alpha = torch.softmax(e, dim=0)
        z = torch.sum(alpha.unsqueeze(-1) * H, dim=0)
        return z, alpha


class AttentionMILModel(nn.Module):
    def __init__(self, encoder: nn.Module, attn_r=128, pos_weight=1.0,
                 pool_mode="lse", lse_tau=10.0, inst_aux_lambda=0.1, inst_aux_k=1):
        super().__init__()
        self.encoder = encoder
        if hasattr(self.encoder.config, "use_cache"):
            self.encoder.config.use_cache = False
        d = self.encoder.config.hidden_size
        self.instance_head = nn.Linear(d, 1)
        self.pool = MILPooler(d, mode=pool_mode, r=attn_r, lse_tau=lse_tau)
        self.bag_head = nn.Linear(d, 1)
        self.register_buffer("pos_weight", torch.tensor(float(pos_weight)))
        self.inst_aux_lambda = inst_aux_lambda
        self.inst_aux_k = inst_aux_k

    def encode_instances(self, input_ids, attention_mask, need_instance_logits=True):
        out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        H = out.last_hidden_state[:, 0, :]
        s = self.instance_head(H).squeeze(-1) if need_instance_logits else None
        return H, s

    def forward(self, input_ids, attention_mask, bag_bounds, labels=None,
                return_instance=True, return_alphas=False, **_):
        H_all, s_all = self.encode_instances(input_ids, attention_mask, need_instance_logits=return_instance)
        bag_logits, alphas_all = [], ([] if return_alphas else None)

        for bi, (start, end) in enumerate(bag_bounds.tolist()):
            H = H_all[start:end]
            z, alpha = self.pool(H)
            S = self.bag_head(z).squeeze(-1)
            bag_logits.append(S)
            if return_alphas:
                alphas_all.append(alpha)

        bag_logits = torch.stack(bag_logits, dim=0)
        out = {"bag_logits": bag_logits}
        if return_instance: out["instance_logits"] = s_all
        if return_alphas:   out["alphas"] = alphas_all
        return out

print("모델 클래스 정의 완료")


In [None]:
# 토크나이저 로드
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
print(f"토크나이저 로드 완료: {BASE_MODEL}")


In [None]:
# 테스트 데이터 로드 및 Dataset 정의
test_df = pd.read_csv(DIR + "test.csv")
assert set(["ID", "title", "paragraph_index", "paragraph_text"]).issubset(test_df.columns)
test_df["paragraph_text"] = test_df["paragraph_text"].astype(str).fillna("")
test_df = test_df.sort_values(["title", "paragraph_index"]).reset_index(drop=True)

print(f"테스트 데이터 로드 완료: {len(test_df)} rows")

class TestParagraphDataset(Dataset):
    def __init__(self, df):
        self.groups = []
        for title, g in df.groupby("title"):
            self.groups.append({
                "title": title, 
                "paras": g["paragraph_text"].tolist(), 
                "index": g.index.tolist()
            })
    def __len__(self): 
        return len(self.groups)
    def __getitem__(self, idx): 
        return self.groups[idx]

def collate_test(batch, tokenizer, max_len=256):
    all_texts, bag_bounds, titles, indices = [], [], [], []
    cursor = 0
    for d in batch:
        titles.append(d["title"])
        indices.append(d["index"])
        paras = d["paras"] if d["paras"] else [""]
        start = cursor
        all_texts.extend(paras)
        cursor += len(paras)
        end = cursor
        bag_bounds.append((start, end))
    enc = tokenizer(all_texts, truncation=True, padding=True, max_length=max_len, return_tensors="pt")
    return {
        "input_ids": enc["input_ids"], 
        "attention_mask": enc["attention_mask"],
        "bag_bounds": torch.tensor(bag_bounds, dtype=torch.long),
        "titles": titles, 
        "indices": indices
    }

te_set = TestParagraphDataset(test_df)
te_loader = DataLoader(
    te_set, batch_size=1, shuffle=False,
    collate_fn=lambda b: collate_test(b, tokenizer, MAX_LEN)
)

print(f"테스트 DataLoader 생성 완료: {len(te_set)} documents")


In [None]:
# 모델 로드
# LoRA 설정 (학습시 사용한 설정과 동일해야 함)
lora_cfg = LoraConfig(
    r=8, lora_alpha=16, lora_dropout=0.1,
    target_modules=["query", "value"],
    task_type="FEATURE_EXTRACTION"
)

# Electra encoder + LoRA
best_encoder = ElectraModel.from_pretrained(BASE_MODEL)
best_encoder = get_peft_model(best_encoder, lora_cfg)

# MIL 모델 생성
best_model = AttentionMILModel(
    encoder=best_encoder, 
    attn_r=128, 
    pos_weight=POS_WEIGHT,
    pool_mode="lse", 
    lse_tau=10.0, 
    inst_aux_lambda=0.1, 
    inst_aux_k=1
).to(device)

# 학습된 가중치 로드
MODEL_PATH = OUTPUT_DIR + "attnmil_lora_best6.pt"
best_model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
best_model.eval()

print(f"모델 로드 완료: {MODEL_PATH}")


In [None]:
# Inference 함수 정의
@torch.no_grad()
def infer(model, loader, T: float = 1.0):
    """
    문단별 확률을 예측하는 inference 함수
    
    Args:
        model: 학습된 AttentionMILModel
        loader: 테스트 DataLoader
        T: Temperature scaling 값 (calibration용)
    
    Returns:
        paragraph_probs: 각 문단의 AI 생성 확률
    """
    paragraph_probs = np.zeros(len(test_df), dtype=float)
    model.eval()
    
    with torch.inference_mode():
        for batch in tqdm(loader, desc="Inference", ncols=100):
            batch = {k: (v.to(device) if torch.is_tensor(v) else v) for k, v in batch.items()}
            
            with autocast_ctx():
                out = model(
                    input_ids=batch["input_ids"],
                    attention_mask=batch["attention_mask"],
                    bag_bounds=batch["bag_bounds"],
                    labels=None,
                    return_instance=True,
                    return_alphas=False,
                )

                # 문단 로짓을 float32로 캐스팅 후 sigmoid
                inst_logits = out["instance_logits"].float()
                p_all = torch.sigmoid(inst_logits).detach().cpu().numpy()

                (start, end) = batch["bag_bounds"][0].tolist()
                idxs = batch["indices"][0]
                paragraph_probs[idxs] = p_all[start:end]

            del out, batch, inst_logits, p_all
            torch.cuda.empty_cache()
            try:
                torch.cuda.ipc_collect()
            except:
                pass
    
    return paragraph_probs

print("Inference 함수 정의 완료")


In [None]:
# Inference 실행
best_T = 0.0385  # 학습시 최적화된 Temperature 값 (calibration)
paragraph_probs = infer(best_model, te_loader, T=best_T)

print(f"Inference 완료: {len(paragraph_probs)} paragraphs")
print(f"예측 확률 범위: [{paragraph_probs.min():.4f}, {paragraph_probs.max():.4f}]")
print(f"예측 확률 평균: {paragraph_probs.mean():.4f}")


In [None]:
# Submission 파일 생성 및 저장
submission = pd.DataFrame({
    "ID": test_df["ID"], 
    "generated": paragraph_probs
})

# ID 숫자 기준 정렬
submission["id_num"] = submission["ID"].str.extract(r"(\d+)").astype(int)
submission = submission.sort_values("id_num").drop(columns="id_num").reset_index(drop=True)

# CSV 저장
SUBMISSION_PATH = OUTPUT_DIR + "submission_inference.csv"
submission.to_csv(SUBMISSION_PATH, index=False)

print(f"Submission 저장 완료: {SUBMISSION_PATH}")
print(f"Shape: {submission.shape}")
submission.head(10)
