In [1]:
# %pip install onnxruntime

In [2]:
import os
import json
import joblib
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass, field
from dotenv import load_dotenv
import numpy as np
import pandas as pd
import yfinance as yf
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, r2_score
import ta
import onnx
import onnxruntime as ort
import warnings

# Suppress ONNX export warnings
warnings.filterwarnings("ignore", category=UserWarning, module="torch.onnx")
warnings.filterwarnings("ignore", category=DeprecationWarning, module="torch.onnx")

# Check PyTorch version for dynamo compatibility
PYTORCH_VERSION = torch.__version__.split('+')[0]
PYTORCH_MAJOR, PYTORCH_MINOR = map(int, PYTORCH_VERSION.split('.')[:2])
USE_DYNAMO = PYTORCH_MAJOR >= 2 and PYTORCH_MINOR >= 9

# Configuration
@dataclass
class Config:
    """Configuration for the stock prediction pipeline."""
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    context_len: int = 60
    pred_len: int = 5
    features: List[str] = field(default_factory=lambda: ["Open", "High", "Low", "Close", "Volume", "RSI14", "MACD"])
    batch_size: int = 32
    parent_ticker: str = "^GSPC"
    child_tickers: List[str] = field(default_factory=lambda: ["GOOG", "AMZN", "META", "AXP"])
    start_date: str = "2004-08-19"  # Google's IPO date
    parent_epochs: int = 20
    child_epochs: int = 10
    parent_dir: str = "outputs/parent"
    workdir: str = "outputs"

    @property
    def input_size(self) -> int:
        return len(self.features)

# Custom exception
class PipelineError(Exception):
    """Custom exception for pipeline errors."""
    pass

# Initialize directories
def initialize_dirs():
    """Initialize output directories."""
    config = Config()
    try:
        os.makedirs(config.workdir, exist_ok=True)
        os.makedirs(config.parent_dir, exist_ok=True)
        for ticker in config.child_tickers:
            os.makedirs(os.path.join(config.workdir, ticker), exist_ok=True)
    except Exception as e:
        raise PipelineError(f"Failed to create output directories: {e}")

initialize_dirs()
load_dotenv()

def rsi(series: pd.Series, period: int = 14) -> pd.Series:
    """Calculate RSI for a given series."""
    delta = series.diff()
    gain = delta.where(delta > 0, 0).rolling(window=period).mean()
    loss = -delta.where(delta < 0, 0).rolling(window=period).mean()
    rs = gain / loss
    return 100 - 100 / (1 + rs)

def macd(series: pd.Series, fast: int = 12, slow: int = 26) -> pd.Series:
    """Calculate MACD for a given series."""
    ema_fast = series.ewm(span=fast, adjust=False).mean()
    ema_slow = series.ewm(span=slow, adjust=False).mean()
    return ema_fast - ema_slow

def fetch_ohlcv(ticker: str, start: str = Config().start_date, end: Optional[str] = None) -> pd.DataFrame:
    """Fetch OHLCV data with technical indicators."""
    config = Config()
    try:
        df = yf.download(ticker, start=start, end=end, interval="1d", auto_adjust=True, progress=False)
        if df.empty:
            raise PipelineError(f"No data downloaded for {ticker}")
        df = df.reset_index().rename(columns={"Date": "date"})
        df = df[["date", "Open", "High", "Low", "Close", "Volume"]].dropna()
        df["RSI14"] = rsi(df["Close"])
        df["MACD"] = macd(df["Close"])
        df = df[["date"] + config.features].dropna()
        
        # Validate data
        if len(df) < config.context_len + config.pred_len:
            raise PipelineError(f"Insufficient data for {ticker}: {len(df)} rows, need at least {config.context_len + config.pred_len}")
        if df[config.features].isnull().any().any():
            raise PipelineError(f"NaN values found in features for {ticker}")
        if not df[config.features].apply(lambda x: pd.api.types.is_numeric_dtype(x)).all():
            raise PipelineError(f"Non-numeric values found in features for {ticker}")
        print(f"Fetched {len(df)} rows for {ticker}")
        return df
    except Exception as e:
        raise PipelineError(f"Failed to fetch data for {ticker}: {e}")

class StockDataset(Dataset):
    """Dataset for stock price sequences."""
    def __init__(self, df: pd.DataFrame, scaler: StandardScaler, context_len: int = Config().context_len, pred_len: int = Config().pred_len):
        self.context_len = context_len
        self.pred_len = pred_len
        try:
            vals = scaler.transform(df[Config().features]).astype("float32")
            self.samples = []
            for t in range(context_len, len(df) - pred_len):
                past = vals[t - context_len:t]
                fut = vals[t:t + pred_len]
                if past.shape == (context_len, len(Config().features)) and fut.shape == (pred_len, len(Config().features)):
                    self.samples.append((past, fut))
                else:
                    print(f"Skipping invalid sample at index {t}: past shape {past.shape}, fut shape {fut.shape}")
            if not self.samples:
                raise PipelineError("No valid samples created for dataset")
        except Exception as e:
            raise PipelineError(f"Failed to create dataset: {e}")

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        past, fut = self.samples[idx]
        return torch.tensor(past), torch.tensor(fut)

class LSTMModel(nn.Module):
    """LSTM model for stock price prediction."""
    def __init__(self, input_size: int = Config().input_size, hidden_size: int = 128, num_layers: int = 3, 
                 pred_len: int = Config().pred_len, dropout: float = 0.2):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout)
        self.fc = nn.Linear(hidden_size, input_size * pred_len)
        self.pred_len = pred_len
        self.input_size = input_size

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out, _ = self.lstm(x)
        out = out[:, -1, :]
        out = self.fc(out)
        return out.view(-1, self.pred_len, self.input_size)

def fit_model(model: nn.Module, train_loader: DataLoader, val_loader: DataLoader, epochs: int = 8, lr: float = 1e-3) -> nn.Module:
    """Train the LSTM model with early stopping."""
    model.to(Config().device)
    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
    criterion = nn.MSELoss()
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min', factor=0.5, patience=3)
    best_val_loss = float('inf')
    patience, counter = 5, 0

    for ep in range(1, epochs + 1):
        model.train()
        total_loss = 0.0
        for X, Y in train_loader:
            X, Y = X.to(Config().device), Y.to(Config().device)
            opt.zero_grad()
            pred = model(X)
            loss = criterion(pred, Y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            opt.step()
            total_loss += loss.item()
        avg_train_loss = total_loss / len(train_loader)
        print(f"Epoch {ep}/{epochs} - Train Loss: {avg_train_loss:.5f}")

        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for X, Y in val_loader:
                X, Y = X.to(Config().device), Y.to(Config().device)
                pred = model(X)
                val_loss += criterion(pred, Y).item()
        avg_val_loss = val_loss / len(val_loader)
        print(f"Epoch {ep}/{epochs} - Val Loss: {avg_val_loss:.5f}")

        scheduler.step(avg_val_loss)
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            counter = 0
        else:
            counter += 1
            if counter >= patience:
                print("Early stopping triggered")
                break

    return model

def model_to_onnx(model: nn.Module, path: str):
    """Convert PyTorch model to ONNX and create inference session."""
    try:
        model.eval()
        dummy_input = torch.randn(1, Config().context_len, Config().input_size).to(Config().device)
        export_kwargs = {
            "model": model,
            "args": dummy_input,
            "f": path,
            "input_names": ['input'],
            "output_names": ['output'],
            "dynamic_axes": {'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}},
            "opset_version": 12
        }
        if USE_DYNAMO:
            export_kwargs["dynamo"] = True
        torch.onnx.export(**export_kwargs)
        return ort.InferenceSession(path)
    except Exception as e:
        raise PipelineError(f"Failed to export model to ONNX at {path}: {e}")

def save_model(model: nn.Module, scaler: StandardScaler, path: str, model_type: str = "parent", ticker: Optional[str] = None):
    """Save model, scaler, and ONNX model."""
    try:
        os.makedirs(path, exist_ok=True)
        torch_save_path = os.path.join(path, "model.pt")
        scaler_filename = "parent_scaler.pkl" if model_type == "parent" else f"{ticker}_child_scaler.pkl"
        scaler_path = os.path.join(path, scaler_filename)
        onnx_path = os.path.join(path, "model.onnx")

        torch.save(model.state_dict(), torch_save_path)
        joblib.dump(scaler, scaler_path)
        model_to_onnx(model, onnx_path)
        print(f"Model and scaler saved locally at {path}")
    except Exception as e:
        raise PipelineError(f"Failed to save model at {path}: {e}")

def load_model(path: str, model_type: str = "parent", ticker: Optional[str] = None) -> Tuple[ort.InferenceSession, StandardScaler]:
    """Load ONNX model and scaler from path."""
    if model_type == "child" and not ticker:
        raise ValueError("Ticker must be provided for child model")
    
    scaler_filename = "parent_scaler.pkl" if model_type == "parent" else f"{ticker}_child_scaler.pkl"
    onnx_path = os.path.join(path, "model.onnx")
    scaler_path = os.path.join(path, scaler_filename)
    
    if os.path.exists(onnx_path) and os.path.exists(scaler_path):
        return ort.InferenceSession(onnx_path), joblib.load(scaler_path)
    raise FileNotFoundError(f"ONNX model or scaler not found at {path}")

def predict_one_step_and_week(session: ort.InferenceSession, df: pd.DataFrame, scaler: StandardScaler, ticker: str) -> Dict:
    """Predict next day and week OHLCV values."""
    try:
        config = Config()
        vals = scaler.transform(df[config.features]).astype("float32")
        if vals.shape[0] < config.context_len:
            raise PipelineError(f"Insufficient data: {vals.shape[0]} rows, need at least {config.context_len}")
        X = vals[-config.context_len:].reshape(1, config.context_len, config.input_size)
        print(f"Input shape for {ticker}: {X.shape}")  # Debug

        pred = session.run(None, {'input': X})[0]
        print(f"Prediction shape for {ticker}: {pred.shape}")  # Debug
        if pred.shape != (1, config.pred_len, config.input_size):
            raise PipelineError(f"Unexpected prediction shape for {ticker}: {pred.shape}")

        pred_full = scaler.inverse_transform(pred.reshape(-1, config.input_size))
        print(f"Inverse transformed shape for {ticker}: {pred_full.shape}")  # Debug
        pred_full = pred_full.reshape(config.pred_len, config.input_size)
        pred = pred_full[:, :5]  # OHLCV only

        # Validate predictions
        if np.any(np.isnan(pred)):
            raise PipelineError(f"NaN values in predictions for {ticker}")

        last_date = pd.to_datetime(df["date"].iloc[-1])
        next_business_days = pd.bdate_range(last_date + pd.Timedelta(days=1), periods=config.pred_len)
        next_business_days_str = [str(d.date()) for d in next_business_days]

        full_forecast = []
        for i, d in enumerate(next_business_days):
            forecast_entry = {
                "date": str(d.date()),
                "open": round(float(pred[i, 0]), 2),
                "high": round(float(pred[i, 1]), 2),
                "low": round(float(pred[i, 2]), 2),
                "close": round(float(pred[i, 3]), 2),
                "volume": int(pred[i, 4])
            }
            print(f"Forecast entry {i} for {ticker}: {forecast_entry}")  # Debug
            full_forecast.append(forecast_entry)

        output = {
            "ticker": ticker,
            "last_date": str(last_date.date()),
            "future_window_days": config.pred_len,
            "next_business_days": next_business_days_str,
            "predictions": {
                "next_day": {
                    "open": round(float(pred[0, 0]), 2),
                    "high": round(float(pred[0, 1]), 2),
                    "low": round(float(pred[0, 2]), 2),
                    "close": round(float(pred[0, 3]), 2),
                    "volume": int(pred[0, 4])
                },
                "next_week": {
                    "high": round(float(np.max(pred[:, 1])), 2),
                    "low": round(float(np.min(pred[:, 2])), 2)
                },
                "full_forecast": full_forecast
            }
        }
        print(f"Prediction output for {ticker}: {json.dumps(output, indent=2)}")  # Debug
        return output
    except Exception as e:
        print(f"Prediction failed for {ticker}: {e}")
        return {"ticker": ticker, "error": str(e)}

def evaluate_model(session: ort.InferenceSession, df: pd.DataFrame, scaler: StandardScaler, out_dir: str, ticker: str) -> Dict:
    """Evaluate model performance."""
    try:
        os.makedirs(out_dir, exist_ok=True)
        config = Config()
        vals = scaler.transform(df[config.features]).astype("float32")
        X, Y = [], []
        for t in range(config.context_len, len(vals) - config.pred_len):
            past = vals[t - config.context_len:t]
            fut = vals[t:t + config.pred_len]
            if past.shape == (config.context_len, config.input_size) and fut.shape == (config.pred_len, config.input_size):
                X.append(past)
                Y.append(fut)
            else:
                print(f"Skipping invalid evaluation sample at index {t}: past shape {past.shape}, fut shape {fut.shape}")

        if not X:
            print(f"No valid samples for evaluation for {ticker}")
            return {}

        X, Y = np.array(X), np.array(Y)
        preds = [session.run(None, {'input': x.reshape(1, config.context_len, config.input_size)})[0] for x in X]
        preds = np.array(preds)
        Y_ohlcv = Y.reshape(-1, config.input_size)[:, :5]
        preds_ohlcv = preds.reshape(-1, config.input_size)[:, :5]

        mse = mean_squared_error(Y_ohlcv, preds_ohlcv)
        rmse = np.sqrt(mse)
        r2 = r2_score(Y_ohlcv, preds_ohlcv)

        metrics = {"MSE": mse, "RMSE": rmse, "R2": r2}
        metrics_filename = f"{ticker}_parent_metrics.json" if "parent" in out_dir else f"{ticker}_child_metrics.json"
        metrics_path = os.path.join(out_dir, metrics_filename)
        with open(metrics_path, "w") as f:
            json.dump(metrics, f, indent=2)
        print(f"{ticker} → MSE: {mse:.5f}, RMSE: {rmse:.5f}, R²: {r2:.5f}")

        return metrics
    except Exception as e:
        print(f"Evaluation failed for {ticker}: {e}")
        return {}

def save_json(payload: Dict, path: str) -> str:
    """Save dictionary to JSON file."""
    try:
        os.makedirs(os.path.dirname(path), exist_ok=True)
        with open(path, "w") as f:
            json.dump(payload, f, indent=2)
        return path
    except Exception as e:
        raise PipelineError(f"Failed to save JSON at {path}: {e}")

def plot_outputs(df: pd.DataFrame, payload: Dict, out_dir: str, ticker: str):
    """Plot historical and forecasted prices."""
    try:
        if "error" in payload:
            raise PipelineError(f"Cannot plot for {ticker}: prediction failed with error {payload['error']}")

        os.makedirs(out_dir, exist_ok=True)
        plt.figure(figsize=(12, 5))
        plt.plot(df["date"], df["Close"], label="History")
        
        ndo = payload["predictions"]["next_day"]["open"]
        ndc = payload["predictions"]["next_day"]["close"]
        ndh = payload["predictions"]["next_day"]["high"]
        ndl = payload["predictions"]["next_day"]["low"]
        whi = payload["predictions"]["next_week"]["high"]
        wlo = payload["predictions"]["next_week"]["low"]
        
        # Validate forecast_closes
        forecast_closes = []
        for p in payload["predictions"]["full_forecast"]:
            close = p.get("close")
            if not isinstance(close, (int, float)) or pd.isna(close):
                raise PipelineError(f"Invalid close value in full_forecast for {ticker}: {close}")
            forecast_closes.append(float(close))
        print(f"Forecast closes for {ticker}: {forecast_closes}")
        
        # Ensure last historical close is a scalar
        last_close = df["Close"].iloc[-1]
        if isinstance(last_close, (np.ndarray, list, pd.Series)):
            last_close = float(last_close.item() if isinstance(last_close, np.ndarray) else last_close[0])
        elif not isinstance(last_close, (int, float)):
            raise PipelineError(f"Invalid last historical close for {ticker}: {last_close}")
        print(f"Last historical close for {ticker}: {last_close}")
        
        # Validate date arrays
        last_date = pd.to_datetime(payload["last_date"])
        next_dates = [pd.to_datetime(d) for d in payload["next_business_days"]]
        y_values = [last_close] + forecast_closes
        print(f"Y-values for plotting {ticker}: {y_values}")
        
        plt.axhline(ndo, color="orange", linestyle="-", alpha=0.7, label="Next-day open")
        plt.axhline(ndc, color="r", linestyle="--", label="Next-day close")
        plt.axhline(ndh, color="darkgreen", linestyle="-", alpha=0.7, label="Next-day high")
        plt.axhline(ndl, color="darkred", linestyle="-", alpha=0.7, label="Next-day low")
        plt.axhline(whi, color="g", linestyle=":", label="Next-week high")
        plt.axhline(wlo, color="b", linestyle=":", label="Next-week low")
        
        plt.plot([last_date] + next_dates, y_values, 'm--', label="Multi-step forecast closes")
        
        plt.legend()
        plt.title(f"{ticker} Close + Next Day & Week Forecast")
        plot_filename = f"{ticker}_parent_history_forecast.png" if "parent" in out_dir else f"{ticker}_child_history_forecast.png"
        plot_path = os.path.join(out_dir, plot_filename)
        plt.savefig(plot_path)
        plt.close()
        print(f"Plot saved for {ticker} at {plot_path}")
    except Exception as e:
        raise PipelineError(f"Plotting failed for {ticker}: {e}")

def train_parent(ticker: str = Config().parent_ticker, start: str = Config().start_date, 
                 epochs: int = Config().parent_epochs, out_dir: str = Config().parent_dir) -> Dict:
    """Train parent model."""
    try:
        df = fetch_ohlcv(ticker, start)
        scaler = StandardScaler().fit(df[Config().features])
        dataset = StockDataset(df, scaler)
        train_size = int(0.8 * len(dataset))
        train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, len(dataset) - train_size])
        train_loader = DataLoader(train_dataset, batch_size=Config().batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=Config().batch_size)

        model = LSTMModel()
        model = fit_model(model, train_loader, val_loader, epochs=epochs, lr=1e-3)
        save_model(model, scaler, out_dir, model_type="parent", ticker=ticker)
        session = model_to_onnx(model, os.path.join(out_dir, "model.onnx"))
        evaluate_model(session, df, scaler, out_dir, ticker.replace("^", ""))
        return {"checkpoint": out_dir}
    except Exception as e:
        raise PipelineError(f"Parent model training failed for {ticker}: {e}")

def train_child(ticker: str, start: str = Config().start_date, epochs: int = Config().child_epochs, 
                parent_dir: str = Config().parent_dir, workdir: str = Config().workdir) -> Dict:
    """Train child model using parent model weights."""
    try:
        df = fetch_ohlcv(ticker, start)
        parent_model = LSTMModel()
        parent_model_path = os.path.join(parent_dir, "model.pt")
        if not os.path.exists(parent_model_path):
            raise FileNotFoundError(f"Parent model not found at {parent_model_path}")
        parent_model.load_state_dict(torch.load(parent_model_path, map_location=Config().device))

        for name, param in parent_model.named_parameters():
            if "lstm" in name:
                param.requires_grad = False

        scaler = StandardScaler().fit(df[Config().features])
        if (df[Config().features].std() == 0).any():
            raise PipelineError(f"Zero variance in features for {ticker}")
        dataset = StockDataset(df, scaler)
        train_size = int(0.8 * len(dataset))
        train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, len(dataset) - train_size])
        train_loader = DataLoader(train_dataset, batch_size=Config().batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=Config().batch_size)

        child_model = fit_model(parent_model, train_loader, val_loader, epochs=epochs, lr=3e-4)
        child_dir = os.path.join(workdir, ticker)
        save_model(child_model, scaler, child_dir, model_type="child", ticker=ticker)

        session = model_to_onnx(child_model, os.path.join(child_dir, "model.onnx"))
        payload = predict_one_step_and_week(session, df, scaler, ticker)
        json_filename = f"{ticker}_child_forecast.json"
        json_path = save_json(payload, os.path.join(child_dir, json_filename))
        plot_outputs(df, payload, child_dir, ticker)
        evaluate_model(session, df, scaler, child_dir, ticker)
        return {"checkpoint": child_dir, "json": json_path}
    except Exception as e:
        raise PipelineError(f"Child model training failed for {ticker}: {e}")

def predict_child(ticker: str, parent_dir: str = Config().parent_dir, workdir: str = Config().workdir) -> Dict:
    """Predict using child model."""
    try:
        child_dir = os.path.join(workdir, ticker)
        df = fetch_ohlcv(ticker)
        session, scaler = load_model(path=child_dir, model_type="child", ticker=ticker)
        return predict_one_step_and_week(session, df, scaler, ticker)
    except Exception as e:
        print(f"Prediction failed for {ticker}: {e}")
        return {"ticker": ticker, "error": str(e)}

def infer_child_stock(
    ticker: str,
    start: str = Config().start_date,
    epochs: int = Config().child_epochs,
    parent_dir: str = Config().parent_dir,
    workdir: str = Config().workdir,
    train_if_not_exists: bool = True,
    return_plot: bool = False,
    return_metrics: bool = False
) -> Dict:
    """Infer stock predictions, training if necessary."""
    child_dir = os.path.join(workdir, ticker)
    onnx_path = os.path.join(child_dir, "model.onnx")
    scaler_path = os.path.join(child_dir, f"{ticker}_child_scaler.pkl")
    
    model_exists = os.path.exists(onnx_path) and os.path.exists(scaler_path)
    
    try:
        session, scaler = load_model(path=child_dir, model_type="child", ticker=ticker)
        print(f"Loaded model for {ticker} from {child_dir}")
    except FileNotFoundError as e:
        if train_if_not_exists:
            print(f"Model for {ticker} not found. Training now...")
            try:
                train_summary = train_child(ticker, start, epochs, parent_dir, workdir)
                child_dir = train_summary["checkpoint"]
                session, scaler = load_model(path=child_dir, model_type="child", ticker=ticker)
            except Exception as e:
                raise PipelineError(f"Failed to train child model for {ticker}: {e}")
        else:
            raise PipelineError(f"Child model for {ticker} not found in {child_dir}.") from e
    
    print(f"Running inference for {ticker}...")
    df = fetch_ohlcv(ticker, start)
    predictions = predict_one_step_and_week(session, df, scaler, ticker)
    
    output = {"ticker": ticker, "predictions": predictions}
    
    if return_metrics:
        output["metrics"] = evaluate_model(session, df, scaler, child_dir, ticker)
    
    if return_plot:
        plot_outputs(df, predictions, child_dir, ticker)
        output["plot_path"] = os.path.join(child_dir, f"{ticker}_child_history_forecast.png")
    
    return output

def main():
    """Run the stock prediction pipeline."""
    config = Config()
    
    # Train parent model
    print("1. Training parent model for S&P 500...")
    parent_model_path = os.path.join(config.parent_dir, "model.pt")
    parent_scaler_path = os.path.join(config.parent_dir, "parent_scaler.pkl")
    parent_onnx_path = os.path.join(config.parent_dir, "model.onnx")
    if os.path.exists(parent_model_path) and os.path.exists(parent_scaler_path) and os.path.exists(parent_onnx_path):
        print(f"✓ Using existing parent model at: {config.parent_dir}")
    else:
        try:
            parent_summary = train_parent(ticker=config.parent_ticker, start=config.start_date, 
                                        epochs=config.parent_epochs, out_dir=config.parent_dir)
            print(f"✓ Parent model trained and saved to: {config.parent_dir}")
        except Exception as e:
            print(f"✗ Error training parent model: {e}")
            if os.path.exists(parent_model_path) and os.path.exists(parent_scaler_path):
                print(f"✓ Found existing parent model at: {config.parent_dir}. Continuing...")
            else:
                print("✗ No existing parent model found. Cannot proceed without parent model. Exiting.")
                exit(1)

    # Train child models
    results = {}
    print("\n2. Training child models sequentially...")
    for ticker in config.child_tickers:
        print(f"Training child model for {ticker}...")
        try:
            summary = train_child(ticker=ticker, start=config.start_date, epochs=config.child_epochs, 
                                 parent_dir=config.parent_dir, workdir=config.workdir)
            results[ticker] = summary
            print(f"✓ {ticker} model trained and saved to: {summary['checkpoint']}")
            print(f"✓ Predictions saved to: {summary['json']}")
            print(f"✓ Metrics saved to: {summary['checkpoint']}/{ticker}_child_metrics.json")
        except Exception as e:
            print(f"✗ Error training {ticker}: {e}")

    # Generate predictions
    print("\n3. Generating fresh predictions...")
    for ticker in config.child_tickers:
        try:
            preds = predict_child(ticker=ticker, parent_dir=config.parent_dir, workdir=config.workdir)
            if "error" in preds:
                print(f"✗ Error predicting {ticker}: {preds['error']}")
                continue
            predictions = preds.get('predictions', {})
            next_business_days = predictions.get('next_business_days', [])
            next_day = predictions.get('next_day', {})
            next_week = predictions.get('next_week', {})
            print(f"✓ {ticker} predictions for {next_business_days}:")
            print(f"  Next-day open: ${next_day.get('open', 'N/A'):.2f}")
            print(f"  Next-day high: ${next_day.get('high', 'N/A'):.2f}")
            print(f"  Next-day low: ${next_day.get('low', 'N/A'):.2f}")
            print(f"  Next-day close: ${next_day.get('close', 'N/A'):.2f}")
            print(f"  Next-week high: ${next_week.get('high', 'N/A'):.2f}")
            print(f"  Next-week low: ${next_week.get('low', 'N/A'):.2f}")
        except Exception as e:
            print(f"✗ Error predicting {ticker}: {e}")

    print(f"\n{'=' * 50}")
    print("Pipeline completed! Check 'outputs/' directory for models, scalers, predictions, metrics, and plots.")
    print("\nFile structure:")
    print("outputs/")
    print("├── parent/")
    print(f"│   ├── model.pt")
    print(f"│   ├── model.onnx")
    print(f"│   ├── parent_scaler.pkl")
    print(f"│   └── {config.parent_ticker.replace('^', '')}_parent_metrics.json")
    for ticker in config.child_tickers:
        if ticker in results:
            print(f"├── {ticker}/")
            print(f"│   ├── model.pt")
            print(f"│   ├── model.onnx")
            print(f"│   ├── {ticker}_child_scaler.pkl")
            print(f"│   ├── {ticker}_child_forecast.json")
            print(f"│   ├── {ticker}_child_metrics.json")
            print(f"│   └── {ticker}_child_history_forecast.png")

if __name__ == '__main__':
    main()

1. Training parent model for S&P 500...
Fetched 5288 rows for ^GSPC
Epoch 1/20 - Train Loss: 0.40425
Epoch 1/20 - Val Loss: 0.18780
Epoch 2/20 - Train Loss: 0.17264
Epoch 2/20 - Val Loss: 0.14371
Epoch 3/20 - Train Loss: 0.14790
Epoch 3/20 - Val Loss: 0.13367
Epoch 4/20 - Train Loss: 0.13877
Epoch 4/20 - Val Loss: 0.12930
Epoch 5/20 - Train Loss: 0.13355
Epoch 5/20 - Val Loss: 0.12181
Epoch 6/20 - Train Loss: 0.12832
Epoch 6/20 - Val Loss: 0.12780
Epoch 7/20 - Train Loss: 0.12615
Epoch 7/20 - Val Loss: 0.11840
Epoch 8/20 - Train Loss: 0.12268
Epoch 8/20 - Val Loss: 0.11871
Epoch 9/20 - Train Loss: 0.12092
Epoch 9/20 - Val Loss: 0.11420
Epoch 10/20 - Train Loss: 0.11745
Epoch 10/20 - Val Loss: 0.11444
Epoch 11/20 - Train Loss: 0.11487
Epoch 11/20 - Val Loss: 0.11293
Epoch 12/20 - Train Loss: 0.11410
Epoch 12/20 - Val Loss: 0.10988
Epoch 13/20 - Train Loss: 0.11201
Epoch 13/20 - Val Loss: 0.10979
Epoch 14/20 - Train Loss: 0.11045
Epoch 14/20 - Val Loss: 0.10869
Epoch 15/20 - Train Loss: 

  torch.onnx.export(**export_kwargs)


GSPC → MSE: 0.06793, RMSE: 0.26064, R²: 0.92977
✓ Parent model trained and saved to: outputs/parent

2. Training child models sequentially...
Training child model for GOOG...
Fetched 5288 rows for GOOG
Epoch 1/10 - Train Loss: 0.13160
Epoch 1/10 - Val Loss: 0.10523
Epoch 2/10 - Train Loss: 0.12313
Epoch 2/10 - Val Loss: 0.10268
Epoch 3/10 - Train Loss: 0.12049
Epoch 3/10 - Val Loss: 0.10066
Epoch 4/10 - Train Loss: 0.11911
Epoch 4/10 - Val Loss: 0.09937
Epoch 5/10 - Train Loss: 0.11801
Epoch 5/10 - Val Loss: 0.09905
Epoch 6/10 - Train Loss: 0.11629
Epoch 6/10 - Val Loss: 0.09803
Epoch 7/10 - Train Loss: 0.11651
Epoch 7/10 - Val Loss: 0.09756
Epoch 8/10 - Train Loss: 0.11612
Epoch 8/10 - Val Loss: 0.09695
Epoch 9/10 - Train Loss: 0.11442
Epoch 9/10 - Val Loss: 0.09663
Epoch 10/10 - Train Loss: 0.11430
Epoch 10/10 - Val Loss: 0.09611
Model and scaler saved locally at outputs/GOOG
Input shape for GOOG: (1, 60, 7)
Prediction shape for GOOG: (1, 5, 7)
Inverse transformed shape for GOOG: (5,

  torch.onnx.export(**export_kwargs)
  last_close = float(last_close.item() if isinstance(last_close, np.ndarray) else last_close[0])


GOOG → MSE: 0.05833, RMSE: 0.24152, R²: 0.93689
✓ GOOG model trained and saved to: outputs/GOOG
✓ Predictions saved to: outputs/GOOG/GOOG_child_forecast.json
✓ Metrics saved to: outputs/GOOG/GOOG_child_metrics.json
Training child model for AMZN...
Fetched 5288 rows for AMZN
Epoch 1/10 - Train Loss: 0.18161
Epoch 1/10 - Val Loss: 0.20865
Epoch 2/10 - Train Loss: 0.17617
Epoch 2/10 - Val Loss: 0.20591
Epoch 3/10 - Train Loss: 0.17232
Epoch 3/10 - Val Loss: 0.20437
Epoch 4/10 - Train Loss: 0.17108
Epoch 4/10 - Val Loss: 0.20337
Epoch 5/10 - Train Loss: 0.17037
Epoch 5/10 - Val Loss: 0.20251
Epoch 6/10 - Train Loss: 0.16876
Epoch 6/10 - Val Loss: 0.20193
Epoch 7/10 - Train Loss: 0.16830
Epoch 7/10 - Val Loss: 0.20136
Epoch 8/10 - Train Loss: 0.16867
Epoch 8/10 - Val Loss: 0.20080
Epoch 9/10 - Train Loss: 0.16714
Epoch 9/10 - Val Loss: 0.20055
Epoch 10/10 - Train Loss: 0.16744
Epoch 10/10 - Val Loss: 0.20008
Model and scaler saved locally at outputs/AMZN
Input shape for AMZN: (1, 60, 7)
Pre

  torch.onnx.export(**export_kwargs)
  last_close = float(last_close.item() if isinstance(last_close, np.ndarray) else last_close[0])


AMZN → MSE: 0.14490, RMSE: 0.38065, R²: 0.85428
✓ AMZN model trained and saved to: outputs/AMZN
✓ Predictions saved to: outputs/AMZN/AMZN_child_forecast.json
✓ Metrics saved to: outputs/AMZN/AMZN_child_metrics.json
Training child model for META...
Fetched 3336 rows for META
Epoch 1/10 - Train Loss: 0.16689
Epoch 1/10 - Val Loss: 0.16882
Epoch 2/10 - Train Loss: 0.15851
Epoch 2/10 - Val Loss: 0.16527
Epoch 3/10 - Train Loss: 0.15662
Epoch 3/10 - Val Loss: 0.16329
Epoch 4/10 - Train Loss: 0.15535
Epoch 4/10 - Val Loss: 0.16189
Epoch 5/10 - Train Loss: 0.15445
Epoch 5/10 - Val Loss: 0.16101
Epoch 6/10 - Train Loss: 0.15363
Epoch 6/10 - Val Loss: 0.16020
Epoch 7/10 - Train Loss: 0.15269
Epoch 7/10 - Val Loss: 0.15945
Epoch 8/10 - Train Loss: 0.15201
Epoch 8/10 - Val Loss: 0.15903
Epoch 9/10 - Train Loss: 0.15288
Epoch 9/10 - Val Loss: 0.15849
Epoch 10/10 - Train Loss: 0.15130
Epoch 10/10 - Val Loss: 0.15835
Model and scaler saved locally at outputs/META
Input shape for META: (1, 60, 7)
Pre

  torch.onnx.export(**export_kwargs)
  last_close = float(last_close.item() if isinstance(last_close, np.ndarray) else last_close[0])


META → MSE: 0.11548, RMSE: 0.33982, R²: 0.88289
✓ META model trained and saved to: outputs/META
✓ Predictions saved to: outputs/META/META_child_forecast.json
✓ Metrics saved to: outputs/META/META_child_metrics.json
Training child model for AXP...
Fetched 5288 rows for AXP
Epoch 1/10 - Train Loss: 0.14368
Epoch 1/10 - Val Loss: 0.13031
Epoch 2/10 - Train Loss: 0.13508
Epoch 2/10 - Val Loss: 0.12754
Epoch 3/10 - Train Loss: 0.13196
Epoch 3/10 - Val Loss: 0.12617
Epoch 4/10 - Train Loss: 0.12995
Epoch 4/10 - Val Loss: 0.12528
Epoch 5/10 - Train Loss: 0.12994
Epoch 5/10 - Val Loss: 0.12441
Epoch 6/10 - Train Loss: 0.12866
Epoch 6/10 - Val Loss: 0.12407
Epoch 7/10 - Train Loss: 0.12821
Epoch 7/10 - Val Loss: 0.12345
Epoch 8/10 - Train Loss: 0.12727
Epoch 8/10 - Val Loss: 0.12312
Epoch 9/10 - Train Loss: 0.12743
Epoch 9/10 - Val Loss: 0.12296
Epoch 10/10 - Train Loss: 0.12638
Epoch 10/10 - Val Loss: 0.12284
Model and scaler saved locally at outputs/AXP
Input shape for AXP: (1, 60, 7)
Predict

  torch.onnx.export(**export_kwargs)
  last_close = float(last_close.item() if isinstance(last_close, np.ndarray) else last_close[0])


AXP → MSE: 0.07930, RMSE: 0.28161, R²: 0.92139
✓ AXP model trained and saved to: outputs/AXP
✓ Predictions saved to: outputs/AXP/AXP_child_forecast.json
✓ Metrics saved to: outputs/AXP/AXP_child_metrics.json

3. Generating fresh predictions...
Fetched 5288 rows for GOOG
Input shape for GOOG: (1, 60, 7)
Prediction shape for GOOG: (1, 5, 7)
Inverse transformed shape for GOOG: (5, 7)
Forecast entry 0 for GOOG: {'date': '2025-09-15', 'open': 188.99, 'high': 193.69, 'low': 188.54, 'close': 190.51, 'volume': 104224232}
Forecast entry 1 for GOOG: {'date': '2025-09-16', 'open': 189.0, 'high': 191.89, 'low': 187.64, 'close': 190.63, 'volume': 74471040}
Forecast entry 2 for GOOG: {'date': '2025-09-17', 'open': 189.69, 'high': 192.71, 'low': 188.88, 'close': 189.86, 'volume': 70436792}
Forecast entry 3 for GOOG: {'date': '2025-09-18', 'open': 190.41, 'high': 190.27, 'low': 187.7, 'close': 190.25, 'volume': 64933716}
Forecast entry 4 for GOOG: {'date': '2025-09-19', 'open': 189.8, 'high': 192.29, 