# Set Up

In [121]:
import os
import sys
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from statsmodels.tsa.stattools import grangercausalitytests
import warnings
import json
from datetime import datetime

# Add src directory to path for importing utility functions
current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir) if current_dir.endswith('notebooks') else current_dir
if parent_dir not in sys.path:
    sys.path.append(parent_dir)

# Import utility functions
from src.utils.data_loader import load_main_dataset, load_trade_data

# Create results directory
results_dir = 'results/trader_analysis'
os.makedirs(results_dir, exist_ok=True)

# Suppress warnings
warnings.filterwarnings('ignore')

# Set plotting style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("viridis")

# Cell 2: Helper Functions (Remove function definitions, just define the functions)
def calculate_gini(values):
    if len(values) <= 1 or np.sum(values) == 0:
        return 0
    
    sorted_values = np.sort(values)
    n = len(sorted_values)
    cumsum = np.cumsum(sorted_values)
    return (n + 1 - 2 * np.sum((n + 1 - np.arange(1, n+1)) * sorted_values) / np.sum(sorted_values)) / n

In [122]:
# Cell 3: Load Market Data
print("Loading main dataset...")
market_data = load_main_dataset('data/cleaned_election_data.csv')

if market_data is not None:
    print(f"Successfully loaded {len(market_data)} markets")
    
    # Explore trader-related metrics
    trader_cols = [col for col in market_data.columns if any(term in col.lower() 
                   for term in ['trader', 'trade', 'concentration'])]
    
    print("\nAvailable trader-related columns:")
    for col in trader_cols:
        print(f"- {col}")
        
    # Display basic statistics for key trader metrics
    trader_metrics = ['unique_traders_count', 'trader_to_trade_ratio', 
                     'two_way_traders_ratio', 'new_trader_influx']
    available_metrics = [col for col in trader_metrics if col in market_data.columns]
    
    if available_metrics:
        print("\nSummary statistics for trader metrics:")
        print(market_data[available_metrics].describe())

Loading main dataset...
Loaded dataset with 1048575 rows and 54 columns
Successfully loaded 1048575 markets

Available trader-related columns:
- last_trade_price
- unique_traders_count
- trader_to_trade_ratio
- two_way_traders_ratio
- trader_concentration
- new_trader_influx
- comment_per_trader

Summary statistics for trader metrics:
       unique_traders_count  trader_to_trade_ratio  two_way_traders_ratio  \
count            489.000000             489.000000             489.000000   
mean            2192.934560               4.409763               0.312478   
std             5783.946414               3.273783               0.230296   
min               31.000000               1.227342               0.007680   
25%              295.000000               2.786284               0.134066   
50%              610.000000               3.389163               0.229508   
75%             1824.000000               4.644178               0.476971   
max            72183.000000              24.593

# Load Trade Data

In [123]:
def load_trade_data_for_analysis(market_ids=None, max_trades_per_market=None):
    """
    Load trade data for specific market IDs
    
    Parameters:
    -----------
    market_ids : list, optional
        List of specific market IDs to load
    max_trades_per_market : int, optional
        Maximum number of trades per market (None for all trades)
    
    Returns:
    --------
    pd.DataFrame
        Combined trade data from specified markets
    """
    print("Loading market data...")
    market_data = load_main_dataset('data/cleaned_election_data.csv')
    
    if market_data is None:
        print("Failed to load market data")
        return None
    
    # If market_ids not provided, use all markets
    if market_ids is None:
        market_ids = market_data['id'].tolist()
    
    print(f"Selected {len(market_ids)} markets for analysis")
    
    from src.utils.data_loader import load_trade_data, get_token_ids_for_market
    
    all_trades = []
    for i, market_id in enumerate(market_ids):
        try:
            # Get market name if available
            market_name = market_data.loc[market_data['id'] == market_id, 'question'].iloc[0] \
                         if 'question' in market_data.columns else f"Market {market_id}"
            
            print(f"\nLoading trades for market {i+1}/{len(market_ids)}: {market_name}")
            print(f"Market ID: {market_id}")
            
            # Try to load trade data using utility function
            trades = load_trade_data(market_id)
            
            if trades is not None and len(trades) > 0:
                print(f"Successfully loaded {len(trades)} trades directly")
                
                # Add market identifier
                trades['market_id'] = float(market_id)
                
                # Sample if max_trades_per_market is specified
                if max_trades_per_market is not None and len(trades) > max_trades_per_market:
                    print(f"Sampling {max_trades_per_market} trades from {len(trades)} total")
                    trades = trades.sample(max_trades_per_market, random_state=42)
                
                all_trades.append(trades)
            else:
                print("No trades found using default method, trying alternative approaches")
                
                # Try to get token IDs for this market
                token_ids = get_token_ids_for_market(market_id, main_df=market_data)
                
                if token_ids and len(token_ids) > 0:
                    print(f"Found {len(token_ids)} token IDs for market {market_id}")
                    
                    # Try to locate token files directly
                    from src.utils.data_loader import find_token_id_file
                    
                    market_trades = []
                    for token_id in token_ids:
                        try:
                            token_file = find_token_id_file(token_id)
                            if token_file:
                                print(f"Found token file: {os.path.basename(token_file)}")
                                
                                # Load this token's trades
                                import pyarrow.parquet as pq
                                token_trades = pq.read_table(token_file).to_pandas()
                                token_trades['market_id'] = float(market_id)
                                token_trades['token_id'] = token_id
                                
                                market_trades.append(token_trades)
                                print(f"Loaded {len(token_trades)} trades for token {token_id}")
                        except Exception as e:
                            print(f"Error loading token {token_id}: {e}")
                    
                    if market_trades:
                        combined_market_trades = pd.concat(market_trades, ignore_index=True)
                        
                        # Sample if specified
                        if max_trades_per_market is not None and len(combined_market_trades) > max_trades_per_market:
                            print(f"Sampling {max_trades_per_market} trades from {len(combined_market_trades)} total")
                            combined_market_trades = combined_market_trades.sample(max_trades_per_market, random_state=42)
                        
                        all_trades.append(combined_market_trades)
                    else:
                        print(f"No trade data found for any tokens in market {market_id}")
                else:
                    print(f"No token IDs found for market {market_id}")
        except Exception as e:
            print(f"Error loading trades for market {market_id}: {e}")
    
    if not all_trades:
        print("No trade data loaded for any markets")
        return None
    
    # Combine all trade data
    combined_trades = pd.concat(all_trades, ignore_index=True)
    
    # Debug logging
    print(f"\nTotal trades loaded: {len(combined_trades)} from {len(all_trades)} markets")
    print("\nMarket-wise trade counts:")
    print(combined_trades['market_id'].value_counts())
    
    
    
    # Standardize trader ID and other columns
    if 'maker' in combined_trades.columns and 'maker_id' not in combined_trades.columns:
        combined_trades['maker_id'] = combined_trades['maker']
    if 'taker' in combined_trades.columns and 'taker_id' not in combined_trades.columns:
        combined_trades['taker_id'] = combined_trades['taker']
    
    # Create trader_id column
    combined_trades['trader_id'] = combined_trades['maker_id']
    
    # Create trade_amount column if not present
    if 'trade_amount' not in combined_trades.columns:
        if 'size' in combined_trades.columns:
            combined_trades['trade_amount'] = combined_trades['size']
        else:
            combined_trades['trade_amount'] = 1.0
    
    # Additional logging
    unique_makers = combined_trades['maker_id'].nunique() if 'maker_id' in combined_trades.columns else 0
    unique_takers = combined_trades['taker_id'].nunique() if 'taker_id' in combined_trades.columns else 0
    unique_traders = combined_trades['trader_id'].nunique() if 'trader_id' in combined_trades.columns else 0
    
    print(f"Unique traders identified: {unique_traders} (makers: {unique_makers}, takers: {unique_takers})")
    
    return combined_trades

In [124]:
def scale_trade_volume(trades_df):
    """
    Scale the trade volume data to appropriate units
    
    Parameters:
    -----------
    trades_df : pd.DataFrame
        DataFrame with trade data
    
    Returns:
    --------
    pd.DataFrame
        DataFrame with scaled trade volumes
    """
    print("Scaling trade volume data...")
    
    # Create a copy to avoid modifying the original
    df = trades_df.copy()
    
    # Check if we have trade_amount column
    if 'trade_amount' in df.columns:
        # Check if values are extremely large (likely in base units)
        median_value = df['trade_amount'].median()
        
        if median_value > 10000:  # Threshold suggesting base units
            scaling_factor = 1e6  # Standard scaling for USDC/USD
            print(f"Applying scaling factor of {scaling_factor:,.0f} to trade_amount")
            
            # Store original values
            df['trade_amount_original'] = df['trade_amount']
            
            # Scale values
            df['trade_amount'] = df['trade_amount'] / scaling_factor
            
            print(f"Volume before scaling: {df['trade_amount_original'].sum():,.2f}")
            print(f"Volume after scaling: {df['trade_amount'].sum():,.2f}")
    
    # Check for size column if trade_amount not present or was not scaled
    elif 'size' in df.columns and 'trade_amount' not in df.columns:
        # Convert size to numeric if needed
        df['size'] = pd.to_numeric(df['size'], errors='coerce')
        
        # Check if values are extremely large
        median_value = df['size'].median()
        
        if median_value > 10000:  # Threshold suggesting base units
            scaling_factor = 1e6  # Standard scaling for USDC/USD
            print(f"Creating trade_amount from size with scaling factor of {scaling_factor:,.0f}")
            
            # Create scaled trade_amount
            df['trade_amount'] = df['size'] / scaling_factor
        else:
            # Use size directly
            print("Using size directly as trade_amount")
            df['trade_amount'] = df['size']
    
    # Handle maker/taker filled amounts if present
    elif all(col in df.columns for col in ['makerAmountFilled', 'takerAmountFilled']):
        # Convert to numeric
        df['makerAmountFilled'] = pd.to_numeric(df['makerAmountFilled'], errors='coerce')
        
        # Check if values are extremely large
        median_value = df['makerAmountFilled'].median()
        
        if median_value > 10000:  # Threshold suggesting base units
            scaling_factor = 1e6  # Standard scaling for USDC/USD
            print(f"Creating trade_amount from filled amounts with scaling factor of {scaling_factor:,.0f}")
            
            # Create scaled trade_amount
            df['trade_amount'] = df['makerAmountFilled'] / scaling_factor
        else:
            # Use makerAmountFilled directly
            print("Using makerAmountFilled directly as trade_amount")
            df['trade_amount'] = df['makerAmountFilled']
    
    # Ensure trade_amount is clean (no invalid values)
    if 'trade_amount' in df.columns:
        # Replace any negative values with NaN
        df.loc[df['trade_amount'] < 0, 'trade_amount'] = np.nan
        
        # Replace any extreme outliers (beyond 3 std from mean)
        mean = df['trade_amount'].mean()
        std = df['trade_amount'].std()
        upper_limit = mean + 3 * std
        
        # Flag potential outliers but don't remove them
        outliers = df['trade_amount'] > upper_limit
        if outliers.sum() > 0:
            print(f"Identified {outliers.sum()} potential outliers (> {upper_limit:.2f})")
        
        # Print summary statistics
        print("\nTrade amount summary statistics:")
        print(df['trade_amount'].describe())
    
    return df

## Define Target Markets

In [125]:
target_markets = [
    "Will Donald Trump win the 2024 US Presidential Election?", 
    "Will Kamala Harris win the 2024 US Presidential Election?"
]

# Load main dataset
market_data = load_main_dataset('data/cleaned_election_data.csv')

# Filter to specific markets
selected_markets = market_data[market_data['question'].isin(target_markets)]
market_ids = selected_markets['id'].tolist()

# Load ALL trades for these markets
trade_data = load_trade_data_for_analysis(
    market_ids=market_ids, 
    max_trades_per_market=None  # Load all trades
)


Loaded dataset with 1048575 rows and 54 columns
Loading market data...
Loaded dataset with 1048575 rows and 54 columns
Selected 2 markets for analysis

Loading trades for market 1/2: Will Donald Trump win the 2024 US Presidential Election?
Market ID: 253591.0
Loaded dataset with 1048575 rows and 54 columns
Successfully loaded 1185000 trades directly

Loading trades for market 2/2: Will Kamala Harris win the 2024 US Presidential Election?
Market ID: 253597.0
Loaded dataset with 1048575 rows and 54 columns
No trade data found for market 253597.0
No trades found using default method, trying alternative approaches
Found 2 token IDs for market 253597.0
Found token file: 69236923620077691027083946871148646972011131466059644796654161903044970987404.parquet 18-30-01-221.parquet
Loaded 802000 trades for token 69236923620077691027083946871148646972011131466059644796654161903044970987404
Found token file: 87584955359245246404952128082451897287778571240979823316620093987046202296181.parquet 18-30-

# Trader Classification and Analysis

## 1. Basic Market Overview


In [126]:
print("Market Trade Statistics:")
print(f"Total Trades: {len(trade_data)}")
print(f"Unique Markets: {trade_data['market_id'].nunique()}")
print(f"Unique Traders: {trade_data['trader_id'].nunique()}")

Market Trade Statistics:
Total Trades: 2431000
Unique Markets: 2
Unique Traders: 113124



## 2. Identify Potential Whale Traders


In [127]:
def identify_whales(trades_df, default_threshold=0.01, generate_plots=True):
    """
    Identify whale traders with visualization of different definitions
    
    Parameters:
    -----------
    trades_df : pd.DataFrame
        DataFrame with trade data
    default_threshold : float
        Default threshold for whale definition (as percentage, e.g., 0.01 for top 1%)
    generate_plots : bool
        Whether to generate visualization plots
    
    Returns:
    --------
    tuple
        Tuple containing (whale_ids, whale_results)
    """
    print("Analyzing trader concentration and identifying whales...")
    
    # Ensure we have the necessary columns
    if 'trader_id' not in trades_df.columns or 'trade_amount' not in trades_df.columns:
        print("Error: Missing required columns (trader_id, trade_amount)")
        return [], {}
    
    # Group trades by trader and calculate total volume
    trader_volumes = trades_df.groupby('trader_id')['trade_amount'].sum().sort_values(ascending=False)
    
    # Calculate total volume
    total_volume = trader_volumes.sum()
    total_traders = len(trader_volumes)
    
    print(f"Total traders: {total_traders:,}")
    print(f"Total volume: {total_volume:,.2f}")
    
    # Create cumulative volume percentages
    cumulative_volumes = trader_volumes.cumsum()
    cumulative_percentages = cumulative_volumes / total_volume * 100
    
    # Create DataFrame for analysis
    trader_analysis = pd.DataFrame({
        'trader_id': trader_volumes.index,
        'volume': trader_volumes.values,
        'cumulative_volume': cumulative_volumes.values,
        'volume_pct': trader_volumes.values / total_volume * 100,
        'cumulative_pct': cumulative_percentages.values
    })
    
    # Calculate Gini coefficient
    gini = calculate_gini(trader_volumes.values)
    print(f"Volume concentration (Gini coefficient): {gini:.4f}")
    
    # Define percentile thresholds to evaluate
    percentile_thresholds = [0.001, 0.01, 0.05, 0.1]
    
    # Calculate metrics for each threshold
    threshold_metrics = []
    for threshold in percentile_thresholds:
        num_whales = max(1, int(total_traders * threshold))
        whale_volume = trader_volumes.iloc[:num_whales].sum()
        whale_volume_pct = whale_volume / total_volume * 100
        
        # Store metrics
        threshold_metrics.append({
            'threshold': threshold,
            'threshold_label': f"Top {threshold*100:.1f}%",
            'num_whales': num_whales,
            'whale_volume': whale_volume,
            'whale_volume_pct': whale_volume_pct,
            'trader_pct': num_whales / total_traders * 100
        })
        
        print(f"Top {threshold*100:.1f}% definition ({num_whales:,} traders): {whale_volume_pct:.2f}% of volume")
    
    # Calculate volume coverage thresholds
    volume_thresholds = [50, 75, 90, 95]
    coverage_metrics = []
    
    for pct in volume_thresholds:
        # Find traders needed to reach this volume percentage
        traders_needed = sum(cumulative_percentages < pct) + 1
        traders_needed = min(traders_needed, len(trader_volumes))
        
        # Get the actual volume percentage
        actual_pct = cumulative_percentages.iloc[traders_needed-1] if traders_needed <= len(cumulative_percentages) else 100
        
        coverage_metrics.append({
            'volume_threshold': pct,
            'threshold_label': f"{pct}% Volume",
            'num_traders': traders_needed,
            'actual_volume_pct': actual_pct,
            'trader_pct': traders_needed / total_traders * 100
        })
        
        print(f"Traders needed for {pct}% volume: {traders_needed:,} ({traders_needed/total_traders*100:.4f}% of all traders)")
    
    # Create combined metrics DataFrame
    all_metrics = pd.DataFrame(threshold_metrics + coverage_metrics)
    
    # Generate visualizations
    if generate_plots:
        # Create figure for combined plots
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))
        
        # 1. Lorenz curve
        ax1.plot(np.linspace(0, 100, len(trader_volumes)), 
                 np.insert(cumulative_percentages.values, 0, 0), 
                 'b-', linewidth=2, label='Volume distribution')
        ax1.plot([0, 100], [0, 100], 'k--', label='Perfect equality')
        ax1.fill_between(np.linspace(0, 100, len(trader_volumes)), 
                          np.insert(cumulative_percentages.values, 0, 0), 
                          np.linspace(0, 100, len(trader_volumes)+1), 
                          alpha=0.2)
        
        # Add key percentiles
        for p in [90, 95, 99, 99.9]:
            # Calculate index for this percentile
            idx = min(int(total_traders * (100-p)/100), len(trader_volumes)-1)
            if idx >= 0:
                # Get x and y coordinates
                x = idx / total_traders * 100
                y = cumulative_percentages.iloc[idx] if idx < len(cumulative_percentages) else 100
                
                # Add reference lines
                ax1.plot([x, x], [0, y], 'r--', alpha=0.5)
                ax1.plot([0, x], [y, y], 'r--', alpha=0.5)
                
                # Add label
                ax1.text(x + 1, 10 + (p-90)*3, f'Top {100-p}%', fontsize=10)
        
        ax1.set_title(f'Trading Volume Distribution (Gini: {gini:.4f})')
        ax1.set_xlabel('Cumulative % of Traders')
        ax1.set_ylabel('Cumulative % of Volume')
        ax1.grid(alpha=0.3)
        ax1.legend()
        
        # 2. Whale definition comparison
        percent_definitions = pd.DataFrame(threshold_metrics)
        
        # Plot bars for percentage of traders vs percentage of volume
        bar_width = 0.35
        x = np.arange(len(percent_definitions))
        
        ax2.bar(x - bar_width/2, percent_definitions['trader_pct'], 
               bar_width, label='% of Traders', color='skyblue')
        ax2.bar(x + bar_width/2, percent_definitions['whale_volume_pct'], 
               bar_width, label='% of Volume', color='orange')
        
        # Set x-axis labels
        ax2.set_xticks(x)
        ax2.set_xticklabels(percent_definitions['threshold_label'])
        
        # Add value labels on bars
        for i, v in enumerate(percent_definitions['trader_pct']):
            ax2.text(i - bar_width/2, v + 1, f"{v:.2f}%", ha='center', fontsize=9)
        
        for i, v in enumerate(percent_definitions['whale_volume_pct']):
            ax2.text(i + bar_width/2, v + 1, f"{v:.2f}%", ha='center', fontsize=9)
        
        ax2.set_title('Whale Definitions Comparison')
        ax2.set_ylabel('Percentage')
        ax2.set_ylim(0, 100)
        ax2.grid(axis='y', alpha=0.3)
        ax2.legend()
        
        plt.tight_layout()
        plt.savefig('whale_definition_analysis.png', dpi=300)
        plt.close()
        
        print("Whale definition analysis visualizations saved to whale_definition_analysis.png")
    
    # Use the default threshold (1% by default)
    num_whales = max(1, int(total_traders * default_threshold))
    whale_ids = trader_volumes.head(num_whales).index.tolist()
    
    print(f"\nUsing top {default_threshold*100:.1f}% definition: {num_whales:,} whales")
    print(f"Selected whale threshold volume: {trader_volumes.iloc[num_whales-1] if num_whales <= len(trader_volumes) else 0:.2f}")
    
    # Return whale IDs and all results for further analysis
    return whale_ids, {
        'trader_analysis': trader_analysis,
        'threshold_metrics': threshold_metrics,
        'coverage_metrics': coverage_metrics,
        'gini': gini,
        'selected_threshold': default_threshold,
        'selected_num_whales': num_whales
    }


# Gini coefficient

In [128]:

def calculate_gini(values):
    """
    Calculate Gini coefficient for an array of values
    """
    # Handle edge cases
    if len(values) <= 1 or np.sum(values) == 0:
        return 0
    
    # Sort values
    sorted_values = np.sort(values)
    n = len(sorted_values)
    
    # Calculate cumulative sum
    cumsum = np.cumsum(sorted_values)
    
    # Calculate Gini coefficient using the formula
    return (n + 1 - 2 * np.sum((n + 1 - np.arange(1, n+1)) * sorted_values) / np.sum(sorted_values)) / n


# Price Impact

In [129]:
def analyze_whale_impact(trades_df, whale_ids):
    """
    Analyze the impact of whale trades on market prices
    
    Parameters:
    -----------
    trades_df : pd.DataFrame
        DataFrame with trade-level data
    whale_ids : list
        List of whale trader IDs
    
    Returns:
    --------
    dict
        Dictionary with whale impact analysis results
    """
    print("Analyzing whale trade impact...")
    
    # Verify whale_ids is not None and not empty
    if whale_ids is None or len(whale_ids) == 0:
        print("Error: No whale trader IDs provided")
        return None
    
    # Make a copy of the data
    df = trades_df.copy()
    
    # Check for required columns
    required_cols = ['trader_id', 'price']
    missing_cols = [col for col in required_cols if col not in df.columns]
    if missing_cols:
        print(f"Error: Missing required columns: {missing_cols}")
        return None
    
    # Clean price data
    df['price'] = pd.to_numeric(df['price'], errors='coerce')
    df = df.dropna(subset=['price'])
    
    # Convert timestamp to datetime if needed
    if 'timestamp' in df.columns:
        if not pd.api.types.is_datetime64_any_dtype(df['timestamp']):
            try:
                df['timestamp'] = pd.to_datetime(df['timestamp'], errors='coerce')
                df = df.dropna(subset=['timestamp'])
                print(f"Converted {len(df)} valid timestamps")
            except Exception as e:
                print(f"Warning: Could not convert timestamps to datetime: {e}")
                # Create sequential index as timestamp substitute
                df = df.sort_index().reset_index(drop=True)
                df['timestamp'] = df.index
    else:
        print("No timestamp column found. Creating sequential index.")
        df = df.reset_index(drop=True)
        df['timestamp'] = df.index
    
    # Sort by timestamp
    df = df.sort_values('timestamp')
    
    # Add whale indicator
    df['is_whale'] = df['trader_id'].isin(whale_ids)
    
    # Separate whale and non-whale trades
    whale_trades = df[df['is_whale']]
    non_whale_trades = df[~df['is_whale']]
    
    print(f"Found {len(whale_trades)} whale trades and {len(non_whale_trades)} non-whale trades")
    
    # Calculate price changes
    df['price_change'] = df['price'].diff()
    
    # Calculate impact by market if market_id is available
    if 'market_id' in df.columns:
        print("Analyzing price impact by market...")
        market_impacts = []
        
        for market_id, market_df in df.groupby('market_id'):
            # Skip markets with too few trades
            if len(market_df) < 10:
                continue
                
            # Sort by timestamp
            market_df = market_df.sort_values('timestamp')
            
            # Calculate price changes
            market_df['price_change'] = market_df['price'].diff()
            
            # Separate whale and non-whale trades
            market_whale_trades = market_df[market_df['is_whale']]
            market_non_whale_trades = market_df[~market_df['is_whale']]
            
            # Skip markets with no whale trades
            if len(market_whale_trades) == 0:
                continue
                
            # Calculate average metrics
            whale_avg_change = market_whale_trades['price_change'].mean()
            non_whale_avg_change = market_non_whale_trades['price_change'].mean()
            
            # Calculate directional impact
            whale_pos_pct = (market_whale_trades['price_change'] > 0).mean() * 100
            whale_neg_pct = (market_whale_trades['price_change'] < 0).mean() * 100
            non_whale_pos_pct = (market_non_whale_trades['price_change'] > 0).mean() * 100
            non_whale_neg_pct = (market_non_whale_trades['price_change'] < 0).mean() * 100
            
            market_impacts.append({
                'market_id': market_id,
                'total_trades': len(market_df),
                'whale_trades': len(market_whale_trades),
                'non_whale_trades': len(market_non_whale_trades),
                'whale_avg_change': whale_avg_change,
                'non_whale_avg_change': non_whale_avg_change,
                'whale_pos_pct': whale_pos_pct,
                'whale_neg_pct': whale_neg_pct,
                'non_whale_pos_pct': non_whale_pos_pct,
                'non_whale_neg_pct': non_whale_neg_pct
            })
        
        # Create markets DataFrame
        if market_impacts:
            markets_df = pd.DataFrame(market_impacts)
            
            # Calculate weighted averages
            weighted_whale_impact = np.average(
                markets_df['whale_avg_change'].fillna(0),
                weights=markets_df['whale_trades']
            )
            
            weighted_non_whale_impact = np.average(
                markets_df['non_whale_avg_change'].fillna(0),
                weights=markets_df['non_whale_trades']
            )
            
            print(f"\nWeighted average whale price impact: {weighted_whale_impact:.6f}")
            print(f"Weighted average non-whale price impact: {weighted_non_whale_impact:.6f}")
            
            # Calculate impact ratio if possible
            if weighted_non_whale_impact != 0:
                impact_ratio = weighted_whale_impact / weighted_non_whale_impact
                print(f"Impact ratio (whale/non-whale): {impact_ratio:.4f}")
            else:
                impact_ratio = None
                print("Impact ratio cannot be calculated (division by zero)")
            
            # Create visualization
            plt.figure(figsize=(15, 12))
            
            # 1. Market-by-market comparison
            plt.subplot(2, 1, 1)
            
            # Sort markets by whale impact
            sorted_markets = markets_df.sort_values('whale_avg_change')
            
            # Plot whale vs non-whale impact by market
            plt.scatter(range(len(sorted_markets)), sorted_markets['whale_avg_change'], 
                       label='Whale impact', alpha=0.7, s=50, color='blue')
            plt.scatter(range(len(sorted_markets)), sorted_markets['non_whale_avg_change'], 
                       label='Non-whale impact', alpha=0.7, s=50, color='orange')
            
            plt.axhline(y=0, color='r', linestyle='--')
            plt.title('Price Impact by Market')
            plt.xlabel('Markets (sorted by whale impact)')
            plt.ylabel('Average Price Change')
            plt.legend()
            plt.grid(alpha=0.3)
            
            # 2. Direction comparison
            plt.subplot(2, 1, 2)
            
            # Calculate average positive/negative percentages
            avg_whale_pos = markets_df['whale_pos_pct'].mean()
            avg_whale_neg = markets_df['whale_neg_pct'].mean()
            avg_nonwhale_pos = markets_df['non_whale_pos_pct'].mean()
            avg_nonwhale_neg = markets_df['non_whale_neg_pct'].mean()
            
            # Plot directional impact
            labels = ['Whale', 'Non-whale']
            pos_values = [avg_whale_pos, avg_nonwhale_pos]
            neg_values = [avg_whale_neg, avg_nonwhale_neg]
            neutral_values = [100 - avg_whale_pos - avg_whale_neg, 
                             100 - avg_nonwhale_pos - avg_nonwhale_neg]
            
            width = 0.35
            x = np.arange(len(labels))
            
            plt.bar(x, pos_values, width, label='Positive impact', color='green')
            plt.bar(x, neg_values, width, bottom=pos_values, label='Negative impact', color='red')
            plt.bar(x, neutral_values, width, 
                   bottom=[pos_values[i] + neg_values[i] for i in range(len(pos_values))], 
                   label='Neutral', color='gray')
            
            plt.title('Direction of Price Impact')
            plt.ylabel('Percentage of Trades')
            plt.xlabel('Trader Type')
            plt.xticks(x, labels)
            plt.legend()
            
            plt.tight_layout()
            plt.savefig('whale_impact_analysis.png', dpi=300)
            plt.close()
            
            print("Whale impact analysis visualization saved to whale_impact_analysis.png")
            
            return {
                'market_impacts': markets_df.to_dict('records'),
                'weighted_whale_impact': weighted_whale_impact,
                'weighted_non_whale_impact': weighted_non_whale_impact,
                'impact_ratio': impact_ratio,
                'direction_metrics': {
                    'whale_positive_pct': avg_whale_pos,
                    'whale_negative_pct': avg_whale_neg,
                    'non_whale_positive_pct': avg_nonwhale_pos,
                    'non_whale_negative_pct': avg_nonwhale_neg
                }
            }
    
    # If market_id not available, perform overall analysis
    print("Analyzing overall price changes...")
    
    # Calculate metrics
    whale_avg_change = whale_trades['price_change'].mean()
    whale_median_change = whale_trades['price_change'].median()
    whale_std_change = whale_trades['price_change'].std()
    
    non_whale_avg_change = non_whale_trades['price_change'].mean()
    non_whale_median_change = non_whale_trades['price_change'].median()
    non_whale_std_change = non_whale_trades['price_change'].std()
    
    print(f"\nWhale trades average price change: {whale_avg_change:.6f}")
    print(f"Non-whale trades average price change: {non_whale_avg_change:.6f}")
    
    # Calculate direction metrics
    whale_pos_pct = (whale_trades['price_change'] > 0).mean() * 100
    whale_neg_pct = (whale_trades['price_change'] < 0).mean() * 100
    non_whale_pos_pct = (non_whale_trades['price_change'] > 0).mean() * 100
    non_whale_neg_pct = (non_whale_trades['price_change'] < 0).mean() * 100
    
    print(f"Whale trades causing price increases: {whale_pos_pct:.2f}%")
    print(f"Whale trades causing price decreases: {whale_neg_pct:.2f}%")
    
    # Calculate following behavior
    df['next_is_whale'] = df['is_whale'].shift(-1)
    df['prev_is_whale'] = df['is_whale'].shift(1)
    
    # Calculate price direction
    df['price_direction'] = np.sign(df['price_change'])
    df['next_price_direction'] = df['price_direction'].shift(-1)
    df['prev_price_direction'] = df['price_direction'].shift(1)
    
    # Calculate how often non-whales follow whale direction
    whale_followed = df[(df['prev_is_whale']) & (~df['is_whale']) & 
                      (df['price_direction'] == df['prev_price_direction'])]
    whale_trades_with_followers = df[df['prev_is_whale'] & ~df['is_whale']]
    
    if len(whale_trades_with_followers) > 0:
        following_ratio = len(whale_followed) / len(whale_trades_with_followers)
        print(f"Non-whale traders follow whale price direction: {following_ratio:.2%} of the time")
    else:
        following_ratio = None
        print("Could not calculate following ratio")
    
    # Create visualization
    plt.figure(figsize=(15, 10))
    
    # 1. Price change distribution
    plt.subplot(2, 1, 1)
    
    # Calculate bins for histogram
    bin_width = max(whale_std_change, non_whale_std_change) / 5
    bins = np.arange(
        min(whale_trades['price_change'].min(), non_whale_trades['price_change'].min()) - bin_width,
        max(whale_trades['price_change'].max(), non_whale_trades['price_change'].max()) + bin_width,
        bin_width
    )
    
    # Plot histograms
    plt.hist(whale_trades['price_change'].dropna(), bins=bins, alpha=0.5, 
             label=f'Whale trades (mean={whale_avg_change:.6f})', color='blue')
    plt.hist(non_whale_trades['price_change'].dropna(), bins=bins, alpha=0.5, 
             label=f'Non-whale trades (mean={non_whale_avg_change:.6f})', color='orange')
    
    plt.axvline(x=0, color='r', linestyle='--')
    plt.axvline(x=whale_avg_change, color='blue', linestyle='-')
    plt.axvline(x=non_whale_avg_change, color='orange', linestyle='-')
    
    plt.title('Distribution of Price Changes')
    plt.xlabel('Price Change')
    plt.ylabel('Frequency')
    plt.legend()
    plt.grid(alpha=0.3)
    
    # 2. Direction comparison
    plt.subplot(2, 1, 2)
    
    # Plot directional impact
    labels = ['Whale', 'Non-whale']
    pos_values = [whale_pos_pct, non_whale_pos_pct]
    neg_values = [whale_neg_pct, non_whale_neg_pct]
    neutral_values = [100 - whale_pos_pct - whale_neg_pct, 
                     100 - non_whale_pos_pct - non_whale_neg_pct]
    
    width = 0.35
    x = np.arange(len(labels))
    
    plt.bar(x, pos_values, width, label='Positive impact', color='green')
    plt.bar(x, neg_values, width, bottom=pos_values, label='Negative impact', color='red')
    plt.bar(x, neutral_values, width, 
           bottom=[pos_values[i] + neg_values[i] for i in range(len(pos_values))], 
           label='Neutral', color='gray')
    
    # Add value labels
    for i, v in enumerate(pos_values):
        plt.text(i, v/2, f"{v:.1f}%", ha='center', color='white', fontweight='bold')
    
    for i, v in enumerate(neg_values):
        plt.text(i, pos_values[i] + v/2, f"{v:.1f}%", ha='center', color='white', fontweight='bold')
    
    plt.title('Direction of Price Impact')
    plt.ylabel('Percentage of Trades')
    plt.xlabel('Trader Type')
    plt.xticks(x, labels)
    plt.legend()
    
    plt.tight_layout()
    plt.savefig('whale_impact_analysis.png', dpi=300)
    plt.close()
    
    print("Whale impact analysis visualization saved to whale_impact_analysis.png")
    
    return {
        'whale_impact': {
            'avg_change': whale_avg_change,
            'median_change': whale_median_change,
            'std_change': whale_std_change,
            'positive_pct': whale_pos_pct,
            'negative_pct': whale_neg_pct
        },
        'non_whale_impact': {
            'avg_change': non_whale_avg_change,
            'median_change': non_whale_median_change,
            'std_change': non_whale_std_change,
            'positive_pct': non_whale_pos_pct,
            'negative_pct': non_whale_neg_pct
        },
        'following_ratio': following_ratio
    }

# Usage section:

# Trading Inequality

In [131]:
def visualize_trading_inequality(trader_analysis_df, save_path='trading_inequality_analysis.png'):
    """
    Create detailed visualizations of trading inequality
    
    Parameters:
    -----------
    trader_analysis_df : pd.DataFrame
        DataFrame with trader volume analysis
    save_path : str
        Path to save the visualization
        
    Returns:
    --------
    dict
        Dictionary with inequality metrics
    """
    print("Generating trading inequality visualizations...")
    
    # Calculate key metrics
    total_traders = len(trader_analysis_df)
    total_volume = trader_analysis_df['volume'].sum()
    
    # Sort by volume for analysis
    sorted_df = trader_analysis_df.sort_values('volume', ascending=True)
    
    # Calculate percentiles
    percentiles = [50, 90, 95, 99, 99.9]
    percentile_data = {}
    
    for p in percentiles:
        threshold = np.percentile(sorted_df['volume'], p)
        traders_above = sum(sorted_df['volume'] > threshold)
        volume_share = sorted_df[sorted_df['volume'] > threshold]['volume'].sum() / total_volume * 100
        
        percentile_data[p] = {
            'threshold': threshold,
            'traders_above': traders_above,
            'traders_pct': traders_above / total_traders * 100,
            'volume_share': volume_share
        }
        
        print(f"Top {100-p:.1f}% of traders (volume > {threshold:.2f}) control {volume_share:.2f}% of volume")
    
    # Create 2x2 visualization
    fig, axs = plt.subplots(2, 2, figsize=(16, 14))
    
    # 1. Lorenz curve (top left)
    ax1 = axs[0, 0]
    
    # Calculate points for the Lorenz curve
    x_lorenz = np.linspace(0, 100, len(sorted_df)+1)
    y_lorenz = np.insert(np.cumsum(sorted_df['volume']) / total_volume * 100, 0, 0)
    
    # Plot Lorenz curve
    ax1.plot(x_lorenz, y_lorenz, 'b-', linewidth=2, label='Volume distribution')
    ax1.plot([0, 100], [0, 100], 'k--', label='Perfect equality')
    ax1.fill_between(x_lorenz, y_lorenz, x_lorenz, alpha=0.2, color='blue')
    
    # Calculate Gini coefficient
    gini = 1 - 2 * np.trapz(y_lorenz, x_lorenz) / 10000  # Area under perfect equality is 100*100/2
    
    # Add key percentiles
    for p in [90, 95, 99, 99.9]:
        # Get index for this percentile
        idx = int(total_traders * (100-p) / 100)
        if idx < len(sorted_df):
            # Get trader percentage and volume percentage
            x = 100 - p
            y = 100 - percentile_data[p]['volume_share']
            
            # Add reference lines
            ax1.plot([100-p, 100-p], [0, 100-y], 'r--', alpha=0.5)
            ax1.plot([0, 100-p], [100-y, 100-y], 'r--', alpha=0.5)
            
            # Add label
            ax1.text(100-p + 0.5, 5, f'Top {p:.1f}%', fontsize=9)
    
    ax1.set_title(f'Trading Volume Lorenz Curve (Gini: {gini:.4f})')
    ax1.set_xlabel('Cumulative % of Traders')
    ax1.set_ylabel('Cumulative % of Volume')
    ax1.grid(alpha=0.3)
    ax1.legend()
    
    # 2. Volume distribution histogram (top right)
    ax2 = axs[0, 1]
    
    # Use log scale for better visualization
    log_volumes = np.log10(sorted_df['volume'] + 1)  # +1 to handle zeros
    
    ax2.hist(log_volumes, bins=50, alpha=0.7, color='skyblue')
    ax2.set_title('Trading Volume Distribution (Log Scale)')
    ax2.set_xlabel('Log10(Volume)')
    ax2.set_ylabel('Number of Traders')
    
    # Add percentile markers
    for p in [50, 90, 95, 99]:
        threshold = np.log10(np.percentile(sorted_df['volume'], p) + 1)
        ax2.axvline(threshold, color='red', linestyle='--', alpha=0.5)
        ax2.text(threshold + 0.1, ax2.get_ylim()[1]*0.9, f'{p}th', fontsize=9, rotation=90)
    
    ax2.grid(alpha=0.3)
    
    # 3. Top trader concentration (bottom left)
    ax3 = axs[1, 0]
    
    # Prepare data for concentration chart
    top_n_categories = ['Top 0.1%', 'Top 0.1-1%', 'Top 1-5%', 'Top 5-10%', 'Bottom 90%']
    
    # Calculate trader counts for each category
    n_01pct = max(1, int(total_traders * 0.001))
    n_1pct = max(1, int(total_traders * 0.01))
    n_5pct = max(1, int(total_traders * 0.05))
    n_10pct = max(1, int(total_traders * 0.1))
    
    # Calculate volume for each category
    vol_01pct = sorted_df.iloc[-n_01pct:]['volume'].sum()
    vol_01_1pct = sorted_df.iloc[-n_1pct:-n_01pct]['volume'].sum() if n_1pct > n_01pct else 0
    vol_1_5pct = sorted_df.iloc[-n_5pct:-n_1pct]['volume'].sum() if n_5pct > n_1pct else 0
    vol_5_10pct = sorted_df.iloc[-n_10pct:-n_5pct]['volume'].sum() if n_10pct > n_5pct else 0
    vol_bottom90pct = sorted_df.iloc[:-n_10pct]['volume'].sum() if len(sorted_df) > n_10pct else 0
    
    # Calculate percentages
    pct_01pct = vol_01pct / total_volume * 100
    pct_01_1pct = vol_01_1pct / total_volume * 100
    pct_1_5pct = vol_1_5pct / total_volume * 100
    pct_5_10pct = vol_5_10pct / total_volume * 100
    pct_bottom90pct = vol_bottom90pct / total_volume * 100
    
    volume_shares = [pct_01pct, pct_01_1pct, pct_1_5pct, pct_5_10pct, pct_bottom90pct]
    
    # Create bar chart
    bars = ax3.bar(top_n_categories, volume_shares, color=['darkred', 'red', 'orange', 'gold', 'green'])
    
    # Add value labels
    for bar in bars:
        height = bar.get_height()
        ax3.text(bar.get_x() + bar.get_width()/2., height + 0.5,
                f'{height:.1f}%', ha='center', fontsize=10)
    
    ax3.set_title('Volume Distribution by Trader Category')
    ax3.set_ylabel('% of Total Volume')
    ax3.set_ylim(0, max(volume_shares) * 1.1)
    ax3.grid(axis='y', alpha=0.3)
    
    # 4. Power law plot (bottom right)
    ax4 = axs[1, 1]
    
    # Rank traders by volume (largest first)
    ranked_volumes = sorted(sorted_df['volume'], reverse=True)
    ranks = np.arange(1, len(ranked_volumes) + 1)
    
    # Plot volumes vs rank (log-log scale)
    ax4.loglog(ranks, ranked_volumes, 'o', markersize=2, alpha=0.5)
    
    # Try to fit a power law
    try:
        from scipy import stats
        # Use only volumes > 0 for log fitting
        positive_volumes = np.array(ranked_volumes)
        positive_volumes = positive_volumes[positive_volumes > 0]
        positive_ranks = np.arange(1, len(positive_volumes) + 1)
        
        # Linear fit on log-log scale
        slope, intercept, r_value, p_value, std_err = stats.linregress(
            np.log10(positive_ranks), np.log10(positive_volumes)
        )
        
        # Plot the fitted line
        x_fit = np.logspace(0, np.log10(len(positive_ranks)), 100)
        y_fit = 10**(intercept) * x_fit**slope
        ax4.plot(x_fit, y_fit, 'r-', linewidth=2, 
                label=f'Power law fit: α={-slope:.2f}, R²={r_value**2:.2f}')
        
        print(f"Power law exponent: α={-slope:.2f}, R²={r_value**2:.2f}")
    except:
        print("Could not fit power law (scipy.stats not available)")
    
    ax4.set_title('Trader Volume Rank Distribution')
    ax4.set_xlabel('Trader Rank')
    ax4.set_ylabel('Volume')
    ax4.grid(True, which="both", ls="-", alpha=0.3)
    ax4.legend()
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300)
    plt.close()
    
    print(f"Trading inequality visualizations saved to {save_path}")
    
    # Return inequality metrics
    return {
        'gini': gini,
        'percentile_data': percentile_data,
        'top_trader_shares': {
            'top_0.1pct': pct_01pct,
            'top_0.1_1pct': pct_01_1pct,
            'top_1_5pct': pct_1_5pct,
            'top_5_10pct': pct_5_10pct,
            'bottom_90pct': pct_bottom90pct
        }
    }

# Behavior Over Time

In [132]:
def analyze_trader_behavior_over_time(trades_df, whale_ids, time_unit='D'):
    """
    Analyze how trading patterns evolve over time for whales vs non-whales
    
    Parameters:
    -----------
    trades_df : pd.DataFrame
        DataFrame with trade data
    whale_ids : list
        List of whale trader IDs
    time_unit : str
        Time unit for aggregation ('D' for day, 'W' for week, etc.)
        
    Returns:
    --------
    dict
        Dictionary with temporal analysis results
    """
    print(f"Analyzing trader behavior over time (unit: {time_unit})...")
    
    # Clean the data
    df = trades_df.copy()
    
    # Ensure we have timestamps
    if 'timestamp' not in df.columns:
        print("Error: No timestamp column available")
        return None
        
    # Convert timestamp to datetime if needed
    if not pd.api.types.is_datetime64_any_dtype(df['timestamp']):
        try:
            df['timestamp'] = pd.to_datetime(df['timestamp'], errors='coerce')
            df = df.dropna(subset=['timestamp'])
            print(f"Converted {len(df)} timestamps to datetime format")
        except Exception as e:
            print(f"Error converting timestamps: {e}")
            return None
    
    # Add whale indicator
    df['is_whale'] = df['trader_id'].isin(whale_ids)
    
    # Ensure price is numeric
    if 'price' in df.columns:
        df['price'] = pd.to_numeric(df['price'], errors='coerce')
    
    # Group by date and trader type
    df['date'] = df['timestamp'].dt.floor(time_unit)
    
    # Calculate metrics by date and trader type
    time_metrics = []
    
    for (date, is_whale), group in df.groupby(['date', 'is_whale']):
        # Skip groups with no trades
        if len(group) == 0:
            continue
            
        metrics = {
            'date': date,
            'is_whale': is_whale,
            'trader_type': 'Whale' if is_whale else 'Non-whale',
            'trade_count': len(group),
            'unique_traders': group['trader_id'].nunique(),
            'volume': group['trade_amount'].sum(),
            'avg_trade_size': group['trade_amount'].mean()
        }
        
        # Add price metrics if available
        if 'price' in group.columns:
            metrics.update({
                'avg_price': group['price'].mean(),
                'price_std': group['price'].std(),
                'price_range': group['price'].max() - group['price'].min() if len(group) > 1 else 0
            })
        
        time_metrics.append(metrics)
    
    # Convert to DataFrame
    time_df = pd.DataFrame(time_metrics)
    
    # Check if we have enough data
    if len(time_df) < 2:
        print("Not enough temporal data to analyze")
        return None
    
    # Create visualization
    plt.figure(figsize=(15, 12))
    
    # 1. Activity over time (top)
    plt.subplot(3, 1, 1)
    
    # Create pivot table for activity
    pivot_activity = time_df.pivot_table(
        index='date', columns='trader_type', values='trade_count'
    ).fillna(0)
    
    # Plot activity
    pivot_activity.plot(ax=plt.gca())
    plt.title('Trading Activity Over Time')
    plt.ylabel('Number of Trades')
    plt.grid(alpha=0.3)
    
    # 2. Volume over time (middle)
    plt.subplot(3, 1, 2)
    
    # Create pivot table for volume
    pivot_volume = time_df.pivot_table(
        index='date', columns='trader_type', values='volume'
    ).fillna(0)
    
    # Plot volume
    pivot_volume.plot(ax=plt.gca())
    plt.title('Trading Volume Over Time')
    plt.ylabel('Volume')
    plt.grid(alpha=0.3)
    
    # 3. Trader participation over time (bottom)
    plt.subplot(3, 1, 3)
    
    # Create pivot table for unique traders
    pivot_traders = time_df.pivot_table(
        index='date', columns='trader_type', values='unique_traders'
    ).fillna(0)
    
    # Plot trader counts
    pivot_traders.plot(ax=plt.gca())
    plt.title('Trader Participation Over Time')
    plt.ylabel('Number of Unique Traders')
    plt.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('trader_behavior_over_time.png', dpi=300)
    plt.close()
    
    print("Temporal analysis visualization saved to trader_behavior_over_time.png")
    
    # Calculate correlation between whale and non-whale activity
    if all(col in pivot_activity.columns for col in ['Whale', 'Non-whale']):
        activity_correlation = pivot_activity['Whale'].corr(pivot_activity['Non-whale'])
        print(f"Correlation between whale and non-whale activity: {activity_correlation:.4f}")
        
        # Calculate lead-lag relationship with a 1-period lag
        whale_lead_corr = pivot_activity['Whale'].shift(1).corr(pivot_activity['Non-whale'])
        nonwhale_lead_corr = pivot_activity['Non-whale'].shift(1).corr(pivot_activity['Whale'])
        
        if whale_lead_corr > nonwhale_lead_corr:
            print(f"Whales appear to lead non-whale activity (correlation: {whale_lead_corr:.4f})")
        else:
            print(f"Non-whales appear to lead whale activity (correlation: {nonwhale_lead_corr:.4f})")
    
    # Calculate trends
    if len(pivot_activity) >= 5:
        # Calculate whale activity trend
        if 'Whale' in pivot_activity.columns:
            whale_trend = pivot_activity['Whale'].rolling(min(5, len(pivot_activity))).mean()
            whale_trend_direction = np.sign(whale_trend.diff().mean())
            print(f"Whale activity trend: {'Increasing' if whale_trend_direction > 0 else 'Decreasing'}")
        
        # Calculate non-whale activity trend
        if 'Non-whale' in pivot_activity.columns:
            nonwhale_trend = pivot_activity['Non-whale'].rolling(min(5, len(pivot_activity))).mean()
            nonwhale_trend_direction = np.sign(nonwhale_trend.diff().mean())
            print(f"Non-whale activity trend: {'Increasing' if nonwhale_trend_direction > 0 else 'Decreasing'}")
    
    return {
        'time_data': time_df.to_dict('records'),
        'activity_correlation': activity_correlation if 'activity_correlation' in locals() else None,
        'whale_lead_correlation': whale_lead_corr if 'whale_lead_corr' in locals() else None,
        'nonwhale_lead_correlation': nonwhale_lead_corr if 'nonwhale_lead_corr' in locals() else None,
        'whale_trend_direction': whale_trend_direction if 'whale_trend_direction' in locals() else None,
        'nonwhale_trend_direction': nonwhale_trend_direction if 'nonwhale_trend_direction' in locals() else None
    }

# Market Accuracy by Whale Activity

In [133]:

def analyze_market_accuracy_by_whale_activity(trades_df, market_data, whale_ids):
    """
    Analyze how whale activity correlates with market prediction accuracy
    
    Parameters:
    -----------
    trades_df : pd.DataFrame
        DataFrame with trade data
    market_data : pd.DataFrame
        DataFrame with market-level data including accuracy metrics
    whale_ids : list
        List of whale trader IDs
    
    Returns:
    --------
    dict
        Dictionary with accuracy analysis results
    """
    print("Analyzing market accuracy by whale activity...")
    
    # Verify we have required columns in market_data
    required_cols = ['market_id', 'brier_score']
    missing_cols = [col for col in required_cols if col not in market_data.columns]
    
    if missing_cols:
        print(f"Error: Missing required columns in market_data: {missing_cols}")
        print("Required columns: market_id, brier_score (or other accuracy metric)")
        return None
    
    # Clean trades data
    df = trades_df.copy()
    df['is_whale'] = df['trader_id'].isin(whale_ids)
    
    # Calculate whale activity metrics per market
    market_metrics = []
    
    for market_id, market_trades in df.groupby('market_id'):
        whale_trades = market_trades[market_trades['is_whale']]
        non_whale_trades = market_trades[~market_trades['is_whale']]
        
        # Skip markets with too few trades
        if len(market_trades) < 10:
            continue
            
        metrics = {
            'market_id': market_id,
            'total_trades': len(market_trades),
            'whale_trades': len(whale_trades),
            'non_whale_trades': len(non_whale_trades),
            'whale_ratio': len(whale_trades) / len(market_trades) if len(market_trades) > 0 else 0,
            'unique_whales': whale_trades['trader_id'].nunique(),
            'unique_non_whales': non_whale_trades['trader_id'].nunique(),
            'whale_volume': whale_trades['trade_amount'].sum() if 'trade_amount' in whale_trades else 0,
            'non_whale_volume': non_whale_trades['trade_amount'].sum() if 'trade_amount' in non_whale_trades else 0,
            'whale_volume_ratio': (
                whale_trades['trade_amount'].sum() / market_trades['trade_amount'].sum() 
                if 'trade_amount' in market_trades and market_trades['trade_amount'].sum() > 0 
                else 0
            ),
        }
        
        market_metrics.append(metrics)
    
    # Convert to DataFrame
    metrics_df = pd.DataFrame(market_metrics)
    
    # Merge with market accuracy data
    merged_df = metrics_df.merge(
        market_data[['market_id', 'brier_score']], 
        on='market_id', 
        how='inner'
    )
    
    if len(merged_df) == 0:
        print("Error: Could not merge trade metrics with market accuracy data")
        return None
    
    # Calculate correlations between whale activity and accuracy
    whale_metrics = ['whale_ratio', 'unique_whales', 'whale_volume_ratio']
    correlations = {}
    
    for metric in whale_metrics:
        if metric in merged_df.columns:
            corr = merged_df[metric].corr(merged_df['brier_score'])
            correlations[f"{metric}_correlation"] = corr
            print(f"Correlation between {metric} and Brier score: {corr:.4f}")
    
    # Group by whale activity level
    merged_df['whale_activity_quantile'] = pd.qcut(
        merged_df['whale_ratio'], 
        q=4, 
        labels=['Low', 'Medium-Low', 'Medium-High', 'High']
    )
    
    # Calculate average accuracy by whale activity
    accuracy_by_activity = merged_df.groupby('whale_activity_quantile')['brier_score'].mean()
    
    print("\nAverage Brier score by whale activity level:")
    for activity, score in accuracy_by_activity.items():
        print(f"{activity} whale activity: {score:.4f}")
    
    return {
        'market_metrics': merged_df.to_dict('records'),
        'correlations': correlations,
        'accuracy_by_activity': accuracy_by_activity.to_dict()
    }


# Whale Sentiment Analysis

In [134]:
def analyze_whale_sentiment_alignment(trades_df, whale_ids):
    """
    Analyze whether whales tend to agree with each other or take opposing positions
    
    Parameters:
    -----------
    trades_df : pd.DataFrame
        DataFrame with trade data
    whale_ids : list
        List of whale trader IDs
    
    Returns:
    --------
    dict
        Dictionary with alignment analysis results
    """
    print("Analyzing whale sentiment alignment...")
    
    # Clean the data
    df = trades_df.copy()
    
    # Filter to whale trades only
    df = df[df['trader_id'].isin(whale_ids)]
    
    # Verify we have direction information
    if 'side' in df.columns:
        direction_col = 'side'
    else:
        print("Error: No trade direction information (side column) available")
        return None
    
    # Verify we have market_id
    if 'market_id' not in df.columns:
        print("Error: market_id column required for alignment analysis")
        return None
    
    # Skip if there are too few whale trades
    if len(df) < 100:
        print(f"Too few whale trades to analyze alignment: {len(df)}")
        return None
    
    # Calculate alignment for each market
    market_alignment = []
    
    for market_id, market_df in df.groupby('market_id'):
        # Skip markets with too few trades or whales
        if len(market_df) < 10 or market_df['trader_id'].nunique() < 2:
            continue
        
        # Map directions to numeric values
        market_df['direction'] = market_df[direction_col].map({'buy': 1, 'sell': -1})
        
        # Drop rows with missing direction
        market_df = market_df.dropna(subset=['direction'])
        
        if len(market_df) < 10:
            continue
        
        # Calculate the average direction for each whale
        whale_sentiments = market_df.groupby('trader_id')['direction'].mean()
        
        # Calculate alignment metrics
        alignment_score = whale_sentiments.mean()  # How aligned are whales overall?
        
        # Consensus strength: how much do whales agree with each other?
        # 1 = perfect agreement, 0 = complete disagreement
        consensus_strength = np.abs(alignment_score)
        
        # Are whales mostly buying or selling?
        sentiment = "Bullish" if alignment_score > 0 else "Bearish"
        
        # Calculate pairwise correlations between whale directions
        if len(whale_sentiments) >= 3:
            # Create time series of directions for each whale
            if 'timestamp' in market_df.columns:
                if not pd.api.types.is_datetime64_any_dtype(market_df['timestamp']):
                    market_df['timestamp'] = pd.to_datetime(market_df['timestamp'], errors='coerce')
                
                pivot = market_df.pivot_table(
                    index='timestamp',
                    columns='trader_id',
                    values='direction',
                    aggfunc='last'
                )
                
                # Calculate correlation matrix
                corr_matrix = pivot.corr()
                
                # Extract upper triangle (excluding diagonal)
                upper_triangle = np.triu(corr_matrix.values, k=1)
                
                # Calculate average pairwise correlation
                avg_pairwise_corr = np.nanmean(upper_triangle) if np.sum(~np.isnan(upper_triangle)) > 0 else np.nan
            else:
                avg_pairwise_corr = np.nan
        else:
            avg_pairwise_corr = np.nan
        
        market_alignment.append({
            'market_id': market_id,
            'unique_whales': market_df['trader_id'].nunique(),
            'total_whale_trades': len(market_df),
            'alignment_score': alignment_score,
            'consensus_strength': consensus_strength,
            'sentiment': sentiment,
            'avg_pairwise_correlation': avg_pairwise_corr
        })
    
    # Check if we have alignment data
    if not market_alignment:
        print("No markets with sufficient data for alignment analysis")
        return None
    
    # Convert to DataFrame
    alignment_df = pd.DataFrame(market_alignment)
    
    # Calculate summary statistics
    avg_consensus = alignment_df['consensus_strength'].mean()
    avg_pairwise_corr = alignment_df['avg_pairwise_correlation'].mean()
    
    print(f"Average whale consensus strength: {avg_consensus:.4f}")
    print(f"Average pairwise correlation between whales: {avg_pairwise_corr:.4f}")
    
    # Count bullish vs bearish markets
    bullish_markets = (alignment_df['alignment_score'] > 0).sum()
    bearish_markets = (alignment_df['alignment_score'] < 0).sum()
    
    bullish_pct = bullish_markets / len(alignment_df) * 100
    print(f"Whales are bullish in {bullish_pct:.1f}% of markets")
    
    # Create visualization
    plt.figure(figsize=(15, 10))
    
    # 1. Consensus strength distribution
    plt.subplot(2, 1, 1)
    
    plt.hist(alignment_df['consensus_strength'], bins=20, alpha=0.7)
    plt.axvline(x=avg_consensus, color='r', linestyle='--', 
               label=f'Average consensus: {avg_consensus:.4f}')
    
    plt.title('Whale Consensus Strength Distribution')
    plt.xlabel('Consensus Strength (0=Disagreement, 1=Agreement)')
    plt.ylabel('Number of Markets')
    plt.legend()
    plt.grid(alpha=0.3)
    
    # 2. Pairwise correlation distribution
    plt.subplot(2, 1, 2)
    
    non_nan_corrs = alignment_df['avg_pairwise_correlation'].dropna()
    if len(non_nan_corrs) > 0:
        plt.hist(non_nan_corrs, bins=20, alpha=0.7)
        plt.axvline(x=avg_pairwise_corr, color='r', linestyle='--', 
                   label=f'Average correlation: {avg_pairwise_corr:.4f}')
        
        plt.title('Whale Pairwise Correlation Distribution')
        plt.xlabel('Average Pairwise Correlation')
        plt.ylabel('Number of Markets')
        plt.legend()
        plt.grid(alpha=0.3)
    else:
        plt.text(0.5, 0.5, "Insufficient data for pairwise correlation analysis", 
                ha='center', va='center', fontsize=12)
    
    plt.tight_layout()
    plt.savefig('whale_sentiment_alignment.png', dpi=300)
    plt.close()
    
    print("Whale sentiment alignment visualization saved to whale_sentiment_alignment.png")
    
    return {
        'market_alignment': alignment_df.to_dict('records'),
        'avg_consensus_strength': avg_consensus,
        'avg_pairwise_correlation': avg_pairwise_corr,
        'bullish_markets_pct': bullish_pct
    }

# Lorenz Curve

In [135]:
def create_lorenz_curve_visualization(whale_analysis_df):
    """
    Create Lorenz curve visualization for trader volume distribution
    
    Parameters:
    -----------
    whale_analysis_df : pd.DataFrame
        DataFrame with trader volume analysis
    """
    plt.figure(figsize=(10, 8))
    
    # Sort traders by volume
    df = whale_analysis_df.sort_values('volume')
    
    # Calculate cumulative percentages
    df['trader_pct'] = np.arange(1, len(df) + 1) / len(df) * 100
    df['volume_pct'] = df['volume'].cumsum() / df['volume'].sum() * 100
    
    # Plot Lorenz curve
    plt.plot(df['trader_pct'], df['volume_pct'], label='Trading volume distribution')
    
    # Plot line of equality
    plt.plot([0, 100], [0, 100], 'k--', label='Perfect equality')
    
    # Fill the area representing the Gini coefficient
    plt.fill_between(df['trader_pct'], df['trader_pct'], df['volume_pct'], alpha=0.2)
    
    # Calculate Gini coefficient
    gini = 1 - np.trapz(df['volume_pct'], df['trader_pct']) / 5000  # Area under perfect equality is 5000 (100*100/2)
    
    # Add key percentiles
    percentiles = [90, 95, 99, 99.9]
    for p in percentiles:
        threshold_idx = int(len(df) * (100 - p) / 100)
        if threshold_idx < len(df):
            x = df['trader_pct'].iloc[threshold_idx]
            y = df['volume_pct'].iloc[threshold_idx]
            plt.plot([x, x], [0, y], 'r--', alpha=0.5)
            plt.plot([0, x], [y, y], 'r--', alpha=0.5)
            plt.text(x + 1, y - 5, f'Top {100-p}%', fontsize=10)
    
    plt.title(f'Trading Volume Distribution (Gini Coefficient: {gini:.4f})')
    plt.xlabel('Cumulative % of Traders')
    plt.ylabel('Cumulative % of Volume')
    plt.grid(alpha=0.3)
    plt.legend()
    
    plt.tight_layout()
    plt.savefig('trading_volume_lorenz.png', dpi=300)
    plt.close()
    
    print(f"Lorenz curve visualization saved as trading_volume_lorenz.png")
    return gini

# Run Analysis

In [138]:
def run_whale_analysis_pipeline(trades_df, generate_plots=True):
    """
    Run the complete whale analysis pipeline
    
    Parameters:
    -----------
    trades_df : pd.DataFrame
        DataFrame with trade data
    generate_plots : bool
        Whether to generate visualization plots
        
    Returns:
    --------
    dict
        Dictionary with all analysis results
    """
    print("\n" + "="*80)
    print("COMPREHENSIVE WHALE TRADER ANALYSIS")
    print("="*80)
    
    # Step 1: Scale trade volumes
    print("\n1. SCALING TRADE VOLUMES")
    print("-"*40)
    cleaned_trades = scale_trade_volume(trades_df)
    
    # Step 2: Identify whales
    print("\n2. IDENTIFYING WHALE TRADERS")
    print("-"*40)
    whale_ids, whale_def_results = identify_whales(cleaned_trades, default_threshold=0.01, generate_plots=generate_plots)
    
    # Step 3: Visualize trading inequality
    print("\n3. ANALYZING TRADING INEQUALITY")
    print("-"*40)
    if 'trader_analysis' in whale_def_results:
        inequality_results = visualize_trading_inequality(whale_def_results['trader_analysis'])
    else:
        print("Skipping inequality visualization as trader analysis is not available")
        inequality_results = None
    
    # Step 4: Analyze whale impact
    print("\n4. ANALYZING WHALE TRADE IMPACT")
    print("-"*40)
    impact_results = analyze_whale_impact(cleaned_trades, whale_ids)
    
    # Step 5: Analyze temporal behavior
    print("\n5. ANALYZING TEMPORAL TRADING PATTERNS")
    print("-"*40)
    time_results = analyze_trader_behavior_over_time(cleaned_trades, whale_ids)
    
    # Step 6: Analyze whale alignment
    print("\n6. ANALYZING WHALE SENTIMENT ALIGNMENT")
    print("-"*40)
    alignment_results = analyze_whale_sentiment_alignment(cleaned_trades, whale_ids)
    
    # Combine results
    all_results = {
        'whale_definition': whale_def_results,
        'inequality': inequality_results,
        'price_impact': impact_results,
        'temporal': time_results,
        'alignment': alignment_results
    }
    
    # Generate summary report
    print("\n" + "="*80)
    print("WHALE ANALYSIS SUMMARY")
    print("="*80)
    
    # Summarize whale definition
    if whale_def_results:
        print(f"\nWhale Definition: Top {whale_def_results['selected_threshold']*100:.1f}% of traders")
        print(f"Number of Whales: {whale_def_results['selected_num_whales']:,}")
        
    # Summarize inequality
    if inequality_results:
        print(f"\nTrading Inequality (Gini): {inequality_results['gini']:.4f}")
        for p in [99, 95, 90]:
            if p in inequality_results['percentile_data']:
                data = inequality_results['percentile_data'][p]
                print(f"Top {100-p:.1f}% of traders control {data['volume_share']:.2f}% of volume")
    
    # Summarize price impact
    if impact_results:
        if 'weighted_whale_impact' in impact_results:
            print(f"\nWeighted Average Whale Price Impact: {impact_results['weighted_whale_impact']:.6f}")
            print(f"Weighted Average Non-Whale Price Impact: {impact_results['weighted_non_whale_impact']:.6f}")
            if impact_results['impact_ratio']:
                print(f"Impact Ratio (Whale/Non-Whale): {impact_results['impact_ratio']:.4f}")
        elif 'whale_impact' in impact_results:
            print(f"\nWhale Average Price Change: {impact_results['whale_impact']['avg_change']:.6f}")
            print(f"Non-Whale Average Price Change: {impact_results['non_whale_impact']['avg_change']:.6f}")
        
        if 'following_ratio' in impact_results and impact_results['following_ratio']:
            print(f"Non-Whales Follow Whale Direction: {impact_results['following_ratio']:.2%} of the time")
    
    # Summarize temporal patterns
    if time_results:
        if 'activity_correlation' in time_results and time_results['activity_correlation']:
            print(f"\nWhale/Non-Whale Activity Correlation: {time_results['activity_correlation']:.4f}")
        
        if 'whale_lead_correlation' in time_results and 'nonwhale_lead_correlation' in time_results:
            lead_type = "Whales" if time_results['whale_lead_correlation'] > time_results['nonwhale_lead_correlation'] else "Non-whales"
            print(f"Leading Influence: {lead_type} appear to lead trading activity")
    
    # Summarize alignment
    if alignment_results:
        print(f"\nWhale Consensus Strength: {alignment_results['avg_consensus_strength']:.4f}")
        print(f"Whale Pairwise Correlation: {alignment_results['avg_pairwise_correlation']:.4f}")
        print(f"Bullish Sentiment: {alignment_results['bullish_markets_pct']:.1f}% of markets")
    
    print("\n" + "="*80)
    print("Analysis Complete. All visualizations saved.")
    print("="*80)
    
    return all_results

# 4. Trader Classification


In [None]:

def classify_traders(trades_df):
    """
    Classify traders based on their trading behavior
    
    Parameters:
    -----------
    trades_df : pd.DataFrame
        DataFrame with trade data
    
    Returns:
    --------
    pd.DataFrame
        Trader classification results
    """
    # Calculate trader metrics
    trader_metrics = trades_df.groupby('trader_id').agg({
        'trade_amount': ['sum', 'mean', 'count'],
        'price': ['mean', 'std']
    }).reset_index()
    
    # Flatten column names
    trader_metrics.columns = ['trader_id', 'total_volume', 'avg_trade_size', 'trade_count', 
                               'avg_price', 'price_volatility']
    
    # Add classification logic
    def classify_trader(row):
        if row['total_volume'] > trader_metrics['total_volume'].quantile(0.95):
            return 'Whale'
        elif row['trade_count'] > trader_metrics['trade_count'].quantile(0.8):
            return 'Active Trader'
        elif row['avg_trade_size'] > trader_metrics['avg_trade_size'].quantile(0.8):
            return 'Large Trade Trader'
        else:
            return 'Casual Trader'
    
    trader_metrics['trader_type'] = trader_metrics.apply(classify_trader, axis=1)
    
    # Print classification summary
    print("\nTrader Classification:")
    print(trader_metrics['trader_type'].value_counts(normalize=True))
    
    return trader_metrics

# Classify traders
trader_classification = classify_traders(trade_data)



# 5. Market Dynamics Visualization


In [None]:

import matplotlib.pyplot as plt
import seaborn as sns

plt.figure(figsize=(15, 10))

# Trader Volume Distribution
plt.subplot(2, 2, 1)
sns.histplot(trader_classification['total_volume'], bins=50, kde=True)
plt.title('Trader Volume Distribution')
plt.xlabel('Total Trading Volume')
plt.ylabel('Frequency')

# Trader Type Distribution
plt.subplot(2, 2, 2)
trader_classification['trader_type'].value_counts().plot(kind='pie', autopct='%1.1f%%')
plt.title('Trader Type Distribution')

# Price Changes by Trader Type
plt.subplot(2, 2, 3)
sns.boxplot(x='trader_type', y='total_volume', data=trader_classification)
plt.title('Volume by Trader Type')
plt.xticks(rotation=45)

# Trade Frequency Distribution
plt.subplot(2, 2, 4)
sns.histplot(trader_classification['trade_count'], bins=50, kde=True)
plt.title('Trade Frequency Distribution')
plt.xlabel('Number of Trades')
plt.ylabel('Frequency')

plt.tight_layout()
plt.savefig('trader_analysis_plots.png')
plt.close()

print("\nAnalysis complete. Visualization saved as trader_analysis_plots.png")

In [None]:
def classify_traders(trades_df, n_clusters=5):
    """
    Classify traders into different types based on their behavior
    
    Parameters:
    -----------
    trades_df : pd.DataFrame
        DataFrame with trade data
    n_clusters : int
        Number of clusters to create
        
    Returns:
    --------
    dict
        Dictionary with classification results
    """
    # Create a combined trader ID using both maker and taker
    if 'maker_id' in trades_df.columns and 'taker_id' in trades_df.columns:
        # Get all unique trader IDs from both maker and taker columns
        all_traders = set(trades_df['maker_id'].dropna().unique()) | set(trades_df['taker_id'].dropna().unique())
        all_trader_ids = list(all_traders)
        
        print(f"Analyzing {len(all_trader_ids)} unique traders from maker/taker columns")
    elif 'trader_id' in trades_df.columns:
        all_trader_ids = trades_df['trader_id'].unique()
        print(f"Analyzing {len(all_trader_ids)} unique traders from trader_id column")
    else:
        print("Error: No trader identifier columns found")
        return None
    
    print("Calculating trader features...")
    
    # Group by trader_id and calculate features
    trader_features = []
    
    # For each trader, calculate features based on their maker and taker activities
    for trader_id in all_trader_ids:
        # Get all trades where this trader was involved (as maker or taker)
        if 'maker_id' in trades_df.columns and 'taker_id' in trades_df.columns:
            maker_trades = trades_df[trades_df['maker_id'] == trader_id]
            taker_trades = trades_df[trades_df['taker_id'] == trader_id]
            trader_trades = pd.concat([maker_trades, taker_trades]).drop_duplicates()
        else:
            trader_trades = trades_df[trades_df['trader_id'] == trader_id]
        
        # Skip traders with too few trades
        if len(trader_trades) < 3:
            continue
            
        # Basic activity metrics
        trade_count = len(trader_trades)
        
        # Trade size metrics
        if 'trade_amount' in trader_trades.columns:
            avg_trade_size = trader_trades['trade_amount'].mean()
            total_volume = trader_trades['trade_amount'].sum()
            trade_size_volatility = trader_trades['trade_amount'].std() / avg_trade_size if avg_trade_size > 0 else 0
        elif 'size' in trader_trades.columns:
            avg_trade_size = trader_trades['size'].mean()
            total_volume = trader_trades['size'].sum()
            trade_size_volatility = trader_trades['size'].std() / avg_trade_size if avg_trade_size > 0 else 0
        else:
            avg_trade_size = np.nan
            total_volume = np.nan
            trade_size_volatility = np.nan
        
        # Trader diversity (market participation)
        if 'market_id' in trader_trades.columns:
            market_count = trader_trades['market_id'].nunique()
            market_concentration = (trader_trades.groupby('market_id').size() / trade_count).max()
        else:
            market_count = 1
            market_concentration = 1.0
            
        # Trading frequency
        if 'timestamp' in trader_trades.columns:
            trader_trades = trader_trades.sort_values('timestamp')
            if len(trader_trades) > 1:
                # Convert timestamp to datetime if it's not already
                if not pd.api.types.is_datetime64_any_dtype(trader_trades['timestamp']):
                    trader_trades['timestamp'] = pd.to_datetime(trader_trades['timestamp'])
                
                # Calculate time between trades in minutes
                time_diffs = trader_trades['timestamp'].diff().dropna()
                if len(time_diffs) > 0:
                    try:
                        avg_time_between_trades = time_diffs.mean().total_seconds() / 60
                        trade_timing_regularity = time_diffs.std().total_seconds() / avg_time_between_trades if avg_time_between_trades > 0 else np.nan
                    except:
                        avg_time_between_trades = np.nan
                        trade_timing_regularity = np.nan
                else:
                    avg_time_between_trades = np.nan
                    trade_timing_regularity = np.nan
            else:
                avg_time_between_trades = np.nan
                trade_timing_regularity = np.nan
        else:
            avg_time_between_trades = np.nan
            trade_timing_regularity = np.nan
            
        # Trading direction bias
        if 'side' in trader_trades.columns:
            buy_count = (trader_trades['side'] == 'buy').sum()
            sell_count = (trader_trades['side'] == 'sell').sum()
            if buy_count + sell_count > 0:
                buy_ratio = buy_count / (buy_count + sell_count)
            else:
                buy_ratio = 0.5
        elif 'trade_direction' in trader_trades.columns:
            buy_count = (trader_trades['trade_direction'] == 'buy').sum()
            sell_count = (trader_trades['trade_direction'] == 'sell').sum()
            if buy_count + sell_count > 0:
                buy_ratio = buy_count / (buy_count + sell_count)
            else:
                buy_ratio = 0.5
        else:
            buy_ratio = 0.5
            
        # Store features
        trader_features.append({
            'trader_id': trader_id,
            'trade_count': trade_count,
            'avg_trade_size': avg_trade_size,
            'total_volume': total_volume,
            'trade_size_volatility': trade_size_volatility,
            'market_count': market_count,
            'market_concentration': market_concentration,
            'avg_time_between_trades': avg_time_between_trades,
            'trade_timing_regularity': trade_timing_regularity,
            'buy_ratio': buy_ratio
        })
    
    if not trader_features:
        print("No trader features calculated")
        return None
        
    # Create DataFrame
    trader_df = pd.DataFrame(trader_features)
    print(f"Calculated features for {len(trader_df)} traders")
    
    # Select features for clustering
    features_for_clustering = [
        'trade_count', 'avg_trade_size', 'market_concentration', 
        'buy_ratio', 'trade_size_volatility'
    ]
    
    # Filter to available features and remove any with all NaN values
    available_features = []
    for f in features_for_clustering:
        if f in trader_df.columns and not trader_df[f].isna().all():
            available_features.append(f)
    
    print(f"Using {len(available_features)} features for clustering: {available_features}")
    
    if len(available_features) < 2:
        print("Insufficient features for clustering")
        return None
        
    # Handle missing values
    X = trader_df[available_features].copy()
    X = X.fillna(X.mean())
    
    # Standardize features
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)
    
    print("Performing clustering...")
    
    # Apply K-means clustering
    kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
    trader_df['cluster'] = kmeans.fit_predict(X_scaled)
    
    # Calculate cluster profiles
    cluster_profiles = trader_df.groupby('cluster')[available_features].mean()
    cluster_sizes = trader_df['cluster'].value_counts().sort_index()
    cluster_profiles['size'] = cluster_sizes.values
    cluster_profiles['percentage'] = 100 * cluster_sizes / cluster_sizes.sum()
    
    # Interpret clusters
    cluster_names = {}
    for cluster_id in range(n_clusters):
        profile = cluster_profiles.loc[cluster_id]
        
        # Calculate z-scores for this cluster compared to others
        z_scores = {}
        for feature in available_features:
            feature_mean = cluster_profiles[feature].mean()
            feature_std = cluster_profiles[feature].std()
            if feature_std > 0:
                z_scores[feature] = (profile[feature] - feature_mean) / feature_std
            else:
                z_scores[feature] = 0
                
        # Determine cluster type based on most extreme z-scores
        top_feature = max(z_scores.items(), key=lambda x: abs(x[1]))
        
        if top_feature[0] == 'trade_count' and top_feature[1] > 1:
            name = "High Frequency Traders"
        elif top_feature[0] == 'avg_trade_size' and top_feature[1] > 1:
            name = "Whale Traders"
        elif top_feature[0] == 'market_concentration' and top_feature[1] > 1:
            name = "Market Specialists" 
        elif top_feature[0] == 'buy_ratio':
            if top_feature[1] > 1:
                name = "Bullish Traders"
            else:
                name = "Bearish Traders"
        elif top_feature[0] == 'trade_size_volatility' and top_feature[1] > 1:
            name = "Opportunistic Traders"
        else:
            name = "Balanced Traders"
            
        cluster_names[cluster_id] = name
    
    # Add names to profiles
    cluster_profiles['type'] = [cluster_names[i] for i in cluster_profiles.index]
    
    # Create visualizations
    
    # 1. PCA visualization of clusters
    if len(available_features) >= 2:
        pca = PCA(n_components=2)
        X_pca = pca.fit_transform(X_scaled)
        
        # Create visualization DataFrame
        viz_df = pd.DataFrame({
            'PC1': X_pca[:, 0],
            'PC2': X_pca[:, 1],
            'Cluster': trader_df['cluster'],
            'Type': trader_df['cluster'].map(cluster_names)
        })
        
        # Calculate explained variance
        explained_variance = pca.explained_variance_ratio_
        
        # Create PCA plot
        plt.figure(figsize=(10, 8))
        sns.scatterplot(data=viz_df, x='PC1', y='PC2', hue='Type', palette='viridis', s=50, alpha=0.7)
        plt.title('Trader Types - PCA Visualization')
        plt.xlabel(f'PC1 ({explained_variance[0]:.1%} variance)')
        plt.ylabel(f'PC2 ({explained_variance[1]:.1%} variance)')
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.tight_layout()
        plt.savefig(os.path.join(results_dir, 'trader_clusters_pca.png'), dpi=300, bbox_inches='tight')
        plt.close()
        
        # 2. Radar chart of cluster profiles
        plt.figure(figsize=(12, 10))
        
        # Normalize profiles for radar chart
        radar_df = cluster_profiles[available_features].copy()
        for feature in available_features:
            feature_max = radar_df[feature].max()
            if feature_max > 0:
                radar_df[feature] = radar_df[feature] / feature_max
                
        # Number of features
        N = len(available_features)
        
        # Create angles for radar chart
        angles = np.linspace(0, 2*np.pi, N, endpoint=False).tolist()
        angles += angles[:1]  # Close the loop
        
        # Create subplot with polar projection
        ax = plt.subplot(111, polar=True)
        
        # Add feature labels
        plt.xticks(angles[:-1], available_features, size=12)
        
        # Plot each cluster
        for cluster_id, name in cluster_names.items():
            values = radar_df.loc[cluster_id].tolist()
            values += values[:1]  # Close the loop
            
            ax.plot(angles, values, linewidth=2, label=name)
            ax.fill(angles, values, alpha=0.1)
            
        plt.title('Trader Type Profiles', size=15)
        plt.legend(loc='upper right', bbox_to_anchor=(0.1, 0.1))
        plt.savefig(os.path.join(results_dir, 'trader_type_radar.png'), dpi=300, bbox_inches='tight')
        plt.close()
        
        # 3. Bar chart of cluster sizes
        plt.figure(figsize=(12, 6))
        
        # Sort by size
        sorted_profiles = cluster_profiles.sort_values('size', ascending=False)
        
        # Create bar chart
        plt.bar(
            range(len(sorted_profiles)), 
            sorted_profiles['size'],
            tick_label=[f"{cluster_names[i]}\n({sorted_profiles.loc[i, 'percentage']:.1f}%)" 
                       for i in sorted_profiles.index]
        )
        
        plt.title('Trader Type Distribution')
        plt.ylabel('Number of Traders')
        plt.tight_layout()
        plt.savefig(os.path.join(results_dir, 'trader_type_distribution.png'), dpi=300, bbox_inches='tight')
        plt.close()
    
    return {
        'trader_features': trader_df,
        'cluster_profiles': cluster_profiles,
        'cluster_names': cluster_names,
        'feature_importance': feature_importance
    }