# ReWTSE-LLM-RL Trading System Training

Training notebook for the hybrid LLM+RL trading system.

**System Architecture:**
- LLM Agent (Google Gemini): Generates strategic trading signals
- RL Agent (DDQN): Learns tactical execution policies
- ReWTSE Ensemble: Weighted ensemble with QP optimization

**Usage:**
- **Local**: Navigate to project directory and run `jupyter notebook` or use VS Code
- **Google Colab**: Upload to Colab, mount Drive, and run all cells

The notebook automatically detects if running on Colab or locally.

## 1. Environment Setup

In [11]:
# Detect environment (Colab or Local)
try:
    import google.colab
    IN_COLAB = True
    print("Running on Google Colab")
except:
    IN_COLAB = False
    print("Running locally")

# Check GPU availability
import torch
print(f"\nPyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")

Running locally

PyTorch version: 2.9.0
CUDA available: False


In [12]:
# Mount Google Drive (only on Colab)
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    print("Google Drive mounted!")
else:
    print("Skipping Google Drive mount (not on Colab)")

Skipping Google Drive mount (not on Colab)


In [None]:
import os

# Setup project path based on environment
if IN_COLAB:
    print("Google Colab Setup:")
    print("Choose one option:\n")
    
    # Option 1: Clone from GitHub
    print("Option 1: Clone from GitHub (Recommended)")
    print("  Uncomment the following lines:")
    !git clone https://github.com/mattiazingaretti/rewts_quant_trading.git
    %cd rewts_quant_trading
    
    # Option 2: Use files from Google Drive
    print("Option 2: Use files from Google Drive")
    print("  Uncomment and modify the path to your project folder:")
    # project_path = '/content/drive/MyDrive/rewts_quant_trading'
    # os.chdir(project_path)
    
    # Option 3: Upload manually
    print("\nOption 3: Upload project files manually")
    print("  Use the folder icon on the left sidebar to upload files\n")
else:
    print("Local Setup:")
    print("Make sure you're in the project root directory\n")
    # Navigate to project root (assuming notebook is in notebooks/)
    if os.path.basename(os.getcwd()) == 'notebooks':
        os.chdir('..')
        print("Changed to project root directory")

print(f"Current working directory: {os.getcwd()}")
print(f"\nDirectory contents:")
!ls -la

In [14]:
# Install required packages
print("Checking and installing dependencies...\n")

if IN_COLAB:
    print("Installing all dependencies on Colab...")
    !pip install -q google-generativeai>=0.3.0
    !pip install -q torch>=2.0.0 torchvision>=0.15.0
    !pip install -q gym>=0.26.0
    !pip install -q stable-baselines3>=2.0.0
    !pip install -q pandas>=1.5.0 numpy>=1.23.0
    !pip install -q matplotlib>=3.6.0 seaborn>=0.12.0
    !pip install -q yfinance>=0.2.0
    !pip install -q cvxopt>=1.3.0
    !pip install -q tqdm>=4.64.0
    !pip install -q pyyaml>=6.0
else:
    print("Local environment detected.")
    print("Using existing virtual environment or system packages.")
    print("\nIf you get import errors, install dependencies with:")
    print("  pip install -r requirements.txt")
    print("\nOr install individually:")
    print("  pip install google-generativeai torch gym stable-baselines3")
    print("  pip install pandas numpy matplotlib seaborn yfinance cvxopt tqdm pyyaml")

print("\n✓ Dependencies check complete!")

Checking and installing dependencies...

Local environment detected.
Using existing virtual environment or system packages.

If you get import errors, install dependencies with:
  pip install -r requirements.txt

Or install individually:
  pip install google-generativeai torch gym stable-baselines3
  pip install pandas numpy matplotlib seaborn yfinance cvxopt tqdm pyyaml

✓ Dependencies check complete!


## 2. Configuration

In [None]:
# API Configuration
import os
from getpass import getpass

# Set your Gemini API key
# Get your API key from: https://makersuite.google.com/app/apikey

# Check if already set in environment
if os.getenv('GEMINI_API_KEY'):
    print("Using GEMINI_API_KEY from environment variables")
    GEMINI_API_KEY = os.getenv('GEMINI_API_KEY')
else:
    print("Enter your Gemini API Key (input will be hidden)")
    GEMINI_API_KEY = getpass('Gemini API Key: ')
    os.environ['GEMINI_API_KEY'] = GEMINI_API_KEY

# Google AI Studio Project ID (for paid tier)
# If you have a paid plan, set your project ID here to avoid free tier rate limits
GOOGLE_PROJECT_ID = 'gen-lang-client-0177583380'  # Your project ID

print("\n✓ API key configured!")
if GOOGLE_PROJECT_ID:
    print(f"✓ Using Google AI Studio project: {GOOGLE_PROJECT_ID}")
    print("  This will use your paid tier quota instead of free tier limits")

### Monitoring API Usage

With your paid tier project configured, you can monitor API usage at:
- **Google AI Studio**: https://aistudio.google.com/app/apikey
- **Cloud Console**: https://console.cloud.google.com/apis/api/generativelanguage.googleapis.com/quotas?project=gen-lang-client-0177583380

Your project ID `gen-lang-client-0177583380` will be used for all API calls, ensuring you use your paid tier quota instead of free tier limits.

In [None]:
# Training Configuration
config = {
    'tickers': ['AAPL'],  # Add more tickers as needed
    
    # LLM Configuration (Google Gemini)
    'llm': {
        'llm_model': 'gemini-2.0-flash-exp',  # Or 'gemini-pro' for better quality
        'temperature': 0.0,  # Deterministic for reproducibility
        'seed': 49,
        'gemini_api_key': os.getenv('GEMINI_API_KEY'),
        'project_id': GOOGLE_PROJECT_ID  # Use paid tier project
    },
    
    # ReWTSE Ensemble Configuration
    'rewts': {
        'chunk_length': 500,  # Number of trading days per chunk
        'lookback_length': 100,  # Lookback period for weight optimization
        'forecast_horizon': 1,  # Steps ahead to forecast
        'episodes_per_chunk': 50,  # Training episodes per chunk (increase for better results)
        'gamma': 0.99,  # Discount factor
        'epsilon_start': 1.0,  # Initial exploration rate
        'epsilon_min': 0.01,  # Minimum exploration rate
        'epsilon_decay': 0.995,  # Exploration decay rate
        'learning_rate': 1e-3,  # Adam learning rate
        'batch_size': 64,  # Mini-batch size
        'buffer_size': 10000,  # Replay buffer size
        'target_update_freq': 10,  # Target network update frequency
        'hidden_dims': [128, 64]  # Neural network architecture
    },
    
    # Trading Environment Configuration
    'trading_env': {
        'initial_balance': 10000,  # Starting capital
        'transaction_cost': 0.001,  # 0.1% transaction cost
        'max_position': 1.0  # Maximum position size (1.0 = 100% of capital)
    },
    
    # Strategy generation frequency (days)
    'strategy_frequency': 20  # Generate new strategy every 20 trading days (~monthly)
}

print("Configuration:")
import json
print(json.dumps({k: v for k, v in config.items() if k != 'llm'}, indent=2))
print(f"\nLLM Model: {config['llm']['llm_model']}")
print(f"LLM Project: {config['llm']['project_id']}")
print(f"Tickers: {config['tickers']}")

## 3. Data Preparation

In [17]:
# Check if data files exist
import os
import pandas as pd

for ticker in config['tickers']:
    market_data_path = f"data/processed/{ticker}_full_data.csv"
    news_data_path = f"data/processed/{ticker}_news.csv"
    
    print(f"\nChecking data for {ticker}:")
    
    if os.path.exists(market_data_path):
        df = pd.read_csv(market_data_path, index_col=0, parse_dates=True)
        print(f"  Market data: ✓ ({len(df)} rows, {df.index.min()} to {df.index.max()})")
    else:
        print(f"  Market data: ✗ NOT FOUND at {market_data_path}")
    
    if os.path.exists(news_data_path):
        news_df = pd.read_csv(news_data_path, index_col=0, parse_dates=True)
        print(f"  News data: ✓ ({len(news_df)} rows)")
    else:
        print(f"  News data: ✗ NOT FOUND at {news_data_path}")

# If data is missing, you need to run the data preprocessing scripts first!


Checking data for AAPL:
  Market data: ✓ (2215 rows, 2012-03-14 00:00:00-04:00 to 2020-12-30 00:00:00-05:00)
  News data: ✓ (114 rows)


  news_df = pd.read_csv(news_data_path, index_col=0, parse_dates=True)


In [None]:
# Download data if not present using the existing DataDownloader class
import sys
import os

# Add scripts to path to import DataDownloader
project_root = os.getcwd()
if project_root not in sys.path:
    sys.path.append(project_root)

from scripts.download_data import DataDownloader

# Check if data needs to be downloaded
data_missing = False
for ticker in config['tickers']:
    market_data_path = f"data/processed/{ticker}_full_data.csv"
    news_data_path = f"data/processed/{ticker}_news.csv"
    
    if not os.path.exists(market_data_path) or not os.path.exists(news_data_path):
        data_missing = True
        break

if data_missing:
    print(f"\n{'='*60}")
    print("Downloading missing data...")
    print(f"{'='*60}")
    
    # Configure downloader with same parameters as training config
    download_config = {
        'tickers': config['tickers'],
        'start_date': '2012-01-01',
        'end_date': '2020-12-31',
    }
    
    downloader = DataDownloader(download_config)
    datasets = downloader.prepare_full_dataset()
    
    print(f"\n{'='*60}")
    print("✓ Data download complete!")
    print(f"{'='*60}")
else:
    print(f"\n{'='*60}")
    print("✓ All data already exists")
    print(f"{'='*60}")

## 4. Import Project Modules

In [18]:
# Add src to path
import sys
import os

project_root = os.getcwd()
if project_root not in sys.path:
    sys.path.append(project_root)

# Import project modules
from src.llm_agents.strategist_agent import StrategistAgent
from src.llm_agents.analyst_agent import AnalystAgent
from src.rl_agents.trading_env import TradingEnv
from src.hybrid_model.ensemble_controller import ReWTSEnsembleController
from src.utils.data_utils import load_market_data, load_news_data, filter_news_by_period

import pandas as pd
import numpy as np
from tqdm import tqdm
import pickle

print("Modules imported successfully!")

Modules imported successfully!


## 5. Training Pipeline

### 5.1 Load Data

In [19]:
# Load data for the first ticker
ticker = config['tickers'][0]
print(f"Loading data for {ticker}...")

# Use utility functions for robust data loading
market_df = load_market_data(ticker)
news_df = load_news_data(ticker)

print(f"\nMarket data shape: {market_df.shape}")
print(f"News data shape: {news_df.shape}")
print(f"\nMarket data index type: {type(market_df.index)}")
print(f"News data index type: {type(news_df.index)}")
print(f"\nMarket data columns: {list(market_df.columns)}")
print(f"News data columns: {list(news_df.columns)}")
print(f"\nMarket date range: {market_df.index.min()} to {market_df.index.max()}")
print(f"News date range: {news_df.index.min()} to {news_df.index.max()}")

Loading data for AAPL...

Market data shape: (2215, 29)
News data shape: (114, 4)

Market data index type: <class 'pandas.core.indexes.base.Index'>
News data index type: <class 'pandas.core.indexes.datetimes.DatetimeIndex'>

Market data columns: ['Open', 'High', 'Low', 'Close', 'Volume', 'Dividends', 'Stock Splits', 'HV_Close', 'SPX_Close', 'VIX_Close', 'SMA_20', 'SMA_50', 'SMA_100', 'SMA_200', 'RSI', 'MACD', 'MACD_Signal', 'MACD_Hist', 'ATR', 'SMA_20_Slope', 'SMA_50_Slope', 'SMA_100_Slope', 'SMA_200_Slope', 'PE_Ratio', 'Debt_to_Equity', 'Current_Ratio', 'ROE', 'Gross_Margin', 'Operating_Margin']
News data columns: ['Unnamed: 0', 'headline', 'summary', 'source']

Market date range: 2012-03-14 00:00:00-04:00 to 2020-12-30 00:00:00-05:00
News date range: 2012-01-03 05:00:00 to 2020-12-24 05:00:00


### 5.2 Pre-compute LLM Strategies

This step generates strategic signals using the Gemini LLM. It may take some time depending on the data size and API rate limits.

In [None]:
import time
import re

def precompute_llm_strategies(ticker, market_df, news_df, config, resume=True):
    """
    Pre-compute LLM strategies for the entire period with automatic retry and progress saving
    
    Args:
        ticker: Stock ticker symbol
        market_df: Market data DataFrame
        news_df: News data DataFrame
        config: Configuration dictionary
        resume: If True, resume from saved progress
    """
    
    print(f"\n{'='*60}")
    print(f"Pre-computing LLM Strategies for {ticker}")
    print(f"{'='*60}")
    
    # Setup checkpoint directory
    checkpoint_dir = 'data/llm_strategies/checkpoints'
    os.makedirs(checkpoint_dir, exist_ok=True)
    checkpoint_path = f"{checkpoint_dir}/{ticker}_strategies_checkpoint.pkl"
    
    # Try to resume from checkpoint
    strategies = []
    start_idx = 0
    
    if resume and os.path.exists(checkpoint_path):
        try:
            with open(checkpoint_path, 'rb') as f:
                checkpoint_data = pickle.load(f)
                strategies = checkpoint_data['strategies']
                start_idx = checkpoint_data['next_idx']
            print(f"✓ Resuming from checkpoint: {len(strategies)} strategies already generated")
            print(f"  Starting from strategy #{start_idx}")
        except Exception as e:
            print(f"Warning: Could not load checkpoint: {e}")
            print("  Starting from scratch")
    
    # Initialize agents
    strategist = StrategistAgent(config['llm'])
    analyst = AnalystAgent(config['llm'])
    
    # Generate strategies at specified frequency
    strategy_frequency = config.get('strategy_frequency', 20)
    num_strategies = len(market_df) // strategy_frequency
    
    print(f"\nTotal data points: {len(market_df)}")
    print(f"Strategy frequency: every {strategy_frequency} days")
    print(f"Total strategies to generate: {num_strategies}")
    print(f"Remaining: {num_strategies - start_idx}")
    
    # API rate limiting parameters
    # Adjust based on tier: paid tier has much higher limits
    project_id = config['llm'].get('project_id')
    if project_id:
        # Paid tier: ~2000 requests/minute for gemini-2.0-flash
        requests_per_minute = 60  # Conservative limit for paid tier
        print(f"\n✓ Using paid tier project: {project_id}")
    else:
        # Free tier: 10 requests/minute
        requests_per_minute = 10
        print("\n⚠ Using free tier (no project ID)")
    
    delay_between_requests = 60.0 / requests_per_minute + 0.2  # Add 0.2s buffer
    
    print(f"\nAPI Rate Limiting:")
    print(f"  Requests per minute: {requests_per_minute}")
    print(f"  Delay between requests: {delay_between_requests:.1f}s")
    print(f"  Estimated time: {(num_strategies - start_idx) * delay_between_requests / 60:.1f} minutes")
    
    for i in tqdm(range(start_idx, num_strategies), desc="Generating LLM strategies", initial=start_idx, total=num_strategies):
        strategy_start_idx = i * strategy_frequency
        strategy_end_idx = min((i + 1) * strategy_frequency, len(market_df))
        
        # Data for this strategy period
        period_data = market_df.iloc[strategy_start_idx:strategy_end_idx]
        
        # Filter news for this period using utility function
        period_news = filter_news_by_period(
            news_df,
            period_data.index[0],
            period_data.index[-1]
        )
        
        # Process news with Analyst Agent with retry logic
        max_retries = 3
        retry_count = 0
        news_signals = None
        
        while retry_count < max_retries and news_signals is None:
            try:
                if len(period_news) > 0:
                    news_signals = analyst.process_news(period_news.to_dict('records'))
                else:
                    # No news available for this period
                    news_signals = {
                        'sentiment': 'neutral',
                        'confidence': 0.5,
                        'key_topics': []
                    }
                
                # Add delay after successful API call
                time.sleep(delay_between_requests)
                
            except Exception as e:
                error_msg = str(e)
                
                # Check if it's a rate limit error
                if '429' in error_msg or 'quota' in error_msg.lower():
                    # Extract wait time from error message if available
                    wait_time = 15 if not project_id else 5  # Shorter wait for paid tier
                    match = re.search(r'retry in (\d+\.?\d*)s', error_msg)
                    if match:
                        wait_time = float(match.group(1)) + 1  # Add 1s buffer
                    
                    print(f"\n⚠ Rate limit reached. Waiting {wait_time:.0f}s...")
                    time.sleep(wait_time)
                    retry_count += 1
                else:
                    # Other error, use neutral sentiment
                    print(f"\n⚠ Error processing news: {error_msg}")
                    news_signals = {
                        'sentiment': 'neutral',
                        'confidence': 0.5,
                        'key_topics': []
                    }
                    break
        
        # If still failed, use neutral
        if news_signals is None:
            news_signals = {
                'sentiment': 'neutral',
                'confidence': 0.5,
                'key_topics': []
            }
        
        # Prepare input for Strategist
        market_data = {
            'timestamp': str(period_data.index[-1]),
            'Close': float(period_data['Close'].iloc[-1]),
            'Volume': float(period_data['Volume'].iloc[-1]),
            'Weekly_Returns': period_data['Close'].pct_change().tail(20).tolist(),
            'HV_Close': float(period_data.get('HV_Close', pd.Series([0])).iloc[-1]),
            'IV_Close': float(period_data.get('IV_Close', pd.Series([0])).iloc[-1]),
            'Beta': 1.0,
            'Classification': 'Growth'
        }
        
        fundamentals = {
            'current_ratio': float(period_data.get('Current_Ratio', pd.Series([1.5])).iloc[-1]),
            'debt_to_equity': float(period_data.get('Debt_to_Equity', pd.Series([0.5])).iloc[-1]),
            'pe_ratio': float(period_data.get('PE_Ratio', pd.Series([20])).iloc[-1]),
            'gross_margin': float(period_data.get('Gross_Margin', pd.Series([0.4])).iloc[-1]),
            'operating_margin': float(period_data.get('Operating_Margin', pd.Series([0.2])).iloc[-1]),
            'eps_yoy': 0.1,
            'net_income_yoy': 0.1
        }
        
        analytics = {
            'ma_20': float(period_data['SMA_20'].iloc[-1]),
            'ma_50': float(period_data['SMA_50'].iloc[-1]),
            'ma_200': float(period_data['SMA_200'].iloc[-1]),
            'ma_20_slope': float(period_data['SMA_20_Slope'].iloc[-1]),
            'ma_50_slope': float(period_data['SMA_50_Slope'].iloc[-1]),
            'rsi': float(period_data['RSI'].iloc[-1]),
            'macd': float(period_data['MACD'].iloc[-1]),
            'macd_signal': float(period_data['MACD_Signal'].iloc[-1]),
            'atr': float(period_data['ATR'].iloc[-1])
        }
        
        macro_data = {
            'SPX_Close': float(period_data.get('SPX_Close', pd.Series([0])).iloc[-1]),
            'SPX_Slope': float(period_data['SPX_Close'].diff().iloc[-1]) if 'SPX_Close' in period_data else 0.0,
            'VIX_Close': float(period_data.get('VIX_Close', pd.Series([0])).iloc[-1]),
            'VIX_Slope': float(period_data['VIX_Close'].diff().iloc[-1]) if 'VIX_Close' in period_data else 0.0,
            'GDP_QoQ': 0.0,
            'PMI': 50.0,
            'PPI_YoY': 0.0,
            'Treasury_YoY': 0.0
        }
        
        # Generate strategy with retry logic
        strategy = None
        retry_count = 0
        
        while retry_count < max_retries and strategy is None:
            try:
                last_strategy = strategies[-1] if strategies else None
                
                strategy = strategist.generate_strategy(
                    market_data=market_data,
                    fundamentals=fundamentals,
                    analytics=analytics,
                    macro_data=macro_data,
                    news_signals=news_signals,
                    last_strategy=last_strategy
                )
                
                # Add delay after successful API call
                time.sleep(delay_between_requests)
                
            except Exception as e:
                error_msg = str(e)
                
                # Check if it's a rate limit error
                if '429' in error_msg or 'quota' in error_msg.lower():
                    wait_time = 15 if not project_id else 5
                    match = re.search(r'retry in (\d+\.?\d*)s', error_msg)
                    if match:
                        wait_time = float(match.group(1)) + 1
                    
                    print(f"\n⚠ Rate limit reached. Waiting {wait_time:.0f}s...")
                    time.sleep(wait_time)
                    retry_count += 1
                else:
                    print(f"\n⚠ Error generating strategy: {error_msg}")
                    raise  # Re-raise non-rate-limit errors
        
        if strategy is None:
            raise Exception("Failed to generate strategy after multiple retries")
        
        strategies.append(strategy)
        
        # Save checkpoint every 10 strategies
        if (i + 1) % 10 == 0 or (i + 1) == num_strategies:
            checkpoint_data = {
                'strategies': strategies,
                'next_idx': i + 1,
                'ticker': ticker,
                'timestamp': time.time()
            }
            with open(checkpoint_path, 'wb') as f:
                pickle.dump(checkpoint_data, f)
    
    print(f"\n{'='*60}")
    print(f"✓ Generated {len(strategies)} strategies")
    print(f"{'='*60}")
    
    # Save final strategies
    os.makedirs('data/llm_strategies', exist_ok=True)
    final_path = f"data/llm_strategies/{ticker}_strategies.pkl"
    with open(final_path, 'wb') as f:
        pickle.dump(strategies, f)
    
    print(f"✓ Strategies saved to {final_path}")
    
    # Clean up checkpoint
    if os.path.exists(checkpoint_path):
        os.remove(checkpoint_path)
        print(f"✓ Checkpoint cleaned up")
    
    return strategies

# Generate strategies with auto-resume
strategies = precompute_llm_strategies(ticker, market_df, news_df, config, resume=True)

# Display sample strategies
print(f"\nSample strategies:")
for i, strategy in enumerate(strategies[:3]):
    print(f"\nStrategy {i+1}:")
    print(f"  Direction: {strategy.direction}")
    print(f"  Strength: {strategy.strength:.2f}")
    print(f"  Signal (τ): {(2*strategy.direction-1) * strategy.strength:.2f}")

### 5.3 Train ReWTSE Ensemble

This is the main training loop. It trains multiple DDQN agents on different time chunks and creates an ensemble.

In [None]:
def train_rewts_ensemble(ticker, market_df, strategies, config):
    """Train ReWTSE ensemble of DDQN agents"""
    
    print(f"\n{'='*60}")
    print(f"Training ReWTSE Ensemble for {ticker}")
    print(f"{'='*60}")
    
    # Initialize ensemble controller
    ensemble = ReWTSEnsembleController(config['rewts'])
    
    # Determine number of chunks
    chunk_length = config['rewts']['chunk_length']
    num_chunks = len(market_df) // chunk_length
    
    print(f"Total data points: {len(market_df)}")
    print(f"Chunk length: {chunk_length}")
    print(f"Number of chunks: {num_chunks}")
    print(f"Episodes per chunk: {config['rewts']['episodes_per_chunk']}")
    
    # Train a DDQN for each chunk
    for chunk_id in range(num_chunks):
        start_idx = chunk_id * chunk_length
        end_idx = min((chunk_id + 1) * chunk_length, len(market_df))
        
        # Extract chunk data
        chunk_df = market_df.iloc[start_idx:end_idx].copy()
        
        # LLM strategies for this chunk
        strategy_start_idx = start_idx // config['strategy_frequency']
        strategy_end_idx = end_idx // config['strategy_frequency']
        chunk_strategies = strategies[strategy_start_idx:strategy_end_idx]
        
        # Ensure we have strategies
        if len(chunk_strategies) == 0:
            print(f"Warning: No strategies for chunk {chunk_id}, skipping")
            continue
        
        print(f"\nChunk {chunk_id}: {len(chunk_df)} days, {len(chunk_strategies)} strategies")
        
        # Create environment for the chunk
        env = TradingEnv(chunk_df, chunk_strategies, config['trading_env'])
        
        # Train DDQN agent
        agent = ensemble.train_chunk_model(
            chunk_id=chunk_id,
            env=env,
            num_episodes=config['rewts']['episodes_per_chunk']
        )
        
        ensemble.chunk_models.append(agent)
    
    print(f"\n{'='*60}")
    print(f"✓ Ensemble training complete!")
    print(f"  Total chunk models: {len(ensemble.chunk_models)}")
    print(f"{'='*60}")
    
    return ensemble

# Train the ensemble
ensemble = train_rewts_ensemble(ticker, market_df, strategies, config)

### 5.4 Save Trained Model

In [None]:
# Save the ensemble model
os.makedirs('models', exist_ok=True)
model_path = f"models/{ticker}_rewts_ensemble.pkl"

with open(model_path, 'wb') as f:
    pickle.dump(ensemble, f)

print(f"✓ Ensemble model saved to {model_path}")

# Save to Google Drive (only on Colab)
if IN_COLAB:
    try:
        drive_models_path = '/content/drive/MyDrive/Papers/models'
        os.makedirs(drive_models_path, exist_ok=True)
        
        drive_model_path = f"{drive_models_path}/{ticker}_rewts_ensemble.pkl"
        with open(drive_model_path, 'wb') as f:
            pickle.dump(ensemble, f)
        
        print(f"✓ Ensemble model also saved to Google Drive: {drive_model_path}")
    except Exception as e:
        print(f"Warning: Could not save to Drive: {e}")
else:
    print(f"✓ Model saved locally in the project directory")

## 6. Quick Evaluation

In [None]:
# Quick evaluation on training data
print("Evaluating ensemble on training data...\n")

# Create evaluation environment
eval_env = TradingEnv(market_df, strategies, config['trading_env'])

# Initialize weights (uniform for now)
if len(ensemble.chunk_models) > 0:
    ensemble.current_weights = np.ones(len(ensemble.chunk_models)) / len(ensemble.chunk_models)

state = eval_env.reset()
done = False
total_reward = 0
actions_taken = []

while not done:
    # Get ensemble action
    action, _ = ensemble.predict_ensemble(state)
    actions_taken.append(action)
    
    # Execute action
    state, reward, done, _ = eval_env.step(action)
    total_reward += reward

# Results
final_portfolio_value = eval_env.portfolio_value
initial_balance = config['trading_env']['initial_balance']
total_return = (final_portfolio_value - initial_balance) / initial_balance * 100

print(f"{'='*60}")
print(f"Training Data Evaluation Results:")
print(f"{'='*60}")
print(f"Initial Balance: ${initial_balance:,.2f}")
print(f"Final Portfolio Value: ${final_portfolio_value:,.2f}")
print(f"Total Return: {total_return:.2f}%")
print(f"Total Reward: {total_reward:.4f}")
print(f"{'='*60}")

# Action distribution
action_names = ['SHORT', 'HOLD', 'LONG']
action_counts = np.bincount(actions_taken, minlength=3)
print(f"\nAction Distribution:")
for i, name in enumerate(action_names):
    pct = action_counts[i] / len(actions_taken) * 100
    print(f"  {name}: {action_counts[i]} ({pct:.1f}%)")

## 7. Visualization

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style('darkgrid')

# Plot portfolio value over time
plt.figure(figsize=(14, 6))

plt.subplot(1, 2, 1)
plt.plot(eval_env.portfolio_history, label='Portfolio Value', linewidth=2)
plt.axhline(y=initial_balance, color='r', linestyle='--', label='Initial Balance')
plt.title(f'{ticker} - Portfolio Value Over Time', fontsize=14, fontweight='bold')
plt.xlabel('Trading Steps')
plt.ylabel('Portfolio Value ($)')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
returns = np.diff(eval_env.portfolio_history) / eval_env.portfolio_history[:-1]
plt.hist(returns, bins=50, alpha=0.7, color='blue', edgecolor='black')
plt.axvline(x=0, color='r', linestyle='--', linewidth=2)
plt.title('Return Distribution', fontsize=14, fontweight='bold')
plt.xlabel('Return')
plt.ylabel('Frequency')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Performance metrics
sharpe_ratio = np.mean(returns) / np.std(returns) * np.sqrt(252) if np.std(returns) > 0 else 0
max_drawdown = np.min(returns) if len(returns) > 0 else 0

print(f"\nPerformance Metrics:")
print(f"  Sharpe Ratio (annualized): {sharpe_ratio:.2f}")
print(f"  Max Drawdown: {max_drawdown*100:.2f}%")
print(f"  Volatility (annualized): {np.std(returns)*np.sqrt(252)*100:.2f}%")

## 8. Next Steps

After training, you can:

1. **Backtest**: Run the backtesting script to evaluate on test data
2. **Add more tickers**: Modify `config['tickers']` to train on multiple stocks
3. **Tune hyperparameters**: Adjust learning rate, episodes, chunk length, etc.
4. **Paper trading**: Use the Alpaca API to test on live data without real money
5. **Analyze ensemble weights**: Examine how the QP optimization weights different chunks

**Important Notes:**
- Training may take several hours depending on data size and configuration
- For production use, train on more episodes and larger datasets
- Always validate on out-of-sample test data before real trading
- Consider implementing proper train/validation/test splits
- Monitor API usage and costs for Gemini calls

In [None]:
# Save training checkpoint with metadata
import datetime

checkpoint = {
    'ensemble': ensemble,
    'strategies': strategies,
    'config': config,
    'ticker': ticker,
    'training_date': datetime.datetime.now().isoformat(),
    'data_period': {
        'start': str(market_df.index.min()),
        'end': str(market_df.index.max()),
        'num_days': len(market_df)
    },
    'performance': {
        'final_value': final_portfolio_value,
        'total_return': total_return,
        'sharpe_ratio': sharpe_ratio
    }
}

checkpoint_path = f"models/{ticker}_checkpoint_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.pkl"
with open(checkpoint_path, 'wb') as f:
    pickle.dump(checkpoint, f)

print(f"✓ Training checkpoint saved to {checkpoint_path}")