In [None]:
import pandas as pd
import numpy as np
from datetime import datetime, timedelta, time
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from scipy.spatial.distance import euclidean, cosine
from dataclasses import dataclass
from typing import Dict, List, Tuple
import warnings
warnings.filterwarnings('ignore')

# ============================================================================
# MOCK DATA GENERATOR
# ============================================================================

class MarketDataGenerator:
    """Generate realistic mock market microstructure data"""

    def __init__(self, seed=42):
        np.random.seed(seed)
        self.symbols = {
            'AAPL': {'price': 180.0, 'volatility': 0.02, 'spread_bps': 1.0, 'avg_volume': 1000000},
            'MSFT': {'price': 380.0, 'volatility': 0.018, 'spread_bps': 1.2, 'avg_volume': 800000},
            'GOOGL': {'price': 140.0, 'volatility': 0.022, 'spread_bps': 1.5, 'avg_volume': 600000},
            'TSLA': {'price': 240.0, 'volatility': 0.035, 'spread_bps': 2.0, 'avg_volume': 1200000},
            'JPM': {'price': 155.0, 'volatility': 0.015, 'spread_bps': 0.8, 'avg_volume': 500000}
        }

    def generate_intraday_volume_curve(self, n_buckets=390):
        """
        Generate realistic U-shaped intraday volume curve
        High at open, low mid-day, high at close
        """
        x = np.linspace(0, 1, n_buckets)
        # U-shaped curve using polynomial
        volume_curve = 2.5 * (x ** 2) - 2.5 * x + 1.5
        # Add some randomness
        volume_curve += np.random.normal(0, 0.1, n_buckets)
        volume_curve = np.maximum(volume_curve, 0.3)  # Floor at 30% of mean
        return volume_curve / volume_curve.mean()  # Normalize to mean=1

    def generate_day_data(self, symbol, date, is_trade_day=False,
                          trade_intensity=1.0, trade_start_pct=0.3, trade_end_pct=0.7):
        """
        Generate one day of tick data for a symbol

        is_trade_day: If True, adds detectable footprint
        trade_intensity: How much extra volume on trade days (1.0 = normal, 2.0 = double)
        trade_start_pct: When trading starts (0.3 = 30% into day)
        trade_end_pct: When trading ends (0.7 = 70% into day)
        """
        params = self.symbols[symbol]

        # Market hours: 9:30 AM to 4:00 PM = 390 minutes
        market_open = datetime.combine(date, time(9, 30))
        n_minutes = 390

        # Generate volume curve
        base_volume_curve = self.generate_intraday_volume_curve(n_minutes)

        # If trade day, add footprint in specified window
        if is_trade_day:
            trade_start_idx = int(n_minutes * trade_start_pct)
            trade_end_idx = int(n_minutes * trade_end_pct)

            # Add volume spike during trading window
            volume_multiplier = np.ones(n_minutes)
            volume_multiplier[trade_start_idx:trade_end_idx] *= trade_intensity

            # Add some "leakage" - slight volume increase before/after
            if trade_start_idx > 10:
                volume_multiplier[trade_start_idx-10:trade_start_idx] *= 1.1
            if trade_end_idx < n_minutes - 10:
                volume_multiplier[trade_end_idx:trade_end_idx+10] *= 1.1

            base_volume_curve *= volume_multiplier

        # Generate tick data
        ticks = []
        current_price = params['price']

        for minute in range(n_minutes):
            timestamp = market_open + timedelta(minutes=minute)

            # Number of ticks this minute (Poisson distributed)
            n_ticks = max(1, int(np.random.poisson(5 * base_volume_curve[minute])))

            for tick in range(n_ticks):
                # Price movement (random walk with mean reversion)
                price_change = np.random.normal(0, params['volatility'] * params['price'] / 100)
                current_price += price_change
                current_price = max(current_price, params['price'] * 0.95)  # Floor
                current_price = min(current_price, params['price'] * 1.05)  # Ceiling

                # Spread (wider during low volume)
                spread_multiplier = 1.0 + (1.5 - base_volume_curve[minute])
                spread = params['price'] * params['spread_bps'] / 10000 * spread_multiplier
                spread = max(spread, 0.01)

                mid_price = current_price
                bid = mid_price - spread / 2
                ask = mid_price + spread / 2

                # Volume (log-normal distribution)
                volume = int(np.random.lognormal(
                    np.log(params['avg_volume'] / n_minutes / 5),
                    0.8
                ) * base_volume_curve[minute])
                volume = max(volume, 100)

                # Trade direction (slight buy pressure on trade days during window)
                if is_trade_day and trade_start_idx <= minute < trade_end_idx:
                    trade_direction = np.random.choice([1, -1], p=[0.55, 0.45])  # 55% buy
                else:
                    trade_direction = np.random.choice([1, -1], p=[0.50, 0.50])  # 50/50

                # Execution price (between bid/ask)
                if trade_direction == 1:  # Buy
                    price = ask - np.random.uniform(0, spread * 0.3)
                else:  # Sell
                    price = bid + np.random.uniform(0, spread * 0.3)

                # Depth (inverse relationship with volume)
                depth_multiplier = 2.0 - base_volume_curve[minute]
                bid_size = int(np.random.lognormal(8, 1) * depth_multiplier)
                ask_size = int(np.random.lognormal(8, 1) * depth_multiplier)

                # Order book imbalance (if trade day, slight bid-side pressure)
                if is_trade_day and trade_start_idx <= minute < trade_end_idx:
                    bid_size = int(bid_size * 1.15)  # More bid depth

                # Quote updates (more during volatile periods)
                quote_update = 1 if np.random.random() < 0.3 else 0

                tick_data = {
                    'symbol': symbol,
                    'date': date,
                    'timestamp': timestamp + timedelta(seconds=tick * (60 / n_ticks)),
                    'price': round(price, 2),
                    'volume': volume,
                    'bid': round(bid, 2),
                    'ask': round(ask, 2),
                    'bid_size': bid_size,
                    'ask_size': ask_size,
                    'trade_direction': trade_direction,
                    'quote_update': quote_update
                }

                ticks.append(tick_data)

        return pd.DataFrame(ticks)

    def generate_multi_day_data(self, symbols, start_date, n_days=30,
                                trade_days_per_symbol=None):
        """
        Generate multiple days for multiple symbols

        trade_days_per_symbol: Dict[symbol, List[date]] - which days you traded each symbol
        """
        if trade_days_per_symbol is None:
            trade_days_per_symbol = {}

        all_data = []
        dates = [start_date + timedelta(days=i) for i in range(n_days)]

        # Filter out weekends
        dates = [d for d in dates if d.weekday() < 5]

        for symbol in symbols:
            symbol_trade_days = trade_days_per_symbol.get(symbol, set())

            for date in dates:
                is_trade_day = date in symbol_trade_days

                # Vary trade intensity across days
                if is_trade_day:
                    trade_intensity = np.random.uniform(1.3, 1.8)
                    trade_start = np.random.uniform(0.2, 0.4)
                    trade_end = np.random.uniform(0.6, 0.8)
                else:
                    trade_intensity = 1.0
                    trade_start = 0.3
                    trade_end = 0.7

                day_data = self.generate_day_data(
                    symbol, date, is_trade_day,
                    trade_intensity, trade_start, trade_end
                )
                all_data.append(day_data)

        return pd.concat(all_data, ignore_index=True)


# ============================================================================
# FOOTPRINT ANALYZER (from previous code)
# ============================================================================

@dataclass
class FootprintMetrics:
    """All metrics for a single time bucket"""
    volume: float
    trade_count: int
    vwap: float
    price_return: float
    volatility: float
    spread: float
    spread_pct: float
    bid_depth: float
    ask_depth: float
    depth_imbalance: float
    trade_imbalance: float
    quote_intensity: int
    price_impact: float
    effective_spread: float
    realized_spread: float


class ComprehensiveFootprintAnalyzer:
    def __init__(self, lookback_days=21, bucket_minutes=1):
        self.lookback_days = lookback_days
        self.bucket_minutes = bucket_minutes

        self.metric_groups = {
            'volume_metrics': ['volume', 'trade_count', 'trade_imbalance'],
            'price_metrics': ['vwap', 'price_return', 'price_impact'],
            'volatility_metrics': ['volatility', 'realized_spread'],
            'liquidity_metrics': ['spread', 'spread_pct', 'effective_spread'],
            'depth_metrics': ['bid_depth', 'ask_depth', 'depth_imbalance'],
            'activity_metrics': ['quote_intensity']
        }

        self.metric_weights = {
            'volume': 0.20,
            'trade_count': 0.10,
            'price_return': 0.15,
            'volatility': 0.15,
            'spread': 0.10,
            'depth_imbalance': 0.10,
            'trade_imbalance': 0.10,
            'price_impact': 0.10
        }

    def compute_bucket_metrics(self, bucket_data, prev_bucket_data=None):
        """Compute all metrics for a single time bucket"""
        if len(bucket_data) == 0:
            return None

        metrics = {}

        # Volume metrics
        metrics['volume'] = bucket_data['volume'].sum()
        metrics['trade_count'] = len(bucket_data)
        metrics['vwap'] = (bucket_data['price'] * bucket_data['volume']).sum() / metrics['volume'] if metrics['volume'] > 0 else bucket_data['price'].mean()

        # Trade direction imbalance
        if 'trade_direction' in bucket_data.columns:
            buy_vol = bucket_data[bucket_data['trade_direction'] == 1]['volume'].sum()
            sell_vol = bucket_data[bucket_data['trade_direction'] == -1]['volume'].sum()
            total_vol = buy_vol + sell_vol
            metrics['trade_imbalance'] = (buy_vol - sell_vol) / total_vol if total_vol > 0 else 0
        else:
            metrics['trade_imbalance'] = 0

        # Price metrics
        if prev_bucket_data is not None and len(prev_bucket_data) > 0:
            prev_vwap = (prev_bucket_data['price'] * prev_bucket_data['volume']).sum() / prev_bucket_data['volume'].sum()
            metrics['price_return'] = (metrics['vwap'] - prev_vwap) / prev_vwap if prev_vwap > 0 else 0
        else:
            metrics['price_return'] = 0

        metrics['volatility'] = bucket_data['price'].std() if len(bucket_data) > 1 else 0

        # Spread metrics
        if 'bid' in bucket_data.columns and 'ask' in bucket_data.columns:
            metrics['spread'] = (bucket_data['ask'] - bucket_data['bid']).mean()
            mid_price = (bucket_data['bid'] + bucket_data['ask']) / 2
            metrics['spread_pct'] = (metrics['spread'] / mid_price.mean()) * 100 if mid_price.mean() > 0 else 0
            metrics['effective_spread'] = 2 * np.abs(bucket_data['price'] - mid_price).mean()
        else:
            metrics['spread'] = 0
            metrics['spread_pct'] = 0
            metrics['effective_spread'] = 0

        # Depth metrics
        if 'bid_size' in bucket_data.columns and 'ask_size' in bucket_data.columns:
            metrics['bid_depth'] = bucket_data['bid_size'].mean()
            metrics['ask_depth'] = bucket_data['ask_size'].mean()
            total_depth = metrics['bid_depth'] + metrics['ask_depth']
            metrics['depth_imbalance'] = (metrics['bid_depth'] - metrics['ask_depth']) / total_depth if total_depth > 0 else 0
        else:
            metrics['bid_depth'] = 0
            metrics['ask_depth'] = 0
            metrics['depth_imbalance'] = 0

        # Quote intensity
        if 'quote_update' in bucket_data.columns:
            metrics['quote_intensity'] = bucket_data['quote_update'].sum()
        else:
            metrics['quote_intensity'] = 0

        # Price impact
        if metrics['volume'] > 0 and len(bucket_data) > 1:
            price_move = np.abs(bucket_data['price'].iloc[-1] - bucket_data['price'].iloc[0])
            metrics['price_impact'] = (price_move / bucket_data['price'].iloc[0]) / (metrics['volume'] / 1e6) if bucket_data['price'].iloc[0] > 0 else 0
        else:
            metrics['price_impact'] = 0

        metrics['realized_spread'] = 0

        return metrics

    def create_comprehensive_baseline(self, historical_data, trade_dates, current_date, symbol):
        """Create baseline profile with ALL metrics"""
        # Filter for this symbol
        symbol_data = historical_data[historical_data['symbol'] == symbol].copy()

        # Get prior dates excluding trading days
        all_dates = sorted(symbol_data['date'].unique())
        all_dates = [d for d in all_dates if d < current_date]

        non_trade_dates = [d for d in all_dates if d not in trade_dates]
        baseline_dates = non_trade_dates[-self.lookback_days:]

        if len(baseline_dates) < self.lookback_days:
            print(f"‚ö†Ô∏è Warning: Only {len(baseline_dates)} non-trading days available for {symbol}")

        baseline_data = symbol_data[symbol_data['date'].isin(baseline_dates)].copy()

        # Create time buckets
        baseline_data['time_bucket'] = pd.to_datetime(
            baseline_data['timestamp']
        ).dt.floor(f'{self.bucket_minutes}min').dt.time

        # Compute metrics for each date and time bucket
        baseline_profiles = []

        for date in baseline_dates:
            date_data = baseline_data[baseline_data['date'] == date]
            time_buckets = sorted(date_data['time_bucket'].unique())

            prev_bucket_data = None
            for tb in time_buckets:
                bucket_data = date_data[date_data['time_bucket'] == tb]
                metrics = self.compute_bucket_metrics(bucket_data, prev_bucket_data)

                if metrics:
                    metrics['date'] = date
                    metrics['time_bucket'] = tb
                    baseline_profiles.append(metrics)

                prev_bucket_data = bucket_data

        baseline_df = pd.DataFrame(baseline_profiles)

        if len(baseline_df) == 0:
            return None, None

        # Average across all baseline dates
        baseline_avg = baseline_df.groupby('time_bucket').agg({
            metric: ['mean', 'std', 'median']
            for metric in self.metric_weights.keys()
        }).reset_index()

        # Flatten column names
        baseline_avg.columns = ['_'.join(col).strip('_') for col in baseline_avg.columns.values]

        return baseline_avg, baseline_df

    def compute_trade_day_profile(self, trade_day_data):
        """Compute full metric profile for trade day"""
        trade_day_data = trade_day_data.copy()
        trade_day_data['time_bucket'] = pd.to_datetime(
            trade_day_data['timestamp']
        ).dt.floor(f'{self.bucket_minutes}min').dt.time

        time_buckets = sorted(trade_day_data['time_bucket'].unique())
        trade_profiles = []

        prev_bucket_data = None
        for tb in time_buckets:
            bucket_data = trade_day_data[trade_day_data['time_bucket'] == tb]
            metrics = self.compute_bucket_metrics(bucket_data, prev_bucket_data)

            if metrics:
                metrics['time_bucket'] = tb
                trade_profiles.append(metrics)

            prev_bucket_data = bucket_data

        return pd.DataFrame(trade_profiles)

    def compute_metric_similarity(self, trade_values, baseline_mean, baseline_std):
        """Compute multiple similarity measures for a single metric"""
        # Normalize
        trade_norm = trade_values / (trade_values.sum() + 1e-10)
        baseline_norm = baseline_mean / (baseline_mean.sum() + 1e-10)

        similarity_scores = {}

        # Correlation
        if len(trade_norm) > 1 and len(baseline_norm) > 1:
            similarity_scores['pearson'] = stats.pearsonr(trade_norm, baseline_norm)[0]
            similarity_scores['spearman'] = stats.spearmanr(trade_norm, baseline_norm)[0]
        else:
            similarity_scores['pearson'] = np.nan
            similarity_scores['spearman'] = np.nan

        # Distance metrics
        similarity_scores['euclidean'] = euclidean(trade_norm, baseline_norm)
        similarity_scores['cosine_sim'] = 1 - cosine(trade_norm, baseline_norm)

        # KS test
        ks_stat, ks_pval = stats.ks_2samp(trade_values, baseline_mean)
        similarity_scores['ks_statistic'] = ks_stat
        similarity_scores['ks_pvalue'] = ks_pval

        # KL divergence
        similarity_scores['kl_divergence'] = stats.entropy(
            trade_norm + 1e-10,
            baseline_norm + 1e-10
        )

        # Z-scores
        if baseline_std is not None and (baseline_std > 0).any():
            z_scores = (trade_values - baseline_mean) / (baseline_std + 1e-10)
            similarity_scores['mean_z_score'] = np.abs(z_scores).mean()
            similarity_scores['max_z_score'] = np.abs(z_scores).max()
        else:
            similarity_scores['mean_z_score'] = np.nan
            similarity_scores['max_z_score'] = np.nan

        # Wasserstein
        similarity_scores['wasserstein'] = stats.wasserstein_distance(trade_norm, baseline_norm)

        return similarity_scores

    def analyze_full_footprint(self, trade_day_profile, baseline_avg):
        """Compare trade day vs baseline across ALL metrics"""
        # Merge on time_bucket
        merged = trade_day_profile.merge(
            baseline_avg,
            on='time_bucket',
            suffixes=('_trade', '_baseline')
        )

        results = {}

        # Analyze each metric
        for metric in self.metric_weights.keys():
            trade_col = f'{metric}_trade' if f'{metric}_trade' in merged.columns else metric
            baseline_mean_col = f'{metric}_mean'
            baseline_std_col = f'{metric}_std'

            if trade_col in merged.columns and baseline_mean_col in merged.columns:
                trade_values = merged[trade_col].values
                baseline_mean = merged[baseline_mean_col].values
                baseline_std = merged[baseline_std_col].values if baseline_std_col in merged.columns else None

                results[metric] = self.compute_metric_similarity(
                    trade_values, baseline_mean, baseline_std
                )

        return results, merged

    def compute_composite_score(self, metric_results):
        """Aggregate similarity across all metrics"""
        composite_scores = {
            'weighted_correlation': 0,
            'weighted_distance': 0,
            'weighted_ks_pval': 0,
            'suspicious_metrics': []
        }

        total_weight = 0

        for metric, weight in self.metric_weights.items():
            if metric in metric_results:
                result = metric_results[metric]

                # Correlation component
                if not np.isnan(result.get('pearson', np.nan)):
                    composite_scores['weighted_correlation'] += weight * result['pearson']
                    total_weight += weight

                # Distance component
                eucl = result.get('euclidean', 0)
                composite_scores['weighted_distance'] += weight * (1 / (1 + eucl))

                # Statistical significance
                ks_pval = result.get('ks_pvalue', 1)
                composite_scores['weighted_ks_pval'] += weight * ks_pval

                # Flag suspicious metrics
                if result.get('pearson', 1) < 0.7 or result.get('ks_pvalue', 1) < 0.05:
                    composite_scores['suspicious_metrics'].append({
                        'metric': metric,
                        'correlation': result.get('pearson', np.nan),
                        'ks_pvalue': result.get('ks_pvalue', np.nan),
                        'mean_z_score': result.get('mean_z_score', np.nan)
                    })

        # Normalize
        if total_weight > 0:
            composite_scores['weighted_correlation'] /= total_weight

        # Detectability score (0-100)
        detectability = 100 * (1 - composite_scores['weighted_correlation'])
        composite_scores['detectability_score'] = detectability

        return composite_scores

    def generate_report(self, metric_results, composite_scores, symbol, trade_date):
        """Generate human-readable report"""
        print("=" * 80)
        print(f"VWAP FOOTPRINT ANALYSIS - {symbol} - {trade_date}")
        print("=" * 80)
        print()

        print(f"üìä COMPOSITE DETECTABILITY SCORE: {composite_scores['detectability_score']:.1f}/100")
        print()

        if composite_scores['detectability_score'] < 20:
            print("‚úÖ EXCELLENT: Your footprint blends in very well with normal market activity")
        elif composite_scores['detectability_score'] < 40:
            print("‚úì GOOD: Footprint is reasonably well disguised")
        elif composite_scores['detectability_score'] < 60:
            print("‚ö†Ô∏è MODERATE: Some detectability - consider randomization")
        else:
            print("üö® HIGH RISK: Footprint is highly detectable - likely exploitable!")

        print()
        print("-" * 80)
        print("METRIC-BY-METRIC ANALYSIS")
        print("-" * 80)
        print()

        # Sort metrics by correlation
        sorted_metrics = sorted(
            metric_results.items(),
            key=lambda x: x[1].get('pearson', 0)
        )

        for metric, scores in sorted_metrics:
            corr = scores.get('pearson', np.nan)
            ks_pval = scores.get('ks_pvalue', np.nan)
            mean_z = scores.get('mean_z_score', np.nan)

            status = "‚úÖ" if corr > 0.8 else "‚ö†Ô∏è" if corr > 0.6 else "üö®"

            print(f"{status} {metric.upper()}")
            print(f"   Correlation:    {corr:.3f}")
            print(f"   KS p-value:     {ks_pval:.4f}")
            print(f"   Mean Z-score:   {mean_z:.2f}")
            print(f"   Euclidean dist: {scores.get('euclidean', np.nan):.4f}")
            print()

        if composite_scores['suspicious_metrics']:
            print("-" * 80)
            print("üîç SUSPICIOUS METRICS (requiring attention)")
            print("-" * 80)
            for item in composite_scores['suspicious_metrics']:
                print(f"  ‚Ä¢ {item['metric']}: correlation={item['correlation']:.3f}, "
                      f"p-value={item['ks_pvalue']:.4f}")
            print()

        print("=" * 80)
        print()


def plot_comprehensive_footprint(merged_data, metric_results, symbol, trade_date):
    """Create comprehensive visualization"""
    key_metrics = ['volume', 'volatility', 'spread', 'depth_imbalance',
                   'trade_imbalance', 'price_impact']

    n_metrics = len(key_metrics)
    fig, axes = plt.subplots(n_metrics, 2, figsize=(16, 4*n_metrics))

    for idx, metric in enumerate(key_metrics):
        # Left: Time series comparison
        ax1 = axes[idx, 0]
        x = range(len(merged_data))

        trade_col = f'{metric}_trade' if f'{metric}_trade' in merged_data.columns else metric
        baseline_col = f'{metric}_mean'
        std_col = f'{metric}_std'

        if baseline_col in merged_data.columns:
            ax1.plot(x, merged_data[baseline_col],
                    label='Baseline (21-day avg)', linewidth=2, alpha=0.7, color='blue')
            if std_col in merged_data.columns:
                ax1.fill_between(x,
                               merged_data[baseline_col] - merged_data[std_col],
                               merged_data[baseline_col] + merged_data[std_col],
                               alpha=0.2, color='blue')

        if trade_col in merged_data.columns:
            ax1.plot(x, merged_data[trade_col],
                    label=f'Trade Day', linewidth=2, alpha=0.8, color='red')

        # Add correlation score
        if metric in metric_results:
            corr = metric_results[metric].get('pearson', np.nan)
            ax1.text(0.02, 0.98, f'Correlation: {corr:.3f}',
                    transform=ax1.transAxes, fontsize=10,
                    verticalalignment='top',
                    bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

        ax1.set_title(f'{metric.replace("_", " ").title()} - Time Series')
        ax1.set_xlabel('Time Bucket (minutes)')
        ax1.set_ylabel(metric)
        ax1.legend()
        ax1.grid(True, alpha=0.3)

        # Right: Distribution comparison
        ax2 = axes[idx, 1]

        if baseline_col in merged_data.columns and trade_col in merged_data.columns:
            ax2.hist(merged_data[baseline_col], bins=20, alpha=0.5,
                    label='Baseline', color='blue', density=True)
            ax2.hist(merged_data[trade_col], bins=20, alpha=0.5,
                    label='Trade Day', color='red', density=True)

            # Add KS test result
            if metric in metric_results:
                ks_pval = metric_results[metric].get('ks_pvalue', np.nan)
                ax2.text(0.02, 0.98, f'KS p-value: {ks_pval:.4f}',
                        transform=ax2.transAxes, fontsize=10,
                        verticalalignment='top',
                        bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

        ax2.set_title(f'{metric.replace("_", " ").title()} - Distribution')
        ax2.set_xlabel(metric)
        ax2.set_ylabel('Density')
        ax2.legend()
        ax2.grid(True, alpha=0.3)

    plt.suptitle(f'Comprehensive Footprint Analysis - {symbol} - {trade_date}',
                 fontsize=16, y=1.001)
    plt.tight_layout()
    return fig


def plot_multi_symbol_heatmap(all_results):
    """
    Heatmap showing detectability across symbols and dates

    all_results: Dict[(symbol, date), composite_scores]
    """
    # Extract data for heatmap
    symbols = sorted(set(k[0] for k in all_results.keys()))
    dates = sorted(set(k[1] for k in all_results.keys()))

    # Create matrix
    matrix = np.zeros((len(symbols), len(dates)))

    for i, symbol in enumerate(symbols):
        for j, date in enumerate(dates):
            key = (symbol, date)
            if key in all_results:
                matrix[i, j] = all_results[key]['detectability_score']
            else:
                matrix[i, j] = np.nan

    # Plot
    fig, ax = plt.subplots(figsize=(max(12, len(dates)*0.8), max(6, len(symbols)*0.6)))

    # Create mask for nan values
    mask = np.isnan(matrix)

    sns.heatmap(matrix,
                xticklabels=[d.strftime('%m/%d') for d in dates],
                yticklabels=symbols,
                annot=True,
                fmt='.1f',
                cmap='RdYlGn_r',  # Red = high detectability (bad), Green = low (good)
                center=50,
                vmin=0, vmax=100,
                mask=mask,
                ax=ax,
                cbar_kws={'label': 'Detectability Score (0-100)'})

    ax.set_title('Multi-Symbol Footprint Detectability Heatmap\nGreen=Well Disguised, Red=Highly Detectable')
    ax.set_xlabel('Trade Date')
    ax.set_ylabel('Symbol')

    plt.tight_layout()
    return fig


# ============================================================================
# COMPLETE WORKING EXAMPLE
# ============================================================================

def main():
    print("Generating mock market data...")
    print()

    # Initialize data generator
    generator = MarketDataGenerator(seed=42)

    # Define symbols to analyze
    symbols = ['AAPL', 'MSFT', 'GOOGL', 'TSLA']

    # Define trading schedule
    start_date = datetime(2024, 10, 1).date()
    n_days = 35

    # Define which days you traded each symbol
    all_dates = [start_date + timedelta(days=i) for i in range(n_days)]
    all_dates = [d for d in all_dates if d.weekday() < 5]  # Remove weekends

    # Pick some random trade dates for each symbol
    np.random.seed(42)
    trade_days_per_symbol = {
        'AAPL': set(np.random.choice(all_dates[-15:], size=5, replace=False)),
        'MSFT': set(np.random.choice(all_dates[-15:], size=4, replace=False)),
        'GOOGL': set(np.random.choice(all_dates[-15:], size=6, replace=False)),
        'TSLA': set(np.random.choice(all_dates[-15:], size=3, replace=False)),
    }

    # Generate data
    market_data = generator.generate_multi_day_data(
        symbols=symbols,
        start_date=start_date,
        n_days=n_days,
        trade_days_per_symbol=trade_days_per_symbol
    )

    print(f"‚úì Generated {len(market_data):,} ticks across {len(symbols)} symbols and {len(all_dates)} days")
    print(f"‚úì Data shape: {market_data.shape}")
    print(f"‚úì Date range: {market_data['date'].min()} to {market_data['date'].max()}")
    print()

    # Display sample
    print("Sample data:")
    print(market_data.head(10))
    print()

    # Initialize analyzer
    analyzer = ComprehensiveFootprintAnalyzer(lookback_days=15, bucket_minutes=5)

    # Run analysis for each symbol on their trade days
    all_results = {}

    for symbol in symbols:
        symbol_trade_dates = trade_days_per_symbol[symbol]

        print(f"\n{'='*80}")
        print(f"ANALYZING {symbol} - {len(symbol_trade_dates)} trade days")
        print(f"{'='*80}\n")

        for trade_date in sorted(symbol_trade_dates):
            try:
                # Create baseline
                baseline_avg, baseline_full = analyzer.create_comprehensive_baseline(
                    historical_data=market_data,
                    trade_dates=symbol_trade_dates,
                    current_date=trade_date,
                    symbol=symbol
                )

                if baseline_avg is None:
                    print(f"‚ö†Ô∏è Skipping {symbol} {trade_date} - insufficient baseline data")
                    continue

                # Get trade day data
                trade_day_data = market_data[
                    (market_data['symbol'] == symbol) &
                    (market_data['date'] == trade_date)
                ]

                if len(trade_day_data) == 0:
                    print(f"‚ö†Ô∏è No data for {symbol} on {trade_date}")
                    continue

                # Compute trade day profile
                trade_profile = analyzer.compute_trade_day_profile(trade_day_data)

                # Analyze footprint
                metric_results, merged_data = analyzer.analyze_full_footprint(
                    trade_profile, baseline_avg
                )

                # Compute composite score
                composite_scores = analyzer.compute_composite_score(metric_results)

                # Store results
                all_results[(symbol, trade_date)] = composite_scores

                # Generate report
                analyzer.generate_report(metric_results, composite_scores, symbol, trade_date)

                # Create visualization
                fig = plot_comprehensive_footprint(merged_data, metric_results, symbol, trade_date)
                plt.savefig(f'footprint_{symbol}_{trade_date}.png', dpi=100, bbox_inches='tight')
                plt.close()

            except Exception as e:
                print(f"‚ùå Error analyzing {symbol} {trade_date}: {e}")
                import traceback
                traceback.print_exc()
                continue

    # Create summary heatmap
    if all_results:
        print("\n" + "="*80)
        print("GENERATING MULTI-SYMBOL SUMMARY")
        print("="*80 + "\n")

        fig = plot_multi_symbol_heatmap(all_results)
        plt.savefig('footprint_summary_heatmap.png', dpi=150, bbox_inches='tight')
        plt.close()

        # Summary statistics
        print("\nSUMMARY STATISTICS:")
        print("-" * 80)

        for symbol in symbols:
            symbol_scores = [v['detectability_score'] for k, v in all_results.items() if k[0] == symbol]
            if symbol_scores:
                print(f"\n{symbol}:")
                print(f"  Mean detectability:   {np.mean(symbol_scores):.1f}")
                print(f"  Min detectability:    {np.min(symbol_scores):.1f}")
                print(f"  Max detectability:    {np.max(symbol_scores):.1f}")
                print(f"  Std deviation:        {np.std(symbol_scores):.1f}")

                high_risk_days = sum(1 for s in symbol_scores if s > 60)
                if high_risk_days > 0:
                    print(f"  üö® High risk days:    {high_risk_days}/{len(symbol_scores)}")

    print("\n" + "="*80)
    print("ANALYSIS COMPLETE!")
    print("="*80)
    print(f"\n‚úì Analyzed {len(all_results)} symbol-date combinations")
    print(f"‚úì Generated {len(all_results)} individual charts")
    print(f"‚úì Generated 1 summary heatmap")
    print("\nFiles created:")
    print("  - footprint_<SYMBOL>_<DATE>.png (individual analyses)")
    print("  - footprint_summary_heatmap.png (overview)")


if __name__ == "__main__":
    main()

Generating mock market data...

‚úì Generated 202,534 ticks across 4 symbols and 25 days
‚úì Data shape: (202534, 11)
‚úì Date range: 2024-10-01 to 2024-11-04

Sample data:
  symbol        date           timestamp   price  volume     bid     ask  \
0   AAPL  2024-10-01 2024-10-01 09:30:00  180.00     458  180.00  180.02   
1   AAPL  2024-10-01 2024-10-01 09:30:15  179.99    1996  179.97  180.00   
2   AAPL  2024-10-01 2024-10-01 09:30:30  180.01    1002  179.99  180.01   
3   AAPL  2024-10-01 2024-10-01 09:30:45  179.97    1378  179.97  179.99   
4   AAPL  2024-10-01 2024-10-01 09:31:00  179.96    1028  179.96  179.98   
5   AAPL  2024-10-01 2024-10-01 09:31:05  179.98     456  179.96  179.98   
6   AAPL  2024-10-01 2024-10-01 09:31:10  180.06     277  180.05  180.07   
7   AAPL  2024-10-01 2024-10-01 09:31:15  180.05    1582  180.05  180.07   
8   AAPL  2024-10-01 2024-10-01 09:31:20  180.04     630  180.04  180.06   
9   AAPL  2024-10-01 2024-10-01 09:31:25  180.01     264  180.00  1