In [None]:
# -*- coding: utf-8 -*-
import os
import re
import argparse
from pathlib import Path
import pandas as pd
from tqdm import tqdm
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig
import sys


# ----------------------------
# 유틸: 후처리 (tidy + drop incomplete)
# ----------------------------
def tidy_korean_summary(text: str) -> str:
    t = re.sub(r"\s+", " ", (text or "")).strip()
    if t and t[-1] not in ".!?…。\"'”’":
        t += "."
    return t

def drop_incomplete_sentences(text: str) -> str:
    if not text or str(text).strip() == "":
        return ""
    sentences = re.split(r'(?<=[\.\?\!。\!?])\s*', text)
    sentences = [s.strip() for s in sentences if s and s.strip()]
    if len(sentences) == 0:
        return ""
    if len(sentences) == 1:
        last = sentences[0]
        if len(last.split()) < 3 or re.search(r'(며$|고$|하며$|으로$|에게$|서$|로$|에$|에 대해$|한다$|된다$)', last):
            return ""
        return last
    last = sentences[-1]
    if (len(last.split()) < 4) or (not re.search(r'[\.\?\!。\!?:;]$', last)) or re.search(r'(며$|고$|하며$|으로$|에게$|서$|로$|에$|에 대해$|한다$|된다$|것이다$)', last):
        sentences = sentences[:-1]
    cleaned = " ".join(sentences).strip()
    return cleaned


# ----------------------------
# 무조건 kobart_finetuned 사용
# ----------------------------
def load_tokenizer_model(local_dir: str = "./kobart_finetuned"):
    p = Path(local_dir)
    if not p.exists():
        raise FileNotFoundError(
            f"[ERROR] 로컬 모델 폴더를 찾을 수 없습니다: {local_dir}\n"
            "확인해주세요. (model.safetensors / tokenizer 파일들이 있어야 함)"
        )

    print(f"[INFO] Using ONLY local model: {local_dir}")

    # config.json 수정: num_labels 같은 classification 필드 제거
    cfg_path = p / "config.json"
    try:
        if cfg_path.exists():
            import json
            with open(cfg_path, "r", encoding="utf-8") as f:
                cfg = json.load(f)
            for k in ["id2label", "label2id", "num_labels", "problem_type"]:
                cfg.pop(k, None)
            with open(cfg_path, "w", encoding="utf-8") as f:
                json.dump(cfg, f, ensure_ascii=False, indent=2)
    except Exception:
        pass

    cfg = AutoConfig.from_pretrained(local_dir, local_files_only=True)
    tokenizer = AutoTokenizer.from_pretrained(local_dir, use_fast=True, local_files_only=True)
    model = AutoModelForSeq2SeqLM.from_pretrained(local_dir, config=cfg, local_files_only=True)

    return tokenizer, model


# ----------------------------
# 요약 호출 함수
# ----------------------------
def summarize_text(tokenizer, model, device, text: str, min_new_tokens: int, max_new_tokens: int) -> str:
    if not text or str(text).strip() == "":
        return ""
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=1024)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        out = model.generate(
            **inputs,
            do_sample=False,
            num_beams=5,
            min_new_tokens=min_new_tokens,
            max_new_tokens=max_new_tokens,
            no_repeat_ngram_size=3,
            repetition_penalty=1.2,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            early_stopping=True,
        )
    decoded = tokenizer.decode(out[0], skip_special_tokens=True).strip()
    cleaned = tidy_korean_summary(decoded)
    cleaned = drop_incomplete_sentences(cleaned)
    return cleaned


# ----------------------------
# CSV/Excel 자동 읽기
# ----------------------------
def _read_table_auto(path: str, encoding: str = "utf-8"):
    ext = os.path.splitext(path)[1].lower()
    if ext in [".xlsx", ".xls"]:
        return pd.read_excel(path, engine="openpyxl")
    elif ext in [".tsv", ".tab"]:
        return pd.read_csv(path, sep="\t", encoding=encoding)
    else:
        return pd.read_csv(path, encoding=encoding)


# ----------------------------
# 메인 실행부
# ----------------------------
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--in", dest="input_path", required=True)
    ap.add_argument("--out", dest="output_path", required=True)
    ap.add_argument("--text-col", dest="text_col", default="newsContent")
    ap.add_argument("--summary-col", dest="summary_col", default="summary")
    ap.add_argument("--min-tokens", dest="min_tokens", type=int, default=40)
    ap.add_argument("--max-tokens", dest="max_tokens", type=int, default=150)
    ap.add_argument("--encoding", dest="encoding", default="utf-8")
    ap.add_argument("--device", dest="device", default=None)
    args = ap.parse_args()

    df = _read_table_auto(args.input_path, encoding=args.encoding)

    tokenizer, model = load_tokenizer_model("./kobart_finetuned")
    device = args.device if args.device else ("cuda" if torch.cuda.is_available() else "cpu")
    
    print(f"[DEVICE] Using device: {device}")


    model.to(device).eval()

    print(f"[RUN] Summarizing '{args.text_col}' → '{args.summary_col}'")

    results = []
    for text in tqdm(df[args.text_col].fillna("").astype(str), desc="Summarizing", ncols=120):
        results.append(summarize_text(tokenizer, model, device, text, args.min_tokens, args.max_tokens))

    df_out = df.copy()
    df_out[args.summary_col] = results
    df_out.to_csv(args.output_path, index=False, encoding="utf-8-sig")

    print(f"[DONE] 요약 완료 → {args.output_path}")



    # ----------------------------
# 단일 문자열 요약 (파일 입출력 X)
# ----------------------------
def summarize_article(text: str) -> str:
    """
    기사 본문 문자열(text)을 입력받아 kobart_finetuned 모델로 요약문을 반환합니다.
    """
    model_dir = "./kobart_finetuned"
    tokenizer, model = load_tokenizer_model(model_dir)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device).eval()

    summary = summarize_text(
        tokenizer, model, device,
        text,
        min_new_tokens=40,
        max_new_tokens=150
    )
    return summary

if __name__ == "__main__":
    if "--in" in sys.argv and "--out" in sys.argv:
        main()
    
    else:        
        text = """삼성전자가 2025년 3분기 실적을 발표했다. 
        영업이익은 전년 동기 대비 30% 증가하며 반도체 부문이 실적을 견인했다. 
        특히 AI 반도체 수요 확대가 주요 원인으로 분석된다."""
        result = summarize_article(text)
        print("요약문:", result)