In [None]:
import os
import glob
import shutil
from typing import List, Tuple

import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split

from transformers import (
    AutoTokenizer,
    AutoModel,
    Trainer,
    TrainingArguments,
)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("DEVICE:", DEVICE)


def split_into_chunks(text, chunk_chars=4000):
    """
    긴 문자열을 chunk_chars 글자 단위로 잘라 리스트로 반환.
    ex) chunk_chars=4000이면 0~3999, 4000~7999, ... 이런 식으로 나뉨.
    """
    text = str(text)
    return [text[i:i + chunk_chars] for i in range(0, len(text), chunk_chars)]


def build_demographics_soft_labels(
    demo_path1: str,
    demo_path2: str,
) -> Tuple[pd.DataFrame, List[str]]:
    """
    demographics_part001/002.xlsx를 사용해
    article_id 기준 demographics soft label matrix 생성
    return:
      - label_mat: index = article_id, columns = demo_group (soft label 분포)
      - demo_labels: demo_group 이름 리스트
    """

    df1 = pd.read_excel(demo_path1)
    df2 = pd.read_excel(demo_path2)
    df_demo = pd.concat([df1, df2], ignore_index=True)

    df_demo = df_demo[~((df_demo["views"] == 0) & (df_demo["ratio"] == 0))].copy()

    df_demo["ratio_frac"] = df_demo["ratio"] / 100.0

    df_demo["demo_group"] = df_demo["age_group"].astype(str) + "_" + df_demo["gender"].astype(str)

    grouped = (
        df_demo
        .groupby(["article_id", "demo_group"])
        .apply(lambda g: (g["views"] * g["ratio_frac"]).sum() / g["views"].sum())
        .reset_index(name="soft_label")
    )

    label_mat = (
        grouped
        .pivot(index="article_id", columns="demo_group", values="soft_label")
        .fillna(0.0)
    )

    drop_cols = [c for c in label_mat.columns if str(c).startswith("전체_")]
    if drop_cols:
        print("[Demographics] Dropping overall gender groups:", drop_cols)
        label_mat = label_mat.drop(columns=drop_cols)

    label_mat = label_mat[label_mat.sum(axis=1) > 0].copy()

    row_sum = label_mat.sum(axis=1)
    row_sum[row_sum == 0.0] = 1.0
    label_mat = label_mat.div(row_sum, axis=0)

    demo_labels = list(label_mat.columns)
    print("[Demographics] demo_labels 개수:", len(demo_labels))

    return label_mat, demo_labels


def build_demographics_hard_labels(
    demo_path1: str,
    demo_path2: str,
) -> Tuple[pd.DataFrame, List[str]]:
    """
    soft 분포를 먼저 만든 뒤,
    각 article_id 마다 argmax demo_group을 뽑아서
    해당 그룹만 1.0, 나머지는 0.0인 one-hot 벡터로 변환
    """

    soft_mat, demo_labels = build_demographics_soft_labels(demo_path1, demo_path2)

    values = soft_mat.values
    major_idx = values.argmax(axis=1)

    hard_values = np.zeros_like(values, dtype=np.float32)
    hard_values[np.arange(len(major_idx)), major_idx] = 1.0

    label_mat_hard = pd.DataFrame(
        hard_values,
        index=soft_mat.index,
        columns=soft_mat.columns,
    )

    print("[Demographics-HARD] demo_labels 개수:", len(demo_labels))
    print("[Demographics-HARD] one-hot 예시:")
    print(label_mat_hard.head())

    return label_mat_hard, demo_labels


class NewsLongformerDataset(Dataset):
    def __init__(self, df, referrers, demo_cols, tokenizer, max_length=4096):
        """
        df: index = article_id
            columns = ['content', ... referrer soft labels..., ... demo_cols ...]
        referrers: 라벨로 쓸 referrer 이름 리스트
        demo_cols: demographic one-hot 컬럼 리스트
        """
        self.df = df.reset_index()
        self.referrers = referrers
        self.demo_cols = demo_cols
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        text = str(row["content"])

        enc = self.tokenizer(
            text,
            truncation=True,
            padding="max_length",
            max_length=self.max_length,
            return_tensors="pt",
        )

        input_ids = enc["input_ids"].squeeze(0)
        attention_mask = enc["attention_mask"].squeeze(0)

        labels = torch.tensor(
            [row[r] for r in self.referrers],
            dtype=torch.float32
        )

        demo_vals = [float(row[c]) for c in self.demo_cols]
        demo_feats = torch.tensor(demo_vals, dtype=torch.float32)

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "demo_feats": demo_feats,
            "labels": labels,
        }


class DemoBigBirdReferrerModel(nn.Module):
    """
    KoBigBird encoder CLS + demo_vector concat → referrer soft-label 예측
    demo_vector = hard one-hot (또는 soft 분포도 가능)
    """
    def __init__(self, encoder_name: str, demo_dim: int, num_labels: int):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(encoder_name)
        hidden_size = self.encoder.config.hidden_size
        self.dropout = nn.Dropout(0.2)
        self.feature_dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(hidden_size + demo_dim, num_labels)

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        demo_feats=None,
        labels=None,
        **kwargs,
    ):
        if demo_feats is None:
            raise ValueError("demo_feats (demographic one-hot) 가 필요합니다.")

        outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )
        pooled = outputs.last_hidden_state[:, 0, :]
        pooled = self.dropout(pooled)

        x = torch.cat([pooled, demo_feats], dim=-1)
        x = self.feature_dropout(x)
        logits = self.classifier(x)

        return {"logits": logits}


class SoftLabelTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.pop("labels")
        outputs = model(**inputs)

        if isinstance(outputs, dict):
            logits = outputs["logits"]
        else:
            logits = outputs.logits

        log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
        loss = -(labels * log_probs).sum(dim=-1).mean()

        if return_outputs:
            return loss, outputs
        return loss


def compute_metrics(eval_pred):
    """
    logits: (N, L)
    labels: soft labels (N, L)  ← referrer 분포
    """
    logits, labels = eval_pred

    logits = np.array(logits)
    labels_soft = np.array(labels)

    probs_pred = torch.softmax(torch.tensor(logits), dim=-1).numpy()

    true_idx = labels_soft.argmax(axis=-1)

    pred_idx = probs_pred.argmax(axis=-1)
    top1_acc = (pred_idx == true_idx).mean()

    k = 3
    pred_rank_idx = np.argsort(-probs_pred, axis=-1)
    topk_hit = np.any(pred_rank_idx[:, :k] == true_idx[:, None], axis=1)
    top3_hit_rate = topk_hit.mean()

    eps = 1e-12
    ce = -(labels_soft * np.log(probs_pred + eps)).sum(axis=1).mean()

    mse = ((probs_pred - labels_soft) ** 2).mean()
    rmse = np.sqrt(mse)

    return {
        "top1_accuracy": float(top1_acc),
        "top3_hit_rate": float(top3_hit_rate),
        "cross_entropy": float(ce),
        "rmse": float(rmse),
    }


def train_model(
    df_path: str,
    demo_path1: str,
    demo_path2: str,
    batch_size: int = 1,
    lr: float = 2e-5,
    num_epochs: int = 3,
    max_length: int = 4096,
    min_ref_count: int = 100,
    output_dir: str = "./kobigbird_demo_hard_title_content",
    chunk_chars: int = 4000,
):
    """
    df_path: news_merged.xlsx 같은 파일 경로
    demo_path1/2: demographics_part001/002.xlsx 경로
    return: (model, tokenizer, referrers, demo_cols, trainer)
    """

    df = pd.read_excel(df_path)

    needed_cols = ["article_id", "title", "content", "views_total", "referrer", "share"]
    for c in needed_cols:
        if c not in df.columns:
            raise ValueError(f"필수 컬럼 {c} 이(가) 데이터에 없습니다.")

    df = df.copy()

    naver_set = {"네이버", "네이버 블로그"}
    if "period" in df.columns:
        group_is_naver_only = (
            df.groupby(["article_id", "period"])["referrer"]
              .transform(lambda s: set(s.unique()).issubset(naver_set))
        )
        before_rows = len(df)
        df = df[~group_is_naver_only].copy()
        print(f"[Referrer] NAVER/NAVER_BLOG only 그룹 제거: {before_rows} -> {len(df)} rows")
    else:
        print("[Referrer] 'period' 컬럼이 없어 NAVER-only 필터는 article_id 단위로만 적용할 수 있습니다.")
        group_is_naver_only = (
            df.groupby(["article_id"])["referrer"]
              .transform(lambda s: set(s.unique()).issubset(naver_set))
        )
        before_rows = len(df)
        df = df[~group_is_naver_only].copy()
        print(f"[Referrer] NAVER/NAVER_BLOG only (article_id 기준) 제거: {before_rows} -> {len(df)} rows")

    df["base_text"] = df["title"].fillna("") + " " + df["content"].fillna("")

    df["referrer_grouped"] = df["referrer"].replace({
        "Facebook": "기타",
        "namu.wiki": "기타",
        "야후": "기타",
    })

    df["referrer"] = df["referrer_grouped"]
    df = df.drop(columns=["referrer_grouped"])

    print("[Referrer] 분포 (after grouping):")
    print(df["referrer"].value_counts())

    ref_counts = df["referrer"].value_counts()
    valid_referrers = ref_counts[ref_counts >= min_ref_count].index.tolist()
    print("사용할 referrer 수:", len(valid_referrers))
    print("예시 referrer:", valid_referrers[:10])

    df = df[df["referrer"].isin(valid_referrers)].copy()

    df["share_frac"] = df["share"] / 100.0

    grouped = (
        df
        .groupby(["article_id", "referrer"])
        .apply(lambda g: (g["views_total"] * g["share_frac"]).sum() / g["views_total"].sum())
        .reset_index(name="soft_label")
    )

    label_mat = (
        grouped
        .pivot(index="article_id", columns="referrer", values="soft_label")
        .fillna(0.0)
    )

    row_sum = label_mat.sum(axis=1)
    row_sum[row_sum == 0.0] = 1.0
    label_mat = label_mat.div(row_sum, axis=0)

    article_text = (
        df
        .drop_duplicates("article_id")
        .set_index("article_id")[["base_text"]]
        .rename(columns={"base_text": "content"})
    )

    dataset_df = article_text.join(label_mat, how="inner").dropna()

    referrer_labels = list(label_mat.columns)
    num_labels = len(referrer_labels)

    print("최종(텍스트+라벨) 데이터 크기:", dataset_df.shape)
    print("라벨(referrer) 개수:", num_labels)

    label_array = dataset_df[referrer_labels].values
    eps = 1e-12
    entropy = -(label_array * np.log(label_array + eps)).sum(axis=1).mean()
    print("평균 라벨 엔트로피 H(y):", entropy)

    demo_hard_mat, demo_labels = build_demographics_hard_labels(demo_path1, demo_path2)

    dataset_df = dataset_df.join(demo_hard_mat, how="inner")

    demo_cols = list(demo_hard_mat.columns)
    print("demo_cols 개수:", len(demo_cols))

    expanded_rows = []
    for article_id, row in dataset_df.iterrows():
        chunks = split_into_chunks(row["content"], chunk_chars=chunk_chars)
        for ci, chunk in enumerate(chunks):
            new_row = row.copy()
            new_row["content"] = chunk
            new_row["chunk_id"] = ci
            new_row["article_id"] = article_id
            expanded_rows.append(new_row)

    tokenizer = AutoTokenizer.from_pretrained("monologg/kobigbird-bert-base")

    filtered_rows = []
    for row in expanded_rows:
        tokens = tokenizer.encode(row["content"], add_special_tokens=True)
        if len(tokens) >= 1000:
            filtered_rows.append(row)

    dataset_df = pd.DataFrame(filtered_rows).set_index("article_id")
    print("chunk 확장 후 데이터 크기:", dataset_df.shape)

    dataset_df["major_referrer"] = dataset_df[referrer_labels].idxmax(axis=1)
    major_ref = dataset_df["major_referrer"]
    print("[major_referrer 분포 (chunk 단위)]:")
    print(major_ref.value_counts(normalize=True))

    train_df, val_df = train_test_split(
        dataset_df,
        test_size=0.2,
        random_state=42,
        stratify=dataset_df["major_referrer"],
    )

    train_dataset = NewsLongformerDataset(
        train_df, referrer_labels, demo_cols, tokenizer, max_length=max_length
    )
    val_dataset = NewsLongformerDataset(
        val_df, referrer_labels, demo_cols, tokenizer, max_length=max_length
    )

    model = DemoBigBirdReferrerModel(
        encoder_name="monologg/kobigbird-bert-base",
        demo_dim=len(demo_cols),
        num_labels=num_labels,
    ).to(DEVICE)

    training_args = TrainingArguments(
        output_dir=output_dir,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        learning_rate=lr,
        num_train_epochs=num_epochs,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        logging_steps=50,
        load_best_model_at_end=True,
        metric_for_best_model="top1_accuracy",
        greater_is_better=True,
        report_to="none",
        save_safetensors=True,
        remove_unused_columns=False,
    )

    trainer = SoftLabelTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
    )

    trainer.train()

    eval_result = trainer.evaluate()
    print("\n최종 평가 결과 (best top-1 기준 모델):")
    for k, v in eval_result.items():
        print(f"{k}: {v}")

    best_ckpt = trainer.state.best_model_checkpoint
    print("Best checkpoint:", best_ckpt)

    ckpt_dirs = sorted(
        [
            p
            for p in glob.glob(os.path.join(output_dir, "checkpoint-*"))
            if os.path.isdir(p)
        ],
        key=lambda p: int(p.split("-")[-1])
    )

    last_ckpt = ckpt_dirs[-1] if ckpt_dirs else None
    print("Last checkpoint:", last_ckpt)

    if best_ckpt is not None:
        best_model_dir = os.path.join(output_dir, "best_model")
        print("Saving best model dir to:", best_model_dir)
        os.makedirs(best_model_dir, exist_ok=True)
        shutil.copytree(best_ckpt, best_model_dir, dirs_exist_ok=True)

    if last_ckpt is not None:
        last_model_dir = os.path.join(output_dir, "last_model")
        print("Saving last model dir to:", last_model_dir)
        os.makedirs(last_model_dir, exist_ok=True)
        shutil.copytree(last_ckpt, last_model_dir, dirs_exist_ok=True)

    for ckpt in ckpt_dirs:
        if ckpt != best_ckpt and ckpt != last_ckpt:
            print("Removing old checkpoint:", ckpt)
            shutil.rmtree(ckpt)

    return trainer.model, tokenizer, referrer_labels, demo_cols, trainer


def predict_platform_distribution(
    text: str,
    age_group: str,
    gender: str,
    model,
    tokenizer,
    referrers,
    demo_cols,
    max_length: int = 4096,
    threshold: float = 0.1,
):
    model.eval()
    with torch.no_grad():
        enc = tokenizer(
            text,
            truncation=True,
            padding="max_length",
            max_length=max_length,
            return_tensors="pt",
        )
        input_ids = enc["input_ids"].to(DEVICE)
        attention_mask = enc["attention_mask"].to(DEVICE)

        target_demo = f"{age_group}_{gender}"
        demo_vec = [1.0 if col == target_demo else 0.0 for col in demo_cols]
        demo_tensor = torch.tensor(demo_vec, dtype=torch.float32).unsqueeze(0).to(DEVICE)

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            demo_feats=demo_tensor,
        )
        if isinstance(outputs, dict):
            logits = outputs["logits"]
        else:
            logits = outputs.logits

        probs = torch.softmax(logits, dim=-1)[0].cpu().numpy()
        
    result = {
        ref: float(p)
        for ref, p in sorted(
            zip(referrers, probs),
            key=lambda x: x[1],
            reverse=True
        )
        if p >= threshold
    }
    return result


if __name__ == "__main__":
    excel_path = "../Korean_Spoken/news_merged_grouped_final_summary.xlsx"
    demo_path1 = "../Korean_Spoken/demographics_part001.xlsx"
    demo_path2 = "../Korean_Spoken/demographics_part002.xlsx"

    model, tokenizer, referrers, demo_cols, trainer = train_model(
        df_path=excel_path,
        demo_path1=demo_path1,
        demo_path2=demo_path2,
        batch_size=2,
        lr=1e-5,
        num_epochs=5,
        max_length=4096,
        min_ref_count=100,
        output_dir="./kobigbird_demo_hard_title_content",
    )

    example_text = "10대 여성 소비 트렌드: 편의점과 패션의 변화에 대한 심층 분석 기사 본문 ..."
    dist = predict_platform_distribution(
        text=example_text,
        age_group="10대",
        gender="여",
        model=model,
        tokenizer=tokenizer,
        referrers=referrers,
        demo_cols=demo_cols,
        max_length=4096,
        threshold=0.1,
    )
    print("\n[예측 분포 (p>=0.1)]")
    for k, v in dist.items():
        print(f"{k}: {v:.4f}")
