# TFT (Temporal Fusion Transformer) Training

25개 dataset variants에 대해 TFT 모델을 학습하고 MSE를 평가합니다.

## 필요 패키지
```bash
pip install pytorch-forecasting lightning pandas numpy scikit-learn pyarrow torch
```

In [1]:
# ==================== IMPORTS ====================
import warnings
warnings.filterwarnings('ignore')


import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

import pandas as pd
import numpy as np
from pathlib import Path
from typing import Dict, List, Tuple, Optional
import json
from datetime import datetime
import torch
import torch.nn as nn
from sklearn.preprocessing import StandardScaler
import gc

# PyTorch Forecasting
from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.metrics import QuantileLoss, MAE, RMSE
import lightning.pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

PyTorch version: 2.9.0+cu128
CUDA available: True
CUDA device: NVIDIA GeForce RTX 4090


In [2]:
# ==================== CONFIGURATION ====================
CONFIG = {
    # Paths
    'data_dir': Path('../feature_datasets'),
    'output_file': Path('./output_tft.csv'),
    'log_dir': Path('./tft_logs'),
    'checkpoint_dir': Path('./tft_checkpoints'),
    
    # Data settings
    'remove_columns': ['person', 'article_id', 'pub_date', 'article_date', 'idx', 'date_str', 'person_id', 'headline'],
    
    # Time series settings
    'time_idx': 'date_index',
    'target': 'value',
    'max_encoder_length': 30,
    'max_prediction_length': 5,
    
    # Model hyperparameters
    'hidden_size': 128,
    'lstm_layers': 2,
    'attention_head_size': 4,
    'dropout': 0.1,
    'hidden_continuous_size': 32,
    
    # Training settings
    'batch_size': 32,
    'learning_rate': 1e-3,
    'max_epochs': 50,
    'gradient_clip_val': 0.1,
    'patience': 10,
    
    # Loss & metrics
    'quantiles': [0.1, 0.5, 0.9],
}

# Create output directories
CONFIG['log_dir'].mkdir(exist_ok=True, parents=True)
CONFIG['checkpoint_dir'].mkdir(exist_ok=True, parents=True)

print("✓ Configuration loaded!")

✓ Configuration loaded!


In [3]:
# ==================== HELPER FUNCTIONS ====================

def get_dataset_filename(variant: Dict) -> str:
    if variant['type'] == 'A':
        return 'dataset_A.parquet'
    else:
        emb = variant['embedding']
        pca = variant['pca']
        return f"dataset_{variant['type']}_{emb}_{pca}.parquet"

def get_variant_name(variant: Dict) -> str:
    if variant['type'] == 'A':
        return 'A_sp500_only'
    else:
        return f"{variant['type']}_{variant['embedding']}_{variant['pca']}"

def calculate_mse(actual: np.ndarray, predicted: np.ndarray) -> float:
    return np.mean((actual - predicted) ** 2)

def load_and_preprocess_data(filepath: Path, variant: Dict, config: Dict) -> pd.DataFrame:
    print(f"  Loading: {filepath.name}")
    df = pd.read_parquet(filepath)
    print(f"  Original shape: {df.shape}")
    
    # Remove unnecessary columns
    cols_to_remove = [col for col in config['remove_columns'] if col in df.columns]
    df = df.drop(columns=cols_to_remove)
    print(f"  After removing columns: {df.shape}")
    
    # Handle embedding column
    if 'embedding' in df.columns:
        print(f"  Expanding embedding column...")
        emb_df = pd.DataFrame(df['embedding'].tolist(), 
                            index=df.index,
                            columns=[f'emb_{i}' for i in range(len(df['embedding'].iloc[0]))])
        df = pd.concat([df.drop('embedding', axis=1), emb_df], axis=1)
        print(f"  Expanded to {emb_df.shape[1]} columns")
    
    # Ensure date_index is integer
    if config['time_idx'] in df.columns:
        df[config['time_idx']] = df[config['time_idx']].astype(int)
    
    # AGGREGATION BY DATE (for B, C, D)
    if variant['type'] != 'A':
        print(f"  Aggregating by date...")
        agg_dict = {}
        for col in df.columns:
            if col == config['time_idx']:
                continue
            elif pd.api.types.is_numeric_dtype(df[col]):
                agg_dict[col] = 'mean'
            else:
                agg_dict[col] = 'first'
        df = df.groupby(config['time_idx']).agg(agg_dict).reset_index()
        print(f"  After aggregation: {len(df)} rows")
    
    # Sort by date_index
    df = df.sort_values(config['time_idx']).reset_index(drop=True)
    print(f"  Final shape: {df.shape}")
    
    return df

def split_data_by_date(df: pd.DataFrame, config: Dict) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    unique_dates = sorted(df[config['time_idx']].unique())
    total_days = len(unique_dates)
    print(f"  Total unique dates: {total_days}")
    
    # 66.7% / 16.7% / 16.7%
    train_end_idx = int(total_days * 0.667)
    valid_end_idx = int(total_days * 0.833)
    
    train_dates = unique_dates[:train_end_idx]
    valid_dates = unique_dates[train_end_idx:valid_end_idx]
    test_dates = unique_dates[valid_end_idx:]
    
    train_df = df[df[config['time_idx']].isin(train_dates)].copy()
    valid_df = df[df[config['time_idx']].isin(valid_dates)].copy()
    test_df = df[df[config['time_idx']].isin(test_dates)].copy()
    
    print(f"  Train: {len(train_df)} rows, {len(train_dates)} dates")
    print(f"  Valid: {len(valid_df)} rows, {len(valid_dates)} dates")
    print(f"  Test: {len(test_df)} rows, {len(test_dates)} dates")
    
    return train_df, valid_df, test_df

def create_time_series_dataset(
    df: pd.DataFrame,
    config: Dict,
    variant: Dict,
) -> TimeSeriesDataSet:
    # NOTE: series_id must be already added to df!
    
    # Identify features
    time_varying_unknown = []
    static_reals = []
    
    exclude_cols = {config['time_idx'], config['target'], 'series_id'}
    
    for col in df.columns:
        if col in exclude_cols:
            continue
        elif col.startswith('person_'):
            static_reals.append(col)
        elif pd.api.types.is_numeric_dtype(df[col]):
            time_varying_unknown.append(col)
    
    time_varying_unknown = sorted(list(set(time_varying_unknown)))
    static_reals = sorted(list(set(static_reals)))
    
    print(f"  Time-varying unknown: {len(time_varying_unknown)}")
    print(f"  Static reals: {len(static_reals)}")
    
    # Create dataset
    dataset = TimeSeriesDataSet(
        df,
        time_idx=config['time_idx'],
        target=config['target'],
        group_ids=['series_id'],
        max_encoder_length=config['max_encoder_length'],
        max_prediction_length=config['max_prediction_length'],
        static_reals=static_reals if static_reals else None,
        time_varying_unknown_reals=time_varying_unknown if time_varying_unknown else None,
        target_normalizer=GroupNormalizer(
            groups=['series_id'], 
            transformation="softplus"
        ),
        add_relative_time_idx=True,
        add_target_scales=True,
        add_encoder_length=True,
        allow_missing_timesteps=False,
    )
    
    return dataset

def train_tft_model(
    train_dataset: TimeSeriesDataSet,
    valid_dataset: TimeSeriesDataSet,
    config: Dict,
    variant: Dict,
) -> TemporalFusionTransformer:
    # Create dataloaders
    train_dataloader = train_dataset.to_dataloader(
        train=True, 
        batch_size=config['batch_size'],
        num_workers=0
    )
    valid_dataloader = valid_dataset.to_dataloader(
        train=False,
        batch_size=config['batch_size'] * 2,
        num_workers=0
    )
    
    # Define model
    tft = TemporalFusionTransformer.from_dataset(
        train_dataset,
        learning_rate=config['learning_rate'],
        hidden_size=config['hidden_size'],
        attention_head_size=config['attention_head_size'],
        dropout=config['dropout'],
        hidden_continuous_size=config['hidden_continuous_size'],
        lstm_layers=config['lstm_layers'],
        loss=QuantileLoss(quantiles=config['quantiles']),
        reduce_on_plateau_patience=4,
    )
    
    print(f"  Model parameters: {sum(p.numel() for p in tft.parameters()):,}")
    
    # Setup callbacks
    variant_name = get_variant_name(variant)
    
    early_stop_callback = EarlyStopping(
        monitor="val_loss",
        min_delta=1e-4,
        patience=config['patience'],
        verbose=False,
        mode="min"
    )
    
    checkpoint_callback = ModelCheckpoint(
        monitor="val_loss",
        dirpath=config['checkpoint_dir'] / variant_name,
        filename='best',
        save_top_k=1,
        mode='min',
    )
    
    # Create trainer
    trainer = pl.Trainer(
        max_epochs=config['max_epochs'],
        accelerator='gpu' if torch.cuda.is_available() else 'cpu',
        devices=1,
        gradient_clip_val=config['gradient_clip_val'],
        callbacks=[early_stop_callback, checkpoint_callback],
        enable_progress_bar=True,
        enable_model_summary=False,
    )
    
    # Train
    print("  Training...")
    trainer.fit(
        tft,
        train_dataloaders=train_dataloader,
        val_dataloaders=valid_dataloader,
    )
    
    # Load best model
    best_model_path = checkpoint_callback.best_model_path
    best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)
    
    return best_tft

def predict_and_evaluate(
    model: TemporalFusionTransformer,
    dataset: TimeSeriesDataSet,
    config: Dict,
) -> float:
    dataloader = dataset.to_dataloader(
        train=False,
        batch_size=config['batch_size'] * 4,
        num_workers=0
    )
    
    # Get predictions
    predictions, x = model.predict(dataloader, mode="prediction", return_x=True, return_index=True)
    
    # Extract median (index 1)
    pred_raw = predictions[:, :, 1].cpu().numpy()
    actual_raw = x["decoder_target"].cpu().numpy()
    
    # Flatten
    actual_flat = actual_raw.flatten()
    pred_flat = pred_raw.flatten()
    
    # Remove NaN
    valid_mask = ~(np.isnan(actual_flat) | np.isnan(pred_flat))
    actual_clean = actual_flat[valid_mask]
    pred_clean = pred_flat[valid_mask]
    
    # Calculate MSE
    mse = calculate_mse(actual_clean, pred_clean)
    
    print(f"  Valid predictions: {len(actual_clean)}")
    print(f"  Test MSE: {mse:.8f}")
    
    return mse

print("✓ All functions loaded!")

✓ All functions loaded!


In [4]:
# ==================== DATASET VARIANTS ====================

DATASET_VARIANTS = [
    # A: SP500만 (1가지)
    {'type': 'A', 'embedding': None, 'pca': None},
    
    # B: SP500 + embeddings (8가지)
    {'type': 'B', 'embedding': 'headlines', 'pca': 'orig'},
    {'type': 'B', 'embedding': 'headlines', 'pca': 'pca'},
    {'type': 'B', 'embedding': 'chunking', 'pca': 'orig'},
    {'type': 'B', 'embedding': 'chunking', 'pca': 'pca'},
    {'type': 'B', 'embedding': 'bodyText', 'pca': 'orig'},
    {'type': 'B', 'embedding': 'bodyText', 'pca': 'pca'},
    {'type': 'B', 'embedding': 'paragraphs', 'pca': 'orig'},
    {'type': 'B', 'embedding': 'paragraphs', 'pca': 'pca'},
    
    # C: B + person one-hot (8가지)
    {'type': 'C', 'embedding': 'headlines', 'pca': 'orig'},
    {'type': 'C', 'embedding': 'headlines', 'pca': 'pca'},
    {'type': 'C', 'embedding': 'chunking', 'pca': 'orig'},
    {'type': 'C', 'embedding': 'chunking', 'pca': 'pca'},
    {'type': 'C', 'embedding': 'bodyText', 'pca': 'orig'},
    {'type': 'C', 'embedding': 'bodyText', 'pca': 'pca'},
    {'type': 'C', 'embedding': 'paragraphs', 'pca': 'orig'},
    {'type': 'C', 'embedding': 'paragraphs', 'pca': 'pca'},
    
    # D: C + FG index (8가지)
    {'type': 'D', 'embedding': 'headlines', 'pca': 'orig'},
    {'type': 'D', 'embedding': 'headlines', 'pca': 'pca'},
    {'type': 'D', 'embedding': 'chunking', 'pca': 'orig'},
    {'type': 'D', 'embedding': 'chunking', 'pca': 'pca'},
    {'type': 'D', 'embedding': 'bodyText', 'pca': 'orig'},
    {'type': 'D', 'embedding': 'bodyText', 'pca': 'pca'},
    {'type': 'D', 'embedding': 'paragraphs', 'pca': 'orig'},
    {'type': 'D', 'embedding': 'paragraphs', 'pca': 'pca'},
]

print(f"Total variants: {len(DATASET_VARIANTS)}")

Total variants: 25


## 전체 학습 실행

In [None]:
# ==================== MAIN TRAINING LOOP ====================

results = []

print("=" * 80)
print(f"Starting TFT training for {len(DATASET_VARIANTS)} variants")
print("=" * 80)

for i, variant in enumerate(DATASET_VARIANTS, 1):
    print(f"\n{'='*80}")
    print(f"Variant {i}/{len(DATASET_VARIANTS)}: {get_variant_name(variant)}")
    print(f"{'='*80}")
    
    try:
        # 1. Load data
        filename = get_dataset_filename(variant)
        filepath = CONFIG['data_dir'] / filename
        
        if not filepath.exists():
            print(f"  WARNING: File not found: {filepath}")
            continue
        
        df = load_and_preprocess_data(filepath, variant, CONFIG)
        
        # 2. Split data
        print("\n  Splitting data...")
        train_df, valid_df, test_df = split_data_by_date(df, CONFIG)
        
        # 3. Add series_id (CRITICAL!)
        train_df['series_id'] = 0
        valid_df['series_id'] = 0
        test_df['series_id'] = 0
        
        # 4. Create datasets
        print("\n  Creating TimeSeriesDataSet...")
        train_dataset = create_time_series_dataset(train_df, CONFIG, variant)
        valid_dataset = TimeSeriesDataSet.from_dataset(train_dataset, valid_df, predict=True, stop_randomization=True)
        test_dataset = TimeSeriesDataSet.from_dataset(train_dataset, test_df, predict=True, stop_randomization=True)
        
        # 5. Train model
        print("\n  Training TFT model...")
        model = train_tft_model(train_dataset, valid_dataset, CONFIG, variant)
        
        # 6. Evaluate
        print("\n  Evaluating...")
        test_mse = predict_and_evaluate(model, test_dataset, CONFIG)
        
        # 7. Save result
        result = {
            'Dataset': variant['type'],
            'Method': variant.get('embedding', '-'),
            'Type': variant.get('pca', '-'),
            'Model': 'TFT',
            'MSE': test_mse
        }
        results.append(result)
        
        print(f"\n  ✓ Completed: MSE={test_mse:.8f}")
        
        # Clean up
        del df, train_df, valid_df, test_df
        del train_dataset, valid_dataset, test_dataset, model
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
    except Exception as e:
        print(f"  ERROR: {str(e)}")
        import traceback
        traceback.print_exc()
        continue

print("\n" + "=" * 80)
print("Training completed!")
print("=" * 80)

Starting TFT training for 25 variants

Variant 1/25: A_sp500_only
  Loading: dataset_A.parquet
  Original shape: (754, 8)
  After removing columns: (754, 7)
  Final shape: (754, 7)

  Splitting data...
  Total unique dates: 754
  Train: 502 rows, 502 dates
  Valid: 126 rows, 126 dates
  Test: 126 rows, 126 dates

  Creating TimeSeriesDataSet...
  Time-varying unknown: 5
  Static reals: 0

  Training TFT model...


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
You are using a CUDA device ('NVIDIA GeForce RTX 4090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [2]


  Model parameters: 1,199,067
  Training...
Epoch 15: 100%|██████████| 14/14 [00:01<00:00,  8.48it/s, v_num=0, train_loss_step=17.70, val_loss=66.40, train_loss_epoch=19.20]
  ERROR: Weights only load failed. This file can still be loaded, to do so you have two options, [1mdo those steps only if you trust the source of the checkpoint[0m. 
	(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
	WeightsUnpickler error: Unsupported global: GLOBAL pytorch_forecasting.data.encoders.GroupNormalizer was not an allowed global by default. Please use `torch.serialization.add_safe_globals([pytorch_forecasting.data.encoders.GroupNormal

Traceback (most recent call last):
  File "/tmp/ipykernel_1048944/2841510984.py", line 42, in <module>
    model = train_tft_model(train_dataset, valid_dataset, CONFIG, variant)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_1048944/3957883494.py", line 208, in train_tft_model
    best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wonjun/.conda/envs/mlproject/lib/python3.11/site-packages/lightning/pytorch/utilities/model_helpers.py", line 130, in wrapper
    return self.method(cls_type, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wonjun/.conda/envs/mlproject/lib/python3.11/site-packages/lightning/pytorch/core/module.py", line 1781, in load_from_checkpoint
    loaded = _load_from_checkpoint(
             ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wonjun/.conda/envs/mlproject/lib/python3.11/s

  Original shape: (461270, 16)
  After removing columns: (461270, 8)
  Expanding embedding column...
  Expanded to 1024 columns
  Aggregating by date...
  After aggregation: 754 rows
  Final shape: (754, 1031)

  Splitting data...
  Total unique dates: 754
  Train: 502 rows, 502 dates
  Valid: 126 rows, 126 dates
  Test: 126 rows, 126 dates

  Creating TimeSeriesDataSet...
  Time-varying unknown: 1029
  Static reals: 0

  Training TFT model...


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [2]


  Model parameters: 17,254,115
  Training...
Epoch 2:   7%|▋         | 1/14 [00:02<00:30,  0.42it/s, v_num=1, train_loss_step=28.20, val_loss=86.60, train_loss_epoch=55.10] 

In [None]:
# ==================== SAVE RESULTS ====================

results_df = pd.DataFrame(results)
results_df = results_df.sort_values('MSE').reset_index(drop=True)
results_df.to_csv(CONFIG['output_file'], index=False)

print(f"\n✓ Results saved to: {CONFIG['output_file']}")
print(f"\nResults:")
print(results_df.to_string())

In [None]:
# ==================== SUMMARY ====================

print("\n" + "=" * 80)
print("SUMMARY")
print("=" * 80)

print(f"\nTotal variants: {len(results_df)}")
print(f"\nBest Model:")
best = results_df.iloc[0]
print(f"  Dataset: {best['Dataset']}")
print(f"  Method: {best['Method']}")
print(f"  Type: {best['Type']}")
print(f"  MSE: {best['MSE']:.8f}")

print(f"\nTop 5:")
print(results_df.head(5).to_string())