In [None]:
# notebooks/04_model_training.ipynb

import pandas as pd
import sys
import os
import torch
import pytorch_lightning as pl

# Add src to path to import our modules
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

import config
from src.model import create_tft_dataset, train_tft_model, save_model
from pytorch_forecasting import TemporalFusionTransformer
from pytorch_forecasting.metrics import QuantileLoss

# --- 1. Load Processed Data ---
try:
    df = pd.read_csv(config.PROCESSED_DATA_FILE, index_col='time', parse_dates=True)
    print(f"Processed data loaded successfully from {config.PROCESSED_DATA_FILE}")
except FileNotFoundError:
    print(f"Error: Processed data file not found. Please run notebook 02 first.")
    raise

# --- 2. Create the TimeSeriesDataSet ---
# This can take some time. The new function no longer returns a scaler.
print("Creating TimeSeriesDataSet with advanced features...")
tft_dataset = create_tft_dataset(df)
print("DataSet created successfully.")
print(f"Number of training samples: {len(tft_dataset)}")

# --- 3. Train the Model ---
# This is computationally intensive. Use Google Colab with GPU for best results.
# The train_tft_model function is already updated in model.py to handle this.
print("\nStarting model training...")
tft_model, trainer = train_tft_model(tft_dataset)
print("Training complete.")

# --- 4. Save the Trained Model ---
# The save_model function in model.py finds the best checkpoint from the trainer.
save_model(trainer, config.MODEL_FILE)