In [None]:
import re
import shap
import pandas as pd
import matplotlib.pyplot as plt

from pathlib import Path
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split


In [None]:
# =====================
# 데이터 불러오기 (프로젝트 상대경로)
# =====================

def find_project_root() -> Path:
    p = Path.cwd()

    for parent in [p] + list(p.parents):
        if (parent / "data").exists() and (parent / "notebooks").exists():
            return parent
        
    return p

def latest_versioned_csv(folder: Path, base_name: str) -> Path | None:
    pattern = re.compile(rf"^{re.escape(base_name)}_v(\d+)\.csv$")
    best_v, best_path = None, None

    for f in folder.glob(f"{base_name}_v*.csv"):
        m = pattern.match(f.name)

        if m:
            v = int(m.group(1))

            if best_v is None or v > best_v:
                best_v, best_path = v, f

    return best_path

def next_versioned_file(folder: Path, base_name: str, ext: str = ".csv") -> Path:
    folder.mkdir(parents=True, exist_ok=True)

    pattern = re.compile(rf"^{re.escape(base_name)}_v(\d+){re.escape(ext)}$")
    versions = []

    for f in folder.glob(f"{base_name}_v*{ext}"):
        m = pattern.match(f.name)

        if m:
            versions.append(int(m.group(1)))

    v = (max(versions) + 1) if versions else 1
    return folder / f"{base_name}_v{v}{ext}"

PROJECT_ROOT = find_project_root()

CLEAN_DIR = PROJECT_ROOT / "data" / "processed"

csv_path = latest_versioned_csv(CLEAN_DIR, "channels_clean")
if csv_path is None:
    csv_path = CLEAN_DIR / "channels_clean_v1.csv"

SHAP_FIG_DIR = PROJECT_ROOT / "figures" / "shap"
SHAP_FIG_DIR.mkdir(parents=True, exist_ok=True)

print("PROJECT_ROOT:", PROJECT_ROOT)
print("CLEAN_DIR:", CLEAN_DIR)
print("SHAP_FIG_DIR:", SHAP_FIG_DIR)

if not csv_path.exists():
    raise FileNotFoundError(f"채널 clean 파일이 없습니다: {csv_path}")

df = pd.read_csv(csv_path, low_memory=False)

print("채널 데이터 로드 완료:", df.shape)
print("컬럼 목록:", list(df.columns))
display(df.head(3))

In [None]:
# =====================
# 입력 변수 / 타깃 설정
# =====================

feature_cols = [
    'upload_frequency',
    'views_per_video',
    'subscriber_per_view',
    'video_count',
    'channel_age_days',
    'category_encoded',
    'country_encoded',
]

target_col = 'views_last_30_days'

# 컬럼 체크
missing_feats = [c for c in feature_cols if c not in df.columns]

if missing_feats:
    raise ValueError(f"아래 feature 컬럼이 csv에 없습니다: {missing_feats}")

if target_col not in df.columns:
    raise ValueError(f"타깃 컬럼 '{target_col}' 이(가) csv에 없습니다!")

X = df[feature_cols].copy()
y = df[target_col].copy()

# 결측치 처리
X = X.fillna(0)
y = y.fillna(0)

print("\n사용 독립변수(feature_cols):", feature_cols)
print("타깃 변수:", target_col)


In [None]:
# =====================
# Train / Test 분할
# =====================

X_train, X_test, y_train, y_test = train_test_split(
    X, y,
    test_size=0.2,
    random_state=42
)

print("\nTrain size:", X_train.shape, " / Test size:", X_test.shape)


In [None]:
# =====================
# RandomForest 회귀 모델 학습
# =====================

rf = RandomForestRegressor(
    n_estimators=300,
    max_depth=10,
    min_samples_leaf=20,
    random_state=42,
    n_jobs=-1
)

rf.fit(X_train, y_train)
print("RandomForest 회귀 학습 완료")


In [None]:
# =====================
# SHAP 값 계산 (채널 성장)
# =====================

# 샘플링
if len(X_train) > 3000:
    X_sample = X_train.sample(3000, random_state=42)
else:
    X_sample = X_train.copy()

print("SHAP 계산용 샘플 크기:", X_sample.shape)

explainer = shap.TreeExplainer(rf)
shap_values = explainer.shap_values(X_sample)

print("SHAP 값 계산 완료")


In [None]:
# =====================
# SHAP 시각화
# =====================

# 저장 경로 (Path)
bar_path = SHAP_FIG_DIR / "shap_fig05_channel_growth_bar.png"
dot_path = SHAP_FIG_DIR / "shap_fig06_channel_growth_dot.png"

# 변수 중요도 막대그래프 (mean(|SHAP|))
plt.figure()
shap.summary_plot(
    shap_values,
    X_sample,
    plot_type="bar",
    show=False
)
plt.tight_layout()
plt.savefig(bar_path, dpi=200, bbox_inches="tight")
plt.close()

# 점 그래프 (SHAP summary dot)
plt.figure()
shap.summary_plot(
    shap_values,
    X_sample,
    show=False
)
plt.tight_layout()
plt.savefig(dot_path, dpi=200, bbox_inches="tight")
plt.close()

print("✅ saved:", bar_path)
print("✅ saved:", dot_path)
