In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime, timedelta
from scipy import stats
from scipy.ndimage import gaussian_filter1d
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score
import warnings
warnings.filterwarnings('ignore')

In [12]:
class EarningsIVDataPipeline:
    """
    Enhanced pipeline for detailed earnings IV analysis with focus on single-name studies
    """
    
    def __init__(self, db_connection):
        self.db = db_connection
        self.data = {}
        self.available_tables = None
        self.analysis_results = {}
        
    # [Include all previous methods from original class - setup_optionm_tables, build_optionm_query, etc.]
    # ... (keeping original methods for brevity, but they would all be included)

    def setup_optionm_tables(self):
        """
        Get available OptionMetrics tables
        """
        if self.available_tables is None:
            tables_df = self.db.raw_sql("""
                SELECT table_name
                FROM information_schema.tables
                WHERE table_schema = 'optionm'
                ORDER BY table_name
            """)
            self.available_tables = set(tables_df['table_name'].str.lower())
        return self.available_tables
    
    def build_optionm_query(self, table_base, start_date, end_date, fields, secids=None):
        """
        Build OptionMetrics query using your existing query builder logic
        """
        table_base = table_base.lower()
        
        # Match all tables starting with the given base (e.g. opprcd)
        matching_tables = [t for t in self.available_tables if t.startswith(table_base)]
        if not matching_tables:
            return f"Table '{table_base}' not found in OptionMetrics."
        
        # SECID filter
        secid_filter = ""
        if secids is not None:
            if isinstance(secids, (list, tuple, set)):
                secid_list = ", ".join(str(s) for s in secids)
                secid_filter = f"AND secid IN ({secid_list})"
            else:
                secid_filter = f"AND secid = {secids}"
        
        # Determine year range
        years = list(range(pd.to_datetime(start_date).year, pd.to_datetime(end_date).year + 1))
         
        # Case: non-suffixed (single) table (e.g., 'securd1')
        if table_base in matching_tables:
            return f"""
    SELECT {', '.join(fields)}
    FROM optionm.{table_base}
    WHERE date BETWEEN '{start_date}' AND '{end_date}'
    {secid_filter}
            """.strip()
        
        # Case: year-suffixed tables (e.g., opprcd2014, hvold2015, etc.)
        union_queries = []
        for year in years:
            table_year = f"{table_base}{year}"
            if table_year in matching_tables:
                query = f"""
    SELECT {', '.join(fields)}
    FROM optionm.{table_year}
    WHERE date BETWEEN '{start_date}' AND '{end_date}'
    {secid_filter}
    """.strip()
                union_queries.append(query)
        
        if not union_queries:
            return f"No available year-specific tables for '{table_base}' in range {years}."
        
        return "\nUNION ALL\n".join(union_queries)

    def build_rdq_query_from_tickers(self, ticker_list, start_date, end_date):
        """
        Build SQL query to fetch earnings report dates (rdq) for a list of tickers.
        """
        if not ticker_list:
            raise ValueError("You must provide at least one ticker.")
        
        # Format tickers for SQL IN clause
        formatted_tickers = ', '.join([f"'{ticker}'" for ticker in ticker_list])
        
        query = f"""
        SELECT cusip,
               tic as ticker,
               datadate,
               rdq as earnings_date,
               fyearq,
               fqtr
        FROM comp.fundq
        WHERE tic IN ({formatted_tickers})
          AND rdq BETWEEN '{start_date}' AND '{end_date}'
          AND rdq IS NOT NULL
        ORDER BY tic, rdq;
        """
        return query
    
    def build_secprd_query(self, secid_list, start_date, end_date):
        """
        Build SQL query to fetch daily stock data from optionm.secprd for a list of SECIDs.
        """
        if not secid_list:
            raise ValueError("SECID list is empty.")
        
        # Format SECIDs as numeric values, no quotes
        formatted_secids = ', '.join([str(int(secid)) for secid in secid_list])
        
        query = f"""
        SELECT *
        FROM optionm.secprd
        WHERE secid IN ({formatted_secids})
          AND date BETWEEN '{start_date}' AND '{end_date}'
        ORDER BY secid, date;
        """
        return query
    def get_securities_info(self, ticker_list):
        """
        Get security information from OptionMetrics securd1 table
        """
        print("Fetching security information from OptionMetrics...")
        
        # Format tickers for SQL IN clause
        formatted_tickers = ', '.join([f"'{ticker}'" for ticker in ticker_list])
        
        query = f"""
        SELECT DISTINCT *
        FROM optionm.securd1
        WHERE ticker IN ({formatted_tickers})
          AND exchange_d != 0
        ORDER BY ticker
        """
        
        self.data['securities'] = self.db.raw_sql(query)
        print(f"Retrieved {len(self.data['securities'])} securities")
        
        return self.data['securities']
    
    def get_earnings_dates(self, ticker_list, start_date='2023-01-01', end_date='2024-12-31'):
        """
        Fetch earnings announcement dates using Compustat
        """
        print("Fetching earnings announcement dates from Compustat...")
        
        query = self.build_rdq_query_from_tickers(ticker_list, start_date, end_date)
        
        try:
            self.data['earnings'] = self.db.raw_sql(query)
            print(f"Retrieved {len(self.data['earnings'])} earnings announcements")
            return self.data['earnings']
        except Exception as e:
            print(f"Error fetching earnings data: {e}")
            return None
    
    def get_option_data(self, secid_list, start_date='2023-01-01', end_date='2024-12-31'):
        """
        Fetch option data from OptionMetrics using secids - NO SYNTHETIC DATA
        """
        print("Fetching option data from OptionMetrics...")
        
        # Setup available tables
        self.setup_optionm_tables()
        
        # Define fields to select - using common OptionMetrics field names
        fields = [
            'date', 'secid', 'exdate', 'strike_price', 'cp_flag',
            'best_bid', 'best_offer', 'open_interest',
            'impl_volatility', 'delta', 'gamma', 'theta', 'vega', 'volume'
        ]
        
        # Build query using your query builder
        query = self.build_optionm_query('opprcd', start_date, end_date, fields, secid_list)
        
        if "not found" in query or "No available" in query:
            print(f"Query build failed: {query}")
            return None
        
        try:
            print("Executing options query...")
            self.data['options'] = self.db.raw_sql(query)
            
            # Debug: Print column names to see what's available
            print(f"Available columns in options data: {list(self.data['options'].columns)}")
            
            # Check for volume column variations
            volume_field_candidates = ['volume', 'vol', 'contract_volume', 'opt_volume']
            volume_col = None
            for col_candidate in volume_field_candidates:
                if col_candidate in self.data['options'].columns:
                    volume_col = col_candidate
                    break
            
            if volume_col and volume_col != 'volume':
                print(f"Found volume column: {volume_col}, renaming to 'volume'")
                self.data['options']['volume'] = self.data['options'][volume_col]
            
            print(f"Retrieved {len(self.data['options'])} option records")
            return self.data['options']
            
        except Exception as e:
            print(f"Error fetching options data: {e}")
            print("Attempting to query with minimal fields...")
            
            # Fallback: try with minimal fields - NO SYNTHETIC DATA
            minimal_fields = ['date', 'secid', 'exdate', 'strike_price', 'cp_flag', 
                            'best_bid', 'best_offer', 'impl_volatility']
            
            query = self.build_optionm_query('opprcd', start_date, end_date, minimal_fields, secid_list)
            
            try:
                self.data['options'] = self.db.raw_sql(query)
                print(f"Retrieved {len(self.data['options'])} option records with minimal fields")
                print("Warning: Limited fields available - some analyses may not be possible")
                return self.data['options']
            except Exception as e2:
                print(f"Fallback also failed: {e2}")
                return None
    
    def get_stock_prices(self, secid_list, start_date='2023-01-01', end_date='2024-12-31'):
        """
        Get underlying stock prices from OptionMetrics secprd
        """
        print("Fetching stock prices from OptionMetrics...")
        
        query = self.build_secprd_query(secid_list, start_date, end_date)
        
        try:
            self.data['stock_prices'] = self.db.raw_sql(query)
            print(f"Retrieved {len(self.data['stock_prices'])} stock price records")
            return self.data['stock_prices']
        except Exception as e:
            print(f"Error fetching stock prices: {e}")
            return None
    
    def merge_securities_earnings(self):
        """
        Merge securities info with earnings data using ticker matching
        """
        if 'securities' not in self.data or 'earnings' not in self.data:
            print("Need both securities and earnings data")
            return None
        
        # Merge on ticker
        merged = self.data['earnings'].merge(
            self.data['securities'][['secid', 'ticker', 'cusip', 'issuer']], 
            on='ticker', 
            how='inner'
        )
        
        self.data['earnings_securities'] = merged
        print(f"Merged {len(merged)} earnings-securities records")
        return merged
    
    def calculate_option_metrics(self):
        """
        Calculate additional option metrics - ONLY WITH REAL DATA
        """
        if 'options' not in self.data:
            raise ValueError("Options data not loaded. Run get_option_data() first.")
        
        df = self.data['options'].copy()
        
        # Convert dates to datetime objects
        df['date'] = pd.to_datetime(df['date'])
        df['exdate'] = pd.to_datetime(df['exdate'])
        
        # Calculate time to expiration in days
        df['tte'] = (df['exdate'] - df['date']).dt.days
        
        # Initialize underlying_price columns
        df['underlying_price'] = np.nan

        # Merge with stock prices to get underlying prices
        if 'stock_prices' in self.data and not self.data['stock_prices'].empty:
            stock_df = self.data['stock_prices'].copy()
            stock_df['date'] = pd.to_datetime(stock_df['date'])
            
            # Rename columns in stock_df BEFORE merging to avoid suffix issues
            stock_df = stock_df.rename(columns={'close': 'underlying_price'})
            
            # Only merge if stock_df has the necessary columns after rename
            if 'underlying_price' in stock_df.columns:
                df = df.merge(stock_df[['date', 'secid', 'underlying_price']], 
                              on=['date', 'secid'], 
                              how='left',
                              suffixes=('', '_stock')) 
            else:
                print("Warning: stock_prices DataFrame is missing 'close' column. Cannot merge underlying price.")
        else:
            print("Warning: No stock prices data available. Cannot calculate proper moneyness.")
            
        # Calculate mid_price
        if 'best_bid' in df.columns and 'best_offer' in df.columns:
            df['mid_price'] = (df['best_bid'] + df['best_offer']) / 2
        
        # Only fill missing underlying prices with mid_price if both exist
        if 'mid_price' in df.columns:
            df['underlying_price'] = df['underlying_price'].fillna(df['mid_price'])

        # Calculate moneyness: Keep your original calculation
        df['moneyness'] = df['strike_price'] / 100000.0
        
        # Calculate bid-ask spread only if both bid and offer exist
        if 'best_bid' in df.columns and 'best_offer' in df.columns:
            df['bid_ask_spread'] = np.where(
                df['best_bid'] > 0, 
                (df['best_offer'] - df['best_bid']) / df['best_bid'], 
                np.nan 
            )
            df['bid_ask_spread'] = df['bid_ask_spread'].clip(lower=0) 
        
        # Log moneyness - only if moneyness is valid
        if 'moneyness' in df.columns:
            df['log_moneyness'] = np.log(df['moneyness'].clip(lower=0.01)) 
        
        self.data['options_enhanced'] = df
        print(f"Enhanced {len(df)} option records with calculated metrics")
        
        return df
    
    def apply_data_filters(self, min_volume=10, max_bid_ask_spread=0.5, 
                           tte_range=(7, 60), moneyness_range=(0.8, 1.2)):
        """
        Apply data quality filters - ONLY filter on available real data
        """
        if 'options_enhanced' not in self.data:
            self.calculate_option_metrics()
        
        df = self.data['options_enhanced'].copy()
        initial_count = len(df)
        
        print(f"Applying filters to {initial_count:,} records...")
        print(f"Available columns: {list(df.columns)}")
        
        # Apply filters conditionally based on available columns
        
        if 'volume' in df.columns:
            df_before = len(df)
            df = df[df['volume'].notna() & (df['volume'] >= min_volume)]
            print(f"  After Volume filter (>= {min_volume}): {len(df):,} records ({len(df)/df_before:.1%} retained)")
        
        if 'bid_ask_spread' in df.columns:
            df_before = len(df)
            df = df[df['bid_ask_spread'].notna() & (df['bid_ask_spread'] <= max_bid_ask_spread)]
            print(f"  After Bid-ask spread filter (<= {max_bid_ask_spread}): {len(df):,} records ({len(df)/df_before:.1%} retained)")
        
        if 'tte' in df.columns:
            df_before = len(df)
            df = df[(df['tte'] >= tte_range[0]) & (df['tte'] <= tte_range[1])]
            print(f"  After TTE filter ({tte_range[0]} <= tte <= {tte_range[1]}): {len(df):,} records ({len(df)/df_before:.1%} retained)")
        
        if 'moneyness' in df.columns:
            df_before = len(df)
            df = df[(df['moneyness'] >= moneyness_range[0]) & (df['moneyness'] <= moneyness_range[1])]
            print(f"  After Moneyness filter ({moneyness_range[0]} <= moneyness <= {moneyness_range[1]}): {len(df):,} records ({len(df)/df_before:.1%} retained)")
        
        # Only filter on real data columns
        real_data_filters = [
            ('vega', lambda x: x > 0),
            ('best_bid', lambda x: x > 0),
            ('best_offer', lambda x: x > 0),
            ('impl_volatility', lambda x: x > 0)
        ]
        
        for col_name, filter_func in real_data_filters:
            if col_name in df.columns:
                df_before = len(df)
                df = df[df[col_name].notna() & df[col_name].apply(filter_func)]
                print(f"  After {col_name} filter: {len(df):,} records ({len(df)/df_before:.1%} retained)")
        
        filtered_count = len(df)
        print(f"Filtered from {initial_count:,} to {filtered_count:,} records "
              f"({filtered_count/initial_count:.1%} retention)")
        
        self.data['options_filtered'] = df
        
        return df
    
    def merge_earnings_options(self, event_window_days=30):
        """
        Merge earnings dates with option data using secid
        """
        if 'earnings_securities' not in self.data or 'options_filtered' not in self.data:
            print("Need both earnings_securities and options_filtered data")
            return None
        
        earnings = self.data['earnings_securities'].copy()
        options = self.data['options_filtered'].copy()
        
        # Convert dates
        earnings['earnings_date'] = pd.to_datetime(earnings['earnings_date'])
        options['date'] = pd.to_datetime(options['date'])
        
        # Merge on secid
        merged_data = []
        
        for _, earning in earnings.iterrows():
            secid = earning['secid']
            earnings_date = earning['earnings_date']
            
            # Get options data within event window
            secid_options = options[
                (options['secid'] == secid) &
                (options['date'] >= earnings_date - timedelta(days=event_window_days)) &
                (options['date'] <= earnings_date)
            ].copy()
            
            if len(secid_options) > 0:
                secid_options['earnings_date'] = earnings_date
                secid_options['days_to_earnings'] = (earnings_date - secid_options['date']).dt.days
                secid_options['ticker'] = earning['ticker']
                merged_data.append(secid_options)
        
        if merged_data:
            self.data['earnings_options'] = pd.concat(merged_data, ignore_index=True)
            print(f"Merged dataset contains {len(self.data['earnings_options'])} records")
        else:
            print("No matching earnings-options data found")
            
        return self.data.get('earnings_options')
        
    
    def get_large_cap_universe(self, min_market_cap=1e9, min_option_volume=1000):
        """
        Get universe of large cap stocks with sufficient option volume
        """
        print("Building large cap universe with option volume filters...")
        
        query = f"""
        WITH market_caps AS (
            SELECT DISTINCT s.secid, s.ticker, s.issuer,
                   p.close * s.share_vol as market_cap,
                   p.date
            FROM optionm.securd1 s
            JOIN optionm.secprd p ON s.secid = p.secid
            WHERE s.exchange_d != 0
              AND p.date >= '2023-01-01'
              AND p.close * s.share_vol >= {min_market_cap}
        ),
        option_volumes AS (
            SELECT secid, 
                   AVG(COALESCE(volume, 0)) as avg_daily_volume,
                   COUNT(*) as trading_days
            FROM optionm.opprcd2023
            WHERE volume IS NOT NULL AND volume > 0
            GROUP BY secid
            HAVING AVG(COALESCE(volume, 0)) >= {min_option_volume}
        )
        SELECT DISTINCT m.secid, m.ticker, m.issuer, 
               AVG(m.market_cap) as avg_market_cap,
               v.avg_daily_volume as avg_option_volume
        FROM market_caps m
        JOIN option_volumes v ON m.secid = v.secid
        GROUP BY m.secid, m.ticker, m.issuer, v.avg_daily_volume
        ORDER BY avg_market_cap DESC
        LIMIT 200
        """
        
        try:
            self.data['universe'] = self.db.raw_sql(query)
            print(f"Universe contains {len(self.data['universe'])} stocks")
            return self.data['universe']
        except Exception as e:
            print(f"Error building universe: {e}")
            # Fallback to manual list
            fallback_tickers = ['AAPL', 'MSFT', 'GOOGL', 'AMZN', 'TSLA', 'META', 'NVDA', 'JPM', 'V', 'UNH']
            return self.get_securities_info(fallback_tickers)

        def single_name_deep_dive(self, ticker, start_date='2023-01-01', end_date='2024-12-31', 
                             tte_focus=[14, 30, 45], moneyness_focus=[0.95, 1.0, 1.05]):
            """
            Comprehensive single-name analysis focusing on specific TTM/moneyness points
            """
        print(f"\n{'='*60}")
        print(f"DEEP DIVE ANALYSIS: {ticker}")
        print(f"{'='*60}")
        
        # Get security info
        securities = self.get_securities_info([ticker])
        if len(securities) == 0:
            print(f"No security data found for {ticker}")
            return None
            
        secid = securities.iloc[0]['secid']
        
        # Get all data for this ticker
        earnings = self.get_earnings_dates([ticker], start_date, end_date)
        stock_prices = self.get_stock_prices([secid], start_date, end_date)
        
        # Enhanced options data query with more fields
        options = self.get_enhanced_option_data([secid], start_date, end_date)
        
        if options is None or len(options) == 0:
            print(f"No options data found for {ticker}")
            return None
        
        # Calculate realized volatility
        realized_vol = self.calculate_realized_volatility(stock_prices, windows=[5, 10, 21, 30])
        
        # Focus analysis on liquid options
        liquid_options = self.filter_liquid_options(options, tte_focus, moneyness_focus)
        
        # Merge with earnings
        if earnings is not None and len(earnings) > 0:
            earnings_options = self.merge_earnings_with_focused_options(
                earnings, liquid_options, event_window=45
            )
        else:
            earnings_options = liquid_options.copy()
            earnings_options['earnings_date'] = None
            earnings_options['days_to_earnings'] = None
        
        # Store results
        analysis_key = f"{ticker}_analysis"
        self.analysis_results[analysis_key] = {
            'ticker': ticker,
            'securities': securities,
            'earnings': earnings,
            'stock_prices': stock_prices,
            'realized_vol': realized_vol,
            'options': options,
            'liquid_options': liquid_options,
            'earnings_options': earnings_options
        }
        
        # Generate comprehensive plots
        self.plot_single_name_analysis(ticker, analysis_key)
        
        # Volume analysis
        self.analyze_option_volume_vs_stock_adv(ticker, analysis_key)
        
        # Volatility surface analysis
        self.analyze_pre_earnings_surface(ticker, analysis_key)
        
        return analysis_key
    
    def get_enhanced_option_data(self, secid_list, start_date, end_date):
        """
        Get enhanced options data with additional calculated fields
        """
        options = self.get_option_data(secid_list, start_date, end_date)
        if options is None:
            return None
            
        # Enhanced calculations
        options = self.calculate_option_metrics()
        
        # Add more derived fields
        df = options.copy()
        
        # Option notional value
        if 'mid_price' in df.columns:
            df['notional'] = df['mid_price'] * 100  # Standard option multiplier
        
        # Vega-weighted IV (if vega available)
        if 'vega' in df.columns:
            df['vega_weighted_iv'] = df['impl_volatility'] * df['vega']
        
        # Distance from ATM
        if 'moneyness' in df.columns:
            df['atm_distance'] = np.abs(np.log(df['moneyness']))
        
        return df
    
    def calculate_realized_volatility(self, stock_prices, windows=[5, 10, 21, 30]):
        """
        Calculate realized volatility using multiple estimators and windows
        """
        if stock_prices is None or len(stock_prices) == 0:
            return None
            
        df = stock_prices.copy()
        df['date'] = pd.to_datetime(df['date'])
        df = df.sort_values('date')
        
        # Calculate returns
        df['returns'] = np.log(df['close'] / df['close'].shift(1))
        
        realized_vol_data = {}
        
        for window in windows:
            # Standard realized volatility
            df[f'realized_vol_{window}d'] = df['returns'].rolling(window=window).std() * np.sqrt(252)
            
            # Exponentially weighted
            df[f'ewm_vol_{window}d'] = df['returns'].ewm(span=window).std() * np.sqrt(252)
            
            # Parkinson estimator (if OHLC available)
            if all(col in df.columns for col in ['high', 'low', 'open']):
                hl_ratio = np.log(df['high'] / df['low'])
                df[f'parkinson_vol_{window}d'] = np.sqrt(
                    hl_ratio.rolling(window=window).apply(lambda x: np.sum(x**2) / len(x)) * 252
                )
        
        return df
    
    def filter_liquid_options(self, options, tte_focus, moneyness_focus, 
                             min_volume=10, max_spread=0.3):
        """
        Filter for liquid options around focal points
        """
        df = options.copy()
        
        # TTE filter - within range of focus points
        tte_ranges = [(tte-7, tte+7) for tte in tte_focus]
        tte_mask = pd.Series(False, index=df.index)
        for low, high in tte_ranges:
            tte_mask |= (df['tte'] >= low) & (df['tte'] <= high)
        
        # Moneyness filter - within range of focus points
        if 'moneyness' in df.columns:
            money_ranges = [(money-0.05, money+0.05) for money in moneyness_focus]
            money_mask = pd.Series(False, index=df.index)
            for low, high in money_ranges:
                money_mask |= (df['moneyness'] >= low) & (df['moneyness'] <= high)
        else:
            money_mask = pd.Series(True, index=df.index)
        
        # Volume and spread filters
        volume_mask = pd.Series(True, index=df.index)
        if 'volume' in df.columns:
            volume_mask = df['volume'] >= min_volume
            
        spread_mask = pd.Series(True, index=df.index)
        if 'bid_ask_spread' in df.columns:
            spread_mask = df['bid_ask_spread'] <= max_spread
        
        # Combine filters
        final_mask = tte_mask & money_mask & volume_mask & spread_mask
        
        print(f"Liquidity filtering: {len(df)} -> {final_mask.sum()} options")
        return df[final_mask].copy()
    
    def merge_earnings_with_focused_options(self, earnings, options, event_window=45):
        """
        Merge earnings with options data, focusing on pre-earnings period
        """
        if earnings is None or len(earnings) == 0:
            return options
            
        earnings['earnings_date'] = pd.to_datetime(earnings['earnings_date'])
        options['date'] = pd.to_datetime(options['date'])
        
        merged_data = []
        
        for _, earning in earnings.iterrows():
            secid = earning['secid']
            earnings_date = earning['earnings_date']
            
            # Get options in event window (focusing on pre-earnings)
            event_options = options[
                (options['secid'] == secid) &
                (options['date'] >= earnings_date - timedelta(days=event_window)) &
                (options['date'] <= earnings_date + timedelta(days=5))  # Small post window
            ].copy()
            
            if len(event_options) > 0:
                event_options['earnings_date'] = earnings_date
                event_options['days_to_earnings'] = (earnings_date - event_options['date']).dt.days
                event_options['pre_post_earnings'] = np.where(
                    event_options['days_to_earnings'] >= 0, 'pre', 'post'
                )
                merged_data.append(event_options)
        
        if merged_data:
            return pd.concat(merged_data, ignore_index=True)
        else:
            return options.copy()
    
    def plot_single_name_analysis(self, ticker, analysis_key):
        """
        Comprehensive plotting for single name analysis
        """
        data = self.analysis_results[analysis_key]
        
        fig = plt.figure(figsize=(20, 16))
        
        # 1. Stock price and realized volatility
        ax1 = plt.subplot(3, 3, 1)
        if data['stock_prices'] is not None:
            stock_df = data['stock_prices']
            ax1.plot(pd.to_datetime(stock_df['date']), stock_df['close'], 'b-', linewidth=1)
            ax1.set_title(f'{ticker} Stock Price')
            ax1.set_ylabel('Price ($)')
            
            # Mark earnings dates
            if data['earnings'] is not None:
                for _, earning in data['earnings'].iterrows():
                    ax1.axvline(pd.to_datetime(earning['earnings_date']), 
                               color='red', alpha=0.7, linestyle='--')
        
        # 2. Realized volatility comparison
        ax2 = plt.subplot(3, 3, 2)
        if data['realized_vol'] is not None:
            rv_df = data['realized_vol']
            rv_df['date'] = pd.to_datetime(rv_df['date'])
            
            # Plot different estimators
            for col in rv_df.columns:
                if 'vol_' in col and '21d' in col:  # Focus on 21-day window
                    if 'realized' in col:
                        ax2.plot(rv_df['date'], rv_df[col], label='Standard', alpha=0.8)
                    elif 'ewm' in col:
                        ax2.plot(rv_df['date'], rv_df[col], label='EWM', alpha=0.8)
                    elif 'parkinson' in col:
                        ax2.plot(rv_df['date'], rv_df[col], label='Parkinson', alpha=0.8)
            
            ax2.set_title('Realized Volatility (21d)')
            ax2.set_ylabel('Annualized Vol')
            ax2.legend()
            
            # Mark earnings
            if data['earnings'] is not None:
                for _, earning in data['earnings'].iterrows():
                    ax2.axvline(pd.to_datetime(earning['earnings_date']), 
                               color='red', alpha=0.7, linestyle='--')
        
        # 3. Implied vs Realized Volatility
        ax3 = plt.subplot(3, 3, 3)
        if data['earnings_options'] is not None and data['realized_vol'] is not None:
            options_df = data['earnings_options']
            rv_df = data['realized_vol']
            
            # Daily average IV
            daily_iv = options_df.groupby('date')['impl_volatility'].mean()
            
            # Merge with realized vol
            rv_df_daily = rv_df.set_index('date')['realized_vol_21d']
            
            common_dates = daily_iv.index.intersection(rv_df_daily.index)
            if len(common_dates) > 0:
                ax3.scatter(rv_df_daily[common_dates], daily_iv[common_dates], 
                           alpha=0.6, s=20)
                
                # 45-degree line
                min_val = min(rv_df_daily[common_dates].min(), daily_iv[common_dates].min())
                max_val = max(rv_df_daily[common_dates].max(), daily_iv[common_dates].max())
                ax3.plot([min_val, max_val], [min_val, max_val], 'r--', alpha=0.7)
                
                ax3.set_xlabel('Realized Vol (21d)')
                ax3.set_ylabel('Implied Vol')
                ax3.set_title('IV vs RV Scatter')
        
        # 4. Volatility surface evolution
        ax4 = plt.subplot(3, 3, 4)
        if data['liquid_options'] is not None:
            self.plot_volatility_surface_evolution(data['liquid_options'], ax4)
        
        # 5. Volume analysis
        ax5 = plt.subplot(3, 3, 5)
        if data['liquid_options'] is not None:
            volume_data = data['liquid_options']
            if 'volume' in volume_data.columns:
                daily_volume = volume_data.groupby('date')['volume'].sum()
                ax5.plot(daily_volume.index, daily_volume.values, 'g-', alpha=0.7)
                ax5.set_title('Daily Option Volume')
                ax5.set_ylabel('Contracts')
        
        # 6. Pre-earnings IV term structure
        ax6 = plt.subplot(3, 3, 6)
        if data['earnings_options'] is not None:
            pre_earnings = data['earnings_options'][
                data['earnings_options']['pre_post_earnings'] == 'pre'
            ]
            if len(pre_earnings) > 0:
                # Average IV by TTE for pre-earnings period
                ts_data = pre_earnings.groupby('tte')['impl_volatility'].mean()
                ax6.plot(ts_data.index, ts_data.values, 'bo-', markersize=4)
                ax6.set_xlabel('Days to Expiration')
                ax6.set_ylabel('Implied Volatility')
                ax6.set_title('Pre-Earnings Term Structure')
        
        # 7. Earnings effect on IV
        ax7 = plt.subplot(3, 3, 7)
        if data['earnings_options'] is not None:
            self.plot_earnings_iv_evolution(data['earnings_options'], ax7)
        
        # 8. Moneyness smile
        ax8 = plt.subplot(3, 3, 8)
        if data['liquid_options'] is not None and 'moneyness' in data['liquid_options'].columns:
            self.plot_volatility_smile(data['liquid_options'], ax8)
        
        # 9. Summary statistics
        ax9 = plt.subplot(3, 3, 9)
        self.plot_summary_statistics(data, ax9)
        
        plt.suptitle(f'{ticker} - Comprehensive Analysis', fontsize=16, y=0.98)
        plt.tight_layout()
        plt.subplots_adjust(top=0.94)
        plt.show()
    
    def plot_volatility_surface_evolution(self, options_df, ax):
        """Plot how volatility surface evolves over time"""
        # Focus on recent data and create surface snapshots
        if 'moneyness' not in options_df.columns:
            ax.text(0.5, 0.5, 'No moneyness data', ha='center', va='center', transform=ax.transAxes)
            return
            
        # Take snapshots at different dates
        dates = sorted(options_df['date'].unique())[-10:]  # Last 10 trading days
        
        for i, date in enumerate(dates[::3]):  # Every 3rd date to avoid clutter
            day_data = options_df[options_df['date'] == date]
            if len(day_data) < 5:
                continue
                
            # Group by moneyness and TTE
            surface = day_data.groupby(['moneyness', 'tte'])['impl_volatility'].mean().reset_index()
            
            # Plot for a specific TTE (e.g., around 30 days)
            tte_slice = surface[(surface['tte'] >= 25) & (surface['tte'] <= 35)]
            if len(tte_slice) > 2:
                color = plt.cm.viridis(i / len(dates[::3]))
                ax.plot(tte_slice['moneyness'], tte_slice['impl_volatility'], 
                       'o-', color=color, alpha=0.7, markersize=3,
                       label=f'{pd.to_datetime(date).strftime("%m/%d")}')
        
        ax.set_xlabel('Moneyness')
        ax.set_ylabel('Implied Volatility')
        ax.set_title('IV Surface Evolution (30d TTE)')
        if ax.get_legend_handles_labels()[0]:
            ax.legend(fontsize=8)
    
    def plot_earnings_iv_evolution(self, earnings_options, ax):
        """Plot how IV evolves around earnings"""
        if 'days_to_earnings' not in earnings_options.columns:
            return
            
        # Average IV by days to earnings
        iv_evolution = earnings_options.groupby('days_to_earnings')['impl_volatility'].agg([
            'mean', 'std', 'count'
        ]).reset_index()
        
        # Filter for sufficient observations
        iv_evolution = iv_evolution[iv_evolution['count'] >= 3]
        
        if len(iv_evolution) > 0:
            ax.errorbar(iv_evolution['days_to_earnings'], iv_evolution['mean'],
                       yerr=iv_evolution['std'], marker='o', capsize=3, capthick=1, alpha=0.8)
            ax.axvline(x=0, color='red', linestyle='--', alpha=0.7, label='Earnings Date')
            ax.set_xlabel('Days to Earnings')
            ax.set_ylabel('Implied Volatility')
            ax.set_title('IV Evolution Around Earnings')
            ax.legend()
    
    def plot_volatility_smile(self, options_df, ax):
        """Plot volatility smile"""
        if 'moneyness' not in options_df.columns:
            return
            
        # Focus on short-term options (7-45 days)
        short_term = options_df[(options_df['tte'] >= 7) & (options_df['tte'] <= 45)]
        
        if len(short_term) == 0:
            return
            
        # Group by moneyness bins
        moneyness_bins = pd.cut(short_term['moneyness'], bins=15)
        smile_data = short_term.groupby(moneyness_bins)['impl_volatility'].agg([
            'mean', 'count'
        ]).reset_index()
        
        # Filter for sufficient observations
        smile_data = smile_data[smile_data['count'] >= 3]
        
        if len(smile_data) > 0:
            moneyness_centers = smile_data['moneyness'].apply(lambda x: x.mid)
            ax.plot(moneyness_centers, smile_data['mean'], 'bo-', markersize=4)
            ax.set_xlabel('Moneyness')
            ax.set_ylabel('Implied Volatility')
            ax.set_title('Volatility Smile (7-45d)')
    
    def plot_summary_statistics(self, data, ax):
        """Plot key summary statistics"""
        ax.axis('off')
        
        # Gather key stats
        stats_text = []
        
        if data['earnings'] is not None:
            stats_text.append(f"Earnings Events: {len(data['earnings'])}")
        
        if data['liquid_options'] is not None:
            total_contracts = len(data['liquid_options'])
            stats_text.append(f"Liquid Options: {total_contracts:,}")
            
            if 'volume' in data['liquid_options'].columns:
                avg_volume = data['liquid_options']['volume'].mean()
                stats_text.append(f"Avg Volume: {avg_volume:.0f}")
            
            if 'impl_volatility' in data['liquid_options'].columns:
                avg_iv = data['liquid_options']['impl_volatility'].mean()
                std_iv = data['liquid_options']['impl_volatility'].std()
                stats_text.append(f"Avg IV: {avg_iv:.3f} ± {std_iv:.3f}")
        
        if data['realized_vol'] is not None and 'realized_vol_21d' in data['realized_vol'].columns:
            avg_rv = data['realized_vol']['realized_vol_21d'].mean()
            stats_text.append(f"Avg RV (21d): {avg_rv:.3f}")
        
        # Display stats
        y_pos = 0.9
        for stat in stats_text:
            ax.text(0.1, y_pos, stat, transform=ax.transAxes, fontsize=10, 
                   verticalalignment='top')
            y_pos -= 0.15
        
        ax.set_title('Summary Statistics')
    
    def analyze_option_volume_vs_stock_adv(self, ticker, analysis_key):
        """
        Analyze option volume relative to stock average daily volume
        """
        data = self.analysis_results[analysis_key]
        
        if data['liquid_options'] is None or data['stock_prices'] is None:
            print("Insufficient data for volume analysis")
            return
        
        # Calculate stock ADV
        stock_df = data['stock_prices']
        if 'volume' in stock_df.columns:
            stock_adv = stock_df['volume'].mean()
        else:
            print("No stock volume data available")
            return
        
        options_df = data['liquid_options']
        
        # Calculate daily option volume and notional
        if 'volume' in options_df.columns:
            daily_option_volume = options_df.groupby('date').agg({
                'volume': 'sum',
                'notional': 'sum' if 'notional' in options_df.columns else lambda x: 0
            }).reset_index()
            
            # Option volume as % of stock ADV
            daily_option_volume['volume_ratio'] = daily_option_volume['volume'] / stock_adv * 100
            
            print(f"\n{ticker} Volume Analysis:")
            print(f"Stock ADV: {stock_adv:,.0f}")
            print(f"Avg Daily Option Volume: {daily_option_volume['volume'].mean():,.0f}")
            print(f"Option/Stock Volume Ratio: {daily_option_volume['volume_ratio'].mean():.1f}%")
            
            # Plot volume comparison
            fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8))
            
            # Daily option volume
            ax1.bar(daily_option_volume['date'], daily_option_volume['volume'], 
                   alpha=0.7, color='blue')
            ax1.axhline(y=daily_option_volume['volume'].mean(), color='red', 
                       linestyle='--', alpha=0.7, label='Average')
            ax1.set_ylabel('Option Volume (Contracts)')
            ax1.set_title(f'{ticker} Daily Option Volume')
            ax1.legend()
            
            # Volume ratio
            ax2.plot(daily_option_volume['date'], daily_option_volume['volume_ratio'], 
                    'go-', markersize=4, alpha=0.7)
            ax2.axhline(y=daily_option_volume['volume_ratio'].mean(), color='red', 
                       linestyle='--', alpha=0.7)
            ax2.set_ylabel('Option/Stock Volume (%)')
            ax2.set_xlabel('Date')
            ax2.set_title('Option Volume as % of Stock ADV')
            
            plt.tight_layout()
            plt.show()
    
    def analyze_pre_earnings_surface(self, ticker, analysis_key):
            """
            Detailed analysis of volatility surface before earnings
            """
            data = self.analysis_results[analysis_key]
            
            if data['earnings_options'] is None:
                print("No earnings-options data for surface analysis")
                return
            
            earnings_df = data['earnings_options']
            
            # Focus on pre-earnings period (1-5 days before)
            pre_earnings = earnings_df[
                (earnings_df['days_to_earnings'] >= 1) & 
                (earnings_df['days_to_earnings'] <= 5)
            ]
            
            if len(pre_earnings) == 0:
                print("No pre-earnings options data found")
                return
            
            print(f"\n{ticker} Pre-Earnings Surface Analysis:")
            print(f"Options in 1-5 days before earnings: {len(pre_earnings)}")
            
            # Create surface plot
            fig = plt.figure(figsize=(15, 10))
            
            # 3D Surface plot
            ax1 = fig.add_subplot(221, projection='3d')
            
            if all(col in pre_earnings.columns for col in ['moneyness', 'tte', 'impl_volatility']):
                # Create surface data
                surface_data = pre_earnings.groupby(['moneyness', 'tte'])['impl_volatility'].mean().reset_index()
                
                if len(surface_data) > 10:
                    # Create meshgrid
                    moneyness_unique = sorted(surface_data['moneyness'].unique())
                    tte_unique = sorted(surface_data['tte'].unique())
                    
                    X, Y = np.meshgrid(moneyness_unique[:10], tte_unique[:10])  # Limit size
                    Z = np.zeros_like(X)
                    
                    for i, tte in enumerate(tte_unique[:10]):
                        for j, moneyness in enumerate(moneyness_unique[:10]):
                            iv_val = surface_data[
                                (surface_data['tte'] == tte) & 
                                (surface_data['moneyness'] == moneyness)
                            ]['impl_volatility']
                            Z[i, j] = iv_val.iloc[0] if len(iv_val) > 0 else np.nan
                    
                    # Remove NaN values
                    mask = ~np.isnan(Z)
                    if mask.sum() > 0:
                        ax1.plot_surface(X, Y, Z, alpha=0.7, cmap='viridis')
                        ax1.set_xlabel('Moneyness')
                        ax1.set_ylabel('Time to Expiration')
                        ax1.set_zlabel('Implied Volatility')
                        ax1.set_title('Pre-Earnings IV Surface')
            
            # 2D slices
            ax2 = fig.add_subplot(222)
            # ATM term structure
            atm_data = pre_earnings[
                (pre_earnings['moneyness'] >= 0.95) & 
                (pre_earnings['moneyness'] <= 1.05)
            ]
            if len(atm_data) > 0:
                ts_data = atm_data.groupby('tte')['impl_volatility'].mean()
                ax2.plot(ts_data.index, ts_data.values, 'bo-', markersize=5)
                ax2.set_xlabel('Days to Expiration')
                ax2.set_ylabel('Implied Volatility')
                ax2.set_title('ATM Term Structure (Pre-Earnings)')
            
            # Volatility smile for specific TTE
            ax3 = fig.add_subplot(223)
            short_tte = pre_earnings[(pre_earnings['tte'] >= 14) & (pre_earnings['tte'] <= 30)]
            if len(short_tte) > 0:
                smile_data = short_tte.groupby('moneyness')['impl_volatility'].mean()
                ax3.plot(smile_data.index, smile_data.values, 'ro-', markersize=4)
                ax3.set_xlabel('Moneyness')
                ax3.set_ylabel('Implied Volatility')
                ax3.set_title('Volatility Smile (14-30 DTE)')
                ax3.axvline(x=1.0, color='black', linestyle='--', alpha=0.5, label='ATM')
                ax3.legend()
            
            # Volume and Open Interest analysis
            ax4 = fig.add_subplot(224)
            if 'volume' in pre_earnings.columns and 'open_interest' in pre_earnings.columns:
                # Group by strike and sum volume/OI
                strike_analysis = pre_earnings.groupby('strike').agg({
                    'volume': 'sum',
                    'open_interest': 'sum',
                    'impl_volatility': 'mean'
                }).reset_index()
                
                # Filter for reasonable strikes (within 20% of ATM)
                current_price = strike_analysis['strike'].median()  # Approximate current price
                reasonable_strikes = strike_analysis[
                    (strike_analysis['strike'] >= current_price * 0.8) & 
                    (strike_analysis['strike'] <= current_price * 1.2)
                ]
                
                if len(reasonable_strikes) > 0:
                    ax4_vol = ax4.twinx()
                    
                    # Volume bars
                    ax4.bar(reasonable_strikes['strike'], reasonable_strikes['volume'], 
                           alpha=0.6, color='blue', label='Volume')
                    ax4.set_xlabel('Strike Price')
                    ax4.set_ylabel('Volume', color='blue')
                    ax4.tick_params(axis='y', labelcolor='blue')
                    
                    # Open Interest line
                    ax4_vol.plot(reasonable_strikes['strike'], reasonable_strikes['open_interest'], 
                               'ro-', markersize=4, label='Open Interest')
                    ax4_vol.set_ylabel('Open Interest', color='red')
                    ax4_vol.tick_params(axis='y', labelcolor='red')
                    
                    ax4.set_title('Volume & Open Interest by Strike')
                    ax4.legend(loc='upper left')
                    ax4_vol.legend(loc='upper right')
            
            plt.tight_layout()
            plt.show()
            
            # Additional analysis and statistics
            self._print_surface_statistics(pre_earnings, ticker)
    
    def print_surface_statistics(self, pre_earnings, ticker):
        """
        Print detailed statistics about the pre-earnings surface
        """
        print(f"\n{ticker} Pre-Earnings Surface Statistics:")
        print("=" * 50)
        
        # IV statistics
        if 'impl_volatility' in pre_earnings.columns:
            iv_stats = pre_earnings['impl_volatility'].describe()
            print(f"Implied Volatility Statistics:")
            print(f"  Mean: {iv_stats['mean']:.2%}")
            print(f"  Median: {pre_earnings['impl_volatility'].median():.2%}")
            print(f"  Std Dev: {iv_stats['std']:.2%}")
            print(f"  Min: {iv_stats['min']:.2%}")
            print(f"  Max: {iv_stats['max']:.2%}")
        
        # Term structure analysis
        if 'tte' in pre_earnings.columns:
            print(f"\nTime to Expiration Range:")
            print(f"  Min: {pre_earnings['tte'].min():.0f} days")
            print(f"  Max: {pre_earnings['tte'].max():.0f} days")
            
            # Average IV by expiration bucket
            tte_buckets = pd.cut(pre_earnings['tte'], 
                               bins=[0, 7, 14, 30, 60, 90, float('inf')],
                               labels=['0-7d', '7-14d', '14-30d', '30-60d', '60-90d', '90d+'])
            
            bucket_iv = pre_earnings.groupby(tte_buckets)['impl_volatility'].mean()
            print(f"\nAverage IV by Expiration Bucket:")
            for bucket, iv in bucket_iv.items():
                if not pd.isna(iv):
                    print(f"  {bucket}: {iv:.2%}")
        
        # Moneyness analysis
        if 'moneyness' in pre_earnings.columns:
            print(f"\nMoneyness Distribution:")
            money_buckets = pd.cut(pre_earnings['moneyness'], 
                                 bins=[0, 0.9, 0.95, 1.05, 1.1, float('inf')],
                                 labels=['Deep OTM', 'OTM', 'ATM', 'ITM', 'Deep ITM'])
            
            bucket_counts = pre_earnings.groupby(money_buckets).size()
            bucket_iv = pre_earnings.groupby(money_buckets)['impl_volatility'].mean()
            
            for bucket in bucket_counts.index:
                if not pd.isna(bucket_iv[bucket]):
                    print(f"  {bucket}: {bucket_counts[bucket]} options, Avg IV: {bucket_iv[bucket]:.2%}")
        
        # Volume analysis
        if 'volume' in pre_earnings.columns:
            total_volume = pre_earnings['volume'].sum()
            avg_volume = pre_earnings['volume'].mean()
            print(f"\nVolume Analysis:")
            print(f"  Total Volume: {total_volume:,.0f}")
            print(f"  Average Volume per Option: {avg_volume:.1f}")
            
            # High volume options
            high_vol_threshold = pre_earnings['volume'].quantile(0.9)
            high_vol_options = pre_earnings[pre_earnings['volume'] >= high_vol_threshold]
            print(f"  High Volume Options (>90th percentile): {len(high_vol_options)}")
            
            if len(high_vol_options) > 0:
                print(f"  Average IV for High Volume Options: {high_vol_options['impl_volatility'].mean():.2%}")
        
        # Put/Call analysis if option_type is available
        if 'option_type' in pre_earnings.columns:
            put_call_summary = pre_earnings.groupby('option_type').agg({
                'impl_volatility': 'mean',
                'volume': 'sum'
            })
            
            print(f"\nPut/Call Analysis:")
            for opt_type, data in put_call_summary.iterrows():
                print(f"  {opt_type.upper()}s - Avg IV: {data['impl_volatility']:.2%}, Total Volume: {data['volume']:,.0f}")
        
        print("=" * 50)

In [5]:
# Example usage function
def run_earnings_iv_analysis(wrds_connection, tickers=['AAPL', 'MSFT', 'GOOGL'], 
                           start_date='2023-01-01', end_date='2024-12-31'):
    """
    Convenience function to run the complete analysis
    
    Args:
        wrds_connection: Active WRDS database connection
        tickers: List of ticker symbols to analyze
        start_date: Analysis start date
        end_date: Analysis end date
    
    Returns:
        EarningsIVDataPipeline: Configured pipeline object with results
    """
    
    # Initialize pipeline
    pipeline = EarningsIVDataPipeline(wrds_connection)
    
    # Run full analysis
    success = pipeline.run_full_analysis(tickers, start_date, end_date)
    
    if success:
        print(f"\n📁 Exporting results...")
        pipeline.export_data()
        print(f"\n🎉 Analysis complete! Results saved and pipeline ready for further exploration.")
    else:
        print(f"\n💥 Analysis failed. Check the logs above for details.")
    
    return pipeline

In [13]:
# Main execution block
if __name__ == "__main__":
    """
    Example usage - modify as needed
    """
    
    import wrds
    
    # Connect to WRDS
    db = wrds.Connection(wrds_username='joycexu020113')
    
    # Define analysis parameters
    #TICKERS = ['AAPL', 'MSFT', 'GOOGL', 'TSLA', 'NVDA']
    TICKERS = ['AAPL']

    START_DATE = '2023-01-01'
    END_DATE = '2024-12-31'
    
    # Run analysis
    pipeline = run_earnings_iv_analysis(
        wrds_connection=db,
        tickers=TICKERS,
        start_date=START_DATE,
        end_date=END_DATE
    )
    
    # Access results
    if 'earnings_options' in pipeline.data:
        print(f"Final dataset shape: {pipeline.data['earnings_options'].shape}")
        print(f"Available columns: {list(pipeline.data['earnings_options'].columns)}")
    
    # Close connection
    db.close()
    
    print("Earnings IV Analysis Pipeline Loaded Successfully!")
    print("To use: create WRDS connection and call run_earnings_iv_analysis()")


Loading library list...
Done


AttributeError: 'EarningsIVDataPipeline' object has no attribute 'run_full_analysis'


# 📈 Earnings Volatility Forecasting Project

**Objective**: Predict post-earnings realized volatility using pre-earnings implied volatility and option features, as motivated by Wolfe Research's paper on Unexpected Earnings Risk.

**This Notebook Covers:**
- Data filtering
- Realized volatility computation
- Pre-earnings implied volatility extraction
- Simple regression and evaluation interface
- Visualizations and case study on AAPL


In [None]:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.kernel_ridge import KernelRidge
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, r2_score


In [None]:

def compute_realized_volatility(price_series, window=21):
    log_returns = np.log(price_series / price_series.shift(1))
    realized_vol = log_returns.rolling(window).std() * np.sqrt(252)
    return realized_vol


In [None]:

def filter_option_data(df, ticker='AAPL', moneyness_range=(0.95, 1.05), 
                       maturity_range=(10, 45), min_volume=100):
    df_filtered = df.copy()
    df_filtered = df_filtered[df_filtered['ticker'] == ticker]
    df_filtered = df_filtered[(df_filtered['moneyness'] >= moneyness_range[0]) &
                              (df_filtered['moneyness'] <= moneyness_range[1])]
    df_filtered = df_filtered[(df_filtered['days_to_expiry'] >= maturity_range[0]) &
                              (df_filtered['days_to_expiry'] <= maturity_range[1])]
    df_filtered = df_filtered[df_filtered['volume'] >= min_volume]
    return df_filtered


In [None]:

def kernel_regression(X_train, y_train, X_test, y_test, gamma=0.1, alpha=1.0):
    model = KernelRidge(kernel='rbf', gamma=gamma, alpha=alpha)
    model.fit(X_train, y_train)
    preds = model.predict(X_test)
    print("R^2:", r2_score(y_test, preds))
    print("RMSE:", mean_squared_error(y_test, preds, squared=False))
    return model, preds


In [None]:

def plot_vol_series(dates, iv, rv, earnings_date=None):
    plt.figure(figsize=(12,6))
    plt.plot(dates, iv, label='Implied Volatility')
    plt.plot(dates, rv, label='Realized Volatility', linestyle='--')
    if earnings_date:
        plt.axvline(x=earnings_date, color='red', linestyle='--', label='Earnings')
    plt.legend()
    plt.title("IV and RV Over Time")
    plt.xlabel("Date")
    plt.ylabel("Volatility")
    plt.show()


In [None]:

def plot_iv_surface_slices(pre_earnings):
    import matplotlib.pyplot as plt

    fig = plt.figure(figsize=(14, 10))
    
    # ATM term structure
    ax2 = fig.add_subplot(221)
    atm_data = pre_earnings[
        (pre_earnings['moneyness'] >= 0.95) & 
        (pre_earnings['moneyness'] <= 1.05)
    ]
    if len(atm_data) > 0:
        ts_data = atm_data.groupby('tte')['impl_volatility'].mean()
        ax2.plot(ts_data.index, ts_data.values, 'bo-', markersize=5)
        ax2.set_xlabel('Days to Expiration')
        ax2.set_ylabel('Implied Volatility')
        ax2.set_title('ATM Term Structure (Pre-Earnings)')

    # Volatility smile for a specific TTE range (14–30 days)
    ax3 = fig.add_subplot(222)
    short_tte = pre_earnings[(pre_earnings['tte'] >= 14) & (pre_earnings['tte'] <= 30)]
    if len(short_tte) > 0:
        smile_data = short_tte.groupby('moneyness')['impl_volatility'].mean()
        ax3.plot(smile_data.index, smile_data.values, 'go-', markersize=5)
        ax3.set_xlabel('Moneyness')
        ax3.set_ylabel('Implied Volatility')
        ax3.set_title('Volatility Smile (TTE 14–30 Days)')
    
    plt.tight_layout()
    plt.show()
