In [None]:
import gc
import os
import re
from typing import List, Dict, Tuple

import numpy as np
import pandas as pd
import polars as pl
import polars.selectors as cs

from sklearn.cluster import KMeans, DBSCAN, MiniBatchKMeans
from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.metrics import classification_report, roc_auc_score
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.preprocessing import StandardScaler, LabelEncoder

# import warnings
# warnings.filterwarnings('ignore')

In [None]:
# Customer attributes for clustering analysis
# CUSTOMER_ATTRIBUTES = ['profileId', 'companyID', 'sex', 'nationality', 'frequentFlyer', 'isVip', 'bySelf', 'corporateTariffCode']
UNNEEDED_ATTRIBUTES = [
    'ranker_id', 'isAccess3D', 'totalPrice', 'taxes', 'legs0_arrivalAt', 'legs0_duration', 'frequent_flyer',
    r'^legs1_(departureAt|arrivalAt|duration)$'
    r'^legs[01]_segments[0-3]_(operatingCarrier_code|aircraft_code|flightNumber)$',
    r'^legs[01]_segments[0-3]_(arrivalTo|baggage|seats).*$'
]
POLARS_INDEX_COL = ['__index_level_0__']
MAJOR_HUBS = ['ATL','DXB','DFW','HND','LHR','DEN','ORD','IST','PVG','ICN','CDG', 'JFK','CLT','MEX','SFO','EWR','MIA','BKK','GRU','HKG']


def get_cabin_class_columns(df: pl.DataFrame) -> List[str]:
    """Get all cabin class columns from the dataframe."""
    columns = df.columns
    return [col for col in columns if col.startswith('legs') and col.endswith('_cabinClass')]


def create_customer_aggregation_features() -> List[pl.Expr]:
    """Create customer aggregation expressions for basic attributes and search behavior."""
    return [
        # Basic customer attributes (take first non-null value per customer)
        pl.col('companyID').drop_nulls().first().alias('companyID'),
        pl.col('sex').drop_nulls().first().alias('sex'),
        pl.col('nationality').drop_nulls().first().alias('nationality'),
        pl.col('frequentFlyer').drop_nulls().first().alias('frequentFlyer'),
        pl.col('isVip').drop_nulls().first().alias('isVip'),
        pl.col('bySelf').drop_nulls().first().alias('bySelf'),
        pl.col('corporateTariffCode').drop_nulls().first().alias('corporateTariffCode'),

        # Normalized frequentFlyer program, addressing null values as null strings, and translating UT program
        pl.col('frequentFlyer').drop_nulls().first().str.replace('- ЮТэйр ЗАО', 'UT').fill_null('').alias('ff_normalized'),

        # Search behavior metrics
        pl.len().alias('total_searches'),
        pl.col('legs1_departureAt').is_not_null().mean().alias('roundtrip_preference'),
        pl.col('searchRoute').drop_nulls().n_unique().alias('unique_routes_searched'),
    ]


def create_booking_lead_time_features() -> List[pl.Expr]:
    """Create booking lead time statistics."""
    # Calculate booking lead time in days
    booking_lead_expr = (
        (pl.col('legs0_departureAt').str.to_datetime() -
         pl.col('requestDate').cast(pl.Datetime)) / pl.duration(days=1)
    ).cast(pl.Int32)

    return [
        booking_lead_expr.min().alias('min_booking_lead_days'),
        booking_lead_expr.max().alias('max_booking_lead_days'),
        booking_lead_expr.mean().alias('avg_booking_lead_days'),
        booking_lead_expr.median().alias('median_booking_lead_days'),
    ]


def create_travel_preference_features() -> List[pl.Expr]:
    """Create travel preference features for most common airports and carriers."""
    return [
        # Most common departure airport
        pl.col('legs0_segments0_departureFrom_airport_iata').drop_nulls().mode().first().alias('most_common_departure_airport'),
        pl.col('legs0_segments0_departureFrom_airport_iata').drop_nulls().n_unique().alias('unique_departure_airports'),

        # Most common marketing carrier
        pl.col('legs0_segments0_marketingCarrier_code').drop_nulls().mode().first().alias('most_common_carrier'),
        pl.col('legs0_segments0_marketingCarrier_code').drop_nulls().n_unique().alias('unique_carriers_used'),
    ]


def create_cabin_class_features(cabin_class_cols: List[str]) -> List[pl.Expr]:
    """Create cabin class preference statistics."""
    if not cabin_class_cols:
        # Return default values if no cabin class columns found
        return [
            pl.lit(None).alias('min_cabin_class'),
            pl.lit(None).alias('max_cabin_class'),
            pl.lit(None).alias('avg_cabin_class'),
        ]

    return [
        # Cabin class statistics across all segments
        pl.min_horizontal([pl.col(col) for col in cabin_class_cols]).min().alias('min_cabin_class'),
        pl.max_horizontal([pl.col(col) for col in cabin_class_cols]).max().alias('max_cabin_class'),
        pl.mean_horizontal([pl.col(col) for col in cabin_class_cols]).mean().alias('avg_cabin_class'),
    ]


def create_temporal_preference_features() -> List[pl.Expr]:
    """Create temporal preference features for departure patterns."""
    return [
        # Weekday preference (most common day of week for departures)
        pl.col('legs0_departureAt').str.to_datetime().dt.weekday()
          .mode().first().alias('weekday_preference'),

        # Weekend travel rate (percentage of weekend departures - 5=Sat, 6=Sun)
        pl.col('legs0_departureAt').str.to_datetime().dt.weekday()
          .map_elements(lambda x: 1 if x >= 5 else 0, return_dtype=pl.Int8)
          .mean().alias('weekend_travel_rate'),

        # Time of day variance (how consistent are their departure times)
        pl.col('legs0_departureAt').str.to_datetime().dt.hour()
          .std().alias('time_of_day_variance'),

        # Night flight preference (flights departing 22:00-06:00)
        pl.col('legs0_departureAt').str.to_datetime().dt.hour()
          .map_elements(lambda x: 1 if (x >= 22 or x < 6) else 0, return_dtype=pl.Int8)
          .mean().alias('night_flight_preference')
    ]


def create_route_specific_features() -> List[pl.Expr]:
    """Create features related to route preferences and characteristics."""

    return [
        # Route loyalty (how frequently they search the same routes)
        (pl.col('searchRoute').n_unique() / pl.len())
          .map_elements(lambda x: 1 - x if x > 0 else 0)  # Invert so higher = more loyal
          .alias('route_loyalty'),

        # Hub preference (preference for major hub airports)
        pl.concat_list([
            pl.col('legs0_segments0_departureFrom_airport_iata').is_in(MAJOR_HUBS),
            pl.col('legs0_segments0_arrivalTo_airport_iata').is_in(MAJOR_HUBS)
        ]).list.mean().alias('hub_preference'),

        # Connection tolerance (preference for flights with connections)
        pl.col('total_segments').mean().alias('connection_tolerance'),

        # Short haul preference
        (1 - (pl.col('legs0_duration').str.extract(r'^(\d+):(\d+)', 1).cast(pl.Int32) / 12))
          .clip(0, 1).alias('short_haul_preference'),

        # Domestic/international ratio based on route length
        # Assuming routes with same first letter in IATA codes are likely domestic
        pl.col('searchRoute').map_elements(
            lambda route: 1 if route and route[:1] == route[3:4] else 0,
            return_dtype=pl.Int8
        ).mean().alias('domestic_international_ratio')
    ]


def create_price_sensitivity_features() -> List[pl.Expr]:
    """Create features related to price sensitivity and patterns."""
    return [
        # Price position preference (typical percentile chosen)
        pl.col('price_percentile').mean().alias('price_position_preference'),

        # Price to duration sensitivity
        # Higher values mean more willing to pay for shorter flights
        pl.covar(
            pl.col('totalPrice'),
            pl.col('total_duration') * -1  # Negative so higher = more sensitive
        ).alias('price_to_duration_sensitivity'),

        # Premium economy preference (assuming cabin class 2 is premium economy)
        pl.mean_horizontal([
            pl.col(f'legs0_segments{i}_cabinClass') == 2
            for i in range(4)
        ]).mean().alias('premium_economy_preference'),

        # Consistent price tier (lower variance = more consistent)
        pl.col('price_tier').std().map_elements(
            lambda x: 1 - min(x / 3, 1) if x is not None else 0.5  # Invert and normalize
        ).alias('consistent_price_tier')
    ]


def create_service_preference_features() -> List[pl.Expr]:
    """Create features related to service preferences."""
    return [
        # Baggage preference (average selected baggage allowance)
        pl.concat_list([
            pl.col(f'legs0_segments{i}_baggageAllowance_quantity')
            for i in range(4)
        ]).list.mean().alias('baggage_preference'),

        # Loyalty program utilization
        pl.col('frequentFlyer').is_not_null().mean().alias('loyalty_program_utilization')
    ]


def create_derived_metrics() -> List[pl.Expr]:
    """Create complex derived metrics from combinations of features."""
    return [
        # Price flexibility index (higher price variance / booking rate = more flexible)
        (pl.col('totalPrice').std() /
         pl.col('booking_rate').clip(0.01, 1))
        .alias('price_flexibility_index'),

        # Convenience priority score (higher = more emphasis on convenient times)
        ((1 - pl.col('time_of_day_variance')) * 10 +
         pl.col('price_to_duration_sensitivity') * 5)
        .alias('convenience_priority_score'),

        # Loyalty vs price index (higher = more loyal, less price sensitive)
        (pl.col('loyalty_program_utilization') * 10 -
         pl.col('price_position_preference') / 10)
        .alias('loyalty_vs_price_index'),

        # Planning consistency score (inverse of lead time variance)
        (1 / (pl.col('max_booking_lead_days') - pl.col('min_booking_lead_days') + 1))
        .alias('planning_consistency_score'),

        # Luxury index (combination of cabin class and price tier)
        (pl.col('avg_cabin_class') * 20 +
         pl.col('price_position_preference') / 2)
        .alias('luxury_index')
    ]


def extract_customer_features(df: pl.DataFrame) -> pl.DataFrame:
    """
    Extract customer features for clustering analysis.
    Aggregates by profileId to create customer-level features.
    """
    # Check if already processed
    if df.height > 0 and 'total_searches' in df.columns:
        return df

    # Get cabin class columns
    cabin_class_cols = [col for col in df.columns if col.startswith('legs') and col.endswith('_cabinClass')]

    # Create lazy frame and group by profileId
    lazy_df = df.lazy().group_by('profileId')

    # Apply feature groups
    customer_features = lazy_df.agg([
        *create_customer_aggregation_features(),
        *create_booking_lead_time_features(),
        *create_travel_preference_features(),
        *create_cabin_class_features(cabin_class_cols),
        *create_temporal_preference_features(),
        *create_route_specific_features(),
        *create_price_sensitivity_features(),
        *create_service_preference_features()
    ])

    # Materialize to generate the basic features
    base_features = customer_features.collect()

    # Add the derived metrics that depend on the generated features
    enhanced_features = base_features.with_columns(create_derived_metrics())

    print(f"Generated {len(enhanced_features.columns)} customer features for {len(enhanced_features)} customers")
    return enhanced_features


In [None]:
import polars as pl
import numpy as np
from sklearn.cluster import MiniBatchKMeans
from sklearn.preprocessing import StandardScaler
import lightgbm as lgb
import joblib
from typing import Dict, List, Tuple, Optional
import gc
from pathlib import Path

class ScalableFlightRecommendationModel:
    """
    Memory-efficient flight recommendation model designed for large datasets (10M+ rows).
    Uses streaming processing, chunked operations, and optimized data structures.
    """

    def __init__(self, n_customer_segments=50, chunk_size=100000, random_state=42):
        self.n_customer_segments = n_customer_segments
        self.chunk_size = chunk_size
        self.random_state = random_state

        # Models
        self.customer_segmentation_model = None
        self.global_model = None

        # Preprocessing
        self.customer_scaler = StandardScaler()
        self.segment_centroids = None

        # Feature lists
        self.customer_features = []
        self.flight_features = []

    def create_flight_features_batch(self, df_chunk: pl.DataFrame) -> pl.DataFrame:
        """
        Create flight-level features
        """

        new_columns = []

        # Price features (always relative to search session)
        new_columns.extend([
            (pl.col('totalPrice').rank(method='ordinal').over('ranker_id') /
             pl.col('totalPrice').count().over('ranker_id')).alias('price_rank_pct'),
            (pl.col('totalPrice') / pl.col('totalPrice').min().over('ranker_id')).alias('price_ratio_to_min'),
        ])

        # Duration features - only create if not already present
        if 'duration_hours' not in df_chunk.columns and 'legs0_duration' in df_chunk.columns:
            new_columns.extend([
                pl.col('legs0_duration').str.split(':').list.get(0).cast(pl.Int32, strict=False).alias('duration_hours'),
                pl.col('legs0_duration').str.split(':').list.get(1).cast(pl.Int32, strict=False).alias('duration_minutes'),
            ])

        # Total segments - only if not present
        if 'total_segments' not in df_chunk.columns:
            new_columns.append(
                sum([pl.col(f'legs0_segments{i}_departureFrom_airport_iata').is_not_null().cast(pl.Int8)
                     for i in range(4)]).alias('total_segments')
            )

        # Temporal features - only if not present
        if 'departure_hour' not in df_chunk.columns and 'legs0_departureAt' in df_chunk.columns:
            new_columns.extend([
                pl.col('legs0_departureAt').str.to_datetime().dt.hour().alias('departure_hour'),
                pl.col('legs0_departureAt').str.to_datetime().dt.weekday().alias('departure_weekday'),
            ])

        # Lead time - only if not present
        if 'booking_lead_days' not in df_chunk.columns and all(col in df_chunk.columns for col in ['legs0_departureAt', 'requestDate']):
            new_columns.append(
                ((pl.col('legs0_departureAt').str.to_datetime() -
                  pl.col('requestDate').cast(pl.Datetime)) / pl.duration(days=1)).cast(pl.Int32).alias('booking_lead_days')
            )

        # Primary carrier
        if 'primary_carrier' not in df_chunk.columns and 'legs0_segments0_marketingCarrier_code' in df_chunk.columns:
            new_columns.append(
                pl.col('legs0_segments0_marketingCarrier_code').fill_null('unknown').alias('primary_carrier')
            )

        # Apply new columns if any
        if new_columns:
            df_chunk = df_chunk.with_columns(new_columns)

        # Derived features - create based on what's available
        derived_columns = []

        if 'departure_hour' in df_chunk.columns:
            if 'is_daytime' not in df_chunk.columns:
                derived_columns.append((pl.col('departure_hour').is_between(6, 22)).cast(pl.Int8).alias('is_daytime'))

        if 'departure_weekday' in df_chunk.columns:
            if 'is_weekend' not in df_chunk.columns:
                derived_columns.append((pl.col('departure_weekday') >= 5).cast(pl.Int8).alias('is_weekend'))

        if 'total_segments' in df_chunk.columns:
            if 'has_connections' not in df_chunk.columns:
                derived_columns.append((pl.col('total_segments') > 1).cast(pl.Int8).alias('has_connections'))

        if all(col in df_chunk.columns for col in ['duration_hours', 'duration_minutes']):
            if 'total_duration_mins' not in df_chunk.columns:
                derived_columns.append(
                    (pl.col('duration_hours').fill_null(0) * 60 + pl.col('duration_minutes').fill_null(0)).alias('total_duration_mins')
                )

        if derived_columns:
            df_chunk = df_chunk.with_columns(derived_columns)

        return df_chunk

    def create_customer_segments_streaming(self, customer_features_df: pl.DataFrame) -> Dict:
        """Create customer segments using Agglomerative Clustering based on evaluation results."""

        # Core behavioral features for segmentation based on your successful evaluation
        segmentation_features = [
            'total_searches', 'isVip', 'roundtrip_preference',
            'avg_booking_lead_days', 'unique_carriers_used',
            # Additional features that likely contributed to good clustering
            'weekend_travel_rate', 'route_loyalty', 'hub_preference', 'connection_tolerance'
        ]

        # Filter available features
        available_features = [f for f in segmentation_features if f in customer_features_df.columns]
        print(f"Using {len(available_features)} features for Agglomerative clustering: {available_features}")

        # Convert to numpy for sklearn
        X_segment = customer_features_df.select(available_features).to_numpy()
        X_segment = np.nan_to_num(X_segment, nan=0.0)

        # Scale features
        X_segment_scaled = self.customer_scaler.fit_transform(X_segment)

        # Use Agglomerative Clustering based on your evaluation results
        print(f"Clustering {len(X_segment)} customers using Agglomerative Clustering...")

        from sklearn.cluster import AgglomerativeClustering
        from sklearn.neighbors import kneighbors_graph

        # Create connectivity matrix for more efficient clustering on large datasets
        if len(X_segment) > 10000:
            print("Creating connectivity graph for large dataset...")
            connectivity = kneighbors_graph(
                X_segment_scaled, n_neighbors=10, include_self=False
            )
        else:
            connectivity = None

        # Use optimal number of clusters from your evaluation (10)
        # But allow for scaling with larger datasets
        n_clusters = min(10, max(5, len(X_segment) // 1000))  # Scale clusters with data size

        self.customer_segmentation_model = AgglomerativeClustering(
            n_clusters=n_clusters,
            connectivity=connectivity,
            linkage='ward'  # Generally works well with scaled features
        )

        segments = self.customer_segmentation_model.fit_predict(X_segment_scaled)

        # Create pseudo-centroids for consistency with prediction pipeline
        # (Agglomerative doesn't have centroids, so we compute them)
        unique_segments = np.unique(segments)
        centroids = []
        for segment_id in unique_segments:
            segment_mask = segments == segment_id
            centroid = X_segment_scaled[segment_mask].mean(axis=0)
            centroids.append(centroid)

        centroids = np.array(centroids)

        # Analyze cluster profiles like in your evaluation
        customer_df = customer_features_df.to_pandas()
        customer_df['cluster'] = segments

        print("\n📊 CLUSTER PROFILES:")
        print("=" * 50)

        cluster_profiles = customer_df.groupby('cluster').agg({
            'total_searches': 'mean',
            'isVip': 'mean',  # This gives VIP rate
            'roundtrip_preference': 'mean',
            'avg_booking_lead_days': 'mean',
            'unique_carriers_used': 'mean'
        }).round(2)

        cluster_sizes = customer_df.groupby('cluster').size()
        cluster_profiles['size'] = cluster_sizes

        print(cluster_profiles)

        # Store segment info
        segment_info = {
            'segments': segments,
            'centroids': centroids,
            'feature_names': available_features,
            'cluster_profiles': cluster_profiles,
            'n_clusters': n_clusters
        }

        return segment_info

    def prepare_training_data_chunked(self, df: pl.DataFrame, customer_features_df: pl.DataFrame,
                                    segment_info: Dict) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        """Prepare training data in chunks to manage memory."""

        print("Adding segments to customer features...")
        customer_with_segments = customer_features_df.with_columns([
            pl.Series('customer_segment', segment_info['segments'])
        ])

        # Define feature columns ONCE before chunking
        feature_cols = [
            # Customer features (select key ones to reduce dimensionality)
            'total_searches', 'roundtrip_preference', 'avg_booking_lead_days',
            'weekend_travel_rate', 'route_loyalty', 'customer_segment',

            # Flight features
            'price_rank_pct', 'price_ratio_to_min', 'duration_hours', 'total_segments',
            'departure_hour', 'booking_lead_days', 'total_duration_mins',
            'is_daytime', 'is_weekend', 'has_connections',

            # Interactions
            'daytime_alignment', 'weekend_alignment', 'price_preference_match', 'carrier_loyalty_match'
        ]

        # Test with a small sample to determine available columns
        print("Determining available features from data sample...")
        sample_ranker_ids = df.select('ranker_id').unique().limit(10)['ranker_id'].to_list()
        df_sample = df.filter(pl.col('ranker_id').is_in(sample_ranker_ids))

        # Process sample to see what columns will be available
        df_sample = self.create_flight_features_batch(df_sample)
        df_sample = df_sample.join(customer_with_segments, on='profileId', how='left')
        df_sample = df_sample.with_columns([
            (pl.col('is_daytime') * (1.0 - pl.col('night_flight_preference').fill_null(0.5))).alias('daytime_alignment'),
            (pl.col('is_weekend') * pl.col('weekend_travel_rate').fill_null(0.5)).alias('weekend_alignment'),
            (pl.col('price_rank_pct') * pl.col('price_position_preference').fill_null(0.5)).alias('price_preference_match'),
            (pl.col('primary_carrier') == pl.col('most_common_carrier').fill_null('unknown')).cast(pl.Int8).alias('carrier_loyalty_match'),
        ])

        # Determine final feature list ONCE
        available_cols = [col for col in feature_cols if col in df_sample.columns]
        self.flight_features = available_cols  # Set this ONCE before chunking

        print(f"Will use {len(available_cols)} features: {available_cols}")
        del df_sample  # Clean up sample

        # Process data in chunks
        print(f"Processing {len(df)} rows in chunks of {self.chunk_size}...")

        X_chunks = []
        y_chunks = []
        group_chunks = []

        # Get unique ranker_ids to avoid splitting groups
        ranker_ids = df.select('ranker_id').unique().sort('ranker_id')['ranker_id'].to_list()

        for i in range(0, len(ranker_ids), self.chunk_size):
            chunk_ranker_ids = ranker_ids[i:i + self.chunk_size]

            # Filter chunk by ranker_ids
            df_chunk = df.filter(pl.col('ranker_id').is_in(chunk_ranker_ids))

            print(f"Processing chunk {i//self.chunk_size + 1}: {len(df_chunk)} rows")

            # Add flight features
            df_chunk = self.create_flight_features_batch(df_chunk)

            # Join with customer features
            df_chunk = df_chunk.join(customer_with_segments, on='profileId', how='left')

            # Create interaction features efficiently
            df_chunk = df_chunk.with_columns([
                (pl.col('is_daytime') * (1.0 - pl.col('night_flight_preference').fill_null(0.5))).alias('daytime_alignment'),
                (pl.col('is_weekend') * pl.col('weekend_travel_rate').fill_null(0.5)).alias('weekend_alignment'),
                (pl.col('price_rank_pct') * pl.col('price_position_preference').fill_null(0.5)).alias('price_preference_match'),
                (pl.col('primary_carrier') == pl.col('most_common_carrier').fill_null('unknown')).cast(pl.Int8).alias('carrier_loyalty_match'),
            ])

            # Use the predetermined feature list (consistent across all chunks)
            try:
                X_chunk = df_chunk.select(self.flight_features).fill_null(0).to_numpy()
            except Exception as e:
                print(f"Error selecting features in chunk {i//self.chunk_size + 1}: {e}")
                # Fallback: check what columns are actually available in this chunk
                chunk_available_cols = [col for col in self.flight_features if col in df_chunk.columns]
                print(f"Available columns in this chunk: {chunk_available_cols}")
                # Pad missing columns with zeros
                chunk_data = df_chunk.select(chunk_available_cols).fill_null(0).to_pandas()
                for missing_col in [col for col in self.flight_features if col not in chunk_available_cols]:
                    chunk_data[missing_col] = 0
                # Reorder to match self.flight_features order
                X_chunk = chunk_data[self.flight_features].to_numpy()

            y_chunk = df_chunk.select('selected').to_numpy().ravel()
            group_chunk = df_chunk.select('ranker_id').to_numpy().ravel()

            X_chunks.append(X_chunk)
            y_chunks.append(y_chunk)
            group_chunks.append(group_chunk)

            # Clean up memory
            del df_chunk
            gc.collect()

        # Combine chunks
        print("Combining chunks...")
        X = np.vstack(X_chunks)
        y = np.concatenate(y_chunks)
        groups = np.concatenate(group_chunks)

        print(f"Final training data: {X.shape[0]} rows, {X.shape[1]} features")
        print(f"Feature consistency check: Expected {len(self.flight_features)} features, got {X.shape[1]}")

        return X, y, groups

    def fit(self, df: pl.DataFrame, customer_features_df: pl.DataFrame) -> 'ScalableFlightRecommendationModel':
        """Fit the model using memory-efficient processing."""

        print("Creating customer segments...")
        segment_info = self.create_customer_segments_streaming(customer_features_df)  # Agglomerative Clustering

        # Store centroids for prediction (needed for Agglomerative Clustering)
        self.segment_centroids = segment_info['centroids']

        print("Preparing training data...")
        X, y, groups = self.prepare_training_data_chunked(df, customer_features_df, segment_info)

        print("Training LightGBM model...")

        # Create group sizes for LightGBM ranking
        unique_groups, group_sizes = np.unique(groups, return_counts=True)

        # LightGBM with memory-efficient settings
        self.global_model = lgb.LGBMRanker(
            objective='lambdarank',
            metric='ndcg',
            boosting_type='gbdt',
            num_leaves=31,
            learning_rate=0.1,
            feature_fraction=0.8,
            bagging_fraction=0.8,
            bagging_freq=5,
            verbose=-1,
            random_state=self.random_state,
            n_jobs=4,
            force_row_wise=True  # Memory efficient
        )

        self.global_model.fit(
            X, y,
            group=group_sizes,
            callbacks=[lgb.early_stopping(50), lgb.log_evaluation(100)]
        )

        print("Training complete!")
        print(f"Final model uses {len(self.flight_features)} features")
        print(f"Customer segments: {segment_info['n_clusters']}")

        return self

    def predict_proba_chunked(self, df: pl.DataFrame, customer_features_df: pl.DataFrame) -> np.ndarray:
        """Make predictions in chunks to manage memory."""

        print("Predicting customer segments...")
        # Get segmentation features
        segmentation_features = [f for f in ['total_searches', 'isVip', 'roundtrip_preference',
                                           'avg_booking_lead_days', 'unique_carriers_used',
                                           'weekend_travel_rate', 'route_loyalty', 'hub_preference', 'connection_tolerance']
                               if f in customer_features_df.columns]

        X_segment = customer_features_df.select(segmentation_features).to_numpy()
        X_segment = np.nan_to_num(X_segment, nan=0.0)
        X_segment_scaled = self.customer_scaler.transform(X_segment)

        # For Agglomerative Clustering, we need to predict segments by finding closest centroids
        # (since AgglomerativeClustering doesn't have a predict method)
        if hasattr(self.customer_segmentation_model, 'cluster_centers_'):
            # This is KMeans-based
            segments = self.customer_segmentation_model.predict(X_segment_scaled)
        else:
            # This is Agglomerative - find closest centroid for each customer
            from sklearn.metrics.pairwise import euclidean_distances

            # Get centroids from segment_info (stored during training)
            if not hasattr(self, 'segment_centroids'):
                print("Warning: No segment centroids found. Using default assignment.")
                segments = np.zeros(len(X_segment_scaled), dtype=int)
            else:
                distances = euclidean_distances(X_segment_scaled, self.segment_centroids)
                segments = np.argmin(distances, axis=1)

        customer_with_segments = customer_features_df.with_columns([
            pl.Series('customer_segment', segments)
        ])

        print(f"Making predictions for {len(df)} rows using {len(self.flight_features)} features...")
        print(f"Expected features: {self.flight_features}")

        # Get unique ranker_ids for chunking
        ranker_ids = df.select('ranker_id').unique().sort('ranker_id')['ranker_id'].to_list()

        predictions = []

        for i in range(0, len(ranker_ids), self.chunk_size):
            chunk_ranker_ids = ranker_ids[i:i + self.chunk_size]

            # Process chunk
            df_chunk = df.filter(pl.col('ranker_id').is_in(chunk_ranker_ids))
            df_chunk = self.create_flight_features_batch(df_chunk)
            df_chunk = df_chunk.join(customer_with_segments, on='profileId', how='left')

            # Add interaction features
            df_chunk = df_chunk.with_columns([
                (pl.col('is_daytime') * (1.0 - pl.col('night_flight_preference').fill_null(0.5))).alias('daytime_alignment'),
                (pl.col('is_weekend') * pl.col('weekend_travel_rate').fill_null(0.5)).alias('weekend_alignment'),
                (pl.col('price_rank_pct') * pl.col('price_position_preference').fill_null(0.5)).alias('price_preference_match'),
                (pl.col('primary_carrier') == pl.col('most_common_carrier').fill_null('unknown')).cast(pl.Int8).alias('carrier_loyalty_match'),
            ])

            # Extract features using the SAME feature list from training
            try:
                X_chunk = df_chunk.select(self.flight_features).fill_null(0).to_numpy()
            except Exception as e:
                print(f"Error selecting features in prediction chunk {i//self.chunk_size + 1}: {e}")
                # Check what's available vs what's expected
                available_in_chunk = [col for col in self.flight_features if col in df_chunk.columns]
                missing_in_chunk = [col for col in self.flight_features if col not in df_chunk.columns]

                print(f"Available: {available_in_chunk}")
                print(f"Missing: {missing_in_chunk}")

                # Create dataframe with available columns and pad missing ones with zeros
                chunk_data = df_chunk.select(available_in_chunk).fill_null(0).to_pandas()
                for missing_col in missing_in_chunk:
                    chunk_data[missing_col] = 0.0

                # Ensure same order as training
                X_chunk = chunk_data[self.flight_features].to_numpy()

            # Predict
            chunk_predictions = self.global_model.predict(X_chunk)
            predictions.append(chunk_predictions)

            print(f"Processed chunk {i//self.chunk_size + 1}/{len(range(0, len(ranker_ids), self.chunk_size))}")

            del df_chunk, X_chunk
            gc.collect()

        final_predictions = np.concatenate(predictions)
        print(f"Generated {len(final_predictions)} predictions")
        return final_predictions

    def save_model(self, filepath: str):
        """Save model efficiently."""
        model_data = {
            'customer_segmentation_model': self.customer_segmentation_model,
            'global_model': self.global_model,
            'customer_scaler': self.customer_scaler,
            'flight_features': self.flight_features,
            'n_customer_segments': self.n_customer_segments,
            'chunk_size': self.chunk_size,
            'random_state': self.random_state
        }
        joblib.dump(model_data, filepath, compress=3)

    def load_model(self, filepath: str):
        """Load saved model."""
        model_data = joblib.load(filepath)
        for key, value in model_data.items():
            setattr(self, key, value)


# Memory-efficient training function
def train_scalable_model(train_df: pl.DataFrame, test_df: pl.DataFrame = None) -> Tuple[np.ndarray, ScalableFlightRecommendationModel]:
    """
    Train model efficiently on large dataset.

    Estimated memory usage: ~4-6GB for 18M rows
    Training time: ~30-60 minutes on modern hardware
    """

    print("Extracting customer features...")
    train_customer_features = extract_customer_features(train_df)

    print("Initializing scalable model...")
    model = ScalableFlightRecommendationModel(
        n_customer_segments=50,  # Increased for better granularity
        chunk_size=50000  # Adjust based on available RAM
    )

    print("Training model...")
    model.fit(train_df, train_customer_features)

    predictions = None
    if test_df is not None:
        print("Making predictions...")
        test_customer_features = extract_customer_features(test_df)
        predictions = model.predict_proba_chunked(test_df, test_customer_features)

    # Save model
    print("Saving model...")
    model.save_model('scalable_flight_model.joblib')

    return predictions, model


# Performance monitoring utilities
class PerformanceMonitor:
    """Monitor memory and time usage during training."""

    @staticmethod
    def estimate_memory_usage(n_rows: int, n_features: int) -> str:
        """Estimate memory usage."""
        # Rough estimates in GB
        base_data = (n_rows * n_features * 8) / (1024**3)  # float64
        working_memory = base_data * 2  # intermediate calculations
        model_memory = 0.5  # model storage

        total = base_data + working_memory + model_memory
        return f"Estimated peak memory usage: {total:.1f}GB"

    @staticmethod
    def get_recommendations(n_rows: int) -> Dict[str, any]:
        """Get performance recommendations based on data size."""

        if n_rows > 50_000_000:
            return {
                'chunk_size': 25000,
                'n_segments': 100,
                'early_stopping': 30,
                'feature_fraction': 0.6,
                'recommendation': 'Consider using a cluster or high-memory machine (32GB+ RAM)'
            }
        elif n_rows > 10_000_000:
            return {
                'chunk_size': 50000,
                'n_segments': 50,
                'early_stopping': 50,
                'feature_fraction': 0.8,
                'recommendation': 'Should work well on 16GB+ RAM machine'
            }
        else:
            return {
                'chunk_size': 100000,
                'n_segments': 25,
                'early_stopping': 100,
                'feature_fraction': 0.9,
                'recommendation': 'Standard configuration should work fine'
            }

# Optimized customer feature extraction maintaining lazy evaluation
def extract_customer_features_scalable(df: pl.DataFrame, chunk_size: int = 100000) -> pl.DataFrame:
    """
    Memory-efficient version of customer feature extraction.
    Maintains lazy evaluation for optimal performance.
    """
    print(f"Extracting customer features from {len(df)} rows using lazy evaluation...")

    # Get cabin class columns that exist
    cabin_class_cols = [col for col in df.columns if col.startswith('legs') and col.endswith('_cabinClass')]

    # Core aggregations - optimized for memory efficiency
    customer_aggs = [
        # Basic attributes (first non-null)
        pl.col('companyID').drop_nulls().first().alias('companyID'),
        pl.col('sex').drop_nulls().first().alias('sex'),
        pl.col('nationality').drop_nulls().first().alias('nationality'),
        pl.col('frequentFlyer').drop_nulls().first().alias('frequentFlyer'),
        pl.col('isVip').drop_nulls().first().alias('isVip'),
        pl.col('bySelf').drop_nulls().first().alias('bySelf'),
        pl.col('corporateTariffCode').drop_nulls().first().alias('corporateTariffCode'),

        # Normalized frequent flyer
        pl.col('frequentFlyer').drop_nulls().first().str.replace('- ЮТэйр ЗАО', 'UT').fill_null('').alias('ff_normalized'),

        # Search behavior (efficient counting)
        pl.len().alias('total_searches'),
        pl.col('legs1_departureAt').is_not_null().mean().alias('roundtrip_preference'),
        pl.col('searchRoute').drop_nulls().n_unique().alias('unique_routes_searched'),

        # Booking lead time (simplified calculation)
        ((pl.col('legs0_departureAt').str.to_datetime() - pl.col('requestDate').cast(pl.Datetime)) / pl.duration(days=1))
        .cast(pl.Int32).mean().alias('avg_booking_lead_days'),

        # Travel preferences (most common values)
        pl.col('legs0_segments0_departureFrom_airport_iata').drop_nulls().mode().first().alias('most_common_departure_airport'),
        pl.col('legs0_segments0_departureFrom_airport_iata').drop_nulls().n_unique().alias('unique_departure_airports'),
        pl.col('legs0_segments0_marketingCarrier_code').drop_nulls().mode().first().alias('most_common_carrier'),
        pl.col('legs0_segments0_marketingCarrier_code').drop_nulls().n_unique().alias('unique_carriers_used'),

        # Temporal preferences
        pl.col('legs0_departureAt').str.to_datetime().dt.weekday().mode().first().alias('weekday_preference'),
        pl.col('legs0_departureAt').str.to_datetime().dt.weekday().map_elements(
            lambda x: 1 if x >= 5 else 0, return_dtype=pl.Int8
        ).mean().alias('weekend_travel_rate'),

        # Night flight preference
        pl.col('legs0_departureAt').str.to_datetime().dt.hour().map_elements(
            lambda x: 1 if (x >= 22 or x < 6) else 0, return_dtype=pl.Int8
        ).mean().alias('night_flight_preference'),

        # Route loyalty (simplified)
        (1 - (pl.col('searchRoute').n_unique() / pl.len().clip(1, None))).alias('route_loyalty'),

        # Hub preference
        pl.col('legs0_segments0_departureFrom_airport_iata').is_in([
            'ATL','DXB','DFW','HND','LHR','DEN','ORD','IST','PVG','ICN','CDG','JFK','CLT','MEX','SFO','EWR','MIA','BKK','GRU','HKG'
        ]).mean().alias('hub_preference'),

        # Connection tolerance (average segments)
        sum([pl.col(f'legs0_segments{i}_departureFrom_airport_iata').is_not_null().cast(pl.Int8) for i in range(4)])
        .mean().alias('connection_tolerance'),
    ]

    # Add cabin class features if available
    if cabin_class_cols:
        cabin_aggs = [
            pl.min_horizontal([pl.col(col) for col in cabin_class_cols]).mean().alias('avg_cabin_class'),
        ]
        customer_aggs.extend(cabin_aggs)

    # Add price-related features if totalPrice exists
    if 'totalPrice' in df.columns:
        price_aggs = [
            # Price position within searches (percentile rank)
            (pl.col('totalPrice').rank(method='ordinal').over('ranker_id') /
             pl.col('totalPrice').count().over('ranker_id')).mean().alias('price_position_preference'),
        ]
        customer_aggs.extend(price_aggs)

    # Use lazy evaluation just like your original function
    try:
        # Create lazy frame and group by profileId (following your original pattern)
        lazy_df = df.lazy().group_by('profileId')

        # Apply feature aggregations
        customer_features = lazy_df.agg(customer_aggs)

        # Collect to materialize the results
        enhanced_features = customer_features.collect()

    except Exception as e:
        print(f"Error in lazy aggregation: {e}")
        # Fallback to simpler aggregation with lazy evaluation
        simple_aggs = [
            pl.len().alias('total_searches'),
            pl.col('legs1_departureAt').is_not_null().mean().alias('roundtrip_preference'),
            pl.col('searchRoute').drop_nulls().n_unique().alias('unique_routes_searched'),
            pl.col('legs0_segments0_departureFrom_airport_iata').drop_nulls().mode().first().alias('most_common_departure_airport'),
        ]
        lazy_df = df.lazy().group_by('profileId')
        enhanced_features = lazy_df.agg(simple_aggs).collect()

    print(f"Generated {len(enhanced_features.columns)} customer features for {len(enhanced_features)} customers")
    return enhanced_features


# Alternative: Use your original function directly for consistency
def use_original_extract_customer_features(df: pl.DataFrame) -> pl.DataFrame:
    """
    Wrapper to use your original extract_customer_features function.
    This ensures complete compatibility with your existing feature engineering.
    """
    print("Using original extract_customer_features function...")

    try:
        # Call your original function directly
        customer_features = extract_customer_features(df)
        return customer_features
    except Exception as e:
        print(f"Original function failed: {e}")
        print("Falling back to scalable version...")
        return extract_customer_features_scalable(df)


# Updated training function with choice of feature extraction
def train_scalable_model_with_original_features(train_df: pl.DataFrame, test_df: pl.DataFrame = None,
                                              use_original_features: bool = True) -> Tuple[np.ndarray, ScalableFlightRecommendationModel]:
    """
    Training pipeline with option to use your original feature extraction.

    Args:
        use_original_features: If True, uses your original extract_customer_features function
                             If False, uses the optimized scalable version
    """

    if use_original_features:
        print("Using your original extract_customer_features function...")
        train_customer_features = use_original_extract_customer_features(train_df)
    else:
        print("Using optimized scalable feature extraction...")
        train_customer_features = extract_customer_features_scalable(train_df)

    print("Initializing scalable model...")
    model = ScalableFlightRecommendationModel(
        n_customer_segments=50,
        chunk_size=50000
    )

    print("Training model...")
    model.fit(train_df, train_customer_features)

    predictions = None
    if test_df is not None:
        print("Making predictions...")
        if use_original_features:
            test_customer_features = use_original_extract_customer_features(test_df)
        else:
            test_customer_features = extract_customer_features_scalable(test_df)
        predictions = model.predict_proba_chunked(test_df, test_customer_features)

    print("Saving model...")
    model.save_model('scalable_flight_model.joblib')

    return predictions, model


# Usage example with performance monitoring
print(PerformanceMonitor.estimate_memory_usage(18_000_000, 25))
print(PerformanceMonitor.get_recommendations(18_000_000))

# Recommend using the optimized version
print("\nFor 18M rows, use: train_scalable_model_optimized() instead of train_scalable_model()")