In [None]:
import os
from typing import List, Tuple

import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split

from transformers import (
    AutoTokenizer,
    AutoModel,
    Trainer,
    TrainingArguments,
)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("DEVICE:", DEVICE)

ENCODER_MODEL_NAME = "monologg/koelectra-base-v3-discriminator"


class SoftLabelTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.pop("labels")
        outputs = model(**inputs)

        if isinstance(outputs, dict):
            logits = outputs["logits"]
        else:
            logits = outputs.logits

        log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
        loss = -(labels * log_probs).sum(dim=-1).mean()

        if return_outputs:
            return loss, outputs
        return loss


def compute_metrics(eval_pred):
    logits, labels = eval_pred

    logits = np.array(logits)
    labels_soft = np.array(labels)

    probs_pred = torch.softmax(torch.tensor(logits), dim=-1).numpy()

    true_idx = labels_soft.argmax(axis=-1)

    pred_idx = probs_pred.argmax(axis=-1)
    top1_acc = (pred_idx == true_idx).mean()

    k = 3
    pred_rank_idx = np.argsort(-probs_pred, axis=-1)
    topk_hit = np.any(pred_rank_idx[:, :k] == true_idx[:, None], axis=1)
    top3_hit_rate = topk_hit.mean()

    eps = 1e-12
    ce = -(labels_soft * np.log(probs_pred + eps)).sum(axis=1).mean()

    mse = ((probs_pred - labels_soft) ** 2).mean()
    rmse = np.sqrt(mse)

    return {
        "top1_accuracy": float(top1_acc),
        "top3_hit_rate": float(top3_hit_rate),
        "cross_entropy": float(ce),
        "rmse": float(rmse),
    }


def build_demographics_soft_labels(
    demo_path1: str,
    demo_path2: str,
) -> Tuple[pd.DataFrame, List[str]]:
    df1 = pd.read_excel(demo_path1)
    df2 = pd.read_excel(demo_path2)
    df_demo = pd.concat([df1, df2], ignore_index=True)

    df_demo = df_demo[~((df_demo["views"] == 0) & (df_demo["ratio"] == 0))].copy()

    df_demo["ratio_frac"] = df_demo["ratio"] / 100.0

    df_demo["demo_group"] = df_demo["age_group"].astype(str) + "_" + df_demo["gender"].astype(str)

    grouped = (
        df_demo
        .groupby(["article_id", "demo_group"])
        .apply(lambda g: (g["views"] * g["ratio_frac"]).sum() / g["views"].sum())
        .reset_index(name="soft_label")
    )

    label_mat = (
        grouped
        .pivot(index="article_id", columns="demo_group", values="soft_label")
        .fillna(0.0)
    )

    drop_cols = [c for c in label_mat.columns if str(c).startswith("전체_")]
    if drop_cols:
        print("[Demographics] Dropping overall gender groups:", drop_cols)
        label_mat = label_mat.drop(columns=drop_cols)

    label_mat = label_mat[label_mat.sum(axis=1) > 0].copy()

    row_sum = label_mat.sum(axis=1)
    row_sum[row_sum == 0.0] = 1.0
    label_mat = label_mat.div(row_sum, axis=0)

    demo_labels = list(label_mat.columns)
    print("[Demographics] demo_labels 개수:", len(demo_labels))

    return label_mat, demo_labels


def build_demographics_hard_labels(
    demo_path1: str,
    demo_path2: str,
) -> Tuple[pd.DataFrame, List[str]]:
    soft_mat, demo_labels = build_demographics_soft_labels(demo_path1, demo_path2)

    values = soft_mat.values
    major_idx = values.argmax(axis=1)

    hard_values = np.zeros_like(values, dtype=np.float32)
    hard_values[np.arange(len(major_idx)), major_idx] = 1.0

    label_mat_hard = pd.DataFrame(
        hard_values,
        index=soft_mat.index,
        columns=soft_mat.columns,
    )

    print("[Demographics-HARD] demo_labels 개수:", len(demo_labels))
    print("[Demographics-HARD] one-hot 예시:")
    print(label_mat_hard.head())

    return label_mat_hard, demo_labels


def build_referrer_soft_labels(
    df_news: pd.DataFrame,
    min_ref_count: int = 100,
) -> Tuple[pd.DataFrame, List[str]]:
    df = df_news.copy()

    naver_set = {"네이버", "네이버 블로그"}
    if "period" in df.columns:
        group_is_naver_only = (
            df.groupby(["article_id", "period"])["referrer"]
              .transform(lambda s: set(s.unique()).issubset(naver_set))
        )
        before_rows = len(df)
        df = df[~group_is_naver_only].copy()
        print(f"[Referrer] NAVER/NAVER_BLOG only 그룹 제거: {before_rows} -> {len(df)} rows")
    else:
        print("[Referrer] 'period' 컬럼이 없어 NAVER-only 필터는 article_id 단위로만 적용할 수 있습니다.")
        group_is_naver_only = (
            df.groupby(["article_id"])["referrer"]
              .transform(lambda s: set(s.unique()).issubset(naver_set))
        )
        before_rows = len(df)
        df = df[~group_is_naver_only].copy()
        print(f"[Referrer] NAVER/NAVER_BLOG only (article_id 기준) 제거: {before_rows} -> {len(df)} rows")

    needed_cols = ["article_id", "summary", "views_total", "referrer", "share"]
    for c in needed_cols:
        if c not in df.columns:
            raise ValueError(f"필수 컬럼 {c} 이(가) 데이터에 없습니다.")

    print("[Referrer] 필터 후 referrer 분포:")
    print(df["referrer"].value_counts())

    ref_counts = df["referrer"].value_counts()
    valid_referrers = ref_counts[ref_counts >= min_ref_count].index.tolist()
    print("[Referrer] 사용할 referrer 수:", len(valid_referrers))
    print("[Referrer] 예시 referrer:", valid_referrers[:10])

    df = df[df["referrer"].isin(valid_referrers)].copy()

    df["share_frac"] = df["share"] / 100.0

    grouped = (
        df
        .groupby(["article_id", "referrer"])
        .apply(lambda g: (g["views_total"] * g["share_frac"]).sum() / g["views_total"].sum())
        .reset_index(name="soft_label")
    )

    label_mat = (
        grouped
        .pivot(index="article_id", columns="referrer", values="soft_label")
        .fillna(0.0)
    )

    row_sum = label_mat.sum(axis=1)
    row_sum[row_sum == 0.0] = 1.0
    label_mat = label_mat.div(row_sum, axis=0)

    ref_labels = list(label_mat.columns)
    print("[Referrer] 라벨(referrer) 개수:", len(ref_labels))

    label_array = label_mat.values
    eps = 1e-12
    entropy = -(label_array * np.log(label_array + eps)).sum(axis=1).mean()
    print("[Referrer] 평균 라벨 엔트로피 H(y):", entropy)

    return label_mat, ref_labels


class TextWithDemoVectorDataset(Dataset):
    def __init__(self, df, text_col, demo_cols, label_cols, tokenizer, max_length=512):
        self.df = df.reset_index(drop=False)
        self.text_col = text_col
        self.demo_cols = demo_cols
        self.label_cols = label_cols
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        text = str(row[self.text_col])

        enc = self.tokenizer(
            text,
            truncation=True,
            padding="max_length",
            max_length=self.max_length,
            return_tensors="pt",
        )

        input_ids = enc["input_ids"].squeeze(0)
        attention_mask = enc["attention_mask"].squeeze(0)

        demo_vals = [float(row[c]) for c in self.demo_cols]
        demo_feats = torch.tensor(demo_vals, dtype=torch.float32)

        label_vals = [float(row[c]) for c in self.label_cols]
        labels = torch.tensor(label_vals, dtype=torch.float32)

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "demo_feats": demo_feats,
            "labels": labels,
        }


class DemoReferrerModel(torch.nn.Module):
    def __init__(self, encoder_name: str, demo_dim: int, num_labels: int):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(encoder_name)
        hidden_size = self.encoder.config.hidden_size
        self.dropout = torch.nn.Dropout(0.2)
        self.feature_dropout = torch.nn.Dropout(0.1)
        self.classifier = torch.nn.Linear(hidden_size + demo_dim, num_labels)

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        demo_feats=None,
        labels=None,
        **kwargs,
    ):
        if demo_feats is None:
            raise ValueError("demo_feats (demographics vector)가 필요합니다.")

        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        pooled = outputs.last_hidden_state[:, 0, :]
        pooled = self.dropout(pooled)

        x = torch.cat([pooled, demo_feats], dim=-1)
        x = self.feature_dropout(x)
        logits = self.classifier(x)

        return {"logits": logits}


def train_referrer_model_with_demo_vector(
    news_path: str,
    demo_label_mat: pd.DataFrame,
    batch_size: int = 16,
    lr: float = 1e-5,
    num_epochs: int = 5,
    max_length: int = 512,
    min_ref_count: int = 100,
    output_dir: str = "./koelectra_referrer_with_demo_vector",
):
    df_news = pd.read_excel(news_path)

    df_news["base_text"] = df_news["title"].fillna("") + " " + df_news["summary"].fillna("")

    ref_label_mat, ref_labels = build_referrer_soft_labels(
        df_news,
        min_ref_count=min_ref_count,
    )

    article_meta = (
        df_news
        .drop_duplicates("article_id")
        .set_index("article_id")[["base_text"]]
    )

    article_text = article_meta[["base_text"]]

    common_ids = ref_label_mat.index.intersection(article_text.index)

    article_text = article_text.loc[common_ids]
    ref_label_mat = ref_label_mat.loc[common_ids]

    demo_for_articles = demo_label_mat.reindex(common_ids)
    demo_for_articles = demo_for_articles.fillna(0.0)
    valid_mask = demo_for_articles.sum(axis=1) > 0
    demo_for_articles = demo_for_articles[valid_mask]

    common_ids = demo_for_articles.index.intersection(article_text.index)
    article_text = article_text.loc[common_ids]
    ref_label_mat = ref_label_mat.loc[common_ids]
    demo_for_articles = demo_for_articles.loc[common_ids]

    demo_cols = list(demo_for_articles.columns)

    dataset_df = article_text.join(demo_for_articles, how="inner")
    dataset_df = dataset_df.join(ref_label_mat, how="inner")

    print("[Step2-Vector] 전체 dataset 크기:", dataset_df.shape)

    dataset_df[ref_labels] = (
        dataset_df[ref_labels]
        .apply(pd.to_numeric, errors="coerce")
        .fillna(0.0)
        .astype("float32")
    )
    dataset_df[demo_cols] = (
        dataset_df[demo_cols]
        .apply(pd.to_numeric, errors="coerce")
        .fillna(0.0)
        .astype("float32")
    )

    dataset_df["major_referrer"] = dataset_df[ref_labels].idxmax(axis=1)
    print("[Step2-Vector] major_referrer 분포:")
    print(dataset_df["major_referrer"].value_counts(normalize=True))

    train_df, val_df = train_test_split(
        dataset_df,
        test_size=0.2,
        random_state=42,
        stratify=dataset_df["major_referrer"],
    )

    print("[Hold-out] train size:", len(train_df), " / val size:", len(val_df))

    if len(train_df) == 0 or len(val_df) == 0:
        raise ValueError(f"[ERROR] train/val 데이터가 비었습니다. train={len(train_df)}, val={len(val_df)}")

    tokenizer = AutoTokenizer.from_pretrained(ENCODER_MODEL_NAME)

    train_dataset = TextWithDemoVectorDataset(
        train_df,
        text_col="base_text",
        demo_cols=demo_cols,
        label_cols=ref_labels,
        tokenizer=tokenizer,
        max_length=max_length,
    )
    val_dataset = TextWithDemoVectorDataset(
        val_df,
        text_col="base_text",
        demo_cols=demo_cols,
        label_cols=ref_labels,
        tokenizer=tokenizer,
        max_length=max_length,
    )

    model = DemoReferrerModel(
        encoder_name=ENCODER_MODEL_NAME,
        demo_dim=len(demo_cols),
        num_labels=len(ref_labels),
    ).to(DEVICE)

    training_args = TrainingArguments(
        output_dir=output_dir,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        learning_rate=lr,
        num_train_epochs=num_epochs,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        logging_steps=50,
        load_best_model_at_end=True,
        metric_for_best_model="top1_accuracy",
        greater_is_better=True,
        report_to="none",
        save_safetensors=False,
        save_total_limit=2,
        remove_unused_columns=False,
        weight_decay=0.01,
    )

    trainer = SoftLabelTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
    )

    trainer.train()

    eval_result = trainer.evaluate()
    print(f"\n[Step2-Vector @ {output_dir}] 최종 평가 결과 (best top-1 기준):")
    for k, v in eval_result.items():
        print(f"{k}: {v}")

    best_ckpt = trainer.state.best_model_checkpoint
    print("[Step2-Vector] Best checkpoint:", best_ckpt)
    if best_ckpt is not None:
        best_model_dir = os.path.join(output_dir, "best_model")
        print("[Step2-Vector] Saving best model to:", best_model_dir)
        os.makedirs(best_model_dir, exist_ok=True)

        state_dict = trainer.model.state_dict()
        safe_state_dict = {
            k: v.detach().cpu().contiguous() for k, v in state_dict.items()
        }
        torch.save(
            safe_state_dict,
            os.path.join(best_model_dir, "pytorch_model.bin")
        )
        tokenizer.save_pretrained(best_model_dir)

    return model.to(DEVICE), tokenizer, ref_labels, demo_cols, trainer, dataset_df


def rank_platforms_with_meta(
    title: str,
    summary: str,
    age_group: str,
    gender: str,
    model,
    tokenizer,
    ref_labels: List[str],
    demo_cols: List[str],
    max_length: int = 512,
):
    model.eval()
    with torch.no_grad():
        text = (title or "") + " " + (summary or "")
        enc = tokenizer(
            text,
            truncation=True,
            padding="max_length",
            max_length=max_length,
            return_tensors="pt",
        )

        target_demo = f"{age_group}_{gender}"
        demo_vec = [1.0 if col == target_demo else 0.0 for col in demo_cols]
        demo_tensor = torch.tensor(demo_vec, dtype=torch.float32).unsqueeze(0)

        input_ids = enc["input_ids"].to(DEVICE)
        attention_mask = enc["attention_mask"].to(DEVICE)
        demo_tensor = demo_tensor.to(DEVICE)

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            demo_feats=demo_tensor,
        )
        logits = outputs["logits"]
        probs = torch.softmax(logits, dim=-1)[0].cpu().numpy()

    ranked = sorted(
        zip(ref_labels, probs),
        key=lambda x: x[1],
        reverse=True
    )
    return ranked


if __name__ == "__main__":
    news_path = "../Korean_Spoken/news_merged_grouped_final_summary.xlsx"
    demo_path1 = "../Korean_Spoken/demographics_part001.xlsx"
    demo_path2 = "../Korean_Spoken/demographics_part002.xlsx"

    demo_soft_mat, demo_labels_soft = build_demographics_soft_labels(demo_path1, demo_path2)
    print("재구성된 demo_labels_soft 개수:", len(demo_labels_soft))

    (ref_model_soft,
     ref_tokenizer_soft,
     ref_labels_soft,
     demo_cols_soft,
     ref_trainer_soft,
     ref_dataset_soft) = train_referrer_model_with_demo_vector(
         news_path=news_path,
         demo_label_mat=demo_soft_mat,
         batch_size=8,
         lr=1e-5,
         num_epochs=5,
         max_length=512,
         min_ref_count=100,
         output_dir="./koelectra_referrer_with_demo_vector_soft_demo_nocat",
    )

    demo_hard_mat, demo_labels_hard = build_demographics_hard_labels(demo_path1, demo_path2)
    print("재구성된 demo_labels_hard 개수:", len(demo_labels_hard))

    (ref_model_hard,
     ref_tokenizer_hard,
     ref_labels_hard,
     demo_cols_hard,
     ref_trainer_hard,
     ref_dataset_hard) = train_referrer_model_with_demo_vector(
         news_path=news_path,
         demo_label_mat=demo_hard_mat,
         batch_size=8,
         lr=1e-5,
         num_epochs=5,
         max_length=512,
         min_ref_count=100,
         output_dir="./koelectra_referrer_with_demo_vector_hard_demo_nocat",
    )

    print("\n===================== SOFT vs HARD DEMO 비교 (NO CATEGORY) =====================")
    soft_metrics = ref_trainer_soft.evaluate()
    hard_metrics = ref_trainer_hard.evaluate()

    metric_names = [
        "eval_loss",
        "eval_top1_accuracy",
        "eval_top3_hit_rate",
        "eval_cross_entropy",
        "eval_rmse",
    ]

    print("\n[Metrics 비교]")
    print("{:<22} | {:>12} | {:>12}".format("metric", "soft_demo", "hard_demo"))
    print("-" * 52)
    for m in metric_names:
        sv = soft_metrics.get(m, None)
        hv = hard_metrics.get(m, None)
        print("{:<22} | {:>12.6f} | {:>12.6f}".format(
            m,
            sv if sv is not None else float("nan"),
            hv if hv is not None else float("nan"),
        ))

    example_title = "10대 여성 소비 트렌드: 편의점과 패션의 변화"
    example_summary = "10대 여성의 소비 패턴을 분석하고 주요 플랫폼별 반응을 정리했다."
    example_age = "10대"
    example_gender = "여"

    print("\n[예측] 플랫폼 추천 순위 (상위 5개) - SOFT DEMO 모델:")
    ranked_soft = rank_platforms_with_meta(
        title=example_title,
        summary=example_summary,
        age_group=example_age,
        gender=example_gender,
        model=ref_model_soft,
        tokenizer=ref_tokenizer_soft,
        ref_labels=ref_labels_soft,
        demo_cols=demo_cols_soft,
        max_length=512,
    )
    for ref, p in ranked_soft[:5]:
        print(f"[SOFT] {ref}: {p:.4f}")

    print("\n[예측] 플랫폼 추천 순위 (상위 5개) - HARD DEMO 모델:")
    ranked_hard = rank_platforms_with_meta(
        title=example_title,
        summary=example_summary,
        age_group=example_age,
        gender=example_gender,
        model=ref_model_hard,
        tokenizer=ref_tokenizer_hard,
        ref_labels=ref_labels_hard,
        demo_cols=demo_cols_hard,
        max_length=512,
    )
    for ref, p in ranked_hard[:5]:
        print(f"[HARD] {ref}: {p:.4f}")
