# Set Up

In [None]:
import os
import sys
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from statsmodels.tsa.stattools import grangercausalitytests
import warnings
import json
from datetime import datetime

# Add src directory to path for importing utility functions
current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir) if current_dir.endswith('notebooks') else current_dir
if parent_dir not in sys.path:
    sys.path.append(parent_dir)

# Import utility functions
from src.utils.data_loader import load_main_dataset, load_trade_data

# Create results directory
results_dir = 'results/trader_analysis'
os.makedirs(results_dir, exist_ok=True)

# Suppress warnings
warnings.filterwarnings('ignore')

# Set plotting style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("viridis")

# Cell 2: Helper Functions (Remove function definitions, just define the functions)
def calculate_gini(values):
    if len(values) <= 1 or np.sum(values) == 0:
        return 0
    
    sorted_values = np.sort(values)
    n = len(sorted_values)
    cumsum = np.cumsum(sorted_values)
    return (n + 1 - 2 * np.sum((n + 1 - np.arange(1, n+1)) * sorted_values) / np.sum(sorted_values)) / n

In [None]:
# Cell 3: Load Market Data
print("Loading main dataset...")
market_data = load_main_dataset('data/cleaned_election_data.csv')

if market_data is not None:
    print(f"Successfully loaded {len(market_data)} markets")
    
    # Explore trader-related metrics
    trader_cols = [col for col in market_data.columns if any(term in col.lower() 
                   for term in ['trader', 'trade', 'concentration'])]
    
    print("\nAvailable trader-related columns:")
    for col in trader_cols:
        print(f"- {col}")
        
    # Display basic statistics for key trader metrics
    trader_metrics = ['unique_traders_count', 'trader_to_trade_ratio', 
                     'two_way_traders_ratio', 'new_trader_influx']
    available_metrics = [col for col in trader_metrics if col in market_data.columns]
    
    if available_metrics:
        print("\nSummary statistics for trader metrics:")
        print(market_data[available_metrics].describe())

# Load Trade Data

In [65]:
def load_trade_data_for_analysis(market_ids=None, max_trades_per_market=None):
    """
    Load trade data for specific market IDs
    
    Parameters:
    -----------
    market_ids : list, optional
        List of specific market IDs to load
    max_trades_per_market : int, optional
        Maximum number of trades per market (None for all trades)
    
    Returns:
    --------
    pd.DataFrame
        Combined trade data from specified markets
    """
    print("Loading market data...")
    market_data = load_main_dataset('data/cleaned_election_data.csv')
    
    if market_data is None:
        print("Failed to load market data")
        return None
    
    # If market_ids not provided, use all markets
    if market_ids is None:
        market_ids = market_data['id'].tolist()
    
    print(f"Selected {len(market_ids)} markets for analysis")
    
    from src.utils.data_loader import load_trade_data, get_token_ids_for_market
    
    all_trades = []
    for i, market_id in enumerate(market_ids):
        try:
            # Get market name if available
            market_name = market_data.loc[market_data['id'] == market_id, 'question'].iloc[0] \
                         if 'question' in market_data.columns else f"Market {market_id}"
            
            print(f"\nLoading trades for market {i+1}/{len(market_ids)}: {market_name}")
            print(f"Market ID: {market_id}")
            
            # Try to load trade data using utility function
            trades = load_trade_data(market_id)
            
            if trades is not None and len(trades) > 0:
                print(f"Successfully loaded {len(trades)} trades")
                
                # Add market identifier
                trades['market_id'] = market_id
                
                # Sample if max_trades_per_market is specified
                if max_trades_per_market is not None and len(trades) > max_trades_per_market:
                    print(f"Sampling {max_trades_per_market} trades from {len(trades)} total")
                    trades = trades.sample(max_trades_per_market, random_state=42)
                
                all_trades.append(trades)
            else:
                print("No trades found using default method, trying alternative approaches")
                
                # Try to get token IDs for this market
                token_ids = get_token_ids_for_market(market_id, main_df=market_data)
                
                if token_ids and len(token_ids) > 0:
                    print(f"Found {len(token_ids)} token IDs for market {market_id}")
                    
                    # Try to locate token files directly
                    from src.utils.data_loader import find_token_id_file
                    
                    market_trades = []
                    for token_id in token_ids:
                        try:
                            token_file = find_token_id_file(token_id)
                            if token_file:
                                print(f"Found token file: {os.path.basename(token_file)}")
                                
                                # Load this token's trades
                                import pyarrow.parquet as pq
                                token_trades = pq.read_table(token_file).to_pandas()
                                token_trades['market_id'] = market_id
                                token_trades['token_id'] = token_id
                                
                                market_trades.append(token_trades)
                                print(f"Loaded {len(token_trades)} trades for token {token_id}")
                        except Exception as e:
                            print(f"Error loading token {token_id}: {e}")
                    
                    if market_trades:
                        combined_market_trades = pd.concat(market_trades, ignore_index=True)
                        
                        # Sample if specified
                        if max_trades_per_market is not None and len(combined_market_trades) > max_trades_per_market:
                            print(f"Sampling {max_trades_per_market} trades from {len(combined_market_trades)} total")
                            combined_market_trades = combined_market_trades.sample(max_trades_per_market, random_state=42)
                        
                        all_trades.append(combined_market_trades)
                    else:
                        print(f"No trade data found for any tokens in market {market_id}")
                else:
                    print(f"No token IDs found for market {market_id}")
        except Exception as e:
            print(f"Error loading trades for market {market_id}: {e}")
    
    if not all_trades:
        print("No trade data loaded for any markets")
        return None
    
    # Combine all trade data
    combined_trades = pd.concat(all_trades, ignore_index=True)
    print(f"\nTotal trades loaded: {len(combined_trades)} from {len(all_trades)} markets")
    
    # Standardize and prepare trade data (same as before)
    if 'timestamp' in combined_trades.columns and not pd.api.types.is_datetime64_any_dtype(combined_trades['timestamp']):
        try:
            combined_trades['timestamp'] = pd.to_numeric(combined_trades['timestamp'], errors='coerce')
            combined_trades['timestamp'] = pd.to_datetime(combined_trades['timestamp'], unit='s', errors='coerce')
            combined_trades = combined_trades.dropna(subset=['timestamp'])
            print(f"Converted {len(combined_trades)} valid timestamps")
        except Exception as e:
            print(f"Error converting timestamps: {e}")
    
    # Standardize trader ID and other columns (same as before)
    if 'maker' in combined_trades.columns and 'maker_id' not in combined_trades.columns:
        combined_trades['maker_id'] = combined_trades['maker']
    if 'taker' in combined_trades.columns and 'taker_id' not in combined_trades.columns:
        combined_trades['taker_id'] = combined_trades['taker']
    
    # Create trader_id column
    combined_trades['trader_id'] = combined_trades['maker_id']
    
    # Create trade_amount column if not present
    if 'trade_amount' not in combined_trades.columns:
        if 'size' in combined_trades.columns:
            combined_trades['trade_amount'] = combined_trades['size']
        else:
            combined_trades['trade_amount'] = 1.0
    
    # Additional logging
    unique_makers = combined_trades['maker_id'].nunique() if 'maker_id' in combined_trades.columns else 0
    unique_takers = combined_trades['taker_id'].nunique() if 'taker_id' in combined_trades.columns else 0
    unique_traders = combined_trades['trader_id'].nunique() if 'trader_id' in combined_trades.columns else 0
    
    print(f"Unique traders identified: {unique_traders} (makers: {unique_makers}, takers: {unique_takers})")
    
    return combined_trades

In [66]:
target_markets = [
    "Will Donald Trump win the 2024 US Presidential Election?", 
    "Will Kamala Harris win the 2024 US Presidential Election?"
]

# Load main dataset
market_data = load_main_dataset('data/cleaned_election_data.csv')

# Filter to specific markets
selected_markets = market_data[market_data['question'].isin(target_markets)]
market_ids = selected_markets['id'].tolist()

# Load ALL trades for these markets
trade_data = load_trade_data_for_analysis(
    market_ids=market_ids, 
    max_trades_per_market=None  # Load all trades
)

print(f"\nTotal trades loaded: {len(trade_data)}")
print(f"Unique markets: {trade_data['market_id'].nunique()}")
print(f"Unique traders: {trade_data['trader_id'].nunique()}")

Loaded dataset with 1048575 rows and 54 columns
Loading market data...
Loaded dataset with 1048575 rows and 54 columns
Selected 2 markets for analysis

Loading trades for market 1/2: Will Donald Trump win the 2024 US Presidential Election?
Market ID: 253591.0
Loaded dataset with 1048575 rows and 54 columns
Successfully loaded 1185000 trades

Loading trades for market 2/2: Will Kamala Harris win the 2024 US Presidential Election?
Market ID: 253597.0
Loaded dataset with 1048575 rows and 54 columns
No trade data found for market 253597.0
No trades found using default method, trying alternative approaches
Found 2 token IDs for market 253597.0
Found token file: 69236923620077691027083946871148646972011131466059644796654161903044970987404.parquet 18-30-01-221.parquet
Loaded 802000 trades for token 69236923620077691027083946871148646972011131466059644796654161903044970987404
Found token file: 87584955359245246404952128082451897287778571240979823316620093987046202296181.parquet 18-30-01-396.pa

In [68]:
# Debug print statements to understand market loading
print("\nDetailed Market Information:")
for index, row in selected_markets.iterrows():
    print(f"Market ID: {row['id']}")
    print(f"Question: {row['question']}")
    print(f"Tokens: {get_token_ids_for_market(row['id'], main_df=market_data)}")
    print("-" * 50)

# Modify the loading to explicitly handle both markets
market_ids = selected_markets['id'].tolist()
print("\nMarket IDs:", market_ids)

trade_data = load_trade_data_for_analysis(
    market_ids=market_ids, 
    max_trades_per_market=None  # Load all trades
)

print(f"\nTotal trades loaded: {len(trade_data)}")
print(f"Unique markets: {trade_data['market_id'].nunique()}")
print(f"Unique traders: {trade_data['trader_id'].nunique()}")
print("\nMarket breakdown:")
print(trade_data['market_id'].value_counts())


Detailed Market Information:
Market ID: 253591.0
Question: Will Donald Trump win the 2024 US Presidential Election?


NameError: name 'get_token_ids_for_market' is not defined

In [None]:
# Trader Classification and Analysis

# 1. Basic Market Overview
print("Market Trade Statistics:")
print(f"Total Trades: {len(trade_data)}")
print(f"Unique Markets: {trade_data['market_id'].nunique()}")
print(f"Unique Traders: {trade_data['trader_id'].nunique()}")


In [None]:

# 2. Identify Potential Whale Traders
def identify_whales(trades_df, threshold=0.05):
    """
    Identify whale traders based on trading volume
    
    Parameters:
    -----------
    trades_df : pd.DataFrame
        DataFrame with trade data
    threshold : float
        Percentage of top traders to consider as whales
    
    Returns:
    --------
    list
        List of whale trader IDs
    """
    # Group trades by trader and calculate total volume
    trader_volumes = trades_df.groupby('trader_id')['trade_amount'].sum().sort_values(ascending=False)
    
    # Calculate number of whales
    whale_count = max(1, int(len(trader_volumes) * threshold))
    
    # Identify whale traders
    whale_ids = trader_volumes.head(whale_count).index.tolist()
    
    print("\nWhale Trader Analysis:")
    print(f"Total Traders: {len(trader_volumes)}")
    print(f"Number of Whales: {whale_count}")
    print(f"Total Volume of Whales: {trader_volumes.head(whale_count).sum()}")
    print(f"Whale Volume Percentage: {trader_volumes.head(whale_count).sum() / trader_volumes.sum() * 100:.2f}%")
    
    return whale_ids

# Identify whales
whale_ids = identify_whales(trade_data)


In [None]:

# 3. Whale Trade Impact Analysis
def analyze_whale_impact(trades_df, whale_ids):
    """
    Analyze the impact of whale trades on market prices
    
    Parameters:
    -----------
    trades_df : pd.DataFrame
        DataFrame with trade data
    whale_ids : list
        List of whale trader IDs
    
    Returns:
    --------
    dict
        Analysis of whale trade impacts
    """
    # Separate whale and non-whale trades
    whale_trades = trades_df[trades_df['trader_id'].isin(whale_ids)]
    non_whale_trades = trades_df[~trades_df['trader_id'].isin(whale_ids)]
    
    # Analyze price changes around whale trades
    def calculate_price_impact(trades):
        # Sort trades by timestamp
        trades_sorted = trades.sort_values('timestamp')
        
        # Calculate price changes
        trades_sorted['price_change'] = trades_sorted['price'].diff()
        
        return {
            'avg_price_change': trades_sorted['price_change'].mean(),
            'median_price_change': trades_sorted['price_change'].median(),
            'volatility': trades_sorted['price_change'].std()
        }
    
    whale_impact = calculate_price_impact(whale_trades)
    non_whale_impact = calculate_price_impact(non_whale_trades)
    
    print("\nWhale Trade Impact:")
    print("Whale Trades Price Changes:")
    for k, v in whale_impact.items():
        print(f"{k}: {v}")
    
    print("\nNon-Whale Trades Price Changes:")
    for k, v in non_whale_impact.items():
        print(f"{k}: {v}")
    
    return {
        'whale_impact': whale_impact,
        'non_whale_impact': non_whale_impact
    }

# Analyze whale trade impact
whale_impact_results = analyze_whale_impact(trade_data, whale_ids)


In [None]:

# 4. Trader Classification
def classify_traders(trades_df):
    """
    Classify traders based on their trading behavior
    
    Parameters:
    -----------
    trades_df : pd.DataFrame
        DataFrame with trade data
    
    Returns:
    --------
    pd.DataFrame
        Trader classification results
    """
    # Calculate trader metrics
    trader_metrics = trades_df.groupby('trader_id').agg({
        'trade_amount': ['sum', 'mean', 'count'],
        'price': ['mean', 'std']
    }).reset_index()
    
    # Flatten column names
    trader_metrics.columns = ['trader_id', 'total_volume', 'avg_trade_size', 'trade_count', 
                               'avg_price', 'price_volatility']
    
    # Add classification logic
    def classify_trader(row):
        if row['total_volume'] > trader_metrics['total_volume'].quantile(0.95):
            return 'Whale'
        elif row['trade_count'] > trader_metrics['trade_count'].quantile(0.8):
            return 'Active Trader'
        elif row['avg_trade_size'] > trader_metrics['avg_trade_size'].quantile(0.8):
            return 'Large Trade Trader'
        else:
            return 'Casual Trader'
    
    trader_metrics['trader_type'] = trader_metrics.apply(classify_trader, axis=1)
    
    # Print classification summary
    print("\nTrader Classification:")
    print(trader_metrics['trader_type'].value_counts(normalize=True))
    
    return trader_metrics

# Classify traders
trader_classification = classify_traders(trade_data)


In [None]:

# 5. Market Dynamics Visualization
import matplotlib.pyplot as plt
import seaborn as sns

plt.figure(figsize=(15, 10))

# Trader Volume Distribution
plt.subplot(2, 2, 1)
sns.histplot(trader_classification['total_volume'], bins=50, kde=True)
plt.title('Trader Volume Distribution')
plt.xlabel('Total Trading Volume')
plt.ylabel('Frequency')

# Trader Type Distribution
plt.subplot(2, 2, 2)
trader_classification['trader_type'].value_counts().plot(kind='pie', autopct='%1.1f%%')
plt.title('Trader Type Distribution')

# Price Changes by Trader Type
plt.subplot(2, 2, 3)
sns.boxplot(x='trader_type', y='total_volume', data=trader_classification)
plt.title('Volume by Trader Type')
plt.xticks(rotation=45)

# Trade Frequency Distribution
plt.subplot(2, 2, 4)
sns.histplot(trader_classification['trade_count'], bins=50, kde=True)
plt.title('Trade Frequency Distribution')
plt.xlabel('Number of Trades')
plt.ylabel('Frequency')

plt.tight_layout()
plt.savefig('trader_analysis_plots.png')
plt.close()

print("\nAnalysis complete. Visualization saved as trader_analysis_plots.png")

# Trader Classification

In [60]:
def classify_traders(trades_df, n_clusters=5):
    """
    Classify traders into different types based on their behavior
    
    Parameters:
    -----------
    trades_df : pd.DataFrame
        DataFrame with trade data
    n_clusters : int
        Number of clusters to create
        
    Returns:
    --------
    dict
        Dictionary with classification results
    """
    # Create a combined trader ID using both maker and taker
    if 'maker_id' in trades_df.columns and 'taker_id' in trades_df.columns:
        # Get all unique trader IDs from both maker and taker columns
        all_traders = set(trades_df['maker_id'].dropna().unique()) | set(trades_df['taker_id'].dropna().unique())
        all_trader_ids = list(all_traders)
        
        print(f"Analyzing {len(all_trader_ids)} unique traders from maker/taker columns")
    elif 'trader_id' in trades_df.columns:
        all_trader_ids = trades_df['trader_id'].unique()
        print(f"Analyzing {len(all_trader_ids)} unique traders from trader_id column")
    else:
        print("Error: No trader identifier columns found")
        return None
    
    print("Calculating trader features...")
    
    # Group by trader_id and calculate features
    trader_features = []
    
    # For each trader, calculate features based on their maker and taker activities
    for trader_id in all_trader_ids:
        # Get all trades where this trader was involved (as maker or taker)
        if 'maker_id' in trades_df.columns and 'taker_id' in trades_df.columns:
            maker_trades = trades_df[trades_df['maker_id'] == trader_id]
            taker_trades = trades_df[trades_df['taker_id'] == trader_id]
            trader_trades = pd.concat([maker_trades, taker_trades]).drop_duplicates()
        else:
            trader_trades = trades_df[trades_df['trader_id'] == trader_id]
        
        # Skip traders with too few trades
        if len(trader_trades) < 3:
            continue
            
        # Basic activity metrics
        trade_count = len(trader_trades)
        
        # Trade size metrics
        if 'trade_amount' in trader_trades.columns:
            avg_trade_size = trader_trades['trade_amount'].mean()
            total_volume = trader_trades['trade_amount'].sum()
            trade_size_volatility = trader_trades['trade_amount'].std() / avg_trade_size if avg_trade_size > 0 else 0
        elif 'size' in trader_trades.columns:
            avg_trade_size = trader_trades['size'].mean()
            total_volume = trader_trades['size'].sum()
            trade_size_volatility = trader_trades['size'].std() / avg_trade_size if avg_trade_size > 0 else 0
        else:
            avg_trade_size = np.nan
            total_volume = np.nan
            trade_size_volatility = np.nan
        
        # Trader diversity (market participation)
        if 'market_id' in trader_trades.columns:
            market_count = trader_trades['market_id'].nunique()
            market_concentration = (trader_trades.groupby('market_id').size() / trade_count).max()
        else:
            market_count = 1
            market_concentration = 1.0
            
        # Trading frequency
        if 'timestamp' in trader_trades.columns:
            trader_trades = trader_trades.sort_values('timestamp')
            if len(trader_trades) > 1:
                # Convert timestamp to datetime if it's not already
                if not pd.api.types.is_datetime64_any_dtype(trader_trades['timestamp']):
                    trader_trades['timestamp'] = pd.to_datetime(trader_trades['timestamp'])
                
                # Calculate time between trades in minutes
                time_diffs = trader_trades['timestamp'].diff().dropna()
                if len(time_diffs) > 0:
                    try:
                        avg_time_between_trades = time_diffs.mean().total_seconds() / 60
                        trade_timing_regularity = time_diffs.std().total_seconds() / avg_time_between_trades if avg_time_between_trades > 0 else np.nan
                    except:
                        avg_time_between_trades = np.nan
                        trade_timing_regularity = np.nan
                else:
                    avg_time_between_trades = np.nan
                    trade_timing_regularity = np.nan
            else:
                avg_time_between_trades = np.nan
                trade_timing_regularity = np.nan
        else:
            avg_time_between_trades = np.nan
            trade_timing_regularity = np.nan
            
        # Trading direction bias
        if 'side' in trader_trades.columns:
            buy_count = (trader_trades['side'] == 'buy').sum()
            sell_count = (trader_trades['side'] == 'sell').sum()
            if buy_count + sell_count > 0:
                buy_ratio = buy_count / (buy_count + sell_count)
            else:
                buy_ratio = 0.5
        elif 'trade_direction' in trader_trades.columns:
            buy_count = (trader_trades['trade_direction'] == 'buy').sum()
            sell_count = (trader_trades['trade_direction'] == 'sell').sum()
            if buy_count + sell_count > 0:
                buy_ratio = buy_count / (buy_count + sell_count)
            else:
                buy_ratio = 0.5
        else:
            buy_ratio = 0.5
            
        # Store features
        trader_features.append({
            'trader_id': trader_id,
            'trade_count': trade_count,
            'avg_trade_size': avg_trade_size,
            'total_volume': total_volume,
            'trade_size_volatility': trade_size_volatility,
            'market_count': market_count,
            'market_concentration': market_concentration,
            'avg_time_between_trades': avg_time_between_trades,
            'trade_timing_regularity': trade_timing_regularity,
            'buy_ratio': buy_ratio
        })
    
    if not trader_features:
        print("No trader features calculated")
        return None
        
    # Create DataFrame
    trader_df = pd.DataFrame(trader_features)
    print(f"Calculated features for {len(trader_df)} traders")
    
    # Select features for clustering
    features_for_clustering = [
        'trade_count', 'avg_trade_size', 'market_concentration', 
        'buy_ratio', 'trade_size_volatility'
    ]
    
    # Filter to available features and remove any with all NaN values
    available_features = []
    for f in features_for_clustering:
        if f in trader_df.columns and not trader_df[f].isna().all():
            available_features.append(f)
    
    print(f"Using {len(available_features)} features for clustering: {available_features}")
    
    if len(available_features) < 2:
        print("Insufficient features for clustering")
        return None
        
    # Handle missing values
    X = trader_df[available_features].copy()
    X = X.fillna(X.mean())
    
    # Standardize features
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)
    
    print("Performing clustering...")
    
    # Apply K-means clustering
    kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
    trader_df['cluster'] = kmeans.fit_predict(X_scaled)
    
    # Calculate cluster profiles
    cluster_profiles = trader_df.groupby('cluster')[available_features].mean()
    cluster_sizes = trader_df['cluster'].value_counts().sort_index()
    cluster_profiles['size'] = cluster_sizes.values
    cluster_profiles['percentage'] = 100 * cluster_sizes / cluster_sizes.sum()
    
    # Interpret clusters
    cluster_names = {}
    for cluster_id in range(n_clusters):
        profile = cluster_profiles.loc[cluster_id]
        
        # Calculate z-scores for this cluster compared to others
        z_scores = {}
        for feature in available_features:
            feature_mean = cluster_profiles[feature].mean()
            feature_std = cluster_profiles[feature].std()
            if feature_std > 0:
                z_scores[feature] = (profile[feature] - feature_mean) / feature_std
            else:
                z_scores[feature] = 0
                
        # Determine cluster type based on most extreme z-scores
        top_feature = max(z_scores.items(), key=lambda x: abs(x[1]))
        
        if top_feature[0] == 'trade_count' and top_feature[1] > 1:
            name = "High Frequency Traders"
        elif top_feature[0] == 'avg_trade_size' and top_feature[1] > 1:
            name = "Whale Traders"
        elif top_feature[0] == 'market_concentration' and top_feature[1] > 1:
            name = "Market Specialists" 
        elif top_feature[0] == 'buy_ratio':
            if top_feature[1] > 1:
                name = "Bullish Traders"
            else:
                name = "Bearish Traders"
        elif top_feature[0] == 'trade_size_volatility' and top_feature[1] > 1:
            name = "Opportunistic Traders"
        else:
            name = "Balanced Traders"
            
        cluster_names[cluster_id] = name
    
    # Add names to profiles
    cluster_profiles['type'] = [cluster_names[i] for i in cluster_profiles.index]
    
    # Create visualizations
    
    # 1. PCA visualization of clusters
    if len(available_features) >= 2:
        pca = PCA(n_components=2)
        X_pca = pca.fit_transform(X_scaled)
        
        # Create visualization DataFrame
        viz_df = pd.DataFrame({
            'PC1': X_pca[:, 0],
            'PC2': X_pca[:, 1],
            'Cluster': trader_df['cluster'],
            'Type': trader_df['cluster'].map(cluster_names)
        })
        
        # Calculate explained variance
        explained_variance = pca.explained_variance_ratio_
        
        # Create PCA plot
        plt.figure(figsize=(10, 8))
        sns.scatterplot(data=viz_df, x='PC1', y='PC2', hue='Type', palette='viridis', s=50, alpha=0.7)
        plt.title('Trader Types - PCA Visualization')
        plt.xlabel(f'PC1 ({explained_variance[0]:.1%} variance)')
        plt.ylabel(f'PC2 ({explained_variance[1]:.1%} variance)')
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.tight_layout()
        plt.savefig(os.path.join(results_dir, 'trader_clusters_pca.png'), dpi=300, bbox_inches='tight')
        plt.close()
        
        # 2. Radar chart of cluster profiles
        plt.figure(figsize=(12, 10))
        
        # Normalize profiles for radar chart
        radar_df = cluster_profiles[available_features].copy()
        for feature in available_features:
            feature_max = radar_df[feature].max()
            if feature_max > 0:
                radar_df[feature] = radar_df[feature] / feature_max
                
        # Number of features
        N = len(available_features)
        
        # Create angles for radar chart
        angles = np.linspace(0, 2*np.pi, N, endpoint=False).tolist()
        angles += angles[:1]  # Close the loop
        
        # Create subplot with polar projection
        ax = plt.subplot(111, polar=True)
        
        # Add feature labels
        plt.xticks(angles[:-1], available_features, size=12)
        
        # Plot each cluster
        for cluster_id, name in cluster_names.items():
            values = radar_df.loc[cluster_id].tolist()
            values += values[:1]  # Close the loop
            
            ax.plot(angles, values, linewidth=2, label=name)
            ax.fill(angles, values, alpha=0.1)
            
        plt.title('Trader Type Profiles', size=15)
        plt.legend(loc='upper right', bbox_to_anchor=(0.1, 0.1))
        plt.savefig(os.path.join(results_dir, 'trader_type_radar.png'), dpi=300, bbox_inches='tight')
        plt.close()
        
        # 3. Bar chart of cluster sizes
        plt.figure(figsize=(12, 6))
        
        # Sort by size
        sorted_profiles = cluster_profiles.sort_values('size', ascending=False)
        
        # Create bar chart
        plt.bar(
            range(len(sorted_profiles)), 
            sorted_profiles['size'],
            tick_label=[f"{cluster_names[i]}\n({sorted_profiles.loc[i, 'percentage']:.1f}%)" 
                       for i in sorted_profiles.index]
        )
        
        plt.title('Trader Type Distribution')
        plt.ylabel('Number of Traders')
        plt.tight_layout()
        plt.savefig(os.path.join(results_dir, 'trader_type_distribution.png'), dpi=300, bbox_inches='tight')
        plt.close()
    
    return {
        'trader_features': trader_df,
        'cluster_profiles': cluster_profiles,
        'cluster_names': cluster_names,
        'feature_importance': feature_importance
    }

# Whales

In [None]:
def identify_whales(trades_df, threshold=0.05, method='volume'):
    """
    Identify whale traders based on trading volume or frequency
    
    Parameters:
    -----------
    trades_df : pd.DataFrame
        DataFrame with trade data
    threshold : float
        Threshold for defining whales (e.g., top 5%)
    method : str
        Method for identifying whales ('volume' or 'frequency')
        
    Returns:
    --------
    dict
        Dictionary with whale identification results
    """
    print(f"Identifying whale traders (top {threshold*100:.1f}%)...")
    
    # Create a trader metrics dataframe
    trader_metrics = []
    
    # Find all unique trader IDs
    if 'maker_id' in trades_df.columns and 'taker_id' in trades_df.columns:
        all_traders = set(trades_df['maker_id'].dropna().unique()) | set(trades_df['taker_id'].dropna().unique())
    elif 'trader_id' in trades_df.columns:
        all_traders = set(trades_df['trader_id'].dropna().unique())
    else:
        print("Error: No trader identifier columns found")
        return None
    
    print(f"Calculating metrics for {len(all_traders)} traders...")
    
    # Calculate volume and frequency for each trader
    for trader_id in all_traders:
        # Get all trades where this trader was involved (as maker or taker)
        if 'maker_id' in trades_df.columns and 'taker_id' in trades_df.columns:
            maker_trades = trades_df[trades_df['maker_id'] == trader_id]
            taker_trades = trades_df[trades_df['taker_id'] == trader_id]
            trader_trades = pd.concat([maker_trades, taker_trades]).drop_duplicates()
        else:
            trader_trades = trades_df[trades_df['trader_id'] == trader_id]
        
        # Calculate volume
        if 'trade_amount' in trader_trades.columns:
            total_volume = trader_trades['trade_amount'].sum()
        elif 'size' in trader_trades.columns:
            total_volume = trader_trades['size'].sum()
        else:
            total_volume = len(trader_trades)
        
        # Store metrics
        trader_metrics.append({
            'trader_id': trader_id,
            'volume': total_volume,
            'trade_count': len(trader_trades)
        })
    
    # Convert to DataFrame
    metrics_df = pd.DataFrame(trader_metrics)
    
    # Determine metric to use for whale identification
    if method == 'volume':
        metric_col = 'volume'
    else:  # frequency
        metric_col = 'trade_count'
    
    # Sort by metric in descending order
    metrics_df = metrics_df.sort_values(metric_col, ascending=False)
    
    # Define whales as top % of traders
    whale_count = max(1, int(len(metrics_df) * threshold))
    whales = metrics_df.head(whale_count).copy()
    whale_ids = whales['trader_id'].tolist()
    
    print(f"Identified {whale_count} whales out of {len(metrics_df)} traders")
    
    # Calculate whale concentration
    whale_concentration = whales[metric_col].sum() / metrics_df[metric_col].sum()
    print(f"Whales control {whale_concentration*100:.2f}% of total {method}")
    
    # Calculate average metric for whales vs non-whales
    whale_avg = whales[metric_col].mean()
    non_whale_avg = metrics_df.iloc[whale_count:][metric_col].mean() if len(metrics_df) > whale_count else 0
    whale_ratio = whale_avg / non_whale_avg if non_whale_avg > 0 else float('inf')
    
    print(f"Average whale {method} is {whale_ratio:.1f}x higher than non-whale {method}")
    
    # Visualize whale distribution
    plt.figure(figsize=(12, 6))
    
    # Create histogram of volume/frequency distribution with log scale
    plt.hist(metrics_df[metric_col], bins=50, log=True, alpha=0.7)
    
    # Add line for whale threshold
    min_whale_value = whales[metric_col].min()
    plt.axvline(min_whale_value, color='red', linestyle='--', 
                label=f'Whale threshold: {min_whale_value:.0f}')
    
    plt.title(f'Distribution of Trader {method.capitalize()} (Log Scale)')
    plt.xlabel(metric_col.capitalize())
    plt.ylabel('Count (Log Scale)')
    plt.legend()
    plt.savefig(os.path.join(results_dir, f'whale_distribution_{method}.png'), dpi=300, bbox_inches='tight')
    plt.close()
    
    # Create Lorenz curve
    plt.figure(figsize=(10, 6))
    
    # Sort values for Lorenz curve
    sorted_values = np.sort(metrics_df[metric_col].values)
    cumulative_pct = np.cumsum(sorted_values) / np.sum(sorted_values)
    
    # Plot Lorenz curve
    plt.plot([0] + list(range(1, len(cumulative_pct) + 1)), [0] + list(cumulative_pct), 
             label='Lorenz curve')
    
    # Plot line of equality
    plt.plot([0, len(cumulative_pct)], [0, 1], 'k--', label='Line of equality')
    
    # Highlight area representing Gini coefficient
    plt.fill_between(range(len(cumulative_pct) + 1), 
                    [0] + list(np.linspace(0, 1, len(cumulative_pct))),
                    [0] + list(cumulative_pct), 
                    alpha=0.2)
    
    # Calculate Gini coefficient
    gini = 1 - 2 * np.trapz(cumulative_pct) / len(cumulative_pct)
    
    plt.title(f'Lorenz Curve of Trader {method.capitalize()} Distribution (Gini: {gini:.4f})')
    plt.xlabel('Cumulative % of Traders')
    plt.ylabel(f'Cumulative % of {method.capitalize()}')
    plt.legend()
    plt.savefig(os.path.join(results_dir, f'lorenz_curve_{method}.png'), dpi=300, bbox_inches='tight')
    plt.close()
    
    return {
        'whale_ids': whale_ids,
        'whales': whales,
        'whale_concentration': whale_concentration,
        'whale_to_non_whale_ratio': whale_ratio,
        'gini_coefficient': gini
    }


## Price Impact

In [None]:
def analyze_whale_price_impact(trades_df, whale_ids, window=5):
    """
    Analyze the price impact of whale trades
    
    Parameters:
    -----------
    trades_df : pd.DataFrame
        DataFrame with trade data
    whale_ids : list
        List of whale trader IDs
    window : int
        Number of trades to consider for price impact
        
    Returns:
    --------
    dict
        Dictionary with price impact analysis results
    """
    if 'price' not in trades_df.columns:
        print("Error: price column not found")
        return None
    
    print("Analyzing price impact of whale trades...")
    
    # Identify whale trades
    if 'maker_id' in trades_df.columns and 'taker_id' in trades_df.columns:
        whale_trades = trades_df[(trades_df['maker_id'].isin(whale_ids)) | 
                               (trades_df['taker_id'].isin(whale_ids))]
        non_whale_trades = trades_df[~((trades_df['maker_id'].isin(whale_ids)) | 
                                     (trades_df['taker_id'].isin(whale_ids)))]
    elif 'trader_id' in trades_df.columns:
        whale_trades = trades_df[trades_df['trader_id'].isin(whale_ids)]
        non_whale_trades = trades_df[~trades_df['trader_id'].isin(whale_ids)]
    else:
        print("Error: No trader identifier columns found")
        return None
    
    if len(whale_trades) == 0:
        print("No whale trades found")
        return None
        
    print(f"Found {len(whale_trades)} whale trades out of {len(trades_df)} total trades")
    
    # Ensure trades are sorted by timestamp
    if 'timestamp' in trades_df.columns:
        trades_df = trades_df.sort_values('timestamp')
    
    # Calculate price impact for each whale trade
    whale_impacts = []
    
    # Group by market_id to analyze within each market
    for market_id, market_trades in trades_df.groupby('market_id'):
        market_trades = market_trades.sort_values('timestamp').reset_index(drop=True)
        
        # Get whale trades in this market
        if 'maker_id' in market_trades.columns and 'taker_id' in market_trades.columns:
            market_whale_trades = market_trades[(market_trades['maker_id'].isin(whale_ids)) | 
                                             (market_trades['taker_id'].isin(whale_ids))]
        else:
            market_whale_trades = market_trades[market_trades['trader_id'].isin(whale_ids)]
        
        for _, whale_trade in market_whale_trades.iterrows():
            # Find the position of this trade in the market trades
            if 'timestamp' in whale_trade:
                # Find trades with the same timestamp
                try:
                    same_time_trades = market_trades[market_trades['timestamp'] == whale_trade['timestamp']]
                    if len(same_time_trades) == 0:
                        continue
                    
                    trade_idx = same_time_trades.index[0]  # Use the first one if multiple
                except:
                    # Skip if timestamp comparison fails
                    continue
            else:
                # Skip if no timestamp
                continue
                
            # Get price before trade (window trades before)
            before_indices = market_trades.index[market_trades.index < trade_idx]
            if len(before_indices) >= window:
                before_idx = before_indices[-window]  # window trades before
                price_before = market_trades.loc[before_idx, 'price']
            elif len(before_indices) > 0:
                before_idx = before_indices[0]  # first available trade
                price_before = market_trades.loc[before_idx, 'price']
            else:
                continue  # Skip if not enough prior trades
                
            # Get price after trade (window trades after)
            after_indices = market_trades.index[market_trades.index > trade_idx]
            if len(after_indices) >= window:
                after_idx = after_indices[window-1]  # window trades after
                price_after = market_trades.loc[after_idx, 'price']
            elif len(after_indices) > 0:
                after_idx = after_indices[-1]  # last available trade
                price_after = market_trades.loc[after_idx, 'price']
            else:
                continue  # Skip if not enough subsequent trades
                
            # Calculate price impact
            price_impact = price_after - price_before
            
            # Determine trader role (maker or taker)
            if 'maker_id' in market_trades.columns and whale_trade['maker_id'] in whale_ids:
                role = 'maker'
                trader_id = whale_trade['maker_id']
            elif 'taker_id' in market_trades.columns and whale_trade['taker_id'] in whale_ids:
                role = 'taker'
                trader_id = whale_trade['taker_id']
            else:
                role = 'unknown'
                trader_id = whale_trade.get('trader_id', 'unknown')
            
            # Calculate trade size
            if 'trade_amount' in whale_trade:
                trade_size = whale_trade['trade_amount']
            elif 'size' in whale_trade:
                trade_size = whale_trade['size']
            else:
                trade_size = 1.0
            
            # Store result
            whale_impacts.append({
                'market_id': market_id,
                'trader_id': trader_id,
                'role': role,
                'trade_size': trade_size,
                'timestamp': whale_trade['timestamp'] if 'timestamp' in whale_trade else None,
                'price_before': price_before,
                'trade_price': whale_trade['price'],
                'price_after': price_after,
                'price_impact': price_impact,
                'abs_price_impact': abs(price_impact),
                'relative_impact': price_impact / price_before if price_before > 0 else 0
            })
    
    if not whale_impacts:
        print("No whale price impacts could be calculated")
        return None
        
    # Convert to DataFrame
    impacts_df = pd.DataFrame(whale_impacts)
    
    # Calculate average impact
    avg_impact = impacts_df['price_impact'].mean()
    avg_abs_impact = impacts_df['abs_price_impact'].mean()
    avg_rel_impact = impacts_df['relative_impact'].mean() * 100  # as percentage
    
    print(f"Average price impact: {avg_impact:.6f}")
    print(f"Average absolute price impact: {avg_abs_impact:.6f}")
    print(f"Average relative impact: {avg_rel_impact:.2f}%")
    
    # Calculate positive and negative impacts
    positive_impacts = impacts_df[impacts_df['price_impact'] > 0]
    negative_impacts = impacts_df[impacts_df['price_impact'] < 0]
    
    positive_pct = len(positive_impacts) / len(impacts_df) * 100 if len(impacts_df) > 0 else 0
    print(f"Positive impacts: {positive_pct:.1f}% of trades")
    
    # Calculate following ratio
    following_ratio = calculate_following_ratio(trades_df, whale_ids)
    
    # Create histogram of price impacts
    plt.figure(figsize=(12, 6))
    sns.histplot(impacts_df['price_impact'], kde=True, bins=30)
    plt.axvline(0, color='red', linestyle='--', label='No impact')
    plt.axvline(avg_impact, color='green', linestyle='-', 
                label=f'Mean impact: {avg_impact:.4f}')
    plt.title('Distribution of Whale Trade Price Impacts')
    plt.xlabel('Price Impact')
    plt.ylabel('Frequency')
    plt.legend()
    plt.savefig(os.path.join(results_dir, 'whale_price_impacts.png'), dpi=300, bbox_inches='tight')
    plt.close()
    
    # Create scatter plot of trade size vs price impact
    plt.figure(figsize=(12, 6))
    plt.scatter(impacts_df['trade_size'], impacts_df['price_impact'], alpha=0.5)
    plt.axhline(0, color='red', linestyle='--')
    plt.title('Trade Size vs Price Impact')
    plt.xlabel('Trade Size')
    plt.ylabel('Price Impact')
    plt.savefig(os.path.join(results_dir, 'trade_size_vs_impact.png'), dpi=300, bbox_inches='tight')
    plt.close()
    
    return {
        'impacts': impacts_df,
        'avg_impact': avg_impact,
        'avg_abs_impact': avg_abs_impact,
        'avg_rel_impact': avg_rel_impact,
        'positive_pct': positive_pct,
        'following_ratio': following_ratio
    }


## Following Ratio

In [None]:

def calculate_following_ratio(trades_df, whale_ids, window_minutes=30):
    """
    Calculate how often non-whale traders follow whale trades in the same direction
    
    Parameters:
    -----------
    trades_df : pd.DataFrame
        DataFrame with trade data
    whale_ids : list
        List of whale trader IDs
    window_minutes : int
        Time window in minutes to consider for following behavior
        
    Returns:
    --------
    float
        Ratio of non-whale trades following whale trades in same direction
    """
    # Check if we have trade direction information
    if 'side' in trades_df.columns:
        direction_col = 'side'
    elif 'trade_direction' in trades_df.columns:
        direction_col = 'trade_direction'
    else:
        print("Warning: No trade direction information available")
        return None
    
    # Ensure trades are sorted by timestamp
    if 'timestamp' not in trades_df.columns:
        print("Warning: No timestamp information for following ratio analysis")
        return None
    
    # Ensure timestamp is datetime
    if not pd.api.types.is_datetime64_any_dtype(trades_df['timestamp']):
        trades_df = trades_df.copy()
        trades_df['timestamp'] = pd.to_datetime(trades_df['timestamp'])
    
    trades_df = trades_df.sort_values('timestamp')
    
    # Identify whale and non-whale trades
    if 'maker_id' in trades_df.columns and 'taker_id' in trades_df.columns:
        whale_trades = trades_df[(trades_df['maker_id'].isin(whale_ids)) | 
                                (trades_df['taker_id'].isin(whale_ids))]
        non_whale_trades = trades_df[~((trades_df['maker_id'].isin(whale_ids)) | 
                                     (trades_df['taker_id'].isin(whale_ids)))]
    elif 'trader_id' in trades_df.columns:
        whale_trades = trades_df[trades_df['trader_id'].isin(whale_ids)]
        non_whale_trades = trades_df[~trades_df['trader_id'].isin(whale_ids)]
    else:
        print("Error: No trader identifier columns found")
        return None
    
    if len(whale_trades) == 0 or len(non_whale_trades) == 0:
        print("Not enough trades to calculate following ratio")
        return None
    
    following_count = 0
    total_count = 0
    
    # Group by market_id to analyze within each market
    for market_id, market_trades in trades_df.groupby('market_id'):
        market_trades = market_trades.sort_values('timestamp')
        
        # Get whale trades in this market
        if 'maker_id' in market_trades.columns and 'taker_id' in market_trades.columns:
            market_whale_trades = market_trades[(market_trades['maker_id'].isin(whale_ids)) | 
                                             (market_trades['taker_id'].isin(whale_ids))]
            market_non_whale_trades = market_trades[~((market_trades['maker_id'].isin(whale_ids)) | 
                                                   (market_trades['taker_id'].isin(whale_ids)))]
        else:
            market_whale_trades = market_trades[market_trades['trader_id'].isin(whale_ids)]
            market_non_whale_trades = market_trades[~market_trades['trader_id'].isin(whale_ids)]
        
        # Skip if not enough trades
        if len(market_whale_trades) == 0 or len(market_non_whale_trades) == 0:
            continue
        
        # Check each whale trade
        for _, whale_trade in market_whale_trades.iterrows():
            # Get whale trade direction
            if direction_col in whale_trade:
                whale_direction = whale_trade[direction_col]
            else:
                continue
            
            # Define time window for subsequent trades
            whale_time = whale_trade['timestamp']
            end_time = whale_time + pd.Timedelta(minutes=window_minutes)
            
            # Find non-whale trades within window
            subsequent_trades = market_non_whale_trades[
                (market_non_whale_trades['timestamp'] > whale_time) & 
                (market_non_whale_trades['timestamp'] <= end_time)
            ]
            
            if len(subsequent_trades) == 0:
                continue
            
            # Count trades in same direction
            same_direction_trades = subsequent_trades[subsequent_trades[direction_col] == whale_direction]
            
            following_count += len(same_direction_trades)
            total_count += len(subsequent_trades)
    
    # Calculate ratio
    following_ratio = following_count / total_count if total_count > 0 else 0
    print(f"Non-whale traders follow whale trade direction {following_ratio*100:.1f}% of the time")
    
    return following_ratio


## Granger Causality

In [None]:

def test_granger_causality(trades_df, whale_ids, max_lag=5):
    """
    Test if whale trades Granger-cause price movements
    
    Parameters:
    -----------
    trades_df : pd.DataFrame
        DataFrame with trade data
    whale_ids : list
        List of whale trader IDs
    max_lag : int
        Maximum number of lags to test
        
    Returns:
    --------
    dict
        Dictionary with Granger causality results
    """
    from statsmodels.tsa.stattools import grangercausalitytests
    
    if 'price' not in trades_df.columns:
        print("Error: price column not found")
        return None
        
    if 'timestamp' not in trades_df.columns:
        print("Error: timestamp column required for time series analysis")
        return None
    
    print("Testing if whale trades Granger-cause price movements...")
    
    # Ensure timestamp is datetime type
    if not pd.api.types.is_datetime64_any_dtype(trades_df['timestamp']):
        trades_df = trades_df.copy()
        trades_df['timestamp'] = pd.to_datetime(trades_df['timestamp'])
    
    # Group by market_id to analyze each market separately
    market_results = {}
    
    for market_id, market_trades in trades_df.groupby('market_id'):
        print(f"Testing market {market_id}...")
        
        # Sort by timestamp
        market_trades = market_trades.sort_values('timestamp')
        
        # Create indicator for whale activity (1 if trade involves a whale, 0 otherwise)
        if 'maker_id' in market_trades.columns and 'taker_id' in market_trades.columns:
            market_trades['is_whale'] = ((market_trades['maker_id'].isin(whale_ids)) | 
                                       (market_trades['taker_id'].isin(whale_ids))).astype(int)
        else:
            market_trades['is_whale'] = market_trades['trader_id'].isin(whale_ids).astype(int)
        
        # Resample to regular intervals
        try:
            # Set timestamp as index
            market_trades = market_trades.set_index('timestamp')
            
            # Resample to 5-minute intervals
            prices = market_trades['price'].resample('5T').last().ffill()
            whale_activity = market_trades['is_whale'].resample('5T').sum().fillna(0)
            
            # Align the series and ensure sufficient data points
            aligned_data = pd.concat([prices, whale_activity], axis=1).dropna()
            aligned_data.columns = ['price', 'whale_activity']
            
            if len(aligned_data) <= max_lag + 2:
                print(f"  Insufficient data points for market {market_id} after resampling")
                continue
                
            # Run Granger causality test (whale_activity → price)
            try:
                gc_result = grangercausalitytests(
                    aligned_data[['price', 'whale_activity']], 
                    maxlag=max_lag, 
                    verbose=False
                )
                
                # Extract p-values for each lag
                p_values = {lag: result[0]['ssr_chi2test'][1] for lag, result in gc_result.items()}
                
                # Find significant lags (p < 0.05)
                significant_lags = [lag for lag, p in p_values.items() if p < 0.05]
                
                # Store results
                market_results[market_id] = {
                    'p_values': p_values,
                    'significant_lags': significant_lags,
                    'min_p_value': min(p_values.values()) if p_values else 1.0,
                    'best_lag': min(p_values, key=p_values.get) if p_values else None,
                    'is_significant': len(significant_lags) > 0
                }
                
                result_str = "SIGNIFICANT" if len(significant_lags) > 0 else "not significant"
                print(f"  Result: {result_str} (best lag: {market_results[market_id]['best_lag']})")
            except Exception as e:
                print(f"  Error in Granger causality test: {e}")
        except Exception as e:
            print(f"  Error processing market {market_id}: {e}")
    
    if not market_results:
        print("No Granger causality results obtained")
        return None
        
    # Summarize results
    significant_markets = sum(1 for r in market_results.values() if r['is_significant'])
    total_markets = len(market_results)
    
    significance_ratio = significant_markets / total_markets if total_markets > 0 else 0
    print(f"\nGranger causality is significant in {significant_markets} of {total_markets} markets ({significance_ratio*100:.1f}%)")
    
    # Calculate average lag of significant effects
    if significant_markets > 0:
        best_lags = [r['best_lag'] for r in market_results.values() if r['is_significant']]
        avg_lag = sum(best_lags) / len(best_lags) if best_lags else 0
        print(f"Average lag of significant effects: {avg_lag:.1f} intervals")
    
    # Create visualization of p-values
    plt.figure(figsize=(12, 6))
    
    # Extract min p-values for each market
    market_ids = list(market_results.keys())
    min_p_values = [market_results[m]['min_p_value'] for m in market_ids]
    
    # Sort by p-value
    sorted_indices = np.argsort(min_p_values)
    sorted_markets = [market_ids[i] for i in sorted_indices]
    sorted_p_values = [min_p_values[i] for i in sorted_indices]
    
    # Create bar chart
    bars = plt.bar(range(len(sorted_markets)), sorted_p_values, color='skyblue')
    
    # Highlight significant markets
    for i, p_value in enumerate(sorted_p_values):
        if p_value < 0.05:
            bars[i].set_color('green')
    
    # Add significance threshold line
    plt.axhline(0.05, color='red', linestyle='--', label='Significance threshold (p=0.05)')
    
    plt.title('Granger Causality Test: Do Whale Trades Predict Price Movements?')
    plt.ylabel('Minimum p-value')
    plt.xlabel('Markets (sorted by p-value)')
    plt.xticks([])  # Hide x-tick labels for cleaner visualization
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(results_dir, 'granger_causality_results.png'), dpi=300, bbox_inches='tight')
    plt.close()
    
    return {
        'market_results': market_results,
        'significant_ratio': significance_ratio,
        'significant_markets': significant_markets,
        'total_markets': total_markets,
        'avg_lag': avg_lag if significant_markets > 0 else None
    }

## Trader Profitability

In [None]:
def analyze_trader_profitability(trades_df, classification_results=None):
    """
    Analyze trader profitability and success rates if profit data is available
    
    Parameters:
    -----------
    trades_df : pd.DataFrame
        DataFrame with trade data
    classification_results : dict, optional
        Dictionary with trader classification results
        
    Returns:
    --------
    dict
        Dictionary with profitability analysis results
    """
    if 'profit' not in trades_df.columns:
        print("Warning: profit column not found, trying to calculate")
        # Try to calculate profit using price changes
        if 'price' in trades_df.columns and 'trade_amount' in trades_df.columns and 'trade_direction' in trades_df.columns:
            try:
                # This is a simplified calculation and may not be accurate
                # A proper calculation would require entry and exit prices for each position
                trades_df['profit'] = trades_df['price'] * trades_df['trade_amount'] * (
                    trades_df['trade_direction'].map({'buy': 1, 'sell': -1})
                )
                print("Created simple profit proxy based on price and direction")
            except Exception as e:
                print(f"Error calculating profit: {e}")
                return None
        else:
            print("Cannot calculate or find profit information")
            return None
    
    print("Analyzing trader profitability...")
    
    # Calculate profitability by trader
    trader_profit = trades_df.groupby('trader_id')['profit'].agg(['sum', 'mean', 'count'])
    trader_profit.columns = ['total_profit', 'avg_profit_per_trade', 'trade_count']
    
    # Calculate win rate (if profit values are reliable)
    trader_win_rate = trades_df.groupby('trader_id')['profit'].apply(
        lambda x: (x > 0).mean()
    ).rename('win_rate')
    
    # Combine metrics
    trader_metrics = pd.concat([trader_profit, trader_win_rate], axis=1)
    
    # Sort by total profit
    trader_metrics = trader_metrics.sort_values('total_profit', ascending=False)
    
    # Calculate profit concentration (Gini coefficient)
    profit_gini = calculate_gini(trader_metrics['total_profit'].values)
    print(f"Profit concentration (Gini): {profit_gini:.4f}")
    
    # Calculate average metrics
    avg_win_rate = trader_metrics['win_rate'].mean()
    print(f"Average win rate: {avg_win_rate:.1%}")
    
    # If trader classification is available, analyze by type
    type_metrics = None
    if classification_results is not None and 'trader_features' in classification_results:
        # Merge with trader types
        trader_features = classification_results['trader_features']
        trader_metrics_with_type = trader_metrics.reset_index().merge(
            trader_features[['trader_id', 'cluster']],
            on='trader_id',
            how='left'
        )
        
        # Add trader type labels
        cluster_names = classification_results.get('cluster_names', {})
        trader_metrics_with_type['trader_type'] = trader_metrics_with_type['cluster'].map(
            lambda x: cluster_names.get(x, f"Cluster {x}")
        )
        
        # Calculate metrics by trader type
        type_metrics = trader_metrics_with_type.groupby('trader_type').agg({
            'total_profit': ['mean', 'sum'],
            'avg_profit_per_trade': 'mean',
            'win_rate': 'mean',
            'trade_count': ['mean', 'sum', 'count']
        })
        
        # Flatten column names
        type_metrics.columns = ['_'.join(col).strip() for col in type_metrics.columns.values]
        
        # Rename count column
        type_metrics = type_metrics.rename(columns={'trade_count_count': 'trader_count'})
        
        # Calculate profit per trader
        type_metrics['profit_per_trader'] = type_metrics['total_profit_sum'] / type_metrics['trader_count']
        
        print("\nProfitability by trader type:")
        print(type_metrics[['trader_count', 'win_rate_mean', 'profit_per_trader']])
        
        # Create visualization of profitability by trader type
        plt.figure(figsize=(12, 6))
        
        # Sort by profit per trader
        sorted_types = type_metrics.sort_values('profit_per_trader', ascending=False)
        
        # Create bar chart
        bars = plt.bar(sorted_types.index, sorted_types['profit_per_trader'])
        
        # Add trader count labels
        for i, bar in enumerate(bars):
            trader_count = sorted_types.iloc[i]['trader_count']
            plt.text(
                bar.get_x() + bar.get_width()/2, 
                bar.get_height() + (0.05 * sorted_types['profit_per_trader'].max()), 
                f"n={trader_count:.0f}",
                ha='center'
            )
        
        plt.title('Average Profit per Trader by Trader Type')
        plt.ylabel('Profit per Trader')
        plt.xticks(rotation=45, ha='right')
        plt.tight_layout()
        plt.savefig(os.path.join(results_dir, 'profit_by_trader_type.png'), dpi=300, bbox_inches='tight')
    
    return {
        'trader_metrics': trader_metrics,
        'profit_gini': profit_gini,
        'avg_win_rate': avg_win_rate,
        'type_metrics': type_metrics
    }

# Run profitability analysis if trade data is available
profitability_results = None
if trade_data is not None and 'trader_id' in trade_data.columns:
    profitability_results = analyze_trader_profitability(
        trade_data, 
        classification_results=classification_results if 'classification_results' in locals() else None
    )

# Aggregated Analysis 

In [None]:

def run_trader_analysis():
    """
    Run the complete trader analysis pipeline and generate report
    """
    print("="*80)
    print("TRADER ANALYSIS FOR PREDICTION MARKETS")
    print("="*80)
    
    # Step 1: Load trade data
    print("\nStep 1: Loading trade data...")
    trade_data = load_trade_data_for_analysis(n_markets=5, max_trades_per_market=50000)
    
    if trade_data is None or len(trade_data) == 0:
        print("Error: No trade data available for analysis")
        return
    
    # Step 2: Classify trader types
    print("\nStep 2: Classifying trader types...")
    classification_results = classify_traders(trade_data, n_clusters=5)
    
    # Step 3: Identify whale traders
    print("\nStep 3: Identifying whale traders...")
    whale_results = identify_whales(trade_data, threshold=0.05, method='volume')
    
    if whale_results is None or 'whale_ids' not in whale_results:
        print("Error: Could not identify whale traders")
        return
    
    # Step 4: Analyze price impact
    print("\nStep 4: Analyzing price impact of whale traders...")
    impact_results = analyze_whale_price_impact(trade_data, whale_results['whale_ids'])
    
    # Step 5: Test Granger causality
    print("\nStep 5: Testing Granger causality...")
    gc_results = test_granger_causality(trade_data, whale_results['whale_ids'])
    
    # Step 6: Generate comprehensive report
    print("\nStep 6: Generating final report...")
    
    report = {
        'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
        'data_summary': {
            'total_trades': len(trade_data),
            'unique_traders': trade_data['trader_id'].nunique() if 'trader_id' in trade_data.columns else 0,
            'markets_analyzed': trade_data['market_id'].nunique() if 'market_id' in trade_data.columns else 0
        },
        'trader_types': {
            'num_types': len(classification_results['cluster_profiles']) if classification_results is not None else 0,
            'types': {name: float(classification_results['cluster_profiles'].loc[i, 'percentage']) 
                     for i, name in classification_results['cluster_names'].items()} 
                     if classification_results is not None else {}
        },
        "whale_analysis": {
            "whale_concentration": float(whale_results['whale_concentration']) if 'whale_results' in locals() and whale_results is not None else None,
            "whale_impact": float(impact_results['avg_impact']) if 'impact_results' in locals() and impact_results is not None else None,
            "following_ratio": float(impact_results['following_ratio']) if 'impact_results' in locals() and impact_results is not None else None
        },
        "causality": {
            "significant_ratio": float(gc_results['significant_ratio']) if 'gc_results' in locals() and gc_results is not None else None
        },
        "profitability": {
            "profit_gini": float(profitability_results['profit_gini']) if 'profitability_results' in locals() and profitability_results is not None else None
        }
    }
    
    # Save as JSON
    import json
    with open(os.path.join(results_dir, 'thesis_summary.json'), 'w') as f:
        json.dump(thesis_summary, f, indent=2)
    print(f"\nSaved thesis summary to {os.path.join(results_dir, 'thesis_summary.json')}")