In [None]:
import re
import shap
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from pathlib import Path
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error, classification_report
from xgboost import XGBRegressor, XGBClassifier


In [None]:
# =========================
# 데이터 불러오기 (프로젝트 상대경로)
# =========================

def find_project_root() -> Path:
    p = Path.cwd()

    for parent in [p] + list(p.parents):
        if (parent / "data").exists() and (parent / "notebooks").exists():
            return parent
        
    return p

def latest_versioned_csv(folder: Path, base_name: str) -> Path | None:
    pattern = re.compile(rf"^{re.escape(base_name)}_v(\d+)\.csv$")
    best_v, best_path = None, None

    for f in folder.glob(f"{base_name}_v*.csv"):
        m = pattern.match(f.name)

        if m:
            v = int(m.group(1))

            if best_v is None or v > best_v:
                best_v, best_path = v, f

    return best_path

def next_versioned_file(folder: Path, base_name: str, ext: str = ".csv") -> Path:
    folder.mkdir(parents=True, exist_ok=True)

    pattern = re.compile(rf"^{re.escape(base_name)}_v(\d+){re.escape(ext)}$")
    versions = []

    for f in folder.glob(f"{base_name}_v*{ext}"):
        m = pattern.match(f.name)

        if m:
            versions.append(int(m.group(1)))

    v = (max(versions) + 1) if versions else 1

    return folder / f"{base_name}_v{v}{ext}"

PROJECT_ROOT = find_project_root()

CLEAN_DIR = PROJECT_ROOT / "data" / "processed"

csv_path = latest_versioned_csv(CLEAN_DIR, "trending_videos_clean")
if csv_path is None:
    csv_path = CLEAN_DIR / "trending_videos_clean_v1.csv"

SHAP_FIG_DIR = PROJECT_ROOT / "figures" / "shap"
SHAP_FIG_DIR.mkdir(parents=True, exist_ok=True)

print("PROJECT_ROOT:", PROJECT_ROOT)
print("CLEAN_DIR:", CLEAN_DIR)
print("SHAP_FIG_DIR:", SHAP_FIG_DIR)

if not csv_path.exists():
    raise FileNotFoundError(f"트렌딩 clean 파일이 없습니다: {csv_path}")


In [None]:
# =========================
# 전체 로드 (메모리 충분할 때만)
# =========================

'''
df = pd.read_csv(csv_path, low_memory=False)
print("데이터 로드 완료, shape:", df.shape)
print("컬럼 목록:", list(df.columns))
'''

# =========================
# 메모리 절약형으로 일부 컬럼만 청크 단위로 읽기 + 샘플링
# =========================

usecols = [
    "video_id",
    "view_count", "likes", "comment_count",
    "categoryId", "publish_dayofweek", "tags_count",
    "trending_days", "engagement_score"
]

dtype_map = {
    "video_id": "string",
    "categoryId": "Int16",
    "publish_dayofweek": "Int8",
    "tags_count": "Int16",
    "view_count": "Int32",
    "likes": "Int32",
    "comment_count": "Int32",
    "trending_days": "Int16",
    "engagement_score": "float32",
}

# 청크 단위로 읽어서 일부만 샘플링해서 메모리에 올리기
chunksize = 200_000     # 메모리 부족 → 100_000
sample_frac = 0.01      # 1% 샘플 (SHAP이면 보통 0.5~2%면 충분)

parts = []
total_rows = 0

for chunk in pd.read_csv(
    csv_path,
    usecols=lambda c: c in usecols,
    dtype=dtype_map,
    chunksize=chunksize,
    low_memory=True
):
    total_rows += len(chunk)

    # 결측 제거(타깃이 없는 행은 학습/분석 불가)
    chunk = chunk.dropna(subset=["trending_days", "engagement_score"])

    # 청크별 샘플링
    if len(chunk) > 0:
        parts.append(chunk.sample(frac=sample_frac, random_state=42))

df = pd.concat(parts, ignore_index=True) if parts else pd.DataFrame(columns=usecols)

print("✅ 원본 처리 row 수:", total_rows)
print("✅ 샘플 df shape:", df.shape)
print("✅ 컬럼:", list(df.columns))
display(df.head(3))

# 컬럼 체크
required_cols = [
    "view_count", "likes", "comment_count",
    "categoryId", "publish_dayofweek", "tags_count",
    "trending_days", "engagement_score"
]

missing = [c for c in required_cols if c not in df.columns]

if missing:
    raise ValueError(f"다음 컬럼이 없습니다: {missing}")

# 공통 입력 변수
feature_cols = ["view_count", "likes", "comment_count",
                "categoryId", "publish_dayofweek", "tags_count"]

X = df[feature_cols].copy()
X = X.fillna(0)


In [None]:
# =======================================================
# 트렌딩 유지기간(trending_days) 회귀 모델
# =======================================================

y_trend = df["trending_days"].fillna(0)

X_train_tr, X_test_tr, y_train_tr, y_test_tr = train_test_split(
    X, y_trend,
    test_size=0.2,
    random_state=42
)

xgb_trend = XGBRegressor(
    n_estimators=200,
    max_depth=6,
    learning_rate=0.1,
    subsample=0.8,
    colsample_bytree=0.8,
    random_state=42,
    tree_method="hist",
    n_jobs=-1
)

xgb_trend.fit(X_train_tr, y_train_tr)

y_pred_tr = xgb_trend.predict(X_test_tr)
rmse_tr = mean_squared_error(y_test_tr, y_pred_tr) ** 0.5
mae_tr = mean_absolute_error(y_test_tr, y_pred_tr)

print("\n[TrendingDays - XGBoost 회귀 결과]")
print("RMSE:", rmse_tr)
print("MAE :", mae_tr)


In [None]:
# =======================================================
# SHAP 시각화
# =======================================================

# 저장 경로 (Path)
trending_bar_path = SHAP_FIG_DIR / "shap_fig01_trending_days_bar.png"
trending_beeswarm_path = SHAP_FIG_DIR / "shap_fig02_trending_days_beeswarm.png"

# SHAP용 샘플 (일부만 사용)
sample_size_tr = min(5000, len(X_test_tr))
X_tr_sample = X_test_tr.sample(sample_size_tr, random_state=42)

explainer_trend = shap.TreeExplainer(xgb_trend)
shap_values_trend = explainer_trend.shap_values(X_tr_sample)

# Bar Plot (변수 중요도)
plt.figure()
shap.summary_plot(shap_values_trend, X_tr_sample,
                  plot_type="bar", show=False)
plt.tight_layout()
plt.savefig(trending_bar_path, dpi=200)
plt.close()

# Beeswarm Plot (샘플별 영향도)
plt.figure()
shap.summary_plot(shap_values_trend, X_tr_sample, show=False)
plt.tight_layout()
plt.savefig(trending_beeswarm_path, dpi=200)
plt.close()

print("✅ saved:", trending_bar_path)
print("✅ saved:", trending_beeswarm_path)


In [None]:
# =======================================================
# 고참여(high_engagement) 분류 모델
# =======================================================

# engagement_score 기반으로 high_engagement 라벨 생성
if df["engagement_score"].isna().all():
    raise ValueError("engagement_score가 모두 NaN 입니다. 먼저 engagement_score를 채워야 합니다.")

df_eng = df.dropna(subset=["engagement_score"]).copy()

threshold = df_eng["engagement_score"].quantile(0.8)
df_eng["high_engagement"] = (df_eng["engagement_score"] >= threshold).astype(int)

print("high_engagement threshold (상위 20%):", threshold)
print("high_engagement 분포:")
print(df_eng["high_engagement"].value_counts())

X_cls = df_eng[feature_cols].fillna(0)
y_cls = df_eng["high_engagement"]

X_train_cl, X_test_cl, y_train_cl, y_test_cl = train_test_split(
    X_cls, y_cls,
    test_size=0.2,
    random_state=42,
    stratify=y_cls
)

xgb_cls = XGBClassifier(
    n_estimators=200,
    max_depth=6,
    learning_rate=0.1,
    subsample=0.8,
    colsample_bytree=0.8,
    random_state=42,
    tree_method="hist",
    n_jobs=-1,
    eval_metric="logloss"
)

xgb_cls.fit(X_train_cl, y_train_cl)

y_pred_cl = xgb_cls.predict(X_test_cl)
print("\n[XGBoost Classifier 결과 (high_engagement)]")
print(classification_report(y_test_cl, y_pred_cl))


In [None]:
# =======================================================
# SHAP 시각화
# =======================================================

# 저장 경로 (Path)
engagement_bar_path = SHAP_FIG_DIR / "shap_fig03_high_engagement_bar.png"
engagement_beeswarm_path = SHAP_FIG_DIR / "shap_fig04_high_engagement_beeswarm.png"

sample_size_cl = min(5000, len(X_test_cl))
X_cl_sample = X_test_cl.sample(sample_size_cl, random_state=42)

explainer_cls = shap.TreeExplainer(xgb_cls)
shap_values_cls = explainer_cls.shap_values(X_cl_sample)

# shap_values가 list 로 나올 수도 있음
if isinstance(shap_values_cls, list):

    # 보통 [0] = negative class, [1] = positive class
    shap_values_cls_plot = shap_values_cls[1]

else:
    shap_values_cls_plot = shap_values_cls

# Bar Plot
plt.figure()
shap.summary_plot(shap_values_cls_plot, X_cl_sample,
                    plot_type="bar", show=False)
plt.tight_layout()
plt.savefig(engagement_bar_path, dpi=200)
plt.close()

# Beeswarm Plot
plt.figure()
shap.summary_plot(shap_values_cls_plot, X_cl_sample, show=False)
plt.tight_layout()
plt.savefig(engagement_beeswarm_path, dpi=200)
plt.close()

print("✅ saved:", engagement_bar_path)
print("✅ saved:", engagement_beeswarm_path)
