In [None]:
# quantize_kobert_intent_v2.py
import os
import torch
import torch.nn as nn
import torch.ao.quantization as quant
from kobert_transformers import get_kobert_model

# ===== 0) 사용자 환경 =====
MODEL_PATH = "/home/j-k13b204/S13P31B204/model_test/kobert_intent_v2/model_best.pt"
SAVE_DIR = "/home/j-k13b204/S13P31B204/model_test/kobert_intent_v2/quantize"
os.makedirs(SAVE_DIR, exist_ok=True)
QUANTIZED_PT_PATH = os.path.join(SAVE_DIR, "model_best_quantized.pt")

# ===== 1) 모델 클래스 정의 (intent용, BCEWithLogits) =====
class KoBertClassifier(nn.Module):
    def __init__(
        self,
        bert,
        hidden_size: int = 768,
        num_classes: int = 5,
        dr_rate: float = 0.3,
        pos_weight: torch.Tensor | None = None,  # 각 클래스의 양성 가중치 (불균형 보정)
    ):
        super().__init__()
        self.bert = bert
        self.dropout = nn.Dropout(p=dr_rate) if dr_rate and dr_rate > 0 else nn.Identity()
        self.classifier = nn.Linear(hidden_size, num_classes)

        if pos_weight is not None:
            self.loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
        else:
            self.loss_fn = nn.BCEWithLogitsLoss()

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor | None = None,
        token_type_ids: torch.Tensor | None = None,
        labels: torch.Tensor | None = None,  # (B, num_classes), multi-hot
    ):
        out = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )

        if getattr(out, "pooler_output", None) is not None:
            pooled = out.pooler_output      # (B, 768)
        else:
            pooled = out[0][:, 0]           # last_hidden_state[:, 0]

        logits = self.classifier(self.dropout(pooled))  # (B, num_classes)

        if labels is not None:
            labels = labels.float()
            loss = self.loss_fn(logits, labels)
            probs = torch.sigmoid(logits)
            return probs, loss
        probs = torch.sigmoid(logits)
        return probs, None

# ===== 2) 모델 로드 함수 =====
def load_fp32_model(model_path: str, num_classes: int = 5, dr_rate: float = 0.3) -> nn.Module:
    device = torch.device("cpu")
    print("[1/3] 모델 구조 초기화…")

    bert = get_kobert_model()  # ★ 언패킹 없이 모델만 반환
    model = KoBertClassifier(
        bert=bert,
        hidden_size=768,
        num_classes=num_classes,
        dr_rate=dr_rate,
        pos_weight=None,   # 추론용이라 pos_weight 안 써도 됨
    ).to(device)

    print("[2/3] 가중치 로드…")
    state = torch.load(model_path, map_location=device)

    # 학습 때 loss_fn.* 도 같이 저장돼서 걸리니까 제거
    for k in list(state.keys()):
        if k.startswith("loss_fn"):
            del state[k]

    # 엄격하게 안 맞춰도 되니 strict=False
    missing, unexpected = model.load_state_dict(state, strict=False)
    if missing:
        print(f"[WARN] missing keys: {missing}")
    if unexpected:
        print(f"[WARN] unexpected keys: {unexpected}")

    model.eval()
    print("[3/3] 로드 완료.")
    return model

# ===== 3) 양자화 함수 =====
def quantize_dynamic_int8(model: nn.Module) -> nn.Module:
    print("\n=> 동적 양자화(INT8) 수행 중…")
    qmodel = quant.quantize_dynamic(
        model,
        {nn.Linear},
        dtype=torch.qint8,
    )
    print("완료.")
    return qmodel

# ===== 4) 실행부 =====
if __name__ == "__main__":
    NUM_CLASSES = 5  # NCS 5대 역량

    model = load_fp32_model(MODEL_PATH, num_classes=NUM_CLASSES, dr_rate=0.3)
    qmodel = quantize_dynamic_int8(model)

    # 저장
    torch.save(qmodel, QUANTIZED_PT_PATH)
    print(f"\n✅ 양자화된 intent 모델 저장 완료: {QUANTIZED_PT_PATH}")

    # 크기 비교
    def size_mb(path): return os.path.getsize(path) / (1024**2)
    orig = size_mb(MODEL_PATH)
    quantized = size_mb(QUANTIZED_PT_PATH)
    print(f"\n원본 모델 크기: {orig:.2f} MB")
    print(f"양자화 모델 크기: {quantized:.2f} MB")
    print(f"압축률: {orig/quantized:.2f}x, 크기 감소: {(orig - quantized)/orig * 100:.1f}%")


[INFO] FP32 모델 로드: /home/j-k13b204/S13P31B204/model_test/kobert_intent_v2/model_best.pt


RuntimeError: Error(s) in loading state_dict for KoBertClassifier:
	Unexpected key(s) in state_dict: "loss_fn.pos_weight". 