In [None]:
# new version
import os
import polars as pl
from typing import List

# Configure Polars for memory usage
pl.Config.set_streaming_chunk_size(1_000_000)

class FeatureSet:
    """Track feature dependencies and versioning"""
    def __init__(self, name: str, version: str, dependencies: List[str] = None):
        self.name = name
        self.version = version  # Version of feature calculation logic
        self.dependencies = dependencies or []
        self.checkpoint_path = None
        self.last_modified = None

class StreamingDataLoader:
    """Memory-efficient dataset loader using Polars streaming capabilities."""
    
    def __init__(self, data_path: str, checkpoint_dir: str, debug_mode: bool = True, sample_size: int = 100):
        self.data_path = data_path
        self.checkpoint_dir = checkpoint_dir
        self.base_fids = None
        self.debug_mode = debug_mode
        self.sample_size = sample_size

    def set_base_fids(self, fids: pl.Series):
        """Set base FIDs for filtering downstream datasets."""
        self.base_fids = fids.unique()
        print(f"Set base FIDs: {len(self.base_fids)} records")

    def get_checkpoint_fids(self) -> bool:
        """Get base FIDs from profile checkpoint if it exists"""
        profile_checkpoint = f"{self.checkpoint_dir}/profile_features.parquet"
        if os.path.exists(profile_checkpoint):
            df = pl.read_parquet(profile_checkpoint)
            if 'fid' in df.columns:
                self.base_fids = df['fid']
                print(f"Loaded base FIDs from checkpoint: {len(self.base_fids)} records")
                return True
        return False

    def get_dataset(self, name: str, columns: List[str] = None, batch_size: int = 100_000) -> pl.LazyFrame:
        """Get dataset as a LazyFrame with streaming and optional column selection."""
        path = f"{self.data_path}/farcaster-{name}-0-1733162400.parquet"
        
        try:
            # Start with lazy scanning
            scan_query = pl.scan_parquet(
                path,
                n_rows=batch_size  # Process in batches
            )
            
            # Apply early column selection if specified
            if columns:
                scan_query = scan_query.select(columns)
            
            if self.debug_mode:
                if self.base_fids is None:
                    # Try to get FIDs from checkpoint first
                    if not self.get_checkpoint_fids():
                        if name == 'profile_with_addresses':
                            collected = scan_query.limit(self.sample_size).collect()
                            self.base_fids = collected['fid']
                            print(f"Established new base FIDs from {name}: {len(self.base_fids)} records")
                            return collected.lazy()
                        else:
                            print(f"Warning: No base FIDs available for {name}")
                            return scan_query.limit(self.sample_size)
                else:
                    print(f"Filtering {name} by {len(self.base_fids)} base FIDs")
                    scan_query = scan_query.filter(pl.col('fid').is_in(self.base_fids))
                    
            # Apply early FID filtering if available
            elif self.base_fids is not None:
                scan_query = scan_query.filter(pl.col('fid').is_in(self.base_fids))
            
            return scan_query
            
        except Exception as e:
            print(f"Error loading {name}: {str(e)}")
            raise

    def load_checkpoint(self, name: str) -> pl.LazyFrame:
        """Load checkpoint as LazyFrame if it exists."""
        path = f"{self.checkpoint_dir}/{name}_features.parquet"
        if os.path.exists(path):
            return pl.scan_parquet(path)
        return None

    def save_checkpoint(self, lf: pl.LazyFrame, name: str):
        """Save LazyFrame as checkpoint."""
        path = f"{self.checkpoint_dir}/{name}_features.parquet"
        # Collect with streaming and save
        lf.collect(streaming=True).write_parquet(path)


class OptimizedFeatureEngineering:
    """Example class with memory-efficient feature engineering using streaming."""
    
    def __init__(self, data_path: str, checkpoint_dir: str):
        self.data_path = data_path
        self.checkpoint_dir = checkpoint_dir
        self.loader = StreamingDataLoader(data_path, checkpoint_dir)

        self.feature_sets = {   
            # Base features
            'profile': FeatureSet('profile', '1.0'),
            'network': FeatureSet('network', '1.0'),
            'temporal': FeatureSet('temporal', '1.0', ['network']),
            
            # Activity features
            'cast': FeatureSet('cast', '1.0'),
            'reaction': FeatureSet('reaction', '1.0'),
            'channel': FeatureSet('channel', '1.0'),
            'verification': FeatureSet('verification', '1.0'),
            
            # Account features
            'user_data': FeatureSet('user_data', '1.0'),
            'storage': FeatureSet('storage', '1.0'),
            'signers': FeatureSet('signers', '1.0'),
            
            # Interaction patterns
            'engagement': FeatureSet('engagement', '1.0', 
                ['cast', 'reaction', 'channel']),
            'mentions': FeatureSet('mentions', '1.0', 
                ['cast', 'network']),
            'reply_patterns': FeatureSet('reply_patterns', '1.0', 
                ['cast', 'temporal']),
            
            # Network quality
            'network_quality': FeatureSet('network_quality', '1.0', 
                ['network', 'engagement']),
            'power_user_interaction': FeatureSet('power_user_interaction', '1.0', 
                ['network', 'temporal']),
            'cluster_analysis': FeatureSet('cluster_analysis', '1.0', 
                ['network', 'engagement']),
            
            # Behavioral patterns
            'activity_patterns': FeatureSet('activity_patterns', '1.0', 
                ['temporal', 'cast', 'reaction']),
            'update_behavior': FeatureSet('update_behavior', '1.0', 
                ['user_data', 'profile']),
            'verification_patterns': FeatureSet('verification_patterns', '1.0', 
                ['verification', 'temporal']),
            
            # Meta features
            'authenticity': FeatureSet('authenticity', '2.0', [
                'profile', 'network', 'channel', 'verification',
                'engagement', 'network_quality', 'activity_patterns'
            ]),
            'influence': FeatureSet('influence', '1.0', [
                'network', 'engagement', 'power_user_interaction'
            ]),
            
            # Final derived features
            'derived': FeatureSet('derived', '2.0', [
                'network', 'temporal', 'authenticity',
                'engagement', 'network_quality', 'influence'
            ])
        }
        
        # Initialize checkpoint tracking
        self._init_checkpoints()


    def _init_checkpoints(self):
        """Initialize checkpoint paths and check existing files"""
        os.makedirs(self.checkpoint_dir, exist_ok=True)
        
        for name, feature_set in self.feature_sets.items():
            path = f"{self.checkpoint_dir}/{name}_features.parquet"
            feature_set.checkpoint_path = path
            
            if os.path.exists(path):
                feature_set.last_modified = os.path.getmtime(path)

    def _needs_rebuild(self, feature_set: FeatureSet) -> bool:
        """Check if feature set needs to be rebuilt"""
        # Check dependencies too
        if not os.path.exists(feature_set.checkpoint_path):
            return True
        
        for dep in feature_set.dependencies:
            dep_set = self.feature_sets.get(dep)
            if dep_set and dep_set.last_modified > feature_set.last_modified:
                return True
        return False

    def _validate_feature_addition(self, original_df: pl.LazyFrame, 
                                new_df: pl.LazyFrame,
                                base_fids: pl.Series,
                                feature_name: str) -> pl.LazyFrame:
        """Validate and fix feature addition results"""
        if new_df is None:
            print(f"Error: {feature_name} returned None")
            return original_df
            
        # Validate FIDs
        if not self._validate_checkpoint_compatibility(new_df, base_fids):
            return original_df
        
        return new_df
        
    def extract_profile_features(self) -> pl.LazyFrame:
        """Extract profile features as a LazyFrame."""
        profiles = self.loader.get_dataset('profile_with_addresses', 
            ['fid', 'fname', 'bio', 'avatar_url', 'verified_addresses', 'display_name'])
        
        return (
            profiles
            .filter(pl.col('fname').is_not_null() & (pl.col('fname') != ""))
            .with_columns([
                pl.col('fid').cast(pl.Int64),
                pl.col('fname').str.contains(r'\.eth$').cast(pl.Int32).alias('has_ens'),
                (pl.col('bio').is_not_null() & (pl.col('bio') != "")).cast(pl.Int32).alias('has_bio'),
                pl.col('avatar_url').is_not_null().cast(pl.Int32).alias('has_avatar'),
                pl.when(pl.col('verified_addresses').str.contains(','))
                 .then(pl.col('verified_addresses').str.contains(',').cast(pl.Int32) + 1)
                 .otherwise(pl.when(pl.col('verified_addresses') != '[]')
                             .then(1)
                             .otherwise(0))
                 .alias('verification_count'),
                pl.col('display_name').is_not_null().cast(pl.Int32).alias('has_display_name')
            ])
        )

    def add_network_quality_features(self, base_lf: pl.LazyFrame) -> pl.LazyFrame:
        """Build network quality features with safer dependency handling"""
        # Load power users and calculate metrics
        power_users = self.loader.get_dataset('power_users', ['fid'])
        power_fids = power_users.select('fid').collect()['fid'].cast(pl.Int64)
        
        # Calculate power user metrics
        casts = self.loader.get_dataset('casts', 
            ['fid', 'parent_fid', 'mentions', 'deleted_at'])
            
        power_metrics = (
            casts.filter(pl.col('deleted_at').is_null())
            .with_columns([
                pl.col('parent_fid').cast(pl.Int64).is_in(power_fids)
                    .alias('is_power_reply'),
                pl.col('mentions').str.contains(power_fids.cast(str).str.concat('|'))
                    .alias('has_power_mention')
            ])
            .group_by('fid')
            .agg([
                pl.col('is_power_reply').sum().alias('power_reply_count'),
                pl.col('has_power_mention').sum().alias('power_mentions_count')
            ])
        )
        
        return base_lf.join(power_metrics, on='fid', how='left')
    def add_network_features(self, base_lf: pl.LazyFrame) -> pl.LazyFrame:
        """Add network features using streaming joins."""
        links = self.loader.get_dataset('links', 
            ['fid', 'target_fid', 'timestamp', 'deleted_at'])
        
        # Following patterns
        following = (
            links
            .filter(pl.col('deleted_at').is_null())
            .group_by('fid')
            .agg([
                pl.len().alias('following_count'),
                pl.n_unique('target_fid').alias('unique_following_count')
            ])
        )
        
        # Follower patterns
        followers = (
            links
            .filter(pl.col('deleted_at').is_null())
            .group_by('target_fid')
            .agg([
                pl.len().alias('follower_count'),
                pl.n_unique('fid').alias('unique_follower_count')
            ])
            .rename({'target_fid': 'fid'})
        )
        
        result = base_lf.join(following, on='fid', how='left')
        result = result.join(followers, on='fid', how='left')
        
        # Derived network metrics
        result = result.with_columns([
            (pl.col('follower_count') / (pl.col('following_count') + 1)).alias('follower_ratio'),
            (pl.col('unique_follower_count') / (pl.col('unique_following_count') + 1))
                .alias('unique_follower_ratio')
        ])
        
        return result
    def add_temporal_features(self, base_lf: pl.LazyFrame) -> pl.LazyFrame:
        """Add temporal features using streaming operations."""
        links = self.loader.get_dataset('links', 
            ['fid', 'timestamp', 'deleted_at'])

        temporal_features = (
            links
            .filter(pl.col('deleted_at').is_null())
            .filter(pl.col('timestamp').is_not_null())
            .with_columns(pl.col('timestamp').cast(pl.Datetime))
            .group_by('fid')
            .agg([
                pl.len().alias('total_activity'),
                pl.col('timestamp').diff().dt.total_hours().mean().alias('avg_hours_between_actions'),
                pl.col('timestamp').diff().dt.total_hours().std().alias('std_hours_between_actions'),
                pl.col('timestamp').dt.weekday().std().alias('weekday_variance'),
                (pl.col('timestamp').diff().dt.total_hours() < 1).sum().alias('rapid_actions'),
                (pl.col('timestamp').diff().dt.total_hours() > 24).sum().alias('long_gaps'),
                pl.col('timestamp').diff().dt.total_hours().quantile(0.9).alias('p90_time_between_actions'),
                pl.col('timestamp').diff().dt.total_hours().quantile(0.1).alias('p10_time_between_actions')
            ])
        )

        return base_lf.join(temporal_features, on='fid', how='left')

    def add_cast_features(self, base_lf: pl.LazyFrame) -> pl.LazyFrame:
        """Add cast-related features using streaming."""
        casts = self.loader.get_dataset('casts', 
            ['fid', 'text', 'parent_hash', 'mentions', 'deleted_at'])

        cast_features = (
            casts
            .filter(pl.col('deleted_at').is_null())
            .with_columns([
                pl.when(pl.col('text').is_not_null())
                    .then(pl.col('text').str.len_chars())
                    .otherwise(0)
                    .alias('cast_length'),
                pl.col('parent_hash').is_not_null().cast(pl.Int32).alias('is_reply'),
                pl.when(pl.col('mentions').is_not_null() & 
                    (pl.col('mentions') != '') & 
                    (pl.col('mentions') != '[]'))
                    .then(1)
                    .otherwise(0)
                    .alias('has_mentions')
            ])
            .group_by('fid')
            .agg([
                pl.len().alias('cast_count'),
                pl.col('cast_length').mean().alias('avg_cast_length'),
                pl.col('is_reply').sum().alias('reply_count'),
                pl.col('has_mentions').sum().alias('mentions_count')
            ])
        )

        return base_lf.join(cast_features, on='fid', how='left')

    def add_reaction_features(self, base_lf: pl.LazyFrame) -> pl.LazyFrame:
        """Add reaction-related features using streaming."""
        reactions = self.loader.get_dataset('reactions', 
            ['fid', 'reaction_type', 'target_fid', 'timestamp', 'deleted_at'])

        reaction_features = (
            reactions
            .filter(pl.col('deleted_at').is_null())
            .with_columns(pl.col('timestamp').cast(pl.Datetime))
            .group_by('fid')
            .agg([
                pl.len().alias('total_reactions'),
                (pl.col('reaction_type') == 1).sum().alias('like_count'),
                (pl.col('reaction_type') == 2).sum().alias('recast_count'),
                pl.n_unique('target_fid').alias('unique_users_reacted_to'),
                pl.col('timestamp').diff().dt.total_hours().mean().alias('avg_hours_between_reactions'),
                pl.col('timestamp').diff().dt.total_hours().std().alias('std_hours_between_reactions')
            ])
        )

        return base_lf.join(reaction_features, on='fid', how='left')

    def add_verification_features(self, base_lf: pl.LazyFrame) -> pl.LazyFrame:
        """Add verification-related features using streaming."""
        verifications = self.loader.get_dataset('verifications', 
            ['fid', 'timestamp', 'deleted_at'])
        
        acc_verifications = self.loader.get_dataset('account_verifications', 
            ['fid', 'platform', 'verified_at'])

        # Process on-chain verifications
        verif_features = (
            verifications
            .filter(pl.col('deleted_at').is_null())
            .with_columns(pl.col('timestamp').cast(pl.Datetime))
            .group_by('fid')
            .agg([
                pl.len().alias('total_verifications'),
                pl.col('timestamp').diff().dt.total_hours().mean().alias('avg_hours_between_verifications'),
                pl.col('timestamp').diff().dt.total_hours().std().alias('std_hours_between_verifications'),
                (pl.col('timestamp').diff().dt.total_hours() < 1).sum().alias('rapid_verifications')
            ])
        )

        # Process platform verifications
        platform_features = (
            acc_verifications
            .with_columns(pl.col('verified_at').cast(pl.Datetime))
            .group_by('fid')
            .agg([
                pl.n_unique('platform').alias('platforms_verified'),
                pl.col('verified_at').min().alias('first_platform_verification'),
                pl.col('verified_at').max().alias('last_platform_verification')
            ])
        )

        result = base_lf.join(verif_features, on='fid', how='left')
        return result.join(platform_features, on='fid', how='left')

    def add_authenticity_features(self, base_lf: pl.LazyFrame) -> pl.LazyFrame:
        """Calculate authenticity features using streaming."""
        return (
            base_lf
            .with_columns([
                # Profile completeness
                ((pl.col('has_bio') + 
                pl.col('has_avatar') + 
                pl.col('has_ens') + 
                (pl.col('verification_count') > 0).cast(pl.Int64)) / 4.0
                ).alias('profile_completeness'),

                # Network balance
                (pl.when(pl.col('following_count') + pl.col('follower_count') > 0)
                .then(1.0 - (pl.col('following_count') - pl.col('follower_count')).abs() /
                    (pl.col('following_count') + pl.col('follower_count')))
                .otherwise(0.0)
                ).alias('network_balance'),

                # Update naturalness
                (pl.when(pl.col('total_updates') > 0)
                .then(1.0 - pl.col('profile_update_consistency'))
                .otherwise(0.0)
                ).alias('update_naturalness')
            ])
            .with_columns([
                (pl.col('profile_completeness') * 0.4 +
                pl.col('network_balance') * 0.3 +
                pl.col('update_naturalness') * 0.3
                ).alias('authenticity_score')
            ])
        )

    def add_influence_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Add influence features with proper error handling"""
        try:
            # Ensure required columns exist and are properly initialized
            required_cols = ['follower_count', 'following_count', 'total_reactions', 'cast_count']
            for col in required_cols:
                if col not in df.columns:
                    df = df.with_columns(pl.lit(0).alias(col))
            
            # Calculate time span if possible
            if 'first_follow' in df.columns and 'last_follow' in df.columns:
                df = df.with_columns([
                    pl.when(pl.col('last_follow').is_not_null() & pl.col('first_follow').is_not_null())
                    .then((pl.col('last_follow') - pl.col('first_follow')).dt.total_hours())
                    .otherwise(0)
                    .alias('follow_time_span_hours')
                ])
            else:
                df = df.with_columns(pl.lit(0).alias('follow_time_span_hours'))

            # Calculate influence metrics safely
            df = df.with_columns([
                # Normalize influence metrics
                ((pl.col('follower_count').fill_null(0) * 0.4 +
                pl.col('total_reactions').fill_null(0) * 0.3 +
                pl.col('cast_count').fill_null(0) * 0.3) / 
                (pl.col('following_count').fill_null(0) + 1)
                ).alias('influence_score'),
                
                # Safe engagement rate calculation
                (pl.when(pl.col('cast_count') > 0)
                .then(pl.col('total_reactions') / pl.col('cast_count'))
                .otherwise(0)
                ).alias('engagement_rate'),
                
                # Safe follower growth rate calculation
                (pl.when(pl.col('follow_time_span_hours') > 0)
                .then(pl.col('follower_count') / pl.col('follow_time_span_hours'))
                .otherwise(0)
                ).alias('follower_growth_rate')
            ])
            
            return df
            
        except Exception as e:
            print(f"Error in influence features: {str(e)}")
            raise
            return df


    def add_user_data_features(self, base_lf: pl.LazyFrame) -> pl.LazyFrame:
        """Add user data features using streaming."""
        user_data = self.loader.get_dataset('user_data', 
            ['fid', 'type', 'timestamp', 'deleted_at'])
        
        update_features = (
            user_data
            .filter(pl.col('deleted_at').is_null())
            .group_by('fid')
            .agg([
                pl.len().alias('total_user_data_updates'),
                pl.col('timestamp').diff().dt.total_hours().mean()
                    .alias('avg_update_interval')
            ])
        )
        
        return base_lf.join(update_features, on='fid', how='left')

    def add_storage_features(self, base_lf: pl.LazyFrame) -> pl.LazyFrame:
        """Add storage-related features using streaming."""
        storage = self.loader.get_dataset('storage', 
            ['fid', 'units', 'deleted_at'])
        
        storage_features = (
            storage
            .filter(pl.col('deleted_at').is_null())
            .group_by('fid')
            .agg([
                pl.col('units').mean().alias('avg_storage_units'),
                pl.col('units').max().alias('max_storage_units'),
                pl.len().alias('storage_update_count')
            ])
        )
        
        return base_lf.join(storage_features, on='fid', how='left')

    def add_signer_features(self, base_lf: pl.LazyFrame) -> pl.LazyFrame:
        """Add signer-related features using streaming."""
        signers = self.loader.get_dataset('signers', 
            ['fid', 'timestamp', 'deleted_at'])
        
        signer_features = (
            signers
            .filter(pl.col('deleted_at').is_null())
            .group_by('fid')
            .agg([
                pl.len().alias('signer_count'),
                pl.col('timestamp').diff().dt.total_hours().mean()
                    .alias('avg_hours_between_signers'),
                pl.col('timestamp').diff().dt.total_hours().std()
                    .alias('std_hours_between_signers')
            ])
        )
        
        return base_lf.join(signer_features, on='fid', how='left')

    def add_mentions_features(self, base_lf: pl.LazyFrame) -> pl.LazyFrame:
        """Add mention pattern features using streaming."""
        casts = self.loader.get_dataset('casts', 
            ['fid', 'mentions', 'deleted_at'])
        
        mention_features = (
            casts
            .filter(pl.col('deleted_at').is_null())
            .with_columns([
                pl.when(
                    pl.col('mentions').is_not_null() & 
                    (pl.col('mentions') != '') & 
                    (pl.col('mentions') != '[]')
                )
                .then(pl.col('mentions').str.json_decode().list.len())
                .otherwise(0)
                .alias('mention_count')
            ])
            .group_by('fid')
            .agg([
                pl.col('mention_count').sum().alias('total_mentions'),
                pl.col('mention_count').mean().alias('avg_mentions_per_cast'),
                (pl.col('mention_count') > 0).sum().alias('casts_with_mentions')
            ])
        )
        
        return base_lf.join(mention_features, on='fid', how='left')

    def add_reply_patterns_features(self, base_lf: pl.LazyFrame) -> pl.LazyFrame:
        """Add reply pattern features using streaming."""
        casts = self.loader.get_dataset('casts', 
            ['fid', 'parent_hash', 'parent_fid', 'timestamp', 'deleted_at'])
        
        reply_features = (
            casts
            .filter(pl.col('deleted_at').is_null())
            .filter(pl.col('parent_hash').is_not_null())
            .group_by('fid')
            .agg([
                pl.len().alias('total_replies'),
                pl.n_unique('parent_fid').alias('unique_users_replied_to'),
                pl.col('timestamp').diff().dt.total_seconds().mean()
                    .alias('avg_seconds_between_replies'),
                pl.col('timestamp').diff().dt.total_seconds().std()
                    .alias('std_seconds_between_replies')
            ])
            .with_columns([
                (pl.col('unique_users_replied_to') / pl.col('total_replies'))
                    .alias('reply_diversity'),
                (pl.col('std_seconds_between_replies') / 
                pl.col('avg_seconds_between_replies'))
                    .alias('reply_timing_variability')
            ])
        )
        
        return base_lf.join(reply_features, on='fid', how='left')

    def add_power_user_interaction_features(self, base_lf: pl.LazyFrame) -> pl.LazyFrame:
        """Add power user interaction features using streaming."""
        # Load power users
        power_users = self.loader.get_dataset('warpcast_power_users', ['fid'])
        power_fids = power_users.select('fid').collect()['fid'].cast(pl.Int64)
        
        # Get interactions
        casts = self.loader.get_dataset('casts',
            ['fid', 'parent_fid', 'mentions', 'timestamp', 'deleted_at'])
        
        power_cast_features = (
            casts
            .filter(pl.col('deleted_at').is_null())
            .with_columns([
                pl.col('parent_fid').cast(pl.Int64).is_in(power_fids)
                    .alias('is_power_reply'),
                pl.col('mentions').str.contains(power_fids.cast(str).str.concat('|'))
                    .alias('has_power_mention')
            ])
            .group_by('fid')
            .agg([
                pl.col('is_power_reply').sum().alias('power_user_replies'),
                pl.col('has_power_mention').sum().alias('power_user_mentions'),
                pl.len().alias('total_casts')
            ])
        )
        
        return base_lf.join(power_cast_features, on='fid', how='left')

    def add_derived_features(self, base_lf: pl.LazyFrame) -> pl.LazyFrame:
        """Add derived/calculated features using streaming."""
        return (
            base_lf
            .with_columns([
                # Log transformations
                pl.col('follower_ratio').fill_null(0).log1p()
                    .alias('follower_ratio_log'),
                pl.col('unique_follower_ratio').fill_null(0).log1p()
                    .alias('unique_follower_ratio_log'),
                
                # Binary features
                (pl.col('follower_count') > pl.col('following_count'))
                    .cast(pl.Int32)
                    .alias('has_more_followers'),
                
                # Composite metrics
                ((pl.col('following_count') - pl.col('follower_count')).abs() /
                (pl.col('following_count') + pl.col('follower_count') + 1)
                ).alias('follow_balance_ratio')
            ])
        )

    def add_engagement_metrics(self, base_lf: pl.LazyFrame) -> pl.LazyFrame:
        """Add engagement metrics using streaming."""
        return (
            base_lf
            .with_columns([
                # Overall engagement score
                ((pl.col('cast_count') + 
                pl.col('total_reactions') + 
                pl.col('channel_memberships')) / 3.0
                ).alias('engagement_score'),
                
                # Activity balance
                (pl.col('cast_count') / (pl.col('total_reactions') + 1))
                    .alias('creation_consumption_ratio'),
                
                # Interaction diversity
                (pl.col('unique_users_reacted_to') / 
                (pl.col('total_reactions') + 1))
                    .alias('interaction_diversity')
            ])
        )
    def add_activity_patterns(self, base_lf: pl.LazyFrame) -> pl.LazyFrame:
        """Add activity patterns with fully safe calculations"""
        # Get activity data
        casts = self.loader.get_dataset('casts', ['fid', 'timestamp', 'deleted_at'])
        reactions = self.loader.get_dataset('reactions', ['fid', 'timestamp', 'deleted_at'])
        
        # Combine valid activities
        activities = pl.concat([
            casts.filter(pl.col('deleted_at').is_null())
                .select(['fid', 'timestamp']),
            reactions.filter(pl.col('deleted_at').is_null())
                .select(['fid', 'timestamp'])
        ])
        
        activity_features = (activities
            .with_columns([
                pl.col('timestamp').cast(pl.Datetime).dt.hour().alias('hour'),
                pl.col('timestamp').cast(pl.Datetime).dt.weekday().alias('weekday')
            ])
            .group_by('fid')
            .agg([
                pl.col('hour').value_counts().std().alias('hour_diversity'),
                pl.col('weekday').value_counts().std().alias('weekday_diversity'),
                pl.len().alias('total_activities')
            ])
        )
        
        return base_lf.join(activity_features, on='fid', how='left')
    def add_update_behavior(self, base_lf: pl.LazyFrame) -> pl.LazyFrame:
        """Add update behavior features with comprehensive null handling"""
        user_data = self.loader.get_dataset('user_data', 
            ['fid', 'timestamp', 'deleted_at'])
        
        valid_updates = (user_data
            .filter(pl.col('deleted_at').is_null())
            .filter(pl.col('timestamp').is_not_null())
            .with_columns([pl.col('timestamp').cast(pl.Datetime)]))
        
        update_metrics = (valid_updates
            .sort(['fid', 'timestamp'])
            .group_by('fid')
            .agg([
                pl.len().alias('total_updates'),
                pl.col('timestamp').diff().dt.total_hours().mean().alias('avg_update_interval'),
                pl.col('timestamp').diff().dt.total_hours().std().alias('update_time_std')
            ])
            .with_columns([
                pl.when(pl.col('avg_update_interval') > 0)
                    .then(pl.col('update_time_std') / pl.col('avg_update_interval'))
                    .otherwise(0.0)
                    .alias('profile_update_consistency')
            ])
        )
        
        return base_lf.join(update_metrics, on='fid', how='left')
    def add_channel_features(self, base_lf: pl.LazyFrame) -> pl.LazyFrame:
        """Add channel features using streaming operations"""
        # Process channel follows
        channel_follows = self.loader.get_dataset('channel_follows', 
            ['fid', 'channel_id', 'timestamp', 'deleted_at'])
        
        follow_features = (
            channel_follows
            .filter(pl.col('deleted_at').is_null())
            .group_by('fid')
            .agg([
                pl.n_unique('channel_id').alias('unique_channels_followed'),
                (pl.col('timestamp').diff().dt.total_seconds() < 60)
                    .sum().alias('rapid_channel_follows'),
                pl.col('timestamp').dt.hour().value_counts().std()
                    .alias('channel_follow_hour_std')
            ])
        )
        
        # Process memberships
        channel_members = self.loader.get_dataset('channel_members', 
            ['fid', 'channel_id', 'deleted_at'])
        
        member_features = (
            channel_members
            .filter(pl.col('deleted_at').is_null())
            .group_by('fid')
            .agg([
                pl.len().alias('channel_memberships'),
                pl.n_unique('channel_id').alias('unique_channel_memberships')
            ])
        )
        
        result = base_lf.join(follow_features, on='fid', how='left')
        return result.join(member_features, on='fid', how='left')
    def add_verification_patterns_features(self, base_lf: pl.LazyFrame) -> pl.LazyFrame:
        """Add verification patterns using streaming"""
        verifications = self.loader.get_dataset('verifications', 
            ['fid', 'timestamp', 'deleted_at'])
        
        # On-chain verification patterns
        verif_patterns = (
            verifications
            .filter(pl.col('deleted_at').is_null())
            .with_columns(pl.col('timestamp').cast(pl.Datetime))
            .group_by('fid')
            .agg([
                pl.col('timestamp').diff().dt.total_hours().mean()
                    .alias('avg_hours_between_verifications'),
                pl.col('timestamp').diff().dt.total_hours().std()
                    .alias('std_hours_between_verifications'),
                (pl.col('timestamp').diff().dt.total_hours() < 1)
                    .sum().alias('rapid_verifications')
            ])
        )
        
        return base_lf.join(verif_patterns, on='fid', how='left')
    def build_feature_matrix(self) -> pl.DataFrame:
        """Build complete feature matrix with all optimizations."""
        try:
            # Start with profile features
            feature_lf = self.loader.load_checkpoint('profile')
            if feature_lf is None:
                print("Building profile features...")
                feature_lf = self.extract_profile_features()
                self.loader.save_checkpoint(feature_lf, 'profile')
                
            # Get base FIDs
            base_fids = feature_lf.select('fid').collect(streaming=True)['fid']
            self.loader.set_base_fids(base_fids)
            
            # Build features in optimal order with dependencies
            feature_lf = (
                feature_lf
                .pipe(self.add_network_features)
                .pipe(self.add_temporal_features)
                .pipe(self.add_cast_features)
                .pipe(self.add_reaction_features)
                .pipe(self.add_verification_features)
                .pipe(self.add_user_data_features)
                .pipe(self.add_storage_features)
                .pipe(self.add_signer_features)
                .pipe(self.add_mentions_features)
                .pipe(self.add_update_behavior)
                .pipe(self.add_verification_patterns_features)
                .pipe(self.add_reply_patterns_features)
                .pipe(self.add_power_user_interaction_features)
                .pipe(self.add_influence_features)
                .pipe(self.add_channel_features)
                .pipe(self.add_engagement_metrics)
                .pipe(self.add_authenticity_features)
                .pipe(self.add_derived_features)
                .pipe(self.add_network_quality_features)
                .pipe(self.add_activity_patterns)
            )
            
            # Collect with streaming
            print("Collecting final feature matrix...")
            result = feature_lf.collect(streaming=True)
            
            # Clean and validate
            result = self._validate_and_clean_features(result)
            
            print(f"Feature matrix built with shape: {result.shape}")
            return result
            
        except Exception as e:
            print(f"Error building feature matrix: {str(e)}")
            raise

    def _validate_and_clean_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Validate and clean feature matrix."""
        # Ensure FID column is correct type
        df = df.with_columns(pl.col('fid').cast(pl.Int64))
        
        # Remove duplicates
        df = df.unique(subset=['fid'])
        
        # Fill nulls in numeric columns
        numeric_cols = [
            c for c in df.columns 
            if c != 'fid' and df[c].dtype in [pl.Int64, pl.Float64]
        ]
        
        df = df.with_columns([
            pl.col(c).fill_null(0).cast(pl.Float64) 
            for c in numeric_cols
        ])
        
        return df


data_path = "data"
checkpoint_dir = "checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

feature_eng = OptimizedFeatureEngineering(data_path, checkpoint_dir)
feature_matrix = feature_eng.build_feature_matrix()
print("Feature matrix built with shape:", feature_matrix.shape)
print(feature_matrix.head())


Set base FIDs: 8 records
Filtering links by 8 base FIDs
Filtering links by 8 base FIDs
Filtering casts by 8 base FIDs
Filtering reactions by 8 base FIDs
Filtering verifications by 8 base FIDs
Filtering account_verifications by 8 base FIDs
Filtering user_data by 8 base FIDs
Filtering storage by 8 base FIDs
Filtering signers by 8 base FIDs
Filtering casts by 8 base FIDs
Filtering user_data by 8 base FIDs
Filtering verifications by 8 base FIDs
Filtering casts by 8 base FIDs
Filtering warpcast_power_users by 8 base FIDs
Filtering casts by 8 base FIDs


  if col not in df.columns:
  if 'first_follow' in df.columns and 'last_follow' in df.columns:


Filtering channel_follows by 8 base FIDs
Filtering channel_members by 8 base FIDs
Filtering power_users by 8 base FIDs
Filtering casts by 8 base FIDs
Filtering casts by 8 base FIDs
Filtering reactions by 8 base FIDs
Collecting final feature matrix...
Feature matrix built with shape: (8, 94)
Feature matrix built with shape: (8, 94)
shape: (5, 94)
┌───────┬────────┬────────────┬────────────┬───┬────────────┬────────────┬────────────┬────────────┐
│ fid   ┆ fname  ┆ bio        ┆ avatar_url ┆ … ┆ power_ment ┆ hour_diver ┆ weekday_di ┆ total_acti │
│ ---   ┆ ---    ┆ ---        ┆ ---        ┆   ┆ ions_count ┆ sity       ┆ versity    ┆ vities     │
│ i64   ┆ str    ┆ str        ┆ str        ┆   ┆ ---        ┆ ---        ┆ ---        ┆ ---        │
│       ┆        ┆            ┆            ┆   ┆ u32        ┆ struct[2]  ┆ struct[2]  ┆ u32        │
╞═══════╪════════╪════════════╪════════════╪═══╪════════════╪════════════╪════════════╪════════════╡
│ 23587 ┆ suzhen ┆ null       ┆ null       ┆ …

: 

In [None]:
import polars as pl
import numpy as np
from typing import List
import joblib
from sklearn.ensemble import RandomForestClassifier, VotingClassifier
from sklearn.metrics import (
    roc_auc_score
)
import xgboost as xgb
from lightgbm import LGBMClassifier
import optuna
from sklearn.model_selection import StratifiedKFold
from sklearn.calibration import CalibratedClassifierCV
import shap
from scipy import stats

class StreamingSybilDetector:
    """Memory-efficient Sybil detection with streaming predictions"""
    
    def __init__(self, feature_engineering: 'OptimizedFeatureEngineering'):
        self.feature_engineering = feature_engineering
        self.model = None
        self.base_models = {}
        self.feature_names = None
        self.shap_explainers = {}
        self.shap_values = {}
        self.feature_importance = {}
        
    def prepare_streaming_features(self, df: pl.LazyFrame, required_cols: List[str] = None) -> pl.LazyFrame:
        """Prepare features using streaming operations"""
        if required_cols is None:
            required_cols = [
                'has_ens', 'has_bio', 'has_avatar', 'verification_count',
                'following_count', 'follower_count', 'follower_ratio',
                # Add other base feature columns as needed
            ]
        
        # Ensure all required columns exist
        available_cols = df.columns
        valid_cols = [col for col in required_cols if col in available_cols]
        
        if not valid_cols:
            raise ValueError("No valid feature columns found")
            
        return df.select(valid_cols)
        
    def train(self, train_lf: pl.LazyFrame, labels: np.ndarray, feature_names: List[str]):
        """Train models using streaming batches"""
        self.feature_names = feature_names
        cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
        
        # Define base models
        base_models = {
            'xgb': xgb.XGBClassifier(eval_metric='auc', 
                                   use_label_encoder=False, 
                                   random_state=42),
            'rf': RandomForestClassifier(n_jobs=-1, 
                                       random_state=42, 
                                       class_weight='balanced'),
            'lgbm': LGBMClassifier(n_jobs=-1, 
                                  random_state=42, 
                                  class_weight='balanced')
        }
        
        print("Starting model training...")
        for name, model in base_models.items():
            print(f"\nTraining {name}...")
            study = optuna.create_study(direction='maximize')
            
            def objective(trial):
                # Model-specific parameter tuning
                if name == 'xgb':
                    params = {
                        'max_depth': trial.suggest_int('max_depth', 3, 7),
                        'learning_rate': trial.suggest_float('learning_rate', 0.01, 0.1),
                        'n_estimators': trial.suggest_int('n_estimators', 100, 500),
                        'min_child_weight': trial.suggest_int('min_child_weight', 1, 7),
                        'subsample': trial.suggest_float('subsample', 0.6, 1.0),
                        'colsample_bytree': trial.suggest_float('colsample_bytree', 0.6, 1.0)
                    }
                elif name == 'rf':
                    params = {
                        'n_estimators': trial.suggest_int('n_estimators', 100, 500),
                        'max_depth': trial.suggest_int('max_depth', 3, 7),
                        'min_samples_split': trial.suggest_int('min_samples_split', 2, 10)
                    }
                else:  # lgbm
                    params = {
                        'n_estimators': trial.suggest_int('n_estimators', 100, 500),
                        'learning_rate': trial.suggest_float('learning_rate', 0.01, 0.1),
                        'num_leaves': trial.suggest_int('num_leaves', 20, 100)
                    }
                
                model.set_params(**params)
                scores = []
                
                # Stream data in batches for cross-validation
                BATCH_SIZE = 10000
                for train_idx, val_idx in cv.split(range(len(labels)), labels):
                    train_data = (train_lf
                        .collect(streaming=True)
                        .select(feature_names)
                        .iter_slices(n_rows=BATCH_SIZE))
                    
                    X_train = []
                    for batch in train_data:
                        if len(batch) > 0:
                            X_train.append(batch.to_numpy())
                    X_train = np.vstack(X_train)
                    
                    # Train and evaluate
                    model.fit(X_train[train_idx], labels[train_idx])
                    val_pred = model.predict_proba(X_train[val_idx])[:, 1]
                    score = roc_auc_score(labels[val_idx], val_pred)
                    scores.append(score)
                
                return np.mean(scores)
            
            # Optimize hyperparameters
            study.optimize(objective, n_trials=30, timeout=600)
            print(f"Best parameters for {name}: {study.best_params}")
            
            # Train final model with best parameters
            best_model = type(model)(**study.best_params)
            
            # Stream full training data
            X_train = []
            for batch in (train_lf
                .collect(streaming=True)
                .select(feature_names)
                .iter_slices(n_rows=10000)):
                if len(batch) > 0:
                    X_train.append(batch.to_numpy())
            X_train = np.vstack(X_train)
            
            best_model.fit(X_train, labels)
            
            # Generate SHAP values for feature importance
            try:
                explainer = shap.TreeExplainer(best_model)
                shap_vals = explainer.shap_values(X_train[:1000])  # Sample for memory efficiency
                self.shap_values[name] = shap_vals
                self.shap_explainers[name] = explainer
                print(f"SHAP values generated for {name}")
            except Exception as e:
                print(f"Error generating SHAP values for {name}: {str(e)}")
            
            # Calibrate probabilities
            calibrated_model = CalibratedClassifierCV(best_model, cv=5, n_jobs=-1)
            calibrated_model.fit(X_train, labels)
            self.base_models[name] = calibrated_model
        
        # Create final ensemble
        self.model = VotingClassifier(
            estimators=[(name, model) for name, model in self.base_models.items()],
            voting='soft'
        )
        self.model.fit(X_train, labels)
        
        # Calculate feature importance
        self._calculate_feature_importance()
        print("Training complete!")
        
    def predict_streaming(self, test_lf: pl.LazyFrame, batch_size: int = 10000) -> pl.LazyFrame:
        """Generate predictions using streaming"""
        if self.model is None:
            raise ValueError("Model not trained")
            
        predictions = []
        total_batches = 0
        
        # Process in batches
        for batch in (test_lf
            .collect(streaming=True)
            .iter_slices(n_rows=batch_size)):
            
            if len(batch) == 0:
                continue
                
            # Get predictions from all base models
            batch_predictions = []
            batch_features = batch.select(self.feature_names).to_numpy()
            
            for name, model in self.base_models.items():
                pred_proba = model.predict_proba(batch_features)[:, 1]
                batch_predictions.append(pred_proba)
            
            # Calculate ensemble statistics
            batch_predictions = np.array(batch_predictions)
            mean_probs = batch_predictions.mean(axis=0)
            std_probs = batch_predictions.std(axis=0)
            
            # Calculate confidence intervals
            conf_intervals = stats.norm.interval(0.95, loc=mean_probs, scale=std_probs)
            
            # Create results DataFrame
            batch_results = pl.DataFrame({
                'fid': batch['fid'],
                'bot_probability': mean_probs,
                'prediction_uncertainty': std_probs,
                'confidence_lower': conf_intervals[0],
                'confidence_upper': conf_intervals[1],
                'is_bot': mean_probs >= 0.5
            })
            
            predictions.append(batch_results)
            total_batches += 1
            
            if total_batches % 10 == 0:
                print(f"Processed {total_batches} batches...")
        
        return pl.concat(predictions)
    
    def _calculate_feature_importance(self):
        """Calculate feature importance across all models"""
        importance_dict = {}
        
        for name, model in self.base_models.items():
            if hasattr(model, 'feature_importances_'):
                importances = model.feature_importances_
            elif hasattr(model, 'base_estimator_'):
                importances = model.base_estimator_.feature_importances_
            else:
                continue
                
            for feat, imp in zip(self.feature_names, importances):
                if feat not in importance_dict:
                    importance_dict[feat] = []
                importance_dict[feat].append(imp)
        
        # Average importance across models
        self.feature_importance = {
            feat: np.mean(scores) 
            for feat, scores in importance_dict.items()
        }

    def save_model(self, path: str):
        """Save model with streaming support"""
        model_data = {
            'base_models': self.base_models,
            'ensemble_model': self.model,
            'feature_names': self.feature_names,
            'feature_importance': self.feature_importance
        }
        joblib.dump(model_data, path)
        
    def load_model(self, path: str):
        """Load model with streaming support"""
        model_data = joblib.load(path)
        self.base_models = model_data['base_models']
        self.model = model_data['ensemble_model']
        self.feature_names = model_data['feature_names']
        self.feature_importance = model_data['feature_importance']



# Initialize components
feature_eng = OptimizedFeatureEngineering("data", "checkpoints")
detector = StreamingSybilDetector(feature_eng)

# Build features using streaming
feature_matrix = feature_eng.build_feature_matrix()

# Load labels efficiently
labels_lf = pl.scan_csv('data/labels.csv')

# Join with streaming
data = pl.concat([
    feature_matrix.lazy(),
    labels_lf.select(['fid', 'bot'])
]).collect(streaming=True)

# Split data
from sklearn.model_selection import train_test_split
train_idx, test_idx = train_test_split(
    range(len(data)), 
    test_size=0.2, 
    stratify=data['bot'],
    random_state=42
)

# Train model with streaming
train_data = data[train_idx].lazy()
test_data = data[test_idx].lazy()

# Get feature names
feature_names = [c for c in data.columns if c not in ['fid', 'bot']]

# Train and evaluate
detector.train(train_data, data['bot'][train_idx].to_numpy(), feature_names)
predictions = detector.predict_streaming(test_data)

# Save model
detector.save_model("checkpoints/sybil_detector.joblib")