In [None]:
import os
import logging
from datetime import datetime
import numpy as np


from pyspark.sql import SparkSession
from pyspark.sql.functions import (
    col, count, when, desc, avg, min, udf, percent_rank
)
from pyspark.sql.types import ( StringType, FloatType
)
from pyspark.ml.feature import (
    VectorAssembler, StandardScaler, StringIndexer, Bucketizer
)
from pyspark.ml.feature import Tokenizer, HashingTF, IDF
from pyspark.sql.functions import concat_ws
from pyspark.ml.clustering import KMeans
from pyspark.ml.recommendation import ALS
from pyspark.ml.evaluation import ClusteringEvaluator, RegressionEvaluator
from pyspark.ml import Pipeline
from pyspark.sql.window import Window
# Configuration du logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("modelisation.log"),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# Paramètres globaux
DATA_DIR = "./data/processed/parquet"
MODELS_DIR = "./models"
RESULTS_DIR = "./results"

# Création des répertoires de sortie s'ils n'existent pas
os.makedirs(MODELS_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)

# Création de la session Spark
def create_spark_session():
    """Crée et retourne une session Spark configurée"""
    return SparkSession.builder \
        .appName("E-commerce Models") \
        .config("spark.driver.memory", "8g") \
        .config("spark.executor.memory", "8g") \
        .config("spark.sql.session.timeZone", "UTC") \
        .config("spark.sql.shuffle.partitions", "20") \
        .config("spark.default.parallelism", "20") \
        .master("local[*]") \
        .getOrCreate()

# Initialisation de la session Spark
spark = create_spark_session()
logger.info("Session Spark initialisée")

2025-05-18 16:42:35,920 - INFO - Session Spark initialisée


In [11]:
def find_latest_file(directory, pattern):
    """Trouve le fichier le plus récent dans le répertoire correspondant au pattern"""
    files = [f for f in os.listdir(directory) if pattern in f]
    if not files:
        return None
    
    latest_file = None
    latest_time = 0
    
    for file in files:
        file_path = os.path.join(directory, file)
        file_time = os.path.getmtime(file_path)
        if file_time > latest_time:
            latest_time = file_time
            latest_file = file
            
    return latest_file

# Trouver les fichiers les plus récents
latest_cleaned = find_latest_file(DATA_DIR, "cleaned_data")
latest_user_behavior = find_latest_file(DATA_DIR, "user_behavior")
latest_recommendation = find_latest_file(DATA_DIR, "recommendation_data")
latest_product = find_latest_file(DATA_DIR, "product_data")
latest_time_series = find_latest_file(DATA_DIR, "time_series_data")

# Chargement des données
logger.info("Chargement des données prétraitées")

# Données nettoyées
cleaned_df = spark.read.parquet(os.path.join(DATA_DIR, latest_cleaned))
logger.info(f"Données nettoyées chargées: {cleaned_df.count()} lignes")
cleaned_df.printSchema()
cleaned_df.show(5)

# Comportements utilisateurs
user_behavior_df = spark.read.parquet(os.path.join(DATA_DIR, latest_user_behavior))
logger.info(f"Comportements utilisateurs chargés: {user_behavior_df.count()} lignes")
user_behavior_df.printSchema()
user_behavior_df.show(5)

# Données de recommandation
recommendation_df = spark.read.parquet(os.path.join(DATA_DIR, latest_recommendation))
logger.info(f"Données de recommandation chargées: {recommendation_df.count()} lignes")
recommendation_df.printSchema()
recommendation_df.show(5)

# Données produits
product_df = spark.read.parquet(os.path.join(DATA_DIR, latest_product))
logger.info(f"Données produits chargées: {product_df.count()} lignes")
product_df.printSchema()
product_df.show(5)

# Données temporelles
time_series_df = spark.read.parquet(os.path.join(DATA_DIR, latest_time_series))
logger.info(f"Données temporelles chargées: {time_series_df.count()} lignes")
time_series_df.printSchema()
time_series_df.show(5)

2025-05-18 16:00:46,162 - INFO - Chargement des données prétraitées
2025-05-18 16:00:50,443 - INFO - Données nettoyées chargées: 1000000 lignes


root
 |-- event_time: timestamp (nullable = true)
 |-- event_type: string (nullable = true)
 |-- product_id: string (nullable = true)
 |-- category_id: string (nullable = true)
 |-- category_code: string (nullable = true)
 |-- brand: string (nullable = true)
 |-- price: double (nullable = true)
 |-- user_id: string (nullable = true)
 |-- user_session: string (nullable = true)
 |-- hour: integer (nullable = true)
 |-- minute: integer (nullable = true)
 |-- second: integer (nullable = true)
 |-- day: integer (nullable = true)
 |-- month: integer (nullable = true)
 |-- dayofweek: integer (nullable = true)
 |-- date: string (nullable = true)
 |-- hour_bucket: string (nullable = true)

+-------------------+----------+----------+-------------------+--------------------+--------+-------+---------+--------------------+----+------+------+---+-----+---------+----------+-------------------+
|         event_time|event_type|product_id|        category_id|       category_code|   brand|  price|  user

2025-05-18 16:00:52,623 - INFO - Comportements utilisateurs chargés: 163024 lignes


root
 |-- user_id: string (nullable = true)
 |-- nb_events: long (nullable = true)
 |-- nb_views: long (nullable = true)
 |-- nb_carts: long (nullable = true)
 |-- nb_purchases: long (nullable = true)
 |-- nb_removes: long (nullable = true)
 |-- avg_price_viewed: double (nullable = true)
 |-- avg_price_purchased: double (nullable = true)
 |-- nb_sessions: long (nullable = true)
 |-- first_seen: timestamp (nullable = true)
 |-- last_seen: timestamp (nullable = true)
 |-- recency: integer (nullable = true)
 |-- frequency: long (nullable = true)
 |-- monetary: double (nullable = true)
 |-- viewed_categories: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- viewed_brands: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- conversion_rate: double (nullable = true)
 |-- cart_abandonment: double (nullable = true)
 |-- engagement_days: integer (nullable = true)

+---------+---------+--------+--------+------------+----------+----------------

2025-05-18 16:00:53,325 - INFO - Données de recommandation chargées: 1000000 lignes


root
 |-- user_id: string (nullable = true)
 |-- product_id: string (nullable = true)
 |-- event_type: string (nullable = true)
 |-- price: double (nullable = true)
 |-- event_time: timestamp (nullable = true)
 |-- interaction_score: integer (nullable = true)

+---------+----------+----------+-------+-------------------+-----------------+
|  user_id|product_id|event_type|  price|         event_time|interaction_score|
+---------+----------+----------+-------+-------------------+-----------------+
|541312140|  44600062|      view|  35.79|2019-10-01 00:00:00|                1|
|554748717|   3900821|      view|   33.2|2019-10-01 00:00:00|                1|
|519107250|  17200506|      view|  543.1|2019-10-01 00:00:01|                1|
|550050854|   1307067|      view| 251.74|2019-10-01 00:00:01|                1|
|535871217|   1004237|      view|1081.98|2019-10-01 00:00:04|                1|
+---------+----------+----------+-------+-------------------+-----------------+
only showing top 5 

2025-05-18 16:00:54,152 - INFO - Données produits chargées: 66603 lignes


root
 |-- product_id: string (nullable = true)
 |-- category_id: string (nullable = true)
 |-- category_code: string (nullable = true)
 |-- brand: string (nullable = true)
 |-- price: double (nullable = true)

+----------+-------------------+--------------------+--------+-------+
|product_id|        category_id|       category_code|   brand|  price|
+----------+-------------------+--------------------+--------+-------+
|   1004235|2053013555631882655|electronics.smart...|   apple|1169.25|
|   5400748|2053013552989470973|             unknown|     svc|  88.17|
|  13200019|2053013557192163841|furniture.bedroom...|      sv| 386.08|
|   3800422|2053013566176363511|     appliances.iron| polaris|  23.14|
|   6000267|2053013560807654091|auto.accessories....|starline| 420.09|
+----------+-------------------+--------------------+--------+-------+
only showing top 5 rows



2025-05-18 16:00:54,882 - INFO - Données temporelles chargées: 17 lignes


root
 |-- hour_bucket: string (nullable = true)
 |-- total_events: long (nullable = true)
 |-- unique_users: long (nullable = true)
 |-- views: long (nullable = true)
 |-- carts: long (nullable = true)
 |-- purchases: long (nullable = true)
 |-- removes: long (nullable = true)
 |-- avg_price: double (nullable = true)

+-------------------+------------+------------+-----+-----+---------+-------+------------------+
|        hour_bucket|total_events|unique_users|views|carts|purchases|removes|         avg_price|
+-------------------+------------+------------+-----+-----+---------+-------+------------------+
|2019-10-01 00:00:00|        1083|         383| 1070|    3|       10|      0| 303.2282086795937|
|2019-10-01 01:00:00|         121|         102|  121|    0|        0|      0|327.07438016528926|
|2019-10-01 02:00:00|       22886|        5378|22326|  244|      316|      0| 289.5163129346923|
|2019-10-01 03:00:00|       49409|       10514|47951|  613|      845|      0|283.44011712259385|
|

In [None]:
def prepare_rfm_segmentation(user_df):
    """Prépare les données pour la segmentation RFM"""
    logger.info("Préparation des données pour segmentation RFM")
    
    # Filtrer les utilisateurs avec au moins une interaction
    rfm_df = user_df.filter(col("nb_events") > 0)
    
    # Convertir les valeurs manquantes/nulles en 0 pour les métriques RFM
    rfm_df = rfm_df.fillna({
        "recency": 30,  # Valeur max si jamais vu
        "frequency": 0,  # Pas d'achats
        "monetary": 0    # Pas de dépenses
    })
    rfm_df = rfm_df.withColumn(
        "is_active",
        when(col("frequency") > 0, 1).otherwise(0)
    )
    
    # Créer des buckets pour les métriques RFM
    # 1 = meilleur, 5 = pire
    
    # Récence (inversée: plus petit = meilleur)
    recency_quantiles = rfm_df.approxQuantile("recency", [0.2, 0.4, 0.6, 0.8], 0.01)
    recency_quantiles = sorted(list(set([0] + recency_quantiles + [31])))  # Ajouter min et max, enlever les doublons
    
    # Fréquence
    frequency_quantiles = rfm_df.approxQuantile("frequency", [0.2, 0.4, 0.6, 0.8], 0.01)
    frequency_quantiles = sorted(list(set([0] + frequency_quantiles)))  # Min et max, enlever les doublons
    if len(frequency_quantiles) <= 1:
        frequency_quantiles.append(1.0)
    frequency_quantiles.append(float('inf'))
    
    # Valeur monétaire
    monetary_quantiles = rfm_df.approxQuantile("monetary", [0.2, 0.4, 0.6, 0.8], 0.01)
    monetary_quantiles = sorted(list(set([0] + monetary_quantiles)))  # Min et max, enlever les doublons
    if len(monetary_quantiles) <= 1:
        monetary_quantiles.append(1.0)
    monetary_quantiles.append(float('inf'))
    
    logger.info(f"Quantiles récence: {recency_quantiles}")
    logger.info(f"Quantiles fréquence: {frequency_quantiles}")
    logger.info(f"Quantiles monétaire: {monetary_quantiles}")
    
    # Inversée pour récence (plus bas = meilleur score)
    recency_bucketizer = Bucketizer(
        splits=recency_quantiles, 
        inputCol="recency", 
        outputCol="recency_score"
    )
    
    # Pas inversée (plus haut = meilleur score)
    frequency_bucketizer = Bucketizer(
        splits=frequency_quantiles, 
        inputCol="frequency", 
        outputCol="frequency_score"
    )
    
    # Pas inversée (plus haut = meilleur score)
    monetary_bucketizer = Bucketizer(
        splits=monetary_quantiles, 
        inputCol="monetary", 
        outputCol="monetary_score"
    )
    
    # Appliquer les bucketizers
    rfm_df = recency_bucketizer.transform(rfm_df)
    rfm_df = frequency_bucketizer.transform(rfm_df)
    rfm_df = monetary_bucketizer.transform(rfm_df)
    
    # Inverser le score de récence (5 - score) pour qu'un score élevé soit meilleur
    rfm_df = rfm_df.withColumn("recency_score", 5.0 - col("recency_score"))
    
    # Ajouter 1 aux scores pour qu'ils commencent à 1 (et pas à 0)
    rfm_df = rfm_df.withColumn("recency_score", col("recency_score") + 1)
    rfm_df = rfm_df.withColumn("frequency_score", col("frequency_score") + 1)
    rfm_df = rfm_df.withColumn("monetary_score", col("monetary_score") + 1)
    
    # Calcul du score RFM global
    rfm_df = rfm_df.withColumn(
        "rfm_score", 
        col("recency_score") * 100 + col("frequency_score") * 10 + col("monetary_score")
    )
    
    # Segmentation RFM basique
    rfm_df = rfm_df.withColumn(
        "rfm_segment",
        when((col("recency_score") >= 4) & (col("frequency_score") >= 4) & (col("monetary_score") >= 4), "Champions")
        .when((col("recency_score") >= 3) & (col("frequency_score") >= 3) & (col("monetary_score") >= 3), "Loyal Customers")
        .when((col("recency_score") >= 3) & (col("frequency_score") <= 2) & (col("monetary_score") <= 2), "Potential Loyalists")
        .when((col("recency_score") <= 2) & (col("frequency_score") >= 3) & (col("monetary_score") >= 3), "At Risk")
        .when((col("recency_score") <= 2) & (col("frequency_score") <= 2) & (col("monetary_score") <= 2), "Hibernating")
        .when((col("recency_score") >= 4) & (col("frequency_score") <= 2), "New Customers")
        .when((col("recency_score") <= 2) & (col("frequency_score") >= 4) & (col("monetary_score") >= 4), "Cannot Lose Them")
        .when((col("recency_score") >= 4) & (col("frequency_score") >= 3) & (col("monetary_score") < 3), "Promising")
        .when((col("recency_score") >= 3) & (col("frequency_score") <= 2) & (col("monetary_score") >= 3), "Need Attention")
        .otherwise("Others")
    )
    
    logger.info("Distribution des segments RFM:")
    rfm_df.groupBy("rfm_segment").count().orderBy(desc("count")).show()
    
    return rfm_df

# Exécuter la segmentation RFM
rfm_segmentation = prepare_rfm_segmentation(user_behavior_df)
rfm_segmentation.select("user_id", "recency", "frequency", "monetary", 
                       "recency_score", "frequency_score", "monetary_score", 
                       "rfm_score", "rfm_segment").show(10)

2025-05-18 16:02:09,427 - INFO - Préparation des données pour segmentation RFM
2025-05-18 16:02:10,396 - INFO - Quantiles récence: [0, 30.0, 31]
2025-05-18 16:02:10,397 - INFO - Quantiles fréquence: [0, 1.0, inf]
2025-05-18 16:02:10,398 - INFO - Quantiles monétaire: [0, 1.0, inf]
2025-05-18 16:02:10,665 - INFO - Distribution des segments RFM:


+-------------------+------+
|        rfm_segment| count|
+-------------------+------+
|Potential Loyalists|163024|
+-------------------+------+

+---------+-------+---------+--------+-------------+---------------+--------------+---------+-------------------+
|  user_id|recency|frequency|monetary|recency_score|frequency_score|monetary_score|rfm_score|        rfm_segment|
+---------+-------+---------+--------+-------------+---------------+--------------+---------+-------------------+
|337535108|     30|        0|     0.0|          5.0|            1.0|           1.0|    511.0|Potential Loyalists|
|410824220|     30|        0|     0.0|          5.0|            1.0|           1.0|    511.0|Potential Loyalists|
|418592979|     30|        0|     0.0|          5.0|            1.0|           1.0|    511.0|Potential Loyalists|
|420339201|     30|        0|     0.0|          5.0|            1.0|           1.0|    511.0|Potential Loyalists|
|430276841|     30|        0|     0.0|          5.0|    

In [None]:
def prepare_behavioral_clustering(user_df):
    """Prépare les données pour le clustering comportemental"""
    logger.info("Préparation des données pour clustering comportemental")
    
    # Sélection des features comportementales pertinentes
    behavior_features = [
        "nb_events", "nb_views", "nb_carts", "nb_purchases", "nb_removes",
        "avg_price_viewed", "avg_price_purchased", "nb_sessions",
        "conversion_rate", "cart_abandonment", "engagement_days"
    ]
    
    # Filtrer uniquement les utilisateurs avec suffisamment d'interactions
    clustering_df = user_df.filter(col("nb_events") >= 2)
    logger.info(f"Utilisateurs avec au moins 2 événements: {clustering_df.count()}")
    
    # Remplacer les valeurs nulles par des zéros ou moyennes selon le cas
    clustering_df = clustering_df.na.fill({
        "avg_price_viewed": 0,
        "avg_price_purchased": 0,
        "conversion_rate": 0,
        "cart_abandonment": 0,
        "engagement_days": 1
    })
    
    # Assembler les features en vecteurs
    assembler = VectorAssembler(
        inputCols=behavior_features,
        outputCol="features_raw",
        handleInvalid="skip"
    )
    clustering_df = assembler.transform(clustering_df)
    
    # Standardiser les features pour le clustering
    scaler = StandardScaler(
        inputCol="features_raw", 
        outputCol="features",
        withStd=True, 
        withMean=True
    )
    
    # Appliquer le scaling
    clustering_pipeline = Pipeline(stages=[scaler])
    clustering_model = clustering_pipeline.fit(clustering_df)
    clustering_df = clustering_model.transform(clustering_df)
    
    # Afficher quelques statistiques sur les features
    logger.info("Statistiques sur les features pour le clustering:")
    for feature in behavior_features:
        clustering_df.select(
            avg(feature).alias(f"avg_{feature}"),
            min(feature).alias(f"min_{feature}"),
            max(feature).alias(f"max_{feature}")
        ).show()
    
    return clustering_df, behavior_features

# Exécuter la préparation pour le clustering
behavior_clustering_df, behavior_features = prepare_behavioral_clustering(user_behavior_df)
behavior_clustering_df.select("user_id", *behavior_features, "features").show(5)

2025-05-18 16:11:09,939 - INFO - Préparation des données pour clustering comportemental
2025-05-18 16:11:10,107 - INFO - Utilisateurs avec au moins 2 événements: 117529
2025-05-18 16:11:10,615 - INFO - Statistiques sur les features pour le clustering:


+-----------------+-------------+-------------+
|    avg_nb_events|min_nb_events|max_nb_events|
+-----------------+-------------+-------------+
|8.121442367415701|            2|          339|
+-----------------+-------------+-------------+

+------------------+------------+------------+
|      avg_nb_views|min_nb_views|max_nb_views|
+------------------+------------+------------+
|7.8537467348484205|           0|         339|
+------------------+------------+------------+

+------------------+------------+------------+
|      avg_nb_carts|min_nb_carts|max_nb_carts|
+------------------+------------+------------+
|0.1245479839018455|           0|         116|
+------------------+------------+------------+

+-------------------+----------------+----------------+
|   avg_nb_purchases|min_nb_purchases|max_nb_purchases|
+-------------------+----------------+----------------+
|0.14314764866543578|               0|              26|
+-------------------+----------------+----------------+

+-----

In [None]:
def train_kmeans_model(df, feature_col="features", k_values=range(2, 11)):
    """Entraîne et évalue plusieurs modèles K-means avec différentes valeurs de k"""
    logger.info("Entraînement des modèles K-means")
    
    # Liste pour stocker les résultats
    silhouette_scores = []
    models = {}
    
    # Évaluateur pour le clustering
    evaluator = ClusteringEvaluator(
        predictionCol="prediction", 
        featuresCol=feature_col,
        metricName="silhouette"
    )
    
    # Tester différentes valeurs de k
    for k in k_values:
        logger.info(f"Essai avec k={k}")
        
        # Créer et entraîner le modèle
        kmeans = KMeans(
            k=k, 
            seed=42, 
            featuresCol=feature_col,
            maxIter=20,
            tol=1e-4
        )
        model = kmeans.fit(df)
        
        # Faire des prédictions
        predictions = model.transform(df)
        
        # Évaluer le modèle
        silhouette = evaluator.evaluate(predictions)
        logger.info(f"Silhouette pour k={k}: {silhouette}")
        
        # Stocker les résultats
        silhouette_scores.append(silhouette)
        models[k] = model
    
    # Trouver la meilleure valeur de k
    best_k = k_values[np.argmax(silhouette_scores)]
    best_score = np.max(silhouette_scores)
    best_model = models[best_k]
    
    logger.info(f"Meilleur modèle: k={best_k} avec silhouette={best_score}")
    
    # Générer les prédictions avec le meilleur modèle
    results = best_model.transform(df)
    
    # Afficher la distribution des clusters
    logger.info("Distribution des clusters:")
    results.groupBy("prediction").count().orderBy("prediction").show()
    
    return best_model, results, best_k, silhouette_scores

# Exécuter l'entraînement du modèle K-means
kmeans_model, cluster_results, best_k, silhouette_scores = train_kmeans_model(
    behavior_clustering_df, feature_col="features", k_values=range(2, 9)
)

# Visualiser les scores silhouette (à travers le logging)
for k, score in zip(range(2, 9), silhouette_scores):
    logger.info(f"K={k}, Silhouette Score={score}")

# Sauvegarder le modèle
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
model_path = f"{MODELS_DIR}/kmeans_behavioral_{best_k}_clusters_{timestamp_str}"
kmeans_model.save(model_path)
logger.info(f"Modèle K-means sauvegardé: {model_path}")

# Analyser les caractéristiques des clusters
def analyze_clusters(df, cluster_col="prediction", feature_cols=None):
    """Analyse les caractéristiques des clusters"""
    logger.info("Analyse des caractéristiques des clusters")
    
    # Calculer les moyennes par cluster
    agg_exprs = [count("*").alias("cluster_size")]
    for col_name in feature_cols:
        agg_exprs.append(avg(col(col_name)).alias(f"avg_{col_name}"))

    cluster_stats = df.groupBy(cluster_col).agg(*agg_exprs).orderBy(cluster_col)

    # Afficher les statistiques par cluster
    cluster_stats.show(truncate=False)
    return cluster_stats

# Analyser les clusters obtenus
cluster_stats = analyze_clusters(
    cluster_results.select("user_id", "prediction", *behavior_features),
    cluster_col="prediction", 
    feature_cols=behavior_features
)

2025-05-18 16:39:21,321 - INFO - Entraînement des modèles K-means
2025-05-18 16:39:21,328 - INFO - Essai avec k=2
2025-05-18 16:39:25,929 - INFO - Silhouette pour k=2: 0.7834441765805065
2025-05-18 16:39:25,930 - INFO - Essai avec k=3
2025-05-18 16:39:29,654 - INFO - Silhouette pour k=3: 0.7244368921023959
2025-05-18 16:39:29,655 - INFO - Essai avec k=4
2025-05-18 16:39:33,403 - INFO - Silhouette pour k=4: 0.7282655495178941
2025-05-18 16:39:33,404 - INFO - Essai avec k=5
2025-05-18 16:39:37,075 - INFO - Silhouette pour k=5: 0.5213942468016235
2025-05-18 16:39:37,076 - INFO - Essai avec k=6
2025-05-18 16:39:40,669 - INFO - Silhouette pour k=6: 0.5486512480531769
2025-05-18 16:39:40,670 - INFO - Essai avec k=7
2025-05-18 16:39:44,318 - INFO - Silhouette pour k=7: 0.5640408649173314
2025-05-18 16:39:44,319 - INFO - Essai avec k=8
2025-05-18 16:39:48,047 - INFO - Silhouette pour k=8: 0.5795547723995433
2025-05-18 16:39:48,048 - INFO - Meilleur modèle: k=2 avec silhouette=0.783444176580506

+----------+------+
|prediction| count|
+----------+------+
|         0|107807|
|         1|  9722|
+----------+------+



2025-05-18 16:39:48,719 - INFO - Modèle K-means sauvegardé: ./models/kmeans_behavioral_2_clusters_20250518_163948
2025-05-18 16:39:48,741 - INFO - Analyse des caractéristiques des clusters


+----------+------------+------------------+-----------------+--------------------+--------------------+--------------+--------------------+-----------------------+------------------+---------------------+--------------------+-------------------+
|prediction|cluster_size|avg_nb_events     |avg_nb_views     |avg_nb_carts        |avg_nb_purchases    |avg_nb_removes|avg_avg_price_viewed|avg_avg_price_purchased|avg_nb_sessions   |avg_conversion_rate  |avg_cart_abandonment|avg_engagement_days|
+----------+------------+------------------+-----------------+--------------------+--------------------+--------------+--------------------+-----------------------+------------------+---------------------+--------------------+-------------------+
|0         |107807      |7.872633502462734 |7.803556355338707|0.043587151112636474|0.025489996011390726|0.0           |320.807282761595    |2.2345311528935974     |1.5037613513037187|0.0027813540821527696|0.02762498415378099 |1.0                |
|1         |

In [45]:
def combine_segmentations(cluster_df, rfm_df):
    """Combine les segmentations RFM et clustering comportemental"""
    logger.info("Combinaison des segmentations RFM et clustering")
    
    # Joindre les deux dataframes sur user_id
    combined_df = cluster_df.select(
        "user_id", "prediction"
    ).join(
        rfm_df.select("user_id", "rfm_segment", "rfm_score", 
                      "recency_score", "frequency_score", "monetary_score"),
        on="user_id", 
        how="inner"
    )
    
    # Renommer les colonnes pour plus de clarté
    combined_df = combined_df.withColumnRenamed("prediction", "behavior_cluster")
    
    # Analyser la distribution conjointe
    logger.info("Distribution conjointe des segments RFM et clusters comportementaux:")
    pivot_table = combined_df.groupBy("behavior_cluster") \
        .pivot("rfm_segment") \
        .agg(count("*")) \
        .na.fill(0)
    
    pivot_table.show(truncate=False)
    
    # Nommer les clusters comportementaux en fonction de leurs caractéristiques
    # Cette partie nécessite une analyse manuelle des statistiques des clusters
    
    # Étiquettes des clusters en fonction des statistiques précédemment calculées
    cluster_labels = {
        # Ces étiquettes sont des exemples et doivent être adaptées en fonction des résultats réels
        0: "Explorateurs Occasionnels",
        1: "Acheteurs Fidèles",
        2: "Visiteurs Fréquents",
        3: "Acheteurs à Fort Panier",
        4: "Visiteurs Uniques",
        5: "Convertisseurs Efficaces",
        6: "Indécis (Abandon Panier)",
        7: "Browsers Passifs",
        8: "Acheteurs Impulsifs"
    }
    
    # Créer une fonction UDF pour appliquer les étiquettes
    cluster_label_udf = udf(lambda cluster_id: cluster_labels.get(cluster_id, f"Cluster {cluster_id}"), StringType())
    
    # Appliquer les étiquettes aux clusters
    combined_df = combined_df.withColumn(
        "behavior_segment", 
        cluster_label_udf(col("behavior_cluster"))
    )
    
    # Afficher la distribution des segments comportementaux
    logger.info("Distribution des segments comportementaux:")
    combined_df.groupBy("behavior_segment").count().orderBy(desc("count")).show(truncate=False)
    
    # Calculer l'affinité entre segments RFM et comportementaux
    logger.info("Affinité entre segments RFM et comportementaux:")
    combined_df.groupBy("behavior_segment", "rfm_segment").count().orderBy(desc("count")).show(20, truncate=False)
    
    return combined_df

# Joindre les résultats de clustering avec les utilisateurs d'origine
user_clusters = cluster_results.select("user_id", "prediction")

# Combiner avec la segmentation RFM
combined_segments = combine_segmentations(user_clusters, rfm_segmentation)

# Sauvegarder les segmentations combinées
combined_output_path = f"{RESULTS_DIR}/combined_segmentation_{timestamp_str}.parquet"
combined_segments.write.mode("overwrite").format("parquet").save(combined_output_path)
logger.info(f"Segmentations combinées sauvegardées: {combined_output_path}")

# Créer un dataframe de profils utilisateurs pour les recommandations
user_profiles = combined_segments.select(
    "user_id", "behavior_cluster", "behavior_segment", "rfm_segment", "rfm_score"
)

# Afficher quelques exemples de profils utilisateurs
logger.info("Exemples de profils utilisateurs:")
user_profiles.show(10, truncate=False)

2025-05-18 16:55:02,073 - INFO - Combinaison des segmentations RFM et clustering
2025-05-18 16:55:02,138 - INFO - Distribution conjointe des segments RFM et clusters comportementaux:
2025-05-18 16:55:04,632 - INFO - Distribution des segments comportementaux:


+----------------+-------------------+
|behavior_cluster|Potential Loyalists|
+----------------+-------------------+
|1               |9722               |
|0               |107807             |
+----------------+-------------------+



Py4JJavaError: An error occurred while calling o12821.showString.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 5 in stage 12012.0 failed 1 times, most recent failure: Lost task 5.0 in stage 12012.0 (TID 79299) (host.docker.internal executor driver): org.apache.spark.SparkException: Python worker failed to connect back.
	at org.apache.spark.api.python.PythonWorkerFactory.createSimpleWorker(PythonWorkerFactory.scala:203)
	at org.apache.spark.api.python.PythonWorkerFactory.create(PythonWorkerFactory.scala:109)
	at org.apache.spark.SparkEnv.createPythonWorker(SparkEnv.scala:124)
	at org.apache.spark.api.python.BasePythonRunner.compute(PythonRunner.scala:174)
	at org.apache.spark.sql.execution.python.BatchEvalPythonExec.evaluate(BatchEvalPythonExec.scala:54)
	at org.apache.spark.sql.execution.python.EvalPythonExec.$anonfun$doExecute$2(EvalPythonExec.scala:131)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2(RDD.scala:858)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2$adapted(RDD.scala:858)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:367)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:331)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:367)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:331)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:367)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:331)
	at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:104)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:54)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:166)
	at org.apache.spark.scheduler.Task.run(Task.scala:141)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:620)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:623)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1144)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:642)
	at java.base/java.lang.Thread.run(Thread.java:1583)
Caused by: java.net.SocketTimeoutException: Accept timed out
	at java.base/sun.nio.ch.NioSocketImpl.timedAccept(NioSocketImpl.java:701)
	at java.base/sun.nio.ch.NioSocketImpl.accept(NioSocketImpl.java:745)
	at java.base/java.net.ServerSocket.implAccept(ServerSocket.java:698)
	at java.base/java.net.ServerSocket.platformImplAccept(ServerSocket.java:663)
	at java.base/java.net.ServerSocket.implAccept(ServerSocket.java:639)
	at java.base/java.net.ServerSocket.implAccept(ServerSocket.java:585)
	at java.base/java.net.ServerSocket.accept(ServerSocket.java:543)
	at org.apache.spark.api.python.PythonWorkerFactory.createSimpleWorker(PythonWorkerFactory.scala:190)
	... 29 more

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2856)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2792)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2791)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2791)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1247)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1247)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1247)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:3060)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2994)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2983)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
Caused by: org.apache.spark.SparkException: Python worker failed to connect back.
	at org.apache.spark.api.python.PythonWorkerFactory.createSimpleWorker(PythonWorkerFactory.scala:203)
	at org.apache.spark.api.python.PythonWorkerFactory.create(PythonWorkerFactory.scala:109)
	at org.apache.spark.SparkEnv.createPythonWorker(SparkEnv.scala:124)
	at org.apache.spark.api.python.BasePythonRunner.compute(PythonRunner.scala:174)
	at org.apache.spark.sql.execution.python.BatchEvalPythonExec.evaluate(BatchEvalPythonExec.scala:54)
	at org.apache.spark.sql.execution.python.EvalPythonExec.$anonfun$doExecute$2(EvalPythonExec.scala:131)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2(RDD.scala:858)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2$adapted(RDD.scala:858)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:367)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:331)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:367)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:331)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:367)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:331)
	at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:104)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:54)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:166)
	at org.apache.spark.scheduler.Task.run(Task.scala:141)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:620)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:623)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1144)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:642)
	at java.base/java.lang.Thread.run(Thread.java:1583)
Caused by: java.net.SocketTimeoutException: Accept timed out
	at java.base/sun.nio.ch.NioSocketImpl.timedAccept(NioSocketImpl.java:701)
	at java.base/sun.nio.ch.NioSocketImpl.accept(NioSocketImpl.java:745)
	at java.base/java.net.ServerSocket.implAccept(ServerSocket.java:698)
	at java.base/java.net.ServerSocket.platformImplAccept(ServerSocket.java:663)
	at java.base/java.net.ServerSocket.implAccept(ServerSocket.java:639)
	at java.base/java.net.ServerSocket.implAccept(ServerSocket.java:585)
	at java.base/java.net.ServerSocket.accept(ServerSocket.java:543)
	at org.apache.spark.api.python.PythonWorkerFactory.createSimpleWorker(PythonWorkerFactory.scala:190)
	... 29 more


In [None]:
def prepare_recommendation_data(recom_df):
    """Prépare les données pour le système de recommandation"""
    logger.info("Préparation des données pour le système de recommandation")
    
    # Indexer les utilisateurs et produits pour ALS
    user_indexer = StringIndexer(
        inputCol="user_id", 
        outputCol="user_idx",
        handleInvalid="skip"
    )
    
    product_indexer = StringIndexer(
        inputCol="product_id", 
        outputCol="product_idx",
        handleInvalid="skip"
    )
    
    # Créer le pipeline de préparation
    pipeline = Pipeline(stages=[user_indexer, product_indexer])
    pipeline_model = pipeline.fit(recom_df)
    als_df = pipeline_model.transform(recom_df)
    
    # Afficher un aperçu des données préparées
    logger.info("Aperçu des données préparées pour ALS:")
    als_df.select("user_id", "user_idx", "product_id", "product_idx", 
                 "event_type", "interaction_score").show(5)
    
    # Calculer quelques statistiques utiles
    unique_users = als_df.select("user_id").distinct().count()
    unique_products = als_df.select("product_id").distinct().count()
    total_interactions = als_df.count()
    
    logger.info(f"Statistiques du dataset de recommandation:")
    logger.info(f"Utilisateurs uniques: {unique_users}")
    logger.info(f"Produits uniques: {unique_products}")
    logger.info(f"Interactions totales: {total_interactions}")
    logger.info(f"Densité: {total_interactions / (unique_users * unique_products) * 100:.6f}%")
    
    # Diviser les données en ensembles d'entraînement et de test
    train_df, test_df = als_df.randomSplit([0.8, 0.2], seed=42)
    
    logger.info(f"Ensemble d'entraînement: {train_df.count()} lignes")
    logger.info(f"Ensemble de test: {test_df.count()} lignes")
    
    return train_df, test_df, pipeline_model

# Préparer les données pour les recommandations
als_train_df, als_test_df, als_pipeline = prepare_recommendation_data(recommendation_df)

2025-05-18 16:10:47,530 - INFO - Préparation des données pour le système de recommandation
2025-05-18 16:10:51,730 - INFO - Aperçu des données préparées pour ALS:


+---------+--------+----------+-----------+----------+-----------------+
|  user_id|user_idx|product_id|product_idx|event_type|interaction_score|
+---------+--------+----------+-----------+----------+-----------------+
|541312140|107029.0|  44600062|    28727.0|      view|                1|
|554748717| 88872.0|   3900821|     1179.0|      view|                1|
|519107250|  8415.0|  17200506|     3015.0|      view|                1|
|550050854| 42331.0|   1307067|      117.0|      view|                1|
|535871217| 27228.0|   1004237|       40.0|      view|                1|
+---------+--------+----------+-----------+----------+-----------------+
only showing top 5 rows



2025-05-18 16:10:56,996 - INFO - Statistiques du dataset de recommandation:
2025-05-18 16:10:56,997 - INFO - Utilisateurs uniques: 163024
2025-05-18 16:10:56,997 - INFO - Produits uniques: 63322
2025-05-18 16:10:56,998 - INFO - Interactions totales: 1000000
2025-05-18 16:10:56,999 - INFO - Densité: 0.009687%
2025-05-18 16:10:59,244 - INFO - Ensemble d'entraînement: 800279 lignes
2025-05-18 16:11:00,726 - INFO - Ensemble de test: 199721 lignes


In [27]:
def train_als_model(train_df, test_df):
    """Entraîne et évalue le modèle ALS"""
    logger.info("Entraînement du modèle ALS")
    
    # Hyperparamètres à tester
    als_models = {}
    ranks = [10, 20, 30]
    reg_params = [0.01, 0.1, 1.0]
    
    best_model = None
    best_rmse = float('inf')
    
    for rank in ranks:
        for reg_param in reg_params:
            logger.info(f"Essai avec rank={rank}, regParam={reg_param}")
            
            als = ALS(
                rank=rank,
                maxIter=15,
                regParam=reg_param,
                userCol="user_idx",
                itemCol="product_idx",
                ratingCol="interaction_score",
                coldStartStrategy="drop",
                nonnegative=True,
                implicitPrefs=True
            )
            
            model = als.fit(train_df)
            predictions = model.transform(test_df)
            
            # Évaluation
            evaluator = RegressionEvaluator(
                metricName="rmse",
                labelCol="interaction_score",
                predictionCol="prediction"
            )
            rmse = evaluator.evaluate(predictions)
            
            logger.info(f"RMSE pour rank={rank}, regParam={reg_param}: {rmse}")
            
            if rmse < best_rmse:
                best_rmse = rmse
                best_model = model
    
    logger.info(f"Meilleur modèle - RMSE: {best_rmse}")
    
    # Sauvegarde du modèle
    timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
    model_path = f"{MODELS_DIR}/als_model_{timestamp_str}"
    best_model.save(model_path)
    logger.info(f"Modèle ALS sauvegardé: {model_path}")
    
    return best_model

# Exécuter l'entraînement ALS
als_model = train_als_model(als_train_df, als_test_df)

2025-05-18 16:24:25,139 - INFO - Entraînement du modèle ALS
2025-05-18 16:24:25,141 - INFO - Essai avec rank=10, regParam=0.01
2025-05-18 16:25:32,342 - INFO - RMSE pour rank=10, regParam=0.01: 1.687387446788512
2025-05-18 16:25:32,343 - INFO - Essai avec rank=10, regParam=0.1
2025-05-18 16:26:20,404 - INFO - RMSE pour rank=10, regParam=0.1: 1.6891501157680544
2025-05-18 16:26:20,405 - INFO - Essai avec rank=10, regParam=1.0
2025-05-18 16:27:09,275 - INFO - RMSE pour rank=10, regParam=1.0: 1.730507037766797
2025-05-18 16:27:09,277 - INFO - Essai avec rank=20, regParam=0.01
2025-05-18 16:28:06,149 - INFO - RMSE pour rank=20, regParam=0.01: 1.6673586953989568
2025-05-18 16:28:06,151 - INFO - Essai avec rank=20, regParam=0.1
2025-05-18 16:28:56,647 - INFO - RMSE pour rank=20, regParam=0.1: 1.669009329707226
2025-05-18 16:28:56,648 - INFO - Essai avec rank=20, regParam=1.0
2025-05-18 16:29:52,043 - INFO - RMSE pour rank=20, regParam=1.0: 1.7204041942327433
2025-05-18 16:29:52,045 - INFO - 

In [33]:
def prepare_content_based_features(product_df):
    """Prépare les caractéristiques produits pour le content-based filtering"""
    logger.info("Préparation des caractéristiques produits")
    
    # Combiner les métadonnées
    product_features = product_df.withColumn(
        "product_features",
        concat_ws(" ", 
            col("category_code"), 
            col("brand"), 
            col("price").cast("string")
        )
    )
    
    # TF-IDF
    tokenizer = Tokenizer(inputCol="product_features", outputCol="tokens")
    hashing_tf = HashingTF(inputCol="tokens", outputCol="raw_features", numFeatures=1000)
    idf = IDF(inputCol="raw_features", outputCol="features")
    
    pipeline = Pipeline(stages=[tokenizer, hashing_tf, idf])
    model = pipeline.fit(product_features)
    product_features = model.transform(product_features)
    
    return product_features

# Utilisation
product_features = prepare_content_based_features(product_df)

2025-05-18 16:40:42,419 - INFO - Préparation des caractéristiques produits


In [None]:
def get_user_preferences(user_id, als_model):
    """Récupère le vecteur latent sous forme de liste"""
    user_factors = als_model.userFactors.filter(col("id") == user_id)
    if user_factors.count() == 0:
        return None
    return user_factors.first().features.tolist()

def cosine_similarity(v1, v2):
    """Calcule la similarité cosinus entre deux vecteurs"""
    return float(v1.dot(v2) / (np.linalg.norm(v1) * np.linalg.norm(v2)))

def combine_recommendations(als_recs, content_recs, alpha=0.7):
    """Combine les recommandations avec pondération"""
    combined = als_recs.union(content_recs).groupBy("product_id").agg(
        (alpha * max("rating")).alias("als_score"),
        ((1 - alpha) * max("similarity")).alias("content_score"),
        (alpha * max("rating") + (1 - alpha) * max("similarity")).alias("combined_score")
    ).orderBy(desc("combined_score"))
    
    return combined

def hybrid_recommendations(user_id, als_model, product_features, num_recs=10):
    """Combine les recommandations collaboratives et basées sur le contenu"""
    # Conversion préalable du user_id en DataFrame
    user_df = spark.createDataFrame([(user_id,)], ["user_id"])
    
    # Récupération des recommandations ALS (avec gestion de la sérialisation)
    als_recs = als_model.recommendForUserSubset(user_df, num_recs)
    
    # Extraction du vecteur utilisateur sous forme de liste Python
    user_vector = get_user_preferences(user_id, als_model)
    if user_vector is None:
        return als_recs
    
    user_array = user_vector.tolist()  # Conversion en liste sérialisable
    
    # Définition de l'UDF avec capture de la liste Python
    def similarity_udf_wrapper(features):
        user_vec = np.array(user_array)
        product_vec = np.array(features.toArray())
        dot_product = np.dot(user_vec, product_vec)
        norm_product = np.linalg.norm(product_vec)
        return float(dot_product / (np.linalg.norm(user_vec) * norm_product))
    
    similarity_udf = udf(similarity_udf_wrapper, FloatType())
    
    # Calcul des similarités
    content_recs = product_features.withColumn(
        "similarity", similarity_udf(col("features"))
    ).select("product_id", "similarity").orderBy(desc("similarity")).limit(num_recs)
    
    # Combinaison des résultats
    als_df = als_recs.selectExpr(
        "product_id", 
        "explode(recommendations).rating"
    )
    
    return combine_recommendations(als_df, content_recs)    

### 2. Implémentation complète de enhanced_rfm_segmentation ###
def enhanced_rfm_segmentation(user_df):
    """Amélioration de la segmentation RFM avec validation"""
    # Calcul des percentiles dynamiques
    window_r = Window.orderBy(desc("recency"))
    window_fm = Window.orderBy(col("frequency"), col("monetary"))
    
    rfm_df = user_df.withColumn(
        "r_percentile", percent_rank().over(window_r)
    ).withColumn(
        "f_percentile", percent_rank().over(window_fm)
    ).withColumn(
        "m_percentile", percent_rank().over(window_fm)
    )
    
    # Calcul du score pondéré
    rfm_df = rfm_df.withColumn(
        "rfm_score",
        (col("r_percentile") * 0.4 + 
         col("f_percentile") * 0.3 + 
         col("m_percentile") * 0.3)
    )
    
    # Clustering avec validation
    assembler = VectorAssembler(inputCols=["rfm_score"], outputCol="features")
    kmeans = KMeans(k=5, seed=42)
    pipeline = Pipeline(stages=[assembler, kmeans])
    model = pipeline.fit(rfm_df)
    
    # Évaluation
    predictions = model.transform(rfm_df)
    evaluator = ClusteringEvaluator()
    silhouette = evaluator.evaluate(predictions)
    logger.info(f"Silhouette Score pour RFM amélioré: {silhouette}")
    
    return predictions

### 3. Intégration dans le flux principal ###
# Après l'entraînement du modèle ALS et la segmentation RFM initiale

# Segmentation RFM améliorée
enhanced_rfm = enhanced_rfm_segmentation(user_behavior_df)
enhanced_rfm.select("user_id", "rfm_score", "prediction").show(5)

# Exemple d'utilisation des recommandations hybrides
sample_user = str(user_behavior_df.first().user_id)

hybrid_recs = hybrid_recommendations(
    user_id=sample_user,
    als_model=als_model,
    product_features=product_features,
    num_recs=10
)

logger.info("Recommandations hybrides pour l'utilisateur %s:", sample_user)
hybrid_recs.show(10)

# Sauvegarde des résultats
hybrid_recs.write.mode("overwrite").parquet(f"{RESULTS_DIR}/hybrid_recommendations")
enhanced_rfm.write.mode("overwrite").parquet(f"{RESULTS_DIR}/enhanced_rfm_segments")

2025-05-18 16:54:03,738 - INFO - Silhouette Score pour RFM amélioré: 0.9795136175414187


+---------+---------+----------+
|  user_id|rfm_score|prediction|
+---------+---------+----------+
|337535108|      0.0|         0|
|410824220|      0.0|         0|
|418592979|      0.0|         0|
|420339201|      0.0|         0|
|430276841|      0.0|         0|
+---------+---------+----------+
only showing top 5 rows



AnalysisException: [UNRESOLVED_COLUMN.WITH_SUGGESTION] A column or function parameter with name `user_idx` cannot be resolved. Did you mean one of the following? [`user_id`].;
'Project ['user_idx]
+- LogicalRDD [user_id#22230], false
