### Imports

In [58]:
import gc
import re
from typing import List, Dict, Tuple

import numpy as np
import polars as pl
import polars.selectors as cs

import joblib
import lightgbm as lgb
from scipy.stats import zscore
from sklearn.cluster import AgglomerativeClustering
from sklearn.decomposition import PCA, TruncatedSVD
from sklearn.ensemble import IsolationForest
from sklearn.metrics import silhouette_score, calinski_harabasz_score, davies_bouldin_score
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import RobustScaler
import umap

### Get the Training Data Set

In [2]:
# data = pl.concat([train, pl.read_parquet('/kaggle/input/aeroclub-recsys-2025/test.parquet')])
train_df = pl.read_parquet('/kaggle/input/aeroclub-recsys-2025/train.parquet')

In [3]:
train_df.head(1000)

Id,bySelf,companyID,corporateTariffCode,frequentFlyer,nationality,isAccess3D,isVip,legs0_arrivalAt,legs0_departureAt,legs0_duration,legs0_segments0_aircraft_code,legs0_segments0_arrivalTo_airport_city_iata,legs0_segments0_arrivalTo_airport_iata,legs0_segments0_baggageAllowance_quantity,legs0_segments0_baggageAllowance_weightMeasurementType,legs0_segments0_cabinClass,legs0_segments0_departureFrom_airport_iata,legs0_segments0_duration,legs0_segments0_flightNumber,legs0_segments0_marketingCarrier_code,legs0_segments0_operatingCarrier_code,legs0_segments0_seatsAvailable,legs0_segments1_aircraft_code,legs0_segments1_arrivalTo_airport_city_iata,legs0_segments1_arrivalTo_airport_iata,legs0_segments1_baggageAllowance_quantity,legs0_segments1_baggageAllowance_weightMeasurementType,legs0_segments1_cabinClass,legs0_segments1_departureFrom_airport_iata,legs0_segments1_duration,legs0_segments1_flightNumber,legs0_segments1_marketingCarrier_code,legs0_segments1_operatingCarrier_code,legs0_segments1_seatsAvailable,legs0_segments2_aircraft_code,legs0_segments2_arrivalTo_airport_city_iata,…,legs1_segments2_baggageAllowance_weightMeasurementType,legs1_segments2_cabinClass,legs1_segments2_departureFrom_airport_iata,legs1_segments2_duration,legs1_segments2_flightNumber,legs1_segments2_marketingCarrier_code,legs1_segments2_operatingCarrier_code,legs1_segments2_seatsAvailable,legs1_segments3_aircraft_code,legs1_segments3_arrivalTo_airport_city_iata,legs1_segments3_arrivalTo_airport_iata,legs1_segments3_baggageAllowance_quantity,legs1_segments3_baggageAllowance_weightMeasurementType,legs1_segments3_cabinClass,legs1_segments3_departureFrom_airport_iata,legs1_segments3_duration,legs1_segments3_flightNumber,legs1_segments3_marketingCarrier_code,legs1_segments3_operatingCarrier_code,legs1_segments3_seatsAvailable,miniRules0_monetaryAmount,miniRules0_percentage,miniRules0_statusInfos,miniRules1_monetaryAmount,miniRules1_percentage,miniRules1_statusInfos,pricingInfo_isAccessTP,pricingInfo_passengerCount,profileId,ranker_id,requestDate,searchRoute,sex,taxes,totalPrice,selected,__index_level_0__
i64,bool,i64,i64,str,i64,bool,bool,str,str,str,str,str,str,f64,f64,f64,str,str,str,str,str,f64,str,str,str,f64,f64,f64,str,str,str,str,str,f64,str,str,…,f64,f64,str,str,str,str,str,f64,str,str,str,f64,f64,f64,str,str,str,str,str,f64,f64,f64,f64,f64,f64,f64,f64,i64,i64,str,datetime[ns],str,bool,f64,f64,i64,i64
0,true,57323,,"""S7/SU/UT""",36,false,false,"""2024-06-15T16:20:00""","""2024-06-15T15:40:00""","""02:40:00""","""YK2""","""KJA""","""KJA""",1.0,0.0,1.0,"""TLK""","""02:40:00""","""216""","""KV""","""KV""",9.0,,,,,,,,,,,,,,,…,,,,,,,,,,,,,,,,,,,,,,,,,,,1.0,1,2087645,"""98ce0dabf6964640b63079fbafd42c…",2024-05-17 03:03:08,"""TLKKJA/KJATLK""",true,370.0,16884.0,1,0
1,true,57323,123,"""S7/SU/UT""",36,true,false,"""2024-06-15T14:50:00""","""2024-06-15T09:25:00""","""07:25:00""","""E70""","""OVB""","""OVB""",1.0,0.0,1.0,"""TLK""","""02:50:00""","""5358""","""S7""","""S7""",4.0,"""E70""","""KJA""","""KJA""",1.0,0.0,1.0,"""OVB""","""01:20:00""","""5311""","""S7""","""S7""",4.0,,,…,,,,,,,,,,,,,,,,,,,,,2300.0,,1.0,3500.0,,1.0,1.0,1,2087645,"""98ce0dabf6964640b63079fbafd42c…",2024-05-17 03:03:08,"""TLKKJA/KJATLK""",true,2240.0,51125.0,0,1
2,true,57323,,"""S7/SU/UT""",36,false,false,"""2024-06-15T14:50:00""","""2024-06-15T09:25:00""","""07:25:00""","""E70""","""OVB""","""OVB""",1.0,0.0,1.0,"""TLK""","""02:50:00""","""5358""","""S7""","""S7""",4.0,"""E70""","""KJA""","""KJA""",1.0,0.0,1.0,"""OVB""","""01:20:00""","""5311""","""S7""","""S7""",4.0,,,…,,,,,,,,,,,,,,,,,,,,,2300.0,,1.0,3500.0,,1.0,1.0,1,2087645,"""98ce0dabf6964640b63079fbafd42c…",2024-05-17 03:03:08,"""TLKKJA/KJATLK""",true,2240.0,53695.0,0,2
3,true,57323,123,"""S7/SU/UT""",36,true,false,"""2024-06-15T14:50:00""","""2024-06-15T09:25:00""","""07:25:00""","""E70""","""OVB""","""OVB""",1.0,0.0,1.0,"""TLK""","""02:50:00""","""5358""","""S7""","""S7""",4.0,"""E70""","""KJA""","""KJA""",1.0,0.0,1.0,"""OVB""","""01:20:00""","""5311""","""S7""","""S7""",4.0,,,…,,,,,,,,,,,,,,,,,,,,,0.0,,1.0,0.0,,1.0,1.0,1,2087645,"""98ce0dabf6964640b63079fbafd42c…",2024-05-17 03:03:08,"""TLKKJA/KJATLK""",true,2240.0,81880.0,0,3
4,true,57323,,"""S7/SU/UT""",36,false,false,"""2024-06-15T14:50:00""","""2024-06-15T09:25:00""","""07:25:00""","""E70""","""OVB""","""OVB""",1.0,0.0,1.0,"""TLK""","""02:50:00""","""5358""","""S7""","""S7""",4.0,"""E70""","""KJA""","""KJA""",1.0,0.0,1.0,"""OVB""","""01:20:00""","""5311""","""S7""","""S7""",4.0,,,…,,,,,,,,,,,,,,,,,,,,,0.0,,1.0,0.0,,1.0,1.0,1,2087645,"""98ce0dabf6964640b63079fbafd42c…",2024-05-17 03:03:08,"""TLKKJA/KJATLK""",true,2240.0,86070.0,0,4
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
995,true,42622,,,36,false,false,"""2024-05-21T07:15:00""","""2024-05-20T17:45:00""","""11:30:00""","""32A""","""MOW""","""DME""",0.0,0.0,1.0,"""OVB""","""04:20:00""","""2516""","""S7""","""S7""",9.0,"""32N""","""HTA""","""HTA""",0.0,0.0,1.0,"""DME""","""06:10:00""","""3045""","""S7""","""S7""",1.0,,,…,,,,,,,,,,,,,,,,,,,,,4000.0,,1.0,0.0,,0.0,0.0,1,2812817,"""1e9c2a3788a74ab7b80890ef712773…",2024-05-17 06:39:30,"""OVBHTA/HTAOVB""",true,2612.0,54837.0,0,995
996,true,42622,115,,36,true,false,"""2024-05-21T07:15:00""","""2024-05-20T17:45:00""","""11:30:00""","""32A""","""MOW""","""DME""",1.0,0.0,1.0,"""OVB""","""04:20:00""","""2516""","""S7""","""S7""",9.0,"""32N""","""HTA""","""HTA""",1.0,0.0,1.0,"""DME""","""06:10:00""","""3045""","""S7""","""S7""",1.0,,,…,,,,,,,,,,,,,,,,,,,,,2300.0,,1.0,3500.0,,1.0,0.0,1,2812817,"""1e9c2a3788a74ab7b80890ef712773…",2024-05-17 06:39:30,"""OVBHTA/HTAOVB""",true,2612.0,71662.0,0,996
997,true,42622,,,36,false,false,"""2024-05-21T07:15:00""","""2024-05-20T17:45:00""","""11:30:00""","""32A""","""MOW""","""DME""",1.0,0.0,1.0,"""OVB""","""04:20:00""","""2516""","""S7""","""S7""",9.0,"""32N""","""HTA""","""HTA""",1.0,0.0,1.0,"""DME""","""06:10:00""","""3045""","""S7""","""S7""",1.0,,,…,,,,,,,,,,,,,,,,,,,,,2300.0,,1.0,3500.0,,1.0,0.0,1,2812817,"""1e9c2a3788a74ab7b80890ef712773…",2024-05-17 06:39:30,"""OVBHTA/HTAOVB""",true,2612.0,76067.0,0,997
998,true,42622,115,,36,true,false,"""2024-05-21T07:15:00""","""2024-05-20T17:45:00""","""11:30:00""","""32A""","""MOW""","""DME""",1.0,0.0,1.0,"""OVB""","""04:20:00""","""2516""","""S7""","""S7""",9.0,"""32N""","""HTA""","""HTA""",1.0,0.0,1.0,"""DME""","""06:10:00""","""3045""","""S7""","""S7""",1.0,,,…,,,,,,,,,,,,,,,,,,,,,0.0,,1.0,0.0,,1.0,0.0,1,2812817,"""1e9c2a3788a74ab7b80890ef712773…",2024-05-17 06:39:30,"""OVBHTA/HTAOVB""",true,2612.0,116322.0,0,998


In [4]:
train_df.shape

(18145372, 127)

In [5]:
print(f"the number of new profiles in train data is {len(train_df.select("profileId").unique())}")
print(f'the number of searches in train data is {len(train_df.select('ranker_id').unique())}')

the number of new profiles in train data is 32922
the number of searches in train data is 105539


In [6]:
flight_counts = train_df.group_by('ranker_id').agg(pl.len().alias('flight_count'))

print(f'The minimum number of searches in train data is {flight_counts['flight_count'].min()}')
print(f'The maximum number of searches in train data is {flight_counts['flight_count'].max()}')

del flight_counts
_ = gc.collect()

The minimum number of searches in train data is 1
The maximum number of searches in train data is 8236


### Split into Train and Test Data Sets

In [7]:
def stratified_flight_count_split(df, test_size=0.2, random_state=42):
    """
    Stratified split based on number of flights per search (ranker_id).
    Ensures train/test sets have similar distributions of search sizes.
    """

    # Calculate flights per search
    search_flight_counts = (
        df
        .group_by('ranker_id')
        .agg(pl.len().alias('flight_count'))
    )

    # Create stratification bins based on your distribution
    # Adjust these thresholds based on your specific data
    search_flight_counts = search_flight_counts.with_columns(
        pl.when(pl.col('flight_count') <= 50).then(pl.lit('very_small'))     # Most searches
        .when(pl.col('flight_count') <= 200).then(pl.lit('small'))
        .when(pl.col('flight_count') <= 500).then(pl.lit('medium'))
        .when(pl.col('flight_count') <= 1000).then(pl.lit('large'))
        .otherwise(pl.lit('very_large'))                                     # High-flight searches
        .alias('size_category')
    )

    # Print distribution for verification
    size_dist = search_flight_counts.group_by('size_category').len().sort('size_category')
    print("Search size distribution:")
    for row in size_dist.iter_rows(named=True):
        print(f"  {row['size_category']}: {row['len']:,} searches")

    # Split within each size category
    train_searches = []
    test_searches = []

    for category in ['very_small', 'small', 'medium', 'large', 'very_large']:
        category_searches = (
            search_flight_counts
            .filter(pl.col('size_category') == category)
            .select('ranker_id')
            .to_pandas()['ranker_id'].tolist()
        )

        if len(category_searches) > 1:
            train_cat, test_cat = train_test_split(
                category_searches,
                test_size=test_size,
                random_state=random_state,
                stratify=None  # No further stratification within category
            )
            train_searches.extend(train_cat)
            test_searches.extend(test_cat)
        elif len(category_searches) == 1:
            # Single search goes to train
            train_searches.extend(category_searches)

    # Filter original data
    train_data = df.filter(pl.col('ranker_id').is_in(train_searches))
    test_data = df.filter(pl.col('ranker_id').is_in(test_searches))

    # Verification: Check distribution preservation
    print(f"\nSplit results:")
    print(f"Train searches: {len(train_searches):,}")
    print(f"Test searches: {len(test_searches):,}")
    print(f"Train rows: {train_data.height:,}")
    print(f"Test rows: {test_data.height:,}")

    # Verify stratification worked
    train_size_dist = (
        train_data
        .group_by('ranker_id')
        .agg(pl.len().alias('flight_count'))
        .with_columns(
            pl.when(pl.col('flight_count') <= 50).then(pl.lit('very_small'))
            .when(pl.col('flight_count') <= 200).then(pl.lit('small'))
            .when(pl.col('flight_count') <= 500).then(pl.lit('medium'))
            .when(pl.col('flight_count') <= 1000).then(pl.lit('large'))
            .otherwise(pl.lit('very_large'))
            .alias('size_category')
        )
        .group_by('size_category')
        .len()
        .sort('size_category')
    )

    test_size_dist = (
        test_data
        .group_by('ranker_id')
        .agg(pl.len().alias('flight_count'))
        .with_columns(
            pl.when(pl.col('flight_count') <= 50).then(pl.lit('very_small'))
            .when(pl.col('flight_count') <= 200).then(pl.lit('small'))
            .when(pl.col('flight_count') <= 500).then(pl.lit('medium'))
            .when(pl.col('flight_count') <= 1000).then(pl.lit('large'))
            .otherwise(pl.lit('very_large'))
            .alias('size_category')
        )
        .group_by('size_category')
        .len()
        .sort('size_category')
    )

    print(f"\nTrain set distribution:")
    for row in train_size_dist.iter_rows(named=True):
        print(f"  {row['size_category']}: {row['len']:,} searches")

    print(f"\nTest set distribution:")
    for row in test_size_dist.iter_rows(named=True):
        print(f"  {row['size_category']}: {row['len']:,} searches")

    return train_data, test_data


In [8]:
train_df, test_df = stratified_flight_count_split(train_df)

Search size distribution:
  large: 5,035 searches
  medium: 13,896 searches
  small: 30,360 searches
  very_large: 3,141 searches
  very_small: 53,107 searches

Split results:
Train searches: 84,429
Test searches: 21,110
Train rows: 14,567,293
Test rows: 3,578,079

Train set distribution:
  large: 4,028 searches
  medium: 11,116 searches
  small: 24,288 searches
  very_large: 2,512 searches
  very_small: 42,485 searches

Test set distribution:
  large: 1,007 searches
  medium: 2,780 searches
  small: 6,072 searches
  very_large: 629 searches
  very_small: 10,622 searches


In [9]:
# Save the train/test splits
train_df.write_parquet('data/train_df.parquet')
test_df.write_parquet('data/test_df.parquet')

# Remove test data from memory
del test_df
gc.collect()

0

In [10]:
# Check if train_df exists and load if not
try:
    train_df.head()
except NameError:
    train_df = pl.read_parquet('data/train_df.parquet')
    print(f'loaded train_df: {train_df.shape}')

train_df.head(100)

Id,bySelf,companyID,corporateTariffCode,frequentFlyer,nationality,isAccess3D,isVip,legs0_arrivalAt,legs0_departureAt,legs0_duration,legs0_segments0_aircraft_code,legs0_segments0_arrivalTo_airport_city_iata,legs0_segments0_arrivalTo_airport_iata,legs0_segments0_baggageAllowance_quantity,legs0_segments0_baggageAllowance_weightMeasurementType,legs0_segments0_cabinClass,legs0_segments0_departureFrom_airport_iata,legs0_segments0_duration,legs0_segments0_flightNumber,legs0_segments0_marketingCarrier_code,legs0_segments0_operatingCarrier_code,legs0_segments0_seatsAvailable,legs0_segments1_aircraft_code,legs0_segments1_arrivalTo_airport_city_iata,legs0_segments1_arrivalTo_airport_iata,legs0_segments1_baggageAllowance_quantity,legs0_segments1_baggageAllowance_weightMeasurementType,legs0_segments1_cabinClass,legs0_segments1_departureFrom_airport_iata,legs0_segments1_duration,legs0_segments1_flightNumber,legs0_segments1_marketingCarrier_code,legs0_segments1_operatingCarrier_code,legs0_segments1_seatsAvailable,legs0_segments2_aircraft_code,legs0_segments2_arrivalTo_airport_city_iata,…,legs1_segments2_baggageAllowance_weightMeasurementType,legs1_segments2_cabinClass,legs1_segments2_departureFrom_airport_iata,legs1_segments2_duration,legs1_segments2_flightNumber,legs1_segments2_marketingCarrier_code,legs1_segments2_operatingCarrier_code,legs1_segments2_seatsAvailable,legs1_segments3_aircraft_code,legs1_segments3_arrivalTo_airport_city_iata,legs1_segments3_arrivalTo_airport_iata,legs1_segments3_baggageAllowance_quantity,legs1_segments3_baggageAllowance_weightMeasurementType,legs1_segments3_cabinClass,legs1_segments3_departureFrom_airport_iata,legs1_segments3_duration,legs1_segments3_flightNumber,legs1_segments3_marketingCarrier_code,legs1_segments3_operatingCarrier_code,legs1_segments3_seatsAvailable,miniRules0_monetaryAmount,miniRules0_percentage,miniRules0_statusInfos,miniRules1_monetaryAmount,miniRules1_percentage,miniRules1_statusInfos,pricingInfo_isAccessTP,pricingInfo_passengerCount,profileId,ranker_id,requestDate,searchRoute,sex,taxes,totalPrice,selected,__index_level_0__
i64,bool,i64,i64,str,i64,bool,bool,str,str,str,str,str,str,f64,f64,f64,str,str,str,str,str,f64,str,str,str,f64,f64,f64,str,str,str,str,str,f64,str,str,…,f64,f64,str,str,str,str,str,f64,str,str,str,f64,f64,f64,str,str,str,str,str,f64,f64,f64,f64,f64,f64,f64,f64,i64,i64,str,datetime[ns],str,bool,f64,f64,i64,i64
25,true,57323,123,,36,true,false,"""2024-06-15T21:35:00""","""2024-06-15T11:25:00""","""12:10:00""","""E70""","""OVB""","""OVB""",0.0,0.0,1.0,"""TOF""","""00:50:00""","""5322""","""S7""","""S7""",9.0,"""E70""","""NJC""","""NJC""",0.0,0.0,1.0,"""OVB""","""01:30:00""","""5329""","""S7""","""S7""",1.0,,,…,,,,,,,,,,,,,,,,,,,,,4000.0,,1.0,0.0,,0.0,1.0,1,2087904,"""905909166d934c618ad55ab7f5cea5…",2024-05-17 03:09:59,"""TOFNJC""",true,444.0,9944.0,0,25
26,true,57323,,,36,false,false,"""2024-06-15T21:35:00""","""2024-06-15T11:25:00""","""12:10:00""","""E70""","""OVB""","""OVB""",0.0,0.0,1.0,"""TOF""","""00:50:00""","""5322""","""S7""","""S7""",9.0,"""E70""","""NJC""","""NJC""",0.0,0.0,1.0,"""OVB""","""01:30:00""","""5329""","""S7""","""S7""",1.0,,,…,,,,,,,,,,,,,,,,,,,,,4000.0,,1.0,0.0,,0.0,1.0,1,2087904,"""905909166d934c618ad55ab7f5cea5…",2024-05-17 03:09:59,"""TOFNJC""",true,444.0,10444.0,0,26
27,true,57323,123,,36,true,false,"""2024-06-15T21:35:00""","""2024-06-15T11:25:00""","""12:10:00""","""E70""","""OVB""","""OVB""",1.0,0.0,1.0,"""TOF""","""00:50:00""","""5322""","""S7""","""S7""",9.0,"""E70""","""NJC""","""NJC""",1.0,0.0,1.0,"""OVB""","""01:30:00""","""5329""","""S7""","""S7""",1.0,,,…,,,,,,,,,,,,,,,,,,,,,2300.0,,1.0,3500.0,,1.0,1.0,1,2087904,"""905909166d934c618ad55ab7f5cea5…",2024-05-17 03:09:59,"""TOFNJC""",true,444.0,12284.0,0,27
28,true,57323,,,36,false,false,"""2024-06-15T21:35:00""","""2024-06-15T11:25:00""","""12:10:00""","""E70""","""OVB""","""OVB""",1.0,0.0,1.0,"""TOF""","""00:50:00""","""5322""","""S7""","""S7""",9.0,"""E70""","""NJC""","""NJC""",1.0,0.0,1.0,"""OVB""","""01:30:00""","""5329""","""S7""","""S7""",1.0,,,…,,,,,,,,,,,,,,,,,,,,,2300.0,,1.0,3500.0,,1.0,1.0,1,2087904,"""905909166d934c618ad55ab7f5cea5…",2024-05-17 03:09:59,"""TOFNJC""",true,444.0,12904.0,0,28
29,true,57323,123,,36,true,false,"""2024-06-15T21:35:00""","""2024-06-15T11:25:00""","""12:10:00""","""E70""","""OVB""","""OVB""",1.0,0.0,1.0,"""TOF""","""00:50:00""","""5322""","""S7""","""S7""",9.0,"""E70""","""NJC""","""NJC""",1.0,0.0,1.0,"""OVB""","""01:30:00""","""5329""","""S7""","""S7""",9.0,,,…,,,,,,,,,,,,,,,,,,,,,0.0,,1.0,0.0,,1.0,1.0,1,2087904,"""905909166d934c618ad55ab7f5cea5…",2024-05-17 03:09:59,"""TOFNJC""",true,870.0,26335.0,0,29
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
120,true,59096,,,36,false,false,"""2024-06-28T11:40:00""","""2024-06-28T10:10:00""","""05:30:00""","""73H""","""OVB""","""OVB""",1.0,0.0,1.0,"""GDX""","""05:30:00""","""5220""","""S7""","""S7""",6.0,,,,,,,,,,,,,,,…,,,,,,,,,,,,,,,,,,,,,2300.0,,1.0,3500.0,,1.0,1.0,1,3380969,"""e0f9319a8b3048cdb1c974395e599e…",2024-05-17 04:02:26,"""GDXOVB""",true,719.0,25239.0,1,120
121,true,59096,,,36,false,false,"""2024-06-28T11:40:00""","""2024-06-28T10:10:00""","""05:30:00""","""73H""","""OVB""","""OVB""",1.0,0.0,1.0,"""GDX""","""05:30:00""","""5220""","""S7""","""S7""",6.0,,,,,,,,,,,,,,,…,,,,,,,,,,,,,,,,,,,,,0.0,,1.0,0.0,,1.0,1.0,1,3380969,"""e0f9319a8b3048cdb1c974395e599e…",2024-05-17 04:02:26,"""GDXOVB""",true,719.0,38689.0,0,121
122,true,59096,,,36,false,false,"""2024-06-28T11:40:00""","""2024-06-28T10:10:00""","""05:30:00""","""73H""","""OVB""","""OVB""",1.0,0.0,2.0,"""GDX""","""05:30:00""","""5220""","""S7""","""S7""",4.0,,,,,,,,,,,,,,,…,,,,,,,,,,,,,,,,,,,,,4000.0,,1.0,0.0,,0.0,0.0,1,3380969,"""e0f9319a8b3048cdb1c974395e599e…",2024-05-17 04:02:26,"""GDXOVB""",true,719.0,44174.0,0,122
123,true,59096,,,36,false,false,"""2024-06-28T11:40:00""","""2024-06-28T10:10:00""","""05:30:00""","""73H""","""OVB""","""OVB""",1.0,0.0,2.0,"""GDX""","""05:30:00""","""5220""","""S7""","""S7""",4.0,,,,,,,,,,,,,,,…,,,,,,,,,,,,,,,,,,,,,2300.0,,1.0,3500.0,,1.0,0.0,1,3380969,"""e0f9319a8b3048cdb1c974395e599e…",2024-05-17 04:02:26,"""GDXOVB""",true,719.0,56639.0,0,123


### Utilities

In [11]:
MAJOR_HUBS = ['ATL','DXB','DFW','HND','LHR','DEN','ORD','IST','PVG','ICN','CDG', 'JFK','CLT','MEX','SFO','EWR','MIA','BKK','GRU','HKG']

def camel_to_snake(name):
    """Convert camelCase or PascalCase to snake_case"""
    s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
    s2 = re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1)
    return s2.lower()


def convert_columns_to_snake_case(df):
    """Convert all column names in a polars DataFrame to snake_case"""
    return df.rename({col: camel_to_snake(col) for col in df.columns})


def parse_duration_to_minutes(duration_col: str) -> pl.Expr:
    """Parse duration string to minutes (handles format like '2.04:20' - D.HH:MM)."""
    return (
        pl.when(pl.col(duration_col).is_not_null() & (pl.col(duration_col) != ""))
        .then(
            pl.when(pl.col(duration_col).str.contains(r'\.'))
            .then(
                # Format: D.HH:MM:SS (e.g., "1.00:30:00", "2.09:45:00")
                pl.col(duration_col).str.extract(r'^(\d+)\.(\d{2}):(\d{2}):(\d{2})$', 1).cast(pl.Int32, strict=False).fill_null(0) * 1440 +  # Days
                pl.col(duration_col).str.extract(r'^(\d+)\.(\d{2}):(\d{2}):(\d{2})$', 2).cast(pl.Int32, strict=False).fill_null(0) * 60 +   # Hours
                pl.col(duration_col).str.extract(r'^(\d+)\.(\d{2}):(\d{2}):(\d{2})$', 3).cast(pl.Int32, strict=False).fill_null(0) +        # Minutes
                (pl.col(duration_col).str.extract(r'^(\d+)\.(\d{2}):(\d{2}):(\d{2})$', 4).cast(pl.Int32, strict=False).fill_null(0) / 60).round(0).cast(pl.Int32, strict=False)  # Seconds
            )
            .otherwise(
                # Format: HH:MM:SS (e.g., "07:25:00", "17:55:00")
                pl.col(duration_col).str.extract(r'^(\d{2}):(\d{2}):(\d{2})$', 1).cast(pl.Int32, strict=False).fill_null(0) * 60 +   # Hours
                pl.col(duration_col).str.extract(r'^(\d{2}):(\d{2}):(\d{2})$', 2).cast(pl.Int32, strict=False).fill_null(0) +        # Minutes
                (pl.col(duration_col).str.extract(r'^(\d{2}):(\d{2}):(\d{2})$', 3).cast(pl.Int32, strict=False).fill_null(0) / 60).round(0).cast(pl.Int32, strict=False)  # Seconds
            )
        )
        .otherwise(0)
    )

### Engineer Customer Features

In [12]:
# Customer attributes for clustering analysis

def create_customer_aggregation_features() -> List[pl.Expr]:
    """Create customer aggregation expressions for basic attributes and search behavior."""
    return [
        # Basic customer attributes (take first non-null value per customer)
        pl.col('companyID').drop_nulls().first().alias('company_id'),
        pl.col('sex').drop_nulls().first().cast(pl.Int8).alias('sex'),
        pl.col('nationality').drop_nulls().first().alias('nationality'),
        pl.col('frequentFlyer').drop_nulls().first().str.replace('- ЮТэйр ЗАО', 'UT').fill_null('').alias('frequent_flyer'),
        pl.col('isVip').drop_nulls().first().cast(pl.Int8).alias('is_vip'),
        pl.col('bySelf').drop_nulls().first().cast(pl.Int8).alias('by_self'),
        pl.col('corporateTariffCode').is_not_null().cast(pl.Int8).max().alias('has_corp_codes'),

        # Search behavior metrics
        pl.len().alias('total_searches'),
        pl.col('legs1_departureAt').is_not_null().mean().alias('roundtrip_preference'),
        pl.col('searchRoute').drop_nulls().n_unique().alias('unique_routes_searched'),
    ]


def create_booking_lead_time_features() -> List[pl.Expr]:
    """Create booking lead time statistics."""
    # Calculate booking lead time in days
    booking_lead_expr = (
        (pl.col('legs0_departureAt').str.to_datetime() -
         pl.col('requestDate').cast(pl.Datetime)) / pl.duration(days=1)
    ).cast(pl.Int32)

    return [
        booking_lead_expr.min().alias('min_booking_lead_days'),
        booking_lead_expr.max().alias('max_booking_lead_days'),
        booking_lead_expr.mean().alias('avg_booking_lead_days'),
        booking_lead_expr.median().alias('median_booking_lead_days'),
    ]


def create_travel_preference_features() -> List[pl.Expr]:
    """Create travel preference features for most common airports and carriers."""
    return [
        # Most common departure airport
        pl.col('legs0_segments0_departureFrom_airport_iata').drop_nulls().mode().first().alias('most_common_departure_airport'),
        pl.col('legs0_segments0_departureFrom_airport_iata').drop_nulls().n_unique().alias('unique_departure_airports'),

        # Most common marketing carrier
        pl.col('legs0_segments0_marketingCarrier_code').drop_nulls().mode().first().alias('most_common_carrier'),
        pl.col('legs0_segments0_marketingCarrier_code').drop_nulls().n_unique().alias('unique_carriers_used'),
    ]


def create_cabin_class_features(cabin_class_cols: List[str]) -> List[pl.Expr]:
    """Create cabin class preference statistics."""
    if not cabin_class_cols:
        # Return default values if no cabin class columns found
        return [
            pl.lit(None).alias('min_cabin_class'),
            pl.lit(None).alias('max_cabin_class'),
            pl.lit(None).alias('avg_cabin_class'),
        ]

    return [
        # Cabin class statistics across all segments
        pl.min_horizontal([pl.col(col) for col in cabin_class_cols]).min().alias('min_cabin_class'),
        pl.max_horizontal([pl.col(col) for col in cabin_class_cols]).max().alias('max_cabin_class'),
        pl.mean_horizontal([pl.col(col) for col in cabin_class_cols]).mean().alias('avg_cabin_class'),
    ]


def create_temporal_preference_features() -> List[pl.Expr]:
    """Create temporal preference features for departure patterns."""
    return [
        # Weekday preference (most common day of week for departures)
        pl.col('legs0_departureAt').str.to_datetime().dt.weekday()
          .mode().first().alias('weekday_preference'),

        # Weekend travel rate (percentage of weekend departures - 5=Sat, 6=Sun)
        pl.col('legs0_departureAt').str.to_datetime().dt.weekday()
          .map_elements(lambda x: 1 if x >= 5 else 0, return_dtype=pl.Int8)
          .mean().alias('weekend_travel_rate'),

        # Time of day variance (how consistent are their departure times)
        pl.col('legs0_departureAt').str.to_datetime().dt.hour()
          .std().alias('time_of_day_variance'),

        # Night flight preference (flights departing 22:00-06:00)
        pl.col('legs0_departureAt').str.to_datetime().dt.hour()
          .map_elements(lambda x: 1 if (x >= 22 or x < 6) else 0, return_dtype=pl.Int8)
          .mean().alias('night_flight_preference')
    ]


def create_route_specific_features() -> List[pl.Expr]:
    """Create features related to route preferences and characteristics."""
    return [
        # Route loyalty
        (1 - (pl.col('searchRoute').n_unique() / pl.len().clip(1, None))).alias('route_loyalty'),

        # Hub preference (preference for major hub airports)
        (
            pl.col('legs0_segments0_departureFrom_airport_iata').is_in(MAJOR_HUBS) |
            pl.col('legs0_segments0_arrivalTo_airport_iata').is_in(MAJOR_HUBS)
        ).cast(pl.Int8).mean().alias('hub_preference'),

        # Short haul preference
        ((1 - (pl.col('leg0_duration_minutes').mean() / 180)) * 0.7 +
         pl.when(pl.col('leg0_duration_minutes') <= 180).then(1).otherwise(0).mean() * 0.3)
        .clip(0, 1).alias('short_haul_preference'),

        # Connection tolerance (average segments)
        pl.sum_horizontal([
            pl.col(f'legs0_segments{i}_departureFrom_airport_iata').is_not_null().cast(pl.Int8)
            for i in range(4)  # assuming max 4 segments
        ]).mean().alias('connection_tolerance'),

        # Preference for longer vs shorter flights (duration quartile preference)
        pl.when(pl.col('leg0_duration_minutes') <= pl.col('leg0_duration_q25'))
        .then(1)  # Short flights
        .when(pl.col('leg0_duration_minutes') <= pl.col('leg0_duration_q50'))
        .then(2)  # Medium-short flights
        .when(pl.col('leg0_duration_minutes') <= pl.col('leg0_duration_q75'))
        .then(3)  # Medium-long flights
        .otherwise(4)  # Long flights
        .mode().first().alias('preferred_duration_quartile')
    ]


def create_price_sensitivity_features() -> List[pl.Expr]:
    """Create features related to price sensitivity and patterns."""
    return [
        # Basic correlation between price and duration
        pl.corr(
            pl.col('totalPrice'),
            pl.col('trip_duration_minutes')
        ).fill_null(0).alias('price_to_duration_sensitivity'),

        # Price per minute metric (average across all flights)
        (pl.col('totalPrice') / pl.col('trip_duration_minutes').clip(1, None)).mean().alias('avg_price_per_minute'),

        # Consistency of price-per-minute (lower std = more consistent valuation)
        (pl.col('totalPrice') / pl.col('trip_duration_minutes').clip(1, None)).std().fill_null(0).alias('price_per_minute_variance'),

        # Price position within searches (percentile rank)
        pl.col('price_percentile').mean().alias('price_position_preference'),

        # Premium economy preference (assuming cabin class 2 is premium economy)
        pl.mean_horizontal([
            pl.col(f'legs0_segments{i}_cabinClass') == 2
            for i in range(4)
        ]).mean().alias('premium_economy_preference'),

        # Consistent price tier
        (1 - (
            pl.when(pl.col('totalPrice') <= pl.col('price_q25'))
            .then(1)  # Budget tier
            .when(pl.col('totalPrice') <= pl.col('price_q50'))
            .then(2)  # Economy tier
            .when(pl.col('totalPrice') <= pl.col('price_q75'))
            .then(3)  # Premium tier
            .otherwise(4)  # Luxury tier
            .std()
            .fill_null(0)  # Handle null case
            / 3
        ).clip(0, 1)).alias('consistent_price_tier'),

        # Most common price tier
        pl.when(pl.col('totalPrice') <= pl.col('price_q25'))
        .then(1)  # Budget tier
        .when(pl.col('totalPrice') <= pl.col('price_q50'))
        .then(2)  # Economy tier
        .when(pl.col('totalPrice') <= pl.col('price_q75'))
        .then(3)  # Premium tier
        .otherwise(4)  # Luxury tier
        .mode().first().alias('preferred_price_tier'),
    ]


def create_service_preference_features() -> List[pl.Expr]:
    """Create features related to service preferences."""
    # First get all the relevant column names for type and quantity
    type_cols = [
        f'legs{leg}_segments{seg}_baggageAllowance_weightMeasurementType'
        for leg in range(2) for seg in range(4)
    ]
    qty_cols = [
        f'legs{leg}_segments{seg}_baggageAllowance_quantity'
        for leg in range(2) for seg in range(4)
    ]

    return [
        # Baggage quantity preference (average of minimum bags allowed per flight option)
        pl.min_horizontal([
            pl.when(pl.col(type_col) == 0)
            .then(pl.col(qty_col))
            .otherwise(pl.lit(None))
            for type_col, qty_col in zip(type_cols, qty_cols)
        ]).mean().fill_null(0).alias('baggage_qty_preference'),

        # Baggage weight preference (average of minimum weight allowed per flight option)
        pl.min_horizontal([
            pl.when(pl.col(type_col) == 1)
            .then(pl.col(qty_col))
            .otherwise(pl.lit(None))
            for type_col, qty_col in zip(type_cols, qty_cols)
        ]).mean().fill_null(0).fill_nan(0).alias('baggage_weight_preference'),

        # Loyalty program utilization
        pl.col('frequentFlyer').is_not_null().mean().alias('loyalty_program_utilization')
    ]


def create_derived_metrics() -> List[pl.Expr]:
    """Create complex derived metrics from combinations of features."""
    return [
        # Convenience priority score (higher = more emphasis on convenient times)
        ((1 - pl.col('time_of_day_variance')) * 10 +
         pl.col('price_to_duration_sensitivity') * 5)
        .alias('convenience_priority_score'),

        # Loyalty vs price index (higher = more loyal, less price sensitive)
        (pl.col('loyalty_program_utilization') * 10 -
         pl.col('price_position_preference') / 10)
        .alias('loyalty_vs_price_index'),

        # Planning consistency score (inverse of lead time variance)
        (1 / (pl.col('max_booking_lead_days') - pl.col('min_booking_lead_days') + 1))
        .alias('planning_consistency_score'),

        # Luxury index (combination of cabin class and price tier)
        (pl.col('avg_cabin_class') * 20 +
         pl.col('price_position_preference') / 2)
        .alias('luxury_index'),

        # Create advanced features
        (pl.col('total_searches') / pl.col('unique_routes_searched').clip(1)).alias('search_intensity_per_route'),
        (pl.col('max_booking_lead_days') - pl.col('min_booking_lead_days')).alias('lead_time_variance'),
        (pl.col('avg_booking_lead_days') / pl.col('median_booking_lead_days').clip(1)).alias('lead_time_skew'),
        (pl.col('unique_carriers_used') / pl.col('total_searches').clip(1)).alias('carrier_diversity'),
        (pl.col('unique_departure_airports') / pl.col('total_searches').clip(1)).alias('airport_diversity'),
        (pl.col('max_cabin_class') - pl.col('min_cabin_class')).alias('cabin_class_range'),
        (pl.col('is_vip').cast(pl.Int8) * 2 + pl.col('has_corp_codes').is_not_null().cast(pl.Int8)).alias('customer_tier'),
    ]


def add_trip_duration(df: pl.DataFrame) -> pl.DataFrame:
    """Add trip duration to the DataFrame."""
    leg0_duration_minutes = parse_duration_to_minutes('legs0_duration')
    leg1_duration_minutes = parse_duration_to_minutes('legs1_duration')
    trip_duration_minutes = leg0_duration_minutes + leg1_duration_minutes

    return df.with_columns([
        leg0_duration_minutes.alias('leg0_duration_minutes'),
        leg1_duration_minutes.alias('leg1_duration_minutes'),
        trip_duration_minutes.alias('trip_duration_minutes'),
    ])


def create_windows_based_features(df) -> pl.DataFrame:
    # Add window-based features like price_percentile if they don't exist
    return (df.with_columns([
        # calculate price percentile over search session
        ((pl.col('totalPrice').rank(method='min').over('ranker_id') - 1) /
        (pl.col('totalPrice').count().over('ranker_id') - 1) * 100)
        .fill_null(50.0).alias('price_percentile'),

        # calculate price quartiles over profileId
        pl.col('totalPrice').quantile(0.25).over('profileId').alias('price_q25'),
        pl.col('totalPrice').quantile(0.50).over('profileId').alias('price_q50'),
        pl.col('totalPrice').quantile(0.75).over('profileId').alias('price_q75'),

        # calculate leg0_duration quartiles over profileId
        pl.col('leg0_duration_minutes').quantile(0.25).over('profileId').alias('leg0_duration_q25'),
        pl.col('leg0_duration_minutes').quantile(0.50).over('profileId').alias('leg0_duration_q50'),
        pl.col('leg0_duration_minutes').quantile(0.75).over('profileId').alias('leg0_duration_q75'),
    ]))


def create_interaction_features() -> List[pl.Expr]:
    # Create customer/business interaction features
    return [
        # Create VIP interactions
        (pl.col('search_intensity_per_route') * pl.col('is_vip')).alias('vip_search_intensity'),
        (pl.col('carrier_diversity') * pl.col('is_vip')).alias('vip_carrier_diversity'),
        (pl.col('avg_cabin_class') * pl.col('is_vip')).alias('vip_cabin_preference'),

        # Create corporate interactions
        (pl.col('total_searches') * pl.col('has_corp_codes')).alias('corp_search_volume'),
        (pl.col('roundtrip_preference') * pl.col('has_corp_codes')).alias('corp_roundtrip_pref'),
        (pl.col('lead_time_variance') * pl.col('has_corp_codes')).alias('corp_planning_variance'),
    ]

def extract_customer_features(df: pl.DataFrame) -> pl.DataFrame:
    """
    Extract customer features for clustering analysis.
    Aggregates by profileId to create customer-level features.
    """
    # Check if already processed
    if df.height > 0 and 'total_searches' in df.columns:
        return df

    # Get cabin class columns
    cabin_class_cols = [col for col in df.columns if col.startswith('legs') and col.endswith('_cabinClass')]

    # Create lazy frame and group by profileId
    lazy_df = create_windows_based_features(add_trip_duration(df)).lazy().group_by('profileId')

    # Apply customer feature groups
    lazy_df = lazy_df.agg([
        *create_customer_aggregation_features(),
        *create_booking_lead_time_features(),
        *create_travel_preference_features(),
        *create_cabin_class_features(cabin_class_cols),
        *create_temporal_preference_features(),
        *create_route_specific_features(),
        *create_price_sensitivity_features(),
        *create_service_preference_features()
    ])

    # Materialize to generate the basic features
    # base_features = customer_features.collect()

    # Add the derived metrics that depend on the generated features
    lazy_df = lazy_df.with_columns(create_derived_metrics())

    # Add interactive features
    lazy_df = lazy_df.with_columns(create_interaction_features())

    # print(f"Generated {len(enhanced_features.columns)} customer features for {len(enhanced_features)} customers")
    return lazy_df.collect().fill_null(0).fill_nan(0)

In [13]:
# 1. Feature Engineering
cust_data = extract_customer_features(train_df)
print(f'Generated {len(cust_data.columns)} customer features for {len(cust_data)} customers')

Generated 58 customer features for 30134 customers


In [14]:
cust_data.head(100)

profileId,company_id,sex,nationality,frequent_flyer,is_vip,by_self,has_corp_codes,total_searches,roundtrip_preference,unique_routes_searched,min_booking_lead_days,max_booking_lead_days,avg_booking_lead_days,median_booking_lead_days,most_common_departure_airport,unique_departure_airports,most_common_carrier,unique_carriers_used,min_cabin_class,max_cabin_class,avg_cabin_class,weekday_preference,weekend_travel_rate,time_of_day_variance,night_flight_preference,route_loyalty,hub_preference,short_haul_preference,connection_tolerance,preferred_duration_quartile,price_to_duration_sensitivity,avg_price_per_minute,price_per_minute_variance,price_position_preference,premium_economy_preference,consistent_price_tier,preferred_price_tier,baggage_qty_preference,baggage_weight_preference,loyalty_program_utilization,convenience_priority_score,loyalty_vs_price_index,planning_consistency_score,luxury_index,search_intensity_per_route,lead_time_variance,lead_time_skew,carrier_diversity,airport_diversity,cabin_class_range,customer_tier,vip_search_intensity,vip_carrier_diversity,vip_cabin_preference,corp_search_volume,corp_roundtrip_pref,corp_planning_variance
i64,i64,i8,i64,str,i8,i8,i8,u32,f64,u32,i32,i32,f64,f64,str,u32,str,u32,f64,f64,f64,i8,f64,f64,f64,f64,f64,f64,f64,i32,f64,f64,f64,f64,f64,f64,i32,f64,f64,f64,f64,f64,f64,f64,f64,i32,f64,f64,f64,f64,i8,f64,f64,f64,i64,f64,i32
3113429,59946,1,36,"""SU""",0,1,0,644,0.0,2,3,7,5.259317,4.0,"""LED""",4,"""SU""",8,1.0,2.0,1.012422,1,0.0,5.527011,0.142857,0.996894,0.0,0.60292,1.015528,1,0.072807,102.757725,56.887385,48.538632,0.012422,0.623435,1,0.762422,0.0,1.0,-44.906072,5.146137,0.2,44.517763,322.0,4,1.314829,0.012422,0.006211,1.0,1,0.0,0.0,0.0,0,0.0,0
3549259,57320,1,36,"""""",0,1,1,33,0.0,1,5,6,5.757576,6.0,"""PEE""",1,"""SU""",4,1.0,2.0,1.363636,7,1.0,6.863937,0.545455,0.969697,0.0,0.0,2.0,2,0.024393,89.819134,58.589391,47.727273,0.363636,0.617843,2,1.121212,0.0,0.0,-58.517406,-4.772727,0.5,51.136364,33.0,1,0.959596,0.121212,0.030303,1.0,1,0.0,0.0,0.0,33,0.0,1
1119689,27507,1,36,"""""",0,1,1,10,0.0,1,7,8,7.7,8.0,"""NBC""",1,"""SU""",3,1.0,1.0,1.0,4,0.0,5.606544,0.3,0.9,0.0,0.0075,1.1,2,0.869421,67.663817,31.925174,46.666667,0.0,0.648636,1,0.8,0.0,0.0,-41.718332,-4.666667,0.5,43.333333,10.0,1,0.9625,0.3,0.1,0.0,1,0.0,0.0,0.0,10,0.0,1
3548342,63016,1,36,"""""",0,1,0,34,0.0,1,1,1,1.0,1.0,"""BQS""",1,"""U6""",3,1.0,2.0,1.176471,5,1.0,1.711084,0.0,0.970588,0.0,0.0,1.676471,1,0.101882,55.218442,36.147153,49.108734,0.176471,0.622897,2,1.117647,0.0,0.0,-6.601426,-4.910873,1.0,48.083779,34.0,0,1.0,0.088235,0.029412,1.0,1,0.0,0.0,0.0,0,0.0,0
495090,45562,1,36,"""SU""",0,1,1,471,0.774947,4,10,22,12.452229,11.0,"""SVO""",7,"""MU""",14,1.0,1.0,1.0,1,0.163482,4.693865,0.07431,0.991507,0.460722,0.0,1.868365,1,0.473284,67.448329,45.614359,47.443058,0.0,0.639223,3,0.725552,26.298701,1.0,-34.572231,5.255694,0.076923,43.721529,117.75,12,1.132021,0.029724,0.014862,0.0,1,0.0,0.0,0.0,471,0.774947,12
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
2370715,57323,1,36,"""""",0,1,1,181,0.0,2,3,5,4.40884,4.0,"""SVO""",5,"""SU""",8,1.0,2.0,1.287293,5,1.0,5.05067,0.116022,0.98895,0.0,0.0,1.259669,1,0.424391,186.720993,165.467673,48.154652,0.287293,0.634413,3,1.243094,0.0,0.0,-38.384746,-4.815465,0.333333,49.823182,90.5,2,1.10221,0.044199,0.027624,1.0,1,0.0,0.0,0.0,181,0.0,2
3412414,42174,1,36,"""S7/SU""",0,1,1,51,0.196078,4,26,40,34.784314,33.0,"""BQS""",3,"""SU""",3,1.0,1.0,1.0,6,1.0,5.446064,0.176471,0.921569,0.0,0.0,1.764706,1,0.303648,41.04609,28.687825,49.705882,0.0,0.627293,3,0.823529,0.0,1.0,-42.942393,5.029412,0.066667,44.852941,12.75,14,1.05407,0.058824,0.058824,0.0,1,0.0,0.0,0.0,51,0.196078,14
3547026,62262,1,36,"""SU""",0,1,0,464,1.0,1,30,31,30.549569,31.0,"""DME""",4,"""U6""",5,1.0,2.0,1.072917,4,0.0,6.617567,0.344828,0.997845,0.0,0.0,1.700431,1,0.154816,49.496154,61.528717,49.050421,0.067888,0.634882,3,0.918103,0.0,1.0,-55.401591,5.094958,0.5,45.983544,464.0,1,0.98547,0.010776,0.008621,1.0,1,0.0,0.0,0.0,0,0.0,0
1237003,25312,0,36,"""""",0,1,1,12,1.0,1,61,61,61.0,61.0,"""VOG""",1,"""SU""",1,1.0,1.0,1.0,7,1.0,3.074824,0.0,0.916667,0.0,0.403704,1.0,2,0.038906,81.504759,15.298634,46.969697,0.0,0.628453,1,1.5,0.0,0.0,-20.553714,-4.69697,1.0,43.484848,12.0,0,1.0,0.083333,0.083333,0.0,1,0.0,0.0,0.0,12,1.0,0


In [15]:
def remove_outliers(df, method='isolation_forest', contamination=0.05):
    """Remove outliers using various methods"""

    # Get numeric columns only
    numeric_cols = [col for col, dtype in df.schema.items() if dtype.is_numeric()]
    df_numeric = df.select(numeric_cols)

    if method == 'isolation_forest':
        iso_forest = IsolationForest(contamination=contamination, random_state=42)
        outlier_labels = iso_forest.fit_predict(df_numeric.to_pandas())
        mask = outlier_labels == 1

    elif method == 'zscore':
        # Remove rows where any feature has |z-score| > 3
        df_pd = df_numeric.to_pandas()
        z_scores = np.abs(zscore(df_pd, nan_policy='omit'))
        mask = (z_scores < 3).all(axis=1)

    elif method == 'iqr':
        # Remove rows outside 1.5*IQR for any feature
        df_pd = df_numeric.to_pandas()
        Q1 = df_pd.quantile(0.25)
        Q3 = df_pd.quantile(0.75)
        IQR = Q3 - Q1
        mask = ~((df_pd < (Q1 - 1.5 * IQR)) | (df_pd > (Q3 + 1.5 * IQR))).any(axis=1)

    # Get outlier indices
    outlier_indices = np.where(~mask)[0]

    cleaned_df = df.filter(pl.Series('mask', mask))
    print(f"Removed {len(df) - len(cleaned_df):,} outliers ({(len(df) - len(cleaned_df))/len(df)*100:.1f}%)")

    return cleaned_df, outlier_indices

def encode_features(df):
    # Get categorical columns
    categorical_cols = [col for col, dtype in df.schema.items() if not dtype.is_numeric() and col not in ['id', 'ranker_id', 'request_date']]

    encoded_df = df.clone()
    encoders = {}

    for col in categorical_cols:
        if col == 'frequent_flyer':
            # Special handling for frequent flyer programs (list feature)
            # Create binary features for most common programs
            all_programs = []
            for programs in df[col].to_list():
                if isinstance(programs, list):
                    all_programs.extend(programs)

            from collections import Counter
            top_programs = Counter(all_programs).most_common(10)

            for program, _ in top_programs:
                encoded_df = encoded_df.with_columns([
                    pl.col(col).map_elements(
                        lambda x: 1 if isinstance(x, list) and program in x else 0,
                        return_dtype=pl.Int8
                    ).alias(f'ff_{program}')
                ])

            # Add count of total programs
            encoded_df = encoded_df.with_columns([
                pl.col(col).map_elements(
                    lambda x: len(x) if isinstance(x, list) else 0,
                    return_dtype=pl.Int8
                ).alias('ff_program_count')
            ])

        else:
            # Target encoding for high-cardinality categoricals like airports/carriers
            if col in ['most_common_departure_airport', 'most_common_carrier']:
                # Use frequency encoding
                value_counts = df[col].value_counts()
                freq_map = {row[0]: row[1] for row in value_counts.rows()}

                encoded_df = encoded_df.with_columns([
                    pl.col(col).replace_strict(freq_map, default=0).alias(f'{col}_frequency')
                ])
            else:
                # Standard label encoding for low-cardinality features
                unique_values = df[col].fill_null('MISSING').unique().sort().to_list()
                encoders[col] = {val: idx for idx, val in enumerate(unique_values)}

                encoded_df = encoded_df.with_columns([
                    pl.col(col).fill_null('MISSING').replace(encoders[col]).alias(f'{col}_encoded')

                ])

    # Remove original categorical columns
    final_df = encoded_df.select(pl.exclude(categorical_cols + ['frequent_flyer', 'most_common_departure_airport', 'most_common_carrier']))

    return final_df.fill_null(0), encoders

def dimensionality_reduction(scaled_features, method='pca', n_components=50):
    """Apply dimensionality reduction before clustering"""

    if method == 'pca':
        reducer = PCA(n_components=n_components, random_state=42)

    elif method == 'truncated_svd':
        reducer = TruncatedSVD(n_components=n_components, random_state=42)

    elif method == 'umap':
        reducer = umap.UMAP(n_components=n_components, random_state=42, n_neighbors=15)

    reduced_features = reducer.fit_transform(scaled_features)

    if hasattr(reducer, 'explained_variance_ratio_'):
        total_variance = reducer.explained_variance_ratio_.sum()
        print(f"{method.upper()} retained {total_variance:.3f} of total variance with {n_components} components")

    return reduced_features, reducer


def generate_clusters(features, n_clusters=10):
    agg = AgglomerativeClustering(n_clusters=n_clusters, linkage='ward')
    labels = agg.fit_predict(features)

    # Create pseudo-centroids for use in prediction, since AgglomerativeClustering doesn't have centroids
    unique_segments = np.unique(labels)
    centroids = []
    for segment_id in unique_segments:
        segment_mask = labels == segment_id
        centroid = features[segment_mask].mean(axis=0)
        centroids.append(centroid)

    centroids = np.array(centroids)

    score = silhouette_score(features, labels)

    return {
        "labels": labels,
        "centroids": centroids,
        "model": agg,
        "silhouette": score,
        "n_clusters": n_clusters
    }


def gen_optimum_num_clusters(features):
    best_agg_score = -1
    best_agg_n = 0
    best_centroids = None

    for n_clusters in [9, 10, 11, 12, 13, 14, 15, 16, 17]:
        agg = AgglomerativeClustering(n_clusters=n_clusters, linkage='ward')
        labels = agg.fit_predict(features)

        score = silhouette_score(features, labels)
        if score > best_agg_score:
            best_agg_score = score
            best_agg_n = n_clusters

            # Create pseudo-centroids for use in prediction, since AgglomerativeClustering doesn't have centroids
            unique_segments = np.unique(labels)
            centroids = []
            for segment_id in unique_segments:
                segment_mask = labels == segment_id
                centroid = features[segment_mask].mean(axis=0)
                centroids.append(centroid)
            best_centroids = np.array(centroids)

    best_agg = AgglomerativeClustering(n_clusters=best_agg_n, linkage='ward')
    agg_labels = best_agg.fit_predict(features)

    print(f"Best number of clusters: {best_agg_n}, with Silhouette score: {best_agg_score:.4f}")

    return {
        'labels': agg_labels,
        'centroids': best_centroids,
        'model': best_agg,
        'silhouette': silhouette_score(features, agg_labels),
        'n_clusters': len(set(agg_labels))
    }

In [16]:
def analyze_clusters(original_df, cluster_labels, outlier_indices=None):
    """Analyze the final clusters"""
    # Add cluster labels to original dataframe for analysis
    if original_df is not None:

        # Check if sizes match
        if len(cluster_labels) != (len(original_df) - (len(outlier_indices) if outlier_indices is not None else 0)):
            print(f"Warning: Cluster labels size ({len(cluster_labels)}) doesn't match DataFrame size ({len(original_df)})")
            print("This likely happened because outliers were removed during clustering and outlier indices not provided or incorrect.")

        else:
            if outlier_indices is not None and len(outlier_indices) > 0:
                mask = ~pl.Series(range(len(original_df))).is_in(outlier_indices)
                analysis_df = original_df.filter(mask)
            else:
                analysis_df = original_df.head(len(cluster_labels))

            # Now add cluster labels safely
            analysis_df = analysis_df.with_columns(pl.Series('cluster', cluster_labels))

            # Cluster profiling
            print("\n📊 CLUSTER PROFILES:")
            print("=" * 50)

            cluster_profiles = analysis_df.group_by('cluster').agg([
                pl.col('total_searches').mean().alias('avg_searches'),
                pl.col('is_vip').mean().alias('vip_rate'),
                pl.col('roundtrip_preference').mean().alias('roundtrip_rate'),
                pl.col('avg_booking_lead_days').mean().alias('avg_lead_days'),
                pl.col('unique_carriers_used').mean().alias('avg_carriers'),
                pl.len().alias('size')
            ]).sort('cluster')

            print(cluster_profiles)

### Generate Customer Clusters

In [17]:
# 2. Remove outliers
cust_data_cleaned, outlier_indices = remove_outliers(cust_data, method='isolation_forest', contamination=0.05)
# cust_data_cleaned, outlier_indices = remove_outliers(cust_data, method='zscore')
# cust_data_cleaned, outlier_indices = remove_outliers(cust_data, method='iqr')

Removed 1,507 outliers (5.0%)


In [18]:
# 3. Advanced encoding
cust_data_encoded, encoders = encode_features(cust_data_cleaned)

In [19]:
# 4. Feature scaling with robust scaler
scaler = RobustScaler()  # Less sensitive to outliers than StandardScaler
scaled_features = scaler.fit_transform(cust_data_encoded.to_pandas())

In [20]:
# 5. Dimensionality reduction
n_components = 0.95  # retain components that explain 95% of variance
# n_components = 30    # retain 30 components

reducer = PCA(n_components=n_components, random_state=42)
reduced_features = reducer.fit_transform(scaled_features)

if hasattr(reducer, 'explained_variance_ratio_'):
    total_variance = reducer.explained_variance_ratio_.sum()
    print(f"PCA retained {total_variance:.3f} of total variance with {reduced_features.shape[1]} components")

PCA retained 0.960 of total variance with 4 components


In [21]:
# 6. Apply clustering algorithm
# clustering_results = generate_clusters(reduced_features)
clustering_results = gen_optimum_num_clusters(reduced_features)

# 7. Clustering results
print("\n🏆 CLUSTERING RESULTS:")
print(f"  Clusters: {clustering_results['n_clusters']:2d}| Silhouette: {clustering_results['silhouette']:.4f} | ")

Best number of clusters: 13, with Silhouette score: 0.8090

🏆 CLUSTERING RESULTS:
  Clusters: 13| Silhouette: 0.8090 | 


In [22]:
analyze_clusters(cust_data_cleaned, clustering_results['labels'])


📊 CLUSTER PROFILES:
shape: (13, 7)
┌─────────┬──────────────┬──────────┬────────────────┬───────────────┬──────────────┬───────┐
│ cluster ┆ avg_searches ┆ vip_rate ┆ roundtrip_rate ┆ avg_lead_days ┆ avg_carriers ┆ size  │
│ ---     ┆ ---          ┆ ---      ┆ ---            ┆ ---           ┆ ---          ┆ ---   │
│ i64     ┆ f64          ┆ f64      ┆ f64            ┆ f64           ┆ f64          ┆ u32   │
╞═════════╪══════════════╪══════════╪════════════════╪═══════════════╪══════════════╪═══════╡
│ 0       ┆ 820.364922   ┆ 0.0      ┆ 0.483473       ┆ 11.265869     ┆ 6.323904     ┆ 707   │
│ 1       ┆ 293.627119   ┆ 1.0      ┆ 0.731554       ┆ 13.552792     ┆ 4.983051     ┆ 59    │
│ 2       ┆ 390.175765   ┆ 0.0      ┆ 0.449503       ┆ 15.455421     ┆ 7.692308     ┆ 2418  │
│ 3       ┆ 535.615385   ┆ 1.0      ┆ 0.86867        ┆ 11.837962     ┆ 5.730769     ┆ 26    │
│ 4       ┆ 295.227273   ┆ 0.0      ┆ 0.406171       ┆ 17.016755     ┆ 5.613636     ┆ 44    │
│ …       ┆ …           

In [23]:
# Add cluster labels to the customer data set
cust_data_cleaned = cust_data_cleaned.with_columns(pl.Series('customer_segment', clustering_results['labels']))

In [24]:
cust_data_cleaned.group_by('customer_segment').len()

customer_segment,len
i64,u32
0,707
12,82
8,5
1,59
4,44
…,…
5,4739
11,2
10,19573
9,970


In [25]:
# Save customer features
cust_data_cleaned.write_parquet('data/train_customer_features.parquet')

In [26]:
# Save the best Agglomerative model
joblib.dump(clustering_results, 'data/agglomerative_model.joblib')

['data/agglomerative_model.joblib']

In [27]:
# Remove unneeded objects to free memory
del cust_data
del cust_data_cleaned
del cust_data_encoded
del scaled_features
del reduced_features
del scaler
del clustering_results
gc.collect()

12

### Engineer Flight Features

In [28]:
CUSTOMER_FEATURES = ['companyID', 'sex', 'nationality', 'isVip', 'bySelf']
UNNEEDED_FEATURES = [
    'frequentFlyer', 'frequent_flyer', 'isAccess3D', 'requestDate', 'searchRoute', 'totalPrice', 'taxes', 'legs0_arrivalAt', 'legs0_duration'
]
UNNEEDED_FEATURES_REGEX = r'^legs[01].*$|^leg0_duration_q.*$|^price_q.*$'
POLARS_INDEX_COL = ['__index_level_0__']


# function to get column groups for the flight segments features that will be used in engineering the new features
def get_column_groups(df: pl.DataFrame) -> Dict[str, List[str]]:
    columns = df.columns
    return {
        'leg0_duration': [col for col in columns if col.startswith('legs0_segments') and col.endswith('_duration')],
        'leg1_duration': [col for col in columns if col.startswith('legs1_segments') and col.endswith('_duration')],
        'mkt_carrier': [col for col in columns if col.endswith('marketingCarrier_code')],
        'op_carrier': [col for col in columns if col.endswith('operatingCarrier_code')],
        'aircraft': [col for col in columns if col.endswith('aircraft_code')],
        'airport': [col for col in columns if 'airport_iata' in col or 'airport_city_iata' in col],
        'seats_avail': [col for col in columns if col.endswith('seatsAvailable')],
        'baggage_type': [col for col in columns if 'baggageAllowance_weightMeasurementType' in col],
        'baggage_allowance': [col for col in columns if 'baggageAllowance_quantity' in col],
        'cabin_class': [col for col in columns if 'cabinClass' in col]
    }


def create_basic_features(col_groups: Dict[str, List[str]]) -> List[pl.Expr]:
    """Create basic customer and route characteristics features."""
    return [
        pl.col('isAccess3D').cast(pl.Int32).alias('is_access3D'),

        # Normalized frequentFlyer program, addressing null values as null strings, and translating UT program
        pl.col('frequentFlyer').str.replace('- ЮТэйр ЗАО', 'UT').fill_null('').alias('frequent_flyer'),

        # Route characteristics
        pl.col('legs1_departureAt').is_not_null().cast(pl.Int8).alias('is_roundtrip'),
        pl.col('searchRoute').str.slice(0, 3).alias('route_origin'),

        # Hub features
        pl.col('searchRoute').str.slice(0, 3).is_in(MAJOR_HUBS).cast(pl.Int32).alias('origin_is_major_hub'),
        pl.col('searchRoute').str.slice(3, 3).is_in(MAJOR_HUBS).cast(pl.Int32).alias('destination_is_major_hub'),
        pl.max_horizontal([
            pl.col(col).is_in(MAJOR_HUBS) for col in col_groups['airport']
        ]).cast(pl.Int32).alias('includes_major_hub'),
    ]


def create_segment_features(col_groups: Dict[str, List[str]]) -> List[pl.Expr]:
    """Create flight segment related features."""
    return [
        # Number of segments per leg
        pl.concat_list([pl.col(col) for col in col_groups['leg0_duration']])
        .list.drop_nulls().list.len().alias('leg0_num_segments'),

        pl.concat_list([pl.col(col) for col in col_groups['leg1_duration']])
        .list.drop_nulls().list.len().alias('leg1_num_segments'),

        # Total segments
        pl.sum_horizontal([pl.col(col).is_not_null() for col in col_groups['aircraft']])
        .alias('total_segments'),

        # Flight time in minutes (sum of segment durations)
        pl.sum_horizontal([
            parse_duration_to_minutes(col)
            for col in col_groups['leg0_duration']
        ]).alias('leg0_flight_time_min'),

        pl.sum_horizontal([
            parse_duration_to_minutes(col)
            for col in col_groups['leg1_duration']
        ]).alias('leg1_flight_time_min'),
    ]


def create_time_features() -> List[pl.Expr]:
    """Create time-based features."""
    return [
        # Booking lead time
        ((pl.col('legs0_departureAt').str.to_datetime() -
            pl.col('requestDate').cast(pl.Datetime)) / pl.duration(days=1)).cast(pl.Int32).alias('booking_lead_days'),

        # Trip duration features
        parse_duration_to_minutes('legs0_duration').alias('leg0_duration_minutes'),
        parse_duration_to_minutes('legs1_duration').alias('leg1_duration_minutes'),
        (parse_duration_to_minutes('legs0_duration') + parse_duration_to_minutes('legs1_duration')).alias('trip_duration_minutes'),

        # Departure/arrival hours
        pl.col('legs0_departureAt').str.to_datetime().dt.hour().alias('leg0_departure_hour'),
        pl.col('legs0_departureAt').str.to_datetime().dt.weekday().alias('leg0_departure_weekday'),
        pl.col('legs0_arrivalAt').str.to_datetime().dt.hour().alias('leg0_arrival_hour'),
        pl.col('legs0_arrivalAt').str.to_datetime().dt.weekday().alias('leg0_arrival_weekday'),
        pl.col('legs1_departureAt').str.to_datetime().dt.hour().alias('leg1_departure_hour'),
        pl.col('legs1_arrivalAt').str.to_datetime().dt.hour().alias('leg1_arrival_hour'),
        pl.col('legs1_departureAt').str.to_datetime().dt.weekday().alias('leg1_departure_weekday'),
        pl.col('legs1_arrivalAt').str.to_datetime().dt.weekday().alias('leg1_arrival_weekday'),
    ]


def create_carrier_features(col_groups: Dict[str, List[str]]) -> List[pl.Expr]:
    """Create carrier and aircraft features."""
    # Create combined carrier column group for carrier features
    all_carrier_cols = col_groups['mkt_carrier'] + col_groups['op_carrier']

    # Create frequent flyer matching expressions
    ff_matches_mkt = [
        pl.when(pl.col(col).is_not_null() & (pl.col(col) != ''))
        .then(pl.col('frequent_flyer').str.contains(pl.col(col)))
        .otherwise(False)
        for col in col_groups['mkt_carrier']
    ]
    ff_matches_op = [
        pl.when(pl.col(col).is_not_null() & (pl.col(col) != ''))
        .then(pl.col('frequent_flyer').str.contains(pl.col(col)))
        .otherwise(False)
        for col in col_groups['op_carrier']
    ]

    return [
        # Carrier features
        pl.concat_list([pl.col(col) for col in all_carrier_cols])
        .list.drop_nulls().list.unique().list.len().alias('unique_carriers'),

        pl.coalesce([pl.col(col) for col in all_carrier_cols]).alias('primary_carrier'),
        pl.col('legs0_segments0_marketingCarrier_code').alias('marketing_carrier'),

        # Frequent flyer matching (check if FF programs match carriers)
        pl.max_horizontal(ff_matches_mkt).cast(pl.Int32).alias('has_mkt_ff'),
        pl.max_horizontal(ff_matches_op).cast(pl.Int32).alias('has_op_ff'),

        # Aircraft features
        pl.concat_list([pl.col(col) for col in col_groups['aircraft']])
        .list.drop_nulls().list.unique().list.len().alias('aircraft_diversity'),

        pl.coalesce([pl.col(col) for col in col_groups['aircraft']]).alias('primary_aircraft'),
    ]


def create_service_features(col_groups: Dict[str, List[str]]) -> List[pl.Expr]:
    """Create service-related features (fees, cabin, seats, baggage)."""
    return [
        # Cancellation/exchange fees
        (
            ((pl.col('miniRules0_monetaryAmount') > 0) & pl.col('miniRules0_monetaryAmount').is_not_null()) |
            ((pl.col('miniRules0_percentage') > 0) & pl.col('miniRules0_percentage').is_not_null())
        ).cast(pl.Int32).alias('has_cancellation_fee'),

        (
            ((pl.col('miniRules1_monetaryAmount') > 0) & pl.col('miniRules1_monetaryAmount').is_not_null()) |
            ((pl.col('miniRules1_percentage') > 0) & pl.col('miniRules1_percentage').is_not_null())
        ).cast(pl.Int32).alias('has_exchange_fee'),

        # Cabin class features
        pl.max_horizontal([pl.col(col) for col in col_groups['cabin_class']]).alias('max_cabin_class'),
        pl.mean_horizontal([pl.col(col) for col in col_groups['cabin_class']]).alias('avg_cabin_class'),

        # Seat availability (using for understanding popularity/scarcity)
        pl.min_horizontal([pl.col(col) for col in col_groups['seats_avail']]).alias('min_seats_available'),
        pl.sum_horizontal([pl.col(col).fill_null(0) for col in col_groups['seats_avail']]).alias('total_seats_available'),

        # Baggage allowance
        pl.min_horizontal([
            pl.when(pl.col(type_col) == 0).then(pl.col(qty_col)).otherwise(pl.lit(None))
            for type_col, qty_col in zip(col_groups['baggage_type'], col_groups['baggage_allowance'])
        ]).fill_null(0).alias("baggage_allowance_quantity"),

        pl.min_horizontal([
            pl.when(pl.col(type_col) == 1).then(pl.col(qty_col)).otherwise(pl.lit(None))
            for type_col, qty_col in zip(col_groups['baggage_type'], col_groups['baggage_allowance'])
        ]).fill_null(0).alias("baggage_allowance_weight")
    ]


def add_route_popularity(lazy_df: pl.LazyFrame) -> pl.LazyFrame:
    """Add search route popularity features."""
    return lazy_df.with_columns([
        pl.len().over('searchRoute').alias('route_popularity')
    ]).with_columns([
        pl.col('route_popularity').log1p().alias('route_popularity_log')
    ])


def create_window_based_flight_features(lazy_df: pl.LazyFrame) -> pl.LazyFrame:
    return lazy_df.with_columns([
        # calculate price percentile over search session
        ((pl.col('totalPrice').rank(method='min').over('ranker_id') - 1) /
        (pl.col('totalPrice').count().over('ranker_id') - 1) * 100)
        .fill_null(50.0).alias('price_percentile'),

        # Price features (always relative to search session)
        (pl.col('totalPrice').rank(method='ordinal').over('ranker_id') /
         pl.col('totalPrice').count().over('ranker_id')).alias('price_rank_pct'),
        (pl.col('totalPrice') / pl.col('totalPrice').min().over('ranker_id')).alias('price_ratio_to_min'),
    ]).with_columns([

        # Route popularity features
        pl.len().over('searchRoute').alias('route_popularity')
    ]).with_columns([
        pl.col('route_popularity').log1p().alias('route_popularity_log')
    ])


def create_derived_flight_features() -> List[pl.Expr]:
    return [
        pl.col('leg0_departure_hour').is_between(6, 22).cast(pl.Int8).alias('is_daytime'),
        (pl.col('leg1_departure_weekday') >= 5).cast(pl.Int8).alias('is_weekend'),
        (pl.col('leg0_departure_hour').is_between(6, 22) &
         (~pl.col('leg0_arrival_hour').is_between(6, 22)))
        .cast(pl.Int8).alias('is_redeye'),
        (pl.col('total_segments') > 1).cast(pl.Int8).alias('has_connections'),
    ]

def extract_flight_features(df: pl.DataFrame) -> pl.DataFrame:
    """ Create flight-level features"""
    # Check if already processed
    if df.height > 0 and 'total_duration_mins' in df.columns:
        return df


    # Get column groups
    col_groups = get_column_groups(df)

    # Start with lazy frame, dropping unnecessary columns
    lazy_df = df.drop(CUSTOMER_FEATURES + POLARS_INDEX_COL, strict=False).lazy()
    lazy_df = create_window_based_flight_features(lazy_df)

    # Apply all feature groups
    lazy_df = lazy_df.with_columns([
        *create_basic_features(col_groups),
        *create_segment_features(col_groups),
        *create_time_features(),
        *create_service_features(col_groups)
    ])

    # Add carrier features (requires frequent_flyer from create_basic_features() step)
    lazy_df = lazy_df.with_columns(create_carrier_features(col_groups))

    # Add derived features
    lazy_df = lazy_df.with_columns(create_derived_flight_features())

    # Add window based derived flight features
    lazy_df = create_window_based_flight_features(lazy_df)

    # Materialize to generate new features, drop unused features, and fill nulls
    return (lazy_df
            .collect()
            .select(~cs.matches(UNNEEDED_FEATURES_REGEX) & ~cs.by_name(UNNEEDED_FEATURES))
            .fill_null(0)
            )


In [29]:
train_df = extract_flight_features(train_df)

In [30]:
train_df.head(100)

Id,corporateTariffCode,miniRules0_monetaryAmount,miniRules0_percentage,miniRules0_statusInfos,miniRules1_monetaryAmount,miniRules1_percentage,miniRules1_statusInfos,pricingInfo_isAccessTP,pricingInfo_passengerCount,profileId,ranker_id,selected,price_percentile,price_rank_pct,price_ratio_to_min,route_popularity,route_popularity_log,is_access3D,is_roundtrip,route_origin,origin_is_major_hub,destination_is_major_hub,includes_major_hub,leg0_num_segments,leg1_num_segments,total_segments,leg0_flight_time_min,leg1_flight_time_min,booking_lead_days,leg0_duration_minutes,leg1_duration_minutes,trip_duration_minutes,leg0_departure_hour,leg0_departure_weekday,leg0_arrival_hour,leg0_arrival_weekday,leg1_departure_hour,leg1_arrival_hour,leg1_departure_weekday,leg1_arrival_weekday,has_cancellation_fee,has_exchange_fee,max_cabin_class,avg_cabin_class,min_seats_available,total_seats_available,baggage_allowance_quantity,baggage_allowance_weight,unique_carriers,primary_carrier,marketing_carrier,has_mkt_ff,has_op_ff,aircraft_diversity,primary_aircraft,is_daytime,is_weekend,is_redeye,has_connections
i64,i64,f64,f64,f64,f64,f64,f64,f64,i64,i64,str,i64,f64,f64,f64,u32,f64,i32,i8,str,i32,i32,i32,u32,u32,u32,i32,i32,i32,i32,i32,i32,i8,i8,i8,i8,i8,i8,i8,i8,i32,i32,f64,f64,f64,f64,f64,f64,u32,str,str,i32,i32,u32,str,i8,i8,i8,i8
25,123,4000.0,0.0,1.0,0.0,0.0,0.0,1.0,1,2087904,"""905909166d934c618ad55ab7f5cea5…",0,0.0,0.066667,1.0,32,3.496508,1,0,"""TOF""",0,0,0,2,0,2,140,0,29,730,0,730,11,6,21,6,0,0,0,0,1,0,1.0,1.0,1.0,10.0,0.0,0.0,1,"""S7""","""S7""",0,0,1,"""E70""",1,0,0,1
26,0,4000.0,0.0,1.0,0.0,0.0,0.0,1.0,1,2087904,"""905909166d934c618ad55ab7f5cea5…",0,7.142857,0.133333,1.050282,32,3.496508,0,0,"""TOF""",0,0,0,2,0,2,140,0,29,730,0,730,11,6,21,6,0,0,0,0,1,0,1.0,1.0,1.0,10.0,0.0,0.0,1,"""S7""","""S7""",0,0,1,"""E70""",1,0,0,1
27,123,2300.0,0.0,1.0,3500.0,0.0,1.0,1.0,1,2087904,"""905909166d934c618ad55ab7f5cea5…",0,35.714286,0.4,1.235318,32,3.496508,1,0,"""TOF""",0,0,0,2,0,2,140,0,29,730,0,730,11,6,21,6,0,0,0,0,1,1,1.0,1.0,1.0,10.0,1.0,0.0,1,"""S7""","""S7""",0,0,1,"""E70""",1,0,0,1
28,0,2300.0,0.0,1.0,3500.0,0.0,1.0,1.0,1,2087904,"""905909166d934c618ad55ab7f5cea5…",0,42.857143,0.466667,1.297667,32,3.496508,0,0,"""TOF""",0,0,0,2,0,2,140,0,29,730,0,730,11,6,21,6,0,0,0,0,1,1,1.0,1.0,1.0,10.0,1.0,0.0,1,"""S7""","""S7""",0,0,1,"""E70""",1,0,0,1
29,123,0.0,0.0,1.0,0.0,0.0,1.0,1.0,1,2087904,"""905909166d934c618ad55ab7f5cea5…",0,78.571429,0.8,2.648331,32,3.496508,1,0,"""TOF""",0,0,0,2,0,2,140,0,29,730,0,730,11,6,21,6,0,0,0,0,0,0,1.0,1.0,9.0,18.0,1.0,0.0,1,"""S7""","""S7""",0,0,1,"""E70""",1,0,0,1
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
120,0,2300.0,0.0,1.0,3500.0,0.0,1.0,1.0,1,3380969,"""e0f9319a8b3048cdb1c974395e599e…",1,4.347826,0.083333,1.293048,452,6.115892,0,0,"""GDX""",0,0,0,1,0,1,330,0,42,330,0,330,10,5,11,5,0,0,0,0,1,1,1.0,1.0,6.0,6.0,1.0,0.0,1,"""S7""","""S7""",0,0,1,"""73H""",1,0,0,0
121,0,0.0,0.0,1.0,0.0,0.0,1.0,1.0,1,3380969,"""e0f9319a8b3048cdb1c974395e599e…",0,47.826087,0.5,1.98212,452,6.115892,0,0,"""GDX""",0,0,0,1,0,1,330,0,42,330,0,330,10,5,11,5,0,0,0,0,0,0,1.0,1.0,6.0,6.0,1.0,0.0,1,"""S7""","""S7""",0,0,1,"""73H""",1,0,0,0
122,0,4000.0,0.0,1.0,0.0,0.0,0.0,0.0,1,3380969,"""e0f9319a8b3048cdb1c974395e599e…",0,52.173913,0.541667,2.263128,452,6.115892,0,0,"""GDX""",0,0,0,1,0,1,330,0,42,330,0,330,10,5,11,5,0,0,0,0,1,0,2.0,2.0,4.0,4.0,1.0,0.0,1,"""S7""","""S7""",0,0,1,"""73H""",1,0,0,0
123,0,2300.0,0.0,1.0,3500.0,0.0,1.0,0.0,1,3380969,"""e0f9319a8b3048cdb1c974395e599e…",0,60.869565,0.625,2.901737,452,6.115892,0,0,"""GDX""",0,0,0,1,0,1,330,0,42,330,0,330,10,5,11,5,0,0,0,0,1,1,2.0,2.0,4.0,4.0,1.0,0.0,1,"""S7""","""S7""",0,0,1,"""73H""",1,0,0,0


In [31]:
# Save flight features for train data
train_df.write_parquet('data/train_flight_features.parquet')

### Prepare Training Data

Now combining customer features with clustering segmentation results and flight features

In [67]:
def interactive_features()  -> List[pl.Expr]:
    return [
        (pl.col('is_daytime') * (1.0 - pl.col('night_flight_preference').fill_null(0.5))).alias('daytime_alignment'),
        (pl.col('is_weekend') * pl.col('weekend_travel_rate').fill_null(0.5)).alias('weekend_alignment'),
        (pl.col('price_rank_pct') * pl.col('price_position_preference').fill_null(0.5)).alias('price_preference_match'),
        (pl.col('primary_carrier') == pl.col('most_common_carrier').fill_null('unknown')).cast(pl.Int8).alias('carrier_loyalty_match'),
    ]

def prepare_combined_data(flight_features, cust_features):
    if not isinstance(cust_features, pl.LazyFrame):
        cust_features = cust_features.lazy()

    # Join the customer features with the flight features, removing outlier profiles
    flight_features = (
        flight_features.lazy()
        .join(cust_features, on='profileId', how='inner').drop(pl.col('profileId'))
        .with_columns([
            *interactive_features(),  # Additional interactive features between customer and flight search data sets
        ])
        .collect()
    )

    return flight_features

In [39]:
# Prepare training data by combining customer features with flight features
train_df = prepare_combined_data(train_df, pl.scan_parquet('data/train_customer_features.parquet'))  # loading cust_features lazy

FileNotFoundError: The system cannot find the file specified. (os error 2): data/train_full_features.parquet

Resolved plan until failure:

	---> FAILED HERE RESOLVING 'sink' <---
DF ["Id", "corporateTariffCode", "miniRules0_monetaryAmount", "miniRules0_percentage", ...]; PROJECT */121 COLUMNS

In [34]:
train_df.head(100)

Id,corporateTariffCode,miniRules0_monetaryAmount,miniRules0_percentage,miniRules0_statusInfos,miniRules1_monetaryAmount,miniRules1_percentage,miniRules1_statusInfos,pricingInfo_isAccessTP,pricingInfo_passengerCount,ranker_id,selected,price_percentile,price_rank_pct,price_ratio_to_min,route_popularity,route_popularity_log,is_access3D,is_roundtrip,route_origin,origin_is_major_hub,destination_is_major_hub,includes_major_hub,leg0_num_segments,leg1_num_segments,total_segments,leg0_flight_time_min,leg1_flight_time_min,booking_lead_days,leg0_duration_minutes,leg1_duration_minutes,trip_duration_minutes,leg0_departure_hour,leg0_departure_weekday,leg0_arrival_hour,leg0_arrival_weekday,leg1_departure_hour,…,route_loyalty,hub_preference,short_haul_preference,connection_tolerance,preferred_duration_quartile,price_to_duration_sensitivity,avg_price_per_minute,price_per_minute_variance,price_position_preference,premium_economy_preference,consistent_price_tier,preferred_price_tier,baggage_qty_preference,baggage_weight_preference,loyalty_program_utilization,convenience_priority_score,loyalty_vs_price_index,planning_consistency_score,luxury_index,search_intensity_per_route,lead_time_variance,lead_time_skew,carrier_diversity,airport_diversity,cabin_class_range,customer_tier,vip_search_intensity,vip_carrier_diversity,vip_cabin_preference,corp_search_volume,corp_roundtrip_pref,corp_planning_variance,customer_segment,daytime_alignment,weekend_alignment,price_preference_match,carrier_loyalty_match
i64,i64,f64,f64,f64,f64,f64,f64,f64,i64,str,i64,f64,f64,f64,u32,f64,i32,i8,str,i32,i32,i32,u32,u32,u32,i32,i32,i32,i32,i32,i32,i8,i8,i8,i8,i8,…,f64,f64,f64,f64,i32,f64,f64,f64,f64,f64,f64,i32,f64,f64,f64,f64,f64,f64,f64,f64,i32,f64,f64,f64,f64,i8,f64,f64,f64,i64,f64,i32,i64,f64,f64,f64,i8
25,123,4000.0,0.0,1.0,0.0,0.0,0.0,1.0,1,"""905909166d934c618ad55ab7f5cea5…",0,0.0,0.066667,1.0,32,3.496508,1,0,"""TOF""",0,0,0,2,0,2,140,0,29,730,0,730,11,6,21,6,0,…,0.958333,0.0,0.0,1.75,1,0.283051,56.281214,62.121473,50.0,0.0,0.609132,2,0.666667,0.0,0.0,2.568737,-5.0,0.2,45.0,24.0,4,0.956897,0.083333,0.041667,0.0,1,0.0,0.0,0.0,24,0.0,4,10,1.0,0.0,3.333333,1
26,0,4000.0,0.0,1.0,0.0,0.0,0.0,1.0,1,"""905909166d934c618ad55ab7f5cea5…",0,7.142857,0.133333,1.050282,32,3.496508,0,0,"""TOF""",0,0,0,2,0,2,140,0,29,730,0,730,11,6,21,6,0,…,0.958333,0.0,0.0,1.75,1,0.283051,56.281214,62.121473,50.0,0.0,0.609132,2,0.666667,0.0,0.0,2.568737,-5.0,0.2,45.0,24.0,4,0.956897,0.083333,0.041667,0.0,1,0.0,0.0,0.0,24,0.0,4,10,1.0,0.0,6.666667,1
27,123,2300.0,0.0,1.0,3500.0,0.0,1.0,1.0,1,"""905909166d934c618ad55ab7f5cea5…",0,35.714286,0.4,1.235318,32,3.496508,1,0,"""TOF""",0,0,0,2,0,2,140,0,29,730,0,730,11,6,21,6,0,…,0.958333,0.0,0.0,1.75,1,0.283051,56.281214,62.121473,50.0,0.0,0.609132,2,0.666667,0.0,0.0,2.568737,-5.0,0.2,45.0,24.0,4,0.956897,0.083333,0.041667,0.0,1,0.0,0.0,0.0,24,0.0,4,10,1.0,0.0,20.0,1
28,0,2300.0,0.0,1.0,3500.0,0.0,1.0,1.0,1,"""905909166d934c618ad55ab7f5cea5…",0,42.857143,0.466667,1.297667,32,3.496508,0,0,"""TOF""",0,0,0,2,0,2,140,0,29,730,0,730,11,6,21,6,0,…,0.958333,0.0,0.0,1.75,1,0.283051,56.281214,62.121473,50.0,0.0,0.609132,2,0.666667,0.0,0.0,2.568737,-5.0,0.2,45.0,24.0,4,0.956897,0.083333,0.041667,0.0,1,0.0,0.0,0.0,24,0.0,4,10,1.0,0.0,23.333333,1
29,123,0.0,0.0,1.0,0.0,0.0,1.0,1.0,1,"""905909166d934c618ad55ab7f5cea5…",0,78.571429,0.8,2.648331,32,3.496508,1,0,"""TOF""",0,0,0,2,0,2,140,0,29,730,0,730,11,6,21,6,0,…,0.958333,0.0,0.0,1.75,1,0.283051,56.281214,62.121473,50.0,0.0,0.609132,2,0.666667,0.0,0.0,2.568737,-5.0,0.2,45.0,24.0,4,0.956897,0.083333,0.041667,0.0,1,0.0,0.0,0.0,24,0.0,4,10,1.0,0.0,40.0,1
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
120,0,2300.0,0.0,1.0,3500.0,0.0,1.0,1.0,1,"""e0f9319a8b3048cdb1c974395e599e…",1,4.347826,0.083333,1.293048,452,6.115892,0,0,"""GDX""",0,0,0,1,0,1,330,0,42,330,0,330,10,5,11,5,0,…,0.971154,0.0,0.0,1.740385,2,0.375849,98.484632,62.123105,49.350265,0.461538,0.619922,1,1.144231,0.0,0.0,-34.951512,-4.935026,0.022222,53.905902,34.666667,44,0.875862,0.019231,0.019231,1.0,1,0.0,0.0,0.0,0,0.0,0,10,0.701923,0.0,4.112522,0
121,0,0.0,0.0,1.0,0.0,0.0,1.0,1.0,1,"""e0f9319a8b3048cdb1c974395e599e…",0,47.826087,0.5,1.98212,452,6.115892,0,0,"""GDX""",0,0,0,1,0,1,330,0,42,330,0,330,10,5,11,5,0,…,0.971154,0.0,0.0,1.740385,2,0.375849,98.484632,62.123105,49.350265,0.461538,0.619922,1,1.144231,0.0,0.0,-34.951512,-4.935026,0.022222,53.905902,34.666667,44,0.875862,0.019231,0.019231,1.0,1,0.0,0.0,0.0,0,0.0,0,10,0.701923,0.0,24.675132,0
122,0,4000.0,0.0,1.0,0.0,0.0,0.0,0.0,1,"""e0f9319a8b3048cdb1c974395e599e…",0,52.173913,0.541667,2.263128,452,6.115892,0,0,"""GDX""",0,0,0,1,0,1,330,0,42,330,0,330,10,5,11,5,0,…,0.971154,0.0,0.0,1.740385,2,0.375849,98.484632,62.123105,49.350265,0.461538,0.619922,1,1.144231,0.0,0.0,-34.951512,-4.935026,0.022222,53.905902,34.666667,44,0.875862,0.019231,0.019231,1.0,1,0.0,0.0,0.0,0,0.0,0,10,0.701923,0.0,26.731393,0
123,0,2300.0,0.0,1.0,3500.0,0.0,1.0,0.0,1,"""e0f9319a8b3048cdb1c974395e599e…",0,60.869565,0.625,2.901737,452,6.115892,0,0,"""GDX""",0,0,0,1,0,1,330,0,42,330,0,330,10,5,11,5,0,…,0.971154,0.0,0.0,1.740385,2,0.375849,98.484632,62.123105,49.350265,0.461538,0.619922,1,1.144231,0.0,0.0,-34.951512,-4.935026,0.022222,53.905902,34.666667,44,0.875862,0.019231,0.019231,1.0,1,0.0,0.0,0.0,0,0.0,0,10,0.701923,0.0,30.843915,0


In [35]:
train_df.shape

(12796940, 121)

In [36]:
def prepare_ranking_data(df):
    # Sort the data sets for ranking
    sort_indices = df['ranker_id'].arg_sort()
    X = df[sort_indices]
    y = X['selected'].to_numpy()

    # Get search group sizes for chunked ranking
    group_sizes = (
        X
        .group_by('ranker_id', maintain_order=True)
        .len()
        .select('len')
        .to_numpy()
        .flatten()
    )

    return X.drop(['ranker_id','selected']), y, group_sizes


def train_lgb_ranker(X, y, group_sizes):
    """
    Train LGBRanker with proper group structure
    """
    categorical_features = [col for col in X.columns if not X[col].dtype.is_numeric()]
    X = X.with_columns([pl.col(col).cast(pl.Categorical) for col in categorical_features])

    # Convert to pandas for LightGBM compatibility
    X = X.to_pandas()

    train_data = lgb.Dataset(
        X, label=y,
        group=group_sizes,  # This tells LGB how many samples per search
        categorical_feature=categorical_features
    )

    params = {
        'objective': 'lambdarank',
        'metric': 'ndcg',
        'ndcg_eval_at': [3],
        'boosting_type': 'gbdt',
        'num_leaves': 255,
        'learning_rate': 0.1,
        'feature_fraction': 0.8,
        'bagging_fraction': 0.8,
        'bagging_freq': 5,
        'min_data_in_leaf': 50,
        'num_threads': -1,
        'verbose': -1
    }

    model = lgb.train(
        params,
        train_data,
        num_boost_round=1000,
        # callbacks=[lgb.early_stopping(100), lgb.log_evaluation(100)]
        callbacks=[lgb.log_evaluation(100)]
    )

    return model

In [37]:
def calculate_hit_rate_at_k(y_true, y_pred, groups, k=3):
    """
    Calculate Hit Rate @ K for ranking model

    Parameters:
    - y_true: actual binary labels (1 if selected, 0 if not)
    - y_pred: predicted scores from ranker
    - groups: group sizes (number of flights per search)
    - k: number of top predictions to consider (3 in your case)

    Returns:
    - hit_rate: fraction of searches where at least 1 of top-k predictions was correct
    """
    hits = 0
    total_searches = 0
    start_idx = 0

    for group_size in groups:
        # Extract data for this search group
        group_true = y_true[start_idx:start_idx + group_size]
        group_pred = y_pred[start_idx:start_idx + group_size]

        # Get indices of actual selections
        actual_selections = set(np.where(group_true == 1)[0])

        # Get top-k predictions (indices sorted by prediction score)
        top_k_indices = np.argsort(group_pred)[-k:]  # Get indices of top-k scores
        top_k_indices = set(top_k_indices)

        # Check if any top-k prediction matches actual selection
        if len(actual_selections.intersection(top_k_indices)) > 0:
            hits += 1

        total_searches += 1
        start_idx += group_size

    hit_rate = hits / total_searches if total_searches > 0 else 0
    return hit_rate

def detailed_hit_rate_evaluation(y_true, y_pred, groups, ranker_ids=None, k=3):
    """
    More detailed evaluation with additional metrics
    """
    hits = 0
    total_searches = 0
    hit_details = []
    start_idx = 0

    for i, group_size in enumerate(groups):
        group_true = y_true[start_idx:start_idx + group_size]
        group_pred = y_pred[start_idx:start_idx + group_size]

        # Actual selections
        actual_selections = np.where(group_true == 1)[0]
        num_actual_selections = len(actual_selections)

        # Top-k predictions
        top_k_indices = np.argsort(group_pred)[-k:]

        # Calculate hit
        hit = len(set(actual_selections).intersection(set(top_k_indices))) > 0

        if hit:
            hits += 1

        # Store details for analysis
        search_id = ranker_ids[i] if ranker_ids is not None else i
        hit_details.append({
            'search_id': search_id,
            'group_size': group_size,
            'num_actual_selections': num_actual_selections,
            'hit': hit,
            'top_k_scores': group_pred[top_k_indices].tolist(),
            'actual_selections_scores': group_pred[actual_selections].tolist() if len(actual_selections) > 0 else []
        })

        total_searches += 1
        start_idx += group_size

    hit_rate = hits / total_searches

    return hit_rate, hit_details

In [38]:
def evaluate_model_hit_rate(model, X_val, y_val, group_sizes_val, k=3):
    """
    Evaluate trained LGBRanker model using HitRate@K
    """
    # Get predictions
    y_pred = model.predict(X_val)

    # Calculate hit rate
    hit_rate = calculate_hit_rate_at_k(y_val, y_pred, group_sizes_val, k=k)

    print(f"Hit Rate @ {k}: {hit_rate:.4f}")
    return hit_rate

def cross_validate_with_hit_rate(df, n_folds=5):
    """
    Cross-validation preserving search groups and calculating HitRate@3
    """
    # Get unique search IDs
    unique_searches = df['ranker_id'].unique()
    np.random.shuffle(unique_searches)

    fold_size = len(unique_searches) // n_folds
    hit_rates = []

    for fold in range(n_folds):
        print(f"\nFold {fold + 1}/{n_folds}")

        # Split search IDs
        start_idx = fold * fold_size
        end_idx = start_idx + fold_size if fold < n_folds - 1 else len(unique_searches)
        val_searches = unique_searches[start_idx:end_idx]
        train_searches = np.setdiff1d(unique_searches, val_searches)

        # Split data
        train = df[df['ranker_id'].isin(train_searches)]
        val = df[df['ranker_id'].isin(val_searches)]

        # Prepare training data
        X_train, y_train, groups_train = prepare_ranking_data(train)
        X_val, y_val, groups_val = prepare_ranking_data(val)

        # Train model
        model = train_lgb_ranker(X_train, y_train, groups_train)

        # Evaluate
        hit_rate = evaluate_model_hit_rate(model, X_val, y_val, groups_val)
        hit_rates.append(hit_rate)

    print(f"\nCross-validation results:")
    print(f"Mean Hit Rate @ 3: {np.mean(hit_rates):.4f} (+/- {np.std(hit_rates)*2:.4f})")
    return hit_rates

In [40]:
try:
    train_df.head()
except NameError:
    train_df = pl.read_parquet('data/train_flight_features.parquet')
    print(f'loaded flight features from parquet file: {train_df.shape}')

In [41]:
# Prepare Train data for ranking
X, y, group_sizes = prepare_ranking_data(train_df)

In [42]:
# Remove the feature dataframe to free up memory
del train_df
gc.collect()

2866

In [43]:
# Train model
model = train_lgb_ranker(X, y, group_sizes)

In [45]:
# Save the trained flight selection model
joblib.dump({'lgb_ranker_model': model}, 'data/lgb_ranker_model.joblib')

['data/lgb_ranker_model.joblib']

In [46]:
del X
del y
del group_sizes
gc.collect()

8

### Prepare Testing Data

In [47]:
# Load Test data set
try:
    test_df.head()
except NameError:
    test_df = pl.read_parquet('data/test_df.parquet')
    print(f'loaded test data set features from parquet file: {test_df.shape}')

test_df.head()

loaded test data set features from parquet file: (3578079, 127)


Id,bySelf,companyID,corporateTariffCode,frequentFlyer,nationality,isAccess3D,isVip,legs0_arrivalAt,legs0_departureAt,legs0_duration,legs0_segments0_aircraft_code,legs0_segments0_arrivalTo_airport_city_iata,legs0_segments0_arrivalTo_airport_iata,legs0_segments0_baggageAllowance_quantity,legs0_segments0_baggageAllowance_weightMeasurementType,legs0_segments0_cabinClass,legs0_segments0_departureFrom_airport_iata,legs0_segments0_duration,legs0_segments0_flightNumber,legs0_segments0_marketingCarrier_code,legs0_segments0_operatingCarrier_code,legs0_segments0_seatsAvailable,legs0_segments1_aircraft_code,legs0_segments1_arrivalTo_airport_city_iata,legs0_segments1_arrivalTo_airport_iata,legs0_segments1_baggageAllowance_quantity,legs0_segments1_baggageAllowance_weightMeasurementType,legs0_segments1_cabinClass,legs0_segments1_departureFrom_airport_iata,legs0_segments1_duration,legs0_segments1_flightNumber,legs0_segments1_marketingCarrier_code,legs0_segments1_operatingCarrier_code,legs0_segments1_seatsAvailable,legs0_segments2_aircraft_code,legs0_segments2_arrivalTo_airport_city_iata,…,legs1_segments2_baggageAllowance_weightMeasurementType,legs1_segments2_cabinClass,legs1_segments2_departureFrom_airport_iata,legs1_segments2_duration,legs1_segments2_flightNumber,legs1_segments2_marketingCarrier_code,legs1_segments2_operatingCarrier_code,legs1_segments2_seatsAvailable,legs1_segments3_aircraft_code,legs1_segments3_arrivalTo_airport_city_iata,legs1_segments3_arrivalTo_airport_iata,legs1_segments3_baggageAllowance_quantity,legs1_segments3_baggageAllowance_weightMeasurementType,legs1_segments3_cabinClass,legs1_segments3_departureFrom_airport_iata,legs1_segments3_duration,legs1_segments3_flightNumber,legs1_segments3_marketingCarrier_code,legs1_segments3_operatingCarrier_code,legs1_segments3_seatsAvailable,miniRules0_monetaryAmount,miniRules0_percentage,miniRules0_statusInfos,miniRules1_monetaryAmount,miniRules1_percentage,miniRules1_statusInfos,pricingInfo_isAccessTP,pricingInfo_passengerCount,profileId,ranker_id,requestDate,searchRoute,sex,taxes,totalPrice,selected,__index_level_0__
i64,bool,i64,i64,str,i64,bool,bool,str,str,str,str,str,str,f64,f64,f64,str,str,str,str,str,f64,str,str,str,f64,f64,f64,str,str,str,str,str,f64,str,str,…,f64,f64,str,str,str,str,str,f64,str,str,str,f64,f64,f64,str,str,str,str,str,f64,f64,f64,f64,f64,f64,f64,f64,i64,i64,str,datetime[ns],str,bool,f64,f64,i64,i64
0,True,57323,,"""S7/SU/UT""",36,False,False,"""2024-06-15T16:20:00""","""2024-06-15T15:40:00""","""02:40:00""","""YK2""","""KJA""","""KJA""",1.0,0.0,1.0,"""TLK""","""02:40:00""","""216""","""KV""","""KV""",9.0,,,,,,,,,,,,,,,…,,,,,,,,,,,,,,,,,,,,,,,,,,,1.0,1,2087645,"""98ce0dabf6964640b63079fbafd42c…",2024-05-17 03:03:08,"""TLKKJA/KJATLK""",True,370.0,16884.0,1,0
1,True,57323,123.0,"""S7/SU/UT""",36,True,False,"""2024-06-15T14:50:00""","""2024-06-15T09:25:00""","""07:25:00""","""E70""","""OVB""","""OVB""",1.0,0.0,1.0,"""TLK""","""02:50:00""","""5358""","""S7""","""S7""",4.0,"""E70""","""KJA""","""KJA""",1.0,0.0,1.0,"""OVB""","""01:20:00""","""5311""","""S7""","""S7""",4.0,,,…,,,,,,,,,,,,,,,,,,,,,2300.0,,1.0,3500.0,,1.0,1.0,1,2087645,"""98ce0dabf6964640b63079fbafd42c…",2024-05-17 03:03:08,"""TLKKJA/KJATLK""",True,2240.0,51125.0,0,1
2,True,57323,,"""S7/SU/UT""",36,False,False,"""2024-06-15T14:50:00""","""2024-06-15T09:25:00""","""07:25:00""","""E70""","""OVB""","""OVB""",1.0,0.0,1.0,"""TLK""","""02:50:00""","""5358""","""S7""","""S7""",4.0,"""E70""","""KJA""","""KJA""",1.0,0.0,1.0,"""OVB""","""01:20:00""","""5311""","""S7""","""S7""",4.0,,,…,,,,,,,,,,,,,,,,,,,,,2300.0,,1.0,3500.0,,1.0,1.0,1,2087645,"""98ce0dabf6964640b63079fbafd42c…",2024-05-17 03:03:08,"""TLKKJA/KJATLK""",True,2240.0,53695.0,0,2
3,True,57323,123.0,"""S7/SU/UT""",36,True,False,"""2024-06-15T14:50:00""","""2024-06-15T09:25:00""","""07:25:00""","""E70""","""OVB""","""OVB""",1.0,0.0,1.0,"""TLK""","""02:50:00""","""5358""","""S7""","""S7""",4.0,"""E70""","""KJA""","""KJA""",1.0,0.0,1.0,"""OVB""","""01:20:00""","""5311""","""S7""","""S7""",4.0,,,…,,,,,,,,,,,,,,,,,,,,,0.0,,1.0,0.0,,1.0,1.0,1,2087645,"""98ce0dabf6964640b63079fbafd42c…",2024-05-17 03:03:08,"""TLKKJA/KJATLK""",True,2240.0,81880.0,0,3
4,True,57323,,"""S7/SU/UT""",36,False,False,"""2024-06-15T14:50:00""","""2024-06-15T09:25:00""","""07:25:00""","""E70""","""OVB""","""OVB""",1.0,0.0,1.0,"""TLK""","""02:50:00""","""5358""","""S7""","""S7""",4.0,"""E70""","""KJA""","""KJA""",1.0,0.0,1.0,"""OVB""","""01:20:00""","""5311""","""S7""","""S7""",4.0,,,…,,,,,,,,,,,,,,,,,,,,,0.0,,1.0,0.0,,1.0,1.0,1,2087645,"""98ce0dabf6964640b63079fbafd42c…",2024-05-17 03:03:08,"""TLKKJA/KJATLK""",True,2240.0,86070.0,0,4


In [48]:
# 1. Customer Feature Engineering
cust_data = extract_customer_features(test_df)
print(f'Generated {len(cust_data.columns)} customer features for {len(cust_data)} customers')

Generated 58 customer features for 13737 customers


In [49]:
# 2. Encode customer features
cust_data_encoded, encoders = encode_features(cust_data)

In [50]:
# 3. Feature scaling with robust scaler
scaler = RobustScaler()  # Less sensitive to outliers than StandardScaler
scaled_features = scaler.fit_transform(cust_data_encoded.to_pandas())

In [51]:
# 5. Dimensionality reduction
reduced_features = reducer.transform(scaled_features)

In [52]:
def predict_cluster_with_centroids(X, centroids):
    """
    Predict clusters by assigning to nearest centroid
    """
    # Calculate distances to all centroids
    distances = euclidean_distances(X, centroids)

    # Assign to nearest centroid
    predictions = np.argmin(distances, axis=1)

    # Get minimum distances as confidence measure
    min_distances = np.min(distances, axis=1)

    return predictions, min_distances

In [55]:
# Load Agglomeratvie clustering results
try:
    clustering_results
except NameError:
    clustering_results = joblib.load('data/agglomerative_model.joblib')
    print(f'Loaded Agglomerative Clustering results from joblib file for {clustering_results['n_clusters']} clusters')

Loaded Agglomerative Clustering results from joblib file for 13 clusters


In [56]:
clustering_results['centroids']

array([[-1.24273156e+00, -3.03288865e-01,  2.28449450e-01,
         1.97093818e+01],
       [ 1.43168141e+02,  5.77070263e+00, -4.73924696e+00,
        -3.42338477e-01],
       [-1.26560328e+00,  5.20151388e+00,  2.19157665e+01,
        -8.58971982e-01],
       [ 3.26875577e+02,  3.57270664e+00, -2.01288788e+00,
        -8.16860155e-01],
       [-1.27178258e+00,  7.65988756e+00,  2.10366181e+00,
         9.46996169e+01],
       [-1.23395722e+00,  1.49329056e+01, -1.01125054e+01,
        -1.05501720e+00],
       [ 2.14578443e+03, -5.28202443e+00,  1.22952358e+00,
         1.25320408e-01],
       [ 7.50084935e+03, -4.11633599e+00,  6.08475708e+00,
         1.21818876e+00],
       [ 6.47464083e+02, -3.70442198e+00,  3.36457252e+00,
        -5.81744011e-01],
       [-1.25811816e+00,  2.50306922e+01,  1.26782596e+01,
        -2.06504145e+00],
       [-1.24038775e+00, -5.51903054e+00, -8.84127860e-01,
        -4.58817192e-01],
       [ 1.01127126e+03, -5.50154576e+00,  8.06479583e-02,
      

In [62]:
# 6. Apply clustering algorithm
# clustering_results = generate_clusters(reduced_features)
cluster_segments = predict_cluster_with_centroids(reduced_features, clustering_results['centroids'])

In [63]:
cluster_segments

(array([10, 10, 10, ..., 10, 10, 10], shape=(13737,)),
 array([1.05595054, 0.05719075, 0.70023978, ..., 0.89272074, 0.35659446,
        3.97282917], shape=(13737,)))

In [64]:
# Add predicted cluster labels to the customer data
cust_data = cust_data.with_columns(pl.Series('customer_segment', cluster_segments[0]))

In [65]:
cust_data.head(1000)

profileId,company_id,sex,nationality,frequent_flyer,is_vip,by_self,has_corp_codes,total_searches,roundtrip_preference,unique_routes_searched,min_booking_lead_days,max_booking_lead_days,avg_booking_lead_days,median_booking_lead_days,most_common_departure_airport,unique_departure_airports,most_common_carrier,unique_carriers_used,min_cabin_class,max_cabin_class,avg_cabin_class,weekday_preference,weekend_travel_rate,time_of_day_variance,night_flight_preference,route_loyalty,hub_preference,short_haul_preference,connection_tolerance,preferred_duration_quartile,price_to_duration_sensitivity,avg_price_per_minute,price_per_minute_variance,price_position_preference,premium_economy_preference,consistent_price_tier,preferred_price_tier,baggage_qty_preference,baggage_weight_preference,loyalty_program_utilization,convenience_priority_score,loyalty_vs_price_index,planning_consistency_score,luxury_index,search_intensity_per_route,lead_time_variance,lead_time_skew,carrier_diversity,airport_diversity,cabin_class_range,customer_tier,vip_search_intensity,vip_carrier_diversity,vip_cabin_preference,corp_search_volume,corp_roundtrip_pref,corp_planning_variance,customer_segment
i64,i64,i8,i64,str,i8,i8,i8,u32,f64,u32,i32,i32,f64,f64,str,u32,str,u32,f64,f64,f64,i8,f64,f64,f64,f64,f64,f64,f64,i32,f64,f64,f64,f64,f64,f64,i32,f64,f64,f64,f64,f64,f64,f64,f64,i32,f64,f64,f64,f64,i8,f64,f64,f64,i64,f64,i32,i64
3559623,60537,1,36,"""""",0,1,0,492,1.0,1,6,7,6.597561,7.0,"""SVX""",1,"""SU""",5,1.0,1.0,1.0,4,0.0,6.371772,0.164634,0.997967,0.0,0.0,1.091463,1,0.643154,66.247593,25.158026,49.398937,0.0,0.626457,1,0.97561,0.0,0.0,-50.501944,-4.939894,0.5,44.699468,492.0,1,0.942509,0.010163,0.002033,0.0,1,0.0,0.0,0.0,0,0.0,0,10
1279879,42174,1,36,"""""",0,1,1,39,1.0,1,63,64,63.384615,63.0,"""BQS""",1,"""S7""",2,1.0,2.0,1.34188,6,1.0,1.478592,0.0,0.974359,0.0,0.0,2.0,1,0.133988,42.440184,26.446119,47.165992,0.307692,0.627574,1,0.897436,0.0,0.0,-4.115981,-4.716599,0.5,50.420603,39.0,1,1.006105,0.051282,0.025641,1.0,1,0.0,0.0,0.0,39,1.0,1,10
3477700,61081,1,36,"""""",0,1,1,647,0.877898,2,9,12,9.867079,10.0,"""UFA""",1,"""SU""",4,1.0,2.0,1.420402,7,0.877898,5.924169,0.126739,0.996909,0.0,0.362223,1.041731,1,-0.067285,184.236894,130.986442,47.917913,0.420402,0.629689,2,1.210201,0.0,0.0,-49.578112,-4.791791,0.25,52.366994,323.5,3,0.986708,0.006182,0.001546,1.0,1,0.0,0.0,0.0,647,0.877898,3,10
1835911,42620,0,36,"""""",0,1,1,918,1.0,1,19,20,19.663399,20.0,"""SVX""",1,"""SU""",6,1.0,2.0,1.37845,3,0.0,7.036581,0.265795,0.998911,0.0,0.243791,1.057734,1,0.053755,140.333237,113.00584,48.330732,0.379085,0.628769,2,1.162309,0.0,0.0,-60.097034,-4.833073,0.5,51.734356,918.0,1,0.98317,0.006536,0.001089,1.0,1,0.0,0.0,0.0,918,1.0,1,10
3420140,57323,1,36,"""""",0,1,1,120,0.0,2,19,23,20.916667,22.0,"""KUF""",5,"""SU""",6,1.0,2.0,1.466667,2,0.0,5.141663,0.125,0.983333,0.0,0.248588,1.083333,1,0.05234,251.258222,191.3653,46.899002,0.466667,0.657184,3,1.475,0.0,0.0,-41.154924,-4.6899,0.2,52.782834,60.0,4,0.950758,0.05,0.041667,1.0,1,0.0,0.0,0.0,120,0.0,4,10
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
2869628,61061,1,36,"""""",0,1,1,52,1.0,1,6,6,6.0,6.0,"""NBC""",1,"""SU""",1,1.0,1.0,1.0,1,0.0,0.0,0.0,0.980769,0.0,0.0,2.0,1,-0.213242,53.957156,19.934868,46.53092,0.0,0.604576,1,0.923077,0.0,0.0,8.933789,-4.653092,1.0,43.26546,52.0,0,1.0,0.019231,0.019231,0.0,1,0.0,0.0,0.0,52,1.0,0,10
2073838,53375,1,36,"""""",0,1,1,35,0.0,2,2,9,5.142857,2.0,"""KJA""",2,"""S7""",3,1.0,2.0,1.157143,3,0.0,5.450002,0.257143,0.942857,0.0,0.0,2.0,1,-0.318778,78.941272,44.809153,46.37013,0.157143,0.627761,3,0.685714,0.0,0.0,-46.09391,-4.637013,0.125,46.327922,17.5,7,2.571429,0.085714,0.057143,1.0,1,0.0,0.0,0.0,35,0.0,7,0
3454242,42174,1,36,"""""",0,1,1,41,0.0,1,104,104,104.0,104.0,"""LED""",1,"""U6""",4,1.0,2.0,1.097561,2,0.0,4.283832,0.097561,0.97561,0.0,0.0,2.0,3,-0.112272,38.56369,34.596927,49.268293,0.097561,0.612877,1,0.97561,0.0,0.0,-33.399685,-4.926829,1.0,46.585366,41.0,0,1.0,0.097561,0.02439,1.0,1,0.0,0.0,0.0,41,0.0,0,10
1699474,24728,1,36,"""SU""",0,1,1,169,1.0,1,5,6,5.733728,6.0,"""SVO""",4,"""SU""",5,1.0,1.0,1.0,7,1.0,7.053782,0.514793,0.994083,0.0,0.0,1.597633,1,0.496965,50.249449,20.170251,49.510425,0.0,0.62946,2,0.739645,0.0,1.0,-58.052997,5.048957,0.5,44.755213,169.0,1,0.955621,0.029586,0.023669,0.0,1,0.0,0.0,0.0,169,1.0,1,10


In [66]:
cust_data.group_by('customer_segment').len()

customer_segment,len
i64,u32
6,3
12,147
1,115
8,30
0,446
2,1479
10,11368
3,66
4,61
11,22


In [68]:
# Extract Flight Features
test_df = extract_flight_features(test_df)

In [69]:
# Combine customer and flight data
test_df = prepare_combined_data(test_df, cust_data)

In [72]:
test_df.head(1000)

Id,corporateTariffCode,miniRules0_monetaryAmount,miniRules0_percentage,miniRules0_statusInfos,miniRules1_monetaryAmount,miniRules1_percentage,miniRules1_statusInfos,pricingInfo_isAccessTP,pricingInfo_passengerCount,ranker_id,selected,price_percentile,price_rank_pct,price_ratio_to_min,route_popularity,route_popularity_log,is_access3D,is_roundtrip,route_origin,origin_is_major_hub,destination_is_major_hub,includes_major_hub,leg0_num_segments,leg1_num_segments,total_segments,leg0_flight_time_min,leg1_flight_time_min,booking_lead_days,leg0_duration_minutes,leg1_duration_minutes,trip_duration_minutes,leg0_departure_hour,leg0_departure_weekday,leg0_arrival_hour,leg0_arrival_weekday,leg1_departure_hour,…,route_loyalty,hub_preference,short_haul_preference,connection_tolerance,preferred_duration_quartile,price_to_duration_sensitivity,avg_price_per_minute,price_per_minute_variance,price_position_preference,premium_economy_preference,consistent_price_tier,preferred_price_tier,baggage_qty_preference,baggage_weight_preference,loyalty_program_utilization,convenience_priority_score,loyalty_vs_price_index,planning_consistency_score,luxury_index,search_intensity_per_route,lead_time_variance,lead_time_skew,carrier_diversity,airport_diversity,cabin_class_range,customer_tier,vip_search_intensity,vip_carrier_diversity,vip_cabin_preference,corp_search_volume,corp_roundtrip_pref,corp_planning_variance,customer_segment,daytime_alignment,weekend_alignment,price_preference_match,carrier_loyalty_match
i64,i64,f64,f64,f64,f64,f64,f64,f64,i64,str,i64,f64,f64,f64,u32,f64,i32,i8,str,i32,i32,i32,u32,u32,u32,i32,i32,i32,i32,i32,i32,i8,i8,i8,i8,i8,…,f64,f64,f64,f64,i32,f64,f64,f64,f64,f64,f64,i32,f64,f64,f64,f64,f64,f64,f64,f64,i32,f64,f64,f64,f64,i8,f64,f64,f64,i64,f64,i32,i64,f64,f64,f64,i8
0,0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1,"""98ce0dabf6964640b63079fbafd42c…",1,0.0,0.04,1.0,25,3.258097,0,1,"""TLK""",0,0,0,1,1,2,160,155,29,160,155,315,15,6,16,6,9,…,0.963415,0.0,0.0,1.97561,2,-0.333777,93.40632,72.972179,46.269727,0.182927,0.633535,2,1.02439,0.0,1.0,-33.657433,5.373027,0.052632,46.7934,27.333333,18,1.421748,0.073171,0.036585,1.0,1,0.0,0.0,0.0,82,0.304878,18,0,1.0,0.0,1.850789,0
1,123,2300.0,0.0,1.0,3500.0,0.0,1.0,1.0,1,"""98ce0dabf6964640b63079fbafd42c…",0,4.166667,0.08,3.028015,25,3.258097,1,1,"""TLK""",0,0,0,2,2,4,250,245,29,445,505,950,9,6,14,6,22,…,0.963415,0.0,0.0,1.97561,2,-0.333777,93.40632,72.972179,46.269727,0.182927,0.633535,2,1.02439,0.0,1.0,-33.657433,5.373027,0.052632,46.7934,27.333333,18,1.421748,0.073171,0.036585,1.0,1,0.0,0.0,0.0,82,0.304878,18,0,1.0,0.0,3.701578,1
2,0,2300.0,0.0,1.0,3500.0,0.0,1.0,1.0,1,"""98ce0dabf6964640b63079fbafd42c…",0,29.166667,0.32,3.18023,25,3.258097,0,1,"""TLK""",0,0,0,2,2,4,250,245,29,445,505,950,9,6,14,6,22,…,0.963415,0.0,0.0,1.97561,2,-0.333777,93.40632,72.972179,46.269727,0.182927,0.633535,2,1.02439,0.0,1.0,-33.657433,5.373027,0.052632,46.7934,27.333333,18,1.421748,0.073171,0.036585,1.0,1,0.0,0.0,0.0,82,0.304878,18,0,1.0,0.0,14.806313,1
3,123,0.0,0.0,1.0,0.0,0.0,1.0,1.0,1,"""98ce0dabf6964640b63079fbafd42c…",0,54.166667,0.56,4.849562,25,3.258097,1,1,"""TLK""",0,0,0,2,2,4,250,245,29,445,505,950,9,6,14,6,22,…,0.963415,0.0,0.0,1.97561,2,-0.333777,93.40632,72.972179,46.269727,0.182927,0.633535,2,1.02439,0.0,1.0,-33.657433,5.373027,0.052632,46.7934,27.333333,18,1.421748,0.073171,0.036585,1.0,1,0.0,0.0,0.0,82,0.304878,18,0,1.0,0.0,25.911047,1
4,0,0.0,0.0,1.0,0.0,0.0,1.0,1.0,1,"""98ce0dabf6964640b63079fbafd42c…",0,79.166667,0.8,5.097726,25,3.258097,0,1,"""TLK""",0,0,0,2,2,4,250,245,29,445,505,950,9,6,14,6,22,…,0.963415,0.0,0.0,1.97561,2,-0.333777,93.40632,72.972179,46.269727,0.182927,0.633535,2,1.02439,0.0,1.0,-33.657433,5.373027,0.052632,46.7934,27.333333,18,1.421748,0.073171,0.036585,1.0,1,0.0,0.0,0.0,82,0.304878,18,0,1.0,0.0,37.015782,1
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
8200,101,0.0,0.0,1.0,0.0,0.0,1.0,0.0,1,"""5ba369d9c3e94218b9987f7618da47…",0,73.929009,0.743276,3.040191,584293,13.27816,1,1,"""MOW""",0,0,0,1,1,2,90,85,3,90,85,175,16,1,18,1,23,…,0.999369,0.0,0.62215,1.0,1,0.511022,154.587657,128.950598,48.122981,0.404161,0.628129,3,0.858344,0.0,0.0,-45.061876,-4.812298,0.083333,52.184651,1586.0,11,0.856572,0.001051,0.000631,3.0,1,0.0,0.0,0.0,4758,1.0,11,10,0.868432,0.0,35.76867,1
8201,101,4600.0,0.0,1.0,0.0,0.0,0.0,0.0,1,"""5ba369d9c3e94218b9987f7618da47…",0,84.455324,0.848411,9.0733,584293,13.27816,1,1,"""MOW""",0,0,0,1,1,2,90,85,3,90,85,175,16,1,18,1,23,…,0.999369,0.0,0.62215,1.0,1,0.511022,154.587657,128.950598,48.122981,0.404161,0.628129,3,0.858344,0.0,0.0,-45.061876,-4.812298,0.083333,52.184651,1586.0,11,0.856572,0.001051,0.000631,3.0,1,0.0,0.0,0.0,4758,1.0,11,10,0.868432,0.0,40.828054,1
8202,101,4600.0,0.0,1.0,4600.0,0.0,1.0,0.0,1,"""5ba369d9c3e94218b9987f7618da47…",0,88.984088,0.893643,10.418909,584293,13.27816,1,1,"""MOW""",0,0,0,1,1,2,90,85,3,90,85,175,16,1,18,1,23,…,0.999369,0.0,0.62215,1.0,1,0.511022,154.587657,128.950598,48.122981,0.404161,0.628129,3,0.858344,0.0,0.0,-45.061876,-4.812298,0.083333,52.184651,1586.0,11,0.856572,0.001051,0.000631,3.0,1,0.0,0.0,0.0,4758,1.0,11,10,0.868432,0.0,43.004766,1
8203,101,0.0,0.0,1.0,0.0,0.0,1.0,0.0,1,"""5ba369d9c3e94218b9987f7618da47…",0,94.492044,0.948655,14.880666,584293,13.27816,1,1,"""MOW""",0,0,0,1,1,2,90,85,3,90,85,175,16,1,18,1,23,…,0.999369,0.0,0.62215,1.0,1,0.511022,154.587657,128.950598,48.122981,0.404161,0.628129,3,0.858344,0.0,0.0,-45.061876,-4.812298,0.083333,52.184651,1586.0,11,0.856572,0.001051,0.000631,3.0,1,0.0,0.0,0.0,4758,1.0,11,10,0.868432,0.0,45.652119,1


In [70]:
# After training your model
def complete_evaluation(model, test_df):
    """
    Complete evaluation workflow
    """
    print("Preparing test data...")
    X_test, y_test, groups_test = prepare_ranking_data(test_df)

    # Prepare testing data for ranking
    categorical_features = [col for col in X_test.columns if not X_test[col].dtype.is_numeric()]
    X_test = X_test.with_columns([pl.col(col).cast(pl.Categorical) for col in categorical_features])

    # Convert to pandas for LightGBM compatibility
    X_test = X_test.to_pandas()

    print("Making predictions...")
    y_pred = model.predict(X_test)

    print("Calculating Hit Rate @ 3...")
    hit_rate, details = detailed_hit_rate_evaluation(
        y_test,
        y_pred,
        groups_test,
        ranker_ids=test_df['ranker_id'].unique()
    )

    print(f"Overall Hit Rate @ 3: {hit_rate:.4f}")

    return hit_rate, details, y_pred


In [73]:
# Usage
hit_rate, details, y_pred = complete_evaluation(model, test_df)

Preparing test data...
Making predictions...
Calculating Hit Rate @ 3...
Overall Hit Rate @ 3: 0.5848


In [74]:
test_df = test_df.with_columns(pl.Series('y_pred', y_pred))

In [75]:
test_df = test_df.with_columns(pl.col('y_pred').rank(method='ordinal', descending=True).over('ranker_id').alias('flight_rank'))

In [76]:
test_df.head(100)

Id,corporateTariffCode,miniRules0_monetaryAmount,miniRules0_percentage,miniRules0_statusInfos,miniRules1_monetaryAmount,miniRules1_percentage,miniRules1_statusInfos,pricingInfo_isAccessTP,pricingInfo_passengerCount,ranker_id,selected,price_percentile,price_rank_pct,price_ratio_to_min,route_popularity,route_popularity_log,is_access3D,is_roundtrip,route_origin,origin_is_major_hub,destination_is_major_hub,includes_major_hub,leg0_num_segments,leg1_num_segments,total_segments,leg0_flight_time_min,leg1_flight_time_min,booking_lead_days,leg0_duration_minutes,leg1_duration_minutes,trip_duration_minutes,leg0_departure_hour,leg0_departure_weekday,leg0_arrival_hour,leg0_arrival_weekday,leg1_departure_hour,…,short_haul_preference,connection_tolerance,preferred_duration_quartile,price_to_duration_sensitivity,avg_price_per_minute,price_per_minute_variance,price_position_preference,premium_economy_preference,consistent_price_tier,preferred_price_tier,baggage_qty_preference,baggage_weight_preference,loyalty_program_utilization,convenience_priority_score,loyalty_vs_price_index,planning_consistency_score,luxury_index,search_intensity_per_route,lead_time_variance,lead_time_skew,carrier_diversity,airport_diversity,cabin_class_range,customer_tier,vip_search_intensity,vip_carrier_diversity,vip_cabin_preference,corp_search_volume,corp_roundtrip_pref,corp_planning_variance,customer_segment,daytime_alignment,weekend_alignment,price_preference_match,carrier_loyalty_match,y_pred,flight_rank
i64,i64,f64,f64,f64,f64,f64,f64,f64,i64,str,i64,f64,f64,f64,u32,f64,i32,i8,str,i32,i32,i32,u32,u32,u32,i32,i32,i32,i32,i32,i32,i8,i8,i8,i8,i8,…,f64,f64,i32,f64,f64,f64,f64,f64,f64,i32,f64,f64,f64,f64,f64,f64,f64,f64,i32,f64,f64,f64,f64,i8,f64,f64,f64,i64,f64,i32,i64,f64,f64,f64,i8,f64,u32
0,0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1,"""98ce0dabf6964640b63079fbafd42c…",1,0.0,0.04,1.0,25,3.258097,0,1,"""TLK""",0,0,0,1,1,2,160,155,29,160,155,315,15,6,16,6,9,…,0.0,1.97561,2,-0.333777,93.40632,72.972179,46.269727,0.182927,0.633535,2,1.02439,0.0,1.0,-33.657433,5.373027,0.052632,46.7934,27.333333,18,1.421748,0.073171,0.036585,1.0,1,0.0,0.0,0.0,82,0.304878,18,0,1.0,0.0,1.850789,0,-6.17485,23
1,123,2300.0,0.0,1.0,3500.0,0.0,1.0,1.0,1,"""98ce0dabf6964640b63079fbafd42c…",0,4.166667,0.08,3.028015,25,3.258097,1,1,"""TLK""",0,0,0,2,2,4,250,245,29,445,505,950,9,6,14,6,22,…,0.0,1.97561,2,-0.333777,93.40632,72.972179,46.269727,0.182927,0.633535,2,1.02439,0.0,1.0,-33.657433,5.373027,0.052632,46.7934,27.333333,18,1.421748,0.073171,0.036585,1.0,1,0.0,0.0,0.0,82,0.304878,18,0,1.0,0.0,3.701578,1,-12.451694,25
2,0,2300.0,0.0,1.0,3500.0,0.0,1.0,1.0,1,"""98ce0dabf6964640b63079fbafd42c…",0,29.166667,0.32,3.18023,25,3.258097,0,1,"""TLK""",0,0,0,2,2,4,250,245,29,445,505,950,9,6,14,6,22,…,0.0,1.97561,2,-0.333777,93.40632,72.972179,46.269727,0.182927,0.633535,2,1.02439,0.0,1.0,-33.657433,5.373027,0.052632,46.7934,27.333333,18,1.421748,0.073171,0.036585,1.0,1,0.0,0.0,0.0,82,0.304878,18,0,1.0,0.0,14.806313,1,-2.238555,4
3,123,0.0,0.0,1.0,0.0,0.0,1.0,1.0,1,"""98ce0dabf6964640b63079fbafd42c…",0,54.166667,0.56,4.849562,25,3.258097,1,1,"""TLK""",0,0,0,2,2,4,250,245,29,445,505,950,9,6,14,6,22,…,0.0,1.97561,2,-0.333777,93.40632,72.972179,46.269727,0.182927,0.633535,2,1.02439,0.0,1.0,-33.657433,5.373027,0.052632,46.7934,27.333333,18,1.421748,0.073171,0.036585,1.0,1,0.0,0.0,0.0,82,0.304878,18,0,1.0,0.0,25.911047,1,-1.800006,2
4,0,0.0,0.0,1.0,0.0,0.0,1.0,1.0,1,"""98ce0dabf6964640b63079fbafd42c…",0,79.166667,0.8,5.097726,25,3.258097,0,1,"""TLK""",0,0,0,2,2,4,250,245,29,445,505,950,9,6,14,6,22,…,0.0,1.97561,2,-0.333777,93.40632,72.972179,46.269727,0.182927,0.633535,2,1.02439,0.0,1.0,-33.657433,5.373027,0.052632,46.7934,27.333333,18,1.421748,0.073171,0.036585,1.0,1,0.0,0.0,0.0,82,0.304878,18,0,1.0,0.0,37.015782,1,-2.841849,6
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
1358,177,0.0,0.0,1.0,0.0,0.0,1.0,1.0,1,"""629321276e5748a5a7ea5e3f91642f…",0,60.526316,0.692308,2.730658,649,6.476972,1,0,"""SGC""",0,0,0,2,0,2,340,0,0,665,0,665,4,6,15,6,0,…,0.0,1.871795,1,-0.04197,103.538343,60.999025,47.975709,0.307692,0.62838,1,1.307692,0.0,0.0,-38.244081,-4.797571,0.5,50.1417,39.0,1,0.74359,0.076923,0.025641,1.0,1,0.0,0.0,0.0,39,0.0,1,10,0.0,0.0,33.213952,1,-1.656742,12
1359,177,4600.0,0.0,1.0,0.0,0.0,0.0,0.0,1,"""629321276e5748a5a7ea5e3f91642f…",0,71.052632,0.794872,3.731367,649,6.476972,1,0,"""SGC""",0,0,0,2,0,2,340,0,0,665,0,665,4,6,15,6,0,…,0.0,1.871795,1,-0.04197,103.538343,60.999025,47.975709,0.307692,0.62838,1,1.307692,0.0,0.0,-38.244081,-4.797571,0.5,50.1417,39.0,1,0.74359,0.076923,0.025641,1.0,1,0.0,0.0,0.0,39,0.0,1,10,0.0,0.0,38.134538,1,-5.032743,26
1360,177,4600.0,0.0,1.0,4600.0,0.0,1.0,0.0,1,"""629321276e5748a5a7ea5e3f91642f…",0,81.578947,0.897436,3.944961,649,6.476972,1,0,"""SGC""",0,0,0,2,0,2,340,0,0,665,0,665,4,6,15,6,0,…,0.0,1.871795,1,-0.04197,103.538343,60.999025,47.975709,0.307692,0.62838,1,1.307692,0.0,0.0,-38.244081,-4.797571,0.5,50.1417,39.0,1,0.74359,0.076923,0.025641,1.0,1,0.0,0.0,0.0,39,0.0,1,10,0.0,0.0,43.055123,1,-0.098127,5
1361,177,0.0,0.0,1.0,0.0,0.0,1.0,0.0,1,"""629321276e5748a5a7ea5e3f91642f…",0,92.105263,1.0,4.961246,649,6.476972,1,0,"""SGC""",0,0,0,2,0,2,340,0,0,665,0,665,4,6,15,6,0,…,0.0,1.871795,1,-0.04197,103.538343,60.999025,47.975709,0.307692,0.62838,1,1.307692,0.0,0.0,-38.244081,-4.797571,0.5,50.1417,39.0,1,0.74359,0.076923,0.025641,1.0,1,0.0,0.0,0.0,39,0.0,1,10,0.0,0.0,47.975709,1,0.027688,4


In [77]:
test_df.group_by(pl.col('ranker_id')).agg(pl.len().alias('count')).sort('count', descending=False).filter(pl.col('count') > 10).head(1000)

ranker_id,count
str,u32
"""ef446e20cacb468982678f62dd2a8b…",11
"""b60508f3e0a6478991fc22bcb7c2a9…",11
"""1577560fb50c4b90a403d44fd28d0c…",11
"""0d918b652375456883b9534046f84e…",11
"""fb231c505f8d4aa3a625ea8e975a1b…",11
…,…
"""5d1ba56f95794ccba0ab61efa837ca…",15
"""2bcbd7dea6954a9b9d794cc2beb372…",15
"""724ed6663bf744cb9bc9c547c72a6b…",15
"""1ec0016042d94b8eaa79801e22dc42…",15


In [78]:
next((item for item in details if item['search_id'] == "ef446e20cacb468982678f62dd2a8b71"), None)

{'search_id': 'ef446e20cacb468982678f62dd2a8b71',
 'group_size': np.uint32(3),
 'num_actual_selections': 1,
 'hit': True,
 'top_k_scores': [-7.444889495298393, -6.945598460298722, -0.5184594304742536],
 'actual_selections_scores': [-0.5184594304742536]}

In [79]:
test_df.filter(pl.col('ranker_id') == "ef446e20cacb468982678f62dd2a8b71")['selected','y_pred', 'flight_rank'].head(100)

selected,y_pred,flight_rank
i64,f64,u32
1,-11.855496,3
0,-13.001767,9
0,-11.868254,4
0,-11.693528,2
0,-12.63086,7
…,…,…
0,-14.169297,11
0,-13.331807,10
0,-11.34904,1
0,-12.039074,6


### Review Sample Submission Format

In [2]:
sample = pl.read_parquet('/kaggle/input/aeroclub-recsys-2025/sample_submission.parquet')

In [3]:
sample.head(100)

Id,ranker_id,selected,__index_level_0__
i64,str,i64,i64
18144679,"""c9373e5f772e43d593dd6ad2fa90f6…",178,18144679
18144680,"""c9373e5f772e43d593dd6ad2fa90f6…",363,18144680
18144681,"""c9373e5f772e43d593dd6ad2fa90f6…",277,18144681
18144682,"""c9373e5f772e43d593dd6ad2fa90f6…",183,18144682
18144683,"""c9373e5f772e43d593dd6ad2fa90f6…",55,18144683
…,…,…,…
18144774,"""c9373e5f772e43d593dd6ad2fa90f6…",324,18144774
18144775,"""c9373e5f772e43d593dd6ad2fa90f6…",59,18144775
18144776,"""c9373e5f772e43d593dd6ad2fa90f6…",147,18144776
18144777,"""c9373e5f772e43d593dd6ad2fa90f6…",233,18144777


In [5]:
sample.filter(pl.col('ranker_id') == "c9373e5f772e43d593dd6ad2fa90f67a").sort("selected", descending=False)

Id,ranker_id,selected,__index_level_0__
i64,str,i64,i64
18144870,"""c9373e5f772e43d593dd6ad2fa90f6…",1,18144870
18145030,"""c9373e5f772e43d593dd6ad2fa90f6…",2,18145030
18144936,"""c9373e5f772e43d593dd6ad2fa90f6…",3,18144936
18144748,"""c9373e5f772e43d593dd6ad2fa90f6…",4,18144748
18144823,"""c9373e5f772e43d593dd6ad2fa90f6…",5,18144823
…,…,…,…
18145000,"""c9373e5f772e43d593dd6ad2fa90f6…",408,18145000
18144751,"""c9373e5f772e43d593dd6ad2fa90f6…",409,18144751
18144918,"""c9373e5f772e43d593dd6ad2fa90f6…",410,18144918
18144688,"""c9373e5f772e43d593dd6ad2fa90f6…",411,18144688
