In [1]:
import pandas as pd
import re, shap
import matplotlib.pyplot as plt
from pathlib import Path
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# =====================
# 1. 데이터 불러오기 (프로젝트 상대경로)
# =====================

def find_project_root() -> Path:
    p = Path.cwd()
    for parent in [p] + list(p.parents):
        if (parent / "data").exists() and (parent / "notebooks").exists():
            return parent
    return p

def latest_versioned_csv(folder: Path, base_name: str) -> Path | None:
    """
    folder 안에서 base_name_v{n}.csv 중 가장 큰 n 파일 Path 반환
    없으면 None
    """
    pattern = re.compile(rf"^{re.escape(base_name)}_v(\d+)\.csv$")
    best_v, best_path = None, None

    for f in folder.glob(f"{base_name}_v*.csv"):
        m = pattern.match(f.name)
        if m:
            v = int(m.group(1))
            if best_v is None or v > best_v:
                best_v, best_path = v, f

    return best_path

def next_versioned_file(folder: Path, base_name: str, ext: str = ".csv") -> Path:
    """
    folder 안에서 base_name_v{n}{ext} 다음 버전 경로 반환 (파일 저장용)
    """
    folder.mkdir(parents=True, exist_ok=True)
    pattern = re.compile(rf"^{re.escape(base_name)}_v(\d+){re.escape(ext)}$")

    versions = []
    for f in folder.glob(f"{base_name}_v*{ext}"):
        m = pattern.match(f.name)
        if m:
            versions.append(int(m.group(1)))

    v = (max(versions) + 1) if versions else 1
    return folder / f"{base_name}_v{v}{ext}"



In [3]:
PROJECT_ROOT = find_project_root()

CLEAN_DIR = PROJECT_ROOT / "data" / "processed" / "03_kaggle_clean"
csv_path = latest_versioned_csv(CLEAN_DIR, "youtube_channels_clean")
if csv_path is None:
    csv_path = CLEAN_DIR / "youtube_channels_clean_v1.csv"

print("PROJECT_ROOT:", PROJECT_ROOT)
print("CLEAN_DIR:", CLEAN_DIR)

if not csv_path.exists():
    raise FileNotFoundError(f"채널 clean 파일이 없습니다: {csv_path}")

df = pd.read_csv(csv_path, low_memory=False)
print("채널 데이터 로드 완료:", df.shape)
print("컬럼 목록:", list(df.columns))
display(df.head(3))

PROJECT_ROOT: c:\Users\73bib\Desktop\유혜원\제주한라대학교\[2025] 프로젝트\bigdata_project\youtube_trending_ml
CLEAN_DIR: c:\Users\73bib\Desktop\유혜원\제주한라대학교\[2025] 프로젝트\bigdata_project\youtube_trending_ml\data\processed\03_kaggle_clean
채널 데이터 로드 완료: (15830, 17)
컬럼 목록: ['channel_id', 'channel_name', 'subscriber_count', 'view_count', 'video_count', 'created_date', 'category', 'country', 'videos_last_30_days', 'views_last_30_days', 'channel_age_days', 'upload_frequency', 'subscriber_per_view', 'views_per_video', 'uploads_per_subscriber', 'category_encoded', 'country_encoded']


Unnamed: 0,channel_id,channel_name,subscriber_count,view_count,video_count,created_date,category,country,videos_last_30_days,views_last_30_days,channel_age_days,upload_frequency,subscriber_per_view,views_per_video,uploads_per_subscriber,category_encoded,country_encoded
0,UCOmHUn--16B90oW2L6FRR3A,BLACKPINK,99000000,39962585446,636,2016-06-29 03:15:23+00:00,"Music of Asia, Pop music, Music, Electronic music",KR,1,3256869,3450.0,0.184348,0.002477,62834250.0,6e-06,1916,56
1,UC3IZKseVpdzPSBaWxBxundA,HYBE LABELS,78700000,41604896923,2817,2008-06-04 08:23:22+00:00,"Hip hop music, Pop music, Music, Music of Asia",KR,79,46074833,6397.0,0.440363,0.001892,14769220.0,3.6e-05,1122,56
2,UCVNE660NcgYzi18LwwUZb7Q,BILLIE EILISH,82300,14316364,1,2019-01-18 05:14:32+00:00,,,0,0,2517.0,0.000397,0.005749,14316360.0,1.2e-05,3443,110


In [4]:
# =====================
# 2. 입력 변수 / 타깃 설정
#    (3번에서 쓰던 것과 동일하게 맞춤)
# =====================

feature_cols = [
    'upload_frequency',
    'views_per_video',
    'subscriber_per_view',
    'video_count',
    'channel_age_days',
    'category_encoded',
    'country_encoded',
]

target_col = 'views_last_30_days'

# 방어 코드: 컬럼 존재 체크
missing_feats = [c for c in feature_cols if c not in df.columns]
if missing_feats:
    raise ValueError(f"아래 feature 컬럼이 csv에 없습니다: {missing_feats}")

if target_col not in df.columns:
    raise ValueError(f"타깃 컬럼 '{target_col}' 이(가) csv에 없습니다!")

X = df[feature_cols].copy()
y = df[target_col].copy()

# 결측치 간단 처리
X = X.fillna(0)
y = y.fillna(0)

print("\n사용 독립변수(feature_cols):", feature_cols)
print("타깃 변수:", target_col)



사용 독립변수(feature_cols): ['upload_frequency', 'views_per_video', 'subscriber_per_view', 'video_count', 'channel_age_days', 'category_encoded', 'country_encoded']
타깃 변수: views_last_30_days


In [5]:
# =====================
# 3. Train / Test 분할
# =====================

X_train, X_test, y_train, y_test = train_test_split(
    X, y,
    test_size=0.2,
    random_state=42
)

print("\nTrain size:", X_train.shape, " / Test size:", X_test.shape)



Train size: (12664, 7)  / Test size: (3166, 7)


In [6]:
# =====================
# 4. RandomForest 회귀 모델 학습
# =====================

rf = RandomForestRegressor(
    n_estimators=300,
    max_depth=10,
    min_samples_leaf=20,
    random_state=42,
    n_jobs=-1
)

rf.fit(X_train, y_train)
print("RandomForest 회귀 학습 완료")


RandomForest 회귀 학습 완료


In [7]:
# =====================
# 5. SHAP 값 계산 (채널 성장)
# =====================

# 샘플링 (채널 15830개라 전부 써도 되지만, 안전하게 3000개 제한)
if len(X_train) > 3000:
    X_sample = X_train.sample(3000, random_state=42)
else:
    X_sample = X_train.copy()

print("SHAP 계산용 샘플 크기:", X_sample.shape)

explainer = shap.TreeExplainer(rf)
shap_values = explainer.shap_values(X_sample)

print("SHAP 값 계산 완료")


SHAP 계산용 샘플 크기: (3000, 7)
SHAP 값 계산 완료


In [8]:
# =====================
# 6. SHAP 시각화 저장 (프로젝트 상대경로)
# =====================

def find_project_root() -> Path:
    p = Path.cwd()
    for parent in [p] + list(p.parents):
        if (parent / "data").exists() and (parent / "notebooks").exists():
            return parent
    return p

out_dir = find_project_root() / "reports" / "figures" / "shap"
out_dir.mkdir(parents=True, exist_ok=True)

# 저장 경로(Path)
bar_path = out_dir / "shap_channel_growth_bar.png"
dot_path = out_dir / "shap_channel_growth_dot.png"

print("✅ SHAP out_dir:", out_dir)
print("bar_path:", bar_path)
print("dot_path:", dot_path)

# 6-1. 변수 중요도 막대그래프 (mean(|SHAP|))
plt.figure()
shap.summary_plot(
    shap_values,
    X_sample,
    plot_type="bar",
    show=False
)
plt.tight_layout()
plt.savefig(bar_path, dpi=200, bbox_inches="tight")
plt.close()

# 6-2. 점 그래프 (SHAP summary dot)
plt.figure()
shap.summary_plot(
    shap_values,
    X_sample,
    show=False
)
plt.tight_layout()
plt.savefig(dot_path, dpi=200, bbox_inches="tight")
plt.close()

print("✅ saved:", bar_path)
print("✅ saved:", dot_path)


✅ SHAP out_dir: c:\Users\73bib\Desktop\유혜원\제주한라대학교\[2025] 프로젝트\bigdata_project\youtube_trending_ml\reports\figures\shap
bar_path: c:\Users\73bib\Desktop\유혜원\제주한라대학교\[2025] 프로젝트\bigdata_project\youtube_trending_ml\reports\figures\shap\shap_channel_growth_bar.png
dot_path: c:\Users\73bib\Desktop\유혜원\제주한라대학교\[2025] 프로젝트\bigdata_project\youtube_trending_ml\reports\figures\shap\shap_channel_growth_dot.png
✅ saved: c:\Users\73bib\Desktop\유혜원\제주한라대학교\[2025] 프로젝트\bigdata_project\youtube_trending_ml\reports\figures\shap\shap_channel_growth_bar.png
✅ saved: c:\Users\73bib\Desktop\유혜원\제주한라대학교\[2025] 프로젝트\bigdata_project\youtube_trending_ml\reports\figures\shap\shap_channel_growth_dot.png
