In [None]:
import numpy as np
import pandas as pd
from tqdm import tqdm
from joblib import Parallel, delayed
from datetime import datetime as dt
import datetime as datetime
from collections import defaultdict
import os
import time
import traceback
import gc
from itertools import product
import sys
from scipy import sparse
import pickle
import uuid
import json

In [None]:
import uuid
import numpy as np

class Trade:
    """
    Class to represent a pair trade with all relevant information.
    
    Tracks entry and exit information, calculates PnL, and handles financing costs.
    """
    # Valid trading sides
    VALID_SIDES = ['short_black_long_white', 'long_black_short_white']
    
    # Financing cost parameters
    DEFAULT_SHORT_SPREAD = 0.01  # 100 bps over Fed Funds
    DEFAULT_LONG_SPREAD = 0.015  # 150 bps over Fed Funds
    DAYS_PER_YEAR = 365
    
    def __init__(self, entry_date, permno_black, permno_white, side, z_diff_entry,
                 investment_black, investment_white, shares_black, shares_white,
                 entry_price_black, entry_price_white, entry_transaction_cost,
                 zscore_method, horizon, lookback, 
                 short_spread=DEFAULT_SHORT_SPREAD, 
                 long_spread=DEFAULT_LONG_SPREAD):
        """Initialize a new trade."""
        # Validate inputs
        self._validate_inputs(entry_date, side, investment_black, investment_white, 
                              shares_black, shares_white, entry_price_black, 
                              entry_price_white, entry_transaction_cost)
        
        # Generate unique trade ID
        self.trade_id = str(uuid.uuid4())
        
        # Entry information
        self.entry_date = entry_date
        self.permno_black = permno_black
        self.permno_white = permno_white
        self.side = side
        self.z_diff_entry = z_diff_entry
        self.investment_black = investment_black
        self.investment_white = investment_white
        self.shares_black = shares_black
        self.shares_white = shares_white
        self.entry_price_black = entry_price_black
        self.entry_price_white = entry_price_white
        self.entry_transaction_cost = entry_transaction_cost
        
        # Parameters used for this trade
        self.zscore_method = zscore_method
        self.horizon = horizon
        self.lookback = lookback
        
        # Financing parameters
        self.short_spread = short_spread
        self.long_spread = long_spread
        
        # Status tracking
        self.status = 'open'
        self.exit_date = None
        self.days_held = 0
        
        # Exit information
        self.exit_price_black = None
        self.exit_price_white = None
        self.exit_transaction_cost = None
        self.financing_cost = 0
        self.z_diff_exit = None
        self.exit_reason = None
        self.gross_pnl = None
        self.net_pnl = None
        
        # Risk metrics
        self.max_drawdown = 0.0
        self.peak_value = self.investment_black + self.investment_white
        self.current_value = self.peak_value
        
        # Track daily values for analysis
        self.daily_values = {entry_date: self.peak_value}
        self.daily_pnl = {}
    
    def _validate_inputs(self, entry_date, side, investment_black, investment_white, 
                     shares_black, shares_white, price_black, price_white, 
                     transaction_cost):
        """Validate inputs to ensure they make sense."""
        # Check side is valid
        if side not in self.VALID_SIDES:
            raise ValueError(f"side must be one of {self.VALID_SIDES}, got {side}")
        
        # Check investments are positive
        if investment_black <= 0 or investment_white <= 0:
            raise ValueError("Investment amounts must be positive")
        
        # Check shares are positive integers
        if not isinstance(shares_black, int) or shares_black <= 0:
            raise ValueError(f"shares_black must be a positive integer, got {shares_black}")
        if not isinstance(shares_white, int) or shares_white <= 0:
            raise ValueError(f"shares_white must be a positive integer, got {shares_white}")
        
        # Check prices are positive
        if price_black <= 0 or price_white <= 0:
            raise ValueError("Prices must be positive")
        
        # Check transaction cost is non-negative
        if transaction_cost < 0:
            raise ValueError("Transaction cost cannot be negative")
    
    def update_daily_financing(self, current_date, fed_funds_rate):
        """Update daily financing costs for open positions."""
        # Calculate daily financing cost based on position side
        if self.side == 'short_black_long_white':
            # Short black (credit at short rate), Long white (debit at long rate)
            black_daily_cost = -self.investment_black * (fed_funds_rate + self.short_spread) / self.DAYS_PER_YEAR
            white_daily_cost = self.investment_white * (fed_funds_rate + self.long_spread) / self.DAYS_PER_YEAR
        else:
            # Long black (debit at long rate), Short white (credit at short rate)
            black_daily_cost = self.investment_black * (fed_funds_rate + self.long_spread) / self.DAYS_PER_YEAR
            white_daily_cost = -self.investment_white * (fed_funds_rate + self.short_spread) / self.DAYS_PER_YEAR
        
        daily_financing = black_daily_cost + white_daily_cost
        self.financing_cost += daily_financing
        
        # Increment days held
        self.days_held += 1
        
        return daily_financing
    
    def update_market_value(self, current_date, current_price_black, current_price_white):
        """Update the market value of the position and track drawdowns."""
        # Calculate current value of both positions
        if self.side == 'short_black_long_white':
            # Short black, long white
            black_value = self.investment_black - (self.shares_black * (current_price_black - self.entry_price_black))
            white_value = self.investment_white + (self.shares_white * (current_price_white - self.entry_price_white))
        else:
            # Long black, short white
            black_value = self.investment_black + (self.shares_black * (current_price_black - self.entry_price_black))
            white_value = self.investment_white - (self.shares_white * (current_price_white - self.entry_price_white))
        
        # Calculate current total value
        current_value = black_value + white_value
        
        # Calculate unrealized PnL
        unrealized_pnl = current_value - (self.investment_black + self.investment_white)
        
        # Update peak value if current value is higher
        if current_value > self.peak_value:
            self.peak_value = current_value
        
        # Update max drawdown if current drawdown is larger
        current_drawdown = (self.peak_value - current_value) / self.peak_value
        if current_drawdown > self.max_drawdown:
            self.max_drawdown = current_drawdown
        
        # Store daily values
        self.daily_values[current_date] = current_value
        self.daily_pnl[current_date] = unrealized_pnl
        
        # Update current value
        self.current_value = current_value
        
        return current_value
    
    def close_trade(self, exit_date, exit_price_black, exit_price_white, exit_reason, z_diff_exit):
        """Close the trade and calculate PnL."""
        self.exit_date = exit_date
        self.exit_price_black = exit_price_black
        self.exit_price_white = exit_price_white
        self.exit_reason = exit_reason
        self.z_diff_exit = z_diff_exit
        
        # Calculate transaction costs for exit (0.01 per share)
        self.exit_transaction_cost = 0.01 * (self.shares_black + self.shares_white)
        
        # Calculate PnL for each leg
        if self.side == 'short_black_long_white':
            # Short black: profit when price falls
            black_pnl = -self.shares_black * (exit_price_black - self.entry_price_black)
            # Long white: profit when price rises
            white_pnl = self.shares_white * (exit_price_white - self.entry_price_white)
        else:  # long_black_short_white
            # Long black: profit when price rises
            black_pnl = self.shares_black * (exit_price_black - self.entry_price_black)
            # Short white: profit when price falls
            white_pnl = -self.shares_white * (exit_price_white - self.entry_price_white)
        
        # Calculate gross and net PnL
        self.gross_pnl = black_pnl + white_pnl
        
        # Total costs include entry and exit transaction costs plus financing
        total_costs = self.entry_transaction_cost + self.exit_transaction_cost + self.financing_cost
        
        # Calculate net PnL
        self.net_pnl = self.gross_pnl - total_costs
        
        # Update status
        self.status = 'closed'
        
        # Final update to daily values
        self.daily_values[exit_date] = self.investment_black + self.investment_white + self.net_pnl
        
        # Calculate ROI
        self.roi = self.net_pnl / (self.investment_black + self.investment_white)
        
        return self.net_pnl
    
    def to_dict(self):
        """Convert trade object to dictionary for logging and analysis."""
        return {
            'trade_id': self.trade_id,
            'entry_date': self.entry_date,
            'exit_date': self.exit_date,
            'permno_black': self.permno_black,
            'permno_white': self.permno_white,
            'side': self.side,
            'z_diff_entry': self.z_diff_entry,
            'z_diff_exit': self.z_diff_exit,
            'investment_black': self.investment_black,
            'investment_white': self.investment_white,
            'shares_black': self.shares_black,
            'shares_white': self.shares_white,
            'entry_price_black': self.entry_price_black,
            'entry_price_white': self.entry_price_white,
            'exit_price_black': self.exit_price_black,
            'exit_price_white': self.exit_price_white,
            'entry_transaction_cost': self.entry_transaction_cost,
            'exit_transaction_cost': self.exit_transaction_cost,
            'financing_cost': self.financing_cost,
            'gross_pnl': self.gross_pnl,
            'net_pnl': self.net_pnl,
            'roi': getattr(self, 'roi', None),
            'exit_reason': self.exit_reason,
            'status': self.status,
            'days_held': self.days_held,
            'max_drawdown': self.max_drawdown,
            'zscore_method': self.zscore_method,
            'horizon': self.horizon,
            'lookback': self.lookback
        }

In [None]:
import numpy as np
import pandas as pd
from joblib import Parallel, delayed

class SignalGenerator:
    def __init__(self, df_main, df_pairs, zscore_method='ou', zscore_threshold=1.5, horizon=5, lookback_period=20):
        self.df_main = df_main
        self.df_pairs = df_pairs
        self.zscore_method = zscore_method
        self.zscore_threshold = zscore_threshold
        self.lookback_period = lookback_period
        self.precomputed_signals = None
        self.horizon = horizon
        
        # Add diagnostic print
        print(f"SignalGenerator initialized with {len(df_pairs)} pairs")
        print(f"Z-score method: {zscore_method}, threshold: {zscore_threshold}, lookback: {lookback_period}")
        
        # Check sample z-score values
        z_col = f'z_{zscore_method}_{self.horizon}d_lb{lookback_period}'
        if z_col in df_main.columns:
            z_values = df_main[z_col].dropna()
            print(f"Z-score column '{z_col}' stats:")
            print(f"  - Non-null values: {len(z_values)} out of {len(df_main)} ({len(z_values)/len(df_main)*100:.2f}%)")
            print(f"  - Range: {z_values.min():.4f} to {z_values.max():.4f}")
            print(f"  - Values exceeding threshold {zscore_threshold}: {(abs(z_values) >= zscore_threshold).sum()} ({(abs(z_values) >= zscore_threshold).sum()/len(z_values)*100:.2f}%)")
        else:
            print(f"WARNING: Z-score column '{z_col}' not found in data!")

    def precompute_signals_parallel(self, horizon=5, n_jobs=4):
        """Precompute signals for all dates and pairs in parallel"""
        z_col = f'z_{self.zscore_method}_{horizon}d_lb{self.lookback_period}'
        print(f"Precomputing signals for z-score column: {z_col}")
        
        # Verify z-score column exists
        if z_col not in self.df_main.columns:
            print(f"ERROR: Z-score column '{z_col}' not found in data columns!")
            print(f"Available columns: {self.df_main.columns}")
            return
        
        # Group by group_id for efficient processing
        group_ids = self.df_pairs['group_id'].unique()
        print(f"Processing {len(group_ids)} unique group_ids")
        
        chunk_size = max(1, len(group_ids) // n_jobs)
        chunked_groups = [group_ids[i:i + chunk_size] for i in range(0, len(group_ids), chunk_size)]
        
        # Precompute group dictionaries
        print("Building group dictionaries...")
        group_df_main_dict = {}
        for group_id in group_ids:
            group_data = self.df_main[self.df_main['group_id'] == group_id]
            if 'permno' in group_data.columns and z_col in group_data.columns and 'date' in group_data.columns:
                filtered_data = group_data[['permno', z_col, 'date']].dropna()
                group_df_main_dict[group_id] = filtered_data
                if len(filtered_data) < 10 and len(filtered_data) > 0:
                    print(f"  Group {group_id}: Only {len(filtered_data)} records with valid z-scores")
            else:
                print(f"  WARNING: Missing required columns for group {group_id}")
        
        print(f"Created dictionaries for {len(group_df_main_dict)} groups")
        
        # Process chunks in parallel
        all_results = []
        signal_counts = []
        
        for chunk_idx, chunk in enumerate(chunked_groups):
            # Process each group in parallel within the chunk
            parallel_results = Parallel(n_jobs=n_jobs, prefer="threads")(
                delayed(self._process_group_signal)(
                    group_id,
                    group_df_main_dict,
                    self.df_pairs[self.df_pairs['group_id'] == group_id],
                    z_col,
                    self.zscore_threshold,
                    horizon
                )
                for group_id in chunk
            )
            
            # Count signals in this chunk
            chunk_signals = 0
            # Append non-empty results
            for df in parallel_results:
                if not df.empty:
                    all_results.append(df)
                    chunk_signals += len(df)
            
            signal_counts.append(chunk_signals)
            print(f"Chunk {chunk_idx+1}: Generated {chunk_signals} signals")
        
        # Concatenate results
        if all_results:
            self.precomputed_signals = pd.concat(all_results).reset_index(drop=True)
            print(f"Total signals generated: {len(self.precomputed_signals)}")
            print(f"Signals span {self.precomputed_signals['date'].nunique()} unique trading days")
            
            # Distribution of signals
            if len(self.precomputed_signals) > 0:
                signal_counts = self.precomputed_signals['signal'].value_counts()
                print(f"Signal distribution: {dict(signal_counts)}")
        else:
            print("WARNING: No signals were generated!")
            self.precomputed_signals = pd.DataFrame()

    def _process_group_signal(self, group_id, group_df_main_dict, df_pairs_group, z_col, zscore_threshold, horizon):
        """Process signals for a specific group (used for parallel processing)"""
        if group_id not in group_df_main_dict:
            return pd.DataFrame()
        
        group_df_main = group_df_main_dict[group_id]
        
        if group_df_main.empty or df_pairs_group.empty:
            return pd.DataFrame()
        
        # Create efficient lookups
        permnos = group_df_main['permno'].values
        z_scores = group_df_main[z_col].values
        dates = group_df_main['date'].unique()
        
        # Create lookup dictionaries
        z_map = dict(zip(permnos, z_scores))
        
        # Process in chunks for memory efficiency
        chunk_size = 1000
        results = []
        
        for i in range(0, len(df_pairs_group), chunk_size):
            df_chunk = df_pairs_group.iloc[i:i+chunk_size].copy()
            
            # Map z-scores efficiently
            df_chunk['z_black'] = df_chunk['permno_black'].map(z_map)
            df_chunk['z_white'] = df_chunk['permno_white'].map(z_map)
            
            # Drop rows with NaN z-scores
            df_chunk = df_chunk.dropna(subset=['z_black', 'z_white'])
            
            if df_chunk.empty:
                continue
                
            # Calculate z_diff
            df_chunk['z_diff'] = df_chunk['z_black'] - df_chunk['z_white']
            
            # Process each date
            for date in dates:
                df_date = df_chunk.copy()
                df_date['date'] = date
                
                # Generate signals using vectorized operations
                z_diff_values = df_date['z_diff'].values
                mask_short = z_diff_values >= zscore_threshold
                mask_long = z_diff_values <= -zscore_threshold
                
                if not (np.any(mask_short) or np.any(mask_long)):
                    continue
                    
                signals = np.full(len(z_diff_values), '', dtype=object)
                signals[mask_short] = 'short_black_long_white'
                signals[mask_long] = 'long_black_short_white'
                
                df_date['signal'] = signals
                
                # Filter valid signals only
                df_date = df_date[signals != '']
                
                if len(df_date) > 0:
                    # Add method info for later reference
                    df_date['zscore_method'] = self.zscore_method
                    df_date['horizon'] = horizon
                    df_date['lookback'] = self.lookback_period
                    results.append(df_date)
            
            # Clear memory
            del df_chunk
        
        if not results:
            return pd.DataFrame()
        
        result_df = pd.concat(results)
        return result_df

    def generate_signals(self, date):
        """Get signals for a specific date"""
        if self.precomputed_signals is None or self.precomputed_signals.empty:
            return []

        signals_today = self.precomputed_signals[self.precomputed_signals['date'] == date]
        
        if signals_today.empty:
            return []
        
        result = signals_today[['date', 'permno_black', 'permno_white', 'signal', 'z_diff', 
                               'zscore_method', 'horizon', 'lookback']].to_dict('records')
        
        return result

In [None]:
import numpy as np
import pandas as pd
from scipy import sparse
#from .trade import Trade

class PortfolioManager:
    def __init__(self, df_main, initial_capital, max_holding_days=5):
        self.df_main = df_main
        self.initial_capital = initial_capital
        self.available_capital = initial_capital
        self.max_holding_days = max_holding_days
        self.active_trades = []
        self.trade_history = []
        self.daily_pnl = {}
        self.equity_curve = {pd.Timestamp.min: initial_capital}  # Initialize with starting capital
        
        # Create lookups for efficient access
        self._create_lookups()
        
    def _create_lookups(self):
        """Create efficient lookups for prices and volumes"""
        # Create lookups directly from df_main
        lookup_data = self.df_main.set_index(['date', 'permno'])
        self.price_lookup = lookup_data['adj_prc'].to_dict()
        self.vol_lookup = lookup_data['adv20'].to_dict()
        self.volatility_lookup = lookup_data['garch_vol'].to_dict()
        
        # Single date-indexed dataframe for other lookups
        date_indexed = self.df_main.drop_duplicates('date').set_index('date')
        self.ffr_lookup = date_indexed['fed_funds_rate'].to_dict()
        self.market_return_lookup = date_indexed['vwretd'].to_dict()
    
    def _calculate_max_shares(self, permno, current_date, price, allocated_money=None):      
        """Calculate maximum number of shares based on liquidity and capital"""
        # Default allocated money if not provided
        if allocated_money is None:
            allocated_money = self.available_capital / 10  # Default to 10% of available capital
        
        # Calculate shares based on allocated money
        capital_shares = 0
        if price > 0:
            capital_shares = int(allocated_money / price)
        
        # Get 20-day average volume with proper error handling
        try:
            adv20 = self.vol_lookup.get((current_date, permno), 0)
            
            # Handle NaN, None, or zero values
            if adv20 is None or np.isnan(adv20) or adv20 <= 0:
                # If ADV20 is invalid, return capital-based shares
                return max(1, capital_shares)
            
            # Limit to 10% of average volume
            liquidity_shares = int(adv20 * 0.1)
            
            # Take minimum of capital-based shares and liquidity-based shares
            max_shares = min(capital_shares, liquidity_shares)
            
            # Ensure at least 1 share if we have capital
            return max(1, max_shares) if capital_shares > 0 else 0
            
        except (TypeError, ValueError) as e:
            # Fall back to capital-based shares
            return max(1, capital_shares)
            
    def reset_capital(self, amount):
        """Reset available capital (called at the start of each quarter)"""
        self.available_capital = amount
        
    def process_trading_day(self, current_date, signals, current_data):
        """Process a single trading day"""
        trade_updates = []
        
        # First update financing costs for all active trades
        fed_funds_rate = self.ffr_lookup.get(current_date, 0.02)  # Default to 2% if missing
        for trade in self.active_trades:
            trade.update_daily_financing(current_date, fed_funds_rate)
            
            # Optional: Update market value for active trades (for internal tracking only)
            price_black = self.price_lookup.get((current_date, trade.permno_black))
            price_white = self.price_lookup.get((current_date, trade.permno_white))
            
            if price_black is not None and price_white is not None:
                trade.update_market_value(current_date, price_black, price_white)
        
        # Then check for exits (z-score reversal or max holding period)
        closed_trades = self._process_exits(current_date, current_data)
        
        # Update available capital from closed trades
        for trade in closed_trades:
            # Return the invested capital plus profit/loss
            self.available_capital += (trade.investment_black + trade.investment_white + trade.net_pnl)
            # Add to trade history (only for closed trades)
            self.trade_history.append(trade)
            # Add to trade updates (for logging) - only adding CLOSED trades
            trade_updates.append(trade.to_dict())
        
        # Calculate daily PnL from closed trades only
        day_pnl = sum([trade.net_pnl for trade in closed_trades])
        
        # Then process new entries if we have signals and available capital
        new_trades = self._process_entries(current_date, signals, current_data)
        
        # Update equity curve for accounting purposes
        prev_equity = max(self.equity_curve.values())
        self.equity_curve[current_date] = prev_equity + day_pnl
        
        # Save daily PnL
        self.daily_pnl[current_date] = day_pnl
        
        # Return only the updates for CLOSED trades
        return trade_updates
    
    def _check_exit_conditions(self, trade, current_date, current_z_diff):
        """Check if a trade should be exited based on the specified conditions"""
        # Condition 1: Z-score mean reversion toward zero
        if trade.side == 'short_black_long_white' and current_z_diff <= 0:
            return True, 'mean_reversion'
        elif trade.side == 'long_black_short_white' and current_z_diff >= 0:
            return True, 'mean_reversion'
            
        # Condition 2: Max holding period reached
        if trade.days_held >= self.max_holding_days:
            return True, 'max_holding'
            
        return False, None
    
    def _process_exits(self, current_date, current_data):
        """Check active trades for exit conditions"""
        closed_trades = []
        remaining_trades = []
        
        for trade in self.active_trades:
            # Get stock permnos
            permno_black = trade.permno_black
            permno_white = trade.permno_white
            
            # Get z-scores efficiently
            z_col = f"z_{trade.zscore_method}_{trade.horizon}d_lb{trade.lookback}"
            
            # Check if data exists for both stocks
            black_data = current_data[current_data['permno'] == permno_black]
            white_data = current_data[current_data['permno'] == permno_white]
            
            if black_data.empty or white_data.empty or z_col not in black_data.columns or z_col not in white_data.columns:
                remaining_trades.append(trade)
                continue
                
            # Get current z-scores
            z_black = black_data[z_col].values[0]
            z_white = white_data[z_col].values[0]
            
            # Calculate current z-diff
            current_z_diff = z_black - z_white
            
            # Check exit conditions
            should_exit, exit_reason = self._check_exit_conditions(trade, current_date, current_z_diff)
            
            if should_exit:
                # Get exit prices
                exit_price_black = self.price_lookup.get((current_date, permno_black))
                exit_price_white = self.price_lookup.get((current_date, permno_white))
                
                if exit_price_black is None or exit_price_white is None:
                    # Can't exit if prices are missing, keep the trade
                    remaining_trades.append(trade)
                    continue
                
                # Close the trade
                trade.close_trade(current_date, exit_price_black, exit_price_white, 
                                 exit_reason, current_z_diff)
                
                closed_trades.append(trade)
            else:
                # Keep track of active trades
                remaining_trades.append(trade)
        
        # Update active trades list
        self.active_trades = remaining_trades
        return closed_trades
    
    def _process_entries(self, current_date, signals, current_data):
        """Process new trade entries with liquidity constraints"""
        if not signals or self.available_capital <= 0:
            return []
            
        # Calculate volatility for each pair in signals for position sizing
        pairs_volatility = {}
        total_inv_vol = 0
        
        # Track rejection reasons
        missing_volatility = 0
        
        for sig in signals:
            permno_black = sig['permno_black']
            permno_white = sig['permno_white']
            
            # Get GARCH volatilities from lookup table
            vol_black = self.volatility_lookup.get((current_date, permno_black))
            vol_white = self.volatility_lookup.get((current_date, permno_white))
            
            if vol_black is None or vol_white is None or vol_black == 0 or vol_white == 0:
                missing_volatility += 1
                continue
                
            # Use combined volatility for the pair
            pair_vol = (vol_black + vol_white) / 2
            pair_key = (permno_black, permno_white)
            pairs_volatility[pair_key] = pair_vol
            
            # Calculate inverse volatility
            inv_vol = 1 / pair_vol
            total_inv_vol += inv_vol
        
        if total_inv_vol == 0:
            return []
            
        # Allocate capital by inverse volatility
        capital_allocations = {}
        for pair_key, vol in pairs_volatility.items():
            inv_vol = 1 / vol
            allocation = (inv_vol / total_inv_vol) * self.available_capital
            capital_allocations[pair_key] = allocation
        
        # Execute trades
        executed_trades = []
        capital_used = 0
        
        # Track rejection reasons
        missing_price = 0
        zero_shares = 0
        insufficient_capital = 0
        
        for sig in signals:
            permno_black = sig['permno_black']
            permno_white = sig['permno_white']
            pair_key = (permno_black, permno_white)
            
            if pair_key not in capital_allocations:
                continue
                
            # Get stock prices
            px_b = self.price_lookup.get((current_date, permno_black))
            px_w = self.price_lookup.get((current_date, permno_white))
            
            if px_b is None or px_w is None or px_b <= 0 or px_w <= 0:
                missing_price += 1
                continue
                
            # Allocate capital to the pair
            pair_capital = capital_allocations[pair_key]
            
            # Split capital equally between black and white stocks
            inv_b = inv_w = pair_capital / 2
            
            # Calculate maximum shares based on liquidity constraint (10% of ADV20)
            max_shares_b = self._calculate_max_shares(permno_black, current_date, px_b, inv_b)
            max_shares_w = self._calculate_max_shares(permno_white, current_date, px_w, inv_w)
            
            # Calculate shares based on capital allocation
            capital_shares_b = int(inv_b / px_b)
            capital_shares_w = int(inv_w / px_w)
            
            # Apply liquidity constraint
            sh_b = min(capital_shares_b, max_shares_b) if max_shares_b > 0 else capital_shares_b
            sh_w = min(capital_shares_w, max_shares_w) if max_shares_w > 0 else capital_shares_w
            
            # Skip if not enough shares can be purchased
            if sh_b == 0 or sh_w == 0:
                zero_shares += 1
                continue
                
            # Recalculate actual investment based on constrained shares
            inv_b = sh_b * px_b
            inv_w = sh_w * px_w
            
            # Calculate transaction costs
            entry_tc = 0.01 * (sh_b + sh_w)
            
            # Check if we have enough capital
            total_cost = inv_b + inv_w + entry_tc
            if total_cost > (self.available_capital - capital_used):
                insufficient_capital += 1
                continue
                
            # Record the capital used
            capital_used += total_cost
            
            # Create new trade
            new_trade = Trade(
                entry_date=current_date,
                permno_black=permno_black,
                permno_white=permno_white,
                side=sig['signal'],
                z_diff_entry=sig['z_diff'],
                investment_black=inv_b,
                investment_white=inv_w,
                shares_black=sh_b,
                shares_white=sh_w,
                entry_price_black=px_b,
                entry_price_white=px_w,
                entry_transaction_cost=entry_tc,
                zscore_method=sig.get('zscore_method', 'ou'),
                horizon=sig.get('horizon', 5),
                lookback=sig.get('lookback', 20)
            )
            
            # Add to active trades list
            self.active_trades.append(new_trade)
            executed_trades.append(new_trade)
        
        # Update available capital
        self.available_capital -= capital_used
        
        return executed_trades

    def mark_to_market_open_positions(self, final_date):
        """Close all open positions at the end of the backtest period using latest prices"""
        closed_trades = []
        
        # Skip if no active trades
        if not self.active_trades:
            return []
        
        for trade in self.active_trades:
            # Get exit prices for the final date
            exit_price_black = self.price_lookup.get((final_date, trade.permno_black))
            exit_price_white = self.price_lookup.get((final_date, trade.permno_white))
            
            # Skip if prices are missing
            if exit_price_black is None or exit_price_white is None:
                # Try to find the last available prices
                dates = sorted(self.price_lookup.keys(), key=lambda x: x[0])
                for date, permno in reversed(dates):
                    if date < final_date and permno == trade.permno_black:
                        exit_price_black = self.price_lookup.get((date, permno))
                        break
                
                for date, permno in reversed(dates):
                    if date < final_date and permno == trade.permno_white:
                        exit_price_white = self.price_lookup.get((date, permno))
                        break
                
                # If still no prices, skip this trade
                if exit_price_black is None or exit_price_white is None:
                    continue
            
            # Close the trade with "end_of_period" as reason
            trade.close_trade(final_date, exit_price_black, exit_price_white, 
                             'end_of_period', 0)  # Use 0 as z_diff_exit
            
            # Add to closed trades and trade history
            closed_trades.append(trade)
            self.trade_history.append(trade)
        
        # Update active trades list (should be empty now)
        self.active_trades = []
        
        # Update equity curve with the PnL from these trades
        if closed_trades:
            day_pnl = sum([trade.net_pnl for trade in closed_trades])
            if final_date not in self.equity_curve:
                prev_equity = max(self.equity_curve.values())
                self.equity_curve[final_date] = prev_equity + day_pnl
            else:
                self.equity_curve[final_date] += day_pnl
        
        # Return the closed trades
        return closed_trades
        
    def get_trade_history(self):
        """Return the trade history for analysis"""
        return self.trade_history

In [None]:
import numpy as np
import pandas as pd

def calculate_trade_based_metrics(trade_df, market_returns, ffr_lookup=None, initial_capital=1_000_000_000):
    """
    Calculate performance metrics based solely on trade log data and returns daily returns data
    for graphing in reports.
    
    Parameters:
    -----------
    trade_df : pandas DataFrame
        DataFrame containing trade information with columns: 
        entry_date, exit_date, net_pnl, etc.
    market_returns : dict or Series
        Market returns indexed by date (vwretd values)
    ffr_lookup : dict or None
        Federal Funds Rate lookup by date. If None, will use 0.02 as default.
    initial_capital : float
        Initial capital for calculating returns
        
    Returns:
    --------
    dict : Dictionary of performance metrics and daily returns data
    """
    # Ensure we have trades to analyze
    if len(trade_df) == 0:
        empty_returns = pd.DataFrame(columns=['date', 'return'])
        return {
            'sharpe_ratio': 0,
            'sortino_ratio': 0,
            'alpha': 0,
            'beta': 0,
            'max_drawdown': 0,
            'hit_rate': 0,
            'num_trades': 0,
            'avg_trade_pnl': 0,
            'avg_holding_period': 0,
            'num_trading_days': 0,
            'daily_returns': empty_returns
        }
    
    # Convert dates to datetime if they aren't already
    if not pd.api.types.is_datetime64_dtype(trade_df['exit_date']):
        trade_df['exit_date'] = pd.to_datetime(trade_df['exit_date'])
    if not pd.api.types.is_datetime64_dtype(trade_df['entry_date']):
        trade_df['entry_date'] = pd.to_datetime(trade_df['entry_date'])
    
    # Sort trades by exit date
    trade_df = trade_df.sort_values('exit_date')
    
    # Calculate basic trade metrics
    num_trades = len(trade_df)
    hit_rate = (trade_df['net_pnl'] > 0).mean()
    avg_trade_pnl = trade_df['net_pnl'].mean()
    avg_holding = trade_df['days_held'].mean() if 'days_held' in trade_df.columns else 0
    
    # Get unique trading days (both entry and exit dates)
    trading_days = sorted(set(trade_df['entry_date'].tolist() + 
                             trade_df['exit_date'].dropna().tolist()))
    num_trading_days = len(trading_days)
    
    # Calculate equity curve for performance metrics
    equity_curve = {}
    current_equity = initial_capital
    
    # Group trades by exit date and calculate daily PnL
    for date, group in trade_df.groupby('exit_date'):
        day_pnl = group['net_pnl'].sum()
        current_equity += day_pnl
        equity_curve[date] = current_equity
    
    # Convert to Series for easier manipulation
    equity_series = pd.Series(equity_curve)
    equity_series = equity_series.sort_index()
    
    # Handle case with insufficient data points
    if len(equity_series) <= 1:
        empty_returns = pd.DataFrame(columns=['date', 'return'])
        return {
            'sharpe_ratio': 0,
            'sortino_ratio': 0,
            'alpha': 0,
            'beta': 0,
            'max_drawdown': 0,
            'hit_rate': hit_rate,
            'num_trades': num_trades,
            'avg_trade_pnl': avg_trade_pnl,
            'avg_holding_period': avg_holding,
            'num_trading_days': num_trading_days,
            'daily_returns': empty_returns
        }
    
    # Calculate daily returns
    daily_returns = equity_series.pct_change().fillna(0)
    
    # Create a DataFrame of dates and returns for graphing
    returns_df = pd.DataFrame({
        'date': daily_returns.index,
        'return': daily_returns.values
    })
    
    # Get corresponding market returns
    if isinstance(market_returns, dict):
        market_returns_series = pd.Series({date: market_returns.get(date, 0) 
                                        for date in daily_returns.index})
    else:
        # Assume it's already a Series
        market_returns_series = market_returns.loc[daily_returns.index]
    
    # Calculate average risk-free rate from Fed Funds Rate data if available
    if ffr_lookup is not None:
        # Extract Fed Funds Rates for trading days
        trading_day_rates = [ffr_lookup.get(date, 0) for date in trading_days if date in ffr_lookup]
        
        # Calculate average Fed Funds Rate during trading period
        if trading_day_rates:
            avg_ffr = sum(trading_day_rates) / len(trading_day_rates)
            # Convert annual rate to daily rate based on trading days
            daily_rfr = avg_ffr / num_trading_days
        else:
            # Default to 2% if no rates found
            daily_rfr = 0.02 / num_trading_days
    else:
        # Default to 2% if no FFR lookup provided
        daily_rfr = 0.02 / num_trading_days
    
    # Calculate metrics based on daily returns
    mean_daily_return = daily_returns.mean()
    std_daily_return = daily_returns.std()
    
    # Calculate Sharpe ratio (based on daily returns)
    sharpe_ratio = 0
    if std_daily_return > 0:
        sharpe_ratio = mean_daily_return / std_daily_return * np.sqrt(len(daily_returns))
    
    # Calculate Sortino ratio (based on daily returns)
    downside_daily_returns = daily_returns[daily_returns < 0]
    downside_std_daily = downside_daily_returns.std() if len(downside_daily_returns) > 0 else 0
    sortino_ratio = 0
    if downside_std_daily > 0:
        sortino_ratio = mean_daily_return / downside_std_daily * np.sqrt(len(daily_returns))
    
    # Calculate CAPM metrics (Beta, Alpha)
    beta = 0
    alpha = 0
    if len(daily_returns) > 1 and len(market_returns_series) > 1:
        # Calculate beta
        cov = np.cov(daily_returns, market_returns_series)[0, 1]
        var = np.var(market_returns_series)
        if var > 0:
            beta = cov / var
            
            # Calculate alpha (based on actual trading days)
            expected_return = daily_rfr + beta * (market_returns_series.mean() - daily_rfr)
            alpha = (daily_returns.mean() - expected_return) * len(daily_returns)
    
    # Calculate drawdowns
    peak = equity_series.expanding().max()
    drawdowns = (equity_series - peak) / peak
    max_drawdown = abs(drawdowns.min()) if len(drawdowns) > 0 else 0
    
    # Return comprehensive metrics and daily returns data
    return {
        'sharpe_ratio': sharpe_ratio,
        'sortino_ratio': sortino_ratio,
        'alpha': alpha,
        'beta': beta,
        'max_drawdown': max_drawdown,
        'hit_rate': hit_rate,
        'num_trades': num_trades,
        'avg_trade_pnl': avg_trade_pnl,
        'avg_holding_period': avg_holding,
        'num_trading_days': num_trading_days,
        'daily_returns': returns_df,
        'avg_fed_funds_rate': avg_ffr if ffr_lookup is not None else 0.02  # Include average FFR in the results
    }

In [None]:
import gc
import numpy as np
import pandas as pd
from joblib import Parallel, delayed
from tqdm import tqdm

#from .signal_generator import SignalGenerator
#from .portfolio_manager import PortfolioManager
#from .performance import calculate_trade_based_metrics

def _process_quarter_parallel(quarter, df_main, filtered_pairs, signal_generator, initial_capital, max_holding_days):
    """Process a single quarter in parallel"""
    # Filter data for this quarter
    quarter_data = df_main[df_main['quarter'] == quarter]
    if quarter_data.empty:
        print(f"Warning: No data found for quarter {quarter}")
        return {'quarter': quarter, 'trade_log': [], 'performance': {}}
    
    print(f"Processing calendar quarter {quarter}")
    
    # Extract quarter start and end dates for better reporting
    quarter_dates = sorted(quarter_data['date'].unique())
    quarter_start = quarter_dates[0]
    quarter_end = quarter_dates[-1]
    print(f"Quarter period: {quarter_start.strftime('%Y-%m-%d')} to {quarter_end.strftime('%Y-%m-%d')}")
    
    # Create portfolio manager for this quarter
    portfolio_manager = PortfolioManager(
        quarter_data, 
        initial_capital,
        max_holding_days=max_holding_days
    )
    
    # Reset capital
    portfolio_manager.reset_capital(initial_capital)
    
    # Group data by date for faster access
    date_grouped_data = {date: group for date, group in quarter_data.groupby('date')}
    
    # Process each trading day
    trade_log = []
    signals_count = 0
    trades_count = 0
    
    # Process each trading day in the quarter
    for current_date in quarter_dates:
        # Get signals for the current date
        try:
            signals = signal_generator.generate_signals(current_date)
            signals_count += len(signals)
        except Exception as e:
            print(f"Error generating signals for date {current_date}: {str(e)}")
            signals = []
        
        # Process the trading day
        try:
            day_results = portfolio_manager.process_trading_day(
                current_date, 
                signals,
                date_grouped_data[current_date]
            )
            
            if day_results:
                trade_log.extend(day_results)
                trades_count += len(day_results)
                
        except Exception as e:
            print(f"Error processing trading day {current_date}: {str(e)}")
    
    print(f"Quarter {quarter} summary: {signals_count} signals generated, {trades_count} trades executed")
    
    # Return the trade log for this quarter
    return {
        'quarter': quarter,
        'trade_log': trade_log,
        'performance': {}  # We'll calculate this later from the trade log
    }
    
class BacktestEngine:
    def __init__(self, df_main, df_pairs, hyperparams):
        self.hyperparams = hyperparams
        self.initial_capital = hyperparams['INITIAL_CAPITAL']
        
        # Select specific columns directly instead of filtering
        zscore_method = hyperparams['ZSCORE_METHOD']
        lookback_period = hyperparams['LOOKBACK_PERIOD']
        horizon = hyperparams['HORIZON']
        
        z_col = f'z_{zscore_method}_{horizon}d_lb{lookback_period}'
        base_cols = ['date', 'permno', 'trading_start', 'group_id', 'adj_prc', 'fed_funds_rate', 'adv20', 'vwretd', 'garch_vol']
        future_return_col = f'future_cumret_{horizon}d'
        
        # Select only required columns
        needed_cols = base_cols + [z_col]
        if future_return_col in df_main.columns:
            needed_cols.append(future_return_col)
            
        needed_cols = [col for col in needed_cols if col in df_main.columns]
        
        # Clean data
        self.df_main = df_main[needed_cols].copy()
        self.df_main = self.df_main.replace([np.inf, -np.inf], np.nan)
        self.df_main = self.df_main.dropna()
        
        # Keep a copy of the pairs data
        self.df_pairs = df_pairs
        
        # Pre-process data for faster lookups
        self._preprocess_data()
    
    def _preprocess_data(self):
        """Preprocess data for efficient backtest execution"""
        # Clean data by replacing infinites with NaN and dropping NaN values
        self.df_main = self.df_main.replace([np.inf, -np.inf], np.nan)
        self.df_main = self.df_main.dropna()
        
        # Make sure date is datetime
        if not pd.api.types.is_datetime64_dtype(self.df_main['date']):
            self.df_main['date'] = pd.to_datetime(self.df_main['date'])
        
        # Create a quarter column based on date
        self.df_main['quarter'] = self.df_main['date'].dt.to_period('Q').astype(str)
        
        # Get unique quarters for processing
        self.quarters = sorted(self.df_main['quarter'].unique())
        
        print(f"Identified {len(self.quarters)} calendar quarters for processing")
        
        # Filter pairs based on correlation and cointegration thresholds
        corr_threshold = self.hyperparams['CORRELATION_THRESHOLD']
        coint_threshold = self.hyperparams['COINTEGRATION_THRESHOLD']
        
        # Apply filters if thresholds are provided AND columns exist
        filter_condition = True  # Default to include all pairs
        
        # Check correlation columns
        if corr_threshold is not None:
            if 'corr' in self.df_pairs.columns:
                filter_condition = filter_condition & (self.df_pairs['corr'] >= corr_threshold)
            elif 'correlation' in self.df_pairs.columns:
                filter_condition = filter_condition & (self.df_pairs['correlation'] >= corr_threshold)
            else:
                print("Warning: No correlation column found in pairs data, skipping correlation filter")
        
        # Check cointegration columns
        if coint_threshold is not None:
            if 'coint_pval' in self.df_pairs.columns:
                filter_condition = filter_condition & (self.df_pairs['coint_pval'] <= coint_threshold)
            elif 'pval' in self.df_pairs.columns:
                filter_condition = filter_condition & (self.df_pairs['pval'] <= coint_threshold)
            elif 'p_value' in self.df_pairs.columns:
                filter_condition = filter_condition & (self.df_pairs['p_value'] <= coint_threshold)
            else:
                print("Warning: No cointegration p-value column found in pairs data, skipping cointegration filter")
        
        # Apply the filter
        self.filtered_pairs = self.df_pairs[filter_condition].copy()
        
        # Log preprocessing results
        print(f"Preprocessing complete: {len(self.filtered_pairs)} pairs after filtering")
    
    def run_backtest(self):
        """Run the full backtest using the specified hyperparameters"""
        print("Running pre-backtest diagnostics...")
        self.run_diagnostics()
        
        # Initialize signal generator with the selected parameters
        zscore_method = self.hyperparams['ZSCORE_METHOD']
        zscore_threshold = self.hyperparams['ZSCORE_THRESHOLD']
        lookback_period = self.hyperparams['LOOKBACK_PERIOD']
        horizon = self.hyperparams['HORIZON']
        max_holding_days = self.hyperparams['MAX_HOLDING_DAYS']
        
        # Create optimized dataset with only needed columns
        z_col = f'z_{zscore_method}_{horizon}d_lb{lookback_period}'
        needed_cols = ['date', 'permno', 'quarter', 'group_id', 'adj_prc', 'fed_funds_rate', 
                      'adv20', 'vwretd', 'garch_vol', z_col]
        
        # Check if all needed columns exist
        needed_cols = [col for col in needed_cols if col in self.df_main.columns]
        
        # Only keep needed columns in memory
        self.optimized_df = self.df_main[needed_cols].copy()
        
        # Pre-sort data for faster operations
        self.optimized_df.sort_values(['date', 'permno'], inplace=True)
        
        # Create signal generator with optimized dataset
        signal_generator = SignalGenerator(
            self.optimized_df, 
            self.filtered_pairs,
            zscore_method=zscore_method,
            zscore_threshold=zscore_threshold,
            horizon=horizon,
            lookback_period=lookback_period
        )
        
        # Use fixed n_jobs=4 for parallel processing
        n_jobs = 4
        
        # Precompute signals with progress bar
        signal_generator.precompute_signals_parallel(horizon=horizon, n_jobs=n_jobs)
        
        # Check if we have any quarters to process
        if len(self.quarters) == 0:
            print("No quarters found to process! Check data filtering.")
            # Return empty results
            return {
                'trade_log': pd.DataFrame(),
                'performance': {},
                'hyperparams': self.hyperparams,
                'quarterly_results': {}
            }
        
        # Process quarters in batches to reduce memory pressure
        batch_size = 4  # Adjust based on your system's memory
        all_closed_trades = []
        quarterly_results = {}
        
        for i in range(0, len(self.quarters), batch_size):
            batch_quarters = self.quarters[i:i+batch_size]
            
            # Process quarters in parallel
            n_jobs = 4
            batch_results = Parallel(n_jobs=n_jobs, prefer="threads")(
                delayed(_process_quarter_parallel)(
                    quarter,
                    self.optimized_df,
                    self.filtered_pairs,
                    signal_generator,
                    self.initial_capital,
                    max_holding_days
                )
                for quarter in tqdm(batch_quarters, desc=f"Processing Quarters Batch {i//batch_size+1}")
            )
            
            # Collect results
            for result in batch_results:
                all_closed_trades.extend(result['trade_log'])
                quarterly_results[result['quarter']] = result['performance']
            
            # Force garbage collection after each batch
            gc.collect()
        
        print(f"\nCollected {len(all_closed_trades)} closed trades across all quarters")
        
        # Calculate performance metrics from trade log
        performance_metrics = {}
        
        if len(all_closed_trades) > 0:
            # Convert trade_log to DataFrame for analysis
            trade_df = pd.DataFrame(all_closed_trades) if all_closed_trades else pd.DataFrame()
            
            # Calculate metrics directly from trades
            if len(trade_df) > 0:
                # Create lookups for market returns and Fed Funds Rate from the original data
                date_indexed = self.df_main.drop_duplicates('date').set_index('date')
                market_return_lookup = date_indexed['vwretd'].to_dict()
                ffr_lookup = date_indexed['fed_funds_rate'].to_dict()
                
                # Calculate comprehensive metrics directly from trade log
                performance_metrics = calculate_trade_based_metrics(
                    trade_df=trade_df,
                    market_returns=market_return_lookup,
                    ffr_lookup=ffr_lookup,
                    initial_capital=self.initial_capital
                )
                
                # Save daily returns data to file for graphing
                if 'daily_returns' in performance_metrics:
                    daily_returns_df = performance_metrics['daily_returns']
                    timestamp = time.strftime("%Y%m%d_%H%M%S")
                    daily_returns_file = f'daily_returns_{timestamp}.csv'
                    daily_returns_df.to_csv(daily_returns_file, index=False)
                    print(f"Daily returns data saved to {daily_returns_file}")
                
                # Print metrics
                print(f"\nCalculated metrics from trade data:")
                print(f"Number of trades: {performance_metrics['num_trades']}")
                print(f"Number of unique trading days: {performance_metrics['num_trading_days']}")
                print(f"Average Fed Funds Rate: {performance_metrics['avg_fed_funds_rate']*100:.2f}%")
                print(f"Hit rate: {performance_metrics['hit_rate']*100:.2f}%")
                print(f"Average trade PnL: ${performance_metrics['avg_trade_pnl']:,.2f}")
                print(f"Average holding period: {performance_metrics['avg_holding_period']:.2f} days")
                print(f"Sharpe ratio (based on trading days): {performance_metrics['sharpe_ratio']:.4f}")
                print(f"Sortino ratio (based on trading days): {performance_metrics['sortino_ratio']:.4f}")
                print(f"CAPM Alpha: {performance_metrics['alpha']:.6f}")
                print(f"CAPM Beta: {performance_metrics['beta']:.4f}")
                print(f"Max drawdown: {performance_metrics['max_drawdown']*100:.2f}%")
        else:
            print("No closed trades found, using empty metrics")
            
        # Return combined results
        results = {
            'trade_log': pd.DataFrame(all_closed_trades) if all_closed_trades else pd.DataFrame(),
            'performance': performance_metrics,
            'hyperparams': self.hyperparams,
            'quarterly_results': quarterly_results
        }
        
        return results
        
    def run_diagnostics(self):
        """Run diagnostic checks to identify potential issues"""
        print("\n=== DIAGNOSTIC REPORT ===\n")
        
        # 1. Check for pairs after filtering
        if hasattr(self, "filtered_pairs"):
            print(f"Filtered pairs: {len(self.filtered_pairs)} of {len(self.df_pairs)} original pairs")
            if len(self.filtered_pairs) == 0:
                print("CRITICAL ERROR: No pairs remain after correlation/cointegration filtering!")
                
                # Check correlation threshold
                corr_threshold = self.hyperparams.get('CORRELATION_THRESHOLD')
                if corr_threshold is not None:
                    for corr_col in ['corr', 'correlation']:
                        if corr_col in self.df_pairs.columns:
                            corr_values = self.df_pairs[corr_col].dropna()
                            print(f"{corr_col} stats: min={corr_values.min():.4f}, max={corr_values.max():.4f}, mean={corr_values.mean():.4f}")
                            above_threshold = (corr_values >= corr_threshold).sum()
                            print(f"Values >= {corr_threshold}: {above_threshold} ({above_threshold/len(corr_values)*100:.2f}%)")
                            break
                
                # Check cointegration threshold
                coint_threshold = self.hyperparams.get('COINTEGRATION_THRESHOLD')
                if coint_threshold is not None:
                    for coint_col in ['coint_pval', 'pval', 'p_value']:
                        if coint_col in self.df_pairs.columns:
                            coint_values = self.df_pairs[coint_col].dropna()
                            print(f"{coint_col} stats: min={coint_values.min():.4f}, max={coint_values.max():.4f}, mean={coint_values.mean():.4f}")
                            below_threshold = (coint_values <= coint_threshold).sum()
                            print(f"Values <= {coint_threshold}: {below_threshold} ({below_threshold/len(coint_values)*100:.2f}%)")
                            break
        
        # 2. Check for z-score columns
        zscore_method = self.hyperparams['ZSCORE_METHOD']
        lookback_period = self.hyperparams['LOOKBACK_PERIOD']
        horizon = self.hyperparams['HORIZON']
        
        z_col = f'z_{zscore_method}_{horizon}d_lb{lookback_period}'
        print(f"\nChecking for z-score column: {z_col}")
        
        if z_col in self.df_main.columns:
            z_values = self.df_main[z_col].dropna()
            z_threshold = self.hyperparams['ZSCORE_THRESHOLD']
            
            print(f"Z-score column stats:")
            print(f"  - Non-null values: {len(z_values)} out of {len(self.df_main)} ({len(z_values)/len(self.df_main)*100:.2f}%)")
            print(f"  - Range: {z_values.min():.4f} to {z_values.max():.4f}")
            print(f"  - Values exceeding threshold {z_threshold}: {(abs(z_values) >= z_threshold).sum()} ({(abs(z_values) >= z_threshold).sum()/len(z_values)*100:.2f}%)")
        else:
            print(f"CRITICAL ERROR: Z-score column '{z_col}' not found in data!")
            z_cols = [col for col in self.df_main.columns if col.startswith('z_')]
            if z_cols:
                print(f"Available z-score columns: {z_cols}")
            else:
                print("No z-score columns found in data!")
        
        # 3. Check for essential columns
        required_cols = ['date', 'permno', 'group_id', 'adj_prc', 'fed_funds_rate', 'adv20', 'vwretd', 'garch_vol']
        missing_cols = [col for col in required_cols if col not in self.df_main.columns]
        
        if missing_cols:
            print(f"\nMISSING REQUIRED COLUMNS: {missing_cols}")
        else:
            print("\nAll required base columns are present")
        
        # 4. Check for NaN values in essential columns
        print("\nNaN check for essential columns:")
        for col in required_cols:
            if col in self.df_main.columns:
                null_count = self.df_main[col].isna().sum()
                null_pct = null_count / len(self.df_main) * 100
                print(f"  - {col}: {null_count} NaN values ({null_pct:.2f}%)")
        
        print("\n=== END OF DIAGNOSTIC REPORT ===\n")

In [None]:
import os
import pickle
import time
import traceback
import numpy as np
import pandas as pd
from itertools import product

#from .backtest_engine import BacktestEngine

def run_hyperparameter_grid_search(df_main, df_pairs, param_grid, output_file='backtest_results.csv'):
    """Run backtest with different hyperparameter combinations"""
    results = []
    
    # Generate parameter combinations more efficiently
    keys = list(param_grid.keys())
    values = list(param_grid.values())
    param_combinations = []
    
    # Get all parameter combinations except INITIAL_CAPITAL
    non_capital_keys = [k for k in keys if k != 'INITIAL_CAPITAL']
    non_capital_values = [param_grid[k] for k in non_capital_keys]
    
    # Generate combinations with product
    for combination in product(*non_capital_values):
        params = dict(zip(non_capital_keys, combination))
        params['INITIAL_CAPITAL'] = param_grid['INITIAL_CAPITAL']
        param_combinations.append(params)
    
    print(f"Running {len(param_combinations)} parameter combinations")
    
    # Use a checkpointing mechanism
    checkpoint_file = f"checkpoint_{os.path.basename(output_file)}.pkl"
    completed_runs = set()
    try:
        if os.path.exists(checkpoint_file):
            with open(checkpoint_file, 'rb') as f:
                checkpoint_data = pickle.load(f)
                results = checkpoint_data.get('results', [])
                completed_runs = set(checkpoint_data.get('completed', []))
                print(f"Loaded {len(results)} previous results from checkpoint")
    except Exception as e:
        print(f"Error loading checkpoint: {str(e)}. Starting fresh.")
        results = []
        completed_runs = set()
    
    # Run backtest for each combination
    for i, params in enumerate(param_combinations):
        # Skip already completed runs
        params_str = str(params)
        if params_str in completed_runs:
            print(f"Skipping combination {i+1}/{len(param_combinations)}: already completed")
            continue
            
        print(f"Running combination {i+1}/{len(param_combinations)}: {params}")
        
        try:
            # Create a different random seed for each run for reproducibility
            seed = hash(params_str) % 10000
            np.random.seed(seed)
            
            backtest = BacktestEngine(df_main, df_pairs, params)
            result = backtest.run_backtest()
            
            # Extract performance metrics
            performance = result['performance']

            # Save trade log to file with timestamp
            timestamp = time.strftime("%Y%m%d_%H%M%S")
            if not result['trade_log'].empty:
                trade_log_file = f"trade_log_{params['ZSCORE_METHOD']}_{params['ZSCORE_THRESHOLD']}_{params['LOOKBACK_PERIOD']}_{params['HORIZON']}_{params['MAX_HOLDING_DAYS']}_{timestamp}.csv"
                result['trade_log'].to_csv(trade_log_file, index=False)
                print(f"Saved {len(result['trade_log'])} trades to {trade_log_file}")
            else:
                print("No trades to save!")
            
            # Combine parameters and performance metrics for output
            result_row = {
                'CORRELATION_THRESHOLD': params['CORRELATION_THRESHOLD'],
                'COINTEGRATION_THRESHOLD': params['COINTEGRATION_THRESHOLD'],
                'ZSCORE_METHOD': params['ZSCORE_METHOD'],
                'ZSCORE_THRESHOLD': params['ZSCORE_THRESHOLD'],
                'LOOKBACK_PERIOD': params['LOOKBACK_PERIOD'],
                'HORIZON': params['HORIZON'],
                'MAX_HOLDING_DAYS': params['MAX_HOLDING_DAYS'],
                'sharpe_ratio': performance.get('sharpe_ratio', 0),
                'sortino_ratio': performance.get('sortino_ratio', 0),
                'alpha': performance.get('alpha', 0),
                'beta': performance.get('beta', 0),
                'max_drawdown': performance.get('max_drawdown', 0),
                'hit_rate': performance.get('hit_rate', 0),
                'num_trades': performance.get('num_trades', 0),
                'avg_trade_pnl': performance.get('avg_trade_pnl', 0),
                'avg_holding_period': performance.get('avg_holding_period', 0),
                'num_trading_days': performance.get('num_trading_days', 0)
            }
            
            results.append(result_row)
            completed_runs.add(params_str)
            
            # Save checkpoint after each successful run
            try:
                with open(checkpoint_file, 'wb') as f:
                    pickle.dump({'results': results, 'completed': list(completed_runs)}, f)
                
                # Save to CSV as well
                pd.DataFrame(results).to_csv(output_file, index=False)
            except Exception as save_err:
                print(f"Error saving checkpoint: {str(save_err)}")
            
        except Exception as e:
            print(f"Error running combination {i+1}: {params}")
            print(f"Error details: {str(e)}")
            traceback.print_exc()
            
            # Add a row with error information
            error_row = {
                'CORRELATION_THRESHOLD': params['CORRELATION_THRESHOLD'],
                'COINTEGRATION_THRESHOLD': params['COINTEGRATION_THRESHOLD'],
                'ZSCORE_METHOD': params['ZSCORE_METHOD'],
                'ZSCORE_THRESHOLD': params['ZSCORE_THRESHOLD'],
                'LOOKBACK_PERIOD': params['LOOKBACK_PERIOD'],
                'HORIZON': params['HORIZON'],
                'MAX_HOLDING_DAYS': params['MAX_HOLDING_DAYS'],
                'error': str(e),
                'sharpe_ratio': 0,
                'sortino_ratio': 0,
                'alpha': 0,
                'beta': 0,
                'max_drawdown': 0,
                'hit_rate': 0,
                'num_trades': 0,
                'avg_trade_pnl': 0,
                'avg_holding_period': 0,
                'num_trading_days': 0
            }
            results.append(error_row)
            
            # Save checkpoint and CSV after error
            try:
                with open(checkpoint_file, 'wb') as f:
                    pickle.dump({'results': results, 'completed': list(completed_runs)}, f)
                pd.DataFrame(results).to_csv(output_file, index=False)
            except Exception as save_err:
                print(f"Error saving checkpoint after error: {str(save_err)}")
    
    # Final save and return
    try:
        results_df = pd.DataFrame(results)
        results_df.to_csv(output_file, index=False)
        return results_df
    except Exception as final_err:
        print(f"Error saving final results: {str(final_err)}")
        return pd.DataFrame(results)

In [None]:
import pandas as pd
import time
import traceback

#from .grid_search import run_hyperparameter_grid_search  # Uncomment this import

def run_backtest(df_main_path='final_backtest_data.csv', 
               df_pairs_path='corr_coin.csv',
               period='train'):
    """Main function to run the backtest"""
    print("Loading data...")
    
    try:
        # Load the datasets with the correct filenames
        try:
            df_merged_filtered = pd.read_csv(df_main_path)
            print(f"Successfully loaded {df_main_path}")
        except Exception as e:
            print(f"Error loading {df_main_path}: {str(e)}")
            raise
            
        try:
            df_pairs = pd.read_csv(df_pairs_path)
            print(f"Successfully loaded {df_pairs_path}")
            print(f"Columns in df_pairs: {list(df_pairs.columns)}")
        except Exception as e:
            print(f"Error loading {df_pairs_path}: {str(e)}")
            raise

        # Rename column names if needed
        if 'permno_1' in df_pairs.columns and 'permno_2' in df_pairs.columns:
            df_pairs.rename(columns={'permno_1': 'permno_black', 'permno_2': 'permno_white'}, inplace=True)
        
        # Convert date columns to datetime
        df_merged_filtered['date'] = pd.to_datetime(df_merged_filtered['date'])
        
        # Filter data based on period
        if period.lower() == 'train':
            start_date = '2015-01-01'
            end_date = '2021-12-31'
            period_name = "in-sample"
        elif period.lower() == 'test':
            start_date = '2022-01-01'
            end_date = '2024-12-31'
            period_name = "out-of-sample"
        else:
            raise ValueError(f"Invalid period: {period}. Use 'train' or 'test'.")
        
        # Filter main dataframe by date
        date_mask = (df_merged_filtered['date'] >= start_date) & (df_merged_filtered['date'] <= end_date)
        df_merged_filtered = df_merged_filtered[date_mask].copy()

        # Define quarters based on calendar date
        df_merged_filtered['quarter'] = df_merged_filtered['date'].dt.to_period('Q').astype(str)
        
        # Filter pairs by date range if formation_date exists
        if 'formation_date' in df_pairs.columns:
            df_pairs['formation_date'] = pd.to_datetime(df_pairs['formation_date'])
            date_mask = (df_pairs['formation_date'] >= start_date) & (df_pairs['formation_date'] <= end_date)
            df_pairs = df_pairs[date_mask].copy()
            print(f"Filtered pairs: {len(df_pairs)} within date range")
        
        # Print data overview
        quarters = df_merged_filtered['quarter'].unique()
        print(f"\n=== Data overview for {period_name} period ({start_date} to {end_date}) ===")
        print(f"Date range: {df_merged_filtered['date'].min()} to {df_merged_filtered['date'].max()}")
        print(f"Number of trading days: {df_merged_filtered['date'].nunique()}")
        print(f"Number of stocks: {df_merged_filtered['permno'].nunique()}")
        print(f"Number of calendar quarters: {len(quarters)}")
        print(f"Number of pairs: {len(df_pairs)}")
        
        # Define hyperparameter grid
        param_grid = {
            'COINTEGRATION_THRESHOLD': [0.05],
            'CORRELATION_THRESHOLD': [0.9], #0.5, 0.7, ],
            'ZSCORE_METHOD': ['classical'], #'ou'],
            'ZSCORE_THRESHOLD': [1],
            'LOOKBACK_PERIOD': [10],
            'HORIZON': [10],
            'MAX_HOLDING_DAYS': [10],
            'INITIAL_CAPITAL': 1_000_000_000
        }
        
        # Output file path
        timestamp = time.strftime("%Y%m%d_%H%M%S")
        output_file = f'backtest_results_{period}_{timestamp}.csv'
        
        # Calculate number of combinations
        num_combinations = 1
        for key, values in param_grid.items():
            if isinstance(values, list):
                num_combinations *= len(values)
        
        print(f"Starting grid search with {num_combinations} combinations...")
        
        # Run grid search
        results = run_hyperparameter_grid_search(df_merged_filtered, df_pairs, param_grid, output_file)
        
        # Print summary of best results
        if not results.empty:
            print("\nTop 5 parameter combinations by Sharpe ratio:")
            top_sharpe = results.sort_values('sharpe_ratio', ascending=False).head(5)
            print(top_sharpe[['ZSCORE_METHOD', 'ZSCORE_THRESHOLD', 'LOOKBACK_PERIOD', 'HORIZON', 'MAX_HOLDING_DAYS', 'sharpe_ratio', 'sortino_ratio', 'alpha']])
            
            # Save the best performing parameters for future use
            try:
                best_params_idx = results['sharpe_ratio'].idxmax()
                best_params = results.loc[best_params_idx].to_dict()
                with open(f'best_params_{period}_{timestamp}.txt', 'w') as f:
                    for k, v in best_params.items():
                        f.write(f"{k}: {v}\n")
                
                print(f"\nResults saved to {output_file}")
                print(f"Best parameters saved to best_params_{period}_{timestamp}.txt")
            except Exception as e:
                print(f"Error saving best parameters: {str(e)}")
        else:
            print("No valid results were generated. Check the error logs.")
        
        return results
        
    except Exception as e:
        print(f"Error in main execution: {str(e)}")
        traceback.print_exc()
        return None

In [None]:
#from .trade import Trade
#from .signal_generator import SignalGenerator
#from .portfolio_manager import PortfolioManager
#from .performance import calculate_trade_based_metrics
#from .backtest_engine import BacktestEngine
#from .grid_search import run_hyperparameter_grid_search
#from .main import run_backtest

__all__ = [
    'Trade',
    'SignalGenerator',
    'PortfolioManager',
    'calculate_trade_based_metrics',
    'BacktestEngine',
    'run_hyperparameter_grid_search',
    'run_backtest'
]

In [None]:
if __name__ == "__main__":
    # Specify period as 'train' or 'test'
    period = 'test'
    
    print(f"Running backtest for period: {period}")
    
    # Run the backtest
    results = run_backtest(
        df_main_path='final_backtest_data.csv',
        df_pairs_path='corr_coin.csv',
        period=period
    )
    
    if results is not None:
        print("Backtest completed successfully!")
    else:
        print("Backtest failed. Check error logs.")