### Imports

In [1]:
import gc
import re
from typing import List, Dict, Tuple
import uuid

import numpy as np
import polars as pl
import polars.selectors as cs

import lightgbm as lgb
from scipy.stats import zscore
from sklearn.cluster import AgglomerativeClustering, DBSCAN
from sklearn.decomposition import PCA, TruncatedSVD
from sklearn.ensemble import IsolationForest
from sklearn.metrics import silhouette_score, calinski_harabasz_score, davies_bouldin_score
from sklearn.mixture import GaussianMixture
from sklearn.model_selection import train_test_split
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import StandardScaler, RobustScaler
import umap

### Get the Training Data Set

In [2]:
# data = pl.concat([train, pl.read_parquet('/kaggle/input/aeroclub-recsys-2025/test.parquet')])
data = pl.read_parquet('/kaggle/input/aeroclub-recsys-2025/train.parquet')

In [3]:
data.head(1000)

Id,bySelf,companyID,corporateTariffCode,frequentFlyer,nationality,isAccess3D,isVip,legs0_arrivalAt,legs0_departureAt,legs0_duration,legs0_segments0_aircraft_code,legs0_segments0_arrivalTo_airport_city_iata,legs0_segments0_arrivalTo_airport_iata,legs0_segments0_baggageAllowance_quantity,legs0_segments0_baggageAllowance_weightMeasurementType,legs0_segments0_cabinClass,legs0_segments0_departureFrom_airport_iata,legs0_segments0_duration,legs0_segments0_flightNumber,legs0_segments0_marketingCarrier_code,legs0_segments0_operatingCarrier_code,legs0_segments0_seatsAvailable,legs0_segments1_aircraft_code,legs0_segments1_arrivalTo_airport_city_iata,legs0_segments1_arrivalTo_airport_iata,legs0_segments1_baggageAllowance_quantity,legs0_segments1_baggageAllowance_weightMeasurementType,legs0_segments1_cabinClass,legs0_segments1_departureFrom_airport_iata,legs0_segments1_duration,legs0_segments1_flightNumber,legs0_segments1_marketingCarrier_code,legs0_segments1_operatingCarrier_code,legs0_segments1_seatsAvailable,legs0_segments2_aircraft_code,legs0_segments2_arrivalTo_airport_city_iata,…,legs1_segments2_baggageAllowance_weightMeasurementType,legs1_segments2_cabinClass,legs1_segments2_departureFrom_airport_iata,legs1_segments2_duration,legs1_segments2_flightNumber,legs1_segments2_marketingCarrier_code,legs1_segments2_operatingCarrier_code,legs1_segments2_seatsAvailable,legs1_segments3_aircraft_code,legs1_segments3_arrivalTo_airport_city_iata,legs1_segments3_arrivalTo_airport_iata,legs1_segments3_baggageAllowance_quantity,legs1_segments3_baggageAllowance_weightMeasurementType,legs1_segments3_cabinClass,legs1_segments3_departureFrom_airport_iata,legs1_segments3_duration,legs1_segments3_flightNumber,legs1_segments3_marketingCarrier_code,legs1_segments3_operatingCarrier_code,legs1_segments3_seatsAvailable,miniRules0_monetaryAmount,miniRules0_percentage,miniRules0_statusInfos,miniRules1_monetaryAmount,miniRules1_percentage,miniRules1_statusInfos,pricingInfo_isAccessTP,pricingInfo_passengerCount,profileId,ranker_id,requestDate,searchRoute,sex,taxes,totalPrice,selected,__index_level_0__
i64,bool,i64,i64,str,i64,bool,bool,str,str,str,str,str,str,f64,f64,f64,str,str,str,str,str,f64,str,str,str,f64,f64,f64,str,str,str,str,str,f64,str,str,…,f64,f64,str,str,str,str,str,f64,str,str,str,f64,f64,f64,str,str,str,str,str,f64,f64,f64,f64,f64,f64,f64,f64,i64,i64,str,datetime[ns],str,bool,f64,f64,i64,i64
0,true,57323,,"""S7/SU/UT""",36,false,false,"""2024-06-15T16:20:00""","""2024-06-15T15:40:00""","""02:40:00""","""YK2""","""KJA""","""KJA""",1.0,0.0,1.0,"""TLK""","""02:40:00""","""216""","""KV""","""KV""",9.0,,,,,,,,,,,,,,,…,,,,,,,,,,,,,,,,,,,,,,,,,,,1.0,1,2087645,"""98ce0dabf6964640b63079fbafd42c…",2024-05-17 03:03:08,"""TLKKJA/KJATLK""",true,370.0,16884.0,1,0
1,true,57323,123,"""S7/SU/UT""",36,true,false,"""2024-06-15T14:50:00""","""2024-06-15T09:25:00""","""07:25:00""","""E70""","""OVB""","""OVB""",1.0,0.0,1.0,"""TLK""","""02:50:00""","""5358""","""S7""","""S7""",4.0,"""E70""","""KJA""","""KJA""",1.0,0.0,1.0,"""OVB""","""01:20:00""","""5311""","""S7""","""S7""",4.0,,,…,,,,,,,,,,,,,,,,,,,,,2300.0,,1.0,3500.0,,1.0,1.0,1,2087645,"""98ce0dabf6964640b63079fbafd42c…",2024-05-17 03:03:08,"""TLKKJA/KJATLK""",true,2240.0,51125.0,0,1
2,true,57323,,"""S7/SU/UT""",36,false,false,"""2024-06-15T14:50:00""","""2024-06-15T09:25:00""","""07:25:00""","""E70""","""OVB""","""OVB""",1.0,0.0,1.0,"""TLK""","""02:50:00""","""5358""","""S7""","""S7""",4.0,"""E70""","""KJA""","""KJA""",1.0,0.0,1.0,"""OVB""","""01:20:00""","""5311""","""S7""","""S7""",4.0,,,…,,,,,,,,,,,,,,,,,,,,,2300.0,,1.0,3500.0,,1.0,1.0,1,2087645,"""98ce0dabf6964640b63079fbafd42c…",2024-05-17 03:03:08,"""TLKKJA/KJATLK""",true,2240.0,53695.0,0,2
3,true,57323,123,"""S7/SU/UT""",36,true,false,"""2024-06-15T14:50:00""","""2024-06-15T09:25:00""","""07:25:00""","""E70""","""OVB""","""OVB""",1.0,0.0,1.0,"""TLK""","""02:50:00""","""5358""","""S7""","""S7""",4.0,"""E70""","""KJA""","""KJA""",1.0,0.0,1.0,"""OVB""","""01:20:00""","""5311""","""S7""","""S7""",4.0,,,…,,,,,,,,,,,,,,,,,,,,,0.0,,1.0,0.0,,1.0,1.0,1,2087645,"""98ce0dabf6964640b63079fbafd42c…",2024-05-17 03:03:08,"""TLKKJA/KJATLK""",true,2240.0,81880.0,0,3
4,true,57323,,"""S7/SU/UT""",36,false,false,"""2024-06-15T14:50:00""","""2024-06-15T09:25:00""","""07:25:00""","""E70""","""OVB""","""OVB""",1.0,0.0,1.0,"""TLK""","""02:50:00""","""5358""","""S7""","""S7""",4.0,"""E70""","""KJA""","""KJA""",1.0,0.0,1.0,"""OVB""","""01:20:00""","""5311""","""S7""","""S7""",4.0,,,…,,,,,,,,,,,,,,,,,,,,,0.0,,1.0,0.0,,1.0,1.0,1,2087645,"""98ce0dabf6964640b63079fbafd42c…",2024-05-17 03:03:08,"""TLKKJA/KJATLK""",true,2240.0,86070.0,0,4
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
995,true,42622,,,36,false,false,"""2024-05-21T07:15:00""","""2024-05-20T17:45:00""","""11:30:00""","""32A""","""MOW""","""DME""",0.0,0.0,1.0,"""OVB""","""04:20:00""","""2516""","""S7""","""S7""",9.0,"""32N""","""HTA""","""HTA""",0.0,0.0,1.0,"""DME""","""06:10:00""","""3045""","""S7""","""S7""",1.0,,,…,,,,,,,,,,,,,,,,,,,,,4000.0,,1.0,0.0,,0.0,0.0,1,2812817,"""1e9c2a3788a74ab7b80890ef712773…",2024-05-17 06:39:30,"""OVBHTA/HTAOVB""",true,2612.0,54837.0,0,995
996,true,42622,115,,36,true,false,"""2024-05-21T07:15:00""","""2024-05-20T17:45:00""","""11:30:00""","""32A""","""MOW""","""DME""",1.0,0.0,1.0,"""OVB""","""04:20:00""","""2516""","""S7""","""S7""",9.0,"""32N""","""HTA""","""HTA""",1.0,0.0,1.0,"""DME""","""06:10:00""","""3045""","""S7""","""S7""",1.0,,,…,,,,,,,,,,,,,,,,,,,,,2300.0,,1.0,3500.0,,1.0,0.0,1,2812817,"""1e9c2a3788a74ab7b80890ef712773…",2024-05-17 06:39:30,"""OVBHTA/HTAOVB""",true,2612.0,71662.0,0,996
997,true,42622,,,36,false,false,"""2024-05-21T07:15:00""","""2024-05-20T17:45:00""","""11:30:00""","""32A""","""MOW""","""DME""",1.0,0.0,1.0,"""OVB""","""04:20:00""","""2516""","""S7""","""S7""",9.0,"""32N""","""HTA""","""HTA""",1.0,0.0,1.0,"""DME""","""06:10:00""","""3045""","""S7""","""S7""",1.0,,,…,,,,,,,,,,,,,,,,,,,,,2300.0,,1.0,3500.0,,1.0,0.0,1,2812817,"""1e9c2a3788a74ab7b80890ef712773…",2024-05-17 06:39:30,"""OVBHTA/HTAOVB""",true,2612.0,76067.0,0,997
998,true,42622,115,,36,true,false,"""2024-05-21T07:15:00""","""2024-05-20T17:45:00""","""11:30:00""","""32A""","""MOW""","""DME""",1.0,0.0,1.0,"""OVB""","""04:20:00""","""2516""","""S7""","""S7""",9.0,"""32N""","""HTA""","""HTA""",1.0,0.0,1.0,"""DME""","""06:10:00""","""3045""","""S7""","""S7""",1.0,,,…,,,,,,,,,,,,,,,,,,,,,0.0,,1.0,0.0,,1.0,0.0,1,2812817,"""1e9c2a3788a74ab7b80890ef712773…",2024-05-17 06:39:30,"""OVBHTA/HTAOVB""",true,2612.0,116322.0,0,998


In [4]:
print(f"the number of new profiles in train data is {len(data.select("profileId").unique())}")
print(f'the number of searches in train data is {len(data.select('ranker_id').unique())}')

the number of new profiles in train data is 32922
the number of searches in train data is 105539


In [6]:
flight_counts = data.group_by('ranker_id').agg(pl.len().alias('flight_count'))

print(f'The minimum number of searches in train data is {flight_counts['flight_count'].min()}')
print(f'The maximum number of searches in train data is {flight_counts['flight_count'].max()}')

del flight_counts
gc.collect()

The minimum number of searches in train data is 1
The maximum number of searches in train data is 8236


12

### Split into Train and Test Data Sets

In [7]:
def stratified_flight_count_split(df, test_size=0.2, random_state=42):
    """
    Stratified split based on number of flights per search (ranker_id).
    Ensures train/test sets have similar distributions of search sizes.
    """

    # Calculate flights per search
    search_flight_counts = (
        df
        .group_by('ranker_id')
        .agg(pl.len().alias('flight_count'))
    )

    # Create stratification bins based on your distribution
    # Adjust these thresholds based on your specific data
    search_flight_counts = search_flight_counts.with_columns(
        pl.when(pl.col('flight_count') <= 50).then(pl.lit('very_small'))     # Most searches
        .when(pl.col('flight_count') <= 200).then(pl.lit('small'))
        .when(pl.col('flight_count') <= 500).then(pl.lit('medium'))
        .when(pl.col('flight_count') <= 1000).then(pl.lit('large'))
        .otherwise(pl.lit('very_large'))                                     # High-flight searches
        .alias('size_category')
    )

    # Print distribution for verification
    size_dist = search_flight_counts.group_by('size_category').len().sort('size_category')
    print("Search size distribution:")
    for row in size_dist.iter_rows(named=True):
        print(f"  {row['size_category']}: {row['len']:,} searches")

    # Split within each size category
    train_searches = []
    test_searches = []

    for category in ['very_small', 'small', 'medium', 'large', 'very_large']:
        category_searches = (
            search_flight_counts
            .filter(pl.col('size_category') == category)
            .select('ranker_id')
            .to_pandas()['ranker_id'].tolist()
        )

        if len(category_searches) > 1:
            train_cat, test_cat = train_test_split(
                category_searches,
                test_size=test_size,
                random_state=random_state,
                stratify=None  # No further stratification within category
            )
            train_searches.extend(train_cat)
            test_searches.extend(test_cat)
        elif len(category_searches) == 1:
            # Single search goes to train
            train_searches.extend(category_searches)

    # Filter original data
    train_data = df.filter(pl.col('ranker_id').is_in(train_searches))
    test_data = df.filter(pl.col('ranker_id').is_in(test_searches))

    # Verification: Check distribution preservation
    print(f"\nSplit results:")
    print(f"Train searches: {len(train_searches):,}")
    print(f"Test searches: {len(test_searches):,}")
    print(f"Train rows: {train_data.height:,}")
    print(f"Test rows: {test_data.height:,}")

    # Verify stratification worked
    train_size_dist = (
        train_data
        .group_by('ranker_id')
        .agg(pl.len().alias('flight_count'))
        .with_columns(
            pl.when(pl.col('flight_count') <= 50).then(pl.lit('very_small'))
            .when(pl.col('flight_count') <= 200).then(pl.lit('small'))
            .when(pl.col('flight_count') <= 500).then(pl.lit('medium'))
            .when(pl.col('flight_count') <= 1000).then(pl.lit('large'))
            .otherwise(pl.lit('very_large'))
            .alias('size_category')
        )
        .group_by('size_category')
        .len()
        .sort('size_category')
    )

    test_size_dist = (
        test_data
        .group_by('ranker_id')
        .agg(pl.len().alias('flight_count'))
        .with_columns(
            pl.when(pl.col('flight_count') <= 50).then(pl.lit('very_small'))
            .when(pl.col('flight_count') <= 200).then(pl.lit('small'))
            .when(pl.col('flight_count') <= 500).then(pl.lit('medium'))
            .when(pl.col('flight_count') <= 1000).then(pl.lit('large'))
            .otherwise(pl.lit('very_large'))
            .alias('size_category')
        )
        .group_by('size_category')
        .len()
        .sort('size_category')
    )

    print(f"\nTrain set distribution:")
    for row in train_size_dist.iter_rows(named=True):
        print(f"  {row['size_category']}: {row['len']:,} searches")

    print(f"\nTest set distribution:")
    for row in test_size_dist.iter_rows(named=True):
        print(f"  {row['size_category']}: {row['len']:,} searches")

    return train_data, test_data


In [8]:
train_df, test_df = stratified_flight_count_split(data)

Search size distribution:
  large: 5,035 searches
  medium: 13,896 searches
  small: 30,360 searches
  very_large: 3,141 searches
  very_small: 53,107 searches

Split results:
Train searches: 84,429
Test searches: 21,110
Train rows: 14,515,468
Test rows: 3,629,904

Train set distribution:
  large: 4,028 searches
  medium: 11,116 searches
  small: 24,288 searches
  very_large: 2,512 searches
  very_small: 42,485 searches

Test set distribution:
  large: 1,007 searches
  medium: 2,780 searches
  small: 6,072 searches
  very_large: 629 searches
  very_small: 10,622 searches


In [9]:
# Save the train/test splits
train_df.write_parquet('data/train_df.parquet')
test_df.write_parquet('data/test_df.parquet')

# Remove test data from memory
del test_df
gc.collect()

0

In [10]:
del data
gc.collect()

0

### Utilities

In [5]:
MAJOR_HUBS = ['ATL','DXB','DFW','HND','LHR','DEN','ORD','IST','PVG','ICN','CDG', 'JFK','CLT','MEX','SFO','EWR','MIA','BKK','GRU','HKG']

def camel_to_snake(name):
    """Convert camelCase or PascalCase to snake_case"""
    s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
    s2 = re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1)
    return s2.lower()


def convert_columns_to_snake_case(df):
    """Convert all column names in a polars DataFrame to snake_case"""
    return df.rename({col: camel_to_snake(col) for col in df.columns})


def parse_duration_to_minutes(duration_col: str) -> pl.Expr:
    """Parse duration string to minutes (handles format like '2.04:20' - D.HH:MM)."""
    return (
        pl.when(pl.col(duration_col).is_not_null() & (pl.col(duration_col) != ""))
        .then(
            pl.when(pl.col(duration_col).str.contains(r'\.'))
            .then(
                # Format: D.HH:MM:SS (e.g., "1.00:30:00", "2.09:45:00")
                pl.col(duration_col).str.extract(r'^(\d+)\.(\d{2}):(\d{2}):(\d{2})$', 1).cast(pl.Int32, strict=False).fill_null(0) * 1440 +  # Days
                pl.col(duration_col).str.extract(r'^(\d+)\.(\d{2}):(\d{2}):(\d{2})$', 2).cast(pl.Int32, strict=False).fill_null(0) * 60 +   # Hours
                pl.col(duration_col).str.extract(r'^(\d+)\.(\d{2}):(\d{2}):(\d{2})$', 3).cast(pl.Int32, strict=False).fill_null(0) +        # Minutes
                (pl.col(duration_col).str.extract(r'^(\d+)\.(\d{2}):(\d{2}):(\d{2})$', 4).cast(pl.Int32, strict=False).fill_null(0) / 60).round(0).cast(pl.Int32, strict=False)  # Seconds
            )
            .otherwise(
                # Format: HH:MM:SS (e.g., "07:25:00", "17:55:00")
                pl.col(duration_col).str.extract(r'^(\d{2}):(\d{2}):(\d{2})$', 1).cast(pl.Int32, strict=False).fill_null(0) * 60 +   # Hours
                pl.col(duration_col).str.extract(r'^(\d{2}):(\d{2}):(\d{2})$', 2).cast(pl.Int32, strict=False).fill_null(0) +        # Minutes
                (pl.col(duration_col).str.extract(r'^(\d{2}):(\d{2}):(\d{2})$', 3).cast(pl.Int32, strict=False).fill_null(0) / 60).round(0).cast(pl.Int32, strict=False)  # Seconds
            )
        )
        .otherwise(0)
    )

### Engineer Customer Features

In [6]:
# Customer attributes for clustering analysis

def create_customer_aggregation_features() -> List[pl.Expr]:
    """Create customer aggregation expressions for basic attributes and search behavior."""
    return [
        # Basic customer attributes (take first non-null value per customer)
        pl.col('companyID').drop_nulls().first().alias('company_id'),
        pl.col('sex').drop_nulls().first().cast(pl.Int8).alias('sex'),
        pl.col('nationality').drop_nulls().first().alias('nationality'),
        pl.col('frequentFlyer').drop_nulls().first().str.replace('- ЮТэйр ЗАО', 'UT').fill_null('').alias('frequent_flyer'),
        pl.col('isVip').drop_nulls().first().cast(pl.Int8).alias('is_vip'),
        pl.col('bySelf').drop_nulls().first().cast(pl.Int8).alias('by_self'),
        pl.col('corporateTariffCode').is_not_null().cast(pl.Int8).max().alias('has_corp_codes'),

        # Search behavior metrics
        pl.len().alias('total_searches'),
        pl.col('legs1_departureAt').is_not_null().mean().alias('roundtrip_preference'),
        pl.col('searchRoute').drop_nulls().n_unique().alias('unique_routes_searched'),
    ]


def create_booking_lead_time_features() -> List[pl.Expr]:
    """Create booking lead time statistics."""
    # Calculate booking lead time in days
    booking_lead_expr = (
        (pl.col('legs0_departureAt').str.to_datetime() -
         pl.col('requestDate').cast(pl.Datetime)) / pl.duration(days=1)
    ).cast(pl.Int32)

    return [
        booking_lead_expr.min().alias('min_booking_lead_days'),
        booking_lead_expr.max().alias('max_booking_lead_days'),
        booking_lead_expr.mean().alias('avg_booking_lead_days'),
        booking_lead_expr.median().alias('median_booking_lead_days'),
    ]


def create_travel_preference_features() -> List[pl.Expr]:
    """Create travel preference features for most common airports and carriers."""
    return [
        # Most common departure airport
        pl.col('legs0_segments0_departureFrom_airport_iata').drop_nulls().mode().first().alias('most_common_departure_airport'),
        pl.col('legs0_segments0_departureFrom_airport_iata').drop_nulls().n_unique().alias('unique_departure_airports'),

        # Most common marketing carrier
        pl.col('legs0_segments0_marketingCarrier_code').drop_nulls().mode().first().alias('most_common_carrier'),
        pl.col('legs0_segments0_marketingCarrier_code').drop_nulls().n_unique().alias('unique_carriers_used'),
    ]


def create_cabin_class_features(cabin_class_cols: List[str]) -> List[pl.Expr]:
    """Create cabin class preference statistics."""
    if not cabin_class_cols:
        # Return default values if no cabin class columns found
        return [
            pl.lit(None).alias('min_cabin_class'),
            pl.lit(None).alias('max_cabin_class'),
            pl.lit(None).alias('avg_cabin_class'),
        ]

    return [
        # Cabin class statistics across all segments
        pl.min_horizontal([pl.col(col) for col in cabin_class_cols]).min().alias('min_cabin_class'),
        pl.max_horizontal([pl.col(col) for col in cabin_class_cols]).max().alias('max_cabin_class'),
        pl.mean_horizontal([pl.col(col) for col in cabin_class_cols]).mean().alias('avg_cabin_class'),
    ]


def create_temporal_preference_features() -> List[pl.Expr]:
    """Create temporal preference features for departure patterns."""
    return [
        # Weekday preference (most common day of week for departures)
        pl.col('legs0_departureAt').str.to_datetime().dt.weekday()
          .mode().first().alias('weekday_preference'),

        # Weekend travel rate (percentage of weekend departures - 5=Sat, 6=Sun)
        pl.col('legs0_departureAt').str.to_datetime().dt.weekday()
          .map_elements(lambda x: 1 if x >= 5 else 0, return_dtype=pl.Int8)
          .mean().alias('weekend_travel_rate'),

        # Time of day variance (how consistent are their departure times)
        pl.col('legs0_departureAt').str.to_datetime().dt.hour()
          .std().alias('time_of_day_variance'),

        # Night flight preference (flights departing 22:00-06:00)
        pl.col('legs0_departureAt').str.to_datetime().dt.hour()
          .map_elements(lambda x: 1 if (x >= 22 or x < 6) else 0, return_dtype=pl.Int8)
          .mean().alias('night_flight_preference')
    ]


def create_route_specific_features() -> List[pl.Expr]:
    """Create features related to route preferences and characteristics."""
    return [
        # Route loyalty
        (1 - (pl.col('searchRoute').n_unique() / pl.len().clip(1, None))).alias('route_loyalty'),

        # Hub preference (preference for major hub airports)
        (
            pl.col('legs0_segments0_departureFrom_airport_iata').is_in(MAJOR_HUBS) |
            pl.col('legs0_segments0_arrivalTo_airport_iata').is_in(MAJOR_HUBS)
        ).cast(pl.Int8).mean().alias('hub_preference'),

        # Short haul preference
        ((1 - (pl.col('leg0_duration_minutes').mean() / 180)) * 0.7 +
         pl.when(pl.col('leg0_duration_minutes') <= 180).then(1).otherwise(0).mean() * 0.3)
        .clip(0, 1).alias('short_haul_preference'),

        # Connection tolerance (average segments)
        pl.sum_horizontal([
            pl.col(f'legs0_segments{i}_departureFrom_airport_iata').is_not_null().cast(pl.Int8)
            for i in range(4)  # assuming max 4 segments
        ]).mean().alias('connection_tolerance'),

        # Preference for longer vs shorter flights (duration quartile preference)
        pl.when(pl.col('leg0_duration_minutes') <= pl.col('leg0_duration_q25'))
        .then(1)  # Short flights
        .when(pl.col('leg0_duration_minutes') <= pl.col('leg0_duration_q50'))
        .then(2)  # Medium-short flights
        .when(pl.col('leg0_duration_minutes') <= pl.col('leg0_duration_q75'))
        .then(3)  # Medium-long flights
        .otherwise(4)  # Long flights
        .mode().first().alias('preferred_duration_quartile')
    ]


def create_price_sensitivity_features() -> List[pl.Expr]:
    """Create features related to price sensitivity and patterns."""
    return [
        # Basic correlation between price and duration
        pl.corr(
            pl.col('totalPrice'),
            pl.col('trip_duration_minutes')
        ).fill_null(0).alias('price_to_duration_sensitivity'),

        # Price per minute metric (average across all flights)
        (pl.col('totalPrice') / pl.col('trip_duration_minutes').clip(1, None)).mean().alias('avg_price_per_minute'),

        # Consistency of price-per-minute (lower std = more consistent valuation)
        (pl.col('totalPrice') / pl.col('trip_duration_minutes').clip(1, None)).std().fill_null(0).alias('price_per_minute_variance'),

        # Price position within searches (percentile rank)
        pl.col('price_percentile').mean().alias('price_position_preference'),

        # Premium economy preference (assuming cabin class 2 is premium economy)
        pl.mean_horizontal([
            pl.col(f'legs0_segments{i}_cabinClass') == 2
            for i in range(4)
        ]).mean().alias('premium_economy_preference'),

        # Consistent price tier
        (1 - (
            pl.when(pl.col('totalPrice') <= pl.col('price_q25'))
            .then(1)  # Budget tier
            .when(pl.col('totalPrice') <= pl.col('price_q50'))
            .then(2)  # Economy tier
            .when(pl.col('totalPrice') <= pl.col('price_q75'))
            .then(3)  # Premium tier
            .otherwise(4)  # Luxury tier
            .std()
            .fill_null(0)  # Handle null case
            / 3
        ).clip(0, 1)).alias('consistent_price_tier'),

        # Most common price tier
        pl.when(pl.col('totalPrice') <= pl.col('price_q25'))
        .then(1)  # Budget tier
        .when(pl.col('totalPrice') <= pl.col('price_q50'))
        .then(2)  # Economy tier
        .when(pl.col('totalPrice') <= pl.col('price_q75'))
        .then(3)  # Premium tier
        .otherwise(4)  # Luxury tier
        .mode().first().alias('preferred_price_tier'),
    ]


def create_service_preference_features() -> List[pl.Expr]:
    """Create features related to service preferences."""
    # First get all the relevant column names for type and quantity
    type_cols = [
        f'legs{leg}_segments{seg}_baggageAllowance_weightMeasurementType'
        for leg in range(2) for seg in range(4)
    ]
    qty_cols = [
        f'legs{leg}_segments{seg}_baggageAllowance_quantity'
        for leg in range(2) for seg in range(4)
    ]

    return [
        # Baggage quantity preference (average of minimum bags allowed per flight option)
        pl.min_horizontal([
            pl.when(pl.col(type_col) == 0)
            .then(pl.col(qty_col))
            .otherwise(pl.lit(None))
            for type_col, qty_col in zip(type_cols, qty_cols)
        ]).mean().fill_null(0).alias('baggage_qty_preference'),

        # Baggage weight preference (average of minimum weight allowed per flight option)
        pl.min_horizontal([
            pl.when(pl.col(type_col) == 1)
            .then(pl.col(qty_col))
            .otherwise(pl.lit(None))
            for type_col, qty_col in zip(type_cols, qty_cols)
        ]).mean().fill_null(0).fill_nan(0).alias('baggage_weight_preference'),

        # Loyalty program utilization
        pl.col('frequentFlyer').is_not_null().mean().alias('loyalty_program_utilization')
    ]


def create_derived_metrics() -> List[pl.Expr]:
    """Create complex derived metrics from combinations of features."""
    return [
        # Convenience priority score (higher = more emphasis on convenient times)
        ((1 - pl.col('time_of_day_variance')) * 10 +
         pl.col('price_to_duration_sensitivity') * 5)
        .alias('convenience_priority_score'),

        # Loyalty vs price index (higher = more loyal, less price sensitive)
        (pl.col('loyalty_program_utilization') * 10 -
         pl.col('price_position_preference') / 10)
        .alias('loyalty_vs_price_index'),

        # Planning consistency score (inverse of lead time variance)
        (1 / (pl.col('max_booking_lead_days') - pl.col('min_booking_lead_days') + 1))
        .alias('planning_consistency_score'),

        # Luxury index (combination of cabin class and price tier)
        (pl.col('avg_cabin_class') * 20 +
         pl.col('price_position_preference') / 2)
        .alias('luxury_index'),

        # Create advanced features
        (pl.col('total_searches') / pl.col('unique_routes_searched').clip(1)).alias('search_intensity_per_route'),
        (pl.col('max_booking_lead_days') - pl.col('min_booking_lead_days')).alias('lead_time_variance'),
        (pl.col('avg_booking_lead_days') / pl.col('median_booking_lead_days').clip(1)).alias('lead_time_skew'),
        (pl.col('unique_carriers_used') / pl.col('total_searches').clip(1)).alias('carrier_diversity'),
        (pl.col('unique_departure_airports') / pl.col('total_searches').clip(1)).alias('airport_diversity'),
        (pl.col('max_cabin_class') - pl.col('min_cabin_class')).alias('cabin_class_range'),
        (pl.col('is_vip').cast(pl.Int8) * 2 + pl.col('has_corp_codes').is_not_null().cast(pl.Int8)).alias('customer_tier'),
    ]


def add_trip_duration(df: pl.DataFrame) -> pl.DataFrame:
    """Add trip duration to the DataFrame."""
    leg0_duration_minutes = parse_duration_to_minutes('legs0_duration')
    leg1_duration_minutes = parse_duration_to_minutes('legs1_duration')
    trip_duration_minutes = leg0_duration_minutes + leg1_duration_minutes

    return df.with_columns([
        leg0_duration_minutes.alias('leg0_duration_minutes'),
        leg1_duration_minutes.alias('leg1_duration_minutes'),
        trip_duration_minutes.alias('trip_duration_minutes'),
    ])


def create_windows_based_features(df) -> pl.DataFrame:
    # Add window-based features like price_percentile if they don't exist
    return (df.with_columns([
        # calculate price percentile over search session
        ((pl.col('totalPrice').rank(method='min').over('ranker_id') - 1) /
        (pl.col('totalPrice').count().over('ranker_id') - 1) * 100)
        .fill_null(50.0).alias('price_percentile'),

        # calculate price quartiles over profileId
        pl.col('totalPrice').quantile(0.25).over('profileId').alias('price_q25'),
        pl.col('totalPrice').quantile(0.50).over('profileId').alias('price_q50'),
        pl.col('totalPrice').quantile(0.75).over('profileId').alias('price_q75'),

        # calculate leg0_duration quartiles over profileId
        pl.col('leg0_duration_minutes').quantile(0.25).over('profileId').alias('leg0_duration_q25'),
        pl.col('leg0_duration_minutes').quantile(0.50).over('profileId').alias('leg0_duration_q50'),
        pl.col('leg0_duration_minutes').quantile(0.75).over('profileId').alias('leg0_duration_q75'),
    ]))


def create_interaction_features() -> List[pl.Expr]:
    # Create customer/business interaction features
    return [
        # Create VIP interactions
        (pl.col('search_intensity_per_route') * pl.col('is_vip')).alias('vip_search_intensity'),
        (pl.col('carrier_diversity') * pl.col('is_vip')).alias('vip_carrier_diversity'),
        (pl.col('avg_cabin_class') * pl.col('is_vip')).alias('vip_cabin_preference'),

        # Create corporate interactions
        (pl.col('total_searches') * pl.col('has_corp_codes')).alias('corp_search_volume'),
        (pl.col('roundtrip_preference') * pl.col('has_corp_codes')).alias('corp_roundtrip_pref'),
        (pl.col('lead_time_variance') * pl.col('has_corp_codes')).alias('corp_planning_variance'),
    ]

def extract_customer_features(df: pl.DataFrame) -> pl.DataFrame:
    """
    Extract customer features for clustering analysis.
    Aggregates by profileId to create customer-level features.
    """
    # Check if already processed
    if df.height > 0 and 'total_searches' in df.columns:
        return df

    # Get cabin class columns
    cabin_class_cols = [col for col in df.columns if col.startswith('legs') and col.endswith('_cabinClass')]

    # Create lazy frame and group by profileId
    lazy_df = create_windows_based_features(add_trip_duration(df)).lazy().group_by('profileId')

    # Apply customer feature groups
    lazy_df = lazy_df.agg([
        *create_customer_aggregation_features(),
        *create_booking_lead_time_features(),
        *create_travel_preference_features(),
        *create_cabin_class_features(cabin_class_cols),
        *create_temporal_preference_features(),
        *create_route_specific_features(),
        *create_price_sensitivity_features(),
        *create_service_preference_features()
    ])

    # Materialize to generate the basic features
    # base_features = customer_features.collect()

    # Add the derived metrics that depend on the generated features
    lazy_df = lazy_df.with_columns(create_derived_metrics())

    # Add interactive features
    lazy_df = lazy_df.with_columns(create_interaction_features())

    # print(f"Generated {len(enhanced_features.columns)} customer features for {len(enhanced_features)} customers")
    return lazy_df.collect().fill_null(0).fill_nan(0)

In [7]:
# 1. Feature Engineering
cust_data = extract_customer_features(data)
print(f'Generated {len(cust_data.columns)} customer features for {len(cust_data)} customers')

Generated 58 customer features for 32922 customers


In [8]:
cust_data.head(100)

profileId,company_id,sex,nationality,frequent_flyer,is_vip,by_self,has_corp_codes,total_searches,roundtrip_preference,unique_routes_searched,min_booking_lead_days,max_booking_lead_days,avg_booking_lead_days,median_booking_lead_days,most_common_departure_airport,unique_departure_airports,most_common_carrier,unique_carriers_used,min_cabin_class,max_cabin_class,avg_cabin_class,weekday_preference,weekend_travel_rate,time_of_day_variance,night_flight_preference,route_loyalty,hub_preference,short_haul_preference,connection_tolerance,preferred_duration_quartile,price_to_duration_sensitivity,avg_price_per_minute,price_per_minute_variance,price_position_preference,premium_economy_preference,consistent_price_tier,preferred_price_tier,baggage_qty_preference,baggage_weight_preference,loyalty_program_utilization,convenience_priority_score,loyalty_vs_price_index,planning_consistency_score,luxury_index,search_intensity_per_route,lead_time_variance,lead_time_skew,carrier_diversity,airport_diversity,cabin_class_range,customer_tier,vip_search_intensity,vip_carrier_diversity,vip_cabin_preference,corp_search_volume,corp_roundtrip_pref,corp_planning_variance
i64,i64,i8,i64,str,i8,i8,i8,u32,f64,u32,i32,i32,f64,f64,str,u32,str,u32,f64,f64,f64,i8,f64,f64,f64,f64,f64,f64,f64,i32,f64,f64,f64,f64,f64,f64,i32,f64,f64,f64,f64,f64,f64,f64,f64,i32,f64,f64,f64,f64,i8,f64,f64,f64,i64,f64,i32
2187737,61081,1,36,"""""",0,1,1,945,1.0,1,40,41,40.732275,41.0,"""SVX""",1,"""SU""",6,1.0,2.0,1.370899,3,0.0,7.066779,0.206349,0.998942,0.0,0.244042,1.058201,1,-0.006646,143.49524,109.121191,48.000852,0.370899,0.629458,2,1.17037,0.0,0.0,-60.70102,-4.800085,0.5,51.418415,945.0,1,0.99347,0.006349,0.001058,1.0,1,0.0,0.0,0.0,945,1.0,1
2103048,57074,1,36,"""""",0,1,0,522,0.0,3,0,6,0.597701,1.0,"""LED""",4,"""SU""",9,1.0,4.0,1.450192,3,0.0,5.031274,0.139847,0.994253,0.007663,0.635792,1.001916,1,0.000759,291.066118,230.135664,49.171021,0.415709,0.626398,1,1.130268,0.0,0.0,-40.308946,-4.917102,0.142857,53.589342,174.0,6,0.597701,0.017241,0.007663,3.0,1,0.0,0.0,0.0,0,0.0,0
2395632,59096,1,36,"""SU""",0,1,0,2948,0.523745,17,3,37,15.892809,21.0,"""LED""",10,"""SU""",17,1.0,4.0,1.358265,1,0.474898,5.633266,0.157056,0.994233,0.009498,0.087378,1.21133,1,0.432401,197.470881,184.515556,48.434932,0.354308,0.623995,1,0.976992,29.375,1.0,-44.170655,5.156507,0.028571,51.382776,173.411765,34,0.7568,0.005767,0.003392,3.0,1,0.0,0.0,0.0,0,0.0,0
3536933,60628,1,36,"""""",0,1,0,10,0.0,1,4,4,4.0,4.0,"""UFA""",1,"""S7""",3,1.0,1.0,1.0,2,0.0,3.529243,0.1,0.9,0.0,0.0,2.0,2,0.373171,81.742386,26.024476,47.777778,0.0,0.613499,2,1.0,0.0,0.0,-23.426576,-4.777778,1.0,43.888889,10.0,0,1.0,0.3,0.1,0.0,1,0.0,0.0,0.0,0,0.0,0
634576,42620,0,36,"""""",0,1,1,13,1.0,1,57,58,57.461538,57.0,"""KUF""",1,"""SU""",4,1.0,2.0,1.307692,3,0.0,3.782551,0.0,0.923077,0.0,0.0,1.0,2,0.299916,114.562559,83.135223,50.0,0.307692,0.602357,1,1.384615,0.0,0.0,-26.325933,-5.0,0.5,51.153846,13.0,1,1.008097,0.307692,0.076923,1.0,1,0.0,0.0,0.0,13,1.0,1
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
2359702,51409,1,36,"""""",0,1,0,1324,0.419184,3,0,49,27.710725,13.0,"""SVO""",5,"""SU""",8,1.0,2.0,1.007553,1,0.0,5.441444,0.148792,0.997734,0.0,0.358115,1.075529,1,0.518454,105.722217,82.016626,47.812388,0.007553,0.62641,1,0.77719,0.0,0.0,-41.822168,-4.781239,0.02,44.057252,441.333333,49,2.131594,0.006042,0.003776,1.0,1,0.0,0.0,0.0,0,0.0,0
3409127,57323,1,36,"""""",0,1,1,57,0.701754,2,29,54,46.54386,54.0,"""IJK""",2,"""SU""",2,1.0,2.0,1.473684,7,0.701754,4.853299,0.052632,0.964912,0.0,0.0,1.964912,1,0.456482,77.664534,51.221153,43.373257,0.473684,0.639476,3,1.54386,0.0,0.0,-36.250581,-4.337326,0.038462,51.160313,28.5,25,0.861923,0.035088,0.035088,1.0,1,0.0,0.0,0.0,57,0.701754,25
2408860,60844,1,36,"""""",0,1,0,897,0.337793,8,6,50,33.408027,40.0,"""SVO""",10,"""SU""",14,1.0,2.0,1.352843,2,0.276477,5.929663,0.113712,0.991081,0.0,0.063229,1.19398,1,0.027138,202.665938,186.50378,47.088763,0.352843,0.631352,3,1.141583,0.0,0.0,-49.160941,-4.708876,0.022222,50.601238,112.125,44,0.835201,0.015608,0.011148,1.0,1,0.0,0.0,0.0,0,0.0,0
2139225,25667,1,36,"""SU""",0,1,1,654,0.928135,9,2,49,29.984709,42.0,"""VVO""",3,"""S7""",8,1.0,4.0,1.179664,3,0.321101,4.212571,0.036697,0.986239,0.0,0.0,1.671254,1,0.471027,62.472372,49.458475,47.74407,0.085627,0.627496,3,0.900612,20.0,1.0,-29.770569,5.225593,0.020833,47.465307,72.666667,47,0.713922,0.012232,0.004587,3.0,1,0.0,0.0,0.0,654,0.928135,47


In [9]:
def remove_outliers(df, method='isolation_forest', contamination=0.05):
    """Remove outliers using various methods"""

    # Get numeric columns only
    numeric_cols = [col for col, dtype in df.schema.items() if dtype.is_numeric()]
    df_numeric = df.select(numeric_cols)

    if method == 'isolation_forest':
        iso_forest = IsolationForest(contamination=contamination, random_state=42)
        outlier_labels = iso_forest.fit_predict(df_numeric.to_pandas())
        mask = outlier_labels == 1

    elif method == 'zscore':
        # Remove rows where any feature has |z-score| > 3
        df_pd = df_numeric.to_pandas()
        z_scores = np.abs(zscore(df_pd, nan_policy='omit'))
        mask = (z_scores < 3).all(axis=1)

    elif method == 'iqr':
        # Remove rows outside 1.5*IQR for any feature
        df_pd = df_numeric.to_pandas()
        Q1 = df_pd.quantile(0.25)
        Q3 = df_pd.quantile(0.75)
        IQR = Q3 - Q1
        mask = ~((df_pd < (Q1 - 1.5 * IQR)) | (df_pd > (Q3 + 1.5 * IQR))).any(axis=1)

    # Get outlier indices
    outlier_indices = np.where(~mask)[0]

    cleaned_df = df.filter(pl.Series('mask', mask))
    print(f"Removed {len(df) - len(cleaned_df):,} outliers ({(len(df) - len(cleaned_df))/len(df)*100:.1f}%)")

    return cleaned_df, outlier_indices

def encode_features(df):
    # Get categorical columns
    categorical_cols = [col for col, dtype in df.schema.items() if not dtype.is_numeric() and col not in ['id', 'ranker_id', 'request_date']]

    encoded_df = df.clone()
    encoders = {}

    for col in categorical_cols:
        if col == 'frequent_flyer':
            # Special handling for frequent flyer programs (list feature)
            # Create binary features for most common programs
            all_programs = []
            for programs in df[col].to_list():
                if isinstance(programs, list):
                    all_programs.extend(programs)

            from collections import Counter
            top_programs = Counter(all_programs).most_common(10)

            for program, _ in top_programs:
                encoded_df = encoded_df.with_columns([
                    pl.col(col).map_elements(
                        lambda x: 1 if isinstance(x, list) and program in x else 0,
                        return_dtype=pl.Int8
                    ).alias(f'ff_{program}')
                ])

            # Add count of total programs
            encoded_df = encoded_df.with_columns([
                pl.col(col).map_elements(
                    lambda x: len(x) if isinstance(x, list) else 0,
                    return_dtype=pl.Int8
                ).alias('ff_program_count')
            ])

        else:
            # Target encoding for high-cardinality categoricals like airports/carriers
            if col in ['most_common_departure_airport', 'most_common_carrier']:
                # Use frequency encoding
                value_counts = df[col].value_counts()
                freq_map = {row[0]: row[1] for row in value_counts.rows()}

                encoded_df = encoded_df.with_columns([
                    pl.col(col).replace_strict(freq_map, default=0).alias(f'{col}_frequency')
                ])
            else:
                # Standard label encoding for low-cardinality features
                unique_values = df[col].fill_null('MISSING').unique().sort().to_list()
                encoders[col] = {val: idx for idx, val in enumerate(unique_values)}

                encoded_df = encoded_df.with_columns([
                    pl.col(col).fill_null('MISSING').replace(encoders[col]).alias(f'{col}_encoded')

                ])

    # Remove original categorical columns
    final_df = encoded_df.select(pl.exclude(categorical_cols + ['frequent_flyer', 'most_common_departure_airport', 'most_common_carrier']))

    return final_df.fill_null(0), encoders

def dimensionality_reduction(scaled_features, method='pca', n_components=50):
    """Apply dimensionality reduction before clustering"""

    if method == 'pca':
        reducer = PCA(n_components=n_components, random_state=42)

    elif method == 'truncated_svd':
        reducer = TruncatedSVD(n_components=n_components, random_state=42)

    elif method == 'umap':
        reducer = umap.UMAP(n_components=n_components, random_state=42, n_neighbors=15)

    reduced_features = reducer.fit_transform(scaled_features)

    if hasattr(reducer, 'explained_variance_ratio_'):
        total_variance = reducer.explained_variance_ratio_.sum()
        print(f"{method.upper()} retained {total_variance:.3f} of total variance with {n_components} components")

    return reduced_features, reducer


def generate_clusters(features, n_clusters=10):
    agg = AgglomerativeClustering(n_clusters=n_clusters, linkage='ward')
    agg_segments = agg.fit_predict(features)

    # Create pseudo-centroids for use in prediction, since AgglomerativeClustering doesn't have centroids
    unique_segments = np.unique(agg_segments)
    centroids = []
    for segment_id in unique_segments:
        segment_mask = agg_segments == segment_id
        centroid = features[segment_mask].mean(axis=0)
        centroids.append(centroid)

    centroids = np.array(centroids)

    score = silhouette_score(features, agg_segments)

    return {
        "labels": agg_segments,
        "centroids": centroids,
        "model": agg,
        "silhouette": score,
        "n_clusters": n_clusters
    }

In [10]:
def analyze_clusters(original_df, cluster_labels, outlier_indices=None):
    """Analyze the final clusters"""
    # Add cluster labels to original dataframe for analysis
    if original_df is not None:

        # Check if sizes match
        if len(cluster_labels) != (len(original_df) - (len(outlier_indices) if outlier_indices is not None else 0)):
            print(f"Warning: Cluster labels size ({len(cluster_labels)}) doesn't match DataFrame size ({len(original_df)})")
            print("This likely happened because outliers were removed during clustering and outlier indices not provided or incorrect.")

        else:
            if outlier_indices is not None and len(outlier_indices) > 0:
                mask = ~pl.Series(range(len(original_df))).is_in(outlier_indices)
                analysis_df = original_df.filter(mask)
            else:
                analysis_df = original_df.head(len(cluster_labels))

            # Now add cluster labels safely
            analysis_df = analysis_df.with_columns(pl.Series('cluster', cluster_labels))

            # Cluster profiling
            print("\n📊 CLUSTER PROFILES:")
            print("=" * 50)

            cluster_profiles = analysis_df.group_by('cluster').agg([
                pl.col('total_searches').mean().alias('avg_searches'),
                pl.col('is_vip').mean().alias('vip_rate'),
                pl.col('roundtrip_preference').mean().alias('roundtrip_rate'),
                pl.col('avg_booking_lead_days').mean().alias('avg_lead_days'),
                pl.col('unique_carriers_used').mean().alias('avg_carriers'),
                pl.len().alias('size')
            ]).sort('cluster')

            print(cluster_profiles)

### Generate Customer Clusters

In [11]:
# 2. Remove outliers
cust_data_cleaned, outlier_indices = remove_outliers(cust_data, method='isolation_forest', contamination=0.05)
# cust_data_cleaned, outlier_indices = remove_outliers(cust_data, method='zscore')
# cust_data_cleaned, outlier_indices = remove_outliers(cust_data, method='iqr')

Removed 1,647 outliers (5.0%)


In [140]:
# 2b. Remove specific customer profiles that result in 1 or 2 member clusters
# Add profileId to removal list and get its index
# profiles_to_remove = [1361953]
# profile_mask = cust_data_cleaned['profileId'].is_in(profiles_to_remove)
# profile_indices = cust_data_cleaned.with_row_index().filter(profile_mask)['index'].to_list()

# Remove from cleaned data and add indices to outlier list
# cust_data_cleaned = cust_data_cleaned.filter(~profile_mask)
# outlier_indices = list(outlier_indices) + profile_indices

In [12]:
# 3. Advanced encoding
cust_data_encoded, encoders = encode_features(cust_data_cleaned)

In [62]:
# 3b. Do not remove outliers
# outlier_indices = []
# cust_data_encoded, encoders = encode_features(cust_data)

In [13]:
# 4. Feature scaling with robust scaler
scaler = RobustScaler()  # Less sensitive to outliers than StandardScaler
scaled_features = scaler.fit_transform(cust_data_encoded.to_pandas())

In [14]:
# 5. Dimensionality reduction
n_components = 0.95  # retain components that explain 95% of variance
# n_components = 30    # retain 30 components

reducer = PCA(n_components=n_components, random_state=42)
reduced_features = reducer.fit_transform(scaled_features)

if hasattr(reducer, 'explained_variance_ratio_'):
    total_variance = reducer.explained_variance_ratio_.sum()
    print(f"PCA retained {total_variance:.3f} of total variance with {reduced_features.shape[1]} components")

PCA retained 0.958 of total variance with 3 components


In [18]:
# 6. Apply clustering algorithm
# clustering_results = generate_clusters(reduced_features)
clustering_results = generate_clusters(reduced_features, n_clusters=14)

# 7. Clustering results
print("\n🏆 CLUSTERING RESULTS:")
print(f"  Clusters: {clustering_results['n_clusters']:2d}| Silhouette: {clustering_results['silhouette']:.4f} | ")


🏆 CLUSTERING RESULTS:
  Clusters: 14| Silhouette: 0.5808 | 


In [16]:
analyze_clusters(cust_data_cleaned, clustering_results['labels'])


📊 CLUSTER PROFILES:
shape: (14, 7)
┌─────────┬──────────────┬──────────┬────────────────┬───────────────┬──────────────┬───────┐
│ cluster ┆ avg_searches ┆ vip_rate ┆ roundtrip_rate ┆ avg_lead_days ┆ avg_carriers ┆ size  │
│ ---     ┆ ---          ┆ ---      ┆ ---            ┆ ---           ┆ ---          ┆ ---   │
│ i64     ┆ f64          ┆ f64      ┆ f64            ┆ f64           ┆ f64          ┆ u32   │
╞═════════╪══════════════╪══════════╪════════════════╪═══════════════╪══════════════╪═══════╡
│ 0       ┆ 580.375      ┆ 0.0      ┆ 0.520211       ┆ 11.457576     ┆ 6.240625     ┆ 320   │
│ 1       ┆ 1391.428571  ┆ 1.0      ┆ 1.0            ┆ 18.328623     ┆ 5.714286     ┆ 7     │
│ 2       ┆ 1001.214286  ┆ 1.0      ┆ 0.9144         ┆ 10.075906     ┆ 6.142857     ┆ 14    │
│ 3       ┆ 477.23813    ┆ 0.0      ┆ 0.465066       ┆ 15.622722     ┆ 8.107378     ┆ 2738  │
│ 4       ┆ 340.87718    ┆ 0.003067 ┆ 0.570304       ┆ 14.834675     ┆ 4.741422     ┆ 20868 │
│ …       ┆ …           

In [101]:
cust_data_cleaned.with_columns(pl.Series('customer_segment', clustering_results['labels'])).filter(pl.col('customer_segment') == 7).head(100)

profileId,company_id,sex,nationality,frequent_flyer,is_vip,by_self,has_corp_codes,total_searches,roundtrip_preference,unique_routes_searched,min_booking_lead_days,max_booking_lead_days,avg_booking_lead_days,median_booking_lead_days,most_common_departure_airport,unique_departure_airports,most_common_carrier,unique_carriers_used,min_cabin_class,max_cabin_class,avg_cabin_class,weekday_preference,weekend_travel_rate,time_of_day_variance,night_flight_preference,route_loyalty,hub_preference,short_haul_preference,connection_tolerance,preferred_duration_quartile,price_to_duration_sensitivity,avg_price_per_minute,price_per_minute_variance,price_position_preference,premium_economy_preference,consistent_price_tier,preferred_price_tier,baggage_qty_preference,baggage_weight_preference,loyalty_program_utilization,convenience_priority_score,loyalty_vs_price_index,planning_consistency_score,luxury_index,search_intensity_per_route,lead_time_variance,lead_time_skew,carrier_diversity,airport_diversity,cabin_class_range,customer_tier,vip_search_intensity,vip_carrier_diversity,vip_cabin_preference,corp_search_volume,corp_roundtrip_pref,corp_planning_variance,customer_segment
i64,i64,i8,i64,str,i8,i8,i8,u32,f64,u32,i32,i32,f64,f64,str,u32,str,u32,f64,f64,f64,i8,f64,f64,f64,f64,f64,f64,f64,i32,f64,f64,f64,f64,f64,f64,i32,f64,f64,f64,f64,f64,f64,f64,f64,i32,f64,f64,f64,f64,i8,f64,f64,f64,i64,f64,i32,i64
1361953,43628,0,36,"""""",1,1,0,7502,1.0,1,11,12,11.844308,12.0,"""SVO""",3,"""SU""",4,1.0,1.0,1.0,3,0.0,5.358169,0.143962,0.999867,0.0,0.650583,1.0,1,0.011864,74.821299,29.334728,48.613918,0.0,0.623578,1,0.730738,0.0,0.0,-43.522373,-4.861392,0.5,44.306959,7502.0,1,0.987026,0.000533,0.0004,0.0,3,7502.0,0.000533,1.0,0,0.0,0,7


In [117]:
def alternative_clustering_methods(features):
    """Try different clustering algorithms"""

    results = {}

    # 1. Gaussian Mixture Models
    print("Testing Gaussian Mixture Models...")
    best_gmm_score = -1
    best_gmm_n = 0

    for n_clusters in [6, 7, 8, 9, 10, 11]:
        gmm = GaussianMixture(
            n_components=n_clusters,
            max_iter=200,
            random_state=42,
            covariance_type='full')
        labels = gmm.fit_predict(features)

        if len(set(labels)) > 1:  # Ensure we have multiple clusters
            score = silhouette_score(features, labels)
            if score > best_gmm_score:
                best_gmm_score = score
                best_gmm_n = n_clusters

    # Fit best GMM
    best_gmm = GaussianMixture(n_components=best_gmm_n, random_state=42, covariance_type='full')
    gmm_labels = best_gmm.fit_predict(features)
    results['gmm'] = {
        'labels': gmm_labels,
        'silhouette': silhouette_score(features, gmm_labels),
        'n_clusters': len(set(gmm_labels))
    }

    # 2. DBSCAN
    print("Testing DBSCAN...")
    # Test different eps values
    best_dbscan_score = -1
    best_dbscan_eps = 0

    # Determine eps values based on feature data
    neighbors = NearestNeighbors(n_neighbors=50)
    distances, _ = neighbors.fit(features).kneighbors(features)
    k_distances = np.sort(distances[:, -1])

    # Use percentiles of actual distances in your data
    eps_values = [
        np.percentile(k_distances, 50),   # Median distance
        np.percentile(k_distances, 70),   # 70th percentile
        np.percentile(k_distances, 85),   # 85th percentile
        np.percentile(k_distances, 95),   # 95th percentile
    ]
    print(f"Suggested eps values based on feature data: {[f'{eps:.3f}' for eps in eps_values]}")

    for eps in eps_values:
        dbscan = DBSCAN(eps=eps, min_samples=50)
        labels = dbscan.fit_predict(features)

        if len(set(labels)) > 1 and -1 not in labels:  # Ensure valid clustering
            score = silhouette_score(features, labels)
            if score > best_dbscan_score:
                best_dbscan_score = score
                best_dbscan_eps = eps
            else:
                print(f'DBSCAN results score {score:.4f} with eps {eps:.2f} was not better than {best_dbscan_score:.4f}')
        else:
            print(f"DBSCAN did not produce any clusters for eps={eps}")

    if best_dbscan_eps > 0:
        best_dbscan = DBSCAN(eps=best_dbscan_eps, min_samples=50)
        dbscan_labels = best_dbscan.fit_predict(features)
        results['dbscan'] = {
            'labels': dbscan_labels,
            'silhouette': silhouette_score(features, dbscan_labels),
            'n_clusters': len(set(dbscan_labels)) - (1 if -1 in dbscan_labels else 0)
        }
    else:
        print('DBSCAN did not produce any clusters...')

    # 3. Agglomerative Clustering
    print("Testing Agglomerative Clustering...")
    best_agg_score = -1
    best_agg_n = 0

    for n_clusters in [6, 7, 8, 9, 10, 11]:
        agg = AgglomerativeClustering(n_clusters=n_clusters, linkage='ward')
        labels = agg.fit_predict(features)

        score = silhouette_score(features, labels)
        if score > best_agg_score:
            best_agg_score = score
            best_agg_n = n_clusters

    best_agg = AgglomerativeClustering(n_clusters=best_agg_n, linkage='ward')
    agg_labels = best_agg.fit_predict(features)
    results['agglomerative'] = {
        'labels': agg_labels,
        'silhouette': silhouette_score(features, agg_labels),
        'n_clusters': len(set(agg_labels))
    }

    try:
        for method, m_results in results.items():
            print(f"{method.upper():15s} | Silhouette: {m_results['silhouette']:.4f} | Clusters: {m_results['n_clusters']:2d}")

        # Select best method
        best_method = max(results.items(), key=lambda x: x[1]['silhouette'])
        print(f"\n🥇 Best method: {best_method[0].upper()} (Silhouette: {best_method[1]['silhouette']:.4f})")
    except Exception as e:
        print(f"Error: {e}")

    return results


In [15]:
def optimum_num_clusters(features):
    best_agg_score = -1
    best_agg_n = 0

    for n_clusters in [9, 10, 11, 12, 13, 14, 15, 16, 17]:
        agg = AgglomerativeClustering(n_clusters=n_clusters, linkage='ward')
        labels = agg.fit_predict(features)

        score = silhouette_score(features, labels)
        if score > best_agg_score:
            best_agg_score = score
            best_agg_n = n_clusters

    best_agg = AgglomerativeClustering(n_clusters=best_agg_n, linkage='ward')
    agg_labels = best_agg.fit_predict(features)

    print(f"Best number of clusters: {best_agg_n}, with Silhouette score: {best_agg_score:.4f}")

    return {
        'labels': agg_labels,
        'silhouette': silhouette_score(features, agg_labels),
        'n_clusters': len(set(agg_labels))
    }

In [16]:
opt_results = optimum_num_clusters(reduced_features)

Best number of clusters: 14, with Silhouette score: 0.9000


In [17]:
analyze_clusters(cust_data_encoded, opt_results['labels'])


📊 CLUSTER PROFILES:
shape: (14, 7)
┌─────────┬──────────────┬──────────┬────────────────┬───────────────┬──────────────┬───────┐
│ cluster ┆ avg_searches ┆ vip_rate ┆ roundtrip_rate ┆ avg_lead_days ┆ avg_carriers ┆ size  │
│ ---     ┆ ---          ┆ ---      ┆ ---            ┆ ---           ┆ ---          ┆ ---   │
│ i64     ┆ f64          ┆ f64      ┆ f64            ┆ f64           ┆ f64          ┆ u32   │
╞═════════╪══════════════╪══════════╪════════════════╪═══════════════╪══════════════╪═══════╡
│ 0       ┆ 507.730658   ┆ 0.000362 ┆ 0.466978       ┆ 15.340529     ┆ 8.086768     ┆ 2766  │
│ 1       ┆ 127.217822   ┆ 1.0      ┆ 0.436559       ┆ 13.501157     ┆ 4.475248     ┆ 101   │
│ 2       ┆ 835.2        ┆ 1.0      ┆ 0.840812       ┆ 12.870281     ┆ 7.15         ┆ 20    │
│ 3       ┆ 1517.888889  ┆ 1.0      ┆ 0.988383       ┆ 13.803028     ┆ 7.777778     ┆ 9     │
│ 4       ┆ 875.488869   ┆ 0.0      ┆ 0.607001       ┆ 17.168671     ┆ 10.642921    ┆ 1123  │
│ …       ┆ …           

In [19]:
cust_data_cleaned = cust_data_cleaned.with_columns(pl.Series('customer_segment', opt_results['labels']))

In [20]:
cust_data_cleaned.group_by('customer_segment').len()

customer_segment,len
i64,u32
3,9
2,20
6,5005
7,1
5,19
…,…
10,25
1,101
0,2766
8,41


In [21]:
# Get profileId for lone member of cluster 7
cust_data_cleaned.filter(pl.col('customer_segment') == 7)['profileId']

profileId
i64
1361953


In [22]:
# Check how many searches exist for the customer
data.filter(pl.col('profileId') == 1361953)['ranker_id'].unique()

ranker_id
str
"""d81120a52d824f00aa68939334a7d9…"


In [23]:
# Get customer data for lone member of cluster 7
dup_cust_data = data.filter(pl.col('profileId') == 1361953)
dup_cust = cust_data_cleaned.filter(pl.col('profileId') == 1361953)

# Generate new profileId and ranker_id for adding duplicate data
new_profile_id = (max(data['profileId'].unique()) + 1)
new_ranker_id = str(uuid.uuid4()).replace('-', '')
assert new_profile_id not in data['profileId'].unique().to_list(), f"ProfileId {new_profile_id} already exists"
assert new_ranker_id not in data['ranker_id'].unique().to_list(), f"RankerId {new_ranker_id} already exists"

# Create data sets for duplicated customer data rows and features
new_cust_data = dup_cust_data.with_columns([
    pl.lit(new_profile_id).cast(pl.Int64).alias('profileId'),
    pl.lit(new_ranker_id).alias('ranker_id')
])
new_cust = dup_cust.with_columns([
    pl.lit(new_profile_id).cast(pl.Int64).alias('profileId')
])

# Append new customer data to original data
data = data.vstack(new_cust_data)
cust_data_cleaned = cust_data_cleaned.vstack(new_cust)


In [24]:
# Save customer features
cust_data_cleaned.write_parquet('data/customer_features.parquet')

In [147]:
del cust_data
del cust_data_cleaned
del cust_data_encoded
del scaled_features
del reduced_features
del scaler
gc.collect()

1752

### Engineer Flight Features

In [148]:
CUSTOMER_FEATURES = ['companyID', 'sex', 'nationality', 'isVip', 'bySelf']
UNNEEDED_FEATURES = [
    'frequentFlyer', 'frequent_flyer', 'isAccess3D', 'requestDate', 'searchRoute', 'totalPrice', 'taxes', 'legs0_arrivalAt', 'legs0_duration'
]
UNNEEDED_FEATURES_REGEX = r'^legs[01].*$|^leg0_duration_q.*$|^price_q.*$'
POLARS_INDEX_COL = ['__index_level_0__']


# function to get column groups for the flight segments features that will be used in engineering the new features
def get_column_groups(df: pl.DataFrame) -> Dict[str, List[str]]:
    columns = df.columns
    return {
        'leg0_duration': [col for col in columns if col.startswith('legs0_segments') and col.endswith('_duration')],
        'leg1_duration': [col for col in columns if col.startswith('legs1_segments') and col.endswith('_duration')],
        'mkt_carrier': [col for col in columns if col.endswith('marketingCarrier_code')],
        'op_carrier': [col for col in columns if col.endswith('operatingCarrier_code')],
        'aircraft': [col for col in columns if col.endswith('aircraft_code')],
        'airport': [col for col in columns if 'airport_iata' in col or 'airport_city_iata' in col],
        'seats_avail': [col for col in columns if col.endswith('seatsAvailable')],
        'baggage_type': [col for col in columns if 'baggageAllowance_weightMeasurementType' in col],
        'baggage_allowance': [col for col in columns if 'baggageAllowance_quantity' in col],
        'cabin_class': [col for col in columns if 'cabinClass' in col]
    }


def create_basic_features(col_groups: Dict[str, List[str]]) -> List[pl.Expr]:
    """Create basic customer and route characteristics features."""
    return [
        pl.col('isAccess3D').cast(pl.Int32).alias('is_access3D'),

        # Normalized frequentFlyer program, addressing null values as null strings, and translating UT program
        pl.col('frequentFlyer').str.replace('- ЮТэйр ЗАО', 'UT').fill_null('').alias('frequent_flyer'),

        # Route characteristics
        pl.col('legs1_departureAt').is_not_null().cast(pl.Int8).alias('is_roundtrip'),
        pl.col('searchRoute').str.slice(0, 3).alias('route_origin'),

        # Hub features
        pl.col('searchRoute').str.slice(0, 3).is_in(MAJOR_HUBS).cast(pl.Int32).alias('origin_is_major_hub'),
        pl.col('searchRoute').str.slice(3, 3).is_in(MAJOR_HUBS).cast(pl.Int32).alias('destination_is_major_hub'),
        pl.max_horizontal([
            pl.col(col).is_in(MAJOR_HUBS) for col in col_groups['airport']
        ]).cast(pl.Int32).alias('includes_major_hub'),
    ]


def create_segment_features(col_groups: Dict[str, List[str]]) -> List[pl.Expr]:
    """Create flight segment related features."""
    return [
        # Number of segments per leg
        pl.concat_list([pl.col(col) for col in col_groups['leg0_duration']])
        .list.drop_nulls().list.len().alias('leg0_num_segments'),

        pl.concat_list([pl.col(col) for col in col_groups['leg1_duration']])
        .list.drop_nulls().list.len().alias('leg1_num_segments'),

        # Total segments
        pl.sum_horizontal([pl.col(col).is_not_null() for col in col_groups['aircraft']])
        .alias('total_segments'),

        # Flight time in minutes (sum of segment durations)
        pl.sum_horizontal([
            parse_duration_to_minutes(col)
            for col in col_groups['leg0_duration']
        ]).alias('leg0_flight_time_min'),

        pl.sum_horizontal([
            parse_duration_to_minutes(col)
            for col in col_groups['leg1_duration']
        ]).alias('leg1_flight_time_min'),
    ]


def create_time_features() -> List[pl.Expr]:
    """Create time-based features."""
    return [
        # Booking lead time
        ((pl.col('legs0_departureAt').str.to_datetime() -
            pl.col('requestDate').cast(pl.Datetime)) / pl.duration(days=1)).cast(pl.Int32).alias('booking_lead_days'),

        # Trip duration features
        parse_duration_to_minutes('legs0_duration').alias('leg0_duration_minutes'),
        parse_duration_to_minutes('legs1_duration').alias('leg1_duration_minutes'),
        (parse_duration_to_minutes('legs0_duration') + parse_duration_to_minutes('legs1_duration')).alias('trip_duration_minutes'),

        # Departure/arrival hours
        pl.col('legs0_departureAt').str.to_datetime().dt.hour().alias('leg0_departure_hour'),
        pl.col('legs0_departureAt').str.to_datetime().dt.weekday().alias('leg0_departure_weekday'),
        pl.col('legs0_arrivalAt').str.to_datetime().dt.hour().alias('leg0_arrival_hour'),
        pl.col('legs0_arrivalAt').str.to_datetime().dt.weekday().alias('leg0_arrival_weekday'),
        pl.col('legs1_departureAt').str.to_datetime().dt.hour().alias('leg1_departure_hour'),
        pl.col('legs1_arrivalAt').str.to_datetime().dt.hour().alias('leg1_arrival_hour'),
        pl.col('legs1_departureAt').str.to_datetime().dt.weekday().alias('leg1_departure_weekday'),
        pl.col('legs1_arrivalAt').str.to_datetime().dt.weekday().alias('leg1_arrival_weekday'),
    ]


def create_carrier_features(col_groups: Dict[str, List[str]]) -> List[pl.Expr]:
    """Create carrier and aircraft features."""
    # Create combined carrier column group for carrier features
    all_carrier_cols = col_groups['mkt_carrier'] + col_groups['op_carrier']

    # Create frequent flyer matching expressions
    ff_matches_mkt = [
        pl.when(pl.col(col).is_not_null() & (pl.col(col) != ''))
        .then(pl.col('frequent_flyer').str.contains(pl.col(col)))
        .otherwise(False)
        for col in col_groups['mkt_carrier']
    ]
    ff_matches_op = [
        pl.when(pl.col(col).is_not_null() & (pl.col(col) != ''))
        .then(pl.col('frequent_flyer').str.contains(pl.col(col)))
        .otherwise(False)
        for col in col_groups['op_carrier']
    ]

    return [
        # Carrier features
        pl.concat_list([pl.col(col) for col in all_carrier_cols])
        .list.drop_nulls().list.unique().list.len().alias('unique_carriers'),

        pl.coalesce([pl.col(col) for col in all_carrier_cols]).alias('primary_carrier'),
        pl.col('legs0_segments0_marketingCarrier_code').alias('marketing_carrier'),

        # Frequent flyer matching (check if FF programs match carriers)
        pl.max_horizontal(ff_matches_mkt).cast(pl.Int32).alias('has_mkt_ff'),
        pl.max_horizontal(ff_matches_op).cast(pl.Int32).alias('has_op_ff'),

        # Aircraft features
        pl.concat_list([pl.col(col) for col in col_groups['aircraft']])
        .list.drop_nulls().list.unique().list.len().alias('aircraft_diversity'),

        pl.coalesce([pl.col(col) for col in col_groups['aircraft']]).alias('primary_aircraft'),
    ]


def create_service_features(col_groups: Dict[str, List[str]]) -> List[pl.Expr]:
    """Create service-related features (fees, cabin, seats, baggage)."""
    return [
        # Cancellation/exchange fees
        (
            ((pl.col('miniRules0_monetaryAmount') > 0) & pl.col('miniRules0_monetaryAmount').is_not_null()) |
            ((pl.col('miniRules0_percentage') > 0) & pl.col('miniRules0_percentage').is_not_null())
        ).cast(pl.Int32).alias('has_cancellation_fee'),

        (
            ((pl.col('miniRules1_monetaryAmount') > 0) & pl.col('miniRules1_monetaryAmount').is_not_null()) |
            ((pl.col('miniRules1_percentage') > 0) & pl.col('miniRules1_percentage').is_not_null())
        ).cast(pl.Int32).alias('has_exchange_fee'),

        # Cabin class features
        pl.max_horizontal([pl.col(col) for col in col_groups['cabin_class']]).alias('max_cabin_class'),
        pl.mean_horizontal([pl.col(col) for col in col_groups['cabin_class']]).alias('avg_cabin_class'),

        # Seat availability (using for understanding popularity/scarcity)
        pl.min_horizontal([pl.col(col) for col in col_groups['seats_avail']]).alias('min_seats_available'),
        pl.sum_horizontal([pl.col(col).fill_null(0) for col in col_groups['seats_avail']]).alias('total_seats_available'),

        # Baggage allowance
        pl.min_horizontal([
            pl.when(pl.col(type_col) == 0).then(pl.col(qty_col)).otherwise(pl.lit(None))
            for type_col, qty_col in zip(col_groups['baggage_type'], col_groups['baggage_allowance'])
        ]).fill_null(0).alias("baggage_allowance_quantity"),

        pl.min_horizontal([
            pl.when(pl.col(type_col) == 1).then(pl.col(qty_col)).otherwise(pl.lit(None))
            for type_col, qty_col in zip(col_groups['baggage_type'], col_groups['baggage_allowance'])
        ]).fill_null(0).alias("baggage_allowance_weight")
    ]


def add_route_popularity(lazy_df: pl.LazyFrame) -> pl.LazyFrame:
    """Add search route popularity features."""
    return lazy_df.with_columns([
        pl.len().over('searchRoute').alias('route_popularity')
    ]).with_columns([
        pl.col('route_popularity').log1p().alias('route_popularity_log')
    ])


def create_window_based_flight_features(lazy_df: pl.LazyFrame) -> pl.LazyFrame:
    return lazy_df.with_columns([
        # calculate price percentile over search session
        ((pl.col('totalPrice').rank(method='min').over('ranker_id') - 1) /
        (pl.col('totalPrice').count().over('ranker_id') - 1) * 100)
        .fill_null(50.0).alias('price_percentile'),

        # Price features (always relative to search session)
        (pl.col('totalPrice').rank(method='ordinal').over('ranker_id') /
         pl.col('totalPrice').count().over('ranker_id')).alias('price_rank_pct'),
        (pl.col('totalPrice') / pl.col('totalPrice').min().over('ranker_id')).alias('price_ratio_to_min'),
    ]).with_columns([

        # Route popularity features
        pl.len().over('searchRoute').alias('route_popularity')
    ]).with_columns([
        pl.col('route_popularity').log1p().alias('route_popularity_log')
    ])


def create_derived_flight_features() -> List[pl.Expr]:
    return [
        pl.col('leg0_departure_hour').is_between(6, 22).cast(pl.Int8).alias('is_daytime'),
        (pl.col('leg1_departure_weekday') >= 5).cast(pl.Int8).alias('is_weekend'),
        (pl.col('leg0_departure_hour').is_between(6, 22) &
         (~pl.col('leg0_arrival_hour').is_between(6, 22)))
        .cast(pl.Int8).alias('is_redeye'),
        (pl.col('total_segments') > 1).cast(pl.Int8).alias('has_connections'),
    ]

def extract_flight_features(df: pl.DataFrame) -> pl.DataFrame:
    """ Create flight-level features"""
    # Check if already processed
    if df.height > 0 and 'total_duration_mins' in df.columns:
        return df


    # Get column groups
    col_groups = get_column_groups(df)

    # Start with lazy frame, dropping unnecessary columns
    lazy_df = df.drop(CUSTOMER_FEATURES + POLARS_INDEX_COL, strict=False).lazy()
    lazy_df = create_window_based_flight_features(lazy_df)

    # Apply all feature groups
    lazy_df = lazy_df.with_columns([
        *create_basic_features(col_groups),
        *create_segment_features(col_groups),
        *create_time_features(),
        *create_service_features(col_groups)
    ])

    # Add carrier features (requires frequent_flyer from create_basic_features() step)
    lazy_df = lazy_df.with_columns(create_carrier_features(col_groups))

    # Add derived features
    lazy_df = lazy_df.with_columns(create_derived_flight_features())

    # Add window based derived flight features
    lazy_df = create_window_based_flight_features(lazy_df)

    # Materialize to generate new features, drop unused features, and fill nulls
    return (lazy_df
            .collect()
            .select(~cs.matches(UNNEEDED_FEATURES_REGEX) & ~cs.by_name(UNNEEDED_FEATURES))
            .fill_null(0)
            )


In [149]:
data = extract_flight_features(data)

In [150]:
data.head(100)

Id,corporateTariffCode,miniRules0_monetaryAmount,miniRules0_percentage,miniRules0_statusInfos,miniRules1_monetaryAmount,miniRules1_percentage,miniRules1_statusInfos,pricingInfo_isAccessTP,pricingInfo_passengerCount,profileId,ranker_id,selected,price_percentile,price_rank_pct,price_ratio_to_min,route_popularity,route_popularity_log,is_access3D,is_roundtrip,route_origin,origin_is_major_hub,destination_is_major_hub,includes_major_hub,leg0_num_segments,leg1_num_segments,total_segments,leg0_flight_time_min,leg1_flight_time_min,booking_lead_days,leg0_duration_minutes,leg1_duration_minutes,trip_duration_minutes,leg0_departure_hour,leg0_departure_weekday,leg0_arrival_hour,leg0_arrival_weekday,leg1_departure_hour,leg1_arrival_hour,leg1_departure_weekday,leg1_arrival_weekday,has_cancellation_fee,has_exchange_fee,max_cabin_class,avg_cabin_class,min_seats_available,total_seats_available,baggage_allowance_quantity,baggage_allowance_weight,unique_carriers,primary_carrier,marketing_carrier,has_mkt_ff,has_op_ff,aircraft_diversity,primary_aircraft,is_daytime,is_weekend,is_redeye,has_connections
i64,i64,f64,f64,f64,f64,f64,f64,f64,i64,i64,str,i64,f64,f64,f64,u32,f64,i32,i8,str,i32,i32,i32,u32,u32,u32,i32,i32,i32,i32,i32,i32,i8,i8,i8,i8,i8,i8,i8,i8,i32,i32,f64,f64,f64,f64,f64,f64,u32,str,str,i32,i32,u32,str,i8,i8,i8,i8
0,0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1,2087645,"""98ce0dabf6964640b63079fbafd42c…",1,0.0,0.04,1.0,25,3.258097,0,1,"""TLK""",0,0,0,1,1,2,160,155,29,160,155,315,15,6,16,6,9,14,2,2,0,0,1.0,1.0,9.0,18.0,1.0,0.0,1,"""KV""","""KV""",0,0,1,"""YK2""",1,0,0,1
1,123,2300.0,0.0,1.0,3500.0,0.0,1.0,1.0,1,2087645,"""98ce0dabf6964640b63079fbafd42c…",0,4.166667,0.08,3.028015,25,3.258097,1,1,"""TLK""",0,0,0,2,2,4,250,245,29,445,505,950,9,6,14,6,22,8,2,3,1,1,1.0,1.0,4.0,26.0,1.0,0.0,1,"""S7""","""S7""",1,1,1,"""E70""",1,0,0,1
2,0,2300.0,0.0,1.0,3500.0,0.0,1.0,1.0,1,2087645,"""98ce0dabf6964640b63079fbafd42c…",0,29.166667,0.32,3.18023,25,3.258097,0,1,"""TLK""",0,0,0,2,2,4,250,245,29,445,505,950,9,6,14,6,22,8,2,3,1,1,1.0,1.0,4.0,26.0,1.0,0.0,1,"""S7""","""S7""",1,1,1,"""E70""",1,0,0,1
3,123,0.0,0.0,1.0,0.0,0.0,1.0,1.0,1,2087645,"""98ce0dabf6964640b63079fbafd42c…",0,54.166667,0.56,4.849562,25,3.258097,1,1,"""TLK""",0,0,0,2,2,4,250,245,29,445,505,950,9,6,14,6,22,8,2,3,0,0,1.0,1.0,4.0,26.0,1.0,0.0,1,"""S7""","""S7""",1,1,1,"""E70""",1,0,0,1
4,0,0.0,0.0,1.0,0.0,0.0,1.0,1.0,1,2087645,"""98ce0dabf6964640b63079fbafd42c…",0,79.166667,0.8,5.097726,25,3.258097,0,1,"""TLK""",0,0,0,2,2,4,250,245,29,445,505,950,9,6,14,6,22,8,2,3,0,0,1.0,1.0,4.0,26.0,1.0,0.0,1,"""S7""","""S7""",1,1,1,"""E70""",1,0,0,1
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
95,0,2300.0,0.0,1.0,3500.0,0.0,1.0,0.0,1,3382768,"""e04b757602824a4dbe227f1e67dbdb…",0,75.757576,0.764706,12.361303,291,5.676754,0,0,"""KHV""",0,0,0,2,0,2,645,0,3,1410,0,1410,8,1,8,2,0,0,0,0,1,1,2.0,1.5,3.0,12.0,1.0,0.0,1,"""S7""","""S7""",0,0,1,"""321""",1,0,0,1
96,139,4000.0,0.0,1.0,0.0,0.0,0.0,0.0,1,3382768,"""e04b757602824a4dbe227f1e67dbdb…",0,78.787879,0.794118,12.817417,291,5.676754,1,0,"""KHV""",0,0,0,2,0,2,645,0,3,1410,0,1410,8,1,8,2,0,0,0,0,1,0,2.0,2.0,3.0,6.0,1.0,0.0,1,"""S7""","""S7""",0,0,1,"""321""",1,0,0,1
97,0,4000.0,0.0,1.0,0.0,0.0,0.0,0.0,1,3382768,"""e04b757602824a4dbe227f1e67dbdb…",0,81.818182,0.823529,13.343914,291,5.676754,0,0,"""KHV""",0,0,0,2,0,2,645,0,3,1410,0,1410,8,1,8,2,0,0,0,0,1,0,2.0,2.0,3.0,6.0,1.0,0.0,1,"""S7""","""S7""",0,0,1,"""321""",1,0,0,1
98,139,2300.0,0.0,1.0,3500.0,0.0,1.0,0.0,1,3382768,"""e04b757602824a4dbe227f1e67dbdb…",0,84.848485,0.852941,16.305548,291,5.676754,1,0,"""KHV""",0,0,0,2,0,2,645,0,3,1410,0,1410,8,1,8,2,0,0,0,0,1,1,2.0,2.0,3.0,6.0,1.0,0.0,1,"""S7""","""S7""",0,0,1,"""321""",1,0,0,1


### Prepare Training Data

Now combining customer features with clustering segmentation results and flight features

In [151]:
def stratified_split(X, test_size=0.2, random_state=42):
    # Get unique search sessions and associated customer_segment
    search_sessions = (
        X
        .group_by('ranker_id')
        .agg(pl.col('customer_segment').first().alias('customer_segment'))
        .to_pandas()
    )

    print(f"Total unique search sessions: {len(search_sessions)}")
    print("Segment distribution:")
    segment_counts = search_sessions['customer_segment'].value_counts().sort_index()
    print(segment_counts)

    # Split on search sessions (ranker_id), stratify by customer_segment
    train_searches, test_searches = train_test_split(
        search_sessions['ranker_id'].values,
        test_size=test_size,
        stratify=search_sessions['customer_segment'].values,
        random_state=random_state
    )

    print(f"\nTrain searches: {len(train_searches)}")
    print(f"Test searches: {len(test_searches)}")

    # Filter feature data on search session splits
    X_train_mask = X['ranker_id'].is_in(train_searches)
    X_test_mask = X['ranker_id'].is_in(test_searches)

    X_train = X.filter(X_train_mask)
    X_test = X.filter(X_test_mask)

    return X_train, X_test


In [152]:
def prepare_training_data(flight_features, cust_data, cust_segments=None, outlier_indices=None, test_size=0.2):
    # if outliers were removed for clustering, then remove them for training
    if outlier_indices is not None and len(outlier_indices) > 0:
        mask = ~pl.Series(range(len(cust_data))).is_in(outlier_indices)
        cust_features = cust_data.filter(mask).lazy()
    else:
        cust_features = cust_data.lazy()

    # Add the cluster labels to the customer features if not already in the dataframe
    if cust_segments is not None and 'customer_segment' not in cust_features.columns:
        cust_features = cust_features.with_columns([
            pl.Series('customer_segment', cust_segments)
        ])

    # Join the customer features with the flight features, removing outlier profiles
    flight_features = flight_features.lazy().join(cust_features, on='profileId', how='inner').drop(pl.col('profileId'))
    flight_features = flight_features.with_columns([
        (pl.col('is_daytime') * (1.0 - pl.col('night_flight_preference').fill_null(0.5))).alias('daytime_alignment'),
        (pl.col('is_weekend') * pl.col('weekend_travel_rate').fill_null(0.5)).alias('weekend_alignment'),
        (pl.col('price_rank_pct') * pl.col('price_position_preference').fill_null(0.5)).alias('price_preference_match'),
        (pl.col('primary_carrier') == pl.col('most_common_carrier').fill_null('unknown')).cast(pl.Int8).alias('carrier_loyalty_match'),
    ])

    # Materialize the data
    X = flight_features.collect()

    # Generate and return Train/Test splits
    return stratified_split(X, test_size=test_size)

In [153]:
X_train, X_test = prepare_training_data(data, pl.read_parquet('data/customer_features.parquet'))

Total unique search sessions: 94543
Segment distribution:
customer_segment
0    16335
1    55546
2      186
3       10
4       50
5        3
6    19093
7     2483
8       79
9      758
Name: count, dtype: int64

Train searches: 75634
Test searches: 18909


In [154]:
X_train.head(100)

Id,corporateTariffCode,miniRules0_monetaryAmount,miniRules0_percentage,miniRules0_statusInfos,miniRules1_monetaryAmount,miniRules1_percentage,miniRules1_statusInfos,pricingInfo_isAccessTP,pricingInfo_passengerCount,ranker_id,selected,price_percentile,price_rank_pct,price_ratio_to_min,route_popularity,route_popularity_log,is_access3D,is_roundtrip,route_origin,origin_is_major_hub,destination_is_major_hub,includes_major_hub,leg0_num_segments,leg1_num_segments,total_segments,leg0_flight_time_min,leg1_flight_time_min,booking_lead_days,leg0_duration_minutes,leg1_duration_minutes,trip_duration_minutes,leg0_departure_hour,leg0_departure_weekday,leg0_arrival_hour,leg0_arrival_weekday,leg1_departure_hour,…,route_loyalty,hub_preference,short_haul_preference,connection_tolerance,preferred_duration_quartile,price_to_duration_sensitivity,avg_price_per_minute,price_per_minute_variance,price_position_preference,premium_economy_preference,consistent_price_tier,preferred_price_tier,baggage_qty_preference,baggage_weight_preference,loyalty_program_utilization,convenience_priority_score,loyalty_vs_price_index,planning_consistency_score,luxury_index,search_intensity_per_route,lead_time_variance,lead_time_skew,carrier_diversity,airport_diversity,cabin_class_range,customer_tier,vip_search_intensity,vip_carrier_diversity,vip_cabin_preference,corp_search_volume,corp_roundtrip_pref,corp_planning_variance,customer_segment,daytime_alignment,weekend_alignment,price_preference_match,carrier_loyalty_match
i64,i64,f64,f64,f64,f64,f64,f64,f64,i64,str,i64,f64,f64,f64,u32,f64,i32,i8,str,i32,i32,i32,u32,u32,u32,i32,i32,i32,i32,i32,i32,i8,i8,i8,i8,i8,…,f64,f64,f64,f64,i32,f64,f64,f64,f64,f64,f64,i32,f64,f64,f64,f64,f64,f64,f64,f64,i32,f64,f64,f64,f64,i8,f64,f64,f64,i64,f64,i32,i64,f64,f64,f64,i8
25,123,4000.0,0.0,1.0,0.0,0.0,0.0,1.0,1,"""905909166d934c618ad55ab7f5cea5…",0,0.0,0.066667,1.0,32,3.496508,1,0,"""TOF""",0,0,0,2,0,2,140,0,29,730,0,730,11,6,21,6,0,…,0.928571,0.0,0.0,1.75,1,0.326639,52.734974,58.652095,50.0,0.0,0.601038,1,0.642857,0.0,0.0,-27.570182,-5.0,0.1,45.0,14.0,9,0.918719,0.071429,0.071429,0.0,1,0.0,0.0,0.0,28,0.0,9,1,0.857143,0.0,3.333333,1
26,0,4000.0,0.0,1.0,0.0,0.0,0.0,1.0,1,"""905909166d934c618ad55ab7f5cea5…",0,7.142857,0.133333,1.050282,32,3.496508,0,0,"""TOF""",0,0,0,2,0,2,140,0,29,730,0,730,11,6,21,6,0,…,0.928571,0.0,0.0,1.75,1,0.326639,52.734974,58.652095,50.0,0.0,0.601038,1,0.642857,0.0,0.0,-27.570182,-5.0,0.1,45.0,14.0,9,0.918719,0.071429,0.071429,0.0,1,0.0,0.0,0.0,28,0.0,9,1,0.857143,0.0,6.666667,1
27,123,2300.0,0.0,1.0,3500.0,0.0,1.0,1.0,1,"""905909166d934c618ad55ab7f5cea5…",0,35.714286,0.4,1.235318,32,3.496508,1,0,"""TOF""",0,0,0,2,0,2,140,0,29,730,0,730,11,6,21,6,0,…,0.928571,0.0,0.0,1.75,1,0.326639,52.734974,58.652095,50.0,0.0,0.601038,1,0.642857,0.0,0.0,-27.570182,-5.0,0.1,45.0,14.0,9,0.918719,0.071429,0.071429,0.0,1,0.0,0.0,0.0,28,0.0,9,1,0.857143,0.0,20.0,1
28,0,2300.0,0.0,1.0,3500.0,0.0,1.0,1.0,1,"""905909166d934c618ad55ab7f5cea5…",0,42.857143,0.466667,1.297667,32,3.496508,0,0,"""TOF""",0,0,0,2,0,2,140,0,29,730,0,730,11,6,21,6,0,…,0.928571,0.0,0.0,1.75,1,0.326639,52.734974,58.652095,50.0,0.0,0.601038,1,0.642857,0.0,0.0,-27.570182,-5.0,0.1,45.0,14.0,9,0.918719,0.071429,0.071429,0.0,1,0.0,0.0,0.0,28,0.0,9,1,0.857143,0.0,23.333333,1
29,123,0.0,0.0,1.0,0.0,0.0,1.0,1.0,1,"""905909166d934c618ad55ab7f5cea5…",0,78.571429,0.8,2.648331,32,3.496508,1,0,"""TOF""",0,0,0,2,0,2,140,0,29,730,0,730,11,6,21,6,0,…,0.928571,0.0,0.0,1.75,1,0.326639,52.734974,58.652095,50.0,0.0,0.601038,1,0.642857,0.0,0.0,-27.570182,-5.0,0.1,45.0,14.0,9,0.918719,0.071429,0.071429,0.0,1,0.0,0.0,0.0,28,0.0,9,1,0.857143,0.0,40.0,1
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
120,0,2300.0,0.0,1.0,3500.0,0.0,1.0,1.0,1,"""e0f9319a8b3048cdb1c974395e599e…",1,4.347826,0.083333,1.293048,535,6.284134,0,0,"""GDX""",0,0,0,1,0,1,330,0,42,330,0,330,10,5,11,5,0,…,0.971154,0.0,0.0,1.740385,2,0.375849,98.484632,62.123105,49.350265,0.461538,0.619922,1,1.144231,0.0,0.0,-34.951512,-4.935026,0.022222,53.905902,34.666667,44,0.875862,0.019231,0.019231,1.0,1,0.0,0.0,0.0,0,0.0,0,1,0.701923,0.0,4.112522,0
121,0,0.0,0.0,1.0,0.0,0.0,1.0,1.0,1,"""e0f9319a8b3048cdb1c974395e599e…",0,47.826087,0.5,1.98212,535,6.284134,0,0,"""GDX""",0,0,0,1,0,1,330,0,42,330,0,330,10,5,11,5,0,…,0.971154,0.0,0.0,1.740385,2,0.375849,98.484632,62.123105,49.350265,0.461538,0.619922,1,1.144231,0.0,0.0,-34.951512,-4.935026,0.022222,53.905902,34.666667,44,0.875862,0.019231,0.019231,1.0,1,0.0,0.0,0.0,0,0.0,0,1,0.701923,0.0,24.675132,0
122,0,4000.0,0.0,1.0,0.0,0.0,0.0,0.0,1,"""e0f9319a8b3048cdb1c974395e599e…",0,52.173913,0.541667,2.263128,535,6.284134,0,0,"""GDX""",0,0,0,1,0,1,330,0,42,330,0,330,10,5,11,5,0,…,0.971154,0.0,0.0,1.740385,2,0.375849,98.484632,62.123105,49.350265,0.461538,0.619922,1,1.144231,0.0,0.0,-34.951512,-4.935026,0.022222,53.905902,34.666667,44,0.875862,0.019231,0.019231,1.0,1,0.0,0.0,0.0,0,0.0,0,1,0.701923,0.0,26.731393,0
123,0,2300.0,0.0,1.0,3500.0,0.0,1.0,0.0,1,"""e0f9319a8b3048cdb1c974395e599e…",0,60.869565,0.625,2.901737,535,6.284134,0,0,"""GDX""",0,0,0,1,0,1,330,0,42,330,0,330,10,5,11,5,0,…,0.971154,0.0,0.0,1.740385,2,0.375849,98.484632,62.123105,49.350265,0.461538,0.619922,1,1.144231,0.0,0.0,-34.951512,-4.935026,0.022222,53.905902,34.666667,44,0.875862,0.019231,0.019231,1.0,1,0.0,0.0,0.0,0,0.0,0,1,0.701923,0.0,30.843915,0


In [155]:
X_train.shape

(12880729, 121)

In [160]:
del data
gc.collect()

12

In [4]:
def prepare_ranking_data(X):
    # Sort the data sets for ranking
    sort_indices = X['ranker_id'].arg_sort()
    X = X[sort_indices]
    y = X['selected'].to_numpy()

    # Get search group sizes for chunked ranking
    group_sizes = (
        X
        .group_by('ranker_id', maintain_order=True)
        .len()
        .select('len')
        .to_numpy()
        .flatten()
    )

    return X.drop(['ranker_id','selected']), y, group_sizes


def train_lgb_ranker(X, y, group_sizes):
    """
    Train LGBRanker with proper group structure
    """
    categorical_features = [col for col in X.columns if not X[col].dtype.is_numeric()]
    X = X.with_columns([pl.col(col).cast(pl.Categorical) for col in categorical_features])

    # Convert to pandas for LightGBM compatibility
    X = X.to_pandas()

    train_data = lgb.Dataset(
        X, label=y,
        group=group_sizes,  # This tells LGB how many samples per search
        categorical_feature=categorical_features
    )

    params = {
        'objective': 'lambdarank',
        'metric': 'ndcg',
        'ndcg_eval_at': [3],
        'boosting_type': 'gbdt',
        'num_leaves': 255,
        'learning_rate': 0.1,
        'feature_fraction': 0.8,
        'bagging_fraction': 0.8,
        'bagging_freq': 5,
        'min_data_in_leaf': 50,
        'num_threads': -1,
        'verbose': -1
    }

    model = lgb.train(
        params,
        train_data,
        num_boost_round=1000,
        # callbacks=[lgb.early_stopping(100), lgb.log_evaluation(100)]
        callbacks=[lgb.log_evaluation(100)]
    )

    return model

In [5]:
def calculate_hit_rate_at_k(y_true, y_pred, groups, k=3):
    """
    Calculate Hit Rate @ K for ranking model

    Parameters:
    - y_true: actual binary labels (1 if selected, 0 if not)
    - y_pred: predicted scores from ranker
    - groups: group sizes (number of flights per search)
    - k: number of top predictions to consider (3 in your case)

    Returns:
    - hit_rate: fraction of searches where at least 1 of top-k predictions was correct
    """
    hits = 0
    total_searches = 0
    start_idx = 0

    for group_size in groups:
        # Extract data for this search group
        group_true = y_true[start_idx:start_idx + group_size]
        group_pred = y_pred[start_idx:start_idx + group_size]

        # Get indices of actual selections
        actual_selections = set(np.where(group_true == 1)[0])

        # Get top-k predictions (indices sorted by prediction score)
        top_k_indices = np.argsort(group_pred)[-k:]  # Get indices of top-k scores
        top_k_indices = set(top_k_indices)

        # Check if any top-k prediction matches actual selection
        if len(actual_selections.intersection(top_k_indices)) > 0:
            hits += 1

        total_searches += 1
        start_idx += group_size

    hit_rate = hits / total_searches if total_searches > 0 else 0
    return hit_rate

def detailed_hit_rate_evaluation(y_true, y_pred, groups, ranker_ids=None, k=3):
    """
    More detailed evaluation with additional metrics
    """
    hits = 0
    total_searches = 0
    hit_details = []
    start_idx = 0

    for i, group_size in enumerate(groups):
        group_true = y_true[start_idx:start_idx + group_size]
        group_pred = y_pred[start_idx:start_idx + group_size]

        # Actual selections
        actual_selections = np.where(group_true == 1)[0]
        num_actual_selections = len(actual_selections)

        # Top-k predictions
        top_k_indices = np.argsort(group_pred)[-k:]

        # Calculate hit
        hit = len(set(actual_selections).intersection(set(top_k_indices))) > 0

        if hit:
            hits += 1

        # Store details for analysis
        search_id = ranker_ids[i] if ranker_ids is not None else i
        hit_details.append({
            'search_id': search_id,
            'group_size': group_size,
            'num_actual_selections': num_actual_selections,
            'hit': hit,
            'top_k_scores': group_pred[top_k_indices].tolist(),
            'actual_selections_scores': group_pred[actual_selections].tolist() if len(actual_selections) > 0 else []
        })

        total_searches += 1
        start_idx += group_size

    hit_rate = hits / total_searches

    return hit_rate, hit_details

In [6]:
def evaluate_model_hit_rate(model, X_val, y_val, group_sizes_val, k=3):
    """
    Evaluate trained LGBRanker model using HitRate@K
    """
    # Get predictions
    y_pred = model.predict(X_val)

    # Calculate hit rate
    hit_rate = calculate_hit_rate_at_k(y_val, y_pred, group_sizes_val, k=k)

    print(f"Hit Rate @ {k}: {hit_rate:.4f}")
    return hit_rate

def cross_validate_with_hit_rate(df, n_folds=5):
    """
    Cross-validation preserving search groups and calculating HitRate@3
    """
    # Get unique search IDs
    unique_searches = df['ranker_id'].unique()
    np.random.shuffle(unique_searches)

    fold_size = len(unique_searches) // n_folds
    hit_rates = []

    for fold in range(n_folds):
        print(f"\nFold {fold + 1}/{n_folds}")

        # Split search IDs
        start_idx = fold * fold_size
        end_idx = start_idx + fold_size if fold < n_folds - 1 else len(unique_searches)
        val_searches = unique_searches[start_idx:end_idx]
        train_searches = np.setdiff1d(unique_searches, val_searches)

        # Split data
        train_df = df[df['ranker_id'].isin(train_searches)]
        val_df = df[df['ranker_id'].isin(val_searches)]

        # Prepare training data
        X_train, y_train, groups_train = prepare_ranking_data(train_df)
        X_val, y_val, groups_val = prepare_ranking_data(val_df)

        # Train model
        model = train_lgb_ranker(X_train, y_train, groups_train)

        # Evaluate
        hit_rate = evaluate_model_hit_rate(model, X_val, y_val, groups_val)
        hit_rates.append(hit_rate)

    print(f"\nCross-validation results:")
    print(f"Mean Hit Rate @ 3: {np.mean(hit_rates):.4f} (+/- {np.std(hit_rates)*2:.4f})")
    return hit_rates

In [None]:
X_train.write_parquet('data/X_train.parquet')
X_test.write_parquet('data/X_test.parquet')

In [None]:
del X_test
gc.collect()

In [2]:
X_train = pl.read_parquet('data/X_train.parquet')

In [7]:
# Prepare Train data for ranking
X, y, group_sizes = prepare_ranking_data(X_train)

del X_train
gc.collect()

26609

In [8]:
# Train model
model = train_lgb_ranker(X, y, group_sizes)

In [9]:
# After training your model
def complete_evaluation(model, test_df):
    """
    Complete evaluation workflow
    """
    print("Preparing test data...")
    X_test, y_test, groups_test = prepare_ranking_data(test_df)

    # Prepare testing data for ranking
    categorical_features = [col for col in X_test.columns if not X_test[col].dtype.is_numeric()]
    X_test = X_test.with_columns([pl.col(col).cast(pl.Categorical) for col in categorical_features])

    # Convert to pandas for LightGBM compatibility
    X_test = X_test.to_pandas()

    print("Making predictions...")
    y_pred = model.predict(X_test)

    print("Calculating Hit Rate @ 3...")
    hit_rate, details = detailed_hit_rate_evaluation(
        y_test,
        y_pred,
        groups_test,
        ranker_ids=test_df['ranker_id'].unique()
    )

    print(f"Overall Hit Rate @ 3: {hit_rate:.4f}")

    return hit_rate, details, y_pred


In [10]:
del X
del y
del group_sizes
gc.collect()

53

In [11]:
test_df = pl.read_parquet('data/X_test.parquet')

In [12]:
# Usage
hit_rate, details, y_pred = complete_evaluation(model, test_df)

Preparing test data...
Making predictions...
Calculating Hit Rate @ 3...
Overall Hit Rate @ 3: 0.6454


In [13]:
y_pred

array([ -4.69552722,  -3.93674295,  -4.02498223, ..., -12.10126926,
       -11.08110956, -12.51480663], shape=(3126995,))

In [32]:
test_df.head(100)

Id,corporateTariffCode,miniRules0_monetaryAmount,miniRules0_percentage,miniRules0_statusInfos,miniRules1_monetaryAmount,miniRules1_percentage,miniRules1_statusInfos,pricingInfo_isAccessTP,pricingInfo_passengerCount,ranker_id,selected,price_percentile,price_rank_pct,price_ratio_to_min,route_popularity,route_popularity_log,is_access3D,is_roundtrip,route_origin,origin_is_major_hub,destination_is_major_hub,includes_major_hub,leg0_num_segments,leg1_num_segments,total_segments,leg0_flight_time_min,leg1_flight_time_min,booking_lead_days,leg0_duration_minutes,leg1_duration_minutes,trip_duration_minutes,leg0_departure_hour,leg0_departure_weekday,leg0_arrival_hour,leg0_arrival_weekday,leg1_departure_hour,…,route_loyalty,hub_preference,short_haul_preference,connection_tolerance,preferred_duration_quartile,price_to_duration_sensitivity,avg_price_per_minute,price_per_minute_variance,price_position_preference,premium_economy_preference,consistent_price_tier,preferred_price_tier,baggage_qty_preference,baggage_weight_preference,loyalty_program_utilization,convenience_priority_score,loyalty_vs_price_index,planning_consistency_score,luxury_index,search_intensity_per_route,lead_time_variance,lead_time_skew,carrier_diversity,airport_diversity,cabin_class_range,customer_tier,vip_search_intensity,vip_carrier_diversity,vip_cabin_preference,corp_search_volume,corp_roundtrip_pref,corp_planning_variance,customer_segment,daytime_alignment,weekend_alignment,price_preference_match,carrier_loyalty_match
i64,i64,f64,f64,f64,f64,f64,f64,f64,i64,str,i64,f64,f64,f64,u32,f64,i32,i8,str,i32,i32,i32,u32,u32,u32,i32,i32,i32,i32,i32,i32,i8,i8,i8,i8,i8,…,f64,f64,f64,f64,i32,f64,f64,f64,f64,f64,f64,i32,f64,f64,f64,f64,f64,f64,f64,f64,i32,f64,f64,f64,f64,i8,f64,f64,f64,i64,f64,i32,i64,f64,f64,f64,i8
0,0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1,"""98ce0dabf6964640b63079fbafd42c…",1,0.0,0.04,1.0,25,3.258097,0,1,"""TLK""",0,0,0,1,1,2,160,155,29,160,155,315,15,6,16,6,9,…,0.962264,0.0,0.0,1.981132,1,-0.253891,93.370333,68.147027,45.055998,0.141509,0.619049,1,1.018868,0.0,1.0,-32.58008,5.4944,0.052632,45.358188,26.5,18,1.433237,0.056604,0.028302,1.0,1,0.0,0.0,0.0,106,0.235849,18,6,1.0,0.0,1.80224,0
1,123,2300.0,0.0,1.0,3500.0,0.0,1.0,1.0,1,"""98ce0dabf6964640b63079fbafd42c…",0,4.166667,0.08,3.028015,25,3.258097,1,1,"""TLK""",0,0,0,2,2,4,250,245,29,445,505,950,9,6,14,6,22,…,0.962264,0.0,0.0,1.981132,1,-0.253891,93.370333,68.147027,45.055998,0.141509,0.619049,1,1.018868,0.0,1.0,-32.58008,5.4944,0.052632,45.358188,26.5,18,1.433237,0.056604,0.028302,1.0,1,0.0,0.0,0.0,106,0.235849,18,6,1.0,0.0,3.60448,1
2,0,2300.0,0.0,1.0,3500.0,0.0,1.0,1.0,1,"""98ce0dabf6964640b63079fbafd42c…",0,29.166667,0.32,3.18023,25,3.258097,0,1,"""TLK""",0,0,0,2,2,4,250,245,29,445,505,950,9,6,14,6,22,…,0.962264,0.0,0.0,1.981132,1,-0.253891,93.370333,68.147027,45.055998,0.141509,0.619049,1,1.018868,0.0,1.0,-32.58008,5.4944,0.052632,45.358188,26.5,18,1.433237,0.056604,0.028302,1.0,1,0.0,0.0,0.0,106,0.235849,18,6,1.0,0.0,14.417919,1
3,123,0.0,0.0,1.0,0.0,0.0,1.0,1.0,1,"""98ce0dabf6964640b63079fbafd42c…",0,54.166667,0.56,4.849562,25,3.258097,1,1,"""TLK""",0,0,0,2,2,4,250,245,29,445,505,950,9,6,14,6,22,…,0.962264,0.0,0.0,1.981132,1,-0.253891,93.370333,68.147027,45.055998,0.141509,0.619049,1,1.018868,0.0,1.0,-32.58008,5.4944,0.052632,45.358188,26.5,18,1.433237,0.056604,0.028302,1.0,1,0.0,0.0,0.0,106,0.235849,18,6,1.0,0.0,25.231359,1
4,0,0.0,0.0,1.0,0.0,0.0,1.0,1.0,1,"""98ce0dabf6964640b63079fbafd42c…",0,79.166667,0.8,5.097726,25,3.258097,0,1,"""TLK""",0,0,0,2,2,4,250,245,29,445,505,950,9,6,14,6,22,…,0.962264,0.0,0.0,1.981132,1,-0.253891,93.370333,68.147027,45.055998,0.141509,0.619049,1,1.018868,0.0,1.0,-32.58008,5.4944,0.052632,45.358188,26.5,18,1.433237,0.056604,0.028302,1.0,1,0.0,0.0,0.0,106,0.235849,18,6,1.0,0.0,36.044799,1
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
378,44,4000.0,0.0,1.0,4000.0,0.0,1.0,0.0,1,"""8539a5b6f7474362bd51378d180245…",0,91.803279,0.919355,4.921193,4522,8.416931,1,0,"""OVB""",0,0,0,2,0,2,445,0,15,835,0,835,21,6,10,7,0,…,0.983871,0.0,0.0,1.290323,1,0.684296,88.889925,69.170345,47.408778,0.0,0.624162,1,0.66129,0.0,0.0,-57.009692,-4.740878,0.5,43.704389,62.0,1,0.987097,0.048387,0.016129,0.0,1,0.0,0.0,0.0,62,0.0,1,1,0.709677,0.0,43.58549,0
379,0,4000.0,0.0,1.0,4000.0,0.0,1.0,0.0,1,"""8539a5b6f7474362bd51378d180245…",0,93.442623,0.935484,5.415335,4522,8.416931,0,0,"""OVB""",0,0,0,2,0,2,445,0,15,835,0,835,21,6,10,7,0,…,0.983871,0.0,0.0,1.290323,1,0.684296,88.889925,69.170345,47.408778,0.0,0.624162,1,0.66129,0.0,0.0,-57.009692,-4.740878,0.5,43.704389,62.0,1,0.987097,0.048387,0.016129,0.0,1,0.0,0.0,0.0,62,0.0,1,1,0.709677,0.0,44.350148,0
380,115,4000.0,0.0,1.0,0.0,0.0,0.0,0.0,1,"""8539a5b6f7474362bd51378d180245…",0,55.737705,0.580645,3.461768,4522,8.416931,1,0,"""OVB""",0,0,0,2,0,2,240,0,15,1135,0,1135,23,6,18,7,0,…,0.983871,0.0,0.0,1.290323,1,0.684296,88.889925,69.170345,47.408778,0.0,0.624162,1,0.66129,0.0,0.0,-57.009692,-4.740878,0.5,43.704389,62.0,1,0.987097,0.048387,0.016129,0.0,1,0.0,0.0,0.0,62,0.0,1,1,0.0,0.0,27.527678,1
381,0,4000.0,0.0,1.0,0.0,0.0,0.0,0.0,1,"""8539a5b6f7474362bd51378d180245…",0,68.852459,0.709677,3.670501,4522,8.416931,0,0,"""OVB""",0,0,0,2,0,2,240,0,15,1135,0,1135,23,6,18,7,0,…,0.983871,0.0,0.0,1.290323,1,0.684296,88.889925,69.170345,47.408778,0.0,0.624162,1,0.66129,0.0,0.0,-57.009692,-4.740878,0.5,43.704389,62.0,1,0.987097,0.048387,0.016129,0.0,1,0.0,0.0,0.0,62,0.0,1,1,0.0,0.0,33.64494,1


In [14]:
test_df = test_df.with_columns(pl.Series('y_pred', y_pred))

In [21]:
test_df = test_df.with_columns(pl.col('y_pred').rank(method='ordinal', descending=True).over('ranker_id').alias('flight_rank'))

In [22]:
test_df.head(100)

Id,corporateTariffCode,miniRules0_monetaryAmount,miniRules0_percentage,miniRules0_statusInfos,miniRules1_monetaryAmount,miniRules1_percentage,miniRules1_statusInfos,pricingInfo_isAccessTP,pricingInfo_passengerCount,ranker_id,selected,price_percentile,price_rank_pct,price_ratio_to_min,route_popularity,route_popularity_log,is_access3D,is_roundtrip,route_origin,origin_is_major_hub,destination_is_major_hub,includes_major_hub,leg0_num_segments,leg1_num_segments,total_segments,leg0_flight_time_min,leg1_flight_time_min,booking_lead_days,leg0_duration_minutes,leg1_duration_minutes,trip_duration_minutes,leg0_departure_hour,leg0_departure_weekday,leg0_arrival_hour,leg0_arrival_weekday,leg1_departure_hour,…,short_haul_preference,connection_tolerance,preferred_duration_quartile,price_to_duration_sensitivity,avg_price_per_minute,price_per_minute_variance,price_position_preference,premium_economy_preference,consistent_price_tier,preferred_price_tier,baggage_qty_preference,baggage_weight_preference,loyalty_program_utilization,convenience_priority_score,loyalty_vs_price_index,planning_consistency_score,luxury_index,search_intensity_per_route,lead_time_variance,lead_time_skew,carrier_diversity,airport_diversity,cabin_class_range,customer_tier,vip_search_intensity,vip_carrier_diversity,vip_cabin_preference,corp_search_volume,corp_roundtrip_pref,corp_planning_variance,customer_segment,daytime_alignment,weekend_alignment,price_preference_match,carrier_loyalty_match,y_pred,flight_rank
i64,i64,f64,f64,f64,f64,f64,f64,f64,i64,str,i64,f64,f64,f64,u32,f64,i32,i8,str,i32,i32,i32,u32,u32,u32,i32,i32,i32,i32,i32,i32,i8,i8,i8,i8,i8,…,f64,f64,i32,f64,f64,f64,f64,f64,f64,i32,f64,f64,f64,f64,f64,f64,f64,f64,i32,f64,f64,f64,f64,i8,f64,f64,f64,i64,f64,i32,i64,f64,f64,f64,i8,f64,u32
0,0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1,"""98ce0dabf6964640b63079fbafd42c…",1,0.0,0.04,1.0,25,3.258097,0,1,"""TLK""",0,0,0,1,1,2,160,155,29,160,155,315,15,6,16,6,9,…,0.0,1.981132,1,-0.253891,93.370333,68.147027,45.055998,0.141509,0.619049,1,1.018868,0.0,1.0,-32.58008,5.4944,0.052632,45.358188,26.5,18,1.433237,0.056604,0.028302,1.0,1,0.0,0.0,0.0,106,0.235849,18,6,1.0,0.0,1.80224,0,-4.695527,13
1,123,2300.0,0.0,1.0,3500.0,0.0,1.0,1.0,1,"""98ce0dabf6964640b63079fbafd42c…",0,4.166667,0.08,3.028015,25,3.258097,1,1,"""TLK""",0,0,0,2,2,4,250,245,29,445,505,950,9,6,14,6,22,…,0.0,1.981132,1,-0.253891,93.370333,68.147027,45.055998,0.141509,0.619049,1,1.018868,0.0,1.0,-32.58008,5.4944,0.052632,45.358188,26.5,18,1.433237,0.056604,0.028302,1.0,1,0.0,0.0,0.0,106,0.235849,18,6,1.0,0.0,3.60448,1,-3.936743,6
2,0,2300.0,0.0,1.0,3500.0,0.0,1.0,1.0,1,"""98ce0dabf6964640b63079fbafd42c…",0,29.166667,0.32,3.18023,25,3.258097,0,1,"""TLK""",0,0,0,2,2,4,250,245,29,445,505,950,9,6,14,6,22,…,0.0,1.981132,1,-0.253891,93.370333,68.147027,45.055998,0.141509,0.619049,1,1.018868,0.0,1.0,-32.58008,5.4944,0.052632,45.358188,26.5,18,1.433237,0.056604,0.028302,1.0,1,0.0,0.0,0.0,106,0.235849,18,6,1.0,0.0,14.417919,1,-4.024982,7
3,123,0.0,0.0,1.0,0.0,0.0,1.0,1.0,1,"""98ce0dabf6964640b63079fbafd42c…",0,54.166667,0.56,4.849562,25,3.258097,1,1,"""TLK""",0,0,0,2,2,4,250,245,29,445,505,950,9,6,14,6,22,…,0.0,1.981132,1,-0.253891,93.370333,68.147027,45.055998,0.141509,0.619049,1,1.018868,0.0,1.0,-32.58008,5.4944,0.052632,45.358188,26.5,18,1.433237,0.056604,0.028302,1.0,1,0.0,0.0,0.0,106,0.235849,18,6,1.0,0.0,25.231359,1,-4.133711,8
4,0,0.0,0.0,1.0,0.0,0.0,1.0,1.0,1,"""98ce0dabf6964640b63079fbafd42c…",0,79.166667,0.8,5.097726,25,3.258097,0,1,"""TLK""",0,0,0,2,2,4,250,245,29,445,505,950,9,6,14,6,22,…,0.0,1.981132,1,-0.253891,93.370333,68.147027,45.055998,0.141509,0.619049,1,1.018868,0.0,1.0,-32.58008,5.4944,0.052632,45.358188,26.5,18,1.433237,0.056604,0.028302,1.0,1,0.0,0.0,0.0,106,0.235849,18,6,1.0,0.0,36.044799,1,-2.884729,1
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
378,44,4000.0,0.0,1.0,4000.0,0.0,1.0,0.0,1,"""8539a5b6f7474362bd51378d180245…",0,91.803279,0.919355,4.921193,4522,8.416931,1,0,"""OVB""",0,0,0,2,0,2,445,0,15,835,0,835,21,6,10,7,0,…,0.0,1.290323,1,0.684296,88.889925,69.170345,47.408778,0.0,0.624162,1,0.66129,0.0,0.0,-57.009692,-4.740878,0.5,43.704389,62.0,1,0.987097,0.048387,0.016129,0.0,1,0.0,0.0,0.0,62,0.0,1,1,0.709677,0.0,43.58549,0,-4.244531,18
379,0,4000.0,0.0,1.0,4000.0,0.0,1.0,0.0,1,"""8539a5b6f7474362bd51378d180245…",0,93.442623,0.935484,5.415335,4522,8.416931,0,0,"""OVB""",0,0,0,2,0,2,445,0,15,835,0,835,21,6,10,7,0,…,0.0,1.290323,1,0.684296,88.889925,69.170345,47.408778,0.0,0.624162,1,0.66129,0.0,0.0,-57.009692,-4.740878,0.5,43.704389,62.0,1,0.987097,0.048387,0.016129,0.0,1,0.0,0.0,0.0,62,0.0,1,1,0.709677,0.0,44.350148,0,-4.711212,24
380,115,4000.0,0.0,1.0,0.0,0.0,0.0,0.0,1,"""8539a5b6f7474362bd51378d180245…",0,55.737705,0.580645,3.461768,4522,8.416931,1,0,"""OVB""",0,0,0,2,0,2,240,0,15,1135,0,1135,23,6,18,7,0,…,0.0,1.290323,1,0.684296,88.889925,69.170345,47.408778,0.0,0.624162,1,0.66129,0.0,0.0,-57.009692,-4.740878,0.5,43.704389,62.0,1,0.987097,0.048387,0.016129,0.0,1,0.0,0.0,0.0,62,0.0,1,1,0.0,0.0,27.527678,1,-5.123382,28
381,0,4000.0,0.0,1.0,0.0,0.0,0.0,0.0,1,"""8539a5b6f7474362bd51378d180245…",0,68.852459,0.709677,3.670501,4522,8.416931,0,0,"""OVB""",0,0,0,2,0,2,240,0,15,1135,0,1135,23,6,18,7,0,…,0.0,1.290323,1,0.684296,88.889925,69.170345,47.408778,0.0,0.624162,1,0.66129,0.0,0.0,-57.009692,-4.740878,0.5,43.704389,62.0,1,0.987097,0.048387,0.016129,0.0,1,0.0,0.0,0.0,62,0.0,1,1,0.0,0.0,33.64494,1,-2.004948,5


In [16]:
test_df.group_by(pl.col('ranker_id')).agg(pl.len().alias('count')).sort('count', descending=False).filter(pl.col('count') > 10).head(1000)

ranker_id,count
str,u32
"""000f30002ec442d1b76b722e9d45a0…",11
"""e2cacfc868714ec7b345a43c6e9648…",11
"""d97b0f1e1d454a60850087480f148f…",11
"""8f1d7c8ef9414b8fb0a86ac35cc8d3…",11
"""ab5bf3d4476e44a181cad76796c318…",11
…,…
"""1ada8b80c5ab4e928fe535a69178ee…",15
"""303e009248f84e6ba96307bcdd932d…",15
"""2e45ea85166e4161933a3861c4a067…",15
"""0f8fddc4bca0471591779c357e2887…",15


In [24]:
test_df.filter(pl.col('ranker_id') == "000f30002ec442d1b76b722e9d45a000")['selected','y_pred', 'flight_rank'].head(100)

selected,y_pred,flight_rank
i64,f64,u32
1,-11.687449,1
0,-14.752849,8
0,-13.130669,3
0,-13.614557,5
0,-14.768519,9
…,…,…
0,-12.664439,2
0,-15.01319,10
0,-13.679878,6
0,-14.401925,7


In [19]:
next((item for item in details if item['search_id'] == "000f30002ec442d1b76b722e9d45a000"), None)

{'search_id': '000f30002ec442d1b76b722e9d45a000',
 'group_size': np.uint32(451),
 'num_actual_selections': 1,
 'hit': True,
 'top_k_scores': [-1.5757630599375563,
  -1.5121979278197502,
  -0.5738580460759077],
 'actual_selections_scores': [-1.5757630599375563]}

In [2]:
sample = pl.read_parquet('/kaggle/input/aeroclub-recsys-2025/sample_submission.parquet')

In [3]:
sample.head(100)

Id,ranker_id,selected,__index_level_0__
i64,str,i64,i64
18144679,"""c9373e5f772e43d593dd6ad2fa90f6…",178,18144679
18144680,"""c9373e5f772e43d593dd6ad2fa90f6…",363,18144680
18144681,"""c9373e5f772e43d593dd6ad2fa90f6…",277,18144681
18144682,"""c9373e5f772e43d593dd6ad2fa90f6…",183,18144682
18144683,"""c9373e5f772e43d593dd6ad2fa90f6…",55,18144683
…,…,…,…
18144774,"""c9373e5f772e43d593dd6ad2fa90f6…",324,18144774
18144775,"""c9373e5f772e43d593dd6ad2fa90f6…",59,18144775
18144776,"""c9373e5f772e43d593dd6ad2fa90f6…",147,18144776
18144777,"""c9373e5f772e43d593dd6ad2fa90f6…",233,18144777


In [5]:
sample.filter(pl.col('ranker_id') == "c9373e5f772e43d593dd6ad2fa90f67a").sort("selected", descending=False)

Id,ranker_id,selected,__index_level_0__
i64,str,i64,i64
18144870,"""c9373e5f772e43d593dd6ad2fa90f6…",1,18144870
18145030,"""c9373e5f772e43d593dd6ad2fa90f6…",2,18145030
18144936,"""c9373e5f772e43d593dd6ad2fa90f6…",3,18144936
18144748,"""c9373e5f772e43d593dd6ad2fa90f6…",4,18144748
18144823,"""c9373e5f772e43d593dd6ad2fa90f6…",5,18144823
…,…,…,…
18145000,"""c9373e5f772e43d593dd6ad2fa90f6…",408,18145000
18144751,"""c9373e5f772e43d593dd6ad2fa90f6…",409,18144751
18144918,"""c9373e5f772e43d593dd6ad2fa90f6…",410,18144918
18144688,"""c9373e5f772e43d593dd6ad2fa90f6…",411,18144688
