In [None]:
# former version
import os
import polars as pl
import numpy as np
from datetime import datetime
import logging
from typing import Dict, List, Tuple
import joblib
from sklearn.ensemble import RandomForestClassifier, VotingClassifier
from sklearn.metrics import (
    classification_report, confusion_matrix, f1_score,
    precision_recall_curve, precision_score, recall_score,
    roc_auc_score
)
import xgboost as xgb
from lightgbm import LGBMClassifier

# Configure Polars for memory usage
pl.Config.set_streaming_chunk_size(1_000_000)
pl.Config.set_fmt_str_lengths(50)

class LazyDatasetLoader:
    """Memory-efficient dataset loader with checkpoint-based FID filtering"""
    
    def __init__(self, data_path: str, checkpoint_dir: str, debug_mode: bool = True, sample_size: int = 700_000, fids_to_ensure: List[int] = None):
        self.data_path = data_path
        self.checkpoint_dir = checkpoint_dir
        self._cached_dataset = None
        self._cached_name = None
        self.debug_mode = debug_mode
        self.sample_size = sample_size
        self.base_fids = None
        self.fids_to_ensure = fids_to_ensure


    def set_base_fids(self, fids):
        """Set base FIDs to maintain consistent filtering"""
        self.base_fids = fids
        print(f"Set base FIDs: {len(fids)} records")
        
        
    def get_checkpoint_fids(self):
        """Get base FIDs from profile checkpoint if it exists"""
        profile_checkpoint = f"{self.checkpoint_dir}/profile_features.parquet"
        if os.path.exists(profile_checkpoint):
            df = pl.read_parquet(profile_checkpoint)
            if 'fid' in df.columns:
                self.base_fids = df['fid']
                print(f"Loaded base FIDs from checkpoint: {len(self.base_fids)} records")
                return True
        return False
        
    def get_dataset(self, name: str, columns: List[str] = None, source="farcaster") -> pl.DataFrame:
        """Get dataset with checkpoint-based FID filtering"""
        if self._cached_dataset is not None:
            self._cached_dataset = None
            
        if source == "farcaster":
            path = f"{self.data_path}/farcaster-{name}-0-1733162400.parquet"
        elif source == "nindexer":
            path = f"{self.data_path}/nindexer-{name}-0-1733508243.parquet"
        try:
            scan_query = pl.scan_parquet(path)
            if columns:
                scan_query = scan_query.select(columns)
                
            if self.debug_mode:
                if self.base_fids is None:
                    # Try to get FIDs from checkpoint first
                    if not self.get_checkpoint_fids():
                        if name == 'profile_with_addresses':
                            self._cached_dataset = scan_query.limit(self.sample_size).collect()
                            dataset_with_fids = scan_query.filter(pl.col('fid').is_in(self.fids_to_ensure)).collect()
                            if len(dataset_with_fids) > 0:
                                self._cached_dataset = pl.concat([self._cached_dataset, dataset_with_fids], how='diagonal').unique(subset='fid')

                            self.base_fids = self._cached_dataset['fid']
                            print(f"Established new base FIDs from {name}: {len(self.base_fids)} records")
                        else:
                            print(f"Warning: No base FIDs available for {name}")
                            self._cached_dataset = scan_query.limit(self.sample_size).collect()
                else:
                    print(f"Filtering {name} by {len(self.base_fids)} base FIDs")
                    self._cached_dataset = (scan_query
                        .filter(pl.col('fid').is_in(self.base_fids))
                        .collect())
            else:
                self._cached_dataset = scan_query.collect()
                    
            print(f"Loaded {name}: {len(self._cached_dataset)} records")
            return self._cached_dataset
            
        except Exception as e:
            print(f"Error loading {name}: {str(e)}")
            raise
            return pl.DataFrame()

    def clear_cache(self):
        """Clear the cached dataset"""
        self._cached_dataset = None
        self._cached_name = None        

class FeatureSet:
    """Track feature dependencies and versioning"""
    def __init__(self, name: str, version: str, dependencies: List[str] = None):
        self.name = name
        self.version = version  # Version of feature calculation logic
        self.dependencies = dependencies or []
        self.checkpoint_path = None
        self.last_modified = None

class FeatureEngineering:
    """Enhanced bot detection system"""
    
    def __init__(self, data_path: str, checkpoint_dir: str, fids_to_ensure: List[int] = None):
        self.data_path = data_path
        self.checkpoint_dir = checkpoint_dir
        self.loader = LazyDatasetLoader(data_path, checkpoint_dir, fids_to_ensure=fids_to_ensure)
        self.fids_to_ensure = fids_to_ensure
        
        # Define comprehensive feature dependencies and versions
        self.feature_sets = {
            # Base features
            'profile': FeatureSet('profile', '1.0'),
            'network': FeatureSet('network', '1.0'),
            'temporal': FeatureSet('temporal', '1.0', ['network']),
            
            # Activity features
            'cast': FeatureSet('cast', '1.0'),
            'reaction': FeatureSet('reaction', '1.0'),
            'channel': FeatureSet('channel', '1.0'),
            'verification': FeatureSet('verification', '1.0'),
            
            # Account features
            'user_data': FeatureSet('user_data', '1.0'),
            'storage': FeatureSet('storage', '1.0'),
            'signers': FeatureSet('signers', '1.0'),
            
            # Interaction patterns
            'engagement': FeatureSet('engagement', '1.0', 
                ['cast', 'reaction', 'channel']),
            'mentions': FeatureSet('mentions', '1.0', 
                ['cast', 'network']),
            'reply_patterns': FeatureSet('reply_patterns', '1.0', 
                ['cast', 'temporal']),
            
            # Network quality
            'network_quality': FeatureSet('network_quality', '1.0', 
                ['network', 'engagement']),
            'power_user_interaction': FeatureSet('power_user_interaction', '1.0', 
                ['network', 'temporal']),
            'cluster_analysis': FeatureSet('cluster_analysis', '1.0', 
                ['network', 'engagement']),
            
            # Behavioral patterns
            'activity_patterns': FeatureSet('activity_patterns', '1.0', 
                ['temporal', 'cast', 'reaction']),
            'update_behavior': FeatureSet('update_behavior', '1.0', 
                ['user_data', 'profile']),
            'verification_patterns': FeatureSet('verification_patterns', '1.0', 
                ['verification', 'temporal']),
            
            # Meta features
            'authenticity': FeatureSet('authenticity', '2.0', [
                'profile', 'network', 'channel', 'verification',
                'engagement', 'network_quality', 'activity_patterns'
            ]),
            'influence': FeatureSet('influence', '1.0', [
                'network', 'engagement', 'power_user_interaction'
            ]),
            
            # Final derived features
            'derived': FeatureSet('derived', '2.0', [
                'network', 'temporal', 'authenticity',
                'engagement', 'network_quality', 'influence'
            ]),

            # nindexer features
            'enhanced_network': FeatureSet('enhanced_network', '1.0', 
                ['network']),
            'enhanced_profile': FeatureSet('enhanced_profile', '1.0', 
                ['profile']),
            'neynar_score': FeatureSet('neynar_score', '1.0'),

            'name_patterns': FeatureSet('name_patterns', '1.0', ['profile']),
            'content_patterns': FeatureSet('content_patterns', '1.0', ['cast']),
            'advanced_temporal': FeatureSet('advanced_temporal', '1.0', ['temporal', 'cast', 'reaction']),
            'reward_gaming': FeatureSet('reward_gaming', '1.0', ['cast', 'reaction', 'temporal']),
            'engagement_authenticity': FeatureSet('engagement_authenticity', '1.0', ['network', 'cast', 'reaction'])

        }
        
        # Initialize checkpoint tracking
        self._init_checkpoints()

    def _analyze_name_patterns(self, text: str) -> Dict[str, int]:
        """Analyze username/display name patterns"""
        if not text:
            return {
                'random_numbers': 0,
                'wallet_pattern': 0,
                'excessive_symbols': 0,
                'airdrop_terms': 0,
                'has_year': 0
            }
        
        return {
            'random_numbers': int(bool(re.findall(r'\d{4,}', text))),
            'wallet_pattern': int(bool(re.findall(r'0x[a-fA-F0-9]{40}', text))),
            'excessive_symbols': int(bool(re.findall(r'[_.\-]{2,}', text))),
            'airdrop_terms': int(any(term in text.lower() for term in ['airdrop', 'farm', 'degen', 'wojak'])),
            'has_year': int(bool(re.findall(r'20[12]\d', text)))
        }

    def _analyze_content_patterns(self, text: str) -> Dict[str, int]:
        """Analyze content for spam/bot patterns"""
        if not text:
            return {
                'template_structure': 0,
                'multiple_cta': 0,
                'urgency_terms': 0,
                'excessive_emojis': 0,
                'price_mentions': 0
            }
        
        text = text.lower()
        return {
            'template_structure': int(bool(re.findall(r'\[.*?\]|\{.*?\}|\<.*?\>', text))),
            'multiple_cta': int(len(re.findall(r'click|join|follow|claim|grab', text)) > 2),
            'urgency_terms': int(bool(re.findall(r'hurry|limited|fast|quick|soon|ending', text))),
            'excessive_emojis': int(len(re.findall(r'[\U0001F300-\U0001F9FF]', text)) > 5),
            'price_mentions': int(bool(re.findall(r'\$\d+|\d+\$', text))),
            'excessive_symbols': int(bool(re.findall(r'[_.\-]{2,}', text))),
            'airdrop_terms': int(any(term in text.lower() for term in ['airdrop', 'farm', 'degen', 'wojak'])),
        }
        
    def validate_dimensions(func):
        """Decorator to validate DataFrame dimensions"""
        def wrapper(self, df: pl.DataFrame, *args, **kwargs):
            input_shape = len(df)
            try:
                result = func(self, df, *args, **kwargs)
                if len(result) != input_shape:
                    print(f"Warning: Shape mismatch in {func.__name__}. Input: {input_shape}, Output: {len(result)}")
                    # Don't force join or filtering here. Just warn.
                return result.fill_null(0)
            except Exception as e:
                print(f"Error in {func.__name__}: {str(e)}")
                raise
        return wrapper

        
    def get_dataset_columns(self, name: str) -> List[str]:
        """Get the list of columns from the dataset without loading data"""
        path = f"{self.data_path}/farcaster-{name}-0-1733162400.parquet"
        ds = pl.scan_parquet(path)
        return ds.columns
        
    def _init_checkpoints(self):
        """Initialize checkpoint paths and check existing files"""
        os.makedirs(self.checkpoint_dir, exist_ok=True)
        
        for name, feature_set in self.feature_sets.items():
            path = f"{self.checkpoint_dir}/{name}_features.parquet"
            feature_set.checkpoint_path = path
            
            if os.path.exists(path):
                feature_set.last_modified = os.path.getmtime(path)
    
    def _needs_rebuild(self, feature_set: FeatureSet) -> bool:
        """Check if feature set needs to be rebuilt"""
        # Always rebuild if no checkpoint exists
        if not os.path.exists(feature_set.checkpoint_path):
            return True
                
        return False


    def extract_profile_features(self) -> pl.DataFrame:
        """Extract comprehensive profile features"""
        profiles = self.loader.get_dataset('profile_with_addresses', 
            ['fid', 'fname', 'bio', 'avatar_url', 'verified_addresses', 'display_name'])
        
        # Filter valid profiles and cast fid type immediately
        profiles = (profiles
            .filter(pl.col('fname').is_not_null() & (pl.col('fname') != ""))
            .with_columns(pl.col('fid').cast(pl.Int64)))
        
        df = profiles.with_columns([
            pl.col('fname').str.contains(r'\.eth$').cast(pl.Int32).alias('has_ens'),
            (pl.col('bio').is_not_null() & (pl.col('bio') != "")).cast(pl.Int32).alias('has_bio'),
            pl.col('avatar_url').is_not_null().cast(pl.Int32).alias('has_avatar'),
            pl.when(pl.col('verified_addresses').str.contains(','))
            .then(pl.col('verified_addresses').str.contains(',').cast(pl.Int32) + 1)
            .otherwise(pl.when(pl.col('verified_addresses') != '[]')
                        .then(1)
                        .otherwise(0))
            .alias('verification_count'),
            (pl.col('display_name').is_not_null()).cast(pl.Int32).alias('has_display_name')
        ])
        
        self.loader.clear_cache()
        return df
    def add_blocking_behavior(self, df: pl.DataFrame) -> pl.DataFrame:
        """Add memory-efficient blocking behavior features"""
        blocks = self.loader.get_dataset('blocks', ['blocker_fid', 'blocked_fid'])
        
        blocking_features = (
            blocks.group_by('blocker_fid')
            .agg([
                pl.count().alias('blocks_made'),
                pl.n_unique('blocked_fid').alias('unique_blocks')
            ])
            .with_columns([
                (pl.col('blocks_made') / (pl.col('unique_blocks') + 1)).alias('block_repeat_ratio')
            ])
            .rename({'blocker_fid': 'fid'})
        )
        
        self.loader.clear_cache()
        return df.join(blocking_features, on='fid', how='left').fill_null(0)

    def add_enhanced_verification_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Add verification features with proper string handling"""
        try:
            # Initialize result with defaults
            result = df.with_columns([
                pl.lit(0).alias('total_verifications'),
                pl.lit(0).alias('eth_verifications'),
                pl.lit(0.0).alias('verification_timing_std'),
                pl.lit(0).alias('platforms_verified'),
                pl.lit(None).alias('first_platform_verification'),
                pl.lit(None).alias('last_platform_verification'),
                pl.lit(0).alias('verification_span_days')
            ])
            
            # Process on-chain verifications
            verifications = self.loader.get_dataset('verifications', 
                ['fid', 'claim', 'timestamp', 'deleted_at'])
            
            if verifications is not None and len(verifications) > 0:
                verif_features = (
                    verifications
                    .filter(pl.col('deleted_at').is_null())
                    .with_columns([
                        pl.col('timestamp').cast(pl.Datetime)
                    ])
                    .group_by('fid')
                    .agg([
                        pl.len().alias('total_verifications'),
                        pl.col('claim').str.contains('ethSignature').sum().alias('eth_verifications'),
                        # Convert durations to floats and fill nulls before std()
                        pl.col('timestamp')
                            .diff()
                            .dt.total_seconds()
                            .cast(pl.Float64)
                            .fill_null(0)
                            .std()
                            .fill_null(0)
                            .alias('verification_timing_std')
                    ])
                )
                verif_features = verif_features.unique(subset=['fid']) 
                result = result.join(verif_features, on='fid', how='left')
            
            # Process platform verifications
            acc_verifications = self.loader.get_dataset('account_verifications', 
                ['fid', 'platform', 'platform_username', 'verified_at'])
            
            if acc_verifications is not None and len(acc_verifications) > 0:
                platform_features = (
                    acc_verifications
                    .with_columns([
                        pl.col('platform_username').map_elements(lambda x: len(str(x)) if x else 0, return_dtype=pl.Int64),
                        pl.col('verified_at').cast(pl.Datetime)
                    ])
                    .group_by('fid')
                    .agg([
                        pl.n_unique('platform').alias('platforms_verified'),
                        pl.col('verified_at').min().alias('first_platform_verification'),
                        pl.col('verified_at').max().alias('last_platform_verification')
                    ])
                )
                platform_features = platform_features.unique(subset=['fid']) 
                result = result.join(platform_features, on='fid', how='left')

                result = result.with_columns([
                    # First ensure both columns are Datetime
                    pl.col('last_platform_verification').cast(pl.Datetime),
                    pl.col('first_platform_verification').cast(pl.Datetime)
                ])

                # Compute duration safely in a separate step
                result = result.with_columns([
                    (pl.col('last_platform_verification') - pl.col('first_platform_verification'))
                        .alias('verification_duration')
                ])

                # Now handle the null durations and convert to days
                result = result.with_columns([
                    pl.when(pl.col('verification_duration').is_not_null())
                    .then(
                        pl.col('verification_duration')
                        .dt.total_days()  # This should return Float64 if duration is valid
                        .fill_null(0.0)   # fill null if any appear
                    )
                    .otherwise(0.0)
                    .alias('verification_span_days')
                ])

                # Drop the intermediate column if not needed
                result = result.drop('verification_duration')

            self.loader.clear_cache()
            return result.fill_null(0)
            
        except Exception as e:
            print(f"Error in verification features: {str(e)}")
            raise
            return df

    def add_cast_behavior_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Add cast behavior features including link and media analysis"""
        try:
            base_fids = df['fid']
            print(f"Processing casts for {len(base_fids)} FIDs")
            
            # Get casts with all needed fields
            casts_df = self.loader.get_dataset('casts', columns=[
                'fid', 'text', 'parent_hash', 'mentions', 'deleted_at', 
                'timestamp', 'embeds'  # Adding embeds for media detection
            ])
            
            # Calculate features safely
            valid_casts = casts_df.filter(pl.col('deleted_at').is_null())
            def analyze_spam_patterns(text: str) -> Dict[str, int]:
                if not text:
                    return {'airdrop': 0, 'money': 0, 'rewards': 0, 'claim': 0, 'moxie': 0}
                    
                text = text.lower()
                spam_keywords = ['airdrop', 'money', 'rewards', 'claim', 'moxie', 'nft', 'drop']
                return {
                    word: text.count(word) 
                    for word in spam_keywords
                }
                
            def get_symbol_ratios(text: str) -> Dict[str, float]:
                if not text:
                    return {'at_symbol_ratio': 0, 'dollar_symbol_ratio': 0, 'link_ratio': 0}
                    
                total_length = len(text)
                return {
                    'at_symbol_ratio': text.count('@') / total_length if total_length > 0 else 0,
                    'dollar_symbol_ratio': text.count('$') / total_length if total_length > 0 else 0,
                    'link_ratio': len(re.findall(r'http[s]?://', text)) / total_length if total_length > 0 else 0
                }
            # Helper function to count links in text
            def count_links(text):
                if not text:
                    return 0
                # Look for common URL patterns
                url_patterns = ['http://', 'https://', 'www.']
                return sum(1 for pattern in url_patterns if pattern in text.lower())
            
            # Helper function to count media items in embeds
            def count_media(embeds):
                if not embeds or embeds == '[]':
                    return 0
                try:
                    # Count image URLs in embeds
                    return embeds.lower().count('image')
                except:
                    return 0
            
            # Add link and media detection
            cast_features = (valid_casts
                .with_columns([
                    # Existing features
                    pl.when(pl.col('text').is_not_null())
                    .then(pl.col('text').map_elements(lambda x: len(x) if x else 0, return_dtype=pl.Int64))
                    .otherwise(0)
                    .alias('cast_length'),
                    pl.col('parent_hash').is_not_null().cast(pl.Int32).alias('is_reply'),
                    (pl.col('mentions').is_not_null() & 
                    (pl.col('mentions') != '') & 
                    (pl.col('mentions') != '[]')).cast(pl.Int32).alias('has_mentions'),
                    
                    # New features for links and media
                    pl.when(pl.col('text').is_not_null())
                    .then(pl.col('text').map_elements(count_links, return_dtype=pl.Int32))
                    .otherwise(0)
                    .alias('link_count'),
                    
                    pl.when(pl.col('embeds').is_not_null())
                    .then(pl.col('embeds').map_elements(count_media, return_dtype=pl.Int32))
                    .otherwise(0)
                    .alias('media_count'),
                    
                    # Flag for casts containing both link and media
                    (pl.when(pl.col('text').is_not_null())
                    .then(pl.col('text').map_elements(count_links, return_dtype=pl.Int32))
                    .otherwise(0) > 0 &
                    pl.when(pl.col('embeds').is_not_null())
                    .then(pl.col('embeds').map_elements(count_media, return_dtype=pl.Int32))
                    .otherwise(0) > 0)
                    .cast(pl.Int32)
                    .alias('has_link_and_media'),

                    pl.col('text').map_elements(analyze_spam_patterns, return_dtype=pl.Utf8).alias('spam_counts'),
                    pl.col('text').map_elements(get_symbol_ratios, return_dtype=pl.Utf8).alias('symbol_ratios'),
                    pl.col('text').map_elements(self._analyze_content_patterns, return_dtype=pl.Utf8).alias('content_patterns')

                ])
                .group_by('fid')
                .agg([
                    # Existing metrics
                    pl.len().alias('cast_count'),
                    pl.col('cast_length').mean().alias('avg_cast_length'),
                    pl.col('is_reply').sum().alias('reply_count'),
                    pl.col('has_mentions').sum().alias('mentions_count'),
                    
                    # New metrics for links
                    pl.col('link_count').sum().alias('total_links'),
                    (pl.col('link_count') > 0).sum().alias('casts_with_links'),
                    (pl.col('link_count') / pl.len()).alias('link_ratio'),
                    
                    # New metrics for media
                    pl.col('media_count').sum().alias('total_media'),
                    (pl.col('media_count') > 0).sum().alias('casts_with_media'),
                    (pl.col('media_count') / pl.len()).alias('media_ratio'),
                    
                    # Spam metrics
                    (pl.col('spam_counts').map_elements(lambda x: x['airdrop']).sum() / pl.len())
                        .alias('airdrop_mention_ratio'),
                    (pl.col('spam_counts').map_elements(lambda x: sum(x.values())).sum() / pl.len())
                        .alias('spam_keyword_ratio'),
                        
                    # Symbol usage metrics
                    pl.col('symbol_ratios').map_elements(lambda x: x['at_symbol_ratio']).mean()
                        .alias('avg_at_symbol_ratio'),
                    pl.col('symbol_ratios').map_elements(lambda x: x['dollar_symbol_ratio']).mean()
                        .alias('avg_dollar_symbol_ratio'),
                    pl.col('symbol_ratios').map_elements(lambda x: x['link_ratio']).mean()
                        .alias('avg_link_ratio'),
                        
                    # Combined metrics
                    pl.col('has_link_and_media').sum().alias('casts_with_both'),
                    (pl.col('has_link_and_media').sum() / pl.len()).alias('multimedia_ratio')

                    # Content pattern metrics
                    (pl.col('content_patterns').map_elements(lambda x: x['template_structure'], return_dtype=pl.Float64).sum() / pl.len())
                        .alias('template_usage_ratio'),
                    (pl.col('content_patterns').map_elements(lambda x: x['multiple_cta'], return_dtype=pl.Float64).sum() / pl.len())
                        .alias('cta_heavy_ratio'),
                    (pl.col('content_patterns').map_elements(lambda x: x['urgency_terms'], return_dtype=pl.Float64).sum() / pl.len())
                        .alias('urgency_ratio'),
                    (pl.col('content_patterns').map_elements(lambda x: x['excessive_emojis'], return_dtype=pl.Float64).sum() / pl.len())
                        .alias('emoji_spam_ratio'),
                    (pl.col('content_patterns').map_elements(lambda x: x['price_mentions'], return_dtype=pl.Float64).sum() / pl.len())
                        .alias('price_mention_ratio'),
                    (pl.col('content_patterns').map_elements(lambda x: x['excessive_symbols'], return_dtype=pl.Float64).sum() / pl.len())
                        .alias('symbol_spam_ratio'),
                    (pl.col('content_patterns').map_elements(lambda x: x['airdrop_terms'], return_dtype=pl.Float64).sum() / pl.len())
                        .alias('airdrop_term_ratio')
                ]))
            
            cast_features = cast_features.unique(subset=['fid']) 
            # Join and handle nulls
            result = df.join(cast_features, on='fid', how='left').fill_null(0)
            
            # Add derived ratios
            result = result.with_columns([
                # Percentage of casts that contain links
                (pl.col('casts_with_links') / pl.col('cast_count')).alias('link_usage_rate'),
                # Percentage of casts that contain media
                (pl.col('casts_with_media') / pl.col('cast_count')).alias('media_usage_rate'),
                # Average number of links per cast with links
                (pl.col('total_links') / (pl.col('casts_with_links') + 1)).alias('avg_links_per_link_cast'),
                # Average number of media items per cast with media
                (pl.col('total_media') / (pl.col('casts_with_media') + 1)).alias('avg_media_per_media_cast')
            ])
            
            return result
                
        except Exception as e:
            print(f"Error in cast behavior: {str(e)}")
            raise
            return df


    def add_influence_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Add influence features with proper error handling"""
        try:
            # Ensure required columns exist and are properly initialized
            required_cols = ['follower_count', 'following_count', 'total_reactions', 'cast_count']
            for col in required_cols:
                if col not in df.columns:
                    df = df.with_columns(pl.lit(0).alias(col))
            
            # Calculate time span if possible
            if 'first_follow' in df.columns and 'last_follow' in df.columns:
                df = df.with_columns([
                    pl.when(pl.col('last_follow').is_not_null() & pl.col('first_follow').is_not_null())
                    .then((pl.col('last_follow') - pl.col('first_follow')).dt.total_hours())
                    .otherwise(0)
                    .alias('follow_time_span_hours')
                ])
            else:
                df = df.with_columns(pl.lit(0).alias('follow_time_span_hours'))

            # Calculate influence metrics safely
            df = df.with_columns([
                # Normalize influence metrics
                ((pl.col('follower_count').fill_null(0) * 0.4 +
                pl.col('total_reactions').fill_null(0) * 0.3 +
                pl.col('cast_count').fill_null(0) * 0.3) / 
                (pl.col('following_count').fill_null(0) + 1)
                ).alias('influence_score'),
                
                # Safe engagement rate calculation
                (pl.when(pl.col('cast_count') > 0)
                .then(pl.col('total_reactions') / pl.col('cast_count'))
                .otherwise(0)
                ).alias('engagement_rate'),
                
                # Safe follower growth rate calculation
                (pl.when(pl.col('follow_time_span_hours') > 0)
                .then(pl.col('follower_count') / pl.col('follow_time_span_hours'))
                .otherwise(0)
                ).alias('follower_growth_rate')
            ])
            
            return df
            
        except Exception as e:
            print(f"Error in influence features: {str(e)}")
            raise
            return df

    def add_storage_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Add storage features with updated functions"""
        storage = self.loader.get_dataset('storage', ['fid', 'units', 'deleted_at'])
        
        storage_features = (
            storage.filter(pl.col('deleted_at').is_null())
            .group_by('fid')
            .agg([
                pl.col('units').mean().alias('avg_storage_units'),
                pl.col('units').max().alias('max_storage_units'),
                pl.len().alias('storage_update_count')
            ])
        )
        
        self.loader.clear_cache()
        storage_features = storage_features.unique(subset=['fid']) 
        return df.join(storage_features, on='fid', how='left').fill_null(0)
    def add_user_data_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Extract features from user_data with better error handling"""
        try:
            user_data = self.loader.get_dataset('user_data', 
                ['fid', 'type', 'timestamp', 'deleted_at'])
            
            if user_data is None or len(user_data) == 0:
                return df.with_columns([
                    pl.lit(0).alias('total_user_data_updates'),
                    pl.lit(0.0).alias('avg_update_interval')
                ])
                
            update_features = (
                user_data.filter(pl.col('deleted_at').is_null())
                .group_by('fid')
                .agg([
                    pl.len().alias('total_user_data_updates'),
                    pl.col('timestamp').diff().mean().dt.total_hours().fill_null(0)
                        .alias('avg_update_interval')
                ])
            )
            
            self.loader.clear_cache()
            update_features = update_features.unique(subset=['fid']) 
            return df.join(update_features, on='fid', how='left').fill_null(0)
        except Exception as e:
            print(f"Error in user_data features: {str(e)}")
            raise
            return df.with_columns([
                pl.lit(0).alias('total_user_data_updates'),
                pl.lit(0.0).alias('avg_update_interval')
            ])
    def add_signer_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Extract features from signer behavior"""
        signers = self.loader.get_dataset('signers', 
            ['fid', 'timestamp', 'deleted_at'])
        
        signer_features = (
            signers.filter(pl.col('deleted_at').is_null())
            .group_by('fid')
            .agg([
                pl.count().alias('signer_count'),
                pl.col('timestamp').diff().mean().dt.total_hours().alias('avg_hours_between_signers'),
                pl.col('timestamp').diff().std().dt.total_hours().alias('std_hours_between_signers')
            ])
        )
        
        self.loader.clear_cache()
        signer_features = signer_features.unique(subset=['fid']) 
        return df.join(signer_features, on='fid', how='left').fill_null(0)
        
    def add_reaction_patterns(self, df: pl.DataFrame) -> pl.DataFrame:
        """Add reaction pattern features with dimension validation"""
        try:
            base_fids = df['fid']
            print(f"Processing reactions for {len(base_fids)} FIDs")
            
            reactions = self.loader.get_dataset('reactions', 
                ['fid', 'reaction_type', 'target_fid', 'timestamp', 'deleted_at'])
            
            # First filter by base FIDs
            reactions = reactions.filter(pl.col('fid').is_in(base_fids))
            reaction_features = (
                reactions.filter(pl.col('deleted_at').is_null())
                .with_columns([
                    pl.col('timestamp').cast(pl.Datetime)
                ])
                .sort('timestamp')
                .group_by('fid')
                .agg([
                    pl.len().alias('total_reactions'),
                    (pl.col('reaction_type') == 1).sum().alias('like_count'),
                    (pl.col('reaction_type') == 2).sum().alias('recast_count'),
                    pl.n_unique('target_fid').alias('unique_users_reacted_to'),
                    pl.col('timestamp').diff().dt.total_hours().mean().alias('avg_hours_between_reactions'),
                    pl.col('timestamp').diff().dt.total_hours().std().alias('std_hours_between_reactions')
                ])
            )

            # Calculate ratios only after joining back to maintain dimensions
            reaction_features = reaction_features.unique(subset=['fid']) 
            result = df.join(reaction_features, on='fid', how='left', coalesce=True).fill_null(0)
            
            result = result.with_columns([
                (pl.col('like_count') / (pl.col('total_reactions') + 1)).alias('like_ratio'),
                (pl.col('recast_count') / (pl.col('total_reactions') + 1)).alias('recast_ratio'),
                (pl.col('unique_users_reacted_to') / (pl.col('total_reactions') + 1)).alias('reaction_diversity'),
                (pl.col('like_count') / (pl.col('recast_count') + 1)).alias('likes_to_recasts_ratio'),
            ])
            
            # Verify dimensions
            if len(result) != len(df):
                print(f"Warning: Reaction features shape mismatch. Expected {len(df)}, got {len(result)}")
                result = result.filter(pl.col('fid').is_in(base_fids))
                
            self.loader.clear_cache()
            return result
            
        except Exception as e:
            print(f"Error in reaction patterns: {str(e)}")
            raise
            return df

    def build_network_quality_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Build network quality features with safer dependency handling"""
        try:
            # Ensure required base metrics exist
            base_metrics = {
                'engagement_score': 0.0,
                'following_count': 0,
                'follower_count': 0
            }
            
            result = self._validate_and_ensure_features(df, base_metrics)
            
            # Load power users
            power_users = self.loader.get_dataset('power_users', ['fid'])
            if power_users is None or len(power_users) == 0:
                return result.with_columns([
                    pl.lit(0).alias('power_reply_count'),
                    pl.lit(0).alias('power_mentions_count')
                ])
            
            # Calculate power user metrics
            power_fids = power_users['fid'].cast(pl.Int64).unique()
            casts = self.loader.get_dataset('casts', 
                ['fid', 'parent_fid', 'mentions', 'deleted_at'])
                
            if casts is not None and len(casts) > 0:
                power_fid_str = str(power_fids[0])

                power_metrics = (
                    casts.filter(pl.col('deleted_at').is_null())
                    .with_columns([
                        pl.col('parent_fid').cast(pl.Int64).is_in(power_fids)
                            .alias('is_power_reply'),
                        pl.when(pl.col('mentions').is_not_null() & pl.col('mentions').str.contains(power_fid_str))
                        .then(1)
                        .otherwise(0)
                        .alias('has_power_mention')
                    ])
                    .group_by('fid')
                    .agg([
                        pl.sum('is_power_reply').alias('power_reply_count'),
                        pl.sum('has_power_mention').alias('power_mentions_count')
                    ])
                )
                
                result = result.join(power_metrics, on='fid', how='left').fill_null(0)
                
            return result
            
        except Exception as e:
            print(f"Error in network quality features: {str(e)}")
            raise
            return df.with_columns([
                pl.lit(0).alias('power_reply_count'),
                pl.lit(0).alias('power_mentions_count')
            ])

    def add_network_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Add network features with proper error handling and null safety"""
        try:
            links = self.loader.get_dataset('links', 
                ['fid', 'target_fid', 'timestamp', 'deleted_at'])
            
            # Filter valid links first
            valid_links = links.filter(pl.col('deleted_at').is_null())
            
            # Calculate following patterns safely
            following = (valid_links
                .group_by('fid')
                .agg([
                    pl.len().alias('following_count'),
                    pl.n_unique('target_fid').alias('unique_following_count'),
                    pl.col('timestamp').min().alias('first_follow'),
                    pl.col('timestamp').max().alias('last_follow')
                ])
                .fill_null(0))
            
            # Calculate follower patterns separately
            followers = (valid_links
                .group_by('target_fid')
                .agg([
                    pl.len().alias('follower_count'),
                    pl.n_unique('fid').alias('unique_follower_count')
                ])
                .rename({'target_fid': 'fid'})
                .fill_null(0))
            
            # Join both patterns
            result = df.join(following, on='fid', how='left').fill_null(0)
            result = result.join(followers, on='fid', how='left').fill_null(0)
            
            # Calculate ratios safely with null handling
            result = result.with_columns([
                (pl.col('follower_count') / (pl.col('following_count') + 1))
                    .alias('follower_ratio'),
                (pl.col('unique_follower_count') / (pl.col('unique_following_count') + 1))
                    .alias('unique_follower_ratio'),
                
                # Add log transformations
                (pl.col('follower_count') / (pl.col('following_count') + 1))
                    .log1p()
                    .alias('follower_ratio_log'),
                (pl.col('unique_follower_count') / (pl.col('unique_following_count') + 1))
                    .log1p()
                    .alias('unique_follower_ratio_log')
            ])
            
            return result
                
        except Exception as e:
            print(f"Error in network features: {str(e)}")
            raise
            return df.with_columns([
                pl.lit(0).alias('following_count'),
                pl.lit(0).alias('unique_following_count'),
                pl.lit(0).alias('follower_count'),
                pl.lit(0).alias('unique_follower_count'),
                pl.lit(0.0).alias('follower_ratio'),
                pl.lit(0.0).alias('unique_follower_ratio'),
                pl.lit(0.0).alias('follower_ratio_log'),
                pl.lit(0.0).alias('unique_follower_ratio_log')
            ])
    def add_temporal_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Add enhanced temporal features with burst detection"""
        try:
            links = self.loader.get_dataset('links', ['fid', 'timestamp', 'deleted_at'])
            
            # Ensure timestamp is datetime type
            valid_links = (links
                .filter(pl.col('deleted_at').is_null())
                .filter(pl.col('timestamp').is_not_null())
                .with_columns([
                    pl.col('timestamp').cast(pl.Datetime).alias('timestamp')
                ]))
            
            temporal_features = (valid_links
                .group_by('fid')
                .agg([
                    # Basic temporal features
                    pl.len().alias('total_activity'),
                    pl.col('timestamp').diff().dt.total_hours().mean().alias('avg_hours_between_actions'),
                    pl.col('timestamp').diff().dt.total_hours().std().alias('std_hours_between_actions'),
                    pl.col('timestamp').dt.weekday().std().alias('weekday_variance'),
                    (pl.col('timestamp').diff().dt.total_hours() < 1).sum().alias('rapid_actions'),
                    (pl.col('timestamp').diff().dt.total_hours() > 24).sum().alias('long_gaps'),
                    
                    # New temporal features
                    pl.col('timestamp').diff().dt.total_hours().quantile(0.9).alias('p90_time_between_actions'),
                    pl.col('timestamp').diff().dt.total_hours().quantile(0.1).alias('p10_time_between_actions'),
                    
                    # Calculate burst ratio (actions within 1 hour of each other)
                    (pl.col('timestamp').diff().dt.total_hours() < 1).sum().alias('actions_in_bursts'),
                    
                    # Calculate velocity
                    (pl.col('timestamp').max() - pl.col('timestamp').min()).dt.total_hours().alias('time_span')
                ]))
            
            # Add derived temporal metrics
            result = df.join(temporal_features, on='fid', how='left').fill_null(0)
            result = result.with_columns([
                # Burst activity ratio
                (pl.col('actions_in_bursts') / (pl.col('total_activity') + 1)).alias('burst_activity_ratio'),
                
                # Activity spread (ratio of actual timespan to expected even distribution)
                (pl.col('time_span') / ((pl.col('total_activity') + 1) * pl.col('avg_hours_between_actions'))).alias('activity_spread'),
                
                # Temporal irregularity (variation in action timing)
                (pl.col('std_hours_between_actions') / (pl.col('avg_hours_between_actions') + 1)).alias('temporal_irregularity'),
                
                # Follow velocity (follows per hour)
                (pl.col('total_activity') / (pl.col('time_span') + 1)).alias('follow_velocity')
            ])
            
            return result.fill_null(0)
                
        except Exception as e:
            print(f"Error in temporal features: {str(e)}")
            raise
            return df.fill_null(0)
# 
    # def add_advanced_temporal_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Add advanced temporal features for bot detection"""
        try:
            activities = []
            
            # Collect cast timestamps
            casts = self.loader.get_dataset('casts', ['fid', 'timestamp', 'deleted_at'])
            if casts is not None:
                valid_casts = casts.filter(pl.col('deleted_at').is_null())
                activities.append(valid_casts.select(['fid', 'timestamp']))
            
            # Collect reaction timestamps
            reactions = self.loader.get_dataset('reactions', ['fid', 'timestamp', 'deleted_at'])
            if reactions is not None:
                valid_reactions = reactions.filter(pl.col('deleted_at').is_null())
                activities.append(valid_reactions.select(['fid', 'timestamp']))
            
            if not activities:
                return df
            
            # Combine all activities
            all_activities = pl.concat(activities)
            
            temporal_features = (all_activities
                .sort(['fid', 'timestamp'])
                .group_by('fid')
                .agg([
                    # Robotic timing detection
                    (pl.col('timestamp').diff().dt.total_seconds().std() < 1)
                        .cast(pl.Int32)
                        .alias('has_robotic_timing'),
                    
                    # Rapid actions
                    (pl.col('timestamp').diff().dt.total_seconds() < 2)
                        .sum()
                        .alias('rapid_action_count'),
                    
                    # Activity bursts
                    (pl.col('timestamp').diff().dt.total_hours().gt(24).sum())
                        .alias('long_dormancy_periods'),
                        
                    # Time between bursts
                    pl.col('timestamp').diff().dt.total_hours().mean().alias('avg_burst_interval')
                ]))
            
            return df.join(temporal_features, on='fid', how='left').fill_null(0)
            
        except Exception as e:
            print(f"Error in advanced temporal features: {str(e)}")
            raise
        
    def add_power_user_interaction_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Enhanced power user interaction analysis"""
        try:
            # Load power users
            power_users = self.loader.get_dataset('warpcast_power_users', ['fid'])
            if power_users is None or len(power_users) == 0:
                print("Warning: No power users found")
                return df.with_columns([
                    pl.lit(0).alias('power_user_replies'),
                    pl.lit(0).alias('power_user_mentions'),
                    pl.lit(0).alias('power_user_reactions'),
                    pl.lit(0).alias('power_user_interaction_ratio')
                ])
            
            # Ensure power_fids are Int64
            power_fids = power_users['fid'].cast(pl.Int64).unique()
            
            # Get interactions with power users
            casts = self.loader.get_dataset('casts', 
                ['fid', 'parent_fid', 'mentions', 'timestamp', 'deleted_at'])
            
            # Process cast interactions
            power_fid_str = str(power_fids[0])
            power_cast_features = (
                casts.filter(pl.col('deleted_at').is_null())
                .with_columns([
                    pl.col('parent_fid').cast(pl.Int64).is_in(power_fids).alias('is_power_reply'),
               pl.when(pl.col('mentions').is_not_null() & pl.col('mentions').str.contains(power_fid_str))
        .then(1)
        .otherwise(0)
        .alias('has_power_mention')

                ])
                .group_by('fid')
                .agg([
                    pl.sum('is_power_reply').alias('power_user_replies'),
                    pl.sum('has_power_mention').alias('power_user_mentions'),
                    pl.len().alias('total_casts')
                ])
            )
            
            # Get reaction data
            reactions = self.loader.get_dataset('reactions', 
                ['fid', 'target_fid', 'timestamp', 'deleted_at'])
            
            power_reaction_features = (
                reactions.filter(pl.col('deleted_at').is_null())
                .with_columns([
                    pl.col('target_fid').cast(pl.Int64).is_in(power_fids).alias('is_power_reaction')
                ])
                .group_by('fid')
                .agg([
                    pl.sum('is_power_reaction').alias('power_user_reactions'),
                    pl.len().alias('total_reactions')
                ])
            )
            
            # Join features
            result = df.join(power_cast_features, on='fid', how='left')
            result = result.join(power_reaction_features, on='fid', how='left')
            
            # Calculate interaction ratios
            result = result.with_columns([
                pl.col('power_user_replies').fill_null(0),
                pl.col('power_user_mentions').fill_null(0),
                pl.col('power_user_reactions').fill_null(0),
                pl.col('total_casts').fill_null(0),
                pl.col('total_reactions').fill_null(0)
            ])
            
            # Calculate overall interaction ratio
            result = result.with_columns([
                ((pl.col('power_user_replies') + 
                pl.col('power_user_mentions') + 
                pl.col('power_user_reactions')) / 
                (pl.col('total_casts') + pl.col('total_reactions') + 1)
                ).alias('power_user_interaction_ratio')
            ])
            
            return result.fill_null(0)
            
        except Exception as e:
            print(f"Error in power user interaction features: {str(e)}")
            raise
            return df.with_columns([
                pl.lit(0).alias('power_user_replies'),
                pl.lit(0).alias('power_user_mentions'),
                pl.lit(0).alias('power_user_reactions'),
                pl.lit(0).alias('power_user_interaction_ratio')
            ])

    def add_activity_patterns_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Add activity patterns with fully safe calculations"""
        try:
            print("Processing activity patterns...")
            
            # Get activity data
            casts = self.loader.get_dataset('casts', ['fid', 'timestamp', 'deleted_at'])
            reactions = self.loader.get_dataset('reactions', ['fid', 'timestamp', 'deleted_at'])
            
            # Initialize result with default values
            result = df.with_columns([
                pl.lit(0.0).alias('hour_diversity'),
                pl.lit(0.0).alias('weekday_diversity'),
                pl.lit(0.0).alias('total_activities')
            ])
            
            # Process activities if data exists
            if casts is not None and reactions is not None:
                # Combine valid activities
                activities = pl.concat([
                    casts.filter(pl.col('deleted_at').is_null())
                        .select(['fid', 'timestamp']),
                    reactions.filter(pl.col('deleted_at').is_null())
                        .select(['fid', 'timestamp'])
                ])
                
                if len(activities) > 0:
                    # Calculate activity metrics
                    activity_features = (activities
                        .with_columns([
                            pl.col('timestamp').cast(pl.Datetime).dt.hour().alias('hour'),
                            pl.col('timestamp').cast(pl.Datetime).dt.weekday().alias('weekday')
                        ])
                        .group_by('fid')
                        .agg([
                            pl.col('hour').value_counts()
                                .std().fill_null(0).alias('hour_diversity'),
                            pl.col('weekday').value_counts()
                                .std().fill_null(0).alias('weekday_diversity'),
                            pl.len().alias('total_activities')
                        ])
                    )
                    
                    # Update result with calculated features
                    result = df.join(activity_features, on='fid', how='left').fill_null(0)
            
            print("Activity patterns calculated successfully")
            return result
            
        except Exception as e:
            print(f"Error in activity patterns: {str(e)}")
            raise
            return df.with_columns([
                pl.lit(0.0).alias('hour_diversity'),
                pl.lit(0.0).alias('weekday_diversity'),
                pl.lit(0.0).alias('total_activities')
            ])
    def verify_matrix(self, df: pl.DataFrame):
        """Verify the final feature matrix has no list columns"""
        for col in df.columns:
            dtype = df[col].dtype
            if str(dtype).startswith('List'):
                raise ValueError(f"Column {col} is still a list type: {dtype}")
            if dtype not in [pl.Float64, pl.Int64]:
                raise ValueError(f"Column {col} is not numeric: {dtype}")
                
    def add_mentions_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Analyze mention patterns with proper null handling"""
        try:
            base_fids = df['fid']
            print(f"Processing mentions for {len(base_fids)} FIDs")
            
            casts = self.loader.get_dataset('casts', ['fid', 'mentions', 'deleted_at'])
            
            # Filter by base FIDs first
            casts = casts.filter(pl.col('fid').is_in(base_fids))
            
            # Parse mentions as JSON and handle counts
            mention_features = (
                casts.filter(pl.col('deleted_at').is_null())
                .with_columns([
                    # Parse JSON string to array and count elements
                    pl.when(
                        pl.col('mentions').is_not_null() & 
                        (pl.col('mentions') != '') & 
                        (pl.col('mentions') != '[]')
                    )
                    .then(pl.col('mentions').str.json_decode().list.len())
                    .otherwise(0)
                    .alias('mention_count'),
                    
                    # Flag for casts with mentions
                    (pl.col('mentions').is_not_null() & 
                    (pl.col('mentions') != '') & 
                    (pl.col('mentions') != '[]')
                    ).cast(pl.Int32).alias('has_mentions')
                ])
                .group_by('fid')
                .agg([
                    # Count total casts with mentions
                    pl.col('has_mentions').sum().alias('casts_with_mentions'),
                    # Total mentions
                    pl.col('mention_count').sum().alias('total_mentions'),
                    # Average mentions per cast
                    pl.col('mention_count').mean().alias('avg_mentions_per_cast')
                ])
            )
            
            # Join and add ratios
            result = df.join(mention_features, on='fid', how='left', coalesce=True).fill_null(0)
            
            # Add derived metrics
            result = result.with_columns([
                (pl.col('casts_with_mentions') / (pl.col('cast_count') + 1)).alias('mention_frequency'),
                (pl.col('avg_mentions_per_cast') / (pl.col('cast_count') + 1)).alias('mention_ratio')
            ])
            
            print(f"Mentions features complete. Shape: {result.shape}")
            self.loader.clear_cache()
            return result
            
        except Exception as e:
            print(f"Error in mentions features: {str(e)}")
            raise
            return df

    def add_reply_patterns_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Add reply features with updated functions"""
        casts = self.loader.get_dataset('casts', 
            ['fid', 'parent_hash', 'parent_fid', 'timestamp', 'deleted_at'])
        
        reply_features = (
            casts.filter(pl.col('deleted_at').is_null())
            .filter(pl.col('parent_hash').is_not_null())
            .group_by('fid')
            .agg([
                pl.len().alias('total_replies'),
                pl.n_unique('parent_fid').alias('unique_users_replied_to'),
                pl.col('timestamp').diff().mean().dt.total_seconds()
                    .alias('avg_seconds_between_replies'),
                pl.col('timestamp').diff().std().dt.total_seconds()
                    .alias('std_seconds_between_replies')
            ])
            .with_columns([
                (pl.col('unique_users_replied_to') / pl.col('total_replies'))
                    .alias('reply_diversity'),
                (pl.col('std_seconds_between_replies') / 
                pl.col('avg_seconds_between_replies')).alias('reply_timing_variability')
            ])
        )
        
        self.loader.clear_cache()
        return df.join(reply_features, on='fid', how='left').fill_null(0)

    # def add_cluster_analysis_features(self, df: pl.DataFrame) -> pl.DataFrame:
    #     """Analyze network clustering with updated functions"""
    #     try:
    #         links = self.loader.get_dataset('links', 
    #             ['fid', 'target_fid', 'deleted_at'])
            
    #         valid_links = links.filter(pl.col('deleted_at').is_null())
            
    #         # Calculate clustering features
    #         cluster_features = (
    #             valid_links.join(
    #                 valid_links.rename({'fid': 'mutual_fid', 'target_fid': 'mutual_target'}),
    #                 left_on='target_fid',
    #                 right_on='mutual_fid'
    #             )
    #             .group_by('fid')
    #             .agg([
    #                 pl.n_unique('mutual_target').alias('mutual_connections'),
    #                 pl.len().alias('potential_triangles')
    #             ])
    #             .with_columns([
    #                 (pl.col('mutual_connections') / (pl.col('potential_triangles') + 1))
    #                 .alias('clustering_coefficient')
    #             ])
    #         )
            
    #         self.loader.clear_cache()
    #         return df.join(cluster_features, on='fid', how='left').fill_null(0)
            
    #     except Exception as e:
    #         print(f"Error in cluster analysis: {str(e)}")
    #         raise
    #         return df.with_columns([
    #             pl.lit(0).alias('mutual_connections'),
    #             pl.lit(0).alias('potential_triangles'),
    #             pl.lit(0.0).alias('clustering_coefficient')
    #         ])

    def add_authenticity_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Add authenticity features with comprehensive null handling"""
        try:
            print("Building authenticity features...")
            
            # Initialize with safe default values
            result = df.clone()
            required_cols = {
                'has_bio': 0,
                'has_avatar': 0,
                'verification_count': 0,
                'has_ens': 0,
                'following_count': 0.0,
                'follower_count': 0.0,
                'total_updates': 0,
                'avg_update_interval': 0.0,
                'profile_update_consistency': 0.0
            }
            
            # Ensure all required columns exist with proper types
            for col, default in required_cols.items():
                if col not in result.columns:
                    print(f"Adding missing column {col} with default {default}")
                    result = result.with_columns(pl.lit(default).alias(col))
                
                # Fill nulls with defaults
                result = result.with_columns(
                    pl.col(col).fill_null(default).alias(col)
                )
            
            result = result.with_columns(
                pl.col('total_updates').cast(pl.Float64).fill_null(0),
                pl.col('profile_update_consistency').cast(pl.Float64).fill_null(0)
            )

            # Safe calculations with explicit null handling
            result = result.with_columns([
                # Profile completeness (0-1) with safe operations
                ((pl.col('has_bio').fill_null(0) + 
                pl.col('has_avatar').fill_null(0) + 
                pl.col('has_ens').fill_null(0) + 
                (pl.col('verification_count').fill_null(0) > 0).cast(pl.Int64)) / 4.0
                ).alias('profile_completeness'),
                
                # Network balance (0-1) with safe division
                (pl.when(pl.col('following_count').fill_null(0) + pl.col('follower_count').fill_null(0) > 0)
                .then(1.0 - (pl.col('following_count').fill_null(0) - pl.col('follower_count').fill_null(0)).abs() /
                    (pl.col('following_count').fill_null(0) + pl.col('follower_count').fill_null(0)))
                .otherwise(0.0)
                ).alias('network_balance'),
                
                # Update naturalness (0-1) with safe comparisons
                    (pl.when(pl.col('total_updates') > 0)
                .then(1.0 - pl.col('profile_update_consistency').clip(0.0, 1.0))
                .otherwise(0.0))
                .alias('update_naturalness')
            ])
            
            # Calculate final authenticity score with weights
            result = result.with_columns([
                (pl.col('profile_completeness').fill_null(0.0) * 0.4 +
                pl.col('network_balance').fill_null(0.0) * 0.3 +
                pl.col('update_naturalness').fill_null(0.0) * 0.3
                ).alias('authenticity_score')
            ])
            
            print("Authenticity features completed successfully")
            return result.drop(['profile_completeness', 'network_balance', 'update_naturalness'])
            
        except Exception as e:
            print(f"Error in authenticity features: {str(e)}")
            raise
            return df.with_columns(pl.lit(0.0).alias('authenticity_score'))
    def add_update_behavior_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Add update behavior features with comprehensive null handling"""
        try:
            print("Building update behavior features...")
            
            # Initialize result with default values
            result = df.clone().with_columns([
                pl.lit(0.0).alias('profile_update_consistency'),
                pl.lit(0).alias('total_updates'),
                pl.lit(0.0).alias('avg_update_interval'),
                pl.lit(0.0).alias('update_time_std')
            ])
            
            # Get user data
            user_data = self.loader.get_dataset('user_data', ['fid', 'timestamp', 'deleted_at'])
            if user_data is None or len(user_data) == 0:
                return result
                
            # Process updates with strict null handling
            valid_updates = (user_data
                .filter(pl.col('deleted_at').is_null())
                .filter(pl.col('timestamp').is_not_null())
                .with_columns([
                    pl.col('timestamp').cast(pl.Datetime).alias('timestamp')
                ]))
            
            if len(valid_updates) == 0:
                return result
                            
            update_metrics = (valid_updates
                .sort(['fid', 'timestamp'])
                .group_by('fid')
                .agg([
                    pl.len().alias('total_updates'),
                    pl.col('timestamp').diff().dt.total_hours().mean().alias('avg_update_interval'),
                    pl.col('timestamp').diff().dt.total_hours().std().alias('update_time_std')
                ]))

            # Ensure all columns are numeric and nulls are handled
            update_metrics = update_metrics.with_columns([
                pl.col('total_updates').cast(pl.Float64).fill_null(0),
                pl.col('avg_update_interval').cast(pl.Float64).fill_null(0),
                pl.col('update_time_std').cast(pl.Float64).fill_null(0)
            ])

            update_metrics = update_metrics.with_columns([
                pl.when(pl.col('avg_update_interval') > 0)
                .then(pl.col('update_time_std') / pl.col('avg_update_interval'))
                .otherwise(0.0)
                .alias('profile_update_consistency')
            ])

            # Join new features safely
            update_metrics = update_metrics.unique(subset=['fid']) 
            result = result.join(update_metrics, on='fid', how='left')
            
            # Fill any remaining nulls
            result = result.with_columns([
                pl.col('total_updates').cast(pl.Float64).fill_null(0),
                pl.col('avg_update_interval').cast(pl.Float64).fill_null(0),
                pl.col('update_time_std').cast(pl.Float64).fill_null(0),
                pl.col('profile_update_consistency').cast(pl.Float64).fill_null(0)
            ])
            
            self.loader.clear_cache()
            print("Update behavior features completed successfully")
            return result
            
        except Exception as e:
            print(f"Error in update behavior features: {str(e)}")
            raise
            print(f"Returning dataframe with default values")
            return df.with_columns([
                pl.lit(0.0).alias('profile_update_consistency'),
                pl.lit(0).alias('total_updates'),
                pl.lit(0.0).alias('avg_update_interval'),
                pl.lit(0.0).alias('update_time_std')
            ])

    def add_verification_patterns_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Add verification patterns with safe calculations"""
        try:
            # Initialize with default columns
            result = df.clone()
            default_cols = {
                'avg_hours_between_verifications': 0.0,
                'std_hours_between_verifications': 0.0,
                'rapid_verifications': 0,
                'avg_hours_between_platform_verifs': 0.0,
                'std_hours_between_platform_verifs': 0.0
            }
            
            # Add on-chain verification patterns
            verifications = self.loader.get_dataset('verifications', 
                ['fid', 'timestamp', 'deleted_at'])
            
            if verifications is not None and len(verifications) > 0:
                valid_verifs = verifications.filter(pl.col('deleted_at').is_null())
                
                if len(valid_verifs) > 0:
                    verif_patterns = (
                        valid_verifs
                        .with_columns(pl.col('timestamp').cast(pl.Datetime))
                        .group_by('fid')
                        .agg([
                            pl.col('timestamp').diff().dt.total_hours().mean().fill_null(0)
                                .alias('avg_hours_between_verifications'),
                            pl.col('timestamp').diff().dt.total_hours().std().fill_null(0)
                                .alias('std_hours_between_verifications'),
                            (pl.col('timestamp').diff().dt.total_hours() < 1).sum().fill_null(0)
                                .alias('rapid_verifications')
                        ])
                    )
                    result = result.join(verif_patterns, on='fid', how='left')
            
            # Add platform verification patterns
            acc_verifications = self.loader.get_dataset('account_verifications', 
                ['fid', 'verified_at'])
            
            if acc_verifications is not None and len(acc_verifications) > 0:
                platform_patterns = (
                    acc_verifications
                    .with_columns(pl.col('verified_at').cast(pl.Datetime))
                    .group_by('fid')
                    .agg([
                        pl.col('verified_at').diff().dt.total_hours().mean().fill_null(0)
                            .alias('avg_hours_between_platform_verifs'),
                        pl.col('verified_at').diff().dt.total_hours().std().fill_null(0)
                            .alias('std_hours_between_platform_verifs')
                    ])
                )
                platform_patterns = platform_patterns.unique(subset=['fid']) 
                result = result.join(platform_patterns, on='fid', how='left')
            
            # Add any missing columns with defaults
            for col, default in default_cols.items():
                if col not in result.columns:
                    result = result.with_columns(pl.lit(default).alias(col))
                else:
                    result = result.with_columns(pl.col(col).fill_null(default))
            
            self.loader.clear_cache()
            return result
            
        except Exception as e:
            print(f"Error in verification patterns: {str(e)}")
            raise
            return df.with_columns([pl.lit(v).alias(k) for k, v in default_cols.items()])
    def _validate_required_columns(self, df: pl.DataFrame, required_cols: List[str]):
        """Validate required columns exist"""
        missing = [col for col in required_cols if col not in df.columns]
        if missing:
            raise ValueError(f"Missing required columns: {missing}")
            
    def _get_feature_build_order(self):
        """Get correct feature build order based on dependencies"""
        visited = set()
        order = []
        
        def visit(name):
            if name in visited:
                return
            visited.add(name)
            feature_set = self.feature_sets[name]
            for dep in feature_set.dependencies:
                visit(dep)
            order.append(name)
        
        for name in self.feature_sets:
            visit(name)
        return order


    def _validate_feature_addition(self, original_df: pl.DataFrame, 
                                new_df: pl.DataFrame,
                                base_fids: pl.Series,
                                feature_name: str) -> pl.DataFrame:
        """Validate and fix feature addition results"""
        if new_df is None:
            print(f"Error: {feature_name} returned None")
            raise
            return original_df
            
        if len(new_df) != len(original_df):
            print(f"Warning: Shape mismatch in {feature_name}. Expected {len(original_df)}, got {len(new_df)}")
            new_df = new_df.filter(pl.col('fid').is_in(base_fids))
            if len(new_df) != len(original_df):
                return original_df
                
        # Cast numeric columns and handle nulls
        new_cols = [c for c in new_df.columns if c not in original_df.columns]
        if new_cols:
            try:
                new_df = new_df.with_columns([
                    pl.col(c).cast(pl.Float64).fill_null(0) 
                    for c in new_cols 
                    if self._is_numeric_dtype(new_df[c].dtype)
                ])
            except Exception as e:
                print(f"Error casting columns in {feature_name}: {str(e)}")
                raise
                return original_df
                
        return new_df


    def add_enhanced_channel_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Add channel features with improved error handling"""
        try:
            # Prepare result DataFrame with default values
            result = df.with_columns([
                pl.lit(0).alias('unique_channels_followed'),
                pl.lit(0).alias('rapid_channel_follows'),
                pl.lit(0.0).alias('channel_follow_hour_std'),
                pl.lit(0).alias('channel_memberships'),
                pl.lit(0).alias('unique_channel_memberships'),
                pl.lit(0.0).alias('channel_follow_burst_ratio'),
                pl.lit(0.0).alias('channel_engagement_ratio')
            ])
            
            # Process channel follows if available
            channel_follows = self.loader.get_dataset('channel_follows', 
                ['fid', 'channel_id', 'timestamp', 'deleted_at'])
            
            if channel_follows is not None and len(channel_follows) > 0:
                follow_features = (
                    channel_follows.filter(pl.col('deleted_at').is_null())
                    .group_by('fid')
                    .agg([
                        pl.n_unique('channel_id').alias('unique_channels_followed'),
                        (pl.col('timestamp').diff().dt.total_seconds() < 60)
                            .sum().alias('rapid_channel_follows'),
                        pl.col('timestamp').dt.hour().value_counts()
                            .std().alias('channel_follow_hour_std')
                    ])
                )
                # Join follow features safely
                if len(follow_features) > 0:
                    follow_features = follow_features.unique(subset=['fid']) 
                    result = result.join(follow_features, on='fid', how='left').fill_null(0)
            
            # Process channel memberships if available
            channel_members = self.loader.get_dataset('channel_members', 
                ['fid', 'channel_id', 'deleted_at'])
            
            if channel_members is not None and len(channel_members) > 0:
                member_features = (
                    channel_members.filter(pl.col('deleted_at').is_null())
                    .group_by('fid')
                    .agg([
                        pl.len().alias('channel_memberships'),
                        pl.n_unique('channel_id').alias('unique_channel_memberships')
                    ])
                )
                # Join member features safely
                if len(member_features) > 0:
                    member_features = member_features.unique(subset=['fid']) 
                    result = result.join(member_features, on='fid', how='left').fill_null(0)
            
            # Calculate derived metrics safely
            result = result.with_columns([
                (pl.col('rapid_channel_follows') / pl.col('unique_channels_followed').add(1))
                    .alias('channel_follow_burst_ratio'),
                (pl.col('channel_memberships') / pl.col('unique_channel_memberships').add(1))
                    .alias('channel_engagement_ratio')
            ])
            
            self.loader.clear_cache()
            return result
            
        except Exception as e:
            print(f"Error in channel features: {str(e)}")
            raise
            return df.with_columns([
                pl.lit(0).alias('unique_channels_followed'),
                pl.lit(0).alias('rapid_channel_follows'),
                pl.lit(0.0).alias('channel_follow_hour_std'),
                pl.lit(0).alias('channel_memberships'),
                pl.lit(0).alias('unique_channel_memberships'),
                pl.lit(0.0).alias('channel_follow_burst_ratio'),
                pl.lit(0.0).alias('channel_engagement_ratio')
            ])

    def add_engagement_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Add engagement features with improved dependency handling"""
        try:
            print("Processing engagement features...")
            
            # Initialize required columns with defaults
            required_cols = {
                'cast_count': 0,
                'total_reactions': 0,
                'channel_memberships': 0
            }
            
            # Ensure base columns exist
            result = df.clone()
            for col, default in required_cols.items():
                if col not in result.columns:
                    result = result.with_columns(pl.lit(default).alias(col))
                else:
                    result = result.with_columns(pl.col(col).fill_null(default))
            
            # Calculate engagement metrics safely
            result = result.with_columns([
                # Overall engagement score
                ((pl.col('cast_count') + 
                pl.col('total_reactions') + 
                pl.col('channel_memberships')) / 3.0
                ).alias('engagement_score'),
                
                # Activity balance
                (pl.col('cast_count') / pl.col('total_reactions').add(1))
                    .alias('creation_consumption_ratio')
            ])
            
            return result
            
        except Exception as e:
            print(f"Error in engagement features: {str(e)}")
            raise
            return df.with_columns([
                pl.lit(0.0).alias('engagement_score'),
                pl.lit(0.0).alias('creation_consumption_ratio')
            ])
    def _validate_and_ensure_features(self, df: pl.DataFrame, 
                                required_features: Dict[str, float]) -> pl.DataFrame:
        """Enhanced feature validation with null handling"""
        result = df.clone()
        
        for feature, default_value in required_features.items():
            if feature not in result.columns:
                print(f"Adding missing feature {feature} with default value {default_value}")
                result = result.with_columns(pl.lit(default_value).alias(feature))
            else:
                result = result.with_columns(
                    pl.when(pl.col(feature).is_null())
                    .then(pl.lit(default_value))
                    .otherwise(pl.col(feature))
                    .alias(feature)
                )
        
        return result
    def _load_checkpoint(self, feature_set: FeatureSet, base_fids: pl.Series) -> pl.DataFrame:
        """Enhanced checkpoint loading with proper list type handling"""
        try:
            print(f"Loading checkpoint: {feature_set.checkpoint_path}")
            checkpoint_df = pl.read_parquet(feature_set.checkpoint_path)
            checkpoint_df = checkpoint_df.with_columns(pl.col('fid').cast(pl.Int64))

            # Handle each column based on its type
            for col in checkpoint_df.columns:
                if col == 'fid':
                    continue
                    
                dtype_str = str(checkpoint_df[col].dtype).lower()
                
                # Skip list type columns
                if 'list' in dtype_str:
                    continue
                    
                # Handle numeric columns
                if any(num_type in dtype_str for num_type in ['int', 'float', 'decimal']):
                    checkpoint_df = checkpoint_df.with_columns([
                        pl.col(col).cast(pl.Float64).fill_null(0)
                    ])

            # Debug info
            print(f"Checkpoint fid type: {checkpoint_df['fid'].dtype}")
            print(f"Base fids type: {base_fids.dtype}")
            
            # Force consistent FID types
            checkpoint_df = checkpoint_df.with_columns(pl.col('fid').cast(pl.Int64))
            base_fids = base_fids.cast(pl.Int64)
            
            # Filter to base_fids
            filtered_df = checkpoint_df.filter(pl.col('fid').is_in(base_fids))
            print(f"Filtered checkpoint from {len(checkpoint_df)} to {len(filtered_df)} rows")
            
            # Special handling for certain feature sets
            if feature_set.name in ['authenticity', 'update_behavior']:
                filtered_df = self._validate_sensitive_checkpoint(filtered_df, feature_set.name)
                
            return filtered_df
            
        except Exception as e:
            print(f"Error loading checkpoint: {str(e)}")
            raise

    def _validate_checkpoint_compatibility(self, checkpoint_df: pl.DataFrame, 
                                        base_fids: pl.Series) -> bool:
        """Validate checkpoint compatibility with list type handling"""
        try:
            if checkpoint_df is None or len(checkpoint_df) == 0:
                return False
                
            # Verify FID column exists and is correct type
            if 'fid' not in checkpoint_df.columns:
                return False
                
            checkpoint_fids = checkpoint_df['fid'].cast(pl.Int64)
            base_fids = base_fids.cast(pl.Int64)
            
            # Verify all base FIDs are present
            missing_fids = pl.Series(np.setdiff1d(base_fids, checkpoint_fids))
            if len(missing_fids) > 0:
                print(f"Missing FIDs in checkpoint: {missing_fids}")
                return False
                
            # Verify column types
            for col in checkpoint_df.columns:
                if col == 'fid':
                    continue
                    
                dtype_str = str(checkpoint_df[col].dtype).lower()
                # Skip validation for list type columns
                if 'list' in dtype_str:
                    continue
                    
                # Validate numeric columns
                if any(num_type in dtype_str for num_type in ['int', 'float', 'decimal']):
                    try:
                        # Test if we can cast to Float64
                        checkpoint_df.select(pl.col(col).cast(pl.Float64))
                    except Exception as e:
                        print(f"Column {col} failed type validation: {str(e)}")
                        return False
                        
            return True
            
        except Exception as e:
            print(f"Error validating checkpoint compatibility: {str(e)}")
            return False
    def add_nindexer_enhanced_network_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Add enhanced network features with scalar aggregations"""
        try:
            follows = self.loader.get_dataset('follows', 
                ['fid', 'target_fid', 'timestamp', 'created_at', 'deleted_at'], 
                source="nindexer")
            follow_counts = self.loader.get_dataset('follow_counts',
                ['fid', 'follower_count', 'following_count', 'created_at'], 
                source="nindexer")
            
            if follows is not None and len(follows) > 0:
                valid_follows = follows.filter(pl.col('deleted_at').is_null())
                
                follow_metrics = (valid_follows
                    .with_columns([
                        pl.col('timestamp').cast(pl.Datetime),
                        pl.col('created_at').cast(pl.Datetime)
                    ])
                    .group_by('fid')
                    .agg([
                        (pl.col('timestamp').max() - pl.col('timestamp').min())
                            .dt.total_hours()
                            .cast(pl.Float64)
                            .alias('network_age_hours'),
                        pl.len().alias('total_follows'),
                        (pl.col('created_at') - pl.col('timestamp'))
                            .dt.total_seconds()
                            .mean()
                            .cast(pl.Float64)
                            .alias('avg_follow_latency_seconds')
                    ])
                    .with_columns([
                        (pl.col('total_follows') / 
                        (pl.col('network_age_hours') + 1))
                        .alias('follow_rate_per_hour')
                    ]))
                
                if follow_counts is not None and len(follow_counts) > 0:
                    count_metrics = (follow_counts
                        .sort('created_at')  # Sort to ensure last() gets most recent
                        .group_by('fid')
                        .agg([
                            pl.col('follower_count')
                                .last()
                                .cast(pl.Float64)
                                .alias('latest_follower_count'),
                            pl.col('following_count')
                                .last()
                                .cast(pl.Float64)
                                .alias('latest_following_count')
                        ])
                        .with_columns([
                            (pl.col('latest_follower_count') / 
                            (pl.col('latest_following_count') + 1))
                            .alias('latest_follow_ratio')
                        ]))
                    
                    result = df.join(follow_metrics, on='fid', how='left')
                    result = result.join(count_metrics, on='fid', how='left')
                else:
                    result = df.join(follow_metrics, on='fid', how='left')
                
                return result.fill_null(0)
                
            return df.with_columns([
                pl.lit(0.0).alias('network_age_hours'),
                pl.lit(0.0).alias('total_follows'),
                pl.lit(0.0).alias('follow_rate_per_hour'),
                pl.lit(0.0).alias('avg_follow_latency_seconds'),
                pl.lit(0.0).alias('latest_follower_count'),
                pl.lit(0.0).alias('latest_following_count'),
                pl.lit(0.0).alias('latest_follow_ratio')
            ])
            
        except Exception as e:
            print(f"Error in enhanced network features: {str(e)}")
            raise

    def add_nindexer_enhanced_profile_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Add enhanced profile features with scalar aggregations"""
        try:
            profiles = self.loader.get_dataset('profiles', 
                ['fid', 'bio', 'pfp_url', 'url', 'username', 
                'location', 'created_at', 'updated_at'], 
                source="nindexer")
            
            if profiles is not None and len(profiles) > 0:
                profile_metrics = (profiles
                    .with_columns([
                        pl.col('created_at').cast(pl.Datetime),
                        pl.col('updated_at').cast(pl.Datetime),
                        pl.col('url').is_not_null().cast(pl.Int32).alias('has_url'),
                        pl.col('location').is_not_null().cast(pl.Int32).alias('has_location')
                    ])
                    .group_by('fid')
                    .agg([
                        # Ensure scalar sum
                        (pl.col('has_url') + pl.col('has_location'))
                            .cast(pl.Float64)
                            .alias('additional_profile_fields'),
                        (pl.col('updated_at').max() - pl.col('created_at').min())
                            .dt.total_hours()
                            .cast(pl.Float64)
                            .alias('profile_age_hours'),
                        pl.col('location')
                            .first()
                            .alias('location')
                    ]))
                
                result = df.join(profile_metrics, on='fid', how='left')
                
                if result.select(pl.col('location').is_not_null().sum()).item() > 0:
                    result = result.with_columns([
                        pl.col('location')
                            .is_not_null()
                            .cast(pl.Int32)
                            .alias('has_location_info')
                    ])
                
                return result.fill_null(0)
                
            return df
            
        except Exception as e:
            print(f"Error in enhanced profile features: {str(e)}")
            raise
    def add_nindexer_neynar_score_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Add Neynar score features and correlations with proper type handling"""
        try:
            scores = self.loader.get_dataset('neynar_user_scores', 
                ['fid', 'score', 'created_at'], source="nindexer")
            
            if scores is not None and len(scores) > 0:
                # Get latest scores per user and ensure we're dealing with scalar values
                score_features = (scores
                    .with_columns([
                        pl.col('created_at').cast(pl.Datetime),
                        # Ensure score is handled as a scalar
                        pl.when(pl.col('score').is_null())
                        .then(0.0)
                        .otherwise(pl.col('score'))
                        .alias('score')
                    ])
                    .group_by('fid')
                    .agg([
                        # Latest score
                        pl.col('score').last().alias('neynar_score'),
                        # Average score over time
                        pl.col('score').mean().alias('avg_neynar_score'),
                        # Score stability
                        pl.col('score').std().alias('neynar_score_std'),
                        # Score trend (positive or negative)
                        (pl.col('score').last() - pl.col('score').first()).alias('score_trend')
                    ]))
                
                result = df.join(score_features, on='fid', how='left')
                
                # Calculate correlation with authenticity score if it exists
                if 'authenticity_score' in result.columns:
                    result = result.with_columns([
                        # Safely calculate score difference
                        (pl.col('neynar_score').cast(pl.Float64) - 
                        pl.col('authenticity_score').cast(pl.Float64))
                        .abs()
                        .alias('score_divergence'),
                        
                        # Calculate relative score difference
                        ((pl.col('neynar_score').cast(pl.Float64) - 
                        pl.col('authenticity_score').cast(pl.Float64)) /
                        (pl.col('authenticity_score').cast(pl.Float64) + 1e-6))
                        .alias('relative_score_diff')
                    ])
                
                # Fill any remaining nulls with 0
                result = result.with_columns([
                    pl.col('neynar_score').fill_null(0.0),
                    pl.col('avg_neynar_score').fill_null(0.0),
                    pl.col('neynar_score_std').fill_null(0.0),
                    pl.col('score_trend').fill_null(0.0)
                ])
                
                if 'score_divergence' in result.columns:
                    result = result.with_columns([
                        pl.col('score_divergence').fill_null(0.0),
                        pl.col('relative_score_diff').fill_null(0.0)
                    ])
                
                return result
            
            return df
            
        except Exception as e:
            print(f"Error in neynar score features: {str(e)}")
            # Return original dataframe with default columns if error occurs
            return df.with_columns([
                pl.lit(0.0).alias('neynar_score'),
                pl.lit(0.0).alias('avg_neynar_score'),
                pl.lit(0.0).alias('neynar_score_std'),
                pl.lit(0.0).alias('score_trend')
            ])
    def add_name_pattern_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Add name pattern features with improved error handling"""
        try:
            print("Building name pattern features...")
            
            # Initialize with default values
            result = df.clone().with_columns([
                pl.lit(0).alias('random_numbers'),
                pl.lit(0).alias('wallet_pattern'),
                pl.lit(0).alias('excessive_symbols'),
                pl.lit(0).alias('airdrop_terms'),
                pl.lit(0).alias('has_year')
            ])

            result = result.with_columns([
                    pl.col('fname').map_elements(self._analyze_name_patterns, return_dtype=pl.Utf8).alias('fname_content_patterns'),
                    pl.col('bio').map_elements(self._analyze_name_patterns, return_dtype=pl.Utf8).alias('bio_content_patterns'),
                ])

            result = result.with_columns([
                    (pl.col('fname_content_patterns').map_elements(lambda x: x['fname_random_numbers'], return_dtype=pl.Float64).sum() / pl.len())
                        .alias('fname_random_numbers'),
                    (pl.col('fname_content_patterns').map_elements(lambda x: x['fname_wallet_pattern'], return_dtype=pl.Float64).sum() / pl.len())
                        .alias('fname_wallet_pattern'),
                    (pl.col('fname_content_patterns').map_elements(lambda x: x['fname_excessive_symbols'], return_dtype=pl.Float64).sum() / pl.len())
                        .alias('fname_excessive_symbols'),
                    (pl.col('fname_content_patterns').map_elements(lambda x: x['fname_airdrop_terms'], return_dtype=pl.Float64).sum() / pl.len())
                        .alias('fname_airdrop_terms'),
                    (pl.col('fname_content_patterns').map_elements(lambda x: x['fname_has_year'], return_dtype=pl.Float64).sum() / pl.len())
                        .alias('fname_has_year'),
                    (pl.col('bio_content_patterns').map_elements(lambda x: x['bio_random_numbers'], return_dtype=pl.Float64).sum() / pl.len())
                        .alias('bio_random_numbers'),
                    (pl.col('bio_content_patterns').map_elements(lambda x: x['bio_wallet_pattern'], return_dtype=pl.Float64).sum() / pl.len())
                        .alias('bio_wallet_pattern'),
                    (pl.col('bio_content_patterns').map_elements(lambda x: x['bio_excessive_symbols'], return_dtype=pl.Float64).sum() / pl.len())
                        .alias('bio_excessive_symbols'),
                    (pl.col('bio_content_patterns').map_elements(lambda x: x['bio_airdrop_terms'], return_dtype=pl.Float64).sum() / pl.len())
                        .alias('bio_airdrop_terms'),
                    (pl.col('bio_content_patterns').map_elements(lambda x: x['bio_has_year'], return_dtype=pl.Float64).sum() / pl.len())
                        .alias('bio_has_year'),
            ])

            
            return result
            
        except Exception as e:
            print(f"Error in name pattern features: {str(e)}")
            raise
            return df.with_columns([
                pl.lit(0).alias('random_numbers'),
                pl.lit(0).alias('wallet_pattern'),
                pl.lit(0).alias('excessive_symbols'),
                pl.lit(0).alias('airdrop_terms'),
                pl.lit(0).alias('has_year')
            ])

    def build_feature_matrix(self) -> pl.DataFrame:
        """Build feature matrix with enhanced safety checks while maintaining existing functionality"""
        print("Starting feature extraction...")
        
        try:
            # Load or build profile features
            if self._needs_rebuild(self.feature_sets['profile']):
                print("Building profile features...")
                df = self.extract_profile_features()
            else:
                print("Loading profile features from checkpoint...")
                df = pl.read_parquet(self.feature_sets['profile'].checkpoint_path)
            
            # Setup base configuration
            df = df.with_columns(pl.col('fid').cast(pl.Int64))
            base_fids = df['fid'].cast(pl.Int64).unique()
            df = df.filter(pl.col('fid').is_in(base_fids))
            self.loader.set_base_fids(base_fids)
            initial_cols = df.columns
            print(f"Base shape: {df.shape}")

            # Define dependencies between features
            dependencies = {
                'engagement': ['cast', 'reaction', 'channel'],
                'network_quality': ['network', 'engagement'],
                'activity_patterns': ['temporal', 'cast', 'reaction'],
                'mentions': ['cast'],
                'reply_patterns': ['cast'],
                'update_behavior': ['user_data'],
                'verification_patterns': ['verification'],
                'authenticity': ['profile', 'network', 'verification', 'engagement']
            }

            # Track successfully built features
            built_features = {'profile'}
            
            feature_sequence = [
                ('network', self.add_network_features),
                ('temporal', self.add_temporal_features),
                ('cast', self.add_cast_behavior_features),
                ('reaction', self.add_reaction_patterns),
                ('channel', self.add_enhanced_channel_features),
                ('user_data', self.add_user_data_features),
                ('verification', self.add_enhanced_verification_features),
                ('engagement', self.add_engagement_features),
                ('network_quality', self.build_network_quality_features),
                ('activity_patterns', self.add_activity_patterns_features),
                ('influence', self.add_influence_features),
                ('mentions', self.add_mentions_features),
                ('reply_patterns', self.add_reply_patterns_features),
                ('power_user_interaction', self.add_power_user_interaction_features),
                # ('cluster_analysis', self.add_cluster_analysis_features),
                ('update_behavior', self.add_update_behavior_features),
                ('verification_patterns', self.add_verification_patterns_features),
                ('authenticity', self.add_authenticity_features),
                ('storage', self.add_storage_features),
                ('derived', self._add_derived_features),
                ('enhanced_network', self.add_nindexer_enhanced_network_features),
                ('enhanced_profile', self.add_nindexer_enhanced_profile_features),
                ('neynar_score', self.add_nindexer_neynar_score_features),
                ('name_patterns', self.add_name_pattern_features),
                # ('content_patterns', self.add_cast_behavior_features),  # Modified version
                # ('advanced_temporal', self.add_advanced_temporal_features),
                # ('reward_gaming', self.add_reward_gaming_features),
                # ('engagement_authenticity', self.add_engagement_authenticity_features)
            ]

            for feature_name, feature_func in feature_sequence:
                feature_set = self.feature_sets[feature_name]
                current_cols = set(df.columns)
                
                try:
                    # Check if dependencies are met
                    should_rebuild = self._needs_rebuild(feature_set)
                    if feature_name in dependencies:
                        deps = dependencies[feature_name]
                        missing_deps = [dep for dep in deps if dep not in built_features]
                        if missing_deps:
                            print(f"Missing dependencies for {feature_name}: {missing_deps}")
                            print(f"Currently built features: {built_features}")
                            should_rebuild = True

                    if should_rebuild:
                        print(f"Building {feature_name} features...")
                        new_df = feature_func(df)
                        
                        if new_df is not None:
                            # Validate and safely join new features
                            new_df = self._validate_checkpoint(new_df, feature_name)
                            new_df = new_df.with_columns(pl.col('fid').cast(pl.Int64))
                            
                            # Only save and update if validation passes
                            if self._validate_checkpoint_compatibility(new_df, base_fids):
                                self._save_checkpoint(new_df, feature_set)
                                df = self._safe_join_features(df, new_df, feature_name)
                                built_features.add(feature_name)
                                print(f"Successfully built and saved {feature_name}")
                    else:
                        print(f"Loading {feature_name} features from checkpoint...")
                        checkpoint_df = self._load_checkpoint(feature_set, base_fids)
                        
                        if checkpoint_df is not None:
                            new_cols = [c for c in checkpoint_df.columns if c not in current_cols]
                            if new_cols:
                                print(f"Adding {len(new_cols)} new columns from {feature_name}")
                                # Use safe join for checkpoint data too
                                df = self._safe_join_features(
                                    df,
                                    checkpoint_df.select(['fid'] + new_cols),
                                    feature_name
                                )
                                built_features.add(feature_name)
                        else:
                            print(f"Failed to load {feature_name} checkpoint, forcing rebuild...")
                            new_df = feature_func(df)
                            if new_df is not None:
                                new_df = self._validate_checkpoint(new_df, feature_name)
                                new_df = new_df.with_columns(pl.col('fid').cast(pl.Int64))
                                self._save_checkpoint(new_df, feature_set)
                                df = self._safe_join_features(df, new_df, feature_name)
                                built_features.add(feature_name)
                    
                    print(f"Shape after {feature_name}: {df.shape}")
                    
                except Exception as e:
                    print(f"Error in {feature_name}: {str(e)}")
                    raise
                    continue

            # Final validation
            df = df.fill_null(0)
            df = df.with_columns(pl.col('fid').cast(pl.Int64))
            
            # self.verify_matrix(df)


            return df
            
        except Exception as e:
            print(f"Critical error: {str(e)}")
            raise
        
    def _validate_checkpoint(self, df: pl.DataFrame, name: str) -> pl.DataFrame:
        """Validate checkpoint data types and ensure type consistency with list handling"""
        try:
            # Always ensure fid is Int64 first
            if 'fid' in df.columns:
                df = df.with_columns(pl.col('fid').cast(pl.Int64))
                
            # Cast numeric columns and handle nulls, excluding list types
            numeric_cols = []
            for col in df.columns:
                if col != 'fid':
                    dtype_str = str(df[col].dtype).lower()
                    # Check if it's a list type
                    if 'list' in dtype_str:
                        continue
                    # Check if it's a numeric type
                    if any(num_type in dtype_str for num_type in ['int', 'float', 'decimal']):
                        numeric_cols.append(col)
            
            if numeric_cols:
                df = df.with_columns([
                    pl.col(col).cast(pl.Float64).fill_null(0) 
                    for col in numeric_cols
                ])
            
            return df
                
        except Exception as e:
            print(f"Error validating checkpoint {name}: {str(e)}")
            raise
            return df

    def _is_numeric_dtype(self, dtype) -> bool:
        """Check if a Polars dtype is numeric, excluding list types"""
        # Convert dtype to string for comparison
        dtype_str = str(dtype).lower()
        # Exclude list types
        if 'list' in dtype_str:
            return False
        return any(num_type in dtype_str 
                for num_type in ['int', 'float', 'decimal'])

    def _safe_join_features(self, df: pl.DataFrame, 
                        new_features: pl.DataFrame,
                        feature_name: str) -> pl.DataFrame:
        """Enhanced safe join features with comprehensive null and list handling"""
        try:
            if new_features is None or len(new_features) == 0:
                print(f"No valid features to join for {feature_name}")
                return df

            # Get new columns
            existing_cols = set(df.columns)
            new_cols = [c for c in new_features.columns 
                    if c != 'fid' and c not in existing_cols]
                    
            if not new_cols:
                print(f"No new columns to add from {feature_name}")
                return df
                
            # Handle nulls in new features before join
            safe_features = new_features.clone()
            for col in new_cols:
                dtype_str = str(new_features[col].dtype).lower()
                if 'list' in dtype_str:
                    # For list columns, replace null with empty list
                    safe_features = safe_features.with_columns(
                        pl.col(col).fill_null([])
                    )
                elif self._is_numeric_dtype(new_features[col].dtype):
                    # For numeric columns, fill null with 0
                    safe_features = safe_features.with_columns(
                        pl.col(col).fill_null(0.0)
                    )
            
            # Join with guaranteed FID type consistency
            safe_features = safe_features.unique(subset=['fid']) 
            result = df.join(
                safe_features.select(['fid'] + new_cols)
                .with_columns(pl.col('fid').cast(pl.Int64)),
                on='fid',
                how='left'
            )
            
            # Handle any new nulls that appeared after join
            for col in new_cols:
                dtype_str = str(result[col].dtype).lower()
                if 'list' in dtype_str:
                    result = result.with_columns(
                        pl.col(col).fill_null([])
                    )
                elif self._is_numeric_dtype(result[col].dtype):
                    result = result.with_columns(
                        pl.col(col).fill_null(0.0)
                    )
                            
            return result
                
        except Exception as e:
            print(f"Error joining {feature_name}: {str(e)}")
            raise
            return df

    def _validate_feature_dependencies(self, feature_name: str, 
                                built_features: set) -> bool:
        """Validate feature dependencies are met"""
        if feature_name not in self.feature_sets:
            return False
            
        feature_set = self.feature_sets[feature_name]
        for dep in feature_set.dependencies:
            if dep not in built_features:
                print(f"Missing dependency {dep} for {feature_name}")
                return False
                
        return True
    def _save_checkpoint(self, df: pl.DataFrame, feature_set: 'FeatureSet'):
        """Save feature checkpoint with validation"""
        # Validate before saving
        df = self._validate_checkpoint(df, feature_set.name)
        df.write_parquet(feature_set.checkpoint_path)
        feature_set.last_modified = os.path.getmtime(feature_set.checkpoint_path)



    def _validate_sensitive_checkpoint(self, df: pl.DataFrame, feature_name: str) -> pl.DataFrame:
        """Additional validation for sensitive features"""
        try:
            # Initialize sensitive columns with safe defaults
            sensitive_defaults = {
                'authenticity': {
                    'authenticity_score': 0.0,
                    'profile_completeness': 0.0,
                    'network_balance': 0.0,
                    'update_naturalness': 0.0
                },
                'update_behavior': {
                    'profile_update_consistency': 0.0,
                    'total_updates': 0,
                    'avg_update_interval': 0.0,
                    'update_time_std': 0.0
                }
            }
            
            if feature_name in sensitive_defaults:
                for col, default in sensitive_defaults[feature_name].items():
                    if col in df.columns:
                        df = df.with_columns(pl.col(col).fill_null(default))
                    else:
                        df = df.with_columns(pl.lit(default).alias(col))
                        
            return df
            
        except Exception as e:
            print(f"Error validating sensitive checkpoint {feature_name}: {str(e)}")
            raise
            return df
    def _add_derived_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Add derived features with comprehensive null handling"""
        try:
            print("Building derived features...")
            result = df.clone()
            
            # Ensure required columns exist
            required_cols = {
                'following_count': 0.0,
                'follower_count': 0.0,
                'follower_ratio': 0.0,
                'unique_follower_ratio': 0.0,
                'follow_velocity': 0.0
            }
            
            # Initialize missing columns
            for col, default in required_cols.items():
                if col not in result.columns:
                    print(f"Adding missing column {col} with default {default}")
                    result = result.with_columns(pl.lit(default).alias(col))
                
                # Fill nulls with defaults
                result = result.with_columns(
                    pl.col(col).fill_null(default).alias(col)
                )
            
            # Safe calculations with explicit null handling
            result = result.with_columns([
                # Log transformations with null safety
                pl.col('follower_ratio').fill_null(0.0).log1p().alias('follower_ratio_log'),
                pl.col('unique_follower_ratio').fill_null(0.0).log1p().alias('unique_follower_ratio_log'),
                pl.col('follow_velocity').fill_null(0.0).log1p().alias('follow_velocity_log'),
                
                # Binary flags with safe comparisons
                (pl.when(pl.col('follower_count').fill_null(0) > pl.col('following_count').fill_null(0))
                .then(1)
                .otherwise(0)
                ).alias('has_more_followers'),
                
                # Balance ratios with safe division
                ((pl.col('following_count').fill_null(0) - pl.col('follower_count').fill_null(0)).abs() / 
                (pl.col('following_count').fill_null(0) + pl.col('follower_count').fill_null(0) + 1)
                ).alias('follow_balance_ratio')
            ])
            
            # Cap extreme values with safe operations
            for col in ['follower_ratio', 'unique_follower_ratio', 'follow_velocity']:
                if col in result.columns:
                    safe_col = pl.col(col).fill_null(0.0)
                    p99 = result.select(safe_col.quantile(0.99)).item()
                    result = result.with_columns([
                        safe_col.clip(0.0, p99).alias(f'{col}_capped')
                    ])
            
            print("Derived features completed successfully")
            return result
            
        except Exception as e:
            print(f"Error in derived features: {str(e)}")
            raise
            return df

In [None]:
# notebook code
import optuna
from sklearn.model_selection import StratifiedKFold, cross_val_score, train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.calibration import CalibratedClassifierCV
from sklearn.ensemble import RandomForestClassifier, VotingClassifier
from sklearn.metrics import (
    classification_report, confusion_matrix, f1_score,
    precision_recall_curve, precision_score, recall_score,
    roc_auc_score
)
import xgboost as xgb
from lightgbm import LGBMClassifier
import shap
from scipy import stats
import numpy as np
from typing import Dict, List, Tuple
import polars as pl

class SybilDetectionSystem:
    def __init__(self, 
                 feature_engineering: 'FeatureEngineering',
                 confidence_thresholds: Dict[str, float] = None,
                 authenticity_thresholds: Dict[str, float] = None):
        self.feature_engineering = feature_engineering
        self.model = None
        self.feature_names = None
        self.scaler = StandardScaler()
        self.confidence_thresholds = confidence_thresholds or {
            'high': 0.95,
            'medium': 0.85,
            'low': 0.70
        }
        self.authenticity_thresholds = authenticity_thresholds or {
            'high': 0.8,
            'medium': 0.6,
            'low': 0.4
        }
        self.feature_importance = {}
        self.shap_values = {}
        self.base_models = {}
        self.shap_explainers = {}
            
    def prepare_features(self, df: pl.DataFrame, scale: bool = False) -> Tuple[np.ndarray, List[str]]:
        """Prepare features with comprehensive feature selection and validation"""
        try:
            # Define feature groups

            valid_cols = [col for col in df.columns if 
                        df[col].dtype in [pl.Float64, pl.Int64] or
                        str(df[col].dtype).startswith(('Float', 'Int'))]
            
            print(f"\nTotal numeric features available: {len(valid_cols)}")
            
            # Convert to numpy array
            features = df.select(valid_cols).fill_null(0)
            for col in valid_cols:
                col_dtype = str(features[col].dtype)
                
                if col_dtype.startswith('list') or col_dtype.startswith('List'):
                    print(f"Converting list column {col} to length feature")
                    features = features.with_columns([
                        pl.when(pl.col(col).is_null())
                        .then(0)
                        .otherwise(pl.col(col).list.len())
                        .alias(col)
                    ])
            
            # Handle infinite values and extreme outliers
            for col in valid_cols:
                col_stats = features.select(
                    pl.col(col).quantile(0.01).alias('q01'),
                    pl.col(col).quantile(0.99).alias('q99'),
                    pl.col(col).mean().alias('mean'),
                    pl.col(col).std().alias('std')
                )
                
                q01 = col_stats['q01'][0]
                q99 = col_stats['q99'][0]
                mean_val = col_stats['mean'][0]
                std_val = col_stats['std'][0]
                
                # Define reasonable bounds for the column
                lower_bound = max(q01, mean_val - 3 * std_val)
                upper_bound = min(q99, mean_val + 3 * std_val)
                
                # Clip values to bounds and replace infinities
                features = features.with_columns([
                    pl.when(pl.col(col).is_infinite())
                    .then(pl.lit(None))
                    .otherwise(pl.col(col))
                    .alias(col)
                ])
                
                features = features.with_columns([
                    pl.col(col).clip(lower_bound, upper_bound).alias(col)
                ])
                
                # Fill remaining nulls with median
                median_val = features.select(pl.col(col).median())[0][0]
                features = features.with_columns([
                    pl.col(col).fill_null(median_val).alias(col)
                ])
                
                # Convert to numeric if needed
                if features[col].dtype not in [pl.Float64, pl.Int64]:
                    features = features.with_columns([
                        pl.col(col).cast(pl.Float64).alias(col)
                    ])

            # Convert to numpy array
            feature_array = features.to_numpy()
            
            if scale:
                feature_array = self.scaler.fit_transform(feature_array)

            print(f"\nFinal feature matrix shape: {feature_array.shape}")
            print(f"Using {len(valid_cols)} features")
            
            # Verify no infinite values remain
            if np.any(np.isinf(feature_array)):
                raise ValueError("Infinite values still present after preprocessing")

            return feature_array, valid_cols
            
        except Exception as e:
            print(f"Error preparing features: {str(e)}")
            print(f"Available columns: {df.columns}")
            raise
    def train(self, X_train: np.ndarray, y_train: np.ndarray, feature_names: List[str]):
        """Train the model with stacking and SHAP explanations"""
        self.feature_names = feature_names
        cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
        
        # Define base models
        base_model_configs = {
            'xgb': xgb.XGBClassifier(eval_metric='auc', random_state=42),
            'rf': RandomForestClassifier(n_jobs=-1, random_state=42, class_weight='balanced'),
            'lgbm': LGBMClassifier(n_jobs=-1, random_state=42, class_weight='balanced')
        }
        
        # Train and calibrate base models
        for name, model in base_model_configs.items():
            print(f"\nStarting training for {name}...")
            study = optuna.create_study(direction='maximize', study_name=f'optuna_{name}')
            
            def objective(trial):
                params = self.get_hyperparameters(name, trial)
                model.set_params(**params)
                cv_scores = cross_val_score(model, X_train, y_train, cv=cv, scoring='average_precision')
                return cv_scores.mean()
            
            study.optimize(objective, n_trials=50, timeout=600)
            print(f"Best parameters for {name}: {study.best_params}")
            print(f"Best CV score for {name}: {study.best_value}")
            
            # Train base model
            best_model = type(model)(**study.best_params)
            best_model.fit(X_train, y_train)
            
            # Store SHAP explainer
            try:
                explainer = shap.TreeExplainer(best_model)
                self.shap_explainers[name] = explainer
                print(f"SHAP explainer created for {name}")
            except Exception as e:
                print(f"Error creating SHAP explainer for {name}: {str(e)}")
                self.shap_explainers[name] = None
            
            # Calibrate model
            print(f"Calibrating {name}...")
            calibrated_model = CalibratedClassifierCV(best_model, cv=5)
            calibrated_model.fit(X_train, y_train)
            self.base_models[name] = calibrated_model
        
        # Build stacked model
        print("\nBuilding stacked model...")
        self.build_stacked_model(X_train, y_train)
        
        # Create final ensemble
        print("\nCreating ensemble...")
        self.model = VotingClassifier(
            estimators=[
                (name, model) for name, model in self.base_models.items()
            ] + [('meta_learner', self.meta_learner)],
            voting='soft',
            weights=[0.25, 0.25, 0.25, 0.25]
        )
        self.model.fit(X_train, y_train)
        print("Ensemble training complete")

        # Get stability metrics
        stability_results = detector.add_cross_validation_stability(X_train, y_train)
        print("\nCross-validation Stability Metrics:")
        print(f"Mean prediction variance: {stability_results['mean_prediction_variance']:.4f}")
        print(f"Max prediction variance: {stability_results['max_prediction_variance']:.4f}")
        print(f"Mean prediction range: {stability_results['mean_prediction_range']:.4f}")
        print(f"Percentage of stable predictions: {stability_results['stable_prediction_percentage']:.2%}")        

        # Split some validation data
        X_val, X_test, y_val, y_test = train_test_split(
            X_train, y_train, test_size=0.2, random_state=42
        )
        
        # Optimize ensemble weights
        weights = self.optimize_ensemble_weights(X_val, y_val)
        print("Optimized model weights:", weights)
        
        # Final evaluation
        final_predictions, unstable_indices = self.predict_with_stability(X_test)
        print(f"Number of unstable predictions: {len(unstable_indices)}")
        
        return self

    def build_stacked_model(self, X: np.ndarray, y: np.ndarray):
        """Build stacked model using base model predictions"""
        base_preds = np.zeros((len(self.base_models), len(X)))
        for i, (name, model) in enumerate(self.base_models.items()):
            base_preds[i] = model.predict_proba(X)[:, 1]
        
        meta_features = np.column_stack([base_preds.T, X])
        meta_learner = LGBMClassifier(
            n_estimators=100,
            learning_rate=0.01,
            max_depth=3,
            num_leaves=8,
            feature_fraction=0.8,
            bagging_fraction=0.8,
            random_state=42
        )
        meta_learner.fit(meta_features, y)
        self.meta_learner = meta_learner

    def get_feature_explanations(self, model_name: str, X: np.ndarray, instance_index: int) -> Dict:
        """Get SHAP explanations for a specific instance"""
        try:
            if instance_index < 0 or instance_index >= X.shape[0]:
                print(f"Invalid instance index: {instance_index}")
                return {}
                
            if model_name not in self.shap_explainers or self.shap_explainers[model_name] is None:
                print(f"No SHAP explainer available for {model_name}")
                return {}
            
            explainer = self.shap_explainers[model_name]
            shap_values = explainer.shap_values(X[instance_index:instance_index+1])
            
            if isinstance(shap_values, list):
                shap_values = shap_values[1] if len(shap_values) > 1 else shap_values[0]
            
            shap_instance = shap_values[0]
            top_indices = np.argsort(np.abs(shap_instance))[-5:][::-1]
            
            return {
                self.feature_names[i]: float(shap_instance[i])
                for i in top_indices
            }
            
        except Exception as e:
            print(f"Error getting SHAP explanations: {str(e)}")
            return {}

    def _calculate_feature_importance(self):
        """Calculate and store aggregated feature importance from base models"""
        try:
            for name, model in self.base_models.items():
                # Access base estimator within CalibratedClassifierCV
                if isinstance(model, CalibratedClassifierCV):
                    # Try to access base_estimator_ (scikit-learn >=0.24)
                    if hasattr(model, 'base_estimator_') and model.base_estimator_ is not None:
                        base_estimator = model.base_estimator_
                    # For older scikit-learn versions
                    elif hasattr(model, 'base_estimator') and model.base_estimator is not None:
                        base_estimator = model.base_estimator
                    else:
                        print(f"Model {name} does not have a base estimator.")
                        continue
                else:
                    base_estimator = model

                # Retrieve feature importances
                if hasattr(base_estimator, 'feature_importances_'):
                    importances = base_estimator.feature_importances_
                    for feat, imp in zip(self.feature_names, importances):
                        self.feature_importance[feat] = self.feature_importance.get(feat, 0) + imp
                else:
                    print(f"No feature_importances_ attribute for model {name}.")

            # Average importances across models
            num_models = len(self.base_models)
            if num_models > 0:
                self.feature_importance = {k: v / num_models for k, v in self.feature_importance.items()}
                print("Feature importance calculated.")
            else:
                print("No models available to calculate feature importance.")
        except Exception as e:
            print(f"Error calculating feature importance: {str(e)}")
    def select_important_features(self, X: np.ndarray, y: np.ndarray, feature_names: List[str], 
                                threshold: float = 0.01) -> List[str]:
        """Select features based on SHAP importance"""
        model = xgb.XGBClassifier(n_estimators=100, random_state=42)
        model.fit(X, y)
        
        explainer = shap.TreeExplainer(model)
        shap_values = explainer.shap_values(X)
        
        if isinstance(shap_values, list):
            shap_values = shap_values[1]
        importance_vals = np.abs(shap_values).mean(0)
        
        importance = dict(zip(feature_names, importance_vals))
        selected_features = [f for f, imp in importance.items() 
                            if imp > threshold * np.max(importance_vals)]
        
        print(f"\nSelected {len(selected_features)}/{len(feature_names)} features")
        print("Top 10 features:", sorted(importance.items(), key=lambda x: x[1], reverse=True)[:10])
        return selected_features

    def get_hyperparameters(self, model_name: str, trial: optuna.Trial) -> Dict:
        """Get optimized hyperparameters with regularization"""
        if model_name == 'xgb':
            return {
                'max_depth': trial.suggest_int('max_depth', 3, 7),
                'learning_rate': trial.suggest_float('learning_rate', 0.01, 0.1),
                'n_estimators': trial.suggest_int('n_estimators', 100, 500),
                'min_child_weight': trial.suggest_int('min_child_weight', 1, 7),
                'subsample': trial.suggest_float('subsample', 0.6, 1.0),
                'colsample_bytree': trial.suggest_float('colsample_bytree', 0.6, 1.0),
                'reg_alpha': trial.suggest_float('reg_alpha', 0, 10),
                'reg_lambda': trial.suggest_float('reg_lambda', 1, 10),
                'scale_pos_weight': trial.suggest_float('scale_pos_weight', 1.0, 10.0),
                'gamma': trial.suggest_float('gamma', 0, 5)
            }
        elif model_name == 'lgbm':
            return {
                'n_estimators': trial.suggest_int('n_estimators', 100, 500),
                'max_depth': trial.suggest_int('max_depth', 3, 7),
                'learning_rate': trial.suggest_float('learning_rate', 0.01, 0.1),
                'num_leaves': trial.suggest_int('num_leaves', 20, 100),
                'feature_fraction': trial.suggest_float('feature_fraction', 0.6, 1.0),
                'bagging_fraction': trial.suggest_float('bagging_fraction', 0.6, 1.0),
                'min_child_samples': trial.suggest_int('min_child_samples', 5, 30),
                'lambda_l1': trial.suggest_float('lambda_l1', 0, 10),
                'lambda_l2': trial.suggest_float('lambda_l2', 0, 10)
            }
        else:  # RandomForest
            return {
                'n_estimators': trial.suggest_int('n_estimators', 100, 500),
                'max_depth': trial.suggest_int('max_depth', 3, 7),
                'min_samples_split': trial.suggest_int('min_samples_split', 2, 10),
                'min_samples_leaf': trial.suggest_int('min_samples_leaf', 1, 5),
                'max_features': trial.suggest_categorical('max_features', ['sqrt', 'log2'])
            }

    def analyze_feature_interactions(self, X: np.ndarray, top_k: int = 10) -> List[Tuple[str, str, float]]:
        """Analyze most important feature interactions"""
        model = self.base_models['xgb']
        if hasattr(model, 'base_estimator_'):
            model = model.base_estimator_
        
        explainer = shap.TreeExplainer(model)
        shap_interaction_values = explainer.shap_interaction_values(X)
        
        # Calculate interaction strengths
        n_features = len(self.feature_names)
        interactions = []
        for i in range(n_features):
            for j in range(i+1, n_features):
                strength = np.abs(shap_interaction_values[:, i, j]).mean()
                interactions.append((
                    self.feature_names[i],
                    self.feature_names[j],
                    float(strength)
                ))
        
        # Return top K interactions
        return sorted(interactions, key=lambda x: x[2], reverse=True)[:top_k]

    def build_stacked_model(self, X: np.ndarray, y: np.ndarray):
        """Build a stacked model with meta-learner"""
        # Create base predictions
        base_preds = np.zeros((len(self.base_models), len(X)))
        for i, (name, model) in enumerate(self.base_models.items()):
            base_preds[i] = model.predict_proba(X)[:, 1]
        
        # Train meta-learner
        meta_features = np.column_stack([
            base_preds.T,  # Base predictions
            X  # Original features
        ])
        
        meta_learner = LGBMClassifier(
            n_estimators=100,   
            learning_rate=0.01,
            max_depth=3,
            num_leaves=8,
            feature_fraction=0.8,
            bagging_fraction=0.8,
            random_state=42
        )
        
        meta_learner.fit(meta_features, y)
        self.meta_learner = meta_learner
    def predict_with_uncertainty(self, features: np.ndarray, 
                                 authenticity_features: np.ndarray) -> List[Dict]:
        """Enhanced predictions with uncertainty estimation"""
        # Get predictions from all models
        predictions = []
        for name, model in self.base_models.items():
            pred_proba = model.predict_proba(features)[:, 1]
            predictions.append(pred_proba)
        
        # Calculate ensemble statistics
        predictions = np.array(predictions)
        mean_probs = predictions.mean(axis=0)
        std_probs = predictions.std(axis=0)
        
        # Calculate prediction intervals
        confidence_interval = stats.norm.interval(0.95, loc=mean_probs, scale=std_probs)
        
        results = []
        for i, (prob, std, auth_scores) in enumerate(zip(mean_probs, std_probs, authenticity_features)):
            authenticity_score = np.mean([
                auth_scores[0],  # authenticity_score
                auth_scores[1],  # engagement_quality
                auth_scores[2],  # natural_behavior_score
                auth_scores[3]   # account_stability
            ])
            
            # Enhanced confidence assessment
            model_uncertainty = std / prob if prob > 0 else std
            confidence = self._assess_confidence(prob, authenticity_score, model_uncertainty)
            
            results.append({
                'is_bot': prob >= 0.5,
                'is_authentic': authenticity_score >= self.authenticity_thresholds['medium'],
                'bot_probability': float(prob),
                'authenticity_score': float(authenticity_score),
                'confidence': confidence,
                'uncertainty': float(std),
                'prediction_interval': (float(confidence_interval[0][i]), 
                                         float(confidence_interval[1][i]))
            })
        
        return results

    def _assess_confidence(self, prob: float, authenticity: float, 
                          uncertainty: float) -> str:
        """Enhanced confidence assessment with uncertainty consideration"""
        # Adjust thresholds based on uncertainty
        uncertainty_penalty = uncertainty * 2
        
        if prob <= 0.1 and authenticity >= self.authenticity_thresholds['high'] and uncertainty < 0.1:
            return 'high_authentic'
        elif prob <= 0.2 and authenticity >= self.authenticity_thresholds['medium'] and uncertainty < 0.15:
            return 'medium_authentic'
        elif prob >= (self.confidence_thresholds['high'] + uncertainty_penalty):
            return 'high_bot'
        elif prob >= (self.confidence_thresholds['medium'] + uncertainty_penalty):
            return 'medium_bot'
        else:
            return 'uncertain'
    def get_feature_explanations(self, model_name: str, X: np.ndarray, instance_index: int) -> Dict:
        """Get SHAP explanations for a specific instance using the underlying base model"""
        try:
            # Index check
            if instance_index < 0 or instance_index >= X.shape[0]:
                print(f"Error: instance_index {instance_index} is out of bounds for test set with size {X.shape[0]}.")
                return {}
            
            if model_name not in self.base_models:
                print(f"No model available for {model_name}.")
                return {}
            
            model = self.base_models[model_name]
            
            # Get the underlying base model from the CalibratedClassifierCV
            if isinstance(model, CalibratedClassifierCV):
                # Access the first calibrated classifier's base estimator
                base_model = model.calibrated_classifiers_[0].base_estimator
                print(f"Using base estimator from calibrated classifier for {model_name}")
            else:
                base_model = model
                print(f"Using model directly for {model_name}")
                
            try:
                print(f"Creating SHAP explainer for model type: {type(base_model)}")
                explainer = shap.TreeExplainer(base_model)
                
                # Use small subset of data for explanation
                instance_data = X[instance_index:instance_index+1]
                print(f"Calculating SHAP values for instance shape: {instance_data.shape}")
                
                shap_vals = explainer.shap_values(instance_data)
                
                # Handle different SHAP value formats
                if isinstance(shap_vals, list):
                    if len(shap_vals) > 1:
                        shap_instance = shap_vals[1][0]  # For binary classification
                    else:
                        shap_instance = shap_vals[0][0]
                else:
                    shap_instance = shap_vals[0]
                
                # Get top feature contributions
                top_indices = np.argsort(np.abs(shap_instance))[-5:][::-1]
                explanations = {
                    self.feature_names[i]: float(shap_instance[i]) 
                    for i in top_indices
                }
                
                return explanations
                
            except Exception as e:
                print(f"Error calculating SHAP values for {model_name}: {str(e)}")
                print(f"Model type: {type(base_model)}")
                return {}
                
        except Exception as e:
            print(f"Error getting feature explanations: {str(e)}")
            return {}
    def add_cross_validation_stability(self, X: np.ndarray, y: np.ndarray, n_splits: int = 5) -> Dict[str, float]:
        """
        Measure prediction stability across different CV folds.
        Returns metrics about how consistent predictions are across folds.
        """
        kf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
        
        # Initialize array to store all predictions for each sample
        all_predictions = np.zeros((len(X), n_splits))
        all_predictions[:] = np.nan  # Fill with NaN to track which predictions we get
        
        # Get predictions from each fold
        for fold_idx, (train_idx, val_idx) in enumerate(kf.split(X, y)):
            X_train, X_val = X[train_idx], X[val_idx]
            y_train, y_val = y[train_idx], y[val_idx]
            
            # Train model on this fold
            self.model.fit(X_train, y_train)
            fold_proba = self.model.predict_proba(X_val)[:, 1]
            
            # Store predictions in the right spots
            all_predictions[val_idx, fold_idx] = fold_proba
        
        # Calculate variance for each sample (ignoring NaN values)
        sample_variances = np.nanvar(all_predictions, axis=1)
        sample_ranges = np.nanmax(all_predictions, axis=1) - np.nanmin(all_predictions, axis=1)
        
        # Compute stability metrics
        stability_metrics = {
            'mean_prediction_variance': np.mean(sample_variances),
            'max_prediction_variance': np.max(sample_variances),
            'mean_prediction_range': np.mean(sample_ranges),
            'max_prediction_range': np.max(sample_ranges),
            'stable_prediction_percentage': np.mean(sample_variances < 0.1)
        }
        
        print(f"\nPrediction matrix shape: {all_predictions.shape}")
        print(f"Number of samples with predictions: {np.sum(~np.isnan(all_predictions.mean(axis=1)))}")
        print(f"Average predictions per sample: {np.mean(~np.isnan(all_predictions)):.2f}")
        
        return stability_metrics    
    def optimize_ensemble_weights(self, X: np.ndarray, y: np.ndarray) -> List[float]:
        """Optimize ensemble weights based on individual model performance"""
        try:
            # Get individual model performances
            model_scores = {}
            
            # Score base models
            for name, model in self.base_models.items():
                score = roc_auc_score(y, model.predict_proba(X)[:, 1])
                model_scores[name] = score
                print(f"{name} ROC AUC: {score:.4f}")
            
            # Score meta learner on combined predictions
            base_preds = np.zeros((len(self.base_models), len(X)))
            for i, (name, model) in enumerate(self.base_models.items()):
                base_preds[i] = model.predict_proba(X)[:, 1]
            
            meta_features = np.column_stack([base_preds.T, X])
            meta_score = roc_auc_score(y, self.meta_learner.predict_proba(meta_features)[:, 1])
            model_scores['meta_learner'] = meta_score
            print(f"Meta learner ROC AUC: {meta_score:.4f}")
            
            # Calculate weights based on relative performance
            total_score = sum(model_scores.values())
            weights = [score/total_score for score in model_scores.values()]
            
            # Update ensemble with new weights
            self.model = VotingClassifier(
                estimators=[
                    (name, model) for name, model in self.base_models.items()
                ] + [('meta_learner', self.meta_learner)],
                voting='soft',
                weights=weights
            )
            self.model.fit(X, y)  # Refit with new weights
            
            return weights
            
        except Exception as e:
            print(f"Error optimizing weights: {str(e)}")
            return [0.25, 0.25, 0.25, 0.25]  # Default weights

    def predict_with_stability(self, X: np.ndarray) -> Tuple[np.ndarray, List[int]]:
        """Make predictions with stability assessment"""
        try:
            # Get predictions from base models
            base_predictions = np.zeros((len(self.base_models), len(X)))
            for i, (name, model) in enumerate(self.base_models.items()):
                base_predictions[i] = model.predict_proba(X)[:, 1]
            
            # Calculate prediction statistics
            mean_predictions = np.mean(base_predictions, axis=0)
            std_predictions = np.std(base_predictions, axis=0)
            
            # Identify unstable predictions (high variance between models)
            unstable_indices = np.where(std_predictions > 0.2)[0]
            
            # Get ensemble predictions
            predictions = self.model.predict_proba(X)
            
            # Adjust confidence for unstable predictions
            confidence_adjustments = 1 - np.clip(std_predictions, 0, 0.5)
            adjusted_predictions = predictions * confidence_adjustments.reshape(-1, 1)
            
            return adjusted_predictions, unstable_indices.tolist()
            
        except Exception as e:
            print(f"Error in prediction with stability: {str(e)}")
            return self.model.predict_proba(X), []

In [None]:
# former version
import os
import polars as pl
import numpy as np
from datetime import datetime
import logging
from typing import Dict, List, Tuple
import joblib
from sklearn.ensemble import RandomForestClassifier, VotingClassifier
from sklearn.metrics import (
    classification_report, confusion_matrix, f1_score,
    precision_recall_curve, precision_score, recall_score,
    roc_auc_score
)
import xgboost as xgb
from lightgbm import LGBMClassifier

# Configure Polars for memory usage
pl.Config.set_streaming_chunk_size(1_000_000)
pl.Config.set_fmt_str_lengths(50)

class LazyDatasetLoader:
    """Memory-efficient dataset loader with checkpoint-based FID filtering"""
    
    def __init__(self, data_path: str, checkpoint_dir: str, debug_mode: bool = True, sample_size: int = 700_000, fids_to_ensure: List[int] = None):
        self.data_path = data_path
        self.checkpoint_dir = checkpoint_dir
        self._cached_dataset = None
        self._cached_name = None
        self.debug_mode = debug_mode
        self.sample_size = sample_size
        self.base_fids = None
        self.fids_to_ensure = fids_to_ensure


    def set_base_fids(self, fids):
        """Set base FIDs to maintain consistent filtering"""
        self.base_fids = fids
        print(f"Set base FIDs: {len(fids)} records")
        
        
    def get_checkpoint_fids(self):
        """Get base FIDs from profile checkpoint if it exists"""
        profile_checkpoint = f"{self.checkpoint_dir}/profile_features.parquet"
        if os.path.exists(profile_checkpoint):
            df = pl.read_parquet(profile_checkpoint)
            if 'fid' in df.columns:
                self.base_fids = df['fid']
                print(f"Loaded base FIDs from checkpoint: {len(self.base_fids)} records")
                return True
        return False
        
    def get_dataset(self, name: str, columns: List[str] = None, source="farcaster") -> pl.DataFrame:
        """Get dataset with checkpoint-based FID filtering"""
        if self._cached_dataset is not None:
            self._cached_dataset = None
            
        if source == "farcaster":
            path = f"{self.data_path}/farcaster-{name}-0-1733162400.parquet"
        elif source == "nindexer":
            path = f"{self.data_path}/nindexer-{name}-0-1733508243.parquet"
        try:
            scan_query = pl.scan_parquet(path)
            if columns:
                scan_query = scan_query.select(columns)
                
            if self.debug_mode:
                if self.base_fids is None:
                    # Try to get FIDs from checkpoint first
                    if not self.get_checkpoint_fids():
                        if name == 'profile_with_addresses':
                            self._cached_dataset = scan_query.limit(self.sample_size).collect()
                            dataset_with_fids = scan_query.filter(pl.col('fid').is_in(self.fids_to_ensure)).collect()
                            if len(dataset_with_fids) > 0:
                                self._cached_dataset = pl.concat([self._cached_dataset, dataset_with_fids], how='diagonal').unique(subset='fid')

                            self.base_fids = self._cached_dataset['fid']
                            print(f"Established new base FIDs from {name}: {len(self.base_fids)} records")
                        else:
                            print(f"Warning: No base FIDs available for {name}")
                            self._cached_dataset = scan_query.limit(self.sample_size).collect()
                else:
                    print(f"Filtering {name} by {len(self.base_fids)} base FIDs")
                    self._cached_dataset = (scan_query
                        .filter(pl.col('fid').is_in(self.base_fids))
                        .collect())
            else:
                self._cached_dataset = scan_query.collect()
                    
            print(f"Loaded {name}: {len(self._cached_dataset)} records")
            return self._cached_dataset
            
        except Exception as e:
            print(f"Error loading {name}: {str(e)}")
            raise
            return pl.DataFrame()

    def clear_cache(self):
        """Clear the cached dataset"""
        self._cached_dataset = None
        self._cached_name = None        

class FeatureSet:
    """Track feature dependencies and versioning"""
    def __init__(self, name: str, version: str, dependencies: List[str] = None):
        self.name = name
        self.version = version  # Version of feature calculation logic
        self.dependencies = dependencies or []
        self.checkpoint_path = None
        self.last_modified = None

class FeatureEngineering:
    """Enhanced bot detection system"""
    
    def __init__(self, data_path: str, checkpoint_dir: str, fids_to_ensure: List[int] = None):
        self.data_path = data_path
        self.checkpoint_dir = checkpoint_dir
        self.loader = LazyDatasetLoader(data_path, checkpoint_dir, fids_to_ensure=fids_to_ensure)
        self.fids_to_ensure = fids_to_ensure
        
        # Define comprehensive feature dependencies and versions
        self.feature_sets = {
            # Base features
            'profile': FeatureSet('profile', '1.0'),
            'network': FeatureSet('network', '1.0'),
            'temporal': FeatureSet('temporal', '1.0', ['network']),
            
            # Activity features
            'cast': FeatureSet('cast', '1.0'),
            'reaction': FeatureSet('reaction', '1.0'),
            'channel': FeatureSet('channel', '1.0'),
            'verification': FeatureSet('verification', '1.0'),
            
            # Account features
            'user_data': FeatureSet('user_data', '1.0'),
            'storage': FeatureSet('storage', '1.0'),
            'signers': FeatureSet('signers', '1.0'),
            
            # Interaction patterns
            'engagement': FeatureSet('engagement', '1.0', 
                ['cast', 'reaction', 'channel']),
            'mentions': FeatureSet('mentions', '1.0', 
                ['cast', 'network']),
            'reply_patterns': FeatureSet('reply_patterns', '1.0', 
                ['cast', 'temporal']),
            
            # Network quality
            'network_quality': FeatureSet('network_quality', '1.0', 
                ['network', 'engagement']),
            'power_user_interaction': FeatureSet('power_user_interaction', '1.0', 
                ['network', 'temporal']),
            'cluster_analysis': FeatureSet('cluster_analysis', '1.0', 
                ['network', 'engagement']),
            
            # Behavioral patterns
            'activity_patterns': FeatureSet('activity_patterns', '1.0', 
                ['temporal', 'cast', 'reaction']),
            'update_behavior': FeatureSet('update_behavior', '1.0', 
                ['user_data', 'profile']),
            'verification_patterns': FeatureSet('verification_patterns', '1.0', 
                ['verification', 'temporal']),
            
            # Meta features
            'authenticity': FeatureSet('authenticity', '2.0', [
                'profile', 'network', 'channel', 'verification',
                'engagement', 'network_quality', 'activity_patterns'
            ]),
            'influence': FeatureSet('influence', '1.0', [
                'network', 'engagement', 'power_user_interaction'
            ]),
            
            # Final derived features
            'derived': FeatureSet('derived', '2.0', [
                'network', 'temporal', 'authenticity',
                'engagement', 'network_quality', 'influence'
            ]),

            # nindexer features
            'enhanced_network': FeatureSet('enhanced_network', '1.0', 
                ['network']),
            'enhanced_profile': FeatureSet('enhanced_profile', '1.0', 
                ['profile']),
            'neynar_score': FeatureSet('neynar_score', '1.0'),

            'name_patterns': FeatureSet('name_patterns', '1.0', ['profile']),
            'content_patterns': FeatureSet('content_patterns', '1.0', ['cast']),
            'advanced_temporal': FeatureSet('advanced_temporal', '1.0', ['temporal', 'cast', 'reaction']),
            'reward_gaming': FeatureSet('reward_gaming', '1.0', ['cast', 'reaction', 'temporal']),
            'engagement_authenticity': FeatureSet('engagement_authenticity', '1.0', ['network', 'cast', 'reaction'])

        }
        
        # Initialize checkpoint tracking
        self._init_checkpoints()

    def _analyze_name_patterns(self, text: str) -> Dict[str, int]:
        """Analyze username/display name patterns"""
        if not text:
            return {
                'random_numbers': 0,
                'wallet_pattern': 0,
                'excessive_symbols': 0,
                'airdrop_terms': 0,
                'has_year': 0
            }
        
        return {
            'random_numbers': int(bool(re.findall(r'\d{4,}', text))),
            'wallet_pattern': int(bool(re.findall(r'0x[a-fA-F0-9]{40}', text))),
            'excessive_symbols': int(bool(re.findall(r'[_.\-]{2,}', text))),
            'airdrop_terms': int(any(term in text.lower() for term in ['airdrop', 'farm', 'degen', 'wojak'])),
            'has_year': int(bool(re.findall(r'20[12]\d', text)))
        }

    def _analyze_content_patterns(self, text: str) -> Dict[str, int]:
        """Analyze content for spam/bot patterns"""
        if not text:
            return {
                'template_structure': 0,
                'multiple_cta': 0,
                'urgency_terms': 0,
                'excessive_emojis': 0,
                'price_mentions': 0
            }
        
        text = text.lower()
        return {
            'template_structure': int(bool(re.findall(r'\[.*?\]|\{.*?\}|\<.*?\>', text))),
            'multiple_cta': int(len(re.findall(r'click|join|follow|claim|grab', text)) > 2),
            'urgency_terms': int(bool(re.findall(r'hurry|limited|fast|quick|soon|ending', text))),
            'excessive_emojis': int(len(re.findall(r'[\U0001F300-\U0001F9FF]', text)) > 5),
            'price_mentions': int(bool(re.findall(r'\$\d+|\d+\$', text))),
            'excessive_symbols': int(bool(re.findall(r'[_.\-]{2,}', text))),
            'airdrop_terms': int(any(term in text.lower() for term in ['airdrop', 'farm', 'degen', 'wojak'])),
        }
        
    def validate_dimensions(func):
        """Decorator to validate DataFrame dimensions"""
        def wrapper(self, df: pl.DataFrame, *args, **kwargs):
            input_shape = len(df)
            try:
                result = func(self, df, *args, **kwargs)
                if len(result) != input_shape:
                    print(f"Warning: Shape mismatch in {func.__name__}. Input: {input_shape}, Output: {len(result)}")
                    # Don't force join or filtering here. Just warn.
                return result.fill_null(0)
            except Exception as e:
                print(f"Error in {func.__name__}: {str(e)}")
                raise
        return wrapper

        
    def get_dataset_columns(self, name: str) -> List[str]:
        """Get the list of columns from the dataset without loading data"""
        path = f"{self.data_path}/farcaster-{name}-0-1733162400.parquet"
        ds = pl.scan_parquet(path)
        return ds.columns
        
    def _init_checkpoints(self):
        """Initialize checkpoint paths and check existing files"""
        os.makedirs(self.checkpoint_dir, exist_ok=True)
        
        for name, feature_set in self.feature_sets.items():
            path = f"{self.checkpoint_dir}/{name}_features.parquet"
            feature_set.checkpoint_path = path
            
            if os.path.exists(path):
                feature_set.last_modified = os.path.getmtime(path)
    
    def _needs_rebuild(self, feature_set: FeatureSet) -> bool:
        """Check if feature set needs to be rebuilt"""
        # Always rebuild if no checkpoint exists
        if not os.path.exists(feature_set.checkpoint_path):
            return True
                
        return False


    def extract_profile_features(self) -> pl.DataFrame:
        """Extract comprehensive profile features"""
        profiles = self.loader.get_dataset('profile_with_addresses', 
            ['fid', 'fname', 'bio', 'avatar_url', 'verified_addresses', 'display_name'])
        
        # Filter valid profiles and cast fid type immediately
        profiles = (profiles
            .filter(pl.col('fname').is_not_null() & (pl.col('fname') != ""))
            .with_columns(pl.col('fid').cast(pl.Int64)))
        
        df = profiles.with_columns([
            pl.col('fname').str.contains(r'\.eth$').cast(pl.Int32).alias('has_ens'),
            (pl.col('bio').is_not_null() & (pl.col('bio') != "")).cast(pl.Int32).alias('has_bio'),
            pl.col('avatar_url').is_not_null().cast(pl.Int32).alias('has_avatar'),
            pl.when(pl.col('verified_addresses').str.contains(','))
            .then(pl.col('verified_addresses').str.contains(',').cast(pl.Int32) + 1)
            .otherwise(pl.when(pl.col('verified_addresses') != '[]')
                        .then(1)
                        .otherwise(0))
            .alias('verification_count'),
            (pl.col('display_name').is_not_null()).cast(pl.Int32).alias('has_display_name')
        ])
        
        self.loader.clear_cache()
        return df
    def add_blocking_behavior(self, df: pl.DataFrame) -> pl.DataFrame:
        """Add memory-efficient blocking behavior features"""
        blocks = self.loader.get_dataset('blocks', ['blocker_fid', 'blocked_fid'])
        
        blocking_features = (
            blocks.group_by('blocker_fid')
            .agg([
                pl.count().alias('blocks_made'),
                pl.n_unique('blocked_fid').alias('unique_blocks')
            ])
            .with_columns([
                (pl.col('blocks_made') / (pl.col('unique_blocks') + 1)).alias('block_repeat_ratio')
            ])
            .rename({'blocker_fid': 'fid'})
        )
        
        self.loader.clear_cache()
        return df.join(blocking_features, on='fid', how='left').fill_null(0)

    def add_enhanced_verification_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Add verification features with proper string handling"""
        try:
            # Initialize result with defaults
            result = df.with_columns([
                pl.lit(0).alias('total_verifications'),
                pl.lit(0).alias('eth_verifications'),
                pl.lit(0.0).alias('verification_timing_std'),
                pl.lit(0).alias('platforms_verified'),
                pl.lit(None).alias('first_platform_verification'),
                pl.lit(None).alias('last_platform_verification'),
                pl.lit(0).alias('verification_span_days')
            ])
            
            # Process on-chain verifications
            verifications = self.loader.get_dataset('verifications', 
                ['fid', 'claim', 'timestamp', 'deleted_at'])
            
            if verifications is not None and len(verifications) > 0:
                verif_features = (
                    verifications
                    .filter(pl.col('deleted_at').is_null())
                    .with_columns([
                        pl.col('timestamp').cast(pl.Datetime)
                    ])
                    .group_by('fid')
                    .agg([
                        pl.len().alias('total_verifications'),
                        pl.col('claim').str.contains('ethSignature').sum().alias('eth_verifications'),
                        # Convert durations to floats and fill nulls before std()
                        pl.col('timestamp')
                            .diff()
                            .dt.total_seconds()
                            .cast(pl.Float64)
                            .fill_null(0)
                            .std()
                            .fill_null(0)
                            .alias('verification_timing_std')
                    ])
                )
                verif_features = verif_features.unique(subset=['fid']) 
                result = result.join(verif_features, on='fid', how='left')
            
            # Process platform verifications
            acc_verifications = self.loader.get_dataset('account_verifications', 
                ['fid', 'platform', 'platform_username', 'verified_at'])
            
            if acc_verifications is not None and len(acc_verifications) > 0:
                platform_features = (
                    acc_verifications
                    .with_columns([
                        pl.col('platform_username').map_elements(lambda x: len(str(x)) if x else 0, return_dtype=pl.Int64),
                        pl.col('verified_at').cast(pl.Datetime)
                    ])
                    .group_by('fid')
                    .agg([
                        pl.n_unique('platform').alias('platforms_verified'),
                        pl.col('verified_at').min().alias('first_platform_verification'),
                        pl.col('verified_at').max().alias('last_platform_verification')
                    ])
                )
                platform_features = platform_features.unique(subset=['fid']) 
                result = result.join(platform_features, on='fid', how='left')

                result = result.with_columns([
                    # First ensure both columns are Datetime
                    pl.col('last_platform_verification').cast(pl.Datetime),
                    pl.col('first_platform_verification').cast(pl.Datetime)
                ])

                # Compute duration safely in a separate step
                result = result.with_columns([
                    (pl.col('last_platform_verification') - pl.col('first_platform_verification'))
                        .alias('verification_duration')
                ])

                # Now handle the null durations and convert to days
                result = result.with_columns([
                    pl.when(pl.col('verification_duration').is_not_null())
                    .then(
                        pl.col('verification_duration')
                        .dt.total_days()  # This should return Float64 if duration is valid
                        .fill_null(0.0)   # fill null if any appear
                    )
                    .otherwise(0.0)
                    .alias('verification_span_days')
                ])

                # Drop the intermediate column if not needed
                result = result.drop('verification_duration')

            self.loader.clear_cache()
            return result.fill_null(0)
            
        except Exception as e:
            print(f"Error in verification features: {str(e)}")
            raise
            return df

    def add_cast_behavior_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Add cast behavior features including link and media analysis"""
        try:
            base_fids = df['fid']
            print(f"Processing casts for {len(base_fids)} FIDs")
            
            # Get casts with all needed fields
            casts_df = self.loader.get_dataset('casts', columns=[
                'fid', 'text', 'parent_hash', 'mentions', 'deleted_at', 
                'timestamp', 'embeds'  # Adding embeds for media detection
            ])
            
            # Calculate features safely
            valid_casts = casts_df.filter(pl.col('deleted_at').is_null())
            def analyze_spam_patterns(text: str) -> Dict[str, int]:
                if not text:
                    return {'airdrop': 0, 'money': 0, 'rewards': 0, 'claim': 0, 'moxie': 0}
                    
                text = text.lower()
                spam_keywords = ['airdrop', 'money', 'rewards', 'claim', 'moxie', 'nft', 'drop']
                return {
                    word: text.count(word) 
                    for word in spam_keywords
                }
                
            def get_symbol_ratios(text: str) -> Dict[str, float]:
                if not text:
                    return {'at_symbol_ratio': 0, 'dollar_symbol_ratio': 0, 'link_ratio': 0}
                    
                total_length = len(text)
                return {
                    'at_symbol_ratio': text.count('@') / total_length if total_length > 0 else 0,
                    'dollar_symbol_ratio': text.count('$') / total_length if total_length > 0 else 0,
                    'link_ratio': len(re.findall(r'http[s]?://', text)) / total_length if total_length > 0 else 0
                }
            # Helper function to count links in text
            def count_links(text):
                if not text:
                    return 0
                # Look for common URL patterns
                url_patterns = ['http://', 'https://', 'www.']
                return sum(1 for pattern in url_patterns if pattern in text.lower())
            
            # Helper function to count media items in embeds
            def count_media(embeds):
                if not embeds or embeds == '[]':
                    return 0
                try:
                    # Count image URLs in embeds
                    return embeds.lower().count('image')
                except:
                    return 0
            
            # Add link and media detection
            cast_features = (valid_casts
                .with_columns([
                    # Existing features
                    pl.when(pl.col('text').is_not_null())
                    .then(pl.col('text').map_elements(lambda x: len(x) if x else 0, return_dtype=pl.Int64))
                    .otherwise(0)
                    .alias('cast_length'),
                    pl.col('parent_hash').is_not_null().cast(pl.Int32).alias('is_reply'),
                    (pl.col('mentions').is_not_null() & 
                    (pl.col('mentions') != '') & 
                    (pl.col('mentions') != '[]')).cast(pl.Int32).alias('has_mentions'),
                    
                    # New features for links and media
                    pl.when(pl.col('text').is_not_null())
                    .then(pl.col('text').map_elements(count_links, return_dtype=pl.Int32))
                    .otherwise(0)
                    .alias('link_count'),
                    
                    pl.when(pl.col('embeds').is_not_null())
                    .then(pl.col('embeds').map_elements(count_media, return_dtype=pl.Int32))
                    .otherwise(0)
                    .alias('media_count'),
                    
                    # Flag for casts containing both link and media
                    (pl.when(pl.col('text').is_not_null())
                    .then(pl.col('text').map_elements(count_links, return_dtype=pl.Int32))
                    .otherwise(0) > 0 &
                    pl.when(pl.col('embeds').is_not_null())
                    .then(pl.col('embeds').map_elements(count_media, return_dtype=pl.Int32))
                    .otherwise(0) > 0)
                    .cast(pl.Int32)
                    .alias('has_link_and_media'),

                    pl.col('text').map_elements(analyze_spam_patterns, return_dtype=pl.Utf8).alias('spam_counts'),
                    pl.col('text').map_elements(get_symbol_ratios, return_dtype=pl.Utf8).alias('symbol_ratios'),
                    pl.col('text').map_elements(self._analyze_content_patterns, return_dtype=pl.Utf8).alias('content_patterns')

                ])
                .group_by('fid')
                .agg([
                    # Existing metrics
                    pl.len().alias('cast_count'),
                    pl.col('cast_length').mean().alias('avg_cast_length'),
                    pl.col('is_reply').sum().alias('reply_count'),
                    pl.col('has_mentions').sum().alias('mentions_count'),
                    
                    # New metrics for links
                    pl.col('link_count').sum().alias('total_links'),
                    (pl.col('link_count') > 0).sum().alias('casts_with_links'),
                    (pl.col('link_count') / pl.len()).alias('link_ratio'),
                    
                    # New metrics for media
                    pl.col('media_count').sum().alias('total_media'),
                    (pl.col('media_count') > 0).sum().alias('casts_with_media'),
                    (pl.col('media_count') / pl.len()).alias('media_ratio'),
                    
                    # Spam metrics
                    (pl.col('spam_counts').map_elements(lambda x: x['airdrop']).sum() / pl.len())
                        .alias('airdrop_mention_ratio'),
                    (pl.col('spam_counts').map_elements(lambda x: sum(x.values())).sum() / pl.len())
                        .alias('spam_keyword_ratio'),
                        
                    # Symbol usage metrics
                    pl.col('symbol_ratios').map_elements(lambda x: x['at_symbol_ratio']).mean()
                        .alias('avg_at_symbol_ratio'),
                    pl.col('symbol_ratios').map_elements(lambda x: x['dollar_symbol_ratio']).mean()
                        .alias('avg_dollar_symbol_ratio'),
                    pl.col('symbol_ratios').map_elements(lambda x: x['link_ratio']).mean()
                        .alias('avg_link_ratio'),
                        
                    # Combined metrics
                    pl.col('has_link_and_media').sum().alias('casts_with_both'),
                    (pl.col('has_link_and_media').sum() / pl.len()).alias('multimedia_ratio')

                    # Content pattern metrics
                    (pl.col('content_patterns').map_elements(lambda x: x['template_structure'], return_dtype=pl.Float64).sum() / pl.len())
                        .alias('template_usage_ratio'),
                    (pl.col('content_patterns').map_elements(lambda x: x['multiple_cta'], return_dtype=pl.Float64).sum() / pl.len())
                        .alias('cta_heavy_ratio'),
                    (pl.col('content_patterns').map_elements(lambda x: x['urgency_terms'], return_dtype=pl.Float64).sum() / pl.len())
                        .alias('urgency_ratio'),
                    (pl.col('content_patterns').map_elements(lambda x: x['excessive_emojis'], return_dtype=pl.Float64).sum() / pl.len())
                        .alias('emoji_spam_ratio'),
                    (pl.col('content_patterns').map_elements(lambda x: x['price_mentions'], return_dtype=pl.Float64).sum() / pl.len())
                        .alias('price_mention_ratio'),
                    (pl.col('content_patterns').map_elements(lambda x: x['excessive_symbols'], return_dtype=pl.Float64).sum() / pl.len())
                        .alias('symbol_spam_ratio'),
                    (pl.col('content_patterns').map_elements(lambda x: x['airdrop_terms'], return_dtype=pl.Float64).sum() / pl.len())
                        .alias('airdrop_term_ratio')
                ]))
            
            cast_features = cast_features.unique(subset=['fid']) 
            # Join and handle nulls
            result = df.join(cast_features, on='fid', how='left').fill_null(0)
            
            # Add derived ratios
            result = result.with_columns([
                # Percentage of casts that contain links
                (pl.col('casts_with_links') / pl.col('cast_count')).alias('link_usage_rate'),
                # Percentage of casts that contain media
                (pl.col('casts_with_media') / pl.col('cast_count')).alias('media_usage_rate'),
                # Average number of links per cast with links
                (pl.col('total_links') / (pl.col('casts_with_links') + 1)).alias('avg_links_per_link_cast'),
                # Average number of media items per cast with media
                (pl.col('total_media') / (pl.col('casts_with_media') + 1)).alias('avg_media_per_media_cast')
            ])
            
            return result
                
        except Exception as e:
            print(f"Error in cast behavior: {str(e)}")
            raise
            return df


    def add_influence_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Add influence features with proper error handling"""
        try:
            # Ensure required columns exist and are properly initialized
            required_cols = ['follower_count', 'following_count', 'total_reactions', 'cast_count']
            for col in required_cols:
                if col not in df.columns:
                    df = df.with_columns(pl.lit(0).alias(col))
            
            # Calculate time span if possible
            if 'first_follow' in df.columns and 'last_follow' in df.columns:
                df = df.with_columns([
                    pl.when(pl.col('last_follow').is_not_null() & pl.col('first_follow').is_not_null())
                    .then((pl.col('last_follow') - pl.col('first_follow')).dt.total_hours())
                    .otherwise(0)
                    .alias('follow_time_span_hours')
                ])
            else:
                df = df.with_columns(pl.lit(0).alias('follow_time_span_hours'))

            # Calculate influence metrics safely
            df = df.with_columns([
                # Normalize influence metrics
                ((pl.col('follower_count').fill_null(0) * 0.4 +
                pl.col('total_reactions').fill_null(0) * 0.3 +
                pl.col('cast_count').fill_null(0) * 0.3) / 
                (pl.col('following_count').fill_null(0) + 1)
                ).alias('influence_score'),
                
                # Safe engagement rate calculation
                (pl.when(pl.col('cast_count') > 0)
                .then(pl.col('total_reactions') / pl.col('cast_count'))
                .otherwise(0)
                ).alias('engagement_rate'),
                
                # Safe follower growth rate calculation
                (pl.when(pl.col('follow_time_span_hours') > 0)
                .then(pl.col('follower_count') / pl.col('follow_time_span_hours'))
                .otherwise(0)
                ).alias('follower_growth_rate')
            ])
            
            return df
            
        except Exception as e:
            print(f"Error in influence features: {str(e)}")
            raise
            return df

    def add_storage_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Add storage features with updated functions"""
        storage = self.loader.get_dataset('storage', ['fid', 'units', 'deleted_at'])
        
        storage_features = (
            storage.filter(pl.col('deleted_at').is_null())
            .group_by('fid')
            .agg([
                pl.col('units').mean().alias('avg_storage_units'),
                pl.col('units').max().alias('max_storage_units'),
                pl.len().alias('storage_update_count')
            ])
        )
        
        self.loader.clear_cache()
        storage_features = storage_features.unique(subset=['fid']) 
        return df.join(storage_features, on='fid', how='left').fill_null(0)
    def add_user_data_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Extract features from user_data with better error handling"""
        try:
            user_data = self.loader.get_dataset('user_data', 
                ['fid', 'type', 'timestamp', 'deleted_at'])
            
            if user_data is None or len(user_data) == 0:
                return df.with_columns([
                    pl.lit(0).alias('total_user_data_updates'),
                    pl.lit(0.0).alias('avg_update_interval')
                ])
                
            update_features = (
                user_data.filter(pl.col('deleted_at').is_null())
                .group_by('fid')
                .agg([
                    pl.len().alias('total_user_data_updates'),
                    pl.col('timestamp').diff().mean().dt.total_hours().fill_null(0)
                        .alias('avg_update_interval')
                ])
            )
            
            self.loader.clear_cache()
            update_features = update_features.unique(subset=['fid']) 
            return df.join(update_features, on='fid', how='left').fill_null(0)
        except Exception as e:
            print(f"Error in user_data features: {str(e)}")
            raise
            return df.with_columns([
                pl.lit(0).alias('total_user_data_updates'),
                pl.lit(0.0).alias('avg_update_interval')
            ])
    def add_signer_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Extract features from signer behavior"""
        signers = self.loader.get_dataset('signers', 
            ['fid', 'timestamp', 'deleted_at'])
        
        signer_features = (
            signers.filter(pl.col('deleted_at').is_null())
            .group_by('fid')
            .agg([
                pl.count().alias('signer_count'),
                pl.col('timestamp').diff().mean().dt.total_hours().alias('avg_hours_between_signers'),
                pl.col('timestamp').diff().std().dt.total_hours().alias('std_hours_between_signers')
            ])
        )
        
        self.loader.clear_cache()
        signer_features = signer_features.unique(subset=['fid']) 
        return df.join(signer_features, on='fid', how='left').fill_null(0)
        
    def add_reaction_patterns(self, df: pl.DataFrame) -> pl.DataFrame:
        """Add reaction pattern features with dimension validation"""
        try:
            base_fids = df['fid']
            print(f"Processing reactions for {len(base_fids)} FIDs")
            
            reactions = self.loader.get_dataset('reactions', 
                ['fid', 'reaction_type', 'target_fid', 'timestamp', 'deleted_at'])
            
            # First filter by base FIDs
            reactions = reactions.filter(pl.col('fid').is_in(base_fids))
            reaction_features = (
                reactions.filter(pl.col('deleted_at').is_null())
                .with_columns([
                    pl.col('timestamp').cast(pl.Datetime)
                ])
                .sort('timestamp')
                .group_by('fid')
                .agg([
                    pl.len().alias('total_reactions'),
                    (pl.col('reaction_type') == 1).sum().alias('like_count'),
                    (pl.col('reaction_type') == 2).sum().alias('recast_count'),
                    pl.n_unique('target_fid').alias('unique_users_reacted_to'),
                    pl.col('timestamp').diff().dt.total_hours().mean().alias('avg_hours_between_reactions'),
                    pl.col('timestamp').diff().dt.total_hours().std().alias('std_hours_between_reactions')
                ])
            )

            # Calculate ratios only after joining back to maintain dimensions
            reaction_features = reaction_features.unique(subset=['fid']) 
            result = df.join(reaction_features, on='fid', how='left', coalesce=True).fill_null(0)
            
            result = result.with_columns([
                (pl.col('like_count') / (pl.col('total_reactions') + 1)).alias('like_ratio'),
                (pl.col('recast_count') / (pl.col('total_reactions') + 1)).alias('recast_ratio'),
                (pl.col('unique_users_reacted_to') / (pl.col('total_reactions') + 1)).alias('reaction_diversity'),
                (pl.col('like_count') / (pl.col('recast_count') + 1)).alias('likes_to_recasts_ratio'),
            ])
            
            # Verify dimensions
            if len(result) != len(df):
                print(f"Warning: Reaction features shape mismatch. Expected {len(df)}, got {len(result)}")
                result = result.filter(pl.col('fid').is_in(base_fids))
                
            self.loader.clear_cache()
            return result
            
        except Exception as e:
            print(f"Error in reaction patterns: {str(e)}")
            raise
            return df

    def build_network_quality_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Build network quality features with safer dependency handling"""
        try:
            # Ensure required base metrics exist
            base_metrics = {
                'engagement_score': 0.0,
                'following_count': 0,
                'follower_count': 0
            }
            
            result = self._validate_and_ensure_features(df, base_metrics)
            
            # Load power users
            power_users = self.loader.get_dataset('power_users', ['fid'])
            if power_users is None or len(power_users) == 0:
                return result.with_columns([
                    pl.lit(0).alias('power_reply_count'),
                    pl.lit(0).alias('power_mentions_count')
                ])
            
            # Calculate power user metrics
            power_fids = power_users['fid'].cast(pl.Int64).unique()
            casts = self.loader.get_dataset('casts', 
                ['fid', 'parent_fid', 'mentions', 'deleted_at'])
                
            if casts is not None and len(casts) > 0:
                power_fid_str = str(power_fids[0])

                power_metrics = (
                    casts.filter(pl.col('deleted_at').is_null())
                    .with_columns([
                        pl.col('parent_fid').cast(pl.Int64).is_in(power_fids)
                            .alias('is_power_reply'),
                        pl.when(pl.col('mentions').is_not_null() & pl.col('mentions').str.contains(power_fid_str))
                        .then(1)
                        .otherwise(0)
                        .alias('has_power_mention')
                    ])
                    .group_by('fid')
                    .agg([
                        pl.sum('is_power_reply').alias('power_reply_count'),
                        pl.sum('has_power_mention').alias('power_mentions_count')
                    ])
                )
                
                result = result.join(power_metrics, on='fid', how='left').fill_null(0)
                
            return result
            
        except Exception as e:
            print(f"Error in network quality features: {str(e)}")
            raise
            return df.with_columns([
                pl.lit(0).alias('power_reply_count'),
                pl.lit(0).alias('power_mentions_count')
            ])

    def add_network_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Add network features with proper error handling and null safety"""
        try:
            links = self.loader.get_dataset('links', 
                ['fid', 'target_fid', 'timestamp', 'deleted_at'])
            
            # Filter valid links first
            valid_links = links.filter(pl.col('deleted_at').is_null())
            
            # Calculate following patterns safely
            following = (valid_links
                .group_by('fid')
                .agg([
                    pl.len().alias('following_count'),
                    pl.n_unique('target_fid').alias('unique_following_count'),
                    pl.col('timestamp').min().alias('first_follow'),
                    pl.col('timestamp').max().alias('last_follow')
                ])
                .fill_null(0))
            
            # Calculate follower patterns separately
            followers = (valid_links
                .group_by('target_fid')
                .agg([
                    pl.len().alias('follower_count'),
                    pl.n_unique('fid').alias('unique_follower_count')
                ])
                .rename({'target_fid': 'fid'})
                .fill_null(0))
            
            # Join both patterns
            result = df.join(following, on='fid', how='left').fill_null(0)
            result = result.join(followers, on='fid', how='left').fill_null(0)
            
            # Calculate ratios safely with null handling
            result = result.with_columns([
                (pl.col('follower_count') / (pl.col('following_count') + 1))
                    .alias('follower_ratio'),
                (pl.col('unique_follower_count') / (pl.col('unique_following_count') + 1))
                    .alias('unique_follower_ratio'),
                
                # Add log transformations
                (pl.col('follower_count') / (pl.col('following_count') + 1))
                    .log1p()
                    .alias('follower_ratio_log'),
                (pl.col('unique_follower_count') / (pl.col('unique_following_count') + 1))
                    .log1p()
                    .alias('unique_follower_ratio_log')
            ])
            
            return result
                
        except Exception as e:
            print(f"Error in network features: {str(e)}")
            raise
            return df.with_columns([
                pl.lit(0).alias('following_count'),
                pl.lit(0).alias('unique_following_count'),
                pl.lit(0).alias('follower_count'),
                pl.lit(0).alias('unique_follower_count'),
                pl.lit(0.0).alias('follower_ratio'),
                pl.lit(0.0).alias('unique_follower_ratio'),
                pl.lit(0.0).alias('follower_ratio_log'),
                pl.lit(0.0).alias('unique_follower_ratio_log')
            ])
    def add_temporal_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Add enhanced temporal features with burst detection"""
        try:
            links = self.loader.get_dataset('links', ['fid', 'timestamp', 'deleted_at'])
            
            # Ensure timestamp is datetime type
            valid_links = (links
                .filter(pl.col('deleted_at').is_null())
                .filter(pl.col('timestamp').is_not_null())
                .with_columns([
                    pl.col('timestamp').cast(pl.Datetime).alias('timestamp')
                ]))
            
            temporal_features = (valid_links
                .group_by('fid')
                .agg([
                    # Basic temporal features
                    pl.len().alias('total_activity'),
                    pl.col('timestamp').diff().dt.total_hours().mean().alias('avg_hours_between_actions'),
                    pl.col('timestamp').diff().dt.total_hours().std().alias('std_hours_between_actions'),
                    pl.col('timestamp').dt.weekday().std().alias('weekday_variance'),
                    (pl.col('timestamp').diff().dt.total_hours() < 1).sum().alias('rapid_actions'),
                    (pl.col('timestamp').diff().dt.total_hours() > 24).sum().alias('long_gaps'),
                    
                    # New temporal features
                    pl.col('timestamp').diff().dt.total_hours().quantile(0.9).alias('p90_time_between_actions'),
                    pl.col('timestamp').diff().dt.total_hours().quantile(0.1).alias('p10_time_between_actions'),
                    
                    # Calculate burst ratio (actions within 1 hour of each other)
                    (pl.col('timestamp').diff().dt.total_hours() < 1).sum().alias('actions_in_bursts'),
                    
                    # Calculate velocity
                    (pl.col('timestamp').max() - pl.col('timestamp').min()).dt.total_hours().alias('time_span')
                ]))
            
            # Add derived temporal metrics
            result = df.join(temporal_features, on='fid', how='left').fill_null(0)
            result = result.with_columns([
                # Burst activity ratio
                (pl.col('actions_in_bursts') / (pl.col('total_activity') + 1)).alias('burst_activity_ratio'),
                
                # Activity spread (ratio of actual timespan to expected even distribution)
                (pl.col('time_span') / ((pl.col('total_activity') + 1) * pl.col('avg_hours_between_actions'))).alias('activity_spread'),
                
                # Temporal irregularity (variation in action timing)
                (pl.col('std_hours_between_actions') / (pl.col('avg_hours_between_actions') + 1)).alias('temporal_irregularity'),
                
                # Follow velocity (follows per hour)
                (pl.col('total_activity') / (pl.col('time_span') + 1)).alias('follow_velocity')
            ])
            
            return result.fill_null(0)
                
        except Exception as e:
            print(f"Error in temporal features: {str(e)}")
            raise
            return df.fill_null(0)
# 
    # def add_advanced_temporal_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Add advanced temporal features for bot detection"""
        try:
            activities = []
            
            # Collect cast timestamps
            casts = self.loader.get_dataset('casts', ['fid', 'timestamp', 'deleted_at'])
            if casts is not None:
                valid_casts = casts.filter(pl.col('deleted_at').is_null())
                activities.append(valid_casts.select(['fid', 'timestamp']))
            
            # Collect reaction timestamps
            reactions = self.loader.get_dataset('reactions', ['fid', 'timestamp', 'deleted_at'])
            if reactions is not None:
                valid_reactions = reactions.filter(pl.col('deleted_at').is_null())
                activities.append(valid_reactions.select(['fid', 'timestamp']))
            
            if not activities:
                return df
            
            # Combine all activities
            all_activities = pl.concat(activities)
            
            temporal_features = (all_activities
                .sort(['fid', 'timestamp'])
                .group_by('fid')
                .agg([
                    # Robotic timing detection
                    (pl.col('timestamp').diff().dt.total_seconds().std() < 1)
                        .cast(pl.Int32)
                        .alias('has_robotic_timing'),
                    
                    # Rapid actions
                    (pl.col('timestamp').diff().dt.total_seconds() < 2)
                        .sum()
                        .alias('rapid_action_count'),
                    
                    # Activity bursts
                    (pl.col('timestamp').diff().dt.total_hours().gt(24).sum())
                        .alias('long_dormancy_periods'),
                        
                    # Time between bursts
                    pl.col('timestamp').diff().dt.total_hours().mean().alias('avg_burst_interval')
                ]))
            
            return df.join(temporal_features, on='fid', how='left').fill_null(0)
            
        except Exception as e:
            print(f"Error in advanced temporal features: {str(e)}")
            raise
        
    def add_power_user_interaction_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Enhanced power user interaction analysis"""
        try:
            # Load power users
            power_users = self.loader.get_dataset('warpcast_power_users', ['fid'])
            if power_users is None or len(power_users) == 0:
                print("Warning: No power users found")
                return df.with_columns([
                    pl.lit(0).alias('power_user_replies'),
                    pl.lit(0).alias('power_user_mentions'),
                    pl.lit(0).alias('power_user_reactions'),
                    pl.lit(0).alias('power_user_interaction_ratio')
                ])
            
            # Ensure power_fids are Int64
            power_fids = power_users['fid'].cast(pl.Int64).unique()
            
            # Get interactions with power users
            casts = self.loader.get_dataset('casts', 
                ['fid', 'parent_fid', 'mentions', 'timestamp', 'deleted_at'])
            
            # Process cast interactions
            power_fid_str = str(power_fids[0])
            power_cast_features = (
                casts.filter(pl.col('deleted_at').is_null())
                .with_columns([
                    pl.col('parent_fid').cast(pl.Int64).is_in(power_fids).alias('is_power_reply'),
               pl.when(pl.col('mentions').is_not_null() & pl.col('mentions').str.contains(power_fid_str))
        .then(1)
        .otherwise(0)
        .alias('has_power_mention')

                ])
                .group_by('fid')
                .agg([
                    pl.sum('is_power_reply').alias('power_user_replies'),
                    pl.sum('has_power_mention').alias('power_user_mentions'),
                    pl.len().alias('total_casts')
                ])
            )
            
            # Get reaction data
            reactions = self.loader.get_dataset('reactions', 
                ['fid', 'target_fid', 'timestamp', 'deleted_at'])
            
            power_reaction_features = (
                reactions.filter(pl.col('deleted_at').is_null())
                .with_columns([
                    pl.col('target_fid').cast(pl.Int64).is_in(power_fids).alias('is_power_reaction')
                ])
                .group_by('fid')
                .agg([
                    pl.sum('is_power_reaction').alias('power_user_reactions'),
                    pl.len().alias('total_reactions')
                ])
            )
            
            # Join features
            result = df.join(power_cast_features, on='fid', how='left')
            result = result.join(power_reaction_features, on='fid', how='left')
            
            # Calculate interaction ratios
            result = result.with_columns([
                pl.col('power_user_replies').fill_null(0),
                pl.col('power_user_mentions').fill_null(0),
                pl.col('power_user_reactions').fill_null(0),
                pl.col('total_casts').fill_null(0),
                pl.col('total_reactions').fill_null(0)
            ])
            
            # Calculate overall interaction ratio
            result = result.with_columns([
                ((pl.col('power_user_replies') + 
                pl.col('power_user_mentions') + 
                pl.col('power_user_reactions')) / 
                (pl.col('total_casts') + pl.col('total_reactions') + 1)
                ).alias('power_user_interaction_ratio')
            ])
            
            return result.fill_null(0)
            
        except Exception as e:
            print(f"Error in power user interaction features: {str(e)}")
            raise
            return df.with_columns([
                pl.lit(0).alias('power_user_replies'),
                pl.lit(0).alias('power_user_mentions'),
                pl.lit(0).alias('power_user_reactions'),
                pl.lit(0).alias('power_user_interaction_ratio')
            ])

    def add_activity_patterns_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Add activity patterns with fully safe calculations"""
        try:
            print("Processing activity patterns...")
            
            # Get activity data
            casts = self.loader.get_dataset('casts', ['fid', 'timestamp', 'deleted_at'])
            reactions = self.loader.get_dataset('reactions', ['fid', 'timestamp', 'deleted_at'])
            
            # Initialize result with default values
            result = df.with_columns([
                pl.lit(0.0).alias('hour_diversity'),
                pl.lit(0.0).alias('weekday_diversity'),
                pl.lit(0.0).alias('total_activities')
            ])
            
            # Process activities if data exists
            if casts is not None and reactions is not None:
                # Combine valid activities
                activities = pl.concat([
                    casts.filter(pl.col('deleted_at').is_null())
                        .select(['fid', 'timestamp']),
                    reactions.filter(pl.col('deleted_at').is_null())
                        .select(['fid', 'timestamp'])
                ])
                
                if len(activities) > 0:
                    # Calculate activity metrics
                    activity_features = (activities
                        .with_columns([
                            pl.col('timestamp').cast(pl.Datetime).dt.hour().alias('hour'),
                            pl.col('timestamp').cast(pl.Datetime).dt.weekday().alias('weekday')
                        ])
                        .group_by('fid')
                        .agg([
                            pl.col('hour').value_counts()
                                .std().fill_null(0).alias('hour_diversity'),
                            pl.col('weekday').value_counts()
                                .std().fill_null(0).alias('weekday_diversity'),
                            pl.len().alias('total_activities')
                        ])
                    )
                    
                    # Update result with calculated features
                    result = df.join(activity_features, on='fid', how='left').fill_null(0)
            
            print("Activity patterns calculated successfully")
            return result
            
        except Exception as e:
            print(f"Error in activity patterns: {str(e)}")
            raise
            return df.with_columns([
                pl.lit(0.0).alias('hour_diversity'),
                pl.lit(0.0).alias('weekday_diversity'),
                pl.lit(0.0).alias('total_activities')
            ])
    def verify_matrix(self, df: pl.DataFrame):
        """Verify the final feature matrix has no list columns"""
        for col in df.columns:
            dtype = df[col].dtype
            if str(dtype).startswith('List'):
                raise ValueError(f"Column {col} is still a list type: {dtype}")
            if dtype not in [pl.Float64, pl.Int64]:
                raise ValueError(f"Column {col} is not numeric: {dtype}")
                
    def add_mentions_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Analyze mention patterns with proper null handling"""
        try:
            base_fids = df['fid']
            print(f"Processing mentions for {len(base_fids)} FIDs")
            
            casts = self.loader.get_dataset('casts', ['fid', 'mentions', 'deleted_at'])
            
            # Filter by base FIDs first
            casts = casts.filter(pl.col('fid').is_in(base_fids))
            
            # Parse mentions as JSON and handle counts
            mention_features = (
                casts.filter(pl.col('deleted_at').is_null())
                .with_columns([
                    # Parse JSON string to array and count elements
                    pl.when(
                        pl.col('mentions').is_not_null() & 
                        (pl.col('mentions') != '') & 
                        (pl.col('mentions') != '[]')
                    )
                    .then(pl.col('mentions').str.json_decode().list.len())
                    .otherwise(0)
                    .alias('mention_count'),
                    
                    # Flag for casts with mentions
                    (pl.col('mentions').is_not_null() & 
                    (pl.col('mentions') != '') & 
                    (pl.col('mentions') != '[]')
                    ).cast(pl.Int32).alias('has_mentions')
                ])
                .group_by('fid')
                .agg([
                    # Count total casts with mentions
                    pl.col('has_mentions').sum().alias('casts_with_mentions'),
                    # Total mentions
                    pl.col('mention_count').sum().alias('total_mentions'),
                    # Average mentions per cast
                    pl.col('mention_count').mean().alias('avg_mentions_per_cast')
                ])
            )
            
            # Join and add ratios
            result = df.join(mention_features, on='fid', how='left', coalesce=True).fill_null(0)
            
            # Add derived metrics
            result = result.with_columns([
                (pl.col('casts_with_mentions') / (pl.col('cast_count') + 1)).alias('mention_frequency'),
                (pl.col('avg_mentions_per_cast') / (pl.col('cast_count') + 1)).alias('mention_ratio')
            ])
            
            print(f"Mentions features complete. Shape: {result.shape}")
            self.loader.clear_cache()
            return result
            
        except Exception as e:
            print(f"Error in mentions features: {str(e)}")
            raise
            return df

    def add_reply_patterns_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Add reply features with updated functions"""
        casts = self.loader.get_dataset('casts', 
            ['fid', 'parent_hash', 'parent_fid', 'timestamp', 'deleted_at'])
        
        reply_features = (
            casts.filter(pl.col('deleted_at').is_null())
            .filter(pl.col('parent_hash').is_not_null())
            .group_by('fid')
            .agg([
                pl.len().alias('total_replies'),
                pl.n_unique('parent_fid').alias('unique_users_replied_to'),
                pl.col('timestamp').diff().mean().dt.total_seconds()
                    .alias('avg_seconds_between_replies'),
                pl.col('timestamp').diff().std().dt.total_seconds()
                    .alias('std_seconds_between_replies')
            ])
            .with_columns([
                (pl.col('unique_users_replied_to') / pl.col('total_replies'))
                    .alias('reply_diversity'),
                (pl.col('std_seconds_between_replies') / 
                pl.col('avg_seconds_between_replies')).alias('reply_timing_variability')
            ])
        )
        
        self.loader.clear_cache()
        return df.join(reply_features, on='fid', how='left').fill_null(0)

    # def add_cluster_analysis_features(self, df: pl.DataFrame) -> pl.DataFrame:
    #     """Analyze network clustering with updated functions"""
    #     try:
    #         links = self.loader.get_dataset('links', 
    #             ['fid', 'target_fid', 'deleted_at'])
            
    #         valid_links = links.filter(pl.col('deleted_at').is_null())
            
    #         # Calculate clustering features
    #         cluster_features = (
    #             valid_links.join(
    #                 valid_links.rename({'fid': 'mutual_fid', 'target_fid': 'mutual_target'}),
    #                 left_on='target_fid',
    #                 right_on='mutual_fid'
    #             )
    #             .group_by('fid')
    #             .agg([
    #                 pl.n_unique('mutual_target').alias('mutual_connections'),
    #                 pl.len().alias('potential_triangles')
    #             ])
    #             .with_columns([
    #                 (pl.col('mutual_connections') / (pl.col('potential_triangles') + 1))
    #                 .alias('clustering_coefficient')
    #             ])
    #         )
            
    #         self.loader.clear_cache()
    #         return df.join(cluster_features, on='fid', how='left').fill_null(0)
            
    #     except Exception as e:
    #         print(f"Error in cluster analysis: {str(e)}")
    #         raise
    #         return df.with_columns([
    #             pl.lit(0).alias('mutual_connections'),
    #             pl.lit(0).alias('potential_triangles'),
    #             pl.lit(0.0).alias('clustering_coefficient')
    #         ])

    def add_authenticity_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Add authenticity features with comprehensive null handling"""
        try:
            print("Building authenticity features...")
            
            # Initialize with safe default values
            result = df.clone()
            required_cols = {
                'has_bio': 0,
                'has_avatar': 0,
                'verification_count': 0,
                'has_ens': 0,
                'following_count': 0.0,
                'follower_count': 0.0,
                'total_updates': 0,
                'avg_update_interval': 0.0,
                'profile_update_consistency': 0.0
            }
            
            # Ensure all required columns exist with proper types
            for col, default in required_cols.items():
                if col not in result.columns:
                    print(f"Adding missing column {col} with default {default}")
                    result = result.with_columns(pl.lit(default).alias(col))
                
                # Fill nulls with defaults
                result = result.with_columns(
                    pl.col(col).fill_null(default).alias(col)
                )
            
            result = result.with_columns(
                pl.col('total_updates').cast(pl.Float64).fill_null(0),
                pl.col('profile_update_consistency').cast(pl.Float64).fill_null(0)
            )

            # Safe calculations with explicit null handling
            result = result.with_columns([
                # Profile completeness (0-1) with safe operations
                ((pl.col('has_bio').fill_null(0) + 
                pl.col('has_avatar').fill_null(0) + 
                pl.col('has_ens').fill_null(0) + 
                (pl.col('verification_count').fill_null(0) > 0).cast(pl.Int64)) / 4.0
                ).alias('profile_completeness'),
                
                # Network balance (0-1) with safe division
                (pl.when(pl.col('following_count').fill_null(0) + pl.col('follower_count').fill_null(0) > 0)
                .then(1.0 - (pl.col('following_count').fill_null(0) - pl.col('follower_count').fill_null(0)).abs() /
                    (pl.col('following_count').fill_null(0) + pl.col('follower_count').fill_null(0)))
                .otherwise(0.0)
                ).alias('network_balance'),
                
                # Update naturalness (0-1) with safe comparisons
                    (pl.when(pl.col('total_updates') > 0)
                .then(1.0 - pl.col('profile_update_consistency').clip(0.0, 1.0))
                .otherwise(0.0))
                .alias('update_naturalness')
            ])
            
            # Calculate final authenticity score with weights
            result = result.with_columns([
                (pl.col('profile_completeness').fill_null(0.0) * 0.4 +
                pl.col('network_balance').fill_null(0.0) * 0.3 +
                pl.col('update_naturalness').fill_null(0.0) * 0.3
                ).alias('authenticity_score')
            ])
            
            print("Authenticity features completed successfully")
            return result.drop(['profile_completeness', 'network_balance', 'update_naturalness'])
            
        except Exception as e:
            print(f"Error in authenticity features: {str(e)}")
            raise
            return df.with_columns(pl.lit(0.0).alias('authenticity_score'))
    def add_update_behavior_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Add update behavior features with comprehensive null handling"""
        try:
            print("Building update behavior features...")
            
            # Initialize result with default values
            result = df.clone().with_columns([
                pl.lit(0.0).alias('profile_update_consistency'),
                pl.lit(0).alias('total_updates'),
                pl.lit(0.0).alias('avg_update_interval'),
                pl.lit(0.0).alias('update_time_std')
            ])
            
            # Get user data
            user_data = self.loader.get_dataset('user_data', ['fid', 'timestamp', 'deleted_at'])
            if user_data is None or len(user_data) == 0:
                return result
                
            # Process updates with strict null handling
            valid_updates = (user_data
                .filter(pl.col('deleted_at').is_null())
                .filter(pl.col('timestamp').is_not_null())
                .with_columns([
                    pl.col('timestamp').cast(pl.Datetime).alias('timestamp')
                ]))
            
            if len(valid_updates) == 0:
                return result
                            
            update_metrics = (valid_updates
                .sort(['fid', 'timestamp'])
                .group_by('fid')
                .agg([
                    pl.len().alias('total_updates'),
                    pl.col('timestamp').diff().dt.total_hours().mean().alias('avg_update_interval'),
                    pl.col('timestamp').diff().dt.total_hours().std().alias('update_time_std')
                ]))

            # Ensure all columns are numeric and nulls are handled
            update_metrics = update_metrics.with_columns([
                pl.col('total_updates').cast(pl.Float64).fill_null(0),
                pl.col('avg_update_interval').cast(pl.Float64).fill_null(0),
                pl.col('update_time_std').cast(pl.Float64).fill_null(0)
            ])

            update_metrics = update_metrics.with_columns([
                pl.when(pl.col('avg_update_interval') > 0)
                .then(pl.col('update_time_std') / pl.col('avg_update_interval'))
                .otherwise(0.0)
                .alias('profile_update_consistency')
            ])

            # Join new features safely
            update_metrics = update_metrics.unique(subset=['fid']) 
            result = result.join(update_metrics, on='fid', how='left')
            
            # Fill any remaining nulls
            result = result.with_columns([
                pl.col('total_updates').cast(pl.Float64).fill_null(0),
                pl.col('avg_update_interval').cast(pl.Float64).fill_null(0),
                pl.col('update_time_std').cast(pl.Float64).fill_null(0),
                pl.col('profile_update_consistency').cast(pl.Float64).fill_null(0)
            ])
            
            self.loader.clear_cache()
            print("Update behavior features completed successfully")
            return result
            
        except Exception as e:
            print(f"Error in update behavior features: {str(e)}")
            raise
            print(f"Returning dataframe with default values")
            return df.with_columns([
                pl.lit(0.0).alias('profile_update_consistency'),
                pl.lit(0).alias('total_updates'),
                pl.lit(0.0).alias('avg_update_interval'),
                pl.lit(0.0).alias('update_time_std')
            ])

    def add_verification_patterns_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Add verification patterns with safe calculations"""
        try:
            # Initialize with default columns
            result = df.clone()
            default_cols = {
                'avg_hours_between_verifications': 0.0,
                'std_hours_between_verifications': 0.0,
                'rapid_verifications': 0,
                'avg_hours_between_platform_verifs': 0.0,
                'std_hours_between_platform_verifs': 0.0
            }
            
            # Add on-chain verification patterns
            verifications = self.loader.get_dataset('verifications', 
                ['fid', 'timestamp', 'deleted_at'])
            
            if verifications is not None and len(verifications) > 0:
                valid_verifs = verifications.filter(pl.col('deleted_at').is_null())
                
                if len(valid_verifs) > 0:
                    verif_patterns = (
                        valid_verifs
                        .with_columns(pl.col('timestamp').cast(pl.Datetime))
                        .group_by('fid')
                        .agg([
                            pl.col('timestamp').diff().dt.total_hours().mean().fill_null(0)
                                .alias('avg_hours_between_verifications'),
                            pl.col('timestamp').diff().dt.total_hours().std().fill_null(0)
                                .alias('std_hours_between_verifications'),
                            (pl.col('timestamp').diff().dt.total_hours() < 1).sum().fill_null(0)
                                .alias('rapid_verifications')
                        ])
                    )
                    result = result.join(verif_patterns, on='fid', how='left')
            
            # Add platform verification patterns
            acc_verifications = self.loader.get_dataset('account_verifications', 
                ['fid', 'verified_at'])
            
            if acc_verifications is not None and len(acc_verifications) > 0:
                platform_patterns = (
                    acc_verifications
                    .with_columns(pl.col('verified_at').cast(pl.Datetime))
                    .group_by('fid')
                    .agg([
                        pl.col('verified_at').diff().dt.total_hours().mean().fill_null(0)
                            .alias('avg_hours_between_platform_verifs'),
                        pl.col('verified_at').diff().dt.total_hours().std().fill_null(0)
                            .alias('std_hours_between_platform_verifs')
                    ])
                )
                platform_patterns = platform_patterns.unique(subset=['fid']) 
                result = result.join(platform_patterns, on='fid', how='left')
            
            # Add any missing columns with defaults
            for col, default in default_cols.items():
                if col not in result.columns:
                    result = result.with_columns(pl.lit(default).alias(col))
                else:
                    result = result.with_columns(pl.col(col).fill_null(default))
            
            self.loader.clear_cache()
            return result
            
        except Exception as e:
            print(f"Error in verification patterns: {str(e)}")
            raise
            return df.with_columns([pl.lit(v).alias(k) for k, v in default_cols.items()])
    def _validate_required_columns(self, df: pl.DataFrame, required_cols: List[str]):
        """Validate required columns exist"""
        missing = [col for col in required_cols if col not in df.columns]
        if missing:
            raise ValueError(f"Missing required columns: {missing}")
            
    def _get_feature_build_order(self):
        """Get correct feature build order based on dependencies"""
        visited = set()
        order = []
        
        def visit(name):
            if name in visited:
                return
            visited.add(name)
            feature_set = self.feature_sets[name]
            for dep in feature_set.dependencies:
                visit(dep)
            order.append(name)
        
        for name in self.feature_sets:
            visit(name)
        return order


    def _validate_feature_addition(self, original_df: pl.DataFrame, 
                                new_df: pl.DataFrame,
                                base_fids: pl.Series,
                                feature_name: str) -> pl.DataFrame:
        """Validate and fix feature addition results"""
        if new_df is None:
            print(f"Error: {feature_name} returned None")
            raise
            return original_df
            
        if len(new_df) != len(original_df):
            print(f"Warning: Shape mismatch in {feature_name}. Expected {len(original_df)}, got {len(new_df)}")
            new_df = new_df.filter(pl.col('fid').is_in(base_fids))
            if len(new_df) != len(original_df):
                return original_df
                
        # Cast numeric columns and handle nulls
        new_cols = [c for c in new_df.columns if c not in original_df.columns]
        if new_cols:
            try:
                new_df = new_df.with_columns([
                    pl.col(c).cast(pl.Float64).fill_null(0) 
                    for c in new_cols 
                    if self._is_numeric_dtype(new_df[c].dtype)
                ])
            except Exception as e:
                print(f"Error casting columns in {feature_name}: {str(e)}")
                raise
                return original_df
                
        return new_df


    def add_enhanced_channel_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Add channel features with improved error handling"""
        try:
            # Prepare result DataFrame with default values
            result = df.with_columns([
                pl.lit(0).alias('unique_channels_followed'),
                pl.lit(0).alias('rapid_channel_follows'),
                pl.lit(0.0).alias('channel_follow_hour_std'),
                pl.lit(0).alias('channel_memberships'),
                pl.lit(0).alias('unique_channel_memberships'),
                pl.lit(0.0).alias('channel_follow_burst_ratio'),
                pl.lit(0.0).alias('channel_engagement_ratio')
            ])
            
            # Process channel follows if available
            channel_follows = self.loader.get_dataset('channel_follows', 
                ['fid', 'channel_id', 'timestamp', 'deleted_at'])
            
            if channel_follows is not None and len(channel_follows) > 0:
                follow_features = (
                    channel_follows.filter(pl.col('deleted_at').is_null())
                    .group_by('fid')
                    .agg([
                        pl.n_unique('channel_id').alias('unique_channels_followed'),
                        (pl.col('timestamp').diff().dt.total_seconds() < 60)
                            .sum().alias('rapid_channel_follows'),
                        pl.col('timestamp').dt.hour().value_counts()
                            .std().alias('channel_follow_hour_std')
                    ])
                )
                # Join follow features safely
                if len(follow_features) > 0:
                    follow_features = follow_features.unique(subset=['fid']) 
                    result = result.join(follow_features, on='fid', how='left').fill_null(0)
            
            # Process channel memberships if available
            channel_members = self.loader.get_dataset('channel_members', 
                ['fid', 'channel_id', 'deleted_at'])
            
            if channel_members is not None and len(channel_members) > 0:
                member_features = (
                    channel_members.filter(pl.col('deleted_at').is_null())
                    .group_by('fid')
                    .agg([
                        pl.len().alias('channel_memberships'),
                        pl.n_unique('channel_id').alias('unique_channel_memberships')
                    ])
                )
                # Join member features safely
                if len(member_features) > 0:
                    member_features = member_features.unique(subset=['fid']) 
                    result = result.join(member_features, on='fid', how='left').fill_null(0)
            
            # Calculate derived metrics safely
            result = result.with_columns([
                (pl.col('rapid_channel_follows') / pl.col('unique_channels_followed').add(1))
                    .alias('channel_follow_burst_ratio'),
                (pl.col('channel_memberships') / pl.col('unique_channel_memberships').add(1))
                    .alias('channel_engagement_ratio')
            ])
            
            self.loader.clear_cache()
            return result
            
        except Exception as e:
            print(f"Error in channel features: {str(e)}")
            raise
            return df.with_columns([
                pl.lit(0).alias('unique_channels_followed'),
                pl.lit(0).alias('rapid_channel_follows'),
                pl.lit(0.0).alias('channel_follow_hour_std'),
                pl.lit(0).alias('channel_memberships'),
                pl.lit(0).alias('unique_channel_memberships'),
                pl.lit(0.0).alias('channel_follow_burst_ratio'),
                pl.lit(0.0).alias('channel_engagement_ratio')
            ])

    def add_engagement_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Add engagement features with improved dependency handling"""
        try:
            print("Processing engagement features...")
            
            # Initialize required columns with defaults
            required_cols = {
                'cast_count': 0,
                'total_reactions': 0,
                'channel_memberships': 0
            }
            
            # Ensure base columns exist
            result = df.clone()
            for col, default in required_cols.items():
                if col not in result.columns:
                    result = result.with_columns(pl.lit(default).alias(col))
                else:
                    result = result.with_columns(pl.col(col).fill_null(default))
            
            # Calculate engagement metrics safely
            result = result.with_columns([
                # Overall engagement score
                ((pl.col('cast_count') + 
                pl.col('total_reactions') + 
                pl.col('channel_memberships')) / 3.0
                ).alias('engagement_score'),
                
                # Activity balance
                (pl.col('cast_count') / pl.col('total_reactions').add(1))
                    .alias('creation_consumption_ratio')
            ])
            
            return result
            
        except Exception as e:
            print(f"Error in engagement features: {str(e)}")
            raise
            return df.with_columns([
                pl.lit(0.0).alias('engagement_score'),
                pl.lit(0.0).alias('creation_consumption_ratio')
            ])
    def _validate_and_ensure_features(self, df: pl.DataFrame, 
                                required_features: Dict[str, float]) -> pl.DataFrame:
        """Enhanced feature validation with null handling"""
        result = df.clone()
        
        for feature, default_value in required_features.items():
            if feature not in result.columns:
                print(f"Adding missing feature {feature} with default value {default_value}")
                result = result.with_columns(pl.lit(default_value).alias(feature))
            else:
                result = result.with_columns(
                    pl.when(pl.col(feature).is_null())
                    .then(pl.lit(default_value))
                    .otherwise(pl.col(feature))
                    .alias(feature)
                )
        
        return result
    def _load_checkpoint(self, feature_set: FeatureSet, base_fids: pl.Series) -> pl.DataFrame:
        """Enhanced checkpoint loading with proper list type handling"""
        try:
            print(f"Loading checkpoint: {feature_set.checkpoint_path}")
            checkpoint_df = pl.read_parquet(feature_set.checkpoint_path)
            checkpoint_df = checkpoint_df.with_columns(pl.col('fid').cast(pl.Int64))

            # Handle each column based on its type
            for col in checkpoint_df.columns:
                if col == 'fid':
                    continue
                    
                dtype_str = str(checkpoint_df[col].dtype).lower()
                
                # Skip list type columns
                if 'list' in dtype_str:
                    continue
                    
                # Handle numeric columns
                if any(num_type in dtype_str for num_type in ['int', 'float', 'decimal']):
                    checkpoint_df = checkpoint_df.with_columns([
                        pl.col(col).cast(pl.Float64).fill_null(0)
                    ])

            # Debug info
            print(f"Checkpoint fid type: {checkpoint_df['fid'].dtype}")
            print(f"Base fids type: {base_fids.dtype}")
            
            # Force consistent FID types
            checkpoint_df = checkpoint_df.with_columns(pl.col('fid').cast(pl.Int64))
            base_fids = base_fids.cast(pl.Int64)
            
            # Filter to base_fids
            filtered_df = checkpoint_df.filter(pl.col('fid').is_in(base_fids))
            print(f"Filtered checkpoint from {len(checkpoint_df)} to {len(filtered_df)} rows")
            
            # Special handling for certain feature sets
            if feature_set.name in ['authenticity', 'update_behavior']:
                filtered_df = self._validate_sensitive_checkpoint(filtered_df, feature_set.name)
                
            return filtered_df
            
        except Exception as e:
            print(f"Error loading checkpoint: {str(e)}")
            raise

    def _validate_checkpoint_compatibility(self, checkpoint_df: pl.DataFrame, 
                                        base_fids: pl.Series) -> bool:
        """Validate checkpoint compatibility with list type handling"""
        try:
            if checkpoint_df is None or len(checkpoint_df) == 0:
                return False
                
            # Verify FID column exists and is correct type
            if 'fid' not in checkpoint_df.columns:
                return False
                
            checkpoint_fids = checkpoint_df['fid'].cast(pl.Int64)
            base_fids = base_fids.cast(pl.Int64)
            
            # Verify all base FIDs are present
            missing_fids = pl.Series(np.setdiff1d(base_fids, checkpoint_fids))
            if len(missing_fids) > 0:
                print(f"Missing FIDs in checkpoint: {missing_fids}")
                return False
                
            # Verify column types
            for col in checkpoint_df.columns:
                if col == 'fid':
                    continue
                    
                dtype_str = str(checkpoint_df[col].dtype).lower()
                # Skip validation for list type columns
                if 'list' in dtype_str:
                    continue
                    
                # Validate numeric columns
                if any(num_type in dtype_str for num_type in ['int', 'float', 'decimal']):
                    try:
                        # Test if we can cast to Float64
                        checkpoint_df.select(pl.col(col).cast(pl.Float64))
                    except Exception as e:
                        print(f"Column {col} failed type validation: {str(e)}")
                        return False
                        
            return True
            
        except Exception as e:
            print(f"Error validating checkpoint compatibility: {str(e)}")
            return False
    def add_nindexer_enhanced_network_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Add enhanced network features with scalar aggregations"""
        try:
            follows = self.loader.get_dataset('follows', 
                ['fid', 'target_fid', 'timestamp', 'created_at', 'deleted_at'], 
                source="nindexer")
            follow_counts = self.loader.get_dataset('follow_counts',
                ['fid', 'follower_count', 'following_count', 'created_at'], 
                source="nindexer")
            
            if follows is not None and len(follows) > 0:
                valid_follows = follows.filter(pl.col('deleted_at').is_null())
                
                follow_metrics = (valid_follows
                    .with_columns([
                        pl.col('timestamp').cast(pl.Datetime),
                        pl.col('created_at').cast(pl.Datetime)
                    ])
                    .group_by('fid')
                    .agg([
                        (pl.col('timestamp').max() - pl.col('timestamp').min())
                            .dt.total_hours()
                            .cast(pl.Float64)
                            .alias('network_age_hours'),
                        pl.len().alias('total_follows'),
                        (pl.col('created_at') - pl.col('timestamp'))
                            .dt.total_seconds()
                            .mean()
                            .cast(pl.Float64)
                            .alias('avg_follow_latency_seconds')
                    ])
                    .with_columns([
                        (pl.col('total_follows') / 
                        (pl.col('network_age_hours') + 1))
                        .alias('follow_rate_per_hour')
                    ]))
                
                if follow_counts is not None and len(follow_counts) > 0:
                    count_metrics = (follow_counts
                        .sort('created_at')  # Sort to ensure last() gets most recent
                        .group_by('fid')
                        .agg([
                            pl.col('follower_count')
                                .last()
                                .cast(pl.Float64)
                                .alias('latest_follower_count'),
                            pl.col('following_count')
                                .last()
                                .cast(pl.Float64)
                                .alias('latest_following_count')
                        ])
                        .with_columns([
                            (pl.col('latest_follower_count') / 
                            (pl.col('latest_following_count') + 1))
                            .alias('latest_follow_ratio')
                        ]))
                    
                    result = df.join(follow_metrics, on='fid', how='left')
                    result = result.join(count_metrics, on='fid', how='left')
                else:
                    result = df.join(follow_metrics, on='fid', how='left')
                
                return result.fill_null(0)
                
            return df.with_columns([
                pl.lit(0.0).alias('network_age_hours'),
                pl.lit(0.0).alias('total_follows'),
                pl.lit(0.0).alias('follow_rate_per_hour'),
                pl.lit(0.0).alias('avg_follow_latency_seconds'),
                pl.lit(0.0).alias('latest_follower_count'),
                pl.lit(0.0).alias('latest_following_count'),
                pl.lit(0.0).alias('latest_follow_ratio')
            ])
            
        except Exception as e:
            print(f"Error in enhanced network features: {str(e)}")
            raise

    def add_nindexer_enhanced_profile_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Add enhanced profile features with scalar aggregations"""
        try:
            profiles = self.loader.get_dataset('profiles', 
                ['fid', 'bio', 'pfp_url', 'url', 'username', 
                'location', 'created_at', 'updated_at'], 
                source="nindexer")
            
            if profiles is not None and len(profiles) > 0:
                profile_metrics = (profiles
                    .with_columns([
                        pl.col('created_at').cast(pl.Datetime),
                        pl.col('updated_at').cast(pl.Datetime),
                        pl.col('url').is_not_null().cast(pl.Int32).alias('has_url'),
                        pl.col('location').is_not_null().cast(pl.Int32).alias('has_location')
                    ])
                    .group_by('fid')
                    .agg([
                        # Ensure scalar sum
                        (pl.col('has_url') + pl.col('has_location'))
                            .cast(pl.Float64)
                            .alias('additional_profile_fields'),
                        (pl.col('updated_at').max() - pl.col('created_at').min())
                            .dt.total_hours()
                            .cast(pl.Float64)
                            .alias('profile_age_hours'),
                        pl.col('location')
                            .first()
                            .alias('location')
                    ]))
                
                result = df.join(profile_metrics, on='fid', how='left')
                
                if result.select(pl.col('location').is_not_null().sum()).item() > 0:
                    result = result.with_columns([
                        pl.col('location')
                            .is_not_null()
                            .cast(pl.Int32)
                            .alias('has_location_info')
                    ])
                
                return result.fill_null(0)
                
            return df
            
        except Exception as e:
            print(f"Error in enhanced profile features: {str(e)}")
            raise
    def add_nindexer_neynar_score_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Add Neynar score features and correlations with proper type handling"""
        try:
            scores = self.loader.get_dataset('neynar_user_scores', 
                ['fid', 'score', 'created_at'], source="nindexer")
            
            if scores is not None and len(scores) > 0:
                # Get latest scores per user and ensure we're dealing with scalar values
                score_features = (scores
                    .with_columns([
                        pl.col('created_at').cast(pl.Datetime),
                        # Ensure score is handled as a scalar
                        pl.when(pl.col('score').is_null())
                        .then(0.0)
                        .otherwise(pl.col('score'))
                        .alias('score')
                    ])
                    .group_by('fid')
                    .agg([
                        # Latest score
                        pl.col('score').last().alias('neynar_score'),
                        # Average score over time
                        pl.col('score').mean().alias('avg_neynar_score'),
                        # Score stability
                        pl.col('score').std().alias('neynar_score_std'),
                        # Score trend (positive or negative)
                        (pl.col('score').last() - pl.col('score').first()).alias('score_trend')
                    ]))
                
                result = df.join(score_features, on='fid', how='left')
                
                # Calculate correlation with authenticity score if it exists
                if 'authenticity_score' in result.columns:
                    result = result.with_columns([
                        # Safely calculate score difference
                        (pl.col('neynar_score').cast(pl.Float64) - 
                        pl.col('authenticity_score').cast(pl.Float64))
                        .abs()
                        .alias('score_divergence'),
                        
                        # Calculate relative score difference
                        ((pl.col('neynar_score').cast(pl.Float64) - 
                        pl.col('authenticity_score').cast(pl.Float64)) /
                        (pl.col('authenticity_score').cast(pl.Float64) + 1e-6))
                        .alias('relative_score_diff')
                    ])
                
                # Fill any remaining nulls with 0
                result = result.with_columns([
                    pl.col('neynar_score').fill_null(0.0),
                    pl.col('avg_neynar_score').fill_null(0.0),
                    pl.col('neynar_score_std').fill_null(0.0),
                    pl.col('score_trend').fill_null(0.0)
                ])
                
                if 'score_divergence' in result.columns:
                    result = result.with_columns([
                        pl.col('score_divergence').fill_null(0.0),
                        pl.col('relative_score_diff').fill_null(0.0)
                    ])
                
                return result
            
            return df
            
        except Exception as e:
            print(f"Error in neynar score features: {str(e)}")
            # Return original dataframe with default columns if error occurs
            return df.with_columns([
                pl.lit(0.0).alias('neynar_score'),
                pl.lit(0.0).alias('avg_neynar_score'),
                pl.lit(0.0).alias('neynar_score_std'),
                pl.lit(0.0).alias('score_trend')
            ])
    def add_name_pattern_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Add name pattern features with improved error handling"""
        try:
            print("Building name pattern features...")
            
            # Initialize with default values
            result = df.clone().with_columns([
                pl.lit(0).alias('random_numbers'),
                pl.lit(0).alias('wallet_pattern'),
                pl.lit(0).alias('excessive_symbols'),
                pl.lit(0).alias('airdrop_terms'),
                pl.lit(0).alias('has_year')
            ])

            result = result.with_columns([
                    pl.col('fname').map_elements(self._analyze_name_patterns, return_dtype=pl.Utf8).alias('fname_content_patterns'),
                    pl.col('bio').map_elements(self._analyze_name_patterns, return_dtype=pl.Utf8).alias('bio_content_patterns'),
                ])

            result = result.with_columns([
                    (pl.col('fname_content_patterns').map_elements(lambda x: x['fname_random_numbers'], return_dtype=pl.Float64).sum() / pl.len())
                        .alias('fname_random_numbers'),
                    (pl.col('fname_content_patterns').map_elements(lambda x: x['fname_wallet_pattern'], return_dtype=pl.Float64).sum() / pl.len())
                        .alias('fname_wallet_pattern'),
                    (pl.col('fname_content_patterns').map_elements(lambda x: x['fname_excessive_symbols'], return_dtype=pl.Float64).sum() / pl.len())
                        .alias('fname_excessive_symbols'),
                    (pl.col('fname_content_patterns').map_elements(lambda x: x['fname_airdrop_terms'], return_dtype=pl.Float64).sum() / pl.len())
                        .alias('fname_airdrop_terms'),
                    (pl.col('fname_content_patterns').map_elements(lambda x: x['fname_has_year'], return_dtype=pl.Float64).sum() / pl.len())
                        .alias('fname_has_year'),
                    (pl.col('bio_content_patterns').map_elements(lambda x: x['bio_random_numbers'], return_dtype=pl.Float64).sum() / pl.len())
                        .alias('bio_random_numbers'),
                    (pl.col('bio_content_patterns').map_elements(lambda x: x['bio_wallet_pattern'], return_dtype=pl.Float64).sum() / pl.len())
                        .alias('bio_wallet_pattern'),
                    (pl.col('bio_content_patterns').map_elements(lambda x: x['bio_excessive_symbols'], return_dtype=pl.Float64).sum() / pl.len())
                        .alias('bio_excessive_symbols'),
                    (pl.col('bio_content_patterns').map_elements(lambda x: x['bio_airdrop_terms'], return_dtype=pl.Float64).sum() / pl.len())
                        .alias('bio_airdrop_terms'),
                    (pl.col('bio_content_patterns').map_elements(lambda x: x['bio_has_year'], return_dtype=pl.Float64).sum() / pl.len())
                        .alias('bio_has_year'),
            ])

            
            return result
            
        except Exception as e:
            print(f"Error in name pattern features: {str(e)}")
            raise
            return df.with_columns([
                pl.lit(0).alias('random_numbers'),
                pl.lit(0).alias('wallet_pattern'),
                pl.lit(0).alias('excessive_symbols'),
                pl.lit(0).alias('airdrop_terms'),
                pl.lit(0).alias('has_year')
            ])

    def build_feature_matrix(self) -> pl.DataFrame:
        """Build feature matrix with enhanced safety checks while maintaining existing functionality"""
        print("Starting feature extraction...")
        
        try:
            # Load or build profile features
            if self._needs_rebuild(self.feature_sets['profile']):
                print("Building profile features...")
                df = self.extract_profile_features()
            else:
                print("Loading profile features from checkpoint...")
                df = pl.read_parquet(self.feature_sets['profile'].checkpoint_path)
            
            # Setup base configuration
            df = df.with_columns(pl.col('fid').cast(pl.Int64))
            base_fids = df['fid'].cast(pl.Int64).unique()
            df = df.filter(pl.col('fid').is_in(base_fids))
            self.loader.set_base_fids(base_fids)
            initial_cols = df.columns
            print(f"Base shape: {df.shape}")

            # Define dependencies between features
            dependencies = {
                'engagement': ['cast', 'reaction', 'channel'],
                'network_quality': ['network', 'engagement'],
                'activity_patterns': ['temporal', 'cast', 'reaction'],
                'mentions': ['cast'],
                'reply_patterns': ['cast'],
                'update_behavior': ['user_data'],
                'verification_patterns': ['verification'],
                'authenticity': ['profile', 'network', 'verification', 'engagement']
            }

            # Track successfully built features
            built_features = {'profile'}
            
            feature_sequence = [
                ('network', self.add_network_features),
                ('temporal', self.add_temporal_features),
                ('cast', self.add_cast_behavior_features),
                ('reaction', self.add_reaction_patterns),
                ('channel', self.add_enhanced_channel_features),
                ('user_data', self.add_user_data_features),
                ('verification', self.add_enhanced_verification_features),
                ('engagement', self.add_engagement_features),
                ('network_quality', self.build_network_quality_features),
                ('activity_patterns', self.add_activity_patterns_features),
                ('influence', self.add_influence_features),
                ('mentions', self.add_mentions_features),
                ('reply_patterns', self.add_reply_patterns_features),
                ('power_user_interaction', self.add_power_user_interaction_features),
                # ('cluster_analysis', self.add_cluster_analysis_features),
                ('update_behavior', self.add_update_behavior_features),
                ('verification_patterns', self.add_verification_patterns_features),
                ('authenticity', self.add_authenticity_features),
                ('storage', self.add_storage_features),
                ('derived', self._add_derived_features),
                ('enhanced_network', self.add_nindexer_enhanced_network_features),
                ('enhanced_profile', self.add_nindexer_enhanced_profile_features),
                ('neynar_score', self.add_nindexer_neynar_score_features),
                ('name_patterns', self.add_name_pattern_features),
                # ('content_patterns', self.add_cast_behavior_features),  # Modified version
                # ('advanced_temporal', self.add_advanced_temporal_features),
                # ('reward_gaming', self.add_reward_gaming_features),
                # ('engagement_authenticity', self.add_engagement_authenticity_features)
            ]

            for feature_name, feature_func in feature_sequence:
                feature_set = self.feature_sets[feature_name]
                current_cols = set(df.columns)
                
                try:
                    # Check if dependencies are met
                    should_rebuild = self._needs_rebuild(feature_set)
                    if feature_name in dependencies:
                        deps = dependencies[feature_name]
                        missing_deps = [dep for dep in deps if dep not in built_features]
                        if missing_deps:
                            print(f"Missing dependencies for {feature_name}: {missing_deps}")
                            print(f"Currently built features: {built_features}")
                            should_rebuild = True

                    if should_rebuild:
                        print(f"Building {feature_name} features...")
                        new_df = feature_func(df)
                        
                        if new_df is not None:
                            # Validate and safely join new features
                            new_df = self._validate_checkpoint(new_df, feature_name)
                            new_df = new_df.with_columns(pl.col('fid').cast(pl.Int64))
                            
                            # Only save and update if validation passes
                            if self._validate_checkpoint_compatibility(new_df, base_fids):
                                self._save_checkpoint(new_df, feature_set)
                                df = self._safe_join_features(df, new_df, feature_name)
                                built_features.add(feature_name)
                                print(f"Successfully built and saved {feature_name}")
                    else:
                        print(f"Loading {feature_name} features from checkpoint...")
                        checkpoint_df = self._load_checkpoint(feature_set, base_fids)
                        
                        if checkpoint_df is not None:
                            new_cols = [c for c in checkpoint_df.columns if c not in current_cols]
                            if new_cols:
                                print(f"Adding {len(new_cols)} new columns from {feature_name}")
                                # Use safe join for checkpoint data too
                                df = self._safe_join_features(
                                    df,
                                    checkpoint_df.select(['fid'] + new_cols),
                                    feature_name
                                )
                                built_features.add(feature_name)
                        else:
                            print(f"Failed to load {feature_name} checkpoint, forcing rebuild...")
                            new_df = feature_func(df)
                            if new_df is not None:
                                new_df = self._validate_checkpoint(new_df, feature_name)
                                new_df = new_df.with_columns(pl.col('fid').cast(pl.Int64))
                                self._save_checkpoint(new_df, feature_set)
                                df = self._safe_join_features(df, new_df, feature_name)
                                built_features.add(feature_name)
                    
                    print(f"Shape after {feature_name}: {df.shape}")
                    
                except Exception as e:
                    print(f"Error in {feature_name}: {str(e)}")
                    raise
                    continue

            # Final validation
            df = df.fill_null(0)
            df = df.with_columns(pl.col('fid').cast(pl.Int64))
            
            # self.verify_matrix(df)


            return df
            
        except Exception as e:
            print(f"Critical error: {str(e)}")
            raise
        
    def _validate_checkpoint(self, df: pl.DataFrame, name: str) -> pl.DataFrame:
        """Validate checkpoint data types and ensure type consistency with list handling"""
        try:
            # Always ensure fid is Int64 first
            if 'fid' in df.columns:
                df = df.with_columns(pl.col('fid').cast(pl.Int64))
                
            # Cast numeric columns and handle nulls, excluding list types
            numeric_cols = []
            for col in df.columns:
                if col != 'fid':
                    dtype_str = str(df[col].dtype).lower()
                    # Check if it's a list type
                    if 'list' in dtype_str:
                        continue
                    # Check if it's a numeric type
                    if any(num_type in dtype_str for num_type in ['int', 'float', 'decimal']):
                        numeric_cols.append(col)
            
            if numeric_cols:
                df = df.with_columns([
                    pl.col(col).cast(pl.Float64).fill_null(0) 
                    for col in numeric_cols
                ])
            
            return df
                
        except Exception as e:
            print(f"Error validating checkpoint {name}: {str(e)}")
            raise
            return df

    def _is_numeric_dtype(self, dtype) -> bool:
        """Check if a Polars dtype is numeric, excluding list types"""
        # Convert dtype to string for comparison
        dtype_str = str(dtype).lower()
        # Exclude list types
        if 'list' in dtype_str:
            return False
        return any(num_type in dtype_str 
                for num_type in ['int', 'float', 'decimal'])

    def _safe_join_features(self, df: pl.DataFrame, 
                        new_features: pl.DataFrame,
                        feature_name: str) -> pl.DataFrame:
        """Enhanced safe join features with comprehensive null and list handling"""
        try:
            if new_features is None or len(new_features) == 0:
                print(f"No valid features to join for {feature_name}")
                return df

            # Get new columns
            existing_cols = set(df.columns)
            new_cols = [c for c in new_features.columns 
                    if c != 'fid' and c not in existing_cols]
                    
            if not new_cols:
                print(f"No new columns to add from {feature_name}")
                return df
                
            # Handle nulls in new features before join
            safe_features = new_features.clone()
            for col in new_cols:
                dtype_str = str(new_features[col].dtype).lower()
                if 'list' in dtype_str:
                    # For list columns, replace null with empty list
                    safe_features = safe_features.with_columns(
                        pl.col(col).fill_null([])
                    )
                elif self._is_numeric_dtype(new_features[col].dtype):
                    # For numeric columns, fill null with 0
                    safe_features = safe_features.with_columns(
                        pl.col(col).fill_null(0.0)
                    )
            
            # Join with guaranteed FID type consistency
            safe_features = safe_features.unique(subset=['fid']) 
            result = df.join(
                safe_features.select(['fid'] + new_cols)
                .with_columns(pl.col('fid').cast(pl.Int64)),
                on='fid',
                how='left'
            )
            
            # Handle any new nulls that appeared after join
            for col in new_cols:
                dtype_str = str(result[col].dtype).lower()
                if 'list' in dtype_str:
                    result = result.with_columns(
                        pl.col(col).fill_null([])
                    )
                elif self._is_numeric_dtype(result[col].dtype):
                    result = result.with_columns(
                        pl.col(col).fill_null(0.0)
                    )
                            
            return result
                
        except Exception as e:
            print(f"Error joining {feature_name}: {str(e)}")
            raise
            return df

    def _validate_feature_dependencies(self, feature_name: str, 
                                built_features: set) -> bool:
        """Validate feature dependencies are met"""
        if feature_name not in self.feature_sets:
            return False
            
        feature_set = self.feature_sets[feature_name]
        for dep in feature_set.dependencies:
            if dep not in built_features:
                print(f"Missing dependency {dep} for {feature_name}")
                return False
                
        return True
    def _save_checkpoint(self, df: pl.DataFrame, feature_set: 'FeatureSet'):
        """Save feature checkpoint with validation"""
        # Validate before saving
        df = self._validate_checkpoint(df, feature_set.name)
        df.write_parquet(feature_set.checkpoint_path)
        feature_set.last_modified = os.path.getmtime(feature_set.checkpoint_path)



    def _validate_sensitive_checkpoint(self, df: pl.DataFrame, feature_name: str) -> pl.DataFrame:
        """Additional validation for sensitive features"""
        try:
            # Initialize sensitive columns with safe defaults
            sensitive_defaults = {
                'authenticity': {
                    'authenticity_score': 0.0,
                    'profile_completeness': 0.0,
                    'network_balance': 0.0,
                    'update_naturalness': 0.0
                },
                'update_behavior': {
                    'profile_update_consistency': 0.0,
                    'total_updates': 0,
                    'avg_update_interval': 0.0,
                    'update_time_std': 0.0
                }
            }
            
            if feature_name in sensitive_defaults:
                for col, default in sensitive_defaults[feature_name].items():
                    if col in df.columns:
                        df = df.with_columns(pl.col(col).fill_null(default))
                    else:
                        df = df.with_columns(pl.lit(default).alias(col))
                        
            return df
            
        except Exception as e:
            print(f"Error validating sensitive checkpoint {feature_name}: {str(e)}")
            raise
            return df
    def _add_derived_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Add derived features with comprehensive null handling"""
        try:
            print("Building derived features...")
            result = df.clone()
            
            # Ensure required columns exist
            required_cols = {
                'following_count': 0.0,
                'follower_count': 0.0,
                'follower_ratio': 0.0,
                'unique_follower_ratio': 0.0,
                'follow_velocity': 0.0
            }
            
            # Initialize missing columns
            for col, default in required_cols.items():
                if col not in result.columns:
                    print(f"Adding missing column {col} with default {default}")
                    result = result.with_columns(pl.lit(default).alias(col))
                
                # Fill nulls with defaults
                result = result.with_columns(
                    pl.col(col).fill_null(default).alias(col)
                )
            
            # Safe calculations with explicit null handling
            result = result.with_columns([
                # Log transformations with null safety
                pl.col('follower_ratio').fill_null(0.0).log1p().alias('follower_ratio_log'),
                pl.col('unique_follower_ratio').fill_null(0.0).log1p().alias('unique_follower_ratio_log'),
                pl.col('follow_velocity').fill_null(0.0).log1p().alias('follow_velocity_log'),
                
                # Binary flags with safe comparisons
                (pl.when(pl.col('follower_count').fill_null(0) > pl.col('following_count').fill_null(0))
                .then(1)
                .otherwise(0)
                ).alias('has_more_followers'),
                
                # Balance ratios with safe division
                ((pl.col('following_count').fill_null(0) - pl.col('follower_count').fill_null(0)).abs() / 
                (pl.col('following_count').fill_null(0) + pl.col('follower_count').fill_null(0) + 1)
                ).alias('follow_balance_ratio')
            ])
            
            # Cap extreme values with safe operations
            for col in ['follower_ratio', 'unique_follower_ratio', 'follow_velocity']:
                if col in result.columns:
                    safe_col = pl.col(col).fill_null(0.0)
                    p99 = result.select(safe_col.quantile(0.99)).item()
                    result = result.with_columns([
                        safe_col.clip(0.0, p99).alias(f'{col}_capped')
                    ])
            
            print("Derived features completed successfully")
            return result
            
        except Exception as e:
            print(f"Error in derived features: {str(e)}")
            raise
            return df

In [None]:
# notebook code
import optuna
from sklearn.model_selection import StratifiedKFold, cross_val_score, train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.calibration import CalibratedClassifierCV
from sklearn.ensemble import RandomForestClassifier, VotingClassifier
from sklearn.metrics import (
    classification_report, confusion_matrix, f1_score,
    precision_recall_curve, precision_score, recall_score,
    roc_auc_score
)
import xgboost as xgb
from lightgbm import LGBMClassifier
import shap
from scipy import stats
import numpy as np
from typing import Dict, List, Tuple
import polars as pl

class SybilDetectionSystem:
    def __init__(self, 
                 feature_engineering: 'FeatureEngineering',
                 confidence_thresholds: Dict[str, float] = None,
                 authenticity_thresholds: Dict[str, float] = None):
        self.feature_engineering = feature_engineering
        self.model = None
        self.feature_names = None
        self.scaler = StandardScaler()
        self.confidence_thresholds = confidence_thresholds or {
            'high': 0.95,
            'medium': 0.85,
            'low': 0.70
        }
        self.authenticity_thresholds = authenticity_thresholds or {
            'high': 0.8,
            'medium': 0.6,
            'low': 0.4
        }
        self.feature_importance = {}
        self.shap_values = {}
        self.base_models = {}
        self.shap_explainers = {}
            
    def prepare_features(self, df: pl.DataFrame, scale: bool = False) -> Tuple[np.ndarray, List[str]]:
        """Prepare features with comprehensive feature selection and validation"""
        try:
            # Define feature groups

            valid_cols = [col for col in df.columns if 
                        df[col].dtype in [pl.Float64, pl.Int64] or
                        str(df[col].dtype).startswith(('Float', 'Int'))]
            
            print(f"\nTotal numeric features available: {len(valid_cols)}")
            
            # Convert to numpy array
            features = df.select(valid_cols).fill_null(0)
            for col in valid_cols:
                col_dtype = str(features[col].dtype)
                
                if col_dtype.startswith('list') or col_dtype.startswith('List'):
                    print(f"Converting list column {col} to length feature")
                    features = features.with_columns([
                        pl.when(pl.col(col).is_null())
                        .then(0)
                        .otherwise(pl.col(col).list.len())
                        .alias(col)
                    ])
            
            # Handle infinite values and extreme outliers
            for col in valid_cols:
                col_stats = features.select(
                    pl.col(col).quantile(0.01).alias('q01'),
                    pl.col(col).quantile(0.99).alias('q99'),
                    pl.col(col).mean().alias('mean'),
                    pl.col(col).std().alias('std')
                )
                
                q01 = col_stats['q01'][0]
                q99 = col_stats['q99'][0]
                mean_val = col_stats['mean'][0]
                std_val = col_stats['std'][0]
                
                # Define reasonable bounds for the column
                lower_bound = max(q01, mean_val - 3 * std_val)
                upper_bound = min(q99, mean_val + 3 * std_val)
                
                # Clip values to bounds and replace infinities
                features = features.with_columns([
                    pl.when(pl.col(col).is_infinite())
                    .then(pl.lit(None))
                    .otherwise(pl.col(col))
                    .alias(col)
                ])
                
                features = features.with_columns([
                    pl.col(col).clip(lower_bound, upper_bound).alias(col)
                ])
                
                # Fill remaining nulls with median
                median_val = features.select(pl.col(col).median())[0][0]
                features = features.with_columns([
                    pl.col(col).fill_null(median_val).alias(col)
                ])
                
                # Convert to numeric if needed
                if features[col].dtype not in [pl.Float64, pl.Int64]:
                    features = features.with_columns([
                        pl.col(col).cast(pl.Float64).alias(col)
                    ])

            # Convert to numpy array
            feature_array = features.to_numpy()
            
            if scale:
                feature_array = self.scaler.fit_transform(feature_array)

            print(f"\nFinal feature matrix shape: {feature_array.shape}")
            print(f"Using {len(valid_cols)} features")
            
            # Verify no infinite values remain
            if np.any(np.isinf(feature_array)):
                raise ValueError("Infinite values still present after preprocessing")

            return feature_array, valid_cols
            
        except Exception as e:
            print(f"Error preparing features: {str(e)}")
            print(f"Available columns: {df.columns}")
            raise
    def train(self, X_train: np.ndarray, y_train: np.ndarray, feature_names: List[str]):
        """Train the model with stacking and SHAP explanations"""
        self.feature_names = feature_names
        cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
        
        # Define base models
        base_model_configs = {
            'xgb': xgb.XGBClassifier(eval_metric='auc', random_state=42),
            'rf': RandomForestClassifier(n_jobs=-1, random_state=42, class_weight='balanced'),
            'lgbm': LGBMClassifier(n_jobs=-1, random_state=42, class_weight='balanced')
        }
        
        # Train and calibrate base models
        for name, model in base_model_configs.items():
            print(f"\nStarting training for {name}...")
            study = optuna.create_study(direction='maximize', study_name=f'optuna_{name}')
            
            def objective(trial):
                params = self.get_hyperparameters(name, trial)
                model.set_params(**params)
                cv_scores = cross_val_score(model, X_train, y_train, cv=cv, scoring='average_precision')
                return cv_scores.mean()
            
            study.optimize(objective, n_trials=50, timeout=600)
            print(f"Best parameters for {name}: {study.best_params}")
            print(f"Best CV score for {name}: {study.best_value}")
            
            # Train base model
            best_model = type(model)(**study.best_params)
            best_model.fit(X_train, y_train)
            
            # Store SHAP explainer
            try:
                explainer = shap.TreeExplainer(best_model)
                self.shap_explainers[name] = explainer
                print(f"SHAP explainer created for {name}")
            except Exception as e:
                print(f"Error creating SHAP explainer for {name}: {str(e)}")
                self.shap_explainers[name] = None
            
            # Calibrate model
            print(f"Calibrating {name}...")
            calibrated_model = CalibratedClassifierCV(best_model, cv=5)
            calibrated_model.fit(X_train, y_train)
            self.base_models[name] = calibrated_model
        
        # Build stacked model
        print("\nBuilding stacked model...")
        self.build_stacked_model(X_train, y_train)
        
        # Create final ensemble
        print("\nCreating ensemble...")
        self.model = VotingClassifier(
            estimators=[
                (name, model) for name, model in self.base_models.items()
            ] + [('meta_learner', self.meta_learner)],
            voting='soft',
            weights=[0.25, 0.25, 0.25, 0.25]
        )
        self.model.fit(X_train, y_train)
        print("Ensemble training complete")

        # Get stability metrics
        stability_results = detector.add_cross_validation_stability(X_train, y_train)
        print("\nCross-validation Stability Metrics:")
        print(f"Mean prediction variance: {stability_results['mean_prediction_variance']:.4f}")
        print(f"Max prediction variance: {stability_results['max_prediction_variance']:.4f}")
        print(f"Mean prediction range: {stability_results['mean_prediction_range']:.4f}")
        print(f"Percentage of stable predictions: {stability_results['stable_prediction_percentage']:.2%}")        

        # Split some validation data
        X_val, X_test, y_val, y_test = train_test_split(
            X_train, y_train, test_size=0.2, random_state=42
        )
        
        # Optimize ensemble weights
        weights = self.optimize_ensemble_weights(X_val, y_val)
        print("Optimized model weights:", weights)
        
        # Final evaluation
        final_predictions, unstable_indices = self.predict_with_stability(X_test)
        print(f"Number of unstable predictions: {len(unstable_indices)}")
        
        return self

    def build_stacked_model(self, X: np.ndarray, y: np.ndarray):
        """Build stacked model using base model predictions"""
        base_preds = np.zeros((len(self.base_models), len(X)))
        for i, (name, model) in enumerate(self.base_models.items()):
            base_preds[i] = model.predict_proba(X)[:, 1]
        
        meta_features = np.column_stack([base_preds.T, X])
        meta_learner = LGBMClassifier(
            n_estimators=100,
            learning_rate=0.01,
            max_depth=3,
            num_leaves=8,
            feature_fraction=0.8,
            bagging_fraction=0.8,
            random_state=42
        )
        meta_learner.fit(meta_features, y)
        self.meta_learner = meta_learner

    def get_feature_explanations(self, model_name: str, X: np.ndarray, instance_index: int) -> Dict:
        """Get SHAP explanations for a specific instance"""
        try:
            if instance_index < 0 or instance_index >= X.shape[0]:
                print(f"Invalid instance index: {instance_index}")
                return {}
                
            if model_name not in self.shap_explainers or self.shap_explainers[model_name] is None:
                print(f"No SHAP explainer available for {model_name}")
                return {}
            
            explainer = self.shap_explainers[model_name]
            shap_values = explainer.shap_values(X[instance_index:instance_index+1])
            
            if isinstance(shap_values, list):
                shap_values = shap_values[1] if len(shap_values) > 1 else shap_values[0]
            
            shap_instance = shap_values[0]
            top_indices = np.argsort(np.abs(shap_instance))[-5:][::-1]
            
            return {
                self.feature_names[i]: float(shap_instance[i])
                for i in top_indices
            }
            
        except Exception as e:
            print(f"Error getting SHAP explanations: {str(e)}")
            return {}

    def _calculate_feature_importance(self):
        """Calculate and store aggregated feature importance from base models"""
        try:
            for name, model in self.base_models.items():
                # Access base estimator within CalibratedClassifierCV
                if isinstance(model, CalibratedClassifierCV):
                    # Try to access base_estimator_ (scikit-learn >=0.24)
                    if hasattr(model, 'base_estimator_') and model.base_estimator_ is not None:
                        base_estimator = model.base_estimator_
                    # For older scikit-learn versions
                    elif hasattr(model, 'base_estimator') and model.base_estimator is not None:
                        base_estimator = model.base_estimator
                    else:
                        print(f"Model {name} does not have a base estimator.")
                        continue
                else:
                    base_estimator = model

                # Retrieve feature importances
                if hasattr(base_estimator, 'feature_importances_'):
                    importances = base_estimator.feature_importances_
                    for feat, imp in zip(self.feature_names, importances):
                        self.feature_importance[feat] = self.feature_importance.get(feat, 0) + imp
                else:
                    print(f"No feature_importances_ attribute for model {name}.")

            # Average importances across models
            num_models = len(self.base_models)
            if num_models > 0:
                self.feature_importance = {k: v / num_models for k, v in self.feature_importance.items()}
                print("Feature importance calculated.")
            else:
                print("No models available to calculate feature importance.")
        except Exception as e:
            print(f"Error calculating feature importance: {str(e)}")
    def select_important_features(self, X: np.ndarray, y: np.ndarray, feature_names: List[str], 
                                threshold: float = 0.01) -> List[str]:
        """Select features based on SHAP importance"""
        model = xgb.XGBClassifier(n_estimators=100, random_state=42)
        model.fit(X, y)
        
        explainer = shap.TreeExplainer(model)
        shap_values = explainer.shap_values(X)
        
        if isinstance(shap_values, list):
            shap_values = shap_values[1]
        importance_vals = np.abs(shap_values).mean(0)
        
        importance = dict(zip(feature_names, importance_vals))
        selected_features = [f for f, imp in importance.items() 
                            if imp > threshold * np.max(importance_vals)]
        
        print(f"\nSelected {len(selected_features)}/{len(feature_names)} features")
        print("Top 10 features:", sorted(importance.items(), key=lambda x: x[1], reverse=True)[:10])
        return selected_features

    def get_hyperparameters(self, model_name: str, trial: optuna.Trial) -> Dict:
        """Get optimized hyperparameters with regularization"""
        if model_name == 'xgb':
            return {
                'max_depth': trial.suggest_int('max_depth', 3, 7),
                'learning_rate': trial.suggest_float('learning_rate', 0.01, 0.1),
                'n_estimators': trial.suggest_int('n_estimators', 100, 500),
                'min_child_weight': trial.suggest_int('min_child_weight', 1, 7),
                'subsample': trial.suggest_float('subsample', 0.6, 1.0),
                'colsample_bytree': trial.suggest_float('colsample_bytree', 0.6, 1.0),
                'reg_alpha': trial.suggest_float('reg_alpha', 0, 10),
                'reg_lambda': trial.suggest_float('reg_lambda', 1, 10),
                'scale_pos_weight': trial.suggest_float('scale_pos_weight', 1.0, 10.0),
                'gamma': trial.suggest_float('gamma', 0, 5)
            }
        elif model_name == 'lgbm':
            return {
                'n_estimators': trial.suggest_int('n_estimators', 100, 500),
                'max_depth': trial.suggest_int('max_depth', 3, 7),
                'learning_rate': trial.suggest_float('learning_rate', 0.01, 0.1),
                'num_leaves': trial.suggest_int('num_leaves', 20, 100),
                'feature_fraction': trial.suggest_float('feature_fraction', 0.6, 1.0),
                'bagging_fraction': trial.suggest_float('bagging_fraction', 0.6, 1.0),
                'min_child_samples': trial.suggest_int('min_child_samples', 5, 30),
                'lambda_l1': trial.suggest_float('lambda_l1', 0, 10),
                'lambda_l2': trial.suggest_float('lambda_l2', 0, 10)
            }
        else:  # RandomForest
            return {
                'n_estimators': trial.suggest_int('n_estimators', 100, 500),
                'max_depth': trial.suggest_int('max_depth', 3, 7),
                'min_samples_split': trial.suggest_int('min_samples_split', 2, 10),
                'min_samples_leaf': trial.suggest_int('min_samples_leaf', 1, 5),
                'max_features': trial.suggest_categorical('max_features', ['sqrt', 'log2'])
            }

    def analyze_feature_interactions(self, X: np.ndarray, top_k: int = 10) -> List[Tuple[str, str, float]]:
        """Analyze most important feature interactions"""
        model = self.base_models['xgb']
        if hasattr(model, 'base_estimator_'):
            model = model.base_estimator_
        
        explainer = shap.TreeExplainer(model)
        shap_interaction_values = explainer.shap_interaction_values(X)
        
        # Calculate interaction strengths
        n_features = len(self.feature_names)
        interactions = []
        for i in range(n_features):
            for j in range(i+1, n_features):
                strength = np.abs(shap_interaction_values[:, i, j]).mean()
                interactions.append((
                    self.feature_names[i],
                    self.feature_names[j],
                    float(strength)
                ))
        
        # Return top K interactions
        return sorted(interactions, key=lambda x: x[2], reverse=True)[:top_k]

    def build_stacked_model(self, X: np.ndarray, y: np.ndarray):
        """Build a stacked model with meta-learner"""
        # Create base predictions
        base_preds = np.zeros((len(self.base_models), len(X)))
        for i, (name, model) in enumerate(self.base_models.items()):
            base_preds[i] = model.predict_proba(X)[:, 1]
        
        # Train meta-learner
        meta_features = np.column_stack([
            base_preds.T,  # Base predictions
            X  # Original features
        ])
        
        meta_learner = LGBMClassifier(
            n_estimators=100,   
            learning_rate=0.01,
            max_depth=3,
            num_leaves=8,
            feature_fraction=0.8,
            bagging_fraction=0.8,
            random_state=42
        )
        
        meta_learner.fit(meta_features, y)
        self.meta_learner = meta_learner
    def predict_with_uncertainty(self, features: np.ndarray, 
                                 authenticity_features: np.ndarray) -> List[Dict]:
        """Enhanced predictions with uncertainty estimation"""
        # Get predictions from all models
        predictions = []
        for name, model in self.base_models.items():
            pred_proba = model.predict_proba(features)[:, 1]
            predictions.append(pred_proba)
        
        # Calculate ensemble statistics
        predictions = np.array(predictions)
        mean_probs = predictions.mean(axis=0)
        std_probs = predictions.std(axis=0)
        
        # Calculate prediction intervals
        confidence_interval = stats.norm.interval(0.95, loc=mean_probs, scale=std_probs)
        
        results = []
        for i, (prob, std, auth_scores) in enumerate(zip(mean_probs, std_probs, authenticity_features)):
            authenticity_score = np.mean([
                auth_scores[0],  # authenticity_score
                auth_scores[1],  # engagement_quality
                auth_scores[2],  # natural_behavior_score
                auth_scores[3]   # account_stability
            ])
            
            # Enhanced confidence assessment
            model_uncertainty = std / prob if prob > 0 else std
            confidence = self._assess_confidence(prob, authenticity_score, model_uncertainty)
            
            results.append({
                'is_bot': prob >= 0.5,
                'is_authentic': authenticity_score >= self.authenticity_thresholds['medium'],
                'bot_probability': float(prob),
                'authenticity_score': float(authenticity_score),
                'confidence': confidence,
                'uncertainty': float(std),
                'prediction_interval': (float(confidence_interval[0][i]), 
                                         float(confidence_interval[1][i]))
            })
        
        return results

    def _assess_confidence(self, prob: float, authenticity: float, 
                          uncertainty: float) -> str:
        """Enhanced confidence assessment with uncertainty consideration"""
        # Adjust thresholds based on uncertainty
        uncertainty_penalty = uncertainty * 2
        
        if prob <= 0.1 and authenticity >= self.authenticity_thresholds['high'] and uncertainty < 0.1:
            return 'high_authentic'
        elif prob <= 0.2 and authenticity >= self.authenticity_thresholds['medium'] and uncertainty < 0.15:
            return 'medium_authentic'
        elif prob >= (self.confidence_thresholds['high'] + uncertainty_penalty):
            return 'high_bot'
        elif prob >= (self.confidence_thresholds['medium'] + uncertainty_penalty):
            return 'medium_bot'
        else:
            return 'uncertain'
    def get_feature_explanations(self, model_name: str, X: np.ndarray, instance_index: int) -> Dict:
        """Get SHAP explanations for a specific instance using the underlying base model"""
        try:
            # Index check
            if instance_index < 0 or instance_index >= X.shape[0]:
                print(f"Error: instance_index {instance_index} is out of bounds for test set with size {X.shape[0]}.")
                return {}
            
            if model_name not in self.base_models:
                print(f"No model available for {model_name}.")
                return {}
            
            model = self.base_models[model_name]
            
            # Get the underlying base model from the CalibratedClassifierCV
            if isinstance(model, CalibratedClassifierCV):
                # Access the first calibrated classifier's base estimator
                base_model = model.calibrated_classifiers_[0].base_estimator
                print(f"Using base estimator from calibrated classifier for {model_name}")
            else:
                base_model = model
                print(f"Using model directly for {model_name}")
                
            try:
                print(f"Creating SHAP explainer for model type: {type(base_model)}")
                explainer = shap.TreeExplainer(base_model)
                
                # Use small subset of data for explanation
                instance_data = X[instance_index:instance_index+1]
                print(f"Calculating SHAP values for instance shape: {instance_data.shape}")
                
                shap_vals = explainer.shap_values(instance_data)
                
                # Handle different SHAP value formats
                if isinstance(shap_vals, list):
                    if len(shap_vals) > 1:
                        shap_instance = shap_vals[1][0]  # For binary classification
                    else:
                        shap_instance = shap_vals[0][0]
                else:
                    shap_instance = shap_vals[0]
                
                # Get top feature contributions
                top_indices = np.argsort(np.abs(shap_instance))[-5:][::-1]
                explanations = {
                    self.feature_names[i]: float(shap_instance[i]) 
                    for i in top_indices
                }
                
                return explanations
                
            except Exception as e:
                print(f"Error calculating SHAP values for {model_name}: {str(e)}")
                print(f"Model type: {type(base_model)}")
                return {}
                
        except Exception as e:
            print(f"Error getting feature explanations: {str(e)}")
            return {}
    def add_cross_validation_stability(self, X: np.ndarray, y: np.ndarray, n_splits: int = 5) -> Dict[str, float]:
        """
        Measure prediction stability across different CV folds.
        Returns metrics about how consistent predictions are across folds.
        """
        kf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
        
        # Initialize array to store all predictions for each sample
        all_predictions = np.zeros((len(X), n_splits))
        all_predictions[:] = np.nan  # Fill with NaN to track which predictions we get
        
        # Get predictions from each fold
        for fold_idx, (train_idx, val_idx) in enumerate(kf.split(X, y)):
            X_train, X_val = X[train_idx], X[val_idx]
            y_train, y_val = y[train_idx], y[val_idx]
            
            # Train model on this fold
            self.model.fit(X_train, y_train)
            fold_proba = self.model.predict_proba(X_val)[:, 1]
            
            # Store predictions in the right spots
            all_predictions[val_idx, fold_idx] = fold_proba
        
        # Calculate variance for each sample (ignoring NaN values)
        sample_variances = np.nanvar(all_predictions, axis=1)
        sample_ranges = np.nanmax(all_predictions, axis=1) - np.nanmin(all_predictions, axis=1)
        
        # Compute stability metrics
        stability_metrics = {
            'mean_prediction_variance': np.mean(sample_variances),
            'max_prediction_variance': np.max(sample_variances),
            'mean_prediction_range': np.mean(sample_ranges),
            'max_prediction_range': np.max(sample_ranges),
            'stable_prediction_percentage': np.mean(sample_variances < 0.1)
        }
        
        print(f"\nPrediction matrix shape: {all_predictions.shape}")
        print(f"Number of samples with predictions: {np.sum(~np.isnan(all_predictions.mean(axis=1)))}")
        print(f"Average predictions per sample: {np.mean(~np.isnan(all_predictions)):.2f}")
        
        return stability_metrics    
    def optimize_ensemble_weights(self, X: np.ndarray, y: np.ndarray) -> List[float]:
        """Optimize ensemble weights based on individual model performance"""
        try:
            # Get individual model performances
            model_scores = {}
            
            # Score base models
            for name, model in self.base_models.items():
                score = roc_auc_score(y, model.predict_proba(X)[:, 1])
                model_scores[name] = score
                print(f"{name} ROC AUC: {score:.4f}")
            
            # Score meta learner on combined predictions
            base_preds = np.zeros((len(self.base_models), len(X)))
            for i, (name, model) in enumerate(self.base_models.items()):
                base_preds[i] = model.predict_proba(X)[:, 1]
            
            meta_features = np.column_stack([base_preds.T, X])
            meta_score = roc_auc_score(y, self.meta_learner.predict_proba(meta_features)[:, 1])
            model_scores['meta_learner'] = meta_score
            print(f"Meta learner ROC AUC: {meta_score:.4f}")
            
            # Calculate weights based on relative performance
            total_score = sum(model_scores.values())
            weights = [score/total_score for score in model_scores.values()]
            
            # Update ensemble with new weights
            self.model = VotingClassifier(
                estimators=[
                    (name, model) for name, model in self.base_models.items()
                ] + [('meta_learner', self.meta_learner)],
                voting='soft',
                weights=weights
            )
            self.model.fit(X, y)  # Refit with new weights
            
            return weights
            
        except Exception as e:
            print(f"Error optimizing weights: {str(e)}")
            return [0.25, 0.25, 0.25, 0.25]  # Default weights

    def predict_with_stability(self, X: np.ndarray) -> Tuple[np.ndarray, List[int]]:
        """Make predictions with stability assessment"""
        try:
            # Get predictions from base models
            base_predictions = np.zeros((len(self.base_models), len(X)))
            for i, (name, model) in enumerate(self.base_models.items()):
                base_predictions[i] = model.predict_proba(X)[:, 1]
            
            # Calculate prediction statistics
            mean_predictions = np.mean(base_predictions, axis=0)
            std_predictions = np.std(base_predictions, axis=0)
            
            # Identify unstable predictions (high variance between models)
            unstable_indices = np.where(std_predictions > 0.2)[0]
            
            # Get ensemble predictions
            predictions = self.model.predict_proba(X)
            
            # Adjust confidence for unstable predictions
            confidence_adjustments = 1 - np.clip(std_predictions, 0, 0.5)
            adjusted_predictions = predictions * confidence_adjustments.reshape(-1, 1)
            
            return adjusted_predictions, unstable_indices.tolist()
            
        except Exception as e:
            print(f"Error in prediction with stability: {str(e)}")
            return self.model.predict_proba(X), []

In [None]:

# Re-initialize after building
feature_eng = FeatureEngineering("data", "checkpoints")
detector = SybilDetectionSystem(feature_eng)

# Build feature matrix
matrix = feature_eng.build_feature_matrix()
print("Feature matrix built")

# Load labels
labels_df = pl.read_csv('data/labels.csv')
labels_df = labels_df.with_columns(pl.col('fid').cast(pl.Int64))

# Ensure matrix fid is Int64
matrix = matrix.with_columns(pl.col('fid').cast(pl.Int64))

# Join with matching types
data = matrix.join(labels_df, on='fid', how='inner')

# Extract features and labels
X, feature_names = detector.prepare_features(data.drop(['bot', 'fid']))
y = data['bot'].to_numpy()

# Perform train/test split
from sklearn.model_selection import train_test_split
fids = data['fid'].to_numpy()

X_train, X_test, y_train, y_test, train_fids, test_fids = train_test_split(
    X, y, fids,
    test_size=0.2,
    random_state=42,
    stratify=y  # ensures class distribution is preserved
)

# Define a path to save the model checkpoint
model_checkpoint_path = "checkpoints/sybil_detector_model.pkl"
os.makedirs("checkpoints", exist_ok=True)

# Check if a model checkpoint already exists
if os.path.exists(model_checkpoint_path):
    # Load the detector object (including the trained model)
    detector = joblib.load(model_checkpoint_path)
    print("Loaded model checkpoint. Skipping training.")
else:
    # Train the model using the training set only
    detector.train(X_train, y_train, feature_names)

    # Save the model checkpoint
    joblib.dump(detector, model_checkpoint_path)
    print(f"Model checkpoint saved to {model_checkpoint_path}")

# Evaluate on the test set
y_pred_proba = detector.model.predict_proba(X_test)[:, 1]
y_pred = (y_pred_proba >= 0.5).astype(int)

# Compute and plot SHAP for X_test:
shap_values_test = detector.shap_explainers['xgb'].shap_values(X_test)
if isinstance(shap_values_test, list) and len(shap_values_test) > 1:
    shap_values_test = shap_values_test[1]
assert shap_values_test.shape[0] == X_test.shape[0], "Mismatch in rows between shap_values and X_test!"
shap.summary_plot(shap_values_test, X_test, feature_names=feature_names)
import matplotlib.pyplot as plt
plt.show()

# Compute evaluation metrics
test_roc_auc = roc_auc_score(y_test, y_pred_proba)
test_f1 = f1_score(y_test, y_pred)
test_precision = precision_score(y_test, y_pred)
test_recall = recall_score(y_test, y_pred)

print("\nTest ROC AUC:", test_roc_auc)
print("Test F1 Score:", test_f1)
print("Test Precision:", test_precision)
print("Test Recall:", test_recall)

# SHAP explanations for a specific instance
instance_index = 0
model_name = 'xgb'
explanations = detector.get_feature_explanations(model_name, X_test, instance_index)
print(f"\nSHAP explanations for instance {instance_index} from {model_name}:")
print(explanations)

from sklearn.metrics import confusion_matrix
import seaborn as sns

cm = confusion_matrix(y_test, y_pred)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['Authentic', 'Bot'], yticklabels=['Authentic', 'Bot'])
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.show()

from sklearn.metrics import matthews_corrcoef, cohen_kappa_score

mcc = matthews_corrcoef(y_test, y_pred)
kappa = cohen_kappa_score(y_test, y_pred)

print("Matthews Correlation Coefficient:", mcc)
print("Cohen's Kappa:", kappa)


In [None]:
import random

# 1. Get all FIDs from the full matrix
all_fids = matrix['fid'].unique()
labeled_fids = data['fid'].unique()

# 2. Convert to numpy arrays for set operations
all_fids_np = all_fids.to_numpy()
labeled_fids_np = labeled_fids.to_numpy()

# 3. Find FIDs that aren't in the labeled dataset
unlabeled_fids = np.setdiff1d(all_fids_np, labeled_fids_np)

# 4. Filter matrix to get unlabeled data
unlabeled_data = matrix.filter(pl.col('fid').is_in(unlabeled_fids))

# 5. Get profiles with fnames
ds = feature_eng.loader.get_dataset('profile_with_addresses')
valid_profiles = ds.filter(pl.col('fname').is_not_null())

# Ensure valid_profiles has unique fids
valid_profiles = valid_profiles.unique(subset=['fid'])

# 5. Get profiles with fnames (already deduplicated above)
valid_fids = valid_profiles['fid'].unique()

# 6. Filter unlabeled data to only include profiles with fnames
unlabeled_data_filtered = unlabeled_data.filter(pl.col('fid').is_in(valid_fids))

In [None]:
# Check if 'profile_update_consistency' column exists
if 'profile_update_consistency' in unlabeled_data_filtered.columns:
    print(unlabeled_data_filtered.select(['profile_update_consistency']).describe())
else:
    print("Column 'profile_update_consistency' not found in unlabeled_data_filtered.")

# Compute statistics for given features
for feature in ['profile_update_consistency', 'influence_score', 'follower_count']:
    if feature in data.columns:
        bot_values = data.filter(pl.col('bot') == 1)[feature]
        human_values = data.filter(pl.col('bot') == 0)[feature]

        # Print statistics
        print(f"\n{feature} statistics:")
        print("Bot mean:", bot_values.mean())
        print("Human mean:", human_values.mean())
        print("Bot std:", bot_values.std())
        print("Human std:", human_values.std())
    else:
        print(f"\nFeature '{feature}' does not exist in the data DataFrame.")

# 7. Prepare features and predictions for unlabeled data
X_unlabeled_filtered, valid_features = detector.prepare_features(unlabeled_data_filtered.drop("fid"))

# Debug prints to ensure lengths match
print(f"unlabeled_data_filtered shape: {unlabeled_data_filtered.shape}")
print(f"X_unlabeled_filtered shape: {X_unlabeled_filtered.shape}")
print(f"model feature length: {len(detector.feature_names)}")
if X_unlabeled_filtered.shape[0] != len(unlabeled_data_filtered):
    print("Warning: Length mismatch between unlabeled_data_filtered and X_unlabeled_filtered!")

In [None]:
y_pred_proba_filtered = detector.model.predict_proba(X_unlabeled_filtered)[:, 1]
y_pred_filtered = (y_pred_proba_filtered >= 0.5).astype(int)

# Compute SHAP values for unlabeled data
# Extract the underlying xgb model from the ensemble if needed
xgb_model = detector.base_models['xgb'].base_estimator_ if hasattr(detector.base_models['xgb'], 'base_estimator_') else detector.base_models['xgb']

explainer_unlabeled = shap.TreeExplainer(xgb_model.estimator)
shap_values_unlabeled = explainer_unlabeled.shap_values(X_unlabeled_filtered)


# 8. Get SHAP explanations for human predictions
human_explanations = []
# sample only 10 random accounts
fids_array = unlabeled_data_filtered.sample(10)['fid'].to_numpy()

# filter only values that would have a SHAP value 

for i, (pred, fid) in enumerate(zip(y_pred_filtered, fids_array)):
    # Ensure i is within the bounds of X_unlabeled_filtered
    if i >= X_unlabeled_filtered.shape[0]:
        print(f"Index {i} is out of range for X_unlabeled_filtered of size {X_unlabeled_filtered.shape[0]}. Skipping.")
        continue

    if pred == 0:  # If predicted human
        # Retrieve SHAP values from the unlabeled SHAP explainer
        shap_values = detector.get_feature_explanations('xgb', X_unlabeled_filtered, i)

        # Filter row by fid
        row = unlabeled_data_filtered.filter(pl.col('fid') == fid)

        # Handle authenticity_score
        if 'authenticity_score' in row.columns and len(row) == 1:
            authenticity_score = row['authenticity_score'].item()
        else:
            authenticity_score = 0.0  # default

        # Get fname for this fid
        fname_series = valid_profiles.filter(pl.col('fid') == fid)['fname']
        if len(fname_series) == 1:
            fname = fname_series.item()
        elif len(fname_series) == 0:
            fname = "Unknown"
        else:
            fname = fname_series[0]

        # Only negative impact features
        negative_reasons = {k: v for k, v in shap_values.items() if v < 0}

        human_explanations.append({
            'fid': fid,
            'fname': fname,
            'authenticity_score': authenticity_score,
            'human_probability': 1 - y_pred_proba_filtered[i],
            'reasons': negative_reasons
        })

# Print results
sample_size = min(10, len(human_explanations))
random_accounts = random.sample(human_explanations, sample_size)

print("\nHuman accounts with explanations (random sample):")
for i, exp in enumerate(random_accounts, start=1):
    print(f"\n{i}. {exp['fname']} (FID: {exp['fid']})")
    print(f"Human Probability: {exp['human_probability']:.2%}")
    print(f"Authenticity Score: {exp['authenticity_score']:.2f}")
    print("Top reasons for human classification:")
    sorted_reasons = sorted(exp['reasons'].items(), key=lambda x: abs(x[1]), reverse=True)
    for feature, impact in sorted_reasons:
        print(f"  - {feature}: {abs(impact):.3f}")

In [None]:
# Calculate distribution for entire unlabeled dataset
total_profiles = len(y_pred_filtered)
humans_count = (y_pred_filtered == 0).sum()
bots_count = (y_pred_filtered == 1).sum()

print("\nOverall Distribution Analysis:")
print(f"Total profiles analyzed: {total_profiles:,}")
print(f"Predicted humans: {humans_count:,} ({(humans_count/total_profiles)*100:.1f}%)")
print(f"Predicted bots: {bots_count:,} ({(bots_count/total_profiles)*100:.1f}%)")

# Probability distribution analysis
print("\nPrediction Probability Analysis:")
print(f"Mean bot probability: {y_pred_proba_filtered.mean():.3f}")
print(f"Median bot probability: {np.median(y_pred_proba_filtered):.3f}")
print(f"Std dev of bot probability: {y_pred_proba_filtered.std():.3f}")

# Distribution buckets for more detailed view
buckets = np.histogram(y_pred_proba_filtered, bins=[0, 0.2, 0.4, 0.6, 0.8, 1.0])
bucket_counts = buckets[0]
bucket_ranges = ['0-20%', '20-40%', '40-60%', '60-80%', '80-100%']

print("\nProbability Distribution Breakdown:")
for range_name, count in zip(bucket_ranges, bucket_counts):
    percentage = (count/total_profiles)*100
    print(f"{range_name}: {count:,} accounts ({percentage:.1f}%)")

# High confidence predictions
high_conf_human = (y_pred_proba_filtered < 0.2).sum()
high_conf_bot = (y_pred_proba_filtered > 0.8).sum()

print("\nHigh Confidence Predictions:")
print(f"High confidence humans (p < 0.2): {high_conf_human:,} ({(high_conf_human/total_profiles)*100:.1f}%)")
print(f"High confidence bots (p > 0.8): {high_conf_bot:,} ({(high_conf_bot/total_profiles)*100:.1f}%)")

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Create the figure and axis
fig, ax = plt.subplots(figsize=(12, 6))

# Create histogram
bins = np.linspace(0, 1, 21)  # 20 bins for smooth distribution
n, bins, patches = ax.hist(y_pred_proba_filtered, bins=bins, edgecolor='black', alpha=0.7)

# Customize colors based on probability ranges
for i, patch in enumerate(patches):
    bin_center = bins[i] + (bins[1] - bins[0])/2
    if bin_center < 0.2:
        patch.set_facecolor('#2ecc71')  # Green for human predictions
    elif bin_center > 0.8:
        patch.set_facecolor('#e74c3c')  # Red for bot predictions
    else:
        patch.set_facecolor('#3498db')  # Blue for uncertain predictions

# Add vertical lines for key thresholds
ax.axvline(x=0.2, color='#27ae60', linestyle='--', alpha=0.5, label='Human threshold (0.2)')
ax.axvline(x=0.8, color='#c0392b', linestyle='--', alpha=0.5, label='Bot threshold (0.8)')

# Customize the plot
ax.set_title('Distribution of Bot Probability Scores', pad=20, fontsize=14)
ax.set_xlabel('Bot Probability Score', fontsize=12)
ax.set_ylabel('Number of Accounts', fontsize=12)

# Add grid for better readability
ax.grid(True, alpha=0.3)

# Add legend
ax.legend()

# Add text annotations for key statistics
stats_text = (
    f'Total Profiles: {total_profiles:,}\n'
    f'Mean Probability: {y_pred_proba_filtered.mean():.3f}\n'
    f'Median Probability: {np.median(y_pred_proba_filtered):.3f}\n'
    f'Std Dev: {y_pred_proba_filtered.std():.3f}'
)
plt.text(0.95, 0.95, stats_text,
         transform=ax.transAxes,
         verticalalignment='top',
         horizontalalignment='right',
         bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

# Adjust layout to prevent text cutoff
plt.tight_layout()

# Show the plot
plt.show()

In [None]:
def get_segment_features(segment: str) -> List[str]:
    base_features = ['has_ens', 'has_bio', 'has_avatar', 'verification_count']
    
    if segment == 'active':
        return base_features + [
            'cast_timing_entropy',
            'reply_ratio',
            'mention_patterns',
            'engagement_rate'
        ]
    elif segment == 'low_activity':
        return base_features + [
            'network_growth_rate',
            'initial_behavior_pattern',
            'verification_sequence'
        ]
    else:  # dormant
        return base_features + [
            'follower_growth_velocity',
            'network_structure',
            'profile_completion_sequence'
        ]

In [None]:
import polars as pl
import numpy as np
import pandas as pd
from typing import Dict, List, Tuple, Any
from sklearn.ensemble import RandomForestClassifier
from lightgbm import LGBMClassifier
import xgboost as xgb
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import make_scorer, precision_recall_fscore_support
from sklearn.model_selection import cross_validate

def segment_users_by_behavior(matrix: pl.DataFrame) -> Dict[str, pl.DataFrame]:
    """Segment users based on behavioral patterns"""
    segments = {
        'power_users': matrix.filter(
            (pl.col('cast_count') >= 20) & 
            (pl.col('reply_count') >= 5)
        ),
        'casual_users': matrix.filter(
            (pl.col('cast_count') >= 5) & 
            (pl.col('cast_count') < 20)
        ),
        'one_time_users': matrix.filter(
            (pl.col('cast_count') > 0) & 
            (pl.col('cast_count') < 5)
        ),
        'lurkers': matrix.filter(pl.col('cast_count') == 0)
    }
    
    total = len(matrix)
    print("\nUser Segment Distribution:")
    for name, segment in segments.items():
        size = len(segment)
        if size > 0 and 'bot' in segment.columns:
            bot_pct = (segment.filter(pl.col('bot') == 1).shape[0] / size) * 100
            print(f"{name}: {size:,} users ({size/total*100:.1f}%) - {bot_pct:.1f}% bots")
            
            metrics = segment.select([
                pl.col('cast_count').mean(),
                pl.col('follower_count').mean(),
                pl.col('following_count').mean()
            ]).to_numpy()[0]
            
            print(f"  Avg casts: {metrics[0]:.1f}")
            print(f"  Avg followers: {metrics[1]:.1f}")
            print(f"  Avg following: {metrics[2]:.1f}")
        else:
            print(f"{name}: {size:,} users ({size/total*100:.1f}%)")
            
    return segments

def get_segment_specific_features(segment_name: str) -> List[str]:
    """Get feature list specific to each behavior segment"""
    base_features = [
        'has_ens', 'has_bio', 'has_avatar', 'verification_count',  # verification_count here
        'following_count', 'follower_count', 'follower_ratio',
        'unique_follower_ratio', 'authenticity_score'
    ]
    
    segment_features = {
        'power_users': [
            'cast_count', 'total_reactions', 'avg_cast_length',
            'reply_count', 'mentions_count', 'engagement_score',
            'weekday_diversity', 'hour_diversity', 'rapid_actions',
            'avg_hours_between_actions', 'std_hours_between_actions',
            'power_user_interaction_ratio', 'influence_score'
        ],
        'casual_users': [
            'cast_count', 'total_reactions', 'engagement_score',
            'reply_count', 'rapid_actions', 'avg_hours_between_actions',
            'avg_cast_length', 'mentions_count'
        ],
        'one_time_users': [
            'cast_count', 'total_reactions', 'profile_update_consistency',
            'follower_growth_rate'
        ],
        'lurkers': [
            'profile_update_consistency', 'network_balance',
            'follower_growth_rate',
            'profile_completeness'
        ]
    }
    
    return base_features + segment_features.get(segment_name, [])

class RapidModelEvaluator:
    def __init__(self, n_cv_splits: int = 5):
        self.n_cv_splits = n_cv_splits
        self.results = {}
        
    def evaluate_model(self, name: str, model: Any, X: np.ndarray, y: np.ndarray) -> Dict:
        """Quickly evaluate a model with cross-validation"""
        scoring = {
            'precision': make_scorer(lambda y_true, y_pred: 
                precision_recall_fscore_support(y_true, y_pred)[0].mean()),
            'recall': make_scorer(lambda y_true, y_pred: 
                precision_recall_fscore_support(y_true, y_pred)[1].mean()),
            'f1': 'f1'
        }
        
        cv_results = cross_validate(
            model, X, y,
            cv=self.n_cv_splits,
            scoring=scoring,
            return_train_score=True,
            n_jobs=-1
        )
        
        self.results[name] = {
            'test_scores': {
                'precision': cv_results['test_precision'].mean(),
                'recall': cv_results['test_recall'].mean(),
                'f1': cv_results['test_f1'].mean()
            },
            'fit_time': cv_results['fit_time'].mean()
        }
        
        return self.results[name]
    
    def compare_models(self, segment_name: str, models: Dict[str, Any], 
                      X: np.ndarray, y: np.ndarray) -> pd.DataFrame:
        """Compare multiple models quickly"""
        self.results = {}  # Reset results for each segment
        for name, model in models.items():
            print(f"Evaluating {name} on {segment_name} segment...")
            self.evaluate_model(f"{name}", model, X, y)
        
        results_df = pd.DataFrame.from_dict(
            {k: v['test_scores'] for k, v in self.results.items()}, 
            orient='index'
        )
        results_df['fit_time'] = [v['fit_time'] for v in self.results.values()]
        
        return results_df.sort_values('f1', ascending=False)

def prepare_segment_features(segment: pl.DataFrame, segment_name: str) -> Tuple[np.ndarray, np.ndarray]:
    """Prepare features for a given segment"""
    feature_cols = get_segment_specific_features(segment_name)
    valid_features = [col for col in feature_cols if col in segment.columns]
    print(f"\nUsing {len(valid_features)} features for {segment_name}:", valid_features)
    
    X = segment.select(valid_features).fill_null(0).to_numpy()
    y = segment['bot'].to_numpy() if 'bot' in segment.columns else None
    
    scaler = StandardScaler()
    X = scaler.fit_transform(X)
    
    return X, y, valid_features  # Return valid_features as well

def evaluate_segments(matrix: pl.DataFrame, labels_df: pl.DataFrame) -> Dict:
    """Evaluate models on each behavioral segment"""
    # Ensure consistent FID type
    matrix = matrix.with_columns(pl.col('fid').cast(pl.Int64))
    labels_df = labels_df.with_columns(pl.col('fid').cast(pl.Int64))
    
    # Join with labels
    data = matrix.join(labels_df, on='fid', how='inner')
    
    # Create segments
    segmented_users = segment_users_by_behavior(data)
    
    # Initialize evaluator
    evaluator = RapidModelEvaluator()
    
    # Store results
    segment_results = {}
    
    # For each segment
    for segment_name, segment_data in segmented_users.items():
        print(f"\nProcessing {segment_name} segment...")
        
        # Prepare features
        X, y, valid_features = prepare_segment_features(segment_data, segment_name)
        
        if y is None or len(np.unique(y)) < 2:
            print(f"Skipping {segment_name} - insufficient labels")
            continue
        
        # Define models with balanced class weights
        models = {
            'xgb': xgb.XGBClassifier(
                n_estimators=100,
                learning_rate=0.05,
                max_depth=4,
                scale_pos_weight=sum(y == 0) / sum(y == 1),
                random_state=42
            ),
            'lgbm': LGBMClassifier(
                n_estimators=100,
                learning_rate=0.05,
                max_depth=4,
                class_weight='balanced',
                random_state=42
            ),
            'rf': RandomForestClassifier(
                n_estimators=100,
                max_depth=4,
                class_weight='balanced',
                random_state=42
            ),
            'logistic': LogisticRegression(
                max_iter=1000,
                class_weight='balanced',
                random_state=42
            )
        }
        
        # Evaluate models
        results = evaluator.compare_models(segment_name, models, X, y)
        segment_results[segment_name] = results
        
        print(f"\nResults for {segment_name}:")
        print(results)
        
        # Print feature importance for best model
        try:
            best_model_name = results.index[0]
            best_model = models[best_model_name]
            
            if hasattr(best_model, 'feature_importances_'):
                importances = pd.DataFrame({
                    'feature': valid_features,
                    'importance': best_model.feature_importances_
                }).sort_values('importance', ascending=False)
                
                print(f"\nTop 10 features for {segment_name}:")
                print(importances.head(10))
            else:
                print(f"\nNo feature importances available for {best_model_name}")
                
        except Exception as e:
            print(f"Error calculating feature importance: {str(e)}")
            continue
    
    return segment_results


results = evaluate_segments(matrix, labels_df)

print("\nBest models per segment:")
for segment_name, result_df in results.items():
    best_model = result_df.index[0]
    best_f1 = result_df.iloc[0]['f1']
    best_precision = result_df.iloc[0]['precision']
    best_recall = result_df.iloc[0]['recall']
    print(f"{segment_name}:")
    print(f"  Best model: {best_model}")
    print(f"  F1: {best_f1:.3f}")
    print(f"  Precision: {best_precision:.3f}")
    print(f"  Recall: {best_recall:.3f}")

In [None]:
import polars as pl
import numpy as np
from typing import Dict, List, Tuple
from sklearn.ensemble import IsolationForest
from sklearn.preprocessing import StandardScaler

def analyze_full_distribution(matrix: pl.DataFrame, labels_df: pl.DataFrame) -> Tuple[Dict, Dict]:
    """Analyze distribution in both full and labeled datasets"""
    print("=== Full Dataset Distribution ===")
    full_segments = segment_users_by_behavior(matrix)
    
    print("\n=== Labeled Dataset Distribution ===")
    labeled_data = matrix.join(labels_df, on='fid', how='inner')
    labeled_segments = segment_users_by_behavior(labeled_data)
    
    # Calculate coverage
    print("\n=== Label Coverage by Segment ===")
    for segment_name in full_segments.keys():
        full_count = len(full_segments[segment_name])
        labeled_count = len(labeled_segments.get(segment_name, pl.DataFrame()))
        coverage = (labeled_count / full_count * 100) if full_count > 0 else 0
        print(f"{segment_name}: {coverage:.1f}% labeled ({labeled_count}/{full_count})")
        
    return full_segments, labeled_segments
def get_unlabeled_samples(matrix: pl.DataFrame, 
                         labels_df: pl.DataFrame, 
                         samples_per_segment: int = 50,
                         use_isolation_forest: bool = True) -> Dict[str, pl.DataFrame]:
    """Get stratified samples of unlabeled data, optionally using anomaly detection"""
    
    # Get unlabeled FIDs
    labeled_fids = labels_df['fid'].unique()
    unlabeled = matrix.filter(~pl.col('fid').is_in(labeled_fids))
    
    # Segment unlabeled data
    segments = segment_users_by_behavior(unlabeled)
    
    # Features for anomaly detection
    anomaly_features = [
        'cast_count', 'follower_count', 'following_count',
        'authenticity_score', 'total_reactions', 'rapid_actions',
        'avg_hours_between_actions', 'std_hours_between_actions'
    ]
    
    samples = {}
    for name, segment in segments.items():
        print(f"\nProcessing {name} segment ({len(segment)} users)")
        
        if len(segment) == 0:
            continue
            
        if use_isolation_forest and len(segment) > samples_per_segment:
            # Prepare features for anomaly detection
            valid_features = [f for f in anomaly_features if f in segment.columns]
            if len(valid_features) > 0:
                X = segment.select(valid_features).fill_null(0).to_numpy()
                X = StandardScaler().fit_transform(X)
                
                # Use Isolation Forest to identify anomalies
                iso_forest = IsolationForest(
                    n_estimators=100,
                    contamination=0.1,  # Assume 10% anomalies
                    random_state=42
                )
                
                # Get anomaly scores
                scores = iso_forest.fit_predict(X)
                anomaly_indices = np.where(scores == -1)[0]
                normal_indices = np.where(scores == 1)[0]
                
                # Sample both anomalies and normal cases
                n_anomalies = min(samples_per_segment // 4, len(anomaly_indices))
                n_normal = samples_per_segment - n_anomalies
                
                # Create a filter for selected indices
                selected_indices = np.concatenate([
                    np.random.choice(anomaly_indices, n_anomalies, replace=False),
                    np.random.choice(normal_indices, n_normal, replace=False)
                ])
                
                # Create a row number column and filter by selected indices
                samples[name] = (segment
                    .with_row_count("row_nr")
                    .filter(pl.col("row_nr").is_in(selected_indices))
                    .drop("row_nr"))
                
                print(f"Selected {n_anomalies} potential anomalies and {n_normal} normal cases")
            else:
                # Fallback to random sampling if features not available
                samples[name] = segment.sample(n=samples_per_segment, seed=42)
        else:
            # For small segments, take all samples
            n_samples = min(samples_per_segment, len(segment))
            samples[name] = segment.sample(n=n_samples, seed=42)
        
        print(f"Final sample size: {len(samples[name])}")
        
        # Print some statistics about the sample
        if 'cast_count' in segment.columns:
            stats = samples[name].select([
                pl.col('cast_count').mean().alias('avg_casts'),
                pl.col('follower_count').mean().alias('avg_followers'),
                pl.col('following_count').mean().alias('avg_following')
            ])
            print("Sample statistics:")
            print(f"  Avg casts: {stats['avg_casts'][0]:.1f}")
            print(f"  Avg followers: {stats['avg_followers'][0]:.1f}")
            print(f"  Avg following: {stats['avg_following'][0]:.1f}")
    
    return samples

def export_samples_for_labeling(samples: Dict[str, pl.DataFrame], 
                              output_path: str = "samples_for_labeling.csv"):
    """Export samples for manual labeling"""
    # Combine all samples
    all_samples = pl.concat([
        segment.with_columns(pl.lit(name).alias('segment'))
        for name, segment in samples.items()
    ])
    
    # Select relevant columns for labeling
    export_columns = [
        'fid', 'segment', 'fname',
        'cast_count', 'follower_count', 'following_count',
        'total_reactions', 'authenticity_score'
    ]
    
    # Only include columns that exist
    valid_columns = [col for col in export_columns if col in all_samples.columns]
    
    # Export to CSV
    all_samples.select(valid_columns).write_csv(output_path)
    print(f"\nExported {len(all_samples)} samples to {output_path}")
    
def suggest_priority_accounts(matrix: pl.DataFrame, 
                              labels_df: pl.DataFrame,
                              n_suggestions: int = 50) -> pl.DataFrame:
    """Suggest priority accounts for labeling based on influence and uncertainty"""
    
    # Get unlabeled accounts
    labeled_fids = labels_df['fid'].unique()
    unlabeled = matrix.filter(~pl.col('fid').is_in(labeled_fids))
    
    # Calculate influence score
    influence_features = [
        'follower_count', 'following_count', 'cast_count',
        'total_reactions', 'authenticity_score'
    ]
    
    # Only use available features
    valid_features = [f for f in influence_features if f in unlabeled.columns]
    
    if len(valid_features) > 0:
        # Normalize features
        normalized = unlabeled.select([
            'fid',
            *(pl.col(f).fill_null(0) / pl.col(f).fill_null(0).max() for f in valid_features)
        ])
        
        # Simple influence score - average of normalized features
        influence_df = normalized.with_columns([
            pl.fold(acc=0, function=lambda acc, x: acc + x, exprs=[pl.col(f) for f in valid_features])
            .alias('influence_score')
        ])
        
        # Get top influential accounts
        suggestions = influence_df.sort('influence_score', descending=True).head(n_suggestions)
        
        # Select only necessary columns for the join
        suggestions = suggestions.select(['fid', 'influence_score'])
        
        # Join back with original features, including 'fname'
        result = suggestions.join(unlabeled, on='fid', how='left', suffix="_unlabeled")
        
        # Ensure 'fname' is included in the final result
        if 'fname' in result.columns:
            return result.select(['fid', 'fname', 'influence_score', 'follower_count', 'cast_count', 'total_reactions'])
        else:
            print("Warning: 'fname' column not found in the dataset.")
            return result.select(['fid', 'influence_score', 'follower_count', 'cast_count', 'total_reactions'])
    else:
        print("No valid features found for influence calculation")
        return unlabeled.head(n_suggestions)


# Analyze current distribution
full_segments, labeled_segments = analyze_full_distribution(matrix, labels_df)

# Get samples for each segment
samples = get_unlabeled_samples(matrix, labels_df, samples_per_segment=50)

# Export samples for labeling
export_samples_for_labeling(samples)

# Get priority suggestions
priority_accounts = suggest_priority_accounts(matrix, labels_df, n_suggestions=50)

print("\nTop accounts to consider for labeling:")
print(priority_accounts.select([
    'fid', 'fname', 'influence_score', 'follower_count', 'cast_count', 'total_reactions'
]).head(10))

In [None]:
# First get indices where bot probability is high
bot_mask = y_pred_proba_filtered > 0.9
bot_fids = fids_array[bot_mask]

# Ensure FIDs are the same type (Int64)
high_conf_bots = unlabeled_data_filtered.filter(pl.col('fid').cast(pl.Int64).is_in(bot_fids))
bot_profiles = valid_profiles.with_columns(pl.col('fid').cast(pl.Int64)).filter(pl.col('fid').is_in(bot_fids))

# Sample 10 random high confidence bots
sample_size = min(10, len(bot_profiles))
random_indices = np.random.choice(range(len(bot_profiles)), sample_size, replace=False)
random_bot_sample = bot_profiles.slice(random_indices[0], sample_size)

print(f"\nSample of High Confidence Bot Predictions (from {len(bot_fids)} total high-confidence bots):")
print("-" * 50)

for row in random_bot_sample.iter_rows():
    fname = row[0]  # fname is first column
    fid = row[6]    # fid is last column
    display_name = row[1]  # display_name is second column
    
    # Get features for this bot
    bot_features = high_conf_bots.filter(pl.col('fid') == fid)
    
    # Find index in original arrays for probability and SHAP values
    fid_idx = np.where(fids_array == fid)[0][0]
    bot_prob = y_pred_proba_filtered[fid_idx]
    
    # Get SHAP values explaining why it's classified as a bot
    shap_values = detector.get_feature_explanations('xgb', X_unlabeled_filtered, fid_idx)
    
    print(f"\nUsername: {fname}")
    print(f"Display Name: {display_name}")
    print(f"FID: {fid}")
    print(f"Bot Probability: {bot_prob:.2%}")
    
    print("\nTop reasons for bot classification:")
    # Sort by absolute value but only show positive values (contributing to bot classification)
    sorted_reasons = sorted(shap_values.items(), key=lambda x: abs(x[1]), reverse=True)
    for feature, impact in sorted_reasons:
        if impact > 0:  # Only show features pushing towards bot classification
            print(f"  - {feature}: {impact:.3f}")
    
    print("\nKey Behavioral Metrics:")
    metrics_to_check = [
        'rapid_actions', 
        'std_hours_between_actions',
        'avg_hours_between_actions',
        'following_count',
        'follower_count',
        'total_activity',
        'hour_diversity',
        'weekday_diversity'
    ]
    
    for metric in metrics_to_check:
        if metric in bot_features.columns:
            val = bot_features.select(metric).item()
            if isinstance(val, (int, float)):
                print(f"{metric}: {val:.2f}")
            else:
                print(f"{metric}: {val}")
            
    print(f"\nProfile link: https://warpcast.com/{fname}")
    print("-" * 50)

print("\nTo verify these accounts, check for:")
print("1. Highly regular posting patterns (low std_hours_between_actions)")
print("2. Unusually high activity rates (high rapid_actions)")
print("3. Unnatural timing patterns (low hour_diversity and weekday_diversity)")
print("4. Suspicious follower/following ratios")

In [None]:
def get_segment_features(segment: str) -> List[str]:
    base_features = ['has_ens', 'has_bio', 'has_avatar', 'verification_count']
    
    if segment == 'active':
        return base_features + [
            'cast_timing_entropy',
            'reply_ratio',
            'mention_patterns',
            'engagement_rate'
        ]
    elif segment == 'low_activity':
        return base_features + [
            'network_growth_rate',
            'initial_behavior_pattern',
            'verification_sequence'
        ]
    else:  # dormant
        return base_features + [
            'follower_growth_velocity',
            'network_structure',
            'profile_completion_sequence'
        ]

In [None]:
import polars as pl
import numpy as np
import pandas as pd
from typing import Dict, List, Tuple, Any
from sklearn.ensemble import RandomForestClassifier
from lightgbm import LGBMClassifier
import xgboost as xgb
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import make_scorer, precision_recall_fscore_support
from sklearn.model_selection import cross_validate

def segment_users_by_behavior(matrix: pl.DataFrame) -> Dict[str, pl.DataFrame]:
    """Segment users based on behavioral patterns"""
    segments = {
        'power_users': matrix.filter(
            (pl.col('cast_count') >= 20) & 
            (pl.col('reply_count') >= 5)
        ),
        'casual_users': matrix.filter(
            (pl.col('cast_count') >= 5) & 
            (pl.col('cast_count') < 20)
        ),
        'one_time_users': matrix.filter(
            (pl.col('cast_count') > 0) & 
            (pl.col('cast_count') < 5)
        ),
        'lurkers': matrix.filter(pl.col('cast_count') == 0)
    }
    
    total = len(matrix)
    print("\nUser Segment Distribution:")
    for name, segment in segments.items():
        size = len(segment)
        if size > 0 and 'bot' in segment.columns:
            bot_pct = (segment.filter(pl.col('bot') == 1).shape[0] / size) * 100
            print(f"{name}: {size:,} users ({size/total*100:.1f}%) - {bot_pct:.1f}% bots")
            
            metrics = segment.select([
                pl.col('cast_count').mean(),
                pl.col('follower_count').mean(),
                pl.col('following_count').mean()
            ]).to_numpy()[0]
            
            print(f"  Avg casts: {metrics[0]:.1f}")
            print(f"  Avg followers: {metrics[1]:.1f}")
            print(f"  Avg following: {metrics[2]:.1f}")
        else:
            print(f"{name}: {size:,} users ({size/total*100:.1f}%)")
            
    return segments

def get_segment_specific_features(segment_name: str) -> List[str]:
    """Get feature list specific to each behavior segment"""
    base_features = [
        'has_ens', 'has_bio', 'has_avatar', 'verification_count',  # verification_count here
        'following_count', 'follower_count', 'follower_ratio',
        'unique_follower_ratio', 'authenticity_score'
    ]
    
    segment_features = {
        'power_users': [
            'cast_count', 'total_reactions', 'avg_cast_length',
            'reply_count', 'mentions_count', 'engagement_score',
            'weekday_diversity', 'hour_diversity', 'rapid_actions',
            'avg_hours_between_actions', 'std_hours_between_actions',
            'power_user_interaction_ratio', 'influence_score'
        ],
        'casual_users': [
            'cast_count', 'total_reactions', 'engagement_score',
            'reply_count', 'rapid_actions', 'avg_hours_between_actions',
            'avg_cast_length', 'mentions_count'
        ],
        'one_time_users': [
            'cast_count', 'total_reactions', 'profile_update_consistency',
            'follower_growth_rate'
        ],
        'lurkers': [
            'profile_update_consistency', 'network_balance',
            'follower_growth_rate',
            'profile_completeness'
        ]
    }
    
    return base_features + segment_features.get(segment_name, [])

class RapidModelEvaluator:
    def __init__(self, n_cv_splits: int = 5):
        self.n_cv_splits = n_cv_splits
        self.results = {}
        
    def evaluate_model(self, name: str, model: Any, X: np.ndarray, y: np.ndarray) -> Dict:
        """Quickly evaluate a model with cross-validation"""
        scoring = {
            'precision': make_scorer(lambda y_true, y_pred: 
                precision_recall_fscore_support(y_true, y_pred)[0].mean()),
            'recall': make_scorer(lambda y_true, y_pred: 
                precision_recall_fscore_support(y_true, y_pred)[1].mean()),
            'f1': 'f1'
        }
        
        cv_results = cross_validate(
            model, X, y,
            cv=self.n_cv_splits,
            scoring=scoring,
            return_train_score=True,
            n_jobs=-1
        )
        
        self.results[name] = {
            'test_scores': {
                'precision': cv_results['test_precision'].mean(),
                'recall': cv_results['test_recall'].mean(),
                'f1': cv_results['test_f1'].mean()
            },
            'fit_time': cv_results['fit_time'].mean()
        }
        
        return self.results[name]
    
    def compare_models(self, segment_name: str, models: Dict[str, Any], 
                      X: np.ndarray, y: np.ndarray) -> pd.DataFrame:
        """Compare multiple models quickly"""
        self.results = {}  # Reset results for each segment
        for name, model in models.items():
            print(f"Evaluating {name} on {segment_name} segment...")
            self.evaluate_model(f"{name}", model, X, y)
        
        results_df = pd.DataFrame.from_dict(
            {k: v['test_scores'] for k, v in self.results.items()}, 
            orient='index'
        )
        results_df['fit_time'] = [v['fit_time'] for v in self.results.values()]
        
        return results_df.sort_values('f1', ascending=False)

def prepare_segment_features(segment: pl.DataFrame, segment_name: str) -> Tuple[np.ndarray, np.ndarray]:
    """Prepare features for a given segment"""
    feature_cols = get_segment_specific_features(segment_name)
    valid_features = [col for col in feature_cols if col in segment.columns]
    print(f"\nUsing {len(valid_features)} features for {segment_name}:", valid_features)
    
    X = segment.select(valid_features).fill_null(0).to_numpy()
    y = segment['bot'].to_numpy() if 'bot' in segment.columns else None
    
    scaler = StandardScaler()
    X = scaler.fit_transform(X)
    
    return X, y, valid_features  # Return valid_features as well

def evaluate_segments(matrix: pl.DataFrame, labels_df: pl.DataFrame) -> Dict:
    """Evaluate models on each behavioral segment"""
    # Ensure consistent FID type
    matrix = matrix.with_columns(pl.col('fid').cast(pl.Int64))
    labels_df = labels_df.with_columns(pl.col('fid').cast(pl.Int64))
    
    # Join with labels
    data = matrix.join(labels_df, on='fid', how='inner')
    
    # Create segments
    segmented_users = segment_users_by_behavior(data)
    
    # Initialize evaluator
    evaluator = RapidModelEvaluator()
    
    # Store results
    segment_results = {}
    
    # For each segment
    for segment_name, segment_data in segmented_users.items():
        print(f"\nProcessing {segment_name} segment...")
        
        # Prepare features
        X, y, valid_features = prepare_segment_features(segment_data, segment_name)
        
        if y is None or len(np.unique(y)) < 2:
            print(f"Skipping {segment_name} - insufficient labels")
            continue
        
        # Define models with balanced class weights
        models = {
            'xgb': xgb.XGBClassifier(
                n_estimators=100,
                learning_rate=0.05,
                max_depth=4,
                scale_pos_weight=sum(y == 0) / sum(y == 1),
                random_state=42
            ),
            'lgbm': LGBMClassifier(
                n_estimators=100,
                learning_rate=0.05,
                max_depth=4,
                class_weight='balanced',
                random_state=42
            ),
            'rf': RandomForestClassifier(
                n_estimators=100,
                max_depth=4,
                class_weight='balanced',
                random_state=42
            ),
            'logistic': LogisticRegression(
                max_iter=1000,
                class_weight='balanced',
                random_state=42
            )
        }
        
        # Evaluate models
        results = evaluator.compare_models(segment_name, models, X, y)
        segment_results[segment_name] = results
        
        print(f"\nResults for {segment_name}:")
        print(results)
        
        # Print feature importance for best model
        try:
            best_model_name = results.index[0]
            best_model = models[best_model_name]
            
            if hasattr(best_model, 'feature_importances_'):
                importances = pd.DataFrame({
                    'feature': valid_features,
                    'importance': best_model.feature_importances_
                }).sort_values('importance', ascending=False)
                
                print(f"\nTop 10 features for {segment_name}:")
                print(importances.head(10))
            else:
                print(f"\nNo feature importances available for {best_model_name}")
                
        except Exception as e:
            print(f"Error calculating feature importance: {str(e)}")
            continue
    
    return segment_results


results = evaluate_segments(matrix, labels_df)

print("\nBest models per segment:")
for segment_name, result_df in results.items():
    best_model = result_df.index[0]
    best_f1 = result_df.iloc[0]['f1']
    best_precision = result_df.iloc[0]['precision']
    best_recall = result_df.iloc[0]['recall']
    print(f"{segment_name}:")
    print(f"  Best model: {best_model}")
    print(f"  F1: {best_f1:.3f}")
    print(f"  Precision: {best_precision:.3f}")
    print(f"  Recall: {best_recall:.3f}")

In [None]:
import polars as pl
import numpy as np
from typing import Dict, List, Tuple
from sklearn.ensemble import IsolationForest
from sklearn.preprocessing import StandardScaler

def analyze_full_distribution(matrix: pl.DataFrame, labels_df: pl.DataFrame) -> Tuple[Dict, Dict]:
    """Analyze distribution in both full and labeled datasets"""
    print("=== Full Dataset Distribution ===")
    full_segments = segment_users_by_behavior(matrix)
    
    print("\n=== Labeled Dataset Distribution ===")
    labeled_data = matrix.join(labels_df, on='fid', how='inner')
    labeled_segments = segment_users_by_behavior(labeled_data)
    
    # Calculate coverage
    print("\n=== Label Coverage by Segment ===")
    for segment_name in full_segments.keys():
        full_count = len(full_segments[segment_name])
        labeled_count = len(labeled_segments.get(segment_name, pl.DataFrame()))
        coverage = (labeled_count / full_count * 100) if full_count > 0 else 0
        print(f"{segment_name}: {coverage:.1f}% labeled ({labeled_count}/{full_count})")
        
    return full_segments, labeled_segments
def get_unlabeled_samples(matrix: pl.DataFrame, 
                         labels_df: pl.DataFrame, 
                         samples_per_segment: int = 50,
                         use_isolation_forest: bool = True) -> Dict[str, pl.DataFrame]:
    """Get stratified samples of unlabeled data, optionally using anomaly detection"""
    
    # Get unlabeled FIDs
    labeled_fids = labels_df['fid'].unique()
    unlabeled = matrix.filter(~pl.col('fid').is_in(labeled_fids))
    
    # Segment unlabeled data
    segments = segment_users_by_behavior(unlabeled)
    
    # Features for anomaly detection
    anomaly_features = [
        'cast_count', 'follower_count', 'following_count',
        'authenticity_score', 'total_reactions', 'rapid_actions',
        'avg_hours_between_actions', 'std_hours_between_actions'
    ]
    
    samples = {}
    for name, segment in segments.items():
        print(f"\nProcessing {name} segment ({len(segment)} users)")
        
        if len(segment) == 0:
            continue
            
        if use_isolation_forest and len(segment) > samples_per_segment:
            # Prepare features for anomaly detection
            valid_features = [f for f in anomaly_features if f in segment.columns]
            if len(valid_features) > 0:
                X = segment.select(valid_features).fill_null(0).to_numpy()
                X = StandardScaler().fit_transform(X)
                
                # Use Isolation Forest to identify anomalies
                iso_forest = IsolationForest(
                    n_estimators=100,
                    contamination=0.1,  # Assume 10% anomalies
                    random_state=42
                )
                
                # Get anomaly scores
                scores = iso_forest.fit_predict(X)
                anomaly_indices = np.where(scores == -1)[0]
                normal_indices = np.where(scores == 1)[0]
                
                # Sample both anomalies and normal cases
                n_anomalies = min(samples_per_segment // 4, len(anomaly_indices))
                n_normal = samples_per_segment - n_anomalies
                
                # Create a filter for selected indices
                selected_indices = np.concatenate([
                    np.random.choice(anomaly_indices, n_anomalies, replace=False),
                    np.random.choice(normal_indices, n_normal, replace=False)
                ])
                
                # Create a row number column and filter by selected indices
                samples[name] = (segment
                    .with_row_count("row_nr")
                    .filter(pl.col("row_nr").is_in(selected_indices))
                    .drop("row_nr"))
                
                print(f"Selected {n_anomalies} potential anomalies and {n_normal} normal cases")
            else:
                # Fallback to random sampling if features not available
                samples[name] = segment.sample(n=samples_per_segment, seed=42)
        else:
            # For small segments, take all samples
            n_samples = min(samples_per_segment, len(segment))
            samples[name] = segment.sample(n=n_samples, seed=42)
        
        print(f"Final sample size: {len(samples[name])}")
        
        # Print some statistics about the sample
        if 'cast_count' in segment.columns:
            stats = samples[name].select([
                pl.col('cast_count').mean().alias('avg_casts'),
                pl.col('follower_count').mean().alias('avg_followers'),
                pl.col('following_count').mean().alias('avg_following')
            ])
            print("Sample statistics:")
            print(f"  Avg casts: {stats['avg_casts'][0]:.1f}")
            print(f"  Avg followers: {stats['avg_followers'][0]:.1f}")
            print(f"  Avg following: {stats['avg_following'][0]:.1f}")
    
    return samples

def export_samples_for_labeling(samples: Dict[str, pl.DataFrame], 
                              output_path: str = "samples_for_labeling.csv"):
    """Export samples for manual labeling"""
    # Combine all samples
    all_samples = pl.concat([
        segment.with_columns(pl.lit(name).alias('segment'))
        for name, segment in samples.items()
    ])
    
    # Select relevant columns for labeling
    export_columns = [
        'fid', 'segment', 'fname',
        'cast_count', 'follower_count', 'following_count',
        'total_reactions', 'authenticity_score'
    ]
    
    # Only include columns that exist
    valid_columns = [col for col in export_columns if col in all_samples.columns]
    
    # Export to CSV
    all_samples.select(valid_columns).write_csv(output_path)
    print(f"\nExported {len(all_samples)} samples to {output_path}")
    
def suggest_priority_accounts(matrix: pl.DataFrame, 
                              labels_df: pl.DataFrame,
                              n_suggestions: int = 50) -> pl.DataFrame:
    """Suggest priority accounts for labeling based on influence and uncertainty"""
    
    # Get unlabeled accounts
    labeled_fids = labels_df['fid'].unique()
    unlabeled = matrix.filter(~pl.col('fid').is_in(labeled_fids))
    
    # Calculate influence score
    influence_features = [
        'follower_count', 'following_count', 'cast_count',
        'total_reactions', 'authenticity_score'
    ]
    
    # Only use available features
    valid_features = [f for f in influence_features if f in unlabeled.columns]
    
    if len(valid_features) > 0:
        # Normalize features
        normalized = unlabeled.select([
            'fid',
            *(pl.col(f).fill_null(0) / pl.col(f).fill_null(0).max() for f in valid_features)
        ])
        
        # Simple influence score - average of normalized features
        influence_df = normalized.with_columns([
            pl.fold(acc=0, function=lambda acc, x: acc + x, exprs=[pl.col(f) for f in valid_features])
            .alias('influence_score')
        ])
        
        # Get top influential accounts
        suggestions = influence_df.sort('influence_score', descending=True).head(n_suggestions)
        
        # Select only necessary columns for the join
        suggestions = suggestions.select(['fid', 'influence_score'])
        
        # Join back with original features, including 'fname'
        result = suggestions.join(unlabeled, on='fid', how='left', suffix="_unlabeled")
        
        # Ensure 'fname' is included in the final result
        if 'fname' in result.columns:
            return result.select(['fid', 'fname', 'influence_score', 'follower_count', 'cast_count', 'total_reactions'])
        else:
            print("Warning: 'fname' column not found in the dataset.")
            return result.select(['fid', 'influence_score', 'follower_count', 'cast_count', 'total_reactions'])
    else:
        print("No valid features found for influence calculation")
        return unlabeled.head(n_suggestions)


# Analyze current distribution
full_segments, labeled_segments = analyze_full_distribution(matrix, labels_df)

# Get samples for each segment
samples = get_unlabeled_samples(matrix, labels_df, samples_per_segment=50)

# Export samples for labeling
export_samples_for_labeling(samples)

# Get priority suggestions
priority_accounts = suggest_priority_accounts(matrix, labels_df, n_suggestions=50)

print("\nTop accounts to consider for labeling:")
print(priority_accounts.select([
    'fid', 'fname', 'influence_score', 'follower_count', 'cast_count', 'total_reactions'
]).head(10))

In [None]:

# Re-initialize after building
feature_eng = FeatureEngineering("data", "checkpoints")
detector = SybilDetectionSystem(feature_eng)

# Build feature matrix
matrix = feature_eng.build_feature_matrix()
print("Feature matrix built")

# Load labels
labels_df = pl.read_csv('data/labels.csv')
labels_df = labels_df.with_columns(pl.col('fid').cast(pl.Int64))

# Ensure matrix fid is Int64
matrix = matrix.with_columns(pl.col('fid').cast(pl.Int64))

# Join with matching types
data = matrix.join(labels_df, on='fid', how='inner')

# Extract features and labels
X, feature_names = detector.prepare_features(data.drop(['bot', 'fid']))
y = data['bot'].to_numpy()

# Perform train/test split
from sklearn.model_selection import train_test_split
fids = data['fid'].to_numpy()

X_train, X_test, y_train, y_test, train_fids, test_fids = train_test_split(
    X, y, fids,
    test_size=0.2,
    random_state=42,
    stratify=y  # ensures class distribution is preserved
)

# Define a path to save the model checkpoint
model_checkpoint_path = "checkpoints/sybil_detector_model.pkl"
os.makedirs("checkpoints", exist_ok=True)

# Check if a model checkpoint already exists
if os.path.exists(model_checkpoint_path):
    # Load the detector object (including the trained model)
    detector = joblib.load(model_checkpoint_path)
    print("Loaded model checkpoint. Skipping training.")
else:
    # Train the model using the training set only
    detector.train(X_train, y_train, feature_names)

    # Save the model checkpoint
    joblib.dump(detector, model_checkpoint_path)
    print(f"Model checkpoint saved to {model_checkpoint_path}")

# Evaluate on the test set
y_pred_proba = detector.model.predict_proba(X_test)[:, 1]
y_pred = (y_pred_proba >= 0.5).astype(int)

# Compute and plot SHAP for X_test:
shap_values_test = detector.shap_explainers['xgb'].shap_values(X_test)
if isinstance(shap_values_test, list) and len(shap_values_test) > 1:
    shap_values_test = shap_values_test[1]
assert shap_values_test.shape[0] == X_test.shape[0], "Mismatch in rows between shap_values and X_test!"
shap.summary_plot(shap_values_test, X_test, feature_names=feature_names)
import matplotlib.pyplot as plt
plt.show()

# Compute evaluation metrics
test_roc_auc = roc_auc_score(y_test, y_pred_proba)
test_f1 = f1_score(y_test, y_pred)
test_precision = precision_score(y_test, y_pred)
test_recall = recall_score(y_test, y_pred)

print("\nTest ROC AUC:", test_roc_auc)
print("Test F1 Score:", test_f1)
print("Test Precision:", test_precision)
print("Test Recall:", test_recall)

# SHAP explanations for a specific instance
instance_index = 0
model_name = 'xgb'
explanations = detector.get_feature_explanations(model_name, X_test, instance_index)
print(f"\nSHAP explanations for instance {instance_index} from {model_name}:")
print(explanations)

from sklearn.metrics import confusion_matrix
import seaborn as sns

cm = confusion_matrix(y_test, y_pred)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['Authentic', 'Bot'], yticklabels=['Authentic', 'Bot'])
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.show()

from sklearn.metrics import matthews_corrcoef, cohen_kappa_score

mcc = matthews_corrcoef(y_test, y_pred)
kappa = cohen_kappa_score(y_test, y_pred)

print("Matthews Correlation Coefficient:", mcc)
print("Cohen's Kappa:", kappa)


In [None]:
import random

# 1. Get all FIDs from the full matrix
all_fids = matrix['fid'].unique()
labeled_fids = data['fid'].unique()

# 2. Convert to numpy arrays for set operations
all_fids_np = all_fids.to_numpy()
labeled_fids_np = labeled_fids.to_numpy()

# 3. Find FIDs that aren't in the labeled dataset
unlabeled_fids = np.setdiff1d(all_fids_np, labeled_fids_np)

# 4. Filter matrix to get unlabeled data
unlabeled_data = matrix.filter(pl.col('fid').is_in(unlabeled_fids))

# 5. Get profiles with fnames
ds = feature_eng.loader.get_dataset('profile_with_addresses')
valid_profiles = ds.filter(pl.col('fname').is_not_null())

# Ensure valid_profiles has unique fids
valid_profiles = valid_profiles.unique(subset=['fid'])

# 5. Get profiles with fnames (already deduplicated above)
valid_fids = valid_profiles['fid'].unique()

# 6. Filter unlabeled data to only include profiles with fnames
unlabeled_data_filtered = unlabeled_data.filter(pl.col('fid').is_in(valid_fids))

In [None]:
# Check if 'profile_update_consistency' column exists
if 'profile_update_consistency' in unlabeled_data_filtered.columns:
    print(unlabeled_data_filtered.select(['profile_update_consistency']).describe())
else:
    print("Column 'profile_update_consistency' not found in unlabeled_data_filtered.")

# Compute statistics for given features
for feature in ['profile_update_consistency', 'influence_score', 'follower_count']:
    if feature in data.columns:
        bot_values = data.filter(pl.col('bot') == 1)[feature]
        human_values = data.filter(pl.col('bot') == 0)[feature]

        # Print statistics
        print(f"\n{feature} statistics:")
        print("Bot mean:", bot_values.mean())
        print("Human mean:", human_values.mean())
        print("Bot std:", bot_values.std())
        print("Human std:", human_values.std())
    else:
        print(f"\nFeature '{feature}' does not exist in the data DataFrame.")

# 7. Prepare features and predictions for unlabeled data
X_unlabeled_filtered, valid_features = detector.prepare_features(unlabeled_data_filtered.drop("fid"))

# Debug prints to ensure lengths match
print(f"unlabeled_data_filtered shape: {unlabeled_data_filtered.shape}")
print(f"X_unlabeled_filtered shape: {X_unlabeled_filtered.shape}")
print(f"model feature length: {len(detector.feature_names)}")
if X_unlabeled_filtered.shape[0] != len(unlabeled_data_filtered):
    print("Warning: Length mismatch between unlabeled_data_filtered and X_unlabeled_filtered!")

In [None]:
y_pred_proba_filtered = detector.model.predict_proba(X_unlabeled_filtered)[:, 1]
y_pred_filtered = (y_pred_proba_filtered >= 0.5).astype(int)

# Compute SHAP values for unlabeled data
# Extract the underlying xgb model from the ensemble if needed
xgb_model = detector.base_models['xgb'].base_estimator_ if hasattr(detector.base_models['xgb'], 'base_estimator_') else detector.base_models['xgb']

explainer_unlabeled = shap.TreeExplainer(xgb_model.estimator)
shap_values_unlabeled = explainer_unlabeled.shap_values(X_unlabeled_filtered)


# 8. Get SHAP explanations for human predictions
human_explanations = []
# sample only 10 random accounts
fids_array = unlabeled_data_filtered.sample(10)['fid'].to_numpy()

# filter only values that would have a SHAP value 

for i, (pred, fid) in enumerate(zip(y_pred_filtered, fids_array)):
    # Ensure i is within the bounds of X_unlabeled_filtered
    if i >= X_unlabeled_filtered.shape[0]:
        print(f"Index {i} is out of range for X_unlabeled_filtered of size {X_unlabeled_filtered.shape[0]}. Skipping.")
        continue

    if pred == 0:  # If predicted human
        # Retrieve SHAP values from the unlabeled SHAP explainer
        shap_values = detector.get_feature_explanations('xgb', X_unlabeled_filtered, i)

        # Filter row by fid
        row = unlabeled_data_filtered.filter(pl.col('fid') == fid)

        # Handle authenticity_score
        if 'authenticity_score' in row.columns and len(row) == 1:
            authenticity_score = row['authenticity_score'].item()
        else:
            authenticity_score = 0.0  # default

        # Get fname for this fid
        fname_series = valid_profiles.filter(pl.col('fid') == fid)['fname']
        if len(fname_series) == 1:
            fname = fname_series.item()
        elif len(fname_series) == 0:
            fname = "Unknown"
        else:
            fname = fname_series[0]

        # Only negative impact features
        negative_reasons = {k: v for k, v in shap_values.items() if v < 0}

        human_explanations.append({
            'fid': fid,
            'fname': fname,
            'authenticity_score': authenticity_score,
            'human_probability': 1 - y_pred_proba_filtered[i],
            'reasons': negative_reasons
        })

# Print results
sample_size = min(10, len(human_explanations))
random_accounts = random.sample(human_explanations, sample_size)

print("\nHuman accounts with explanations (random sample):")
for i, exp in enumerate(random_accounts, start=1):
    print(f"\n{i}. {exp['fname']} (FID: {exp['fid']})")
    print(f"Human Probability: {exp['human_probability']:.2%}")
    print(f"Authenticity Score: {exp['authenticity_score']:.2f}")
    print("Top reasons for human classification:")
    sorted_reasons = sorted(exp['reasons'].items(), key=lambda x: abs(x[1]), reverse=True)
    for feature, impact in sorted_reasons:
        print(f"  - {feature}: {abs(impact):.3f}")

In [None]:
# Calculate distribution for entire unlabeled dataset
total_profiles = len(y_pred_filtered)
humans_count = (y_pred_filtered == 0).sum()
bots_count = (y_pred_filtered == 1).sum()

print("\nOverall Distribution Analysis:")
print(f"Total profiles analyzed: {total_profiles:,}")
print(f"Predicted humans: {humans_count:,} ({(humans_count/total_profiles)*100:.1f}%)")
print(f"Predicted bots: {bots_count:,} ({(bots_count/total_profiles)*100:.1f}%)")

# Probability distribution analysis
print("\nPrediction Probability Analysis:")
print(f"Mean bot probability: {y_pred_proba_filtered.mean():.3f}")
print(f"Median bot probability: {np.median(y_pred_proba_filtered):.3f}")
print(f"Std dev of bot probability: {y_pred_proba_filtered.std():.3f}")

# Distribution buckets for more detailed view
buckets = np.histogram(y_pred_proba_filtered, bins=[0, 0.2, 0.4, 0.6, 0.8, 1.0])
bucket_counts = buckets[0]
bucket_ranges = ['0-20%', '20-40%', '40-60%', '60-80%', '80-100%']

print("\nProbability Distribution Breakdown:")
for range_name, count in zip(bucket_ranges, bucket_counts):
    percentage = (count/total_profiles)*100
    print(f"{range_name}: {count:,} accounts ({percentage:.1f}%)")

# High confidence predictions
high_conf_human = (y_pred_proba_filtered < 0.2).sum()
high_conf_bot = (y_pred_proba_filtered > 0.8).sum()

print("\nHigh Confidence Predictions:")
print(f"High confidence humans (p < 0.2): {high_conf_human:,} ({(high_conf_human/total_profiles)*100:.1f}%)")
print(f"High confidence bots (p > 0.8): {high_conf_bot:,} ({(high_conf_bot/total_profiles)*100:.1f}%)")

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Create the figure and axis
fig, ax = plt.subplots(figsize=(12, 6))

# Create histogram
bins = np.linspace(0, 1, 21)  # 20 bins for smooth distribution
n, bins, patches = ax.hist(y_pred_proba_filtered, bins=bins, edgecolor='black', alpha=0.7)

# Customize colors based on probability ranges
for i, patch in enumerate(patches):
    bin_center = bins[i] + (bins[1] - bins[0])/2
    if bin_center < 0.2:
        patch.set_facecolor('#2ecc71')  # Green for human predictions
    elif bin_center > 0.8:
        patch.set_facecolor('#e74c3c')  # Red for bot predictions
    else:
        patch.set_facecolor('#3498db')  # Blue for uncertain predictions

# Add vertical lines for key thresholds
ax.axvline(x=0.2, color='#27ae60', linestyle='--', alpha=0.5, label='Human threshold (0.2)')
ax.axvline(x=0.8, color='#c0392b', linestyle='--', alpha=0.5, label='Bot threshold (0.8)')

# Customize the plot
ax.set_title('Distribution of Bot Probability Scores', pad=20, fontsize=14)
ax.set_xlabel('Bot Probability Score', fontsize=12)
ax.set_ylabel('Number of Accounts', fontsize=12)

# Add grid for better readability
ax.grid(True, alpha=0.3)

# Add legend
ax.legend()

# Add text annotations for key statistics
stats_text = (
    f'Total Profiles: {total_profiles:,}\n'
    f'Mean Probability: {y_pred_proba_filtered.mean():.3f}\n'
    f'Median Probability: {np.median(y_pred_proba_filtered):.3f}\n'
    f'Std Dev: {y_pred_proba_filtered.std():.3f}'
)
plt.text(0.95, 0.95, stats_text,
         transform=ax.transAxes,
         verticalalignment='top',
         horizontalalignment='right',
         bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

# Adjust layout to prevent text cutoff
plt.tight_layout()

# Show the plot
plt.show()

In [None]:
def get_segment_features(segment: str) -> List[str]:
    base_features = ['has_ens', 'has_bio', 'has_avatar', 'verification_count']
    
    if segment == 'active':
        return base_features + [
            'cast_timing_entropy',
            'reply_ratio',
            'mention_patterns',
            'engagement_rate'
        ]
    elif segment == 'low_activity':
        return base_features + [
            'network_growth_rate',
            'initial_behavior_pattern',
            'verification_sequence'
        ]
    else:  # dormant
        return base_features + [
            'follower_growth_velocity',
            'network_structure',
            'profile_completion_sequence'
        ]

In [None]:
import polars as pl
import numpy as np
import pandas as pd
from typing import Dict, List, Tuple, Any
from sklearn.ensemble import RandomForestClassifier
from lightgbm import LGBMClassifier
import xgboost as xgb
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import make_scorer, precision_recall_fscore_support
from sklearn.model_selection import cross_validate

def segment_users_by_behavior(matrix: pl.DataFrame) -> Dict[str, pl.DataFrame]:
    """Segment users based on behavioral patterns"""
    segments = {
        'power_users': matrix.filter(
            (pl.col('cast_count') >= 20) & 
            (pl.col('reply_count') >= 5)
        ),
        'casual_users': matrix.filter(
            (pl.col('cast_count') >= 5) & 
            (pl.col('cast_count') < 20)
        ),
        'one_time_users': matrix.filter(
            (pl.col('cast_count') > 0) & 
            (pl.col('cast_count') < 5)
        ),
        'lurkers': matrix.filter(pl.col('cast_count') == 0)
    }
    
    total = len(matrix)
    print("\nUser Segment Distribution:")
    for name, segment in segments.items():
        size = len(segment)
        if size > 0 and 'bot' in segment.columns:
            bot_pct = (segment.filter(pl.col('bot') == 1).shape[0] / size) * 100
            print(f"{name}: {size:,} users ({size/total*100:.1f}%) - {bot_pct:.1f}% bots")
            
            metrics = segment.select([
                pl.col('cast_count').mean(),
                pl.col('follower_count').mean(),
                pl.col('following_count').mean()
            ]).to_numpy()[0]
            
            print(f"  Avg casts: {metrics[0]:.1f}")
            print(f"  Avg followers: {metrics[1]:.1f}")
            print(f"  Avg following: {metrics[2]:.1f}")
        else:
            print(f"{name}: {size:,} users ({size/total*100:.1f}%)")
            
    return segments

def get_segment_specific_features(segment_name: str) -> List[str]:
    """Get feature list specific to each behavior segment"""
    base_features = [
        'has_ens', 'has_bio', 'has_avatar', 'verification_count',  # verification_count here
        'following_count', 'follower_count', 'follower_ratio',
        'unique_follower_ratio', 'authenticity_score'
    ]
    
    segment_features = {
        'power_users': [
            'cast_count', 'total_reactions', 'avg_cast_length',
            'reply_count', 'mentions_count', 'engagement_score',
            'weekday_diversity', 'hour_diversity', 'rapid_actions',
            'avg_hours_between_actions', 'std_hours_between_actions',
            'power_user_interaction_ratio', 'influence_score'
        ],
        'casual_users': [
            'cast_count', 'total_reactions', 'engagement_score',
            'reply_count', 'rapid_actions', 'avg_hours_between_actions',
            'avg_cast_length', 'mentions_count'
        ],
        'one_time_users': [
            'cast_count', 'total_reactions', 'profile_update_consistency',
            'follower_growth_rate'
        ],
        'lurkers': [
            'profile_update_consistency', 'network_balance',
            'follower_growth_rate',
            'profile_completeness'
        ]
    }
    
    return base_features + segment_features.get(segment_name, [])

class RapidModelEvaluator:
    def __init__(self, n_cv_splits: int = 5):
        self.n_cv_splits = n_cv_splits
        self.results = {}
        
    def evaluate_model(self, name: str, model: Any, X: np.ndarray, y: np.ndarray) -> Dict:
        """Quickly evaluate a model with cross-validation"""
        scoring = {
            'precision': make_scorer(lambda y_true, y_pred: 
                precision_recall_fscore_support(y_true, y_pred)[0].mean()),
            'recall': make_scorer(lambda y_true, y_pred: 
                precision_recall_fscore_support(y_true, y_pred)[1].mean()),
            'f1': 'f1'
        }
        
        cv_results = cross_validate(
            model, X, y,
            cv=self.n_cv_splits,
            scoring=scoring,
            return_train_score=True,
            n_jobs=-1
        )
        
        self.results[name] = {
            'test_scores': {
                'precision': cv_results['test_precision'].mean(),
                'recall': cv_results['test_recall'].mean(),
                'f1': cv_results['test_f1'].mean()
            },
            'fit_time': cv_results['fit_time'].mean()
        }
        
        return self.results[name]
    
    def compare_models(self, segment_name: str, models: Dict[str, Any], 
                      X: np.ndarray, y: np.ndarray) -> pd.DataFrame:
        """Compare multiple models quickly"""
        self.results = {}  # Reset results for each segment
        for name, model in models.items():
            print(f"Evaluating {name} on {segment_name} segment...")
            self.evaluate_model(f"{name}", model, X, y)
        
        results_df = pd.DataFrame.from_dict(
            {k: v['test_scores'] for k, v in self.results.items()}, 
            orient='index'
        )
        results_df['fit_time'] = [v['fit_time'] for v in self.results.values()]
        
        return results_df.sort_values('f1', ascending=False)

def prepare_segment_features(segment: pl.DataFrame, segment_name: str) -> Tuple[np.ndarray, np.ndarray]:
    """Prepare features for a given segment"""
    feature_cols = get_segment_specific_features(segment_name)
    valid_features = [col for col in feature_cols if col in segment.columns]
    print(f"\nUsing {len(valid_features)} features for {segment_name}:", valid_features)
    
    X = segment.select(valid_features).fill_null(0).to_numpy()
    y = segment['bot'].to_numpy() if 'bot' in segment.columns else None
    
    scaler = StandardScaler()
    X = scaler.fit_transform(X)
    
    return X, y, valid_features  # Return valid_features as well

def evaluate_segments(matrix: pl.DataFrame, labels_df: pl.DataFrame) -> Dict:
    """Evaluate models on each behavioral segment"""
    # Ensure consistent FID type
    matrix = matrix.with_columns(pl.col('fid').cast(pl.Int64))
    labels_df = labels_df.with_columns(pl.col('fid').cast(pl.Int64))
    
    # Join with labels
    data = matrix.join(labels_df, on='fid', how='inner')
    
    # Create segments
    segmented_users = segment_users_by_behavior(data)
    
    # Initialize evaluator
    evaluator = RapidModelEvaluator()
    
    # Store results
    segment_results = {}
    
    # For each segment
    for segment_name, segment_data in segmented_users.items():
        print(f"\nProcessing {segment_name} segment...")
        
        # Prepare features
        X, y, valid_features = prepare_segment_features(segment_data, segment_name)
        
        if y is None or len(np.unique(y)) < 2:
            print(f"Skipping {segment_name} - insufficient labels")
            continue
        
        # Define models with balanced class weights
        models = {
            'xgb': xgb.XGBClassifier(
                n_estimators=100,
                learning_rate=0.05,
                max_depth=4,
                scale_pos_weight=sum(y == 0) / sum(y == 1),
                random_state=42
            ),
            'lgbm': LGBMClassifier(
                n_estimators=100,
                learning_rate=0.05,
                max_depth=4,
                class_weight='balanced',
                random_state=42
            ),
            'rf': RandomForestClassifier(
                n_estimators=100,
                max_depth=4,
                class_weight='balanced',
                random_state=42
            ),
            'logistic': LogisticRegression(
                max_iter=1000,
                class_weight='balanced',
                random_state=42
            )
        }
        
        # Evaluate models
        results = evaluator.compare_models(segment_name, models, X, y)
        segment_results[segment_name] = results
        
        print(f"\nResults for {segment_name}:")
        print(results)
        
        # Print feature importance for best model
        try:
            best_model_name = results.index[0]
            best_model = models[best_model_name]
            
            if hasattr(best_model, 'feature_importances_'):
                importances = pd.DataFrame({
                    'feature': valid_features,
                    'importance': best_model.feature_importances_
                }).sort_values('importance', ascending=False)
                
                print(f"\nTop 10 features for {segment_name}:")
                print(importances.head(10))
            else:
                print(f"\nNo feature importances available for {best_model_name}")
                
        except Exception as e:
            print(f"Error calculating feature importance: {str(e)}")
            continue
    
    return segment_results


results = evaluate_segments(matrix, labels_df)

print("\nBest models per segment:")
for segment_name, result_df in results.items():
    best_model = result_df.index[0]
    best_f1 = result_df.iloc[0]['f1']
    best_precision = result_df.iloc[0]['precision']
    best_recall = result_df.iloc[0]['recall']
    print(f"{segment_name}:")
    print(f"  Best model: {best_model}")
    print(f"  F1: {best_f1:.3f}")
    print(f"  Precision: {best_precision:.3f}")
    print(f"  Recall: {best_recall:.3f}")

In [None]:
import polars as pl
import numpy as np
from typing import Dict, List, Tuple
from sklearn.ensemble import IsolationForest
from sklearn.preprocessing import StandardScaler

def analyze_full_distribution(matrix: pl.DataFrame, labels_df: pl.DataFrame) -> Tuple[Dict, Dict]:
    """Analyze distribution in both full and labeled datasets"""
    print("=== Full Dataset Distribution ===")
    full_segments = segment_users_by_behavior(matrix)
    
    print("\n=== Labeled Dataset Distribution ===")
    labeled_data = matrix.join(labels_df, on='fid', how='inner')
    labeled_segments = segment_users_by_behavior(labeled_data)
    
    # Calculate coverage
    print("\n=== Label Coverage by Segment ===")
    for segment_name in full_segments.keys():
        full_count = len(full_segments[segment_name])
        labeled_count = len(labeled_segments.get(segment_name, pl.DataFrame()))
        coverage = (labeled_count / full_count * 100) if full_count > 0 else 0
        print(f"{segment_name}: {coverage:.1f}% labeled ({labeled_count}/{full_count})")
        
    return full_segments, labeled_segments
def get_unlabeled_samples(matrix: pl.DataFrame, 
                         labels_df: pl.DataFrame, 
                         samples_per_segment: int = 50,
                         use_isolation_forest: bool = True) -> Dict[str, pl.DataFrame]:
    """Get stratified samples of unlabeled data, optionally using anomaly detection"""
    
    # Get unlabeled FIDs
    labeled_fids = labels_df['fid'].unique()
    unlabeled = matrix.filter(~pl.col('fid').is_in(labeled_fids))
    
    # Segment unlabeled data
    segments = segment_users_by_behavior(unlabeled)
    
    # Features for anomaly detection
    anomaly_features = [
        'cast_count', 'follower_count', 'following_count',
        'authenticity_score', 'total_reactions', 'rapid_actions',
        'avg_hours_between_actions', 'std_hours_between_actions'
    ]
    
    samples = {}
    for name, segment in segments.items():
        print(f"\nProcessing {name} segment ({len(segment)} users)")
        
        if len(segment) == 0:
            continue
            
        if use_isolation_forest and len(segment) > samples_per_segment:
            # Prepare features for anomaly detection
            valid_features = [f for f in anomaly_features if f in segment.columns]
            if len(valid_features) > 0:
                X = segment.select(valid_features).fill_null(0).to_numpy()
                X = StandardScaler().fit_transform(X)
                
                # Use Isolation Forest to identify anomalies
                iso_forest = IsolationForest(
                    n_estimators=100,
                    contamination=0.1,  # Assume 10% anomalies
                    random_state=42
                )
                
                # Get anomaly scores
                scores = iso_forest.fit_predict(X)
                anomaly_indices = np.where(scores == -1)[0]
                normal_indices = np.where(scores == 1)[0]
                
                # Sample both anomalies and normal cases
                n_anomalies = min(samples_per_segment // 4, len(anomaly_indices))
                n_normal = samples_per_segment - n_anomalies
                
                # Create a filter for selected indices
                selected_indices = np.concatenate([
                    np.random.choice(anomaly_indices, n_anomalies, replace=False),
                    np.random.choice(normal_indices, n_normal, replace=False)
                ])
                
                # Create a row number column and filter by selected indices
                samples[name] = (segment
                    .with_row_count("row_nr")
                    .filter(pl.col("row_nr").is_in(selected_indices))
                    .drop("row_nr"))
                
                print(f"Selected {n_anomalies} potential anomalies and {n_normal} normal cases")
            else:
                # Fallback to random sampling if features not available
                samples[name] = segment.sample(n=samples_per_segment, seed=42)
        else:
            # For small segments, take all samples
            n_samples = min(samples_per_segment, len(segment))
            samples[name] = segment.sample(n=n_samples, seed=42)
        
        print(f"Final sample size: {len(samples[name])}")
        
        # Print some statistics about the sample
        if 'cast_count' in segment.columns:
            stats = samples[name].select([
                pl.col('cast_count').mean().alias('avg_casts'),
                pl.col('follower_count').mean().alias('avg_followers'),
                pl.col('following_count').mean().alias('avg_following')
            ])
            print("Sample statistics:")
            print(f"  Avg casts: {stats['avg_casts'][0]:.1f}")
            print(f"  Avg followers: {stats['avg_followers'][0]:.1f}")
            print(f"  Avg following: {stats['avg_following'][0]:.1f}")
    
    return samples

def export_samples_for_labeling(samples: Dict[str, pl.DataFrame], 
                              output_path: str = "samples_for_labeling.csv"):
    """Export samples for manual labeling"""
    # Combine all samples
    all_samples = pl.concat([
        segment.with_columns(pl.lit(name).alias('segment'))
        for name, segment in samples.items()
    ])
    
    # Select relevant columns for labeling
    export_columns = [
        'fid', 'segment', 'fname',
        'cast_count', 'follower_count', 'following_count',
        'total_reactions', 'authenticity_score'
    ]
    
    # Only include columns that exist
    valid_columns = [col for col in export_columns if col in all_samples.columns]
    
    # Export to CSV
    all_samples.select(valid_columns).write_csv(output_path)
    print(f"\nExported {len(all_samples)} samples to {output_path}")
    
def suggest_priority_accounts(matrix: pl.DataFrame, 
                              labels_df: pl.DataFrame,
                              n_suggestions: int = 50) -> pl.DataFrame:
    """Suggest priority accounts for labeling based on influence and uncertainty"""
    
    # Get unlabeled accounts
    labeled_fids = labels_df['fid'].unique()
    unlabeled = matrix.filter(~pl.col('fid').is_in(labeled_fids))
    
    # Calculate influence score
    influence_features = [
        'follower_count', 'following_count', 'cast_count',
        'total_reactions', 'authenticity_score'
    ]
    
    # Only use available features
    valid_features = [f for f in influence_features if f in unlabeled.columns]
    
    if len(valid_features) > 0:
        # Normalize features
        normalized = unlabeled.select([
            'fid',
            *(pl.col(f).fill_null(0) / pl.col(f).fill_null(0).max() for f in valid_features)
        ])
        
        # Simple influence score - average of normalized features
        influence_df = normalized.with_columns([
            pl.fold(acc=0, function=lambda acc, x: acc + x, exprs=[pl.col(f) for f in valid_features])
            .alias('influence_score')
        ])
        
        # Get top influential accounts
        suggestions = influence_df.sort('influence_score', descending=True).head(n_suggestions)
        
        # Select only necessary columns for the join
        suggestions = suggestions.select(['fid', 'influence_score'])
        
        # Join back with original features, including 'fname'
        result = suggestions.join(unlabeled, on='fid', how='left', suffix="_unlabeled")
        
        # Ensure 'fname' is included in the final result
        if 'fname' in result.columns:
            return result.select(['fid', 'fname', 'influence_score', 'follower_count', 'cast_count', 'total_reactions'])
        else:
            print("Warning: 'fname' column not found in the dataset.")
            return result.select(['fid', 'influence_score', 'follower_count', 'cast_count', 'total_reactions'])
    else:
        print("No valid features found for influence calculation")
        return unlabeled.head(n_suggestions)


# Analyze current distribution
full_segments, labeled_segments = analyze_full_distribution(matrix, labels_df)

# Get samples for each segment
samples = get_unlabeled_samples(matrix, labels_df, samples_per_segment=50)

# Export samples for labeling
export_samples_for_labeling(samples)

# Get priority suggestions
priority_accounts = suggest_priority_accounts(matrix, labels_df, n_suggestions=50)

print("\nTop accounts to consider for labeling:")
print(priority_accounts.select([
    'fid', 'fname', 'influence_score', 'follower_count', 'cast_count', 'total_reactions'
]).head(10))

In [None]:
# First get indices where bot probability is high
bot_mask = y_pred_proba_filtered > 0.9
bot_fids = fids_array[bot_mask]

# Ensure FIDs are the same type (Int64)
high_conf_bots = unlabeled_data_filtered.filter(pl.col('fid').cast(pl.Int64).is_in(bot_fids))
bot_profiles = valid_profiles.with_columns(pl.col('fid').cast(pl.Int64)).filter(pl.col('fid').is_in(bot_fids))

# Sample 10 random high confidence bots
sample_size = min(10, len(bot_profiles))
random_indices = np.random.choice(range(len(bot_profiles)), sample_size, replace=False)
random_bot_sample = bot_profiles.slice(random_indices[0], sample_size)

print(f"\nSample of High Confidence Bot Predictions (from {len(bot_fids)} total high-confidence bots):")
print("-" * 50)

for row in random_bot_sample.iter_rows():
    fname = row[0]  # fname is first column
    fid = row[6]    # fid is last column
    display_name = row[1]  # display_name is second column
    
    # Get features for this bot
    bot_features = high_conf_bots.filter(pl.col('fid') == fid)
    
    # Find index in original arrays for probability and SHAP values
    fid_idx = np.where(fids_array == fid)[0][0]
    bot_prob = y_pred_proba_filtered[fid_idx]
    
    # Get SHAP values explaining why it's classified as a bot
    shap_values = detector.get_feature_explanations('xgb', X_unlabeled_filtered, fid_idx)
    
    print(f"\nUsername: {fname}")
    print(f"Display Name: {display_name}")
    print(f"FID: {fid}")
    print(f"Bot Probability: {bot_prob:.2%}")
    
    print("\nTop reasons for bot classification:")
    # Sort by absolute value but only show positive values (contributing to bot classification)
    sorted_reasons = sorted(shap_values.items(), key=lambda x: abs(x[1]), reverse=True)
    for feature, impact in sorted_reasons:
        if impact > 0:  # Only show features pushing towards bot classification
            print(f"  - {feature}: {impact:.3f}")
    
    print("\nKey Behavioral Metrics:")
    metrics_to_check = [
        'rapid_actions', 
        'std_hours_between_actions',
        'avg_hours_between_actions',
        'following_count',
        'follower_count',
        'total_activity',
        'hour_diversity',
        'weekday_diversity'
    ]
    
    for metric in metrics_to_check:
        if metric in bot_features.columns:
            val = bot_features.select(metric).item()
            if isinstance(val, (int, float)):
                print(f"{metric}: {val:.2f}")
            else:
                print(f"{metric}: {val}")
            
    print(f"\nProfile link: https://warpcast.com/{fname}")
    print("-" * 50)

print("\nTo verify these accounts, check for:")
print("1. Highly regular posting patterns (low std_hours_between_actions)")
print("2. Unusually high activity rates (high rapid_actions)")
print("3. Unnatural timing patterns (low hour_diversity and weekday_diversity)")
print("4. Suspicious follower/following ratios")

In [None]:
def get_segment_features(segment: str) -> List[str]:
    base_features = ['has_ens', 'has_bio', 'has_avatar', 'verification_count']
    
    if segment == 'active':
        return base_features + [
            'cast_timing_entropy',
            'reply_ratio',
            'mention_patterns',
            'engagement_rate'
        ]
    elif segment == 'low_activity':
        return base_features + [
            'network_growth_rate',
            'initial_behavior_pattern',
            'verification_sequence'
        ]
    else:  # dormant
        return base_features + [
            'follower_growth_velocity',
            'network_structure',
            'profile_completion_sequence'
        ]

In [None]:
import polars as pl
import numpy as np
import pandas as pd
from typing import Dict, List, Tuple, Any
from sklearn.ensemble import RandomForestClassifier
from lightgbm import LGBMClassifier
import xgboost as xgb
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import make_scorer, precision_recall_fscore_support
from sklearn.model_selection import cross_validate

def segment_users_by_behavior(matrix: pl.DataFrame) -> Dict[str, pl.DataFrame]:
    """Segment users based on behavioral patterns"""
    segments = {
        'power_users': matrix.filter(
            (pl.col('cast_count') >= 20) & 
            (pl.col('reply_count') >= 5)
        ),
        'casual_users': matrix.filter(
            (pl.col('cast_count') >= 5) & 
            (pl.col('cast_count') < 20)
        ),
        'one_time_users': matrix.filter(
            (pl.col('cast_count') > 0) & 
            (pl.col('cast_count') < 5)
        ),
        'lurkers': matrix.filter(pl.col('cast_count') == 0)
    }
    
    total = len(matrix)
    print("\nUser Segment Distribution:")
    for name, segment in segments.items():
        size = len(segment)
        if size > 0 and 'bot' in segment.columns:
            bot_pct = (segment.filter(pl.col('bot') == 1).shape[0] / size) * 100
            print(f"{name}: {size:,} users ({size/total*100:.1f}%) - {bot_pct:.1f}% bots")
            
            metrics = segment.select([
                pl.col('cast_count').mean(),
                pl.col('follower_count').mean(),
                pl.col('following_count').mean()
            ]).to_numpy()[0]
            
            print(f"  Avg casts: {metrics[0]:.1f}")
            print(f"  Avg followers: {metrics[1]:.1f}")
            print(f"  Avg following: {metrics[2]:.1f}")
        else:
            print(f"{name}: {size:,} users ({size/total*100:.1f}%)")
            
    return segments

def get_segment_specific_features(segment_name: str) -> List[str]:
    """Get feature list specific to each behavior segment"""
    base_features = [
        'has_ens', 'has_bio', 'has_avatar', 'verification_count',  # verification_count here
        'following_count', 'follower_count', 'follower_ratio',
        'unique_follower_ratio', 'authenticity_score'
    ]
    
    segment_features = {
        'power_users': [
            'cast_count', 'total_reactions', 'avg_cast_length',
            'reply_count', 'mentions_count', 'engagement_score',
            'weekday_diversity', 'hour_diversity', 'rapid_actions',
            'avg_hours_between_actions', 'std_hours_between_actions',
            'power_user_interaction_ratio', 'influence_score'
        ],
        'casual_users': [
            'cast_count', 'total_reactions', 'engagement_score',
            'reply_count', 'rapid_actions', 'avg_hours_between_actions',
            'avg_cast_length', 'mentions_count'
        ],
        'one_time_users': [
            'cast_count', 'total_reactions', 'profile_update_consistency',
            'follower_growth_rate'
        ],
        'lurkers': [
            'profile_update_consistency', 'network_balance',
            'follower_growth_rate',
            'profile_completeness'
        ]
    }
    
    return base_features + segment_features.get(segment_name, [])

class RapidModelEvaluator:
    def __init__(self, n_cv_splits: int = 5):
        self.n_cv_splits = n_cv_splits
        self.results = {}
        
    def evaluate_model(self, name: str, model: Any, X: np.ndarray, y: np.ndarray) -> Dict:
        """Quickly evaluate a model with cross-validation"""
        scoring = {
            'precision': make_scorer(lambda y_true, y_pred: 
                precision_recall_fscore_support(y_true, y_pred)[0].mean()),
            'recall': make_scorer(lambda y_true, y_pred: 
                precision_recall_fscore_support(y_true, y_pred)[1].mean()),
            'f1': 'f1'
        }
        
        cv_results = cross_validate(
            model, X, y,
            cv=self.n_cv_splits,
            scoring=scoring,
            return_train_score=True,
            n_jobs=-1
        )
        
        self.results[name] = {
            'test_scores': {
                'precision': cv_results['test_precision'].mean(),
                'recall': cv_results['test_recall'].mean(),
                'f1': cv_results['test_f1'].mean()
            },
            'fit_time': cv_results['fit_time'].mean()
        }
        
        return self.results[name]
    
    def compare_models(self, segment_name: str, models: Dict[str, Any], 
                      X: np.ndarray, y: np.ndarray) -> pd.DataFrame:
        """Compare multiple models quickly"""
        self.results = {}  # Reset results for each segment
        for name, model in models.items():
            print(f"Evaluating {name} on {segment_name} segment...")
            self.evaluate_model(f"{name}", model, X, y)
        
        results_df = pd.DataFrame.from_dict(
            {k: v['test_scores'] for k, v in self.results.items()}, 
            orient='index'
        )
        results_df['fit_time'] = [v['fit_time'] for v in self.results.values()]
        
        return results_df.sort_values('f1', ascending=False)

def prepare_segment_features(segment: pl.DataFrame, segment_name: str) -> Tuple[np.ndarray, np.ndarray]:
    """Prepare features for a given segment"""
    feature_cols = get_segment_specific_features(segment_name)
    valid_features = [col for col in feature_cols if col in segment.columns]
    print(f"\nUsing {len(valid_features)} features for {segment_name}:", valid_features)
    
    X = segment.select(valid_features).fill_null(0).to_numpy()
    y = segment['bot'].to_numpy() if 'bot' in segment.columns else None
    
    scaler = StandardScaler()
    X = scaler.fit_transform(X)
    
    return X, y, valid_features  # Return valid_features as well

def evaluate_segments(matrix: pl.DataFrame, labels_df: pl.DataFrame) -> Dict:
    """Evaluate models on each behavioral segment"""
    # Ensure consistent FID type
    matrix = matrix.with_columns(pl.col('fid').cast(pl.Int64))
    labels_df = labels_df.with_columns(pl.col('fid').cast(pl.Int64))
    
    # Join with labels
    data = matrix.join(labels_df, on='fid', how='inner')
    
    # Create segments
    segmented_users = segment_users_by_behavior(data)
    
    # Initialize evaluator
    evaluator = RapidModelEvaluator()
    
    # Store results
    segment_results = {}
    
    # For each segment
    for segment_name, segment_data in segmented_users.items():
        print(f"\nProcessing {segment_name} segment...")
        
        # Prepare features
        X, y, valid_features = prepare_segment_features(segment_data, segment_name)
        
        if y is None or len(np.unique(y)) < 2:
            print(f"Skipping {segment_name} - insufficient labels")
            continue
        
        # Define models with balanced class weights
        models = {
            'xgb': xgb.XGBClassifier(
                n_estimators=100,
                learning_rate=0.05,
                max_depth=4,
                scale_pos_weight=sum(y == 0) / sum(y == 1),
                random_state=42
            ),
            'lgbm': LGBMClassifier(
                n_estimators=100,
                learning_rate=0.05,
                max_depth=4,
                class_weight='balanced',
                random_state=42
            ),
            'rf': RandomForestClassifier(
                n_estimators=100,
                max_depth=4,
                class_weight='balanced',
                random_state=42
            ),
            'logistic': LogisticRegression(
                max_iter=1000,
                class_weight='balanced',
                random_state=42
            )
        }
        
        # Evaluate models
        results = evaluator.compare_models(segment_name, models, X, y)
        segment_results[segment_name] = results
        
        print(f"\nResults for {segment_name}:")
        print(results)
        
        # Print feature importance for best model
        try:
            best_model_name = results.index[0]
            best_model = models[best_model_name]
            
            if hasattr(best_model, 'feature_importances_'):
                importances = pd.DataFrame({
                    'feature': valid_features,
                    'importance': best_model.feature_importances_
                }).sort_values('importance', ascending=False)
                
                print(f"\nTop 10 features for {segment_name}:")
                print(importances.head(10))
            else:
                print(f"\nNo feature importances available for {best_model_name}")
                
        except Exception as e:
            print(f"Error calculating feature importance: {str(e)}")
            continue
    
    return segment_results


results = evaluate_segments(matrix, labels_df)

print("\nBest models per segment:")
for segment_name, result_df in results.items():
    best_model = result_df.index[0]
    best_f1 = result_df.iloc[0]['f1']
    best_precision = result_df.iloc[0]['precision']
    best_recall = result_df.iloc[0]['recall']
    print(f"{segment_name}:")
    print(f"  Best model: {best_model}")
    print(f"  F1: {best_f1:.3f}")
    print(f"  Precision: {best_precision:.3f}")
    print(f"  Recall: {best_recall:.3f}")

In [None]:
import polars as pl
import numpy as np
from typing import Dict, List, Tuple
from sklearn.ensemble import IsolationForest
from sklearn.preprocessing import StandardScaler

def analyze_full_distribution(matrix: pl.DataFrame, labels_df: pl.DataFrame) -> Tuple[Dict, Dict]:
    """Analyze distribution in both full and labeled datasets"""
    print("=== Full Dataset Distribution ===")
    full_segments = segment_users_by_behavior(matrix)
    
    print("\n=== Labeled Dataset Distribution ===")
    labeled_data = matrix.join(labels_df, on='fid', how='inner')
    labeled_segments = segment_users_by_behavior(labeled_data)
    
    # Calculate coverage
    print("\n=== Label Coverage by Segment ===")
    for segment_name in full_segments.keys():
        full_count = len(full_segments[segment_name])
        labeled_count = len(labeled_segments.get(segment_name, pl.DataFrame()))
        coverage = (labeled_count / full_count * 100) if full_count > 0 else 0
        print(f"{segment_name}: {coverage:.1f}% labeled ({labeled_count}/{full_count})")
        
    return full_segments, labeled_segments
def get_unlabeled_samples(matrix: pl.DataFrame, 
                         labels_df: pl.DataFrame, 
                         samples_per_segment: int = 50,
                         use_isolation_forest: bool = True) -> Dict[str, pl.DataFrame]:
    """Get stratified samples of unlabeled data, optionally using anomaly detection"""
    
    # Get unlabeled FIDs
    labeled_fids = labels_df['fid'].unique()
    unlabeled = matrix.filter(~pl.col('fid').is_in(labeled_fids))
    
    # Segment unlabeled data
    segments = segment_users_by_behavior(unlabeled)
    
    # Features for anomaly detection
    anomaly_features = [
        'cast_count', 'follower_count', 'following_count',
        'authenticity_score', 'total_reactions', 'rapid_actions',
        'avg_hours_between_actions', 'std_hours_between_actions'
    ]
    
    samples = {}
    for name, segment in segments.items():
        print(f"\nProcessing {name} segment ({len(segment)} users)")
        
        if len(segment) == 0:
            continue
            
        if use_isolation_forest and len(segment) > samples_per_segment:
            # Prepare features for anomaly detection
            valid_features = [f for f in anomaly_features if f in segment.columns]
            if len(valid_features) > 0:
                X = segment.select(valid_features).fill_null(0).to_numpy()
                X = StandardScaler().fit_transform(X)
                
                # Use Isolation Forest to identify anomalies
                iso_forest = IsolationForest(
                    n_estimators=100,
                    contamination=0.1,  # Assume 10% anomalies
                    random_state=42
                )
                
                # Get anomaly scores
                scores = iso_forest.fit_predict(X)
                anomaly_indices = np.where(scores == -1)[0]
                normal_indices = np.where(scores == 1)[0]
                
                # Sample both anomalies and normal cases
                n_anomalies = min(samples_per_segment // 4, len(anomaly_indices))
                n_normal = samples_per_segment - n_anomalies
                
                # Create a filter for selected indices
                selected_indices = np.concatenate([
                    np.random.choice(anomaly_indices, n_anomalies, replace=False),
                    np.random.choice(normal_indices, n_normal, replace=False)
                ])
                
                # Create a row number column and filter by selected indices
                samples[name] = (segment
                    .with_row_count("row_nr")
                    .filter(pl.col("row_nr").is_in(selected_indices))
                    .drop("row_nr"))
                
                print(f"Selected {n_anomalies} potential anomalies and {n_normal} normal cases")
            else:
                # Fallback to random sampling if features not available
                samples[name] = segment.sample(n=samples_per_segment, seed=42)
        else:
            # For small segments, take all samples
            n_samples = min(samples_per_segment, len(segment))
            samples[name] = segment.sample(n=n_samples, seed=42)
        
        print(f"Final sample size: {len(samples[name])}")
        
        # Print some statistics about the sample
        if 'cast_count' in segment.columns:
            stats = samples[name].select([
                pl.col('cast_count').mean().alias('avg_casts'),
                pl.col('follower_count').mean().alias('avg_followers'),
                pl.col('following_count').mean().alias('avg_following')
            ])
            print("Sample statistics:")
            print(f"  Avg casts: {stats['avg_casts'][0]:.1f}")
            print(f"  Avg followers: {stats['avg_followers'][0]:.1f}")
            print(f"  Avg following: {stats['avg_following'][0]:.1f}")
    
    return samples

def export_samples_for_labeling(samples: Dict[str, pl.DataFrame], 
                              output_path: str = "samples_for_labeling.csv"):
    """Export samples for manual labeling"""
    # Combine all samples
    all_samples = pl.concat([
        segment.with_columns(pl.lit(name).alias('segment'))
        for name, segment in samples.items()
    ])
    
    # Select relevant columns for labeling
    export_columns = [
        'fid', 'segment', 'fname',
        'cast_count', 'follower_count', 'following_count',
        'total_reactions', 'authenticity_score'
    ]
    
    # Only include columns that exist
    valid_columns = [col for col in export_columns if col in all_samples.columns]
    
    # Export to CSV
    all_samples.select(valid_columns).write_csv(output_path)
    print(f"\nExported {len(all_samples)} samples to {output_path}")
    
def suggest_priority_accounts(matrix: pl.DataFrame, 
                              labels_df: pl.DataFrame,
                              n_suggestions: int = 50) -> pl.DataFrame:
    """Suggest priority accounts for labeling based on influence and uncertainty"""
    
    # Get unlabeled accounts
    labeled_fids = labels_df['fid'].unique()
    unlabeled = matrix.filter(~pl.col('fid').is_in(labeled_fids))
    
    # Calculate influence score
    influence_features = [
        'follower_count', 'following_count', 'cast_count',
        'total_reactions', 'authenticity_score'
    ]
    
    # Only use available features
    valid_features = [f for f in influence_features if f in unlabeled.columns]
    
    if len(valid_features) > 0:
        # Normalize features
        normalized = unlabeled.select([
            'fid',
            *(pl.col(f).fill_null(0) / pl.col(f).fill_null(0).max() for f in valid_features)
        ])
        
        # Simple influence score - average of normalized features
        influence_df = normalized.with_columns([
            pl.fold(acc=0, function=lambda acc, x: acc + x, exprs=[pl.col(f) for f in valid_features])
            .alias('influence_score')
        ])
        
        # Get top influential accounts
        suggestions = influence_df.sort('influence_score', descending=True).head(n_suggestions)
        
        # Select only necessary columns for the join
        suggestions = suggestions.select(['fid', 'influence_score'])
        
        # Join back with original features, including 'fname'
        result = suggestions.join(unlabeled, on='fid', how='left', suffix="_unlabeled")
        
        # Ensure 'fname' is included in the final result
        if 'fname' in result.columns:
            return result.select(['fid', 'fname', 'influence_score', 'follower_count', 'cast_count', 'total_reactions'])
        else:
            print("Warning: 'fname' column not found in the dataset.")
            return result.select(['fid', 'influence_score', 'follower_count', 'cast_count', 'total_reactions'])
    else:
        print("No valid features found for influence calculation")
        return unlabeled.head(n_suggestions)


# Analyze current distribution
full_segments, labeled_segments = analyze_full_distribution(matrix, labels_df)

# Get samples for each segment
samples = get_unlabeled_samples(matrix, labels_df, samples_per_segment=50)

# Export samples for labeling
export_samples_for_labeling(samples)

# Get priority suggestions
priority_accounts = suggest_priority_accounts(matrix, labels_df, n_suggestions=50)

print("\nTop accounts to consider for labeling:")
print(priority_accounts.select([
    'fid', 'fname', 'influence_score', 'follower_count', 'cast_count', 'total_reactions'
]).head(10))