SARIMAX — Modified Version

1. 라이브러리 Import

In [12]:
import os
import json
import gc
import re
from pathlib import Path

import numpy as np
import pandas as pd
from statsmodels.tsa.statespace.sarimax import SARIMAX
from sklearn.metrics import mean_squared_error
from sklearn.preprocessing import StandardScaler

# 프로젝트 경로 설정
ROOT = Path.cwd()
FEATURE_DIR = ROOT / "../feature_datasets"
OUTPUT_DIR = ROOT / "results_sarimax"
OUTPUT_DIR.mkdir(exist_ok=True)

print("PROJECT ROOT:", ROOT)
print("FEATURE DIR :", FEATURE_DIR)
print("OUTPUT DIR  :", OUTPUT_DIR)

PROJECT ROOT: /NAS/wonjun/20252R0136COSE36203/Prediction
FEATURE DIR : /NAS/wonjun/20252R0136COSE36203/Prediction/../feature_datasets
OUTPUT DIR  : /NAS/wonjun/20252R0136COSE36203/Prediction/results_sarimax


2. 임베딩 확장 함수 — orig dataset 지원

orig 데이터는 embedding 컬럼이 numpy.ndarray 형태이므로 emb_0 ~ emb_1023 같은 형태로 펼쳐줘야 SARIMAX에서 사용 가능

(단, PCA 데이터는 embedding 컬럼이 없음 → 그대로 사용)

In [13]:
def expand_embeddings(df: pd.DataFrame) -> pd.DataFrame:
    if "embedding" not in df.columns:
        return df

    print("   >>> Expanding 'embedding' column into emb_0 ~ emb_k ...")

    # embedding column → NxD matrix
    emb_matrix = np.stack(df["embedding"].values)
    dim = emb_matrix.shape[1]

    # 새 컬럼명 생성
    col_names = [f"emb_{i}" for i in range(dim)]
    emb_df = pd.DataFrame(emb_matrix, columns=col_names, index=df.index)

    # 원래 embedding 컬럼 제거 후 병합
    df_exp = pd.concat([df.drop(columns=["embedding"]), emb_df], axis=1)

    del emb_matrix, emb_df
    gc.collect()
    return df_exp

3. Drop Rule 정의 + Daily Aggregation

SARIMAX는 날짜별 딱 1개의 row만 필요하므로

(1) Drop Rule → feature selection 수행 후

(2) 날짜별 mean aggregation 수행 (기사 많아도 1일 1행)

In [14]:
DROP_PATTERNS = [
    r"^value$",          # Endog (target)
    r"^date_index$", r"^person_id$", r"^article_id$", r"^idx$",  # ID 계열
    r"article_date", r"pub_date", r"^date$",                    # 중복 날짜 정보
    r"headline", r"trailText", r"bodyText", r"web.*", r"api.*", r"wordcount",
    r"^person$",                                             # person string
    r"^fg_value$",                                           # Leakage
    r"^lag_\d+$", r"^fg_lag_\d+$"                             # Lag Variables 제거
]

# 패턴 매칭 함수
def match_any(col, patterns):
    return any(re.search(p, col) for p in patterns)

4. SARIMAX용 preprocessing:
   - Drop rule 적용
   - Embedding 확장
   - Daily aggregation
   - Endog/Exog 분리
   - Scaling (train/test leakage 방지)

In [15]:
def preprocess_for_sarimax(df_raw: pd.DataFrame):
    # ---- 1. orig dataset이면 embedding 펼치기 ----
    df = expand_embeddings(df_raw)

    # ---- 2. 날짜 컬럼 자동 감지 (date_str vs pub_date) ----
    if "date_str" in df.columns:
        date_col = "date_str"
    elif "pub_date" in df.columns:
        date_col = "pub_date"
    else:
        # 컬럼이 둘 다 없으면 에러 발생
        raise ValueError(f"날짜 컬럼(date_str 또는 pub_date)이 없습니다. Columns: {df.columns.tolist()}")

    # ---- 3. Drop rule 적용해 valid feature만 남기기 ----
    valid_cols = []
    for col in df.columns:
        if col == date_col: # 날짜 컬럼은 유지 (groupby용)
            continue
        if match_any(col, DROP_PATTERNS):
            continue
        if pd.api.types.is_numeric_dtype(df[col]):
            valid_cols.append(col)

    # Endog(target) 추가
    agg_cols = valid_cols + ["value"]

    # ---- 4. 날짜별 평균 집계 ----
    # 감지된 date_col을 기준으로 그룹화
    daily_df = df.groupby(date_col)[agg_cols].mean()

    # ---- 5. 인덱스 설정 (datetime) ----
    # YYYY_MM_DD 형식이면변환
    daily_df.index = pd.to_datetime(daily_df.index.astype(str).str.replace("_", "-"))
    daily_df = daily_df.sort_index()

    # ---- 6. y (endog) / X (exog) 분리 ----
    y = daily_df["value"]
    X = daily_df[valid_cols] if valid_cols else None

    return y, X

5. 파일명 파싱

In [16]:
# dataset_A.parquet
# dataset_B_headlines_orig.parquet
# dataset_D_paragraphs_pca.parquet

def parse_filename(fname):
    name = fname.replace("dataset_", "").replace(".parquet", "")
    parts = name.split("_")

    if len(parts) == 1:
        return parts[0], "none", "none"

    if len(parts) >= 3:
        Dataset = parts[0]
        Type = parts[-1]
        Method = "_".join(parts[1:-1])
        return Dataset, Method, Type
    
    return None, None, None

6. SARIMAX Single Run 함수

In [17]:
def run_sarimax_for_file(fname: str, metrics: list):
    Dataset, Method, Type = parse_filename(fname)
    print(f"\n>>> Processing {fname} | Dataset={Dataset}, Method={Method}, Type={Type}")

    # ---- Orig dataset 자동 스킵 처리 ----
    if Type == "orig":
        print("   [Skip] ORIG dataset skipped (1024-d embedding too large for SARIMAX).")
        return

    # ---- 1) 데이터 로드 + 전처리 ----
    df_raw = pd.read_parquet(FEATURE_DIR / fname)
    y, X = preprocess_for_sarimax(df_raw)

    print(f"   >>> Daily rows: {len(y)}")
    if X is not None:
        print(f"   >>> Exog Feature Count: {X.shape[1]}")

    # ---- 2) Train/Test Split ----
    train_mask = (y.index <= "2019-06-30")
    test_mask  = (y.index >= "2019-07-01") & (y.index <= "2019-12-31")

    y_train = y[train_mask]
    y_test  = y[test_mask]

    # EXOG split
    if X is not None:
        X_train = X[train_mask]
        X_test  = X[test_mask]
    else:
        X_train = X_test = None

    print(f"   >>> Train={len(y_train)}, Test={len(y_test)}")

    if len(y_train) == 0 or len(y_test) == 0:
        print("   [Error] Empty split.")
        return

    # ---- 3) Scaling (NO LEAKAGE) ----
    if X_train is not None:
        scaler = StandardScaler()
        scaler.fit(X_train)
        X_train_scaled = pd.DataFrame(scaler.transform(X_train), index=X_train.index, columns=X_train.columns)
        X_test_scaled  = pd.DataFrame(scaler.transform(X_test), index=X_test.index, columns=X_test.columns)
    else:
        X_train_scaled = X_test_scaled = None

    # ---- 4) SARIMAX 모델 학습 ----
    print("   >>> Fitting SARIMAX (order=(1,1,1), seasonal_order=(1,1,1,5))")
    try:
        model = SARIMAX(
            endog=y_train,
            exog=X_train_scaled,
            order=(1,1,1),
            seasonal_order=(1,1,1,5),
            enforce_stationarity=False,
            enforce_invertibility=False
        )
        fit_model = model.fit(disp=False, maxiter=300)

        # ---- 5) Forecast ----
        forecast = fit_model.get_forecast(steps=len(y_test), exog=X_test_scaled)
        pred = forecast.predicted_mean

        # ---- 6) Metrics ----
        mse = mean_squared_error(y_test, pred)
        print(f"   [Result] Test MSE = {mse:.6f}")

        metrics.append({
            "Dataset": Dataset,
            "Method": Method or "none",
            "Type": Type or "none",
            "Model": "SARIMAX",
            "MSE": mse
        })

        # ---- 7) JSON 저장 ----
        pred_json = {}
        for date, actual, p in zip(y_test.index, y_test.values, pred.values):
            pred_json[date.strftime("%Y-%m-%d")] = {"actual": actual, "pred": p}

        json_path = OUTPUT_DIR / f"pred_SARIMAX_{Dataset}_{Method or 'none'}_{Type or 'none'}.json"
        with open(json_path, "w", encoding="utf-8") as f:
            json.dump(pred_json, f, indent=4)

    except Exception as e:
        print(f"   [Error] Model failed: {e}")

    gc.collect()

7. Main Loop — 모든 dataset 처리

In [18]:
def main():
    if not FEATURE_DIR.exists():
        print("[Error] Feature directory not found.")
        return

    files = sorted([f for f in os.listdir(FEATURE_DIR) if f.endswith(".parquet")])

    print(f"Found {len(files)} files.")

    metrics = []
    for fname in files:
        try:
            run_sarimax_for_file(fname, metrics)
        except Exception as e:
            print(f"[Error] {fname}: {e}")

    # ---- 최종 metrics 저장 ----
    if metrics:
        df_metric = pd.DataFrame(metrics)
        df_metric = df_metric.sort_values(["Dataset", "Method", "Type"])

        csv_path = OUTPUT_DIR / "sarimax_evaluation_metrics.csv"
        df_metric.to_csv(csv_path, index=False)

        print("\n[Done] SARIMAX completed.")
        print("Metrics saved to:", csv_path)
        print(df_metric)
    else:
        print("No metrics produced.")


if __name__ == "__main__":
    main()

Found 25 files.

>>> Processing dataset_A.parquet | Dataset=A, Method=none, Type=none
   >>> Daily rows: 754
   >>> Train=626, Test=128
   >>> Fitting SARIMAX (order=(1,1,1), seasonal_order=(1,1,1,5))


  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  return get_prediction_index(
  return get_prediction_index(


   [Result] Test MSE = 5241.431595

>>> Processing dataset_B_bodyText_orig.parquet | Dataset=B, Method=bodyText, Type=orig
   [Skip] ORIG dataset skipped (1024-d embedding too large for SARIMAX).

>>> Processing dataset_B_bodyText_pca.parquet | Dataset=B, Method=bodyText, Type=pca
   >>> Daily rows: 1095
   >>> Exog Feature Count: 191
   >>> Train=911, Test=184
   >>> Fitting SARIMAX (order=(1,1,1), seasonal_order=(1,1,1,5))


  self._init_dates(dates, freq)
  self._init_dates(dates, freq)


   [Result] Test MSE = 4938.855466

>>> Processing dataset_B_chunking_orig.parquet | Dataset=B, Method=chunking, Type=orig
   [Skip] ORIG dataset skipped (1024-d embedding too large for SARIMAX).

>>> Processing dataset_B_chunking_pca.parquet | Dataset=B, Method=chunking, Type=pca
   >>> Daily rows: 1095
   >>> Exog Feature Count: 244
   >>> Train=911, Test=184
   >>> Fitting SARIMAX (order=(1,1,1), seasonal_order=(1,1,1,5))


  self._init_dates(dates, freq)
  self._init_dates(dates, freq)


   [Result] Test MSE = 5654.456251

>>> Processing dataset_B_headlines_orig.parquet | Dataset=B, Method=headlines, Type=orig
   [Skip] ORIG dataset skipped (1024-d embedding too large for SARIMAX).

>>> Processing dataset_B_headlines_pca.parquet | Dataset=B, Method=headlines, Type=pca
   >>> Daily rows: 1095
   >>> Exog Feature Count: 304
   >>> Train=911, Test=184
   >>> Fitting SARIMAX (order=(1,1,1), seasonal_order=(1,1,1,5))


  self._init_dates(dates, freq)
  self._init_dates(dates, freq)


   [Result] Test MSE = 4606.232972

>>> Processing dataset_B_paragraphs_orig.parquet | Dataset=B, Method=paragraphs, Type=orig
   [Skip] ORIG dataset skipped (1024-d embedding too large for SARIMAX).

>>> Processing dataset_B_paragraphs_pca.parquet | Dataset=B, Method=paragraphs, Type=pca
   >>> Daily rows: 1095
   >>> Exog Feature Count: 208
   >>> Train=911, Test=184
   >>> Fitting SARIMAX (order=(1,1,1), seasonal_order=(1,1,1,5))


  self._init_dates(dates, freq)
  self._init_dates(dates, freq)


KeyboardInterrupt: 