trajectory forecasting with adaptive updating 

**Main Features**
- Brent Crude, WTI Crude, Dutch TTF Gas, Henry Hub Gas
- Equinor (EQNR.OL): Open, Close, High, Low, Volume, Market Cap
- OSEBX Index: Open, Close, High, Low, Volume
- VIX (volatility index)
- Dollar Index (DXY)

**Relevant Stocks**
- **Norway**: Aker BP (AKRBP), DNO (DNO), Vår Energi (VAR), Petroleum Geo-Services (PGS), BW Offshore (BWO), Frontline (FRO)
- **US/Global**: Exxon (XOM), Chevron (CVX), Shell (SHEL), BP (BP), TotalEnergies (TTE), ConocoPhillips (COP), Occidental (OXY)

**Stock Exchanges**
- S&P 500, NASDAQ, Dow Jones
- FTSE 100, DAX, CAC 40
- Nikkei 225, Hang Seng

**Commodity Prices**
- Gold (XAU), Silver (XAG)
- **Currencies**: USD/NOK, EUR/NOK, GBP/NOK, SEK/NOK, USD/EUR
- Coal (API2), Uranium (UX)
- Carbon Credits (EU ETS)

**Economic Indicators**
- **Interest Rates**: Norway (Norges Bank), US Fed Funds, ECB, BoE, BoJ, PBoC
- **Inflation**: Norway CPI, US CPI, EU HICP
- **Unemployment**: Norway, US, EU rates
- **Analyst Targets**: Equinor consensus price targets, EPS estimates


#### Fetch Dependencies

In [None]:
# Data fetching
import yfinance as yf

# Data manipulation
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# ML libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# Data preprocessing
from sklearn.preprocessing import StandardScaler

# Date and time handling
from datetime import datetime, timedelta

# Suppress warnings for cleaner output
import warnings
warnings.filterwarnings('ignore')

#### Collect Data

In [None]:
# CELL 1: COLLECT DATA

def collect_data(start_date="2021-01-01", end_date=None):
    """Collect stock data and return as DataFrame"""
    end_date = end_date or datetime.now().strftime('%Y-%m-%d')
    
    # Updated tickers based on research
    tickers = {
        # Main stock
        'EQNR.OL': 'equinor',
       
        # Energy commodities
        'BZ=F': 'brent_crude',
        'CL=F': 'wti_crude',
        'TTF=F': 'ttf_gas',
        'NG=F': 'henry_hub',
       
        # Norwegian energy stocks
        'AKRBP.OL': 'aker_bp',
        'DNO.OL': 'dno',
        'VAR.OL': 'var_energi',
        'PGS.OL': 'pgs',
        'BWO.OL': 'bw_offshore',
        'FRO.OL': 'frontline',
       
        # Global energy stocks
        'XOM': 'exxon',
        'CVX': 'chevron',
        'SHEL': 'shell',
        'BP': 'bp',
        'TTE': 'totalenergies',
        'COP': 'conocophillips',
        'OXY': 'occidental',
       
        # Indices
        'OSEBX.OL': 'osebx',
        '^GSPC': 'sp500',
        '^IXIC': 'nasdaq',
        '^DJI': 'dow_jones',
        '^FTSE': 'ftse100',
        '^GDAXI': 'dax',
        '^FCHI': 'cac40',
        '^N225': 'nikkei',
        '^HSI': 'hang_seng',
       
        # Volatility and Dollar
        '^VIX': 'vix',
        'DX-Y.NYB': 'dollar_index',
       
        # Commodities
        'GC=F': 'gold',
        'SI=F': 'silver',
       
        # Currencies
        'NOK=X': 'usd_nok',
        'EURNOK=X': 'eur_nok',
        'GBPNOK=X': 'gbp_nok',
        'SEKNOK=X': 'sek_nok',
        'EURUSD=X': 'eur_usd'
    }
    
    all_data = {}
    
    # Download each ticker separately to avoid alignment issues
    for ticker, name in tickers.items():
        try:
            print(f"Downloading {name}...")
            # Download individually
            data = yf.download(ticker, start=start_date, end=end_date, progress=False)
            
            if len(data) > 0:
                # Only keep OHLC and Volume columns
                cols_to_keep = ['Open', 'High', 'Low', 'Close', 'Volume']
                data = data[[c for c in cols_to_keep if c in data.columns]]
                # Handle column renaming - columns might be strings or tuples
                new_cols = []
                for col in data.columns:
                    if isinstance(col, tuple):
                        col_name = col[0] if len(col) > 0 else str(col)
                    else:
                        col_name = str(col)
                    new_cols.append(f"{name}_{col_name.lower()}")
                data.columns = new_cols
                all_data[name] = data
                print(f"  ✓ {name}: {len(data)} rows")
            else:
                print(f"  ✗ No data received for {ticker}")
                
        except Exception as e:
            print(f"  ✗ Error fetching {ticker}: {e}")
    
    if not all_data:
        print("No data collected")
        return pd.DataFrame()
    
    # Combine using outer join to keep all dates
    df = pd.concat(all_data.values(), axis=1, join='outer')
    print(f"Combined data: {len(df)} rows, {len(df.columns)} columns")
    
    # Check initial NaN percentage
    nan_pct = df.isnull().sum().sum() / df.size * 100
    print(f"Initial NaN percentage: {nan_pct:.2f}%")
    
    # Keep only dates where Equinor traded (removes weekends/holidays)
    if 'equinor_close' in df.columns:
        before_filter = len(df)
        df = df[df['equinor_close'].notna()]
        print(f"Filtered to Equinor trading days: {before_filter} → {len(df)} rows")
    
    # Forward fill then backward fill to handle gaps
    df = df.ffill().bfill()
    
    # For any remaining NaNs at the beginning, drop those rows
    # This happens when some tickers start trading later than others
    first_valid_idx = df.first_valid_index()
    last_valid_idx = df.last_valid_index()
    if first_valid_idx and last_valid_idx:
        df = df.loc[first_valid_idx:last_valid_idx]
    
    # Final check for NaN percentage
    nan_count = df.isnull().sum().sum()
    if nan_count > 0:
        nan_pct_final = nan_count / df.size * 100
        print(f"Warning: {nan_count} NaN values remain ({nan_pct_final:.2f}%)")
        # Show which columns have NaNs
        nan_cols = df.columns[df.isnull().any()].tolist()
        if nan_cols:
            print(f"  Columns with NaNs: {nan_cols}")
    
    print(f"\nFinal data: {len(df)} rows, {len(df.columns)} columns")
    if len(df) > 0:
        print(f"Date range: {df.index[0].date()} to {df.index[-1].date()}")
        nan_pct_final = df.isnull().sum().sum() / df.size * 100
        print(f"Final NaN percentage: {nan_pct_final:.2f}%")
    else:
        print("WARNING: No data remaining after processing")
    
    return df

# Run collection
data = collect_data(start_date="2015-01-01")

#### Rate of Change

# Add rate of change features
def add_rate_of_change(df):
    """Add rate of change features to existing dataframe"""
    
    # Key assets for rate of change analysis
    key_assets = ['equinor', 'brent_crude', 'wti_crude', 'usd_nok', 'vix']
    
    for asset in key_assets:
        close_col = f'{asset}_close'
        if close_col in df.columns:
            # First derivative (daily returns)
            df[f'{asset}_return'] = df[close_col].pct_change()
            
            # Second derivative (acceleration)
            df[f'{asset}_acceleration'] = df[f'{asset}_return'].diff()
            
            # Volatility momentum (20-day rolling vol change)
            rolling_vol = df[f'{asset}_return'].rolling(20).std()
            df[f'{asset}_vol_momentum'] = rolling_vol.pct_change()
    
    # Cross-asset correlation momentum (Equinor vs Brent)
    if 'equinor_return' in df.columns and 'brent_crude_return' in df.columns:
        rolling_corr = df['equinor_return'].rolling(20).corr(df['brent_crude_return'])
        df['eq_brent_corr_momentum'] = rolling_corr.diff()
    
    # Drop initial NaN rows created by calculations
    df = df.dropna()
    
    return df

# Apply to your data
data = add_rate_of_change(data)
print(f"After adding rate of change: {data.shape}")

#### Print Data

# CELL 2: PRINT HEAD OF DATA
print(f"Shape: {data.shape}")
print(f"Date range: {data.index.min().date()} to {data.index.max().date()}")
print(f"Columns: {len(data.columns)}")
print(f"Remaining NaNs: {data.isnull().sum().sum()}")
print(f"NaN percentage: {data.isnull().sum().sum() / data.size * 100:.2f}%")
data.head()

#### Save Matrix to CSV

# CELL 3: SAVE TO CSV
filepath = "data/equinor_data_8sept.csv"
data.to_csv(filepath)
print(f"Saved {len(data)} rows to {filepath}")

### Model

In [None]:
"""
Enhanced Trajectory Distribution Model with Optuna Optimization
Includes domain-specific features, rank predictions, and sharper loss functions
"""

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import optuna
from optuna.trial import TrialState
import warnings
warnings.filterwarnings('ignore')

# Set random seeds
np.random.seed(42)
torch.manual_seed(42)

# Device configuration
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

# ============================================================================
# ENHANCED DATA PREPARATION WITH DOMAIN FEATURES
# ============================================================================

def prepare_distribution_data(filepath='data/equinor_data_8sept.csv'):
    """Dynamically prepare data with domain-specific features"""
    data = pd.read_csv(filepath, index_col=0, parse_dates=True)
    
    # Calculate next-day targets
    data['next_day_return'] = data['equinor_close'].pct_change(1).shift(-1)
    data['next_day_volatility'] = data['equinor_close'].pct_change(1).rolling(20).std().shift(-1)
    
    # Store original columns
    original_columns = set(data.columns)
    
    # Detect all assets dynamically
    price_columns = [col for col in data.columns if 'close' in col.lower()]
    assets = list(set([col.replace('_close', '') for col in price_columns]))
    
    print(f"   Found {len(assets)} assets: {assets[:5]}...")
    
    # Generate features for each asset
    for asset in assets:
        base_col = f'{asset}_close' if f'{asset}_close' in data.columns else None
        
        if base_col and base_col in data.columns:
            # Returns at multiple horizons
            for period in [1, 2, 3, 5, 7, 10, 15, 20, 30, 60]:
                data[f'{asset}_ret_{period}d'] = data[base_col].pct_change(period)
            
            # Log returns
            data[f'{asset}_logret_1d'] = np.log(data[base_col] / data[base_col].shift(1))
            
            # Volatility features
            for period in [5, 10, 20, 30, 60]:
                data[f'{asset}_vol_{period}d'] = data[f'{asset}_ret_1d'].rolling(period).std()
                data[f'{asset}_vol_skew_{period}d'] = data[f'{asset}_ret_1d'].rolling(period).skew()
                data[f'{asset}_vol_kurt_{period}d'] = data[f'{asset}_ret_1d'].rolling(period).kurt()
            
            # Moving averages
            for period in [5, 10, 20, 50, 100, 200]:
                if len(data) >= period:
                    data[f'{asset}_ma{period}'] = data[base_col].rolling(period).mean()
                    data[f'{asset}_rel_ma{period}'] = data[base_col] / data[f'{asset}_ma{period}'] - 1
            
            # Technical indicators
            for period in [5, 10, 20]:
                data[f'{asset}_min_{period}d'] = data[base_col].rolling(period).min()
                data[f'{asset}_max_{period}d'] = data[base_col].rolling(period).max()
                data[f'{asset}_range_{period}d'] = (data[base_col] - data[f'{asset}_min_{period}d']) / \
                                                   (data[f'{asset}_max_{period}d'] - data[f'{asset}_min_{period}d'] + 1e-10)
            
            # RSI
            delta = data[base_col].diff()
            gain = (delta.where(delta > 0, 0)).rolling(14).mean()
            loss = (-delta.where(delta < 0, 0)).rolling(14).mean()
            rs = gain / (loss + 1e-10)
            data[f'{asset}_rsi'] = 100 - (100 / (1 + rs))
            
            # MACD
            exp1 = data[base_col].ewm(span=12, adjust=False).mean()
            exp2 = data[base_col].ewm(span=26, adjust=False).mean()
            data[f'{asset}_macd'] = exp1 - exp2
            data[f'{asset}_macd_signal'] = data[f'{asset}_macd'].ewm(span=9, adjust=False).mean()
            
            # Bollinger Bands
            ma20 = data[base_col].rolling(20).mean()
            std20 = data[base_col].rolling(20).std()
            data[f'{asset}_bb_upper'] = ma20 + (std20 * 2)
            data[f'{asset}_bb_lower'] = ma20 - (std20 * 2)
            data[f'{asset}_bb_pos'] = (data[base_col] - data[f'{asset}_bb_lower']) / \
                                      (data[f'{asset}_bb_upper'] - data[f'{asset}_bb_lower'] + 1e-10)
    
    # DOMAIN-SPECIFIC FEATURES FOR EQUINOR
    if all(col in data.columns for col in ['brent_crude_ret_1d', 'usd_nok_ret_1d']):
        # Oil-equity signal
        data['oil_equity_signal'] = (
            data.get('brent_crude_ret_1d', 0) * 0.5 +
            data.get('osebx_ret_1d', 0) * 0.3 +
            data.get('usd_nok_ret_1d', 0) * -0.2  # NOK strength = bad for Equinor
        )
        
        # Oil-equity divergence
        data['oil_equity_divergence'] = (
            data.get('brent_crude_ret_5d', 0) - data.get('equinor_ret_5d', 0)
        )
    
    # Energy sector momentum
    energy_assets = ['aker_bp', 'var_energi', 'dno', 'equinor']
    energy_returns = []
    for asset in energy_assets:
        ret_col = f'{asset}_ret_1d'
        if ret_col in data.columns:
            energy_returns.append(data[ret_col])
    
    if energy_returns:
        data['energy_momentum'] = pd.concat(energy_returns, axis=1).mean(axis=1)
        data['energy_momentum_5d'] = pd.concat([data[f'{asset}_ret_5d'] 
                                                for asset in energy_assets 
                                                if f'{asset}_ret_5d' in data.columns], axis=1).mean(axis=1)
    
    # Macro signals
    if 'vix_close' in data.columns:
        data['risk_on_signal'] = -data['vix_close'].pct_change(5)  # Lower VIX = risk on
    
    # Cross-asset correlations
    return_cols = [col for col in data.columns if '_ret_1d' in col]
    for i, col1 in enumerate(return_cols[:10]):  # Limit to avoid explosion
        for col2 in return_cols[i+1:i+3]:
            asset1 = col1.replace('_ret_1d', '')
            asset2 = col2.replace('_ret_1d', '')
            
            for period in [5, 20]:
                data[f'{asset1}_{asset2}_corr{period}'] = data[col1].rolling(period).corr(data[col2])
    
    # Calendar features
    data['day_of_week'] = pd.to_datetime(data.index).dayofweek
    data['month'] = pd.to_datetime(data.index).month
    data['quarter'] = pd.to_datetime(data.index).quarter
    
    # Cyclical encoding
    data['day_sin'] = np.sin(2 * np.pi * data['day_of_week'] / 7)
    data['day_cos'] = np.cos(2 * np.pi * data['day_of_week'] / 7)
    data['month_sin'] = np.sin(2 * np.pi * data['month'] / 12)
    data['month_cos'] = np.cos(2 * np.pi * data['month'] / 12)
    
    # Get all new features
    new_columns = set(data.columns) - original_columns
    feature_columns = [col for col in new_columns 
                      if col not in ['next_day_return', 'next_day_volatility']]
    
    print(f"   Created {len(feature_columns)} features dynamically")
    
    # Select features and handle missing values
    X = data[feature_columns].ffill().bfill().fillna(0)
    X = X.replace([np.inf, -np.inf], [1e6, -1e6])
    
    # Targets
    y_return = data['next_day_return']
    y_volatility = data['next_day_volatility']
    
    # Create rank targets (quintiles)
    y_rank = pd.qcut(y_return.dropna(), q=5, labels=[0, 1, 2, 3, 4])
    
    # Remove NaN targets
    valid_idx = ~(y_return.isna() | y_volatility.isna())
    X = X[valid_idx]
    y_return = y_return[valid_idx]
    y_volatility = y_volatility[valid_idx]
    y_rank = y_rank[valid_idx]
    
    return X, y_return, y_volatility, y_rank

# ============================================================================
# ENHANCED MODEL WITH ATTENTION
# ============================================================================

class EnhancedTrajectoryGRU(nn.Module):
    """GRU with attention and sharper predictions"""
    
    def __init__(self, input_features, hidden_size=256, state_size=128, 
                 dropout=0.3, n_heads=8):
        super().__init__()
        self.state_size = state_size
        self.hidden_size = hidden_size
        
        # Feature attention to identify important features
        self.feature_attention = nn.Sequential(
            nn.Linear(input_features, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, input_features),
            nn.Softmax(dim=-1)
        )
        
        # Input projection
        self.input_projection = nn.Sequential(
            nn.Linear(input_features + state_size, hidden_size * 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size * 2, hidden_size)
        )
        
        # GRU with attention
        self.gru = nn.GRU(hidden_size, hidden_size, 
                         batch_first=True, num_layers=3, dropout=dropout)
        
        self.attention = nn.MultiheadAttention(hidden_size, num_heads=n_heads, 
                                               dropout=dropout, batch_first=True)
        
        # Prediction heads
        self.return_head = nn.Sequential(
            nn.Linear(hidden_size, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
        
        self.volatility_head = nn.Sequential(
            nn.Linear(hidden_size, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Softplus()
        )
        
        # Rank prediction head (5 classes)
        self.rank_head = nn.Sequential(
            nn.Linear(hidden_size, 64),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(64, 5)
        )
        
        # Hidden state projection
        self.state_head = nn.Sequential(
            nn.Linear(hidden_size, state_size),
            nn.Tanh()
        )
        
    def forward(self, x, hidden_state=None):
        batch_size = x.size(0)
        
        if hidden_state is None:
            hidden_state = torch.zeros(batch_size, self.state_size).to(x.device)
        
        # Apply feature attention
        feature_weights = self.feature_attention(x)
        x_weighted = x * feature_weights
        
        # Concatenate and project
        x_combined = torch.cat([x_weighted, hidden_state], dim=-1)
        x_projected = self.input_projection(x_combined)
        x_projected = x_projected.unsqueeze(1)
        
        # GRU processing
        gru_out, _ = self.gru(x_projected)
        
        # Self-attention
        attn_out, attn_weights = self.attention(gru_out, gru_out, gru_out)
        gru_out = gru_out + attn_out
        gru_out = gru_out.squeeze(1)
        
        # Predictions
        predicted_return = self.return_head(gru_out).squeeze(-1)
        predicted_vol = self.volatility_head(gru_out).squeeze(-1)
        rank_logits = self.rank_head(gru_out)
        new_state = self.state_head(gru_out)
        
        return predicted_return, predicted_vol, rank_logits, new_state, feature_weights

# ============================================================================
# SHARPER LOSS FUNCTION
# ============================================================================

class SharperDistributionLoss(nn.Module):
    """Loss function that penalizes excessive uncertainty"""
    
    def __init__(self, alpha=1.0, beta=1.0, gamma=1.0, uncertainty_penalty=5.0, 
                 target_std=0.02, rank_weight=1.0):
        super().__init__()
        self.mse = nn.MSELoss()
        self.ce = nn.CrossEntropyLoss()
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.uncertainty_penalty = uncertainty_penalty
        self.target_std = target_std
        self.rank_weight = rank_weight
        
    def forward(self, pred_mean, pred_std, rank_logits, actual_return, 
                realized_vol, rank_labels):
        # Standard losses
        return_loss = self.mse(pred_mean, actual_return)
        vol_loss = self.mse(pred_std, realized_vol)
        
        # NLL loss
        eps = 1e-6
        variance = (pred_std + eps) ** 2
        nll_loss = 0.5 * (torch.log(2 * np.pi * variance) + 
                         ((actual_return - pred_mean) ** 2) / variance)
        nll_loss = nll_loss.mean()
        
        # Penalize excessive uncertainty
        excess_uncertainty = torch.relu(pred_std - self.target_std)
        uncertainty_loss = excess_uncertainty.mean() * self.uncertainty_penalty
        
        # Rank prediction loss
        if rank_labels is not None and rank_logits is not None:
            rank_loss = self.ce(rank_logits, rank_labels)
        else:
            rank_loss = 0
        
        total_loss = (self.alpha * return_loss + 
                     self.beta * vol_loss + 
                     self.gamma * nll_loss + 
                     uncertainty_loss +
                     self.rank_weight * rank_loss)
        
        return total_loss, {
            'return_mse': return_loss.item(),
            'vol_mse': vol_loss.item(),
            'nll': nll_loss.item(),
            'uncertainty_penalty': uncertainty_loss.item(),
            'rank_loss': rank_loss.item() if rank_loss != 0 else 0
        }

# ============================================================================
# DATASET
# ============================================================================

class EnhancedDataset(Dataset):
    def __init__(self, features, returns, volatilities, ranks=None):
        self.features = torch.FloatTensor(features.values if hasattr(features, 'values') else features)
        self.returns = torch.FloatTensor(returns.values if hasattr(returns, 'values') else returns)
        self.volatilities = torch.FloatTensor(volatilities.values if hasattr(volatilities, 'values') else volatilities)
        if ranks is not None:
            self.ranks = torch.LongTensor(ranks.values if hasattr(ranks, 'values') else ranks)
        else:
            self.ranks = None
        
    def __len__(self):
        return len(self.features)
    
    def __getitem__(self, idx):
        if self.ranks is not None:
            return self.features[idx], self.returns[idx], self.volatilities[idx], self.ranks[idx]
        return self.features[idx], self.returns[idx], self.volatilities[idx], torch.tensor(0)

# ============================================================================
# TRAINING WITH LONGER EPOCHS
# ============================================================================

def train_enhanced_model(model, X_train, y_return_train, y_vol_train, y_rank_train,
                         X_val, y_return_val, y_vol_val, y_rank_val,
                         epochs=100, batch_size=32, lr=0.001, 
                         loss_weights=None):
    """Train with longer epochs and better optimization"""
    model = model.to(device)
    
    train_losses = []
    val_losses = []
    best_val_loss = float('inf')
    patience_counter = 0
    max_patience = 15
    
    # Default loss weights
    if loss_weights is None:
        loss_weights = {'alpha': 1.0, 'beta': 1.0, 'gamma': 1.0, 
                       'uncertainty_penalty': 5.0, 'rank_weight': 1.0}
    
    # Create datasets
    train_dataset = EnhancedDataset(X_train, y_return_train, y_vol_train, y_rank_train)
    val_dataset = EnhancedDataset(X_val, y_return_val, y_vol_val, y_rank_val)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    
    # Optimizer with different learning rates for different parts
    optimizer = torch.optim.AdamW([
        {'params': model.feature_attention.parameters(), 'lr': lr * 2},
        {'params': model.gru.parameters(), 'lr': lr},
        {'params': model.return_head.parameters(), 'lr': lr},
        {'params': model.volatility_head.parameters(), 'lr': lr},
        {'params': model.rank_head.parameters(), 'lr': lr * 1.5}
    ], weight_decay=1e-5)
    
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5
    )
    
    criterion = SharperDistributionLoss(**loss_weights)
    
    for epoch in range(epochs):
        # Training
        model.train()
        train_metrics = {'return_mse': 0, 'vol_mse': 0, 'nll': 0, 
                        'uncertainty_penalty': 0, 'rank_loss': 0}
        epoch_train_loss = 0
        
        for features, returns, vols, ranks in train_loader:
            features = features.to(device)
            returns = returns.to(device)
            vols = vols.to(device)
            ranks = ranks.to(device)
            
            optimizer.zero_grad()
            
            pred_mean, pred_std, rank_logits, _, _ = model(features)
            loss, metrics = criterion(pred_mean, pred_std, rank_logits, 
                                     returns, vols, ranks)
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            epoch_train_loss += loss.item()
            for key in train_metrics:
                train_metrics[key] += metrics[key]
        
        # Validation
        model.eval()
        val_metrics = {'return_mse': 0, 'vol_mse': 0, 'nll': 0,
                      'uncertainty_penalty': 0, 'rank_loss': 0}
        epoch_val_loss = 0
        
        with torch.no_grad():
            for features, returns, vols, ranks in val_loader:
                features = features.to(device)
                returns = returns.to(device)
                vols = vols.to(device)
                ranks = ranks.to(device)
                
                pred_mean, pred_std, rank_logits, _, _ = model(features)
                loss, metrics = criterion(pred_mean, pred_std, rank_logits,
                                        returns, vols, ranks)
                
                epoch_val_loss += loss.item()
                for key in val_metrics:
                    val_metrics[key] += metrics[key]
        
        train_losses.append(epoch_train_loss / len(train_loader))
        val_losses.append(epoch_val_loss / len(val_loader))
        
        scheduler.step(val_losses[-1])
        
        # Early stopping
        if val_losses[-1] < best_val_loss:
            best_val_loss = val_losses[-1]
            patience_counter = 0
            best_model_state = model.state_dict()
        else:
            patience_counter += 1
            if patience_counter >= max_patience:
                print(f"Early stopping at epoch {epoch}")
                model.load_state_dict(best_model_state)
                break
        
        if epoch % 10 == 0:
            print(f"Epoch {epoch}: Train Loss: {train_losses[-1]:.4f}, "
                  f"Val Loss: {val_losses[-1]:.4f}")
            print(f"  Return MSE: {val_metrics['return_mse']/len(val_loader):.4f}, "
                  f"Rank Loss: {val_metrics['rank_loss']/len(val_loader):.4f}")
    
    return model, train_losses, val_losses

# ============================================================================
# OPTUNA OPTIMIZATION
# ============================================================================

def optimize_hyperparameters(X_train, y_return_train, y_vol_train, y_rank_train,
                            X_val, y_return_val, y_vol_val, y_rank_val,
                            n_trials=50):
    """Optimize hyperparameters using Optuna"""
    
    def objective(trial):
        # Hyperparameters to optimize
        hidden_size = trial.suggest_int('hidden_size', 128, 512, step=64)
        state_size = trial.suggest_int('state_size', 64, 256, step=32)
        dropout = trial.suggest_float('dropout', 0.1, 0.5)
        lr = trial.suggest_float('lr', 1e-4, 1e-2, log=True)
        batch_size = trial.suggest_categorical('batch_size', [16, 32, 64, 128])
        n_heads = trial.suggest_categorical('n_heads', [4, 8, 16])
        
        # Loss function weights
        alpha = trial.suggest_float('alpha_return', 0.5, 2.0)
        beta = trial.suggest_float('beta_vol', 0.1, 1.0)
        gamma = trial.suggest_float('gamma_nll', 0.1, 1.0)
        uncertainty_penalty = trial.suggest_float('uncertainty_penalty', 1.0, 10.0)
        rank_weight = trial.suggest_float('rank_weight', 0.5, 2.0)
        target_std = trial.suggest_float('target_std', 0.01, 0.03)
        
        loss_weights = {
            'alpha': alpha, 'beta': beta, 'gamma': gamma,
            'uncertainty_penalty': uncertainty_penalty,
            'rank_weight': rank_weight,
            'target_std': target_std
        }
        
        # Create and train model
        model = EnhancedTrajectoryGRU(
            input_features=X_train.shape[1],
            hidden_size=hidden_size,
            state_size=state_size,
            dropout=dropout,
            n_heads=n_heads
        )
        
        # Train for fewer epochs during optimization
        model, train_losses, val_losses = train_enhanced_model(
            model, X_train, y_return_train, y_vol_train, y_rank_train,
            X_val, y_return_val, y_vol_val, y_rank_val,
            epochs=30, batch_size=batch_size, lr=lr,
            loss_weights=loss_weights
        )
        
        # Evaluate model sharpness
        model.eval()
        with torch.no_grad():
            val_dataset = EnhancedDataset(X_val, y_return_val, y_vol_val, y_rank_val)
            val_loader = DataLoader(val_dataset, batch_size=32)
            
            all_stds = []
            for features, _, _, _ in val_loader:
                features = features.to(device)
                _, pred_std, _, _, _ = model(features)
                all_stds.append(pred_std.cpu().numpy())
            
            avg_std = np.concatenate(all_stds).mean()
        
        # Penalize models with too-wide confidence intervals
        penalty = max(0, (avg_std - 0.025) * 100)
        
        return val_losses[-1] + penalty
    
    # Create study
    study = optuna.create_study(
        direction='minimize',
        pruner=optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=10)
    )
    
    study.optimize(objective, n_trials=n_trials, show_progress_bar=True)
    
    print("\nBest hyperparameters:")
    for key, value in study.best_params.items():
        print(f"  {key}: {value}")
    
    return study.best_params

# ============================================================================
# TRAJECTORY GENERATION
# ============================================================================

@torch.no_grad()
def generate_sharp_trajectories(model, initial_features, n_trajectories=1000, 
                               horizon=15, force_decision=True):
    """Generate trajectories with sharper, more decisive predictions"""
    model.eval()
    device = next(model.parameters()).device
    batch_size = 100
    all_trajectories = []
    all_ranks = []
    
    for batch_start in range(0, n_trajectories, batch_size):
        batch_end = min(batch_start + batch_size, n_trajectories)
        current_batch_size = batch_end - batch_start
        
        features = initial_features.repeat(current_batch_size, 1).to(device)
        hidden_state = None
        batch_trajectories = []
        batch_ranks = []
        
        cumulative_variance = torch.zeros(current_batch_size).to(device)
        
        for day in range(horizon):
            pred_mean, pred_std, rank_logits, hidden_state, feature_weights = model(features, hidden_state)
            
            # Get rank predictions
            rank_probs = F.softmax(rank_logits, dim=-1)
            predicted_ranks = torch.argmax(rank_probs, dim=-1)
            batch_ranks.append(predicted_ranks.cpu().numpy())
            
            if force_decision:
                # Adjust mean based on rank prediction
                rank_adjustment = torch.where(predicted_ranks == 0, -0.02,  # Strong sell
                                            torch.where(predicted_ranks == 1, -0.01,  # Sell
                                            torch.where(predicted_ranks == 2, 0.0,    # Hold
                                            torch.where(predicted_ranks == 3, 0.01,   # Buy
                                                       0.02))))  # Strong buy
                pred_mean = pred_mean + rank_adjustment
                
                # Reduce std for more decisive predictions
                pred_std = pred_std * 0.7
            
            # Add accumulated uncertainty (but less than before)
            total_std = torch.sqrt(pred_std**2 + cumulative_variance * 0.5)
            
            # Sample returns
            sampled_returns = torch.normal(pred_mean, total_std)
            batch_trajectories.append(sampled_returns.cpu().numpy())
            
            # Accumulate variance with stronger decay
            cumulative_variance = cumulative_variance * 0.9 + pred_std**2 * 0.3
            
            # Update features
            features = features.clone()
            features[:, 0] = sampled_returns
            noise_scale = 0.005 * np.sqrt(day + 1)
            features = features + torch.randn_like(features) * noise_scale
        
        all_trajectories.append(np.array(batch_trajectories).T)
        all_ranks.append(np.array(batch_ranks).T)
    
    return np.vstack(all_trajectories), np.vstack(all_ranks)

# ============================================================================
# ENHANCED VISUALIZATION
# ============================================================================

def plot_enhanced_results(trajectories, ranks, feature_importance=None):
    """Visualize enhanced model results"""
    fig = plt.figure(figsize=(16, 12))
    
    # Trajectory distribution
    ax1 = plt.subplot(2, 3, 1)
    percentiles = {
        'p5': np.percentile(trajectories, 5, axis=0),
        'p25': np.percentile(trajectories, 25, axis=0),
        'p50': np.percentile(trajectories, 50, axis=0),
        'p75': np.percentile(trajectories, 75, axis=0),
        'p95': np.percentile(trajectories, 95, axis=0)
    }
    days = np.arange(1, trajectories.shape[1] + 1)
    
    ax1.fill_between(days, percentiles['p5'], percentiles['p95'], 
                     alpha=0.2, color='blue', label='90% CI')
    ax1.fill_between(days, percentiles['p25'], percentiles['p75'], 
                     alpha=0.3, color='blue', label='50% CI')
    ax1.plot(days, percentiles['p50'], 'b-', linewidth=2, label='Median')
    ax1.set_xlabel('Days Ahead')
    ax1.set_ylabel('Cumulative Return')
    ax1.set_title('Sharper Trajectory Predictions')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Rank distribution over time
    ax2 = plt.subplot(2, 3, 2)
    rank_means = ranks.mean(axis=0)
    rank_colors = ['darkred', 'red', 'gray', 'green', 'darkgreen']
    
    for rank in range(5):
        rank_pct = (ranks == rank).mean(axis=0) * 100
        ax2.plot(days, rank_pct, label=f'Rank {rank}', color=rank_colors[rank])
    
    ax2.set_xlabel('Days Ahead')
    ax2.set_ylabel('Percentage')
    ax2.set_title('Rank Distribution Over Time')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    # Feature importance (if available)
    if feature_importance is not None:
        ax3 = plt.subplot(2, 3, 3)
        top_features = np.argsort(feature_importance)[-20:]
        ax3.barh(range(20), feature_importance[top_features])
        ax3.set_xlabel('Importance')
        ax3.set_title('Top 20 Feature Importance')
        ax3.grid(True, alpha=0.3)
    
    # Confidence interval width
    ax4 = plt.subplot(2, 3, 4)
    ci_width = percentiles['p75'] - percentiles['p25']
    ax4.plot(days, ci_width, 'g-', linewidth=2)
    ax4.fill_between(days, 0, ci_width, alpha=0.3, color='green')
    ax4.set_xlabel('Days Ahead')
    ax4.set_ylabel('50% CI Width')
    ax4.set_title('Prediction Confidence')
    ax4.grid(True, alpha=0.3)
    
    # Return distribution at key horizons
    ax5 = plt.subplot(2, 3, 5)
    horizons = [1, 5, 10, 15]
    colors = ['blue', 'green', 'orange', 'red']
    
    for horizon, color in zip(horizons, colors):
        if horizon <= trajectories.shape[1]:
            ax5.hist(trajectories[:, horizon-1], bins=30, alpha=0.3,
                    color=color, label=f'Day {horizon}', density=True)
    
    ax5.set_xlabel('Return')
    ax5.set_ylabel('Density')
    ax5.set_title('Return Distributions')
    ax5.legend()
    ax5.grid(True, alpha=0.3)
    
    # Decision signal strength
    ax6 = plt.subplot(2, 3, 6)
    decision_strength = np.abs(rank_means - 2)  # Distance from neutral (rank 2)
    ax6.plot(days, decision_strength, 'r-', linewidth=2)
    ax6.fill_between(days, 0, decision_strength, alpha=0.3, color='red')
    ax6.set_xlabel('Days Ahead')
    ax6.set_ylabel('Signal Strength')
    ax6.set_title('Decision Confidence')
    ax6.grid(True, alpha=0.3)
    
    plt.suptitle('Enhanced Model Results with Sharper Predictions', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()
    
    return percentiles

# ============================================================================
# MAIN EXECUTION
# ============================================================================

def main(use_optuna=True):
    print("=" * 60)
    print("ENHANCED TRAJECTORY MODEL WITH SHARPER PREDICTIONS")
    print("=" * 60)
    
    # Load data
    print("\n1. Loading and preparing enhanced data...")
    X, y_return, y_volatility, y_rank = prepare_distribution_data()
    print(f"   Features: {X.shape[1]}")
    print(f"   Samples: {len(X)}")
    
    # Split data
    train_idx = int(0.8 * len(X))
    val_idx = int(0.9 * len(X))
    
    X_train = X.iloc[:train_idx]
    y_return_train = y_return.iloc[:train_idx]
    y_vol_train = y_volatility.iloc[:train_idx]
    y_rank_train = y_rank.iloc[:train_idx]
    
    X_val = X.iloc[train_idx:val_idx]
    y_return_val = y_return.iloc[train_idx:val_idx]
    y_vol_val = y_volatility.iloc[train_idx:val_idx]
    y_rank_val = y_rank.iloc[train_idx:val_idx]
    
    X_test = X.iloc[val_idx:]
    y_return_test = y_return.iloc[val_idx:]
    y_vol_test = y_volatility.iloc[val_idx:]
    y_rank_test = y_rank.iloc[val_idx:]
    
    # Scale features
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_val_scaled = scaler.transform(X_val)
    X_test_scaled = scaler.transform(X_test)
    
    # Optuna optimization
    if use_optuna:
        print("\n2. Running Optuna hyperparameter optimization...")
        best_params = optimize_hyperparameters(
            X_train_scaled, y_return_train, y_vol_train, y_rank_train,
            X_val_scaled, y_return_val, y_vol_val, y_rank_val,
            n_trials=30
        )
    else:
        best_params = {
            'hidden_size': 256, 'state_size': 128, 'dropout': 0.3,
            'lr': 0.001, 'batch_size': 32, 'n_heads': 8,
            'alpha_return': 1.0, 'beta_vol': 0.5, 'gamma_nll': 0.5,
            'uncertainty_penalty': 5.0, 'rank_weight': 1.0, 'target_std': 0.02
        }
    
    # Train final model with best parameters
    print("\n3. Training final model with extended epochs...")
    model = EnhancedTrajectoryGRU(
        input_features=X_train_scaled.shape[1],
        hidden_size=best_params['hidden_size'],
        state_size=best_params['state_size'],
        dropout=best_params['dropout'],
        n_heads=best_params['n_heads']
    )
    
    loss_weights = {
        'alpha': best_params['alpha_return'],
        'beta': best_params['beta_vol'],
        'gamma': best_params['gamma_nll'],
        'uncertainty_penalty': best_params['uncertainty_penalty'],
        'rank_weight': best_params['rank_weight'],
        'target_std': best_params['target_std']
    }
    
    model, train_losses, val_losses = train_enhanced_model(
        model, X_train_scaled, y_return_train, y_vol_train, y_rank_train,
        X_val_scaled, y_return_val, y_vol_val, y_rank_val,
        epochs=100,  # Longer training
        batch_size=best_params['batch_size'],
        lr=best_params['lr'],
        loss_weights=loss_weights
    )
    
    # Generate sharp trajectories
    print("\n4. Generating sharp trajectory predictions...")
    test_features = torch.FloatTensor(X_test_scaled[0:1])
    trajectories, ranks = generate_sharp_trajectories(
        model, test_features, n_trajectories=1000, horizon=15,
        force_decision=True
    )
    
    # Get feature importance
    model.eval()
    with torch.no_grad():
        _, _, _, _, feature_weights = model(test_features)
        feature_importance = feature_weights[0].cpu().numpy()
    
    # Visualize results
    print("\n5. Plotting enhanced results...")
    percentiles = plot_enhanced_results(trajectories, ranks, feature_importance)
    
    # Price forecast
    current_price = 36.50
    cumulative_returns = np.cumsum(trajectories, axis=1)
    median_prices = current_price * (1 + np.median(cumulative_returns, axis=0))
    p25_prices = current_price * (1 + np.percentile(cumulative_returns, 25, axis=0))
    p75_prices = current_price * (1 + np.percentile(cumulative_returns, 75, axis=0))
    
    print("\n" + "="*60)
    print("ENHANCED PRICE FORECAST (Sharper Predictions)")
    print(f"Current Price: ${current_price:.2f}")
    print("="*60)
    
    for day in [1, 5, 10, 15]:
        if day <= len(median_prices):
            rank_at_day = ranks[:, day-1].mean()
            signal = ['STRONG SELL', 'SELL', 'HOLD', 'BUY', 'STRONG BUY'][int(rank_at_day)]
            print(f"Day {day:2d}: ${median_prices[day-1]:.2f} "
                  f"[${p25_prices[day-1]:.2f} - ${p75_prices[day-1]:.2f}] "
                  f"Signal: {signal}")
    
    print("\n" + "="*60)
    print("Model complete with sharper, more decisive predictions!")
    print("="*60)
    
    return model, trajectories, ranks, feature_importance

if __name__ == "__main__":
    model, trajectories, ranks, feature_importance = main(use_optuna=True)

Using device: mps
ENHANCED TRAJECTORY MODEL WITH SHARPER PREDICTIONS

1. Loading and preparing enhanced data...
   Found 36 assets: ['nikkei', 'vix', 'wti_crude', 'dax', 'nasdaq']...
   Created 1960 features dynamically


[I 2025-09-09 00:13:03,862] A new study created in memory with name: no-name-f2d6258f-f7f4-4118-965e-d5a914b52483


   Features: 1960
   Samples: 2642

2. Running Optuna hyperparameter optimization...


  0%|          | 0/30 [00:00<?, ?it/s]