In [None]:
from core.spark_utils import create_spark_session
from core.s3.s3_utils import S3Service
from core.s3.settings import S3Settings


spark = create_spark_session(
    S3Settings().S3_KEY,
    S3Settings().S3_SECRET
)

s3 = S3Service()


## CHARGING_STATUS_IDX : OEM Development

In [None]:
# Charger le dataframe
rss = s3.read_parquet_df_spark(spark, 'raw_ts/mercedes-benz/time_series/raw_ts_spark.parquet')

In [None]:
# Filtrer sur un VIN

### INPUT ###
vin = 'W1N9N0CB8SJ120589'
rss_by_vin = rss.filter(rss['vin'] == vin)

In [None]:
from pyspark.sql import functions as F
from pyspark.sql.window import Window
from pyspark.sql.functions import lag, lead, sum, when, col, signum, dense_rank
from pyspark.sql.types import DoubleType
from pyspark.sql import DataFrame as DF

def _reassign_short_phases(self, df, min_duration_minutes=3):
    """
    Recalcule les phase_id en fusionnant les phases de moins de `min_duration_minutes`
    avec la phase valide précédente.

    Args:
        df (DataFrame): DataFrame Spark avec les colonnes `phase_id`, `date`, `total_phase_time`
        min_duration_minutes (float): Durée minimale pour conserver une phase (en minutes)

    Returns:
        DataFrame: DataFrame avec la colonne `phase_id` mise à jour
    """

    # 1. Marquer les phases valides
    df = df.withColumn(
        "is_valid_phase",
        F.when(F.col("total_phase_time") >= min_duration_minutes, 1).otherwise(0)
    )

    # 2. Déterminer la dernière phase valide précédemment
    w_time = Window.partitionBy("vin").orderBy("date").rowsBetween(Window.unboundedPreceding, 0)

    df = df.withColumn(
        "last_valid_phase_id",
        F.last(
            F.when(F.col("is_valid_phase") == 1, F.col("phase_id")),
            ignorenulls=True
        ).over(w_time)
    )

    # 3. Mettre à jour le phase_id
    df = df.withColumn(
        "phase_id_updated",
        F.when(F.col("is_valid_phase") == 1, F.col("phase_id"))
        .otherwise(F.col("last_valid_phase_id"))
    )

    # 4. Re-numérotation des phase_id pour compacter (optionnel mais propre)
    df = df.withColumn(
        "phase_id_final",
        F.dense_rank().over(Window.partitionBy('vin').orderBy("phase_id_updated")) - 1
    )

    # 5. Nettoyage final
    df = df.drop("phase_id", "last_valid_phase_id", "is_valid_phase", "phase_id_updated")
    df = df.withColumnRenamed("phase_id_final", "phase_id")

    return df

def compute_charge_idx_bis(self, tss: DF) -> DF:

    tss = self.compute_energy_added(tss)

    # Définir les fenêtres
    w = Window.partitionBy("vin").orderBy("date")

    # Calculer soc_diff en allant chercher la précédente valeur non nulle
    tss = tss.withColumn("soc_diff", 
        F.when(
            F.col("soc").isNotNull(),
            F.col("soc") - F.last("soc", ignorenulls=True).over(
                w.rowsBetween(Window.unboundedPreceding, -1)
            )
        ).otherwise(None)
)

    # Calcul du gap en minutes
    df = tss.withColumn("prev_date", lag("date").over(w))
    df = df.withColumn("time_gap_minutes", 
        (F.unix_timestamp("date") - F.unix_timestamp("prev_date")) / 60)

    # Calcul de direction avec forward fill
    df = df.withColumn("direction_raw", 
        F.when(col("soc_diff").isNull(), None).otherwise(F.signum("soc_diff")))

    # Forward fill de la direction
    df = df.withColumn("direction", 
        F.last("direction_raw", ignorenulls=True).over(w.partitionBy("vin").orderBy("date").rowsBetween(Window.unboundedPreceding, 0)))



    # Détecter les changements de direction
    df = df.withColumn("direction_change", 
        F.when(F.col("direction") != F.lag("direction").over(w), 1).otherwise(0))

    # Créer phase_id en cumulant les changements
    df = df.withColumn("phase_id", 
        F.sum("direction_change").over(w.rowsBetween(Window.unboundedPreceding, 0)))

    w_phase = Window.partitionBy("vin", "phase_id")

    df = df.withColumn("total_phase_time", F.sum("time_gap_minutes").over(w_phase))

    df = self._reassign_short_phases(df)

    w_phase = Window.partitionBy("vin", "phase_id")

    df = df.withColumn("total_soc_diff", F.sum("soc_diff").over(w_phase))

    df = df.withColumn("prev_phase", F.lag("direction").over(w)).withColumn("next_phase", F.lead("direction").over(w))

    df = df.withColumn(
        "charging_status",
        F.when(F.col("total_soc_diff") > 0.5, "charging")
        .when(F.col("total_soc_diff") < -0.5, "discharging")
        .when((F.col("prev_phase") == F.col("next_phase")) & (F.col("prev_phase") >= 0), "charging")
        .when((F.col("prev_phase") == F.col("next_phase")) & (F.col("prev_phase") <= 0), "discharging")
        .otherwise("idle")  
    )
    # Étape 3: Recréer le phase_id basé sur les changements de phase
    df = df.withColumn("charging_status_change", 
        F.when(F.col("charging_status") != F.lag("charging_status").over(w), 1).otherwise(0))

    df = df.withColumn("charging_status_idx", 
        F.sum("charging_status_change").over(w.rowsBetween(Window.unboundedPreceding, 0)))

    return df