# Imports and paths

In [45]:
import os
import sys

# root path
ROOT = os.path.abspath(os.path.join(os.getcwd(), '..'))

# Add the project root to the Python path
if ROOT not in sys.path:
    sys.path.append(ROOT)

import polars as pl
import pandas as pd
import numpy as np
import json

In [46]:
from config import TRAIN_PARQUET_PATH, TEST_PARQUET_PATH, PRODUCTS_PARQUET_PATH, USERS_DATA_PATH, PRODUCTS_PARQUET_PATH_IMPUTED, ALL_PRODS_TRAINTEST_PATH
from src.data.loaders import PolarsLoader

## Users data

In [47]:
users = pl.read_parquet(USERS_DATA_PATH, low_memory=True).rename({'country': 'user_country'})
print(users.head())
print(users.null_count())
print("Number of users:", users['user_id'].n_unique())

shape: (5, 5)
┌──────────────┬─────┬─────┬───────────┬─────────┐
│ user_country ┆ R   ┆ F   ┆ M         ┆ user_id │
│ ---          ┆ --- ┆ --- ┆ ---       ┆ ---     │
│ i8           ┆ i16 ┆ i16 ┆ f32       ┆ i32     │
╞══════════════╪═════╪═════╪═══════════╪═════════╡
│ 25           ┆ 74  ┆ 86  ┆ 11.64094  ┆ 180365  │
│ 25           ┆ 79  ┆ 5   ┆ 30.283333 ┆ 430101  │
│ 25           ┆ 0   ┆ 35  ┆ 47.25     ┆ 134206  │
│ 25           ┆ 0   ┆ 138 ┆ 46.604679 ┆ 180364  │
│ 25           ┆ 1   ┆ 24  ┆ 66.113075 ┆ 430100  │
└──────────────┴─────┴─────┴───────────┴─────────┘
shape: (1, 5)
┌──────────────┬─────┬─────┬─────┬─────────┐
│ user_country ┆ R   ┆ F   ┆ M   ┆ user_id │
│ ---          ┆ --- ┆ --- ┆ --- ┆ ---     │
│ u32          ┆ u32 ┆ u32 ┆ u32 ┆ u32     │
╞══════════════╪═════╪═════╪═════╪═════════╡
│ 0            ┆ 0   ┆ 0   ┆ 0   ┆ 0       │
└──────────────┴─────┴─────┴─────┴─────────┘
Number of users: 557006


In [48]:
# Select only one entry per user by: HIGH F, HIGH R
users = users.sort(['user_id', 'F', 'R'], descending=[False, True, True])\
            .group_by('user_id')\
            .agg(pl.all().first())

In [49]:
def create_initial_extra_features(df: pl.DataFrame) -> pl.DataFrame:

    df = df.with_columns([
        # Average value per purchase
        (pl.col('M') / pl.col('F')).alias('avg_value_per_purchase'),
        # Purchase frequency rate (F normalized by time window)
        (pl.col('F') / pl.col('R')).alias('purchase_rate').cast(pl.Float32),
        # Value density (M normalized by time window)
        (pl.col('M') / pl.col('R')).alias('spend_rate_per_day'),
    ])

    df = df.with_columns([
        # Value-frequency relationship
        (pl.col('M') * pl.col('F')).alias('total_value_frequency').cast(pl.Float32),
        
        # Recency-frequency relationship
        ((pl.col('R') / pl.col('F'))).alias('avg_days_between_purchases').cast(pl.Float32),
    ])

    # Country stats and users relatives to country
    contry_stats = df.group_by('user_country').agg([
            pl.col('M').mean().alias('country_avg_monetary'),
            pl.col('F').mean().alias('country_avg_frequency').cast(pl.Float32),
            pl.col('R').mean().alias('country_avg_recency').cast(pl.Float32)
        ])
    df = df.join(contry_stats, on='user_country').with_columns([
            (pl.col('M') / pl.col('country_avg_monetary')).alias('relative_monetary_value'),
            (pl.col('F') / pl.col('country_avg_frequency')).alias('relative_frequency').cast(pl.Float32),
            (pl.col('R') / pl.col('country_avg_recency')).alias('relative_recency').cast(pl.Float32),
        ])

    df = df.with_columns([
        # High expend customer flag
        (pl.col('M') > pl.col('M').mean().over('user_country')).cast(pl.Int8).alias('is_high_value_incountry'),
        # Frequent buyer flag
        (pl.col('F') > pl.col('F').mean().over('user_country')).cast(pl.Int8).alias('is_frequent_buyer_incountry'),
        # Recent customer flag
        (pl.col('R') > pl.col('R').mean().over('user_country')).cast(pl.Int8).alias('is_recent_customer_incountry'),
        # High expend customer flag
        (pl.col('M') > pl.col('M').mean()).cast(pl.Int8).alias('is_high_value'),
        # Frequent buyer flag
        (pl.col('F') > pl.col('F').mean()).cast(pl.Int8).alias('is_frequent_buyer'),
        # Recent customer flag
        (pl.col('R') > pl.col('R').mean()).cast(pl.Int8).alias('is_recent_customer'),
    ])
    
    # Replace NaN and Inf values with 0
    for col in ['purchase_rate', 'spend_rate_per_day', 'avg_days_between_purchases']:
            if col in df.columns:
                df = df.with_columns(pl.when(pl.col(col).is_nan() | (pl.col(col)
                                                                                .is_infinite()))
                                                                                .then(0)
                                                                                .otherwise(pl.col(col))
                                                                                .alias(col),)

    return df

users_eng = create_initial_extra_features(users)

In [50]:
def create_rfm_segments_and_ranks(users_df):
    # Create quintiles (5 segments) for each metric
    return users_df.with_columns([
        # Recency quintile (1 is most recent, 5 is least recent)
        pl.col('R')
            .rank(descending=True)  # High R is better
            .over('user_country')
            .map_batches(lambda x: pd.qcut(x, q=5, labels=False) + 1)
            .alias('r_segment').cast(pl.UInt8),
            
        # Frequency quintile (1 is highest frequency, 5 is lowest)
        pl.col('F')
            .rank(descending=True)   # higher F is better
            .over('user_country')
            .map_batches(lambda x: pd.qcut(x, q=5, labels=False) + 1)
            .alias('f_segment').cast(pl.UInt8),
            
        # Monetary quintile (1 is highest value, 5 is lowest)
        pl.col('M')
            .rank(descending=True)   # higher M is better
            .over('user_country')
            .map_batches(lambda x: pd.qcut(x, q=5, labels=False) + 1)
            .alias('m_segment').cast(pl.UInt8),

        # Individual percentile ranks
        pl.col('R').rank(descending=True)
            .over('user_country').alias('r_rank_in_country').cast(pl.Int32),
        pl.col('F').rank(descending=True)
            .over('user_country').alias('f_rank_in_country').cast(pl.Int32),
        pl.col('M').rank(descending=True)
            .over('user_country').alias('m_rank_in_country').cast(pl.Int32),
    ])

# This would let you identify segments like:
# - Premium customers (111): high in all metrics
# - Lost customers (555): low in all metrics
# - High value but inactive (511): good monetary but haven't bought recently

users_eng = create_rfm_segments_and_ranks(users_eng)

In [51]:
users_eng

user_id,user_country,R,F,M,avg_value_per_purchase,purchase_rate,spend_rate_per_day,total_value_frequency,avg_days_between_purchases,country_avg_monetary,country_avg_frequency,country_avg_recency,relative_monetary_value,relative_frequency,relative_recency,is_high_value_incountry,is_frequent_buyer_incountry,is_recent_customer_incountry,is_high_value,is_frequent_buyer,is_recent_customer,r_segment,f_segment,m_segment,r_rank_in_country,f_rank_in_country,m_rank_in_country
i32,i8,i16,i16,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,i8,i8,i8,i8,i8,i8,u8,u8,u8,i32,i32,i32
1,25,60,18,40.518333,2.251019,0.3,0.675306,729.330017,3.333333,42.604546,37.78318,56.788792,0.951033,0.476402,1.056547,0,0,1,0,0,1,2,3,3,128112,318758,243018
2,25,2,37,38.485115,1.040138,18.5,19.242558,1423.949219,0.054054,42.604546,37.78318,56.788792,0.90331,0.979272,0.035218,0,0,0,0,0,0,5,2,3,473983,186318,268906
3,25,11,64,80.771408,1.262053,5.818182,7.342855,5169.370117,0.171875,42.604546,37.78318,56.788792,1.89584,1.693875,0.1937,1,1,0,0,1,0,3,1,1,331106,91650,31976
4,25,43,18,70.28611,3.904784,0.418605,1.634561,1265.150024,2.388889,42.604546,37.78318,56.788792,1.649733,0.476402,0.757192,1,0,0,0,0,0,2,3,1,164114,318758,52180
5,25,214,2,93.220001,46.610001,0.009346,0.435607,186.440002,107.0,42.604546,37.78318,56.788792,2.18803,0.052934,3.768349,1,0,1,0,0,1,1,5,1,30093,522181,18891
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
557002,25,0,8,23.286667,2.910833,0.0,0.0,186.293335,0.0,42.604546,37.78318,56.788792,0.546577,0.211734,0.0,0,0,0,0,0,0,5,4,5,534571,428650,465473
557003,25,53,12,41.223331,3.435278,0.226415,0.777799,494.679993,4.416667,42.604546,37.78318,56.788792,0.967581,0.317602,0.933283,0,0,0,0,0,0,2,4,3,140733,380127,235047
557004,25,147,12,62.479168,5.206597,0.081633,0.425028,749.75,12.25,42.604546,37.78318,56.788792,1.466491,0.317602,2.588539,1,0,1,0,0,1,1,4,1,53616,380127,77321
557005,25,2,108,21.617975,0.200166,54.0,10.808988,2334.741211,0.018519,42.604546,37.78318,56.788792,0.50741,2.858415,0.035218,0,1,0,0,1,0,5,1,5,473983,34287,482218
