In [None]:
import os
import glob
import shutil
from typing import List

import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split

from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    Trainer,
    TrainingArguments,
)

# ===============================
# 0. 공통 설정
# ===============================

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("DEVICE:", DEVICE)

ENCODER_MODEL_NAME = "monologg/koelectra-base-v3-discriminator"


# ===============================
# 1. Dataset 정의 (title 텍스트 사용)
# ===============================

class NewsDataset(Dataset):
    def __init__(self, df, referrers, tokenizer, max_length=512):
        """
        df: index = article_id, columns = ["title", (optional: "summary"), <referrer1> ...]
        referrers: 라벨로 쓸 referrer(또는 3-class) 이름 리스트
        """
        self.df = df.reset_index()
        self.referrers = referrers
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        #title만 이용
        text = str(row["title"])

        #title + summary 쓰는 코드
        # text = str(row["title"]) + " " + str(row.get("summary", ""))

        enc = self.tokenizer(
            text,
            truncation=True,
            padding="max_length",
            max_length=self.max_length,
            return_tensors="pt",
        )

        input_ids = enc["input_ids"].squeeze(0)
        attention_mask = enc["attention_mask"].squeeze(0)

        labels = torch.tensor(
            [row[r] for r in self.referrers],
            dtype=torch.float32
        )

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
        }



class SoftLabelTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.pop("labels")  # soft label (분포)
        outputs = model(**inputs)
        logits = outputs.logits

        log_probs = torch.nn.functional.log_softmax(logits, dim=-1)

        # Soft cross-entropy: - sum_k y_k log p_k
        loss = -(labels * log_probs).sum(dim=-1).mean()

        if return_outputs:
            return loss, outputs
        return loss


def compute_metrics(eval_pred):
    logits, labels = eval_pred

    probs = torch.softmax(torch.tensor(logits), dim=-1).numpy()


    true_idx = labels.argmax(axis=-1)
    pred_idx = probs.argmax(axis=-1)
    accuracy = (true_idx == pred_idx).mean()


    eps = 1e-12
    ce = -(labels * np.log(probs + eps)).sum(axis=-1).mean()


    mse = ((probs - labels) ** 2).mean()
    rmse = np.sqrt(mse)

    return {
        "accuracy": float(accuracy),
        "cross_entropy": float(ce),
        "rmse": float(rmse),
    }



def train_model(
    df_path: str,
    batch_size: int = 4,
    lr: float = 2e-5,
    num_epochs: int = 5,
    max_length: int = 512,
    min_ref_count: int = 100,
    output_dir: str = "./koelectra_title_3class",
):
    """
    df_path: title (그리고 선택적으로 summary) 컬럼 포함된 엑셀 경로
    referrer → 3-class soft label (domestic_search / global_search / others)
    """

    df = pd.read_excel(df_path)

    needed_cols = ["article_id", "title", "views_total", "referrer", "share"]
    for c in needed_cols:
        if c not in df.columns:
            raise ValueError(f"필수 컬럼 {c} 이(가) 데이터에 없습니다.")

    df = df.copy()

    if "summary" not in df.columns:
        print("[경고] 'summary' 컬럼이 없습니다. title+summary 모드는 나중에 사용할 수 없습니다.")


    naver_set = {"네이버", "네이버 블로그"}
    if "period" in df.columns:
        group_is_naver_only = (
            df.groupby(["article_id", "period"])["referrer"]
              .transform(lambda s: set(s.unique()).issubset(naver_set))
        )
        before_rows = len(df)
        df = df[~group_is_naver_only].copy()
        print(f"[Referrer] NAVER/NAVER_BLOG only 그룹 제거: {before_rows} -> {len(df)} rows")

    print("[Referrer] 필터 전 referrer 분포:")
    print(df["referrer"].value_counts())


    ref_counts = df["referrer"].value_counts()
    valid_referrers = ref_counts[ref_counts >= min_ref_count].index.tolist()
    print("[Referrer] 사용할 referrer 수:", len(valid_referrers))
    print("[Referrer] 예시 referrer:", valid_referrers[:10])

    df = df[df["referrer"].isin(valid_referrers)].copy()

    df["share_frac"] = df["share"] / 100.0

    grouped = (
        df
        .groupby(["article_id", "referrer"])
        .apply(lambda g: (g["views_total"] * g["share_frac"]).sum() / g["views_total"].sum())
        .reset_index(name="soft_label")
    )

    label_mat = (
        grouped
        .pivot(index="article_id", columns="referrer", values="soft_label")
        .fillna(0.0)
    )

    row_sum = label_mat.sum(axis=1)
    row_sum[row_sum == 0.0] = 1.0
    label_mat = label_mat.div(row_sum, axis=0)

    print("[원본 라벨 행렬 shape]:", label_mat.shape)
    print("[원본 referrer 컬럼들]:", list(label_mat.columns))


    group_cols = {
        "domestic_search": ["네이버", "네이버 블로그", "Daum"],
        "global_search": ["Google", "Bing", "AI 검색엔진"],
        "others": ["기타"],
    }

    label_mat_3 = pd.DataFrame(index=label_mat.index)

    for gname, cols in group_cols.items():
        exist_cols = [c for c in cols if c in label_mat.columns]
        if len(exist_cols) == 0:
            label_mat_3[gname] = 0.0
        else:
            label_mat_3[gname] = label_mat[exist_cols].sum(axis=1)

 
    row_sum_3 = label_mat_3.sum(axis=1)
    row_sum_3[row_sum_3 == 0.0] = 1.0
    label_mat_3 = label_mat_3.div(row_sum_3, axis=0)

    label_mat = label_mat_3
    referrer_labels = list(label_mat.columns)
    num_labels = len(referrer_labels)

    print("[3-class 라벨 행렬 shape]:", label_mat.shape)
    print("[3-class 라벨 컬럼들]:", referrer_labels)

    article_text = (
        df
        .drop_duplicates("article_id")
        .set_index("article_id")[["title"]]   # 기본: title만 사용
    )

    #title+summary를 하나의 텍스트로 쓰고 싶을 때:
    
    # df["base_text"] = df["title"].fillna("") + " " + df["summary"].fillna("")
    # article_text = (
    #     df
    #     .drop_duplicates("article_id")
    #     .set_index("article_id")[["base_text"]]
    # )
    #
    # 그리고 아래 dataset_df 만들 때도 article_text와 join하면 됨
    # Dataset에서는 row["base_text"]를 쓰도록 수정 필요
    
    dataset_df = article_text.join(label_mat, how="inner").dropna()

    print("[최종 데이터 크기]:", dataset_df.shape)
    print("[라벨 개수(num_labels)]:", num_labels)


    dataset_df["major_referrer"] = dataset_df[referrer_labels].idxmax(axis=1)
    print("[major_referrer 분포 (3-class)]:")
    print(dataset_df["major_referrer"].value_counts(normalize=True))

    train_df, val_df = train_test_split(
        dataset_df,
        test_size=0.2,
        random_state=42,
        stratify=dataset_df["major_referrer"],
    )


    tokenizer = AutoTokenizer.from_pretrained(ENCODER_MODEL_NAME)

    train_dataset = NewsDataset(
        train_df,
        referrer_labels,
        tokenizer,
        max_length=max_length,
    )
    val_dataset = NewsDataset(
        val_df,
        referrer_labels,
        tokenizer,
        max_length=max_length,
    )


    model = AutoModelForSequenceClassification.from_pretrained(
        ENCODER_MODEL_NAME,
        num_labels=num_labels,
    )
    model.to(DEVICE)


    training_args = TrainingArguments(
        output_dir=output_dir,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        learning_rate=lr,
        num_train_epochs=num_epochs,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        logging_steps=50,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        report_to="none",
        save_safetensors=True,
        save_total_limit=2,
    )

    trainer = SoftLabelTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
    )


    trainer.train()

 
    eval_result = trainer.evaluate()
    print("\n최종 평가 결과 (3-class):")
    for k, v in eval_result.items():
        print(f"{k}: {v}")

    best_ckpt = trainer.state.best_model_checkpoint
    print("Best checkpoint:", best_ckpt)

    ckpt_dirs = sorted(
        [
            p
            for p in glob.glob(os.path.join(output_dir, "checkpoint-*"))
            if os.path.isdir(p)
        ],
        key=lambda p: int(p.split("-")[-1])
    )

    last_ckpt = ckpt_dirs[-1] if ckpt_dirs else None
    print("Last checkpoint:", last_ckpt)


    if best_ckpt is not None:
        best_model_dir = os.path.join(output_dir, "best_model")
        print("Saving best model to:", best_model_dir)
        best_model = AutoModelForSequenceClassification.from_pretrained(
            best_ckpt,
            num_labels=num_labels,
        )
        best_model.save_pretrained(best_model_dir)
        tokenizer.save_pretrained(best_model_dir)


    if last_ckpt is not None:
        last_model_dir = os.path.join(output_dir, "last_model")
        print("Saving last model to:", last_model_dir)
        last_model = AutoModelForSequenceClassification.from_pretrained(
            last_ckpt,
            num_labels=num_labels,
        )
        last_model.save_pretrained(last_model_dir)
        tokenizer.save_pretrained(last_model_dir)


    for ckpt in ckpt_dirs:
        if ckpt != best_ckpt and ckpt != last_ckpt:
            print("Removing old checkpoint:", ckpt)
            shutil.rmtree(ckpt)

    return trainer.model, tokenizer, referrer_labels, trainer



if __name__ == "__main__":
    excel_path = "../Korean_Spoken/news_merged_grouped_final_summary.xlsx"

    model, tokenizer, referrers, trainer = train_model(
        df_path=excel_path,
        batch_size=4,
        lr=2e-5,
        num_epochs=10,
        max_length=512,
        min_ref_count=100,
        output_dir="./koelectra_title_3class_10epoch",
    )

    print("\n=== Training Finished ===")
    print("Referrer labels (3-class):", referrers)
