SARIMAX — Modified Version

1. 라이브러리 Import

In [7]:
import os
import json
import gc
import re
from pathlib import Path

import numpy as np
import pandas as pd
from statsmodels.tsa.statespace.sarimax import SARIMAX
from sklearn.metrics import mean_squared_error
from sklearn.preprocessing import StandardScaler

# 경로 설정 (Linear Regression과 동일)
DATA_DIR = Path("../feature_datasets")
OUTPUT_DIR = Path("results_sarimax")
RESULTS_DIR = OUTPUT_DIR / "results"

OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

print(f"Data Source: {DATA_DIR}")
print(f"Output Path: {OUTPUT_DIR}")

Data Source: ../feature_datasets
Output Path: results_sarimax


In [8]:
def load_data(file_path):
    """Parquet 파일 로드"""
    print(f"Loading {file_path.name}...")
    df = pd.read_parquet(file_path)
    
    # 날짜순 정렬
    if 'date_index' in df.columns:
        df = df.sort_values('date_index').reset_index(drop=True)
    
    # pub_date → Date 변환
    if 'pub_date' not in df.columns:
        raise ValueError(f"'pub_date' column not found in {file_path.name}")
    
    df['Date'] = pd.to_datetime(df['pub_date'], format='%Y_%m_%d')
    
    print(f"   Loaded {len(df)} rows, {len(df.columns)} columns")
    print(f"   Date range: {df['Date'].min().strftime('%Y-%m-%d')} to {df['Date'].max().strftime('%Y-%m-%d')}")
    
    return df

3. Drop Rule 정의 + Daily Aggregation

SARIMAX는 날짜별 딱 1개의 row만 필요하므로

(1) Drop Rule → feature selection 수행 후

(2) 날짜별 mean aggregation 수행 (기사 많아도 1일 1행)

In [9]:
DROP_PATTERNS = [
    r"^value$",          # Endog (target)
    r"^date_index$", r"^person_id$", r"^article_id$", r"^idx$",  # ID 계열
    r"article_date", r"pub_date", r"^date$",                    # 중복 날짜 정보
    r"headline", r"trailText", r"bodyText", r"web.*", r"api.*", r"wordcount",
    r"^person$",                                             # person string
    r"^fg_value$",                                           # Leakage
    r"^lag_\d+$", r"^fg_lag_\d+$"                             # Lag Variables 제거
]

# 패턴 매칭 함수
def match_any(col, patterns):
    return any(re.search(p, col) for p in patterns)

4. SARIMAX용 preprocessing:
   - Drop rule 적용
   - Embedding 확장
   - Daily aggregation
   - Endog/Exog 분리
   - Scaling (train/test leakage 방지)

In [10]:
def preprocess_for_sarimax(df: pd.DataFrame):
    # ---- 1. Date 컬럼 사용 (이미 load_data에서 생성됨) ----
    if 'Date' not in df.columns:
        raise ValueError("'Date' column not found in dataframe")
    
    date_col = 'Date'

    # ---- 2. Drop rule 적용해 valid feature만 남기기 ----
    valid_cols = []
    for col in df.columns:
        if col == date_col:  # 날짜 컬럼은 유지 (groupby용)
            continue
        if match_any(col, DROP_PATTERNS):
            continue
        if pd.api.types.is_numeric_dtype(df[col]):
            valid_cols.append(col)

    # Endog(target) 추가
    agg_cols = valid_cols + ["value"]

    # ---- 3. 날짜별 평균 집계 ----
    daily_df = df.groupby(date_col)[agg_cols].mean()
    daily_df = daily_df.sort_index()

    # ---- 4. y (endog) / X (exog) 분리 ----
    y = daily_df["value"]
    X = daily_df[valid_cols] if valid_cols else None

    return y, X

5. 파일명 파싱

In [11]:
# dataset_A.parquet
# dataset_B_headlines_orig.parquet
# dataset_D_paragraphs_pca.parquet

def parse_filename(fname):
    name = fname.replace("dataset_", "").replace(".parquet", "")
    parts = name.split("_")

    if len(parts) == 1:
        return parts[0], "none", "none"

    if len(parts) >= 3:
        Dataset = parts[0]
        Type = parts[-1]
        Method = "_".join(parts[1:-1])
        return Dataset, Method, Type
    
    return None, None, None

6. SARIMAX Single Run 함수

In [12]:
def run_sarimax_for_file(fname: str, metrics: list):
    Dataset, Method, Type = parse_filename(fname)
    print(f"\n>>> Processing {fname} | Dataset={Dataset}, Method={Method}, Type={Type}")

    # ---- Orig dataset 자동 스킵 처리 ----
    if Type == "orig":
        print("   [Skip] ORIG dataset skipped (1024-d embedding too large for SARIMAX).")
        return

    # ---- 1) 데이터 로드 (load_data 함수 사용) ----
    file_path = DATA_DIR / fname
    df = load_data(file_path)
    
    # ---- 2) 전처리 ----
    y, X = preprocess_for_sarimax(df)

    print(f"   >>> Daily rows: {len(y)}")
    if X is not None:
        print(f"   >>> Exog Feature Count: {X.shape[1]}")

    # ---- 2) Train/Test Split ----
    train_mask = (y.index <= "2019-06-30")
    test_mask  = (y.index >= "2019-07-01") & (y.index <= "2019-12-31")

    y_train = y[train_mask]
    y_test  = y[test_mask]

    # EXOG split
    if X is not None:
        X_train = X[train_mask]
        X_test  = X[test_mask]
    else:
        X_train = X_test = None

    print(f"   >>> Train={len(y_train)}, Test={len(y_test)}")

    if len(y_train) == 0 or len(y_test) == 0:
        print("   [Error] Empty split.")
        return

    # ---- 3) Scaling (NO LEAKAGE) ----
    if X_train is not None:
        scaler = StandardScaler()
        scaler.fit(X_train)
        X_train_scaled = pd.DataFrame(scaler.transform(X_train), index=X_train.index, columns=X_train.columns)
        X_test_scaled  = pd.DataFrame(scaler.transform(X_test), index=X_test.index, columns=X_test.columns)
    else:
        X_train_scaled = X_test_scaled = None

    # ---- 4) SARIMAX 모델 학습 ----
    print("   >>> Fitting SARIMAX (order=(1,1,1), seasonal_order=(1,1,1,5))")
    try:
        model = SARIMAX(
            endog=y_train,
            exog=X_train_scaled,
            order=(1,1,1),
            seasonal_order=(1,1,1,5),
            enforce_stationarity=False,
            enforce_invertibility=False
        )
        fit_model = model.fit(disp=False, maxiter=300)

        # ---- 5) Forecast ----
        forecast = fit_model.get_forecast(steps=len(y_test), exog=X_test_scaled)
        pred = forecast.predicted_mean

        # ---- 6) Metrics ----
        mse = mean_squared_error(y_test, pred)
        print(f"   [Result] Test MSE = {mse:.6f}")

        metrics.append({
            "Feature_set": Dataset,
            "Embeddings": Method or "-",
            "Dim_reduction": Type or "-",
            "Model": "SARIMAX",
            "MSE": mse
        })

        # ---- 7) JSON 저장 ----
        pred_json = {}
        for date, actual, p in zip(y_test.index, y_test.values, pred.values):
            pred_json[date.strftime("%Y-%m-%d")] = {"actual": actual, "pred": p}
        # 데이터셋 이름 생성
        if Dataset == "A":
            dataset_name = "A"
        else:
            dataset_name = f"{Dataset}_{Method or 'none'}_{Type or 'none'}"
        
        json_path = RESULTS_DIR / f"pred_SARIMAX_{dataset_name}.json"
        with open(json_path, "w", encoding="utf-8") as f:
            json.dump(pred_json, f, indent=4)

    except Exception as e:
        print(f"   [Error] Model failed: {e}")

    gc.collect()

7. Main Loop — 모든 dataset 처리

In [13]:
def main():
    if not DATA_DIR.exists():
        print("[Error] Feature directory not found.")
        return

    files = sorted([f for f in os.listdir(DATA_DIR) if f.endswith(".parquet")])

    print(f"Found {len(files)} files.")

    metrics = []
    for fname in files:
        try:
            run_sarimax_for_file(fname, metrics)
        except Exception as e:
            print(f"[Error] {fname}: {e}")

    # ---- 최종 metrics 저장 ----
    if metrics:
        df_metric = pd.DataFrame(metrics)
        df_metric = df_metric.sort_values(["Feature_set", "Embeddings", "Dim_reduction"])

        csv_path = OUTPUT_DIR / "sarimax_evaluation_metrics.csv"
        df_metric.to_csv(csv_path, index=False)

        print("\n[Done] SARIMAX completed.")
        print("Metrics saved to:", csv_path)
        print(df_metric)
    else:
        print("No metrics produced.")

if __name__ == "__main__":
    main()

Found 25 files.

>>> Processing dataset_A.parquet | Dataset=A, Method=none, Type=none
Loading dataset_A.parquet...
   Loaded 754 rows, 9 columns
   Date range: 2017-01-03 to 2019-12-31
   >>> Daily rows: 754
   >>> Train=626, Test=128
   >>> Fitting SARIMAX (order=(1,1,1), seasonal_order=(1,1,1,5))


  self._init_dates(dates, freq)
  self._init_dates(dates, freq)
  return get_prediction_index(
  return get_prediction_index(


   [Result] Test MSE = 5241.398086

>>> Processing dataset_B_bodyText_orig.parquet | Dataset=B, Method=bodyText, Type=orig
   [Skip] ORIG dataset skipped (1024-d embedding too large for SARIMAX).

>>> Processing dataset_B_bodyText_pca.parquet | Dataset=B, Method=bodyText, Type=pca
Loading dataset_B_bodyText_pca.parquet...
   Loaded 460722 rows, 204 columns
   Date range: 2017-01-01 to 2019-12-31
   >>> Daily rows: 1095
   >>> Exog Feature Count: 191
   >>> Train=911, Test=184
   >>> Fitting SARIMAX (order=(1,1,1), seasonal_order=(1,1,1,5))


  self._init_dates(dates, freq)
  self._init_dates(dates, freq)


   [Result] Test MSE = 5059.568327

>>> Processing dataset_B_chunking_orig.parquet | Dataset=B, Method=chunking, Type=orig
   [Skip] ORIG dataset skipped (1024-d embedding too large for SARIMAX).

>>> Processing dataset_B_chunking_pca.parquet | Dataset=B, Method=chunking, Type=pca
Loading dataset_B_chunking_pca.parquet...
   Loaded 460722 rows, 257 columns
   Date range: 2017-01-01 to 2019-12-31
   >>> Daily rows: 1095
   >>> Exog Feature Count: 244
   >>> Train=911, Test=184
   >>> Fitting SARIMAX (order=(1,1,1), seasonal_order=(1,1,1,5))


  self._init_dates(dates, freq)
  self._init_dates(dates, freq)


   [Result] Test MSE = 5709.054234

>>> Processing dataset_B_headlines_orig.parquet | Dataset=B, Method=headlines, Type=orig
   [Skip] ORIG dataset skipped (1024-d embedding too large for SARIMAX).

>>> Processing dataset_B_headlines_pca.parquet | Dataset=B, Method=headlines, Type=pca
Loading dataset_B_headlines_pca.parquet...
   Loaded 461270 rows, 317 columns
   Date range: 2017-01-01 to 2019-12-31
   >>> Daily rows: 1095
   >>> Exog Feature Count: 304
   >>> Train=911, Test=184
   >>> Fitting SARIMAX (order=(1,1,1), seasonal_order=(1,1,1,5))


  self._init_dates(dates, freq)
  self._init_dates(dates, freq)


   [Result] Test MSE = 4698.667975

>>> Processing dataset_B_paragraphs_orig.parquet | Dataset=B, Method=paragraphs, Type=orig
   [Skip] ORIG dataset skipped (1024-d embedding too large for SARIMAX).

>>> Processing dataset_B_paragraphs_pca.parquet | Dataset=B, Method=paragraphs, Type=pca
Loading dataset_B_paragraphs_pca.parquet...
   Loaded 460722 rows, 221 columns
   Date range: 2017-01-01 to 2019-12-31
   >>> Daily rows: 1095
   >>> Exog Feature Count: 208
   >>> Train=911, Test=184
   >>> Fitting SARIMAX (order=(1,1,1), seasonal_order=(1,1,1,5))


  self._init_dates(dates, freq)
  self._init_dates(dates, freq)


   [Result] Test MSE = 5428.652941

>>> Processing dataset_C_bodyText_orig.parquet | Dataset=C, Method=bodyText, Type=orig
   [Skip] ORIG dataset skipped (1024-d embedding too large for SARIMAX).

>>> Processing dataset_C_bodyText_pca.parquet | Dataset=C, Method=bodyText, Type=pca
Loading dataset_C_bodyText_pca.parquet...
   Loaded 460722 rows, 304 columns
   Date range: 2017-01-01 to 2019-12-31
   >>> Daily rows: 1095
   >>> Exog Feature Count: 291
   >>> Train=911, Test=184
   >>> Fitting SARIMAX (order=(1,1,1), seasonal_order=(1,1,1,5))


  self._init_dates(dates, freq)
  self._init_dates(dates, freq)


   [Result] Test MSE = 4839.386842

>>> Processing dataset_C_chunking_orig.parquet | Dataset=C, Method=chunking, Type=orig
   [Skip] ORIG dataset skipped (1024-d embedding too large for SARIMAX).

>>> Processing dataset_C_chunking_pca.parquet | Dataset=C, Method=chunking, Type=pca
Loading dataset_C_chunking_pca.parquet...
   Loaded 460722 rows, 357 columns
   Date range: 2017-01-01 to 2019-12-31
   >>> Daily rows: 1095
   >>> Exog Feature Count: 344
   >>> Train=911, Test=184
   >>> Fitting SARIMAX (order=(1,1,1), seasonal_order=(1,1,1,5))


  self._init_dates(dates, freq)
  self._init_dates(dates, freq)


   [Result] Test MSE = 3858.753400

>>> Processing dataset_C_headlines_orig.parquet | Dataset=C, Method=headlines, Type=orig
   [Skip] ORIG dataset skipped (1024-d embedding too large for SARIMAX).

>>> Processing dataset_C_headlines_pca.parquet | Dataset=C, Method=headlines, Type=pca
Loading dataset_C_headlines_pca.parquet...
   Loaded 461270 rows, 417 columns
   Date range: 2017-01-01 to 2019-12-31
   >>> Daily rows: 1095
   >>> Exog Feature Count: 404
   >>> Train=911, Test=184
   >>> Fitting SARIMAX (order=(1,1,1), seasonal_order=(1,1,1,5))


  self._init_dates(dates, freq)
  self._init_dates(dates, freq)


   [Result] Test MSE = 6680.211604

>>> Processing dataset_C_paragraphs_orig.parquet | Dataset=C, Method=paragraphs, Type=orig
   [Skip] ORIG dataset skipped (1024-d embedding too large for SARIMAX).

>>> Processing dataset_C_paragraphs_pca.parquet | Dataset=C, Method=paragraphs, Type=pca
Loading dataset_C_paragraphs_pca.parquet...
   Loaded 460722 rows, 321 columns
   Date range: 2017-01-01 to 2019-12-31
   >>> Daily rows: 1095
   >>> Exog Feature Count: 308
   >>> Train=911, Test=184
   >>> Fitting SARIMAX (order=(1,1,1), seasonal_order=(1,1,1,5))


  self._init_dates(dates, freq)
  self._init_dates(dates, freq)


   [Result] Test MSE = 3994.778622

>>> Processing dataset_D_bodyText_orig.parquet | Dataset=D, Method=bodyText, Type=orig
   [Skip] ORIG dataset skipped (1024-d embedding too large for SARIMAX).

>>> Processing dataset_D_bodyText_pca.parquet | Dataset=D, Method=bodyText, Type=pca
Loading dataset_D_bodyText_pca.parquet...
   Loaded 460722 rows, 310 columns
   Date range: 2017-01-01 to 2019-12-31
   >>> Daily rows: 1095
   >>> Exog Feature Count: 291
   >>> Train=911, Test=184
   >>> Fitting SARIMAX (order=(1,1,1), seasonal_order=(1,1,1,5))


  self._init_dates(dates, freq)
  self._init_dates(dates, freq)


   [Result] Test MSE = 4839.386842

>>> Processing dataset_D_chunking_orig.parquet | Dataset=D, Method=chunking, Type=orig
   [Skip] ORIG dataset skipped (1024-d embedding too large for SARIMAX).

>>> Processing dataset_D_chunking_pca.parquet | Dataset=D, Method=chunking, Type=pca
Loading dataset_D_chunking_pca.parquet...
   Loaded 460722 rows, 363 columns
   Date range: 2017-01-01 to 2019-12-31
   >>> Daily rows: 1095
   >>> Exog Feature Count: 344
   >>> Train=911, Test=184
   >>> Fitting SARIMAX (order=(1,1,1), seasonal_order=(1,1,1,5))


  self._init_dates(dates, freq)
  self._init_dates(dates, freq)


   [Result] Test MSE = 3858.753400

>>> Processing dataset_D_headlines_orig.parquet | Dataset=D, Method=headlines, Type=orig
   [Skip] ORIG dataset skipped (1024-d embedding too large for SARIMAX).

>>> Processing dataset_D_headlines_pca.parquet | Dataset=D, Method=headlines, Type=pca
Loading dataset_D_headlines_pca.parquet...
   Loaded 461270 rows, 423 columns
   Date range: 2017-01-01 to 2019-12-31
   >>> Daily rows: 1095
   >>> Exog Feature Count: 404
   >>> Train=911, Test=184
   >>> Fitting SARIMAX (order=(1,1,1), seasonal_order=(1,1,1,5))


  self._init_dates(dates, freq)
  self._init_dates(dates, freq)


   [Result] Test MSE = 6680.211604

>>> Processing dataset_D_paragraphs_orig.parquet | Dataset=D, Method=paragraphs, Type=orig
   [Skip] ORIG dataset skipped (1024-d embedding too large for SARIMAX).

>>> Processing dataset_D_paragraphs_pca.parquet | Dataset=D, Method=paragraphs, Type=pca
Loading dataset_D_paragraphs_pca.parquet...
   Loaded 460722 rows, 327 columns
   Date range: 2017-01-01 to 2019-12-31
   >>> Daily rows: 1095
   >>> Exog Feature Count: 308
   >>> Train=911, Test=184
   >>> Fitting SARIMAX (order=(1,1,1), seasonal_order=(1,1,1,5))


  self._init_dates(dates, freq)
  self._init_dates(dates, freq)


   [Result] Test MSE = 3994.778622

[Done] SARIMAX completed.
Metrics saved to: results_sarimax/sarimax_evaluation_metrics.csv
   Feature_set  Embeddings Dim_reduction    Model          MSE
0            A        none          none  SARIMAX  5241.398086
1            B    bodyText           pca  SARIMAX  5059.568327
2            B    chunking           pca  SARIMAX  5709.054234
3            B   headlines           pca  SARIMAX  4698.667975
4            B  paragraphs           pca  SARIMAX  5428.652941
5            C    bodyText           pca  SARIMAX  4839.386842
6            C    chunking           pca  SARIMAX  3858.753400
7            C   headlines           pca  SARIMAX  6680.211604
8            C  paragraphs           pca  SARIMAX  3994.778622
9            D    bodyText           pca  SARIMAX  4839.386842
10           D    chunking           pca  SARIMAX  3858.753400
11           D   headlines           pca  SARIMAX  6680.211604
12           D  paragraphs           pca  SARIMAX  399