# Preprocessing

In [2]:
%load_ext sagemaker_studio_analytics_extension.magics
%sm_analytics emr connect --cluster-id j-3MNOUDE8S2V30 --auth-type None   

Successfully read emr cluster(j-3MNOUDE8S2V30) details
Initiating EMR connection..
Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log,User,Current session?
0,application_1677493527171_0002,pyspark,idle,Link,Link,,✔


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

SparkSession available as 'spark'.
{"namespace": "sagemaker-analytics", "cluster_id": "j-3MNOUDE8S2V30", "error_message": null, "success": true, "service": "emr", "operation": "connect"}


In [3]:
%%configure -f

{ "conf":
    {
        "spark.pyspark.python": "python3",
        "spark.pyspark.virtualenv.enabled": "true",
        "spark.pyspark.virtualenv.type":"native",
        "spark.pyspark.virtualenv.bin.path":"/usr/bin/virtualenv",
        "spark.sql.legacy.parquet.datetimeRebaseModeInRead": "LEGACY",
        "spark.driver.memory": "10000M"
    }
}

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log,User,Current session?
1,application_1677493527171_0003,pyspark,idle,Link,Link,,✔


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

SparkSession available as 'spark'.


ID,YARN Application ID,Kind,State,Spark UI,Driver log,User,Current session?
1,application_1677493527171_0003,pyspark,idle,Link,Link,,✔


In [4]:
# Pyspark functions
from pyspark.sql.functions import col, udf, isnan, when, count, lit, mean, min, max, sum, round, format_number, abs, current_date, dayofweek, size, dayofyear, \
                                    datediff, to_date, year, to_utc_timestamp, unix_timestamp, coalesce, months_between, substring, rank, countDistinct, collect_set, \
                                    first, quarter, dense_rank, desc, month, row_number, asc

from pyspark.sql import DataFrame
from pyspark.sql.types import *
from pyspark.sql.window import Window
from functools import reduce

# Other
import os
import pandas as pd
import numpy as np
import datetime
from typing import Dict
from matplotlib import pyplot as plt
pd.set_option('display.max_columns', None)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

#### Preprocessing Functions

In [45]:
def get_last_s3_partition(s3_dir_customer: str, grep_flag: bool=False, grep: str='') -> str:
    
    """ This functions assumes that the last partition is the last one that appears on `aws s3 ls` command

    Parameters
    ----------
    s3_dir_customer: s3 path with churn data ending with '/'

    Returns
    -------
    complete path to the last partition
    """
    
    var = os.popen(f"aws s3 ls {s3_dir_customer} |tail -n 1 | awk '{{print $2}}'").read()
    os.path.join(s3_dir_customer,var[:-1])

    return os.path.join(s3_dir_customer,var[:-1])

def rename_multiple_columns(df: DataFrame, rename_col: Dict[str, str]) -> DataFrame:

    """This function change names in columns of a pyspark dataframe.

    Parameters
    ----------
    dataframe: Dataframe with the columns we want to change.
    rename_col: Dictionary where its keys are the old names of the columns
    and the values are the new names of the columns.

    Returns
    -------
    Dataframe with the new names updated.

    """

    for old_name, new_name in rename_col.items():
        df = df.withColumnRenamed(old_name, new_name)

    return df

def create_null_field(column: str, null_format: object) -> object:
    
    """This function change a value by null.

    Parameters
    ----------
    column: column with value to change.
    null_format: value to change inside column

    Returns
    -------
    Column with null values instead of null_format.
    
    """
    
    return when(column.startswith(null_format), lit(None)).otherwise(column)

def eliminate_time(date_time: str, sep: str) -> str:
    
    """This function split a datetime string to eliminate time.

    Parameters
    ----------
    date_time: datetime string
    sep: separator between date and time

    Returns
    -------
    String with date
    
    """

    if not date_time is None:
        return date_time.strip().split(sep)[0]
    else:
        return date_time
    
def substitute_values_column(df: DataFrame, column: str, dict_values_condition: dict) -> DataFrame:
    
    """This function change the values inside a column based on a condition.

    Parameters
    ----------
    dataframe: original dataframe
    column: column to change values
    dict_values_condition: list of dictionaries with dictionaries
    with this structure --> {'value_1': x, 'value_2': y, 'col_condition': z}
    where value_1 is the value to change, value_2 the substitute and
    col_condition the column to meet the condition
    
    Returns
    -------
    Dataframe with column values changed
    
    """

    for dict_value in dict_values_condition:
        if dict_value['value_1'] == 'null':
            df = df.withColumn(column, when(col(dict_value['col_condition']).isNull(),
                                                     dict_value['value_2']).otherwise(col(column)))
        else:
            df = df.withColumn(column, when(col(dict_value['col_condition']) == dict_value['value_1'],
                                                         dict_value['value_2']).otherwise(col(column)))
    
    return df

def coalesce_two_columns(df: DataFrame, dict_columns: dict) -> DataFrame:
    
    """This function coalesce two columns.

    Parameters
    ----------
    dataframe: original dataframe
    dict_columns: {col1: col2}
    
    Returns
    -------
    Dataframe two columns coalesced
    
    """

    for column in dict_columns.keys():
        df = df.withColumn(column, coalesce(col(column), dict_columns[column]))
        
    return df

def season_extraction(date: object) -> str:

    """This function extracts season from a date value.

    Parameters
    ----------
    date: datetime
    
    Returns
    -------
    String with season
    
    """

    # "day of year" ranges for the northern hemisphere
    spring = range(80, 172)
    summer = range(172, 264)
    fall = range(264, 355)
    # winter = everything else
    if date in spring:
        season = 'spring'
    elif date in summer:
        season = 'summer'
    elif date in fall:
        season = 'fall'
    else:
        season = 'winter'

    return season

def read_s3_data_churn(s3_dir_customer: str, columns: str) -> DataFrame:
    
    """This function read churn data from S3, filter and select columns

    Parameters
    ----------
    s3_dir_customer: s3 path with churn data
    columns: list of columns to filter out
    
    Returns
    -------
    pyspark Dataframe with all churn information not processed
    
    """

    #Read and rename desired columns
    df = spark.read.csv(s3_dir_customer, header = 'True')

    # Select desired columns
    df = df.select(columns)
    
    return df

def translate_class_code(df: DataFrame) -> DataFrame:
    
    """Translate class codes in each main class (economy, premium, business)

    Parameters
    ----------
    df_customer_data_filtered: pyspark DataFrame
    
    Returns
    -------
    pyspark Dataframe with sold_class_code translated to main class
    
    """
      
    economy = ["A", "B", "F", "G", "H", "K", "L", "M", "N", "O", "Q", "S", "V", "Y", "Z", "X"]
    premium = ["E", "P", "T", "W"]
    business = ["C", "D", "I", "J", "R", "U"]
    
    df = df.withColumn('sold_class_code', when(col('sold_class_code').isin(economy),
                                               'economy').otherwise(col('sold_class_code')))
    df = df.withColumn('sold_class_code', when(col('sold_class_code').isin(premium),
                                               'premium').otherwise(col('sold_class_code')))
    df = df.withColumn('sold_class_code', when(col('sold_class_code').isin(business),
                                               'business').otherwise(col('sold_class_code')))
    
    return df

def cast_columns(df: DataFrame) -> DataFrame:
    
    """Cast all pyspark columns

    Parameters
    ----------
    df_customer_data_filtered: pyspark DataFrame
    
    Returns
    -------
    pyspark Dataframe with int, double and date columns casted
    
    """
    
    # Cast gross_revenue and total_payment to double
    df = df.withColumn("gross_revenue_eur",df.gross_revenue_eur.cast('double'))
    df = df.withColumn("eur_bags",df.eur_bags.cast('double'))
    df = df.withColumn("eur_seats",df.eur_seats.cast('double'))
    df = df.withColumn("eur_upgs",df.eur_upgs.cast('double'))
    df = df.withColumn("eur_others",df.eur_others.cast('double'))
    df = df.withColumn("revenue_avios",df.revenue_avios.cast('double'))
    df = df.withColumn("num_bags",df.num_bags.cast('int'))
    df = df.withColumn("num_seats",df.num_seats.cast('int'))
    df = df.withColumn("num_upgs",df.num_upgs.cast('int'))
    df = df.withColumn("num_others",df.num_others.cast('int'))
    # Cast disruptions
    df = df.withColumn("flag_misconnection_misc",when(col('flag_misconnection_misc')==1,lit(1)).otherwise(lit(0)).cast('int'))
    df = df.withColumn("flag_dng",when(col('flag_dng')==1,lit(1)).otherwise(lit(0)).cast('int'))
    df = df.withColumn("cancelled",when(col('cancelled')==1,lit(1)).otherwise(lit(0)).cast('int'))
    df = df.withColumn("delayed_minutes_arrival",df.delayed_minutes_arrival.cast('int'))
    # Cast marketing permission and is_corporate
    df = df.withColumn("mkt_permission",df.mkt_permission.cast('int'))
    df = df.withColumn("is_corporate",df.is_corporate.cast('int'))
    # Cast ticket_date to date
    eliminateTimeUDF = udf(lambda x: eliminate_time(x, ' '))
    df = df.withColumn('ticket_sale_date_cast', to_date(col("ticket_sale_date"),"yyyy-MM-dd").alias("date_2"))
    df = df.withColumn('date_creation_pnr_resiber', to_date(eliminateTimeUDF(col("date_creation_pnr_resiber")),
                                                            "yyyy-MM-dd").alias("date_1"))
    df = df.withColumn('date', to_date(col("birth_date"), "yyyy-MM-dd").alias("date"))
    df = df.withColumn('loc_dep_date', to_date(col("loc_dep_date"),"yyyy-MM-dd"))
    df = df.withColumn('loc_arr_date', to_date(col("loc_arr_date"),"yyyy-MM-dd"))
    df = df.withColumn('date_creation_idgoldenrecord', to_date(col("date_creation_idgoldenrecord"),"yyyy-MM-dd"))
    
    return df

def email_agencies_process(df: DataFrame) -> DataFrame:
    
    """This function eliminates rows with email operative containing
    words from the black list that tries to eliminate agencies from the
    final client scope

    Parameters
    ----------
    df_customer_data_filtered: pyspark DataFrame
    
    Returns
    -------
    pyspark Dataframe filtering some agencies
    
    """

    black_list_agencies = 'viaje|trip|flight|viaxe|halcon|emisiones|reserva|booking|billete|' + \
                       'travel|tour|venta|agencia|crew|hotel|viaggi|junta|xunta|dreamgo|' + \
                        'enjoyebre|airtip|melisur|ltnspain|unififi|triporate|book|' + \
                        'fly|viaja|pentamundos|vacaciones|ndc-communication|mundoterra|embarcate|' + \
                        'agency|grupogea|vectalia|partocrs|@airbus|@aer|millenniummenorca|@reattiva|' + \
                        'compras|pasajes|coniltur|playasenator|melillaexpress|' + \
                        'abramar|rascadomarin|revistaestretegas|@esky|confirmation|@calrom|confirmacion|' + \
                        'vacaciones|turmar|kanvoy|vuelo|undefined|iberiaexpress|gestion|@aer|@ulysse|@turatlantica|geographica|geostar|' + \
                        'terminal9|murimar|hceivissa|gaselec|@anjoca|grupogea|betweencongresos|protecmedia|recepcion|' + \
                        'grupolesaca|vacanze|voyage|pisamundavecindario|administracion|overture|' + \
                        'tierrasdelmundo|planning|properties|vuela|crucero|@dondetellevo|@iberostar|' + \
                        'nuevasrutas|persiguiendoelviento|viatges|ticket|airdepartment|e-savia|@reedmackay|' + \
                        '@globalia|.edreams.com|@trailfinders'
    df = df.withColumn('perc_email_agency', when(col('email_operative').rlike(black_list_agencies), lit(1)).otherwise(lit(0)))
    
    w1 = Window.partitionBy('email_operative')
    df = df.withColumn("n_ticket_email", count("*").over(w1))
    
    return df

def extract_features_non_filtered_data(df: DataFrame, year_1: str, year_2: str, year_3: str) -> DataFrame:
    
    """Extract window features (exchanges, not_travelled and refunded) before any filter

    Parameters
    ----------
    df_customer_data_filtered: pyspark DataFrame
    year_1: flag for year - 1 from training date
    year_2: flag for year - 2 from training date
    year_3: flag for year - 2 from training date
    
    Returns
    -------
    pyspark Dataframe with coupon_usage_code as flag feature with 3 new columns
    
    """
    
    # id golden features level with all information (Exchanges, Not Travelled)
    w1 = Window.partitionBy('cid')
    df = df.withColumn('n_exchanges', sum(when((col('coupon_usage_code') == 'E') & \
                                               col('year').isin([int(year_3), int(year_2),int(year_1)]),
                                               lit(1)).otherwise(lit(0))).over(w1))
    df = df.withColumn('n_not_travelled', sum(when((col('coupon_usage_code') == 'N') & \
                                                   col('year').isin([int(year_3), int(year_2), int(year_1)]),
                                                   lit(1)).otherwise(lit(0))).over(w1))
    df = df.withColumn('n_refunded', sum(when((col('coupon_usage_code') == 'R') & \
                                              col('year').isin([int(year_3), int(year_2), int(year_1)]),
                                              lit(1)).otherwise(lit(0))).over(w1))
    
    return df

def extract_general_features(df: DataFrame, year_1: str, year_2: str, year_3: str) -> DataFrame:
    
    """Extract features window features after filtering out T and N coupon usage code

    Parameters
    ----------
    df_customer_data_filtered: pyspark DataFrame
    year_1: flag for year - 1 from training date
    year_2: flag for year - 2 from training date
    year_3: flag for year - 2 from training date
    
    Returns
    -------
    pyspark Dataframe with sold_class_code translated, flight spent deviation and position
    
    """
    
    # Class code
    df = translate_class_code(df)
    # Mean, max and position of flight spent in with respect other passengers
    w1 = Window.partitionBy('op_carrier_code', 'op_flight_num', 'loc_dep_date', 'sold_class_code')
    df = df.withColumn("gross_revenue_valid", when((col('op_carrier_code').isNotNull()) & \
                                                   (col('year').isin([int(year_3), int(year_2), int(year_1)])),
                                                   col('gross_revenue_eur')).otherwise(np.nan))
    df = df.withColumn("mean_flight_spent", mean("gross_revenue_valid").over(w1))
    df = df.withColumn('deviation_flight_spent', (col('gross_revenue_valid')-col('mean_flight_spent'))/col('mean_flight_spent'))
    df = df.withColumn("rank", row_number().over(w1.orderBy(desc("gross_revenue_valid"))))
    df = df.withColumn("total_passenger_class", count('rank').over(w1))
    
    return df

def extract_general_features_after(df: DataFrame) -> DataFrame:

    # Op code IB %
    w1 = Window.partitionBy('cid')
    df = df.withColumn("IB_count", sum((when((col('op_carrier_code') == 'IB') & (col('coupon_usage_code') == 'T'),
                                                                                lit(1)).otherwise(lit(0)))).over(w1))
    df = df.withColumn("total_op_count", sum((when((col('coupon_usage_code') == 'T'),
                                                   lit(1)).otherwise(lit(0)))).over(w1))
    df = df.withColumn("IB_pctg", col('IB_count')/(col('total_op_count')))
    
    # Extract if passenger is resident
    df = df.withColumn("is_resident", when(col('pax_type_seg').contains('RESIDENT'), lit(1)).otherwise(lit(0)))
    
    return df

def extract_rank_features(df: DataFrame) -> DataFrame:
    
    """Extract order of purchase and flight by id_golden_record

    Parameters
    ----------
    df_customer_data_filtered: pyspark DataFrame

    Returns
    -------
    pyspark Dataframe with ranking of flights and purchases
    
    """
    
    # Purchase and flights ranking
    w1 = Window.partitionBy('cid')
    df =  df.withColumn("rank_purchase", dense_rank().over(w1.orderBy(desc("date_creation_pnr_resiber"), 'pnr_resiber')))
    df =  df.withColumn("rank_flight", dense_rank().over(w1.orderBy(desc("loc_arr_date"))))
    
    return df

def filter_initial_dataframe(df: DataFrame) -> DataFrame:
    
    """Filter out first requeriments for churn analysis

    Parameters
    ----------
    df_customer_data_sel: pyspark DataFrame

    Returns
    -------
    pyspark Dataframe without crew data, cast, coalesce of pnr date and pnrs
    and change 0001-01-01 by null values

    """

    # Change 0001-01-01 date by null values
    df = df.withColumn('date_creation_pnr_resiber', create_null_field(col('date_creation_pnr_resiber'), '0001-01-01'))

    # Filter rows with CREW customers
    df = df.where(~col("last_name_1").contains("CREWANE"))

    # Cast columns
    df = cast_columns(df)

    # Colaesce columns for null values
    dict_columns = {'date_creation_pnr_resiber': 'ticket_sale_date_cast',
                    'pnr_resiber': 'pnr_amadeus'}
    df = coalesce_two_columns(df, dict_columns)

    return df

def create_present_future_flags(df: DataFrame, year_1: str, year_2: str, year_3: str, date_limit: object,
                                date_3year: object, date_2year: object, date_1year: object, date_future: object=None) -> DataFrame:
    
    """Create columns with present and future flags
     
    Parameters
    ----------
    df_customer_data_filtered: pyspark DataFrame
    year_1: flag with value -1
    year_2: flag with value -2
    year_3: flag with value -3
    date_limit: limit for labeling future transactions
    date_3year: year-3 date
    date_2year: year-2 date
    date_1year: year-1 date
    date_future: limit prediction date
    
    Returns
    -------
    pyspark Dataframe with column year indicating periods (0 for future)
    
    """
    
    # Select years for training and testing
    df = df.withColumn('year', year('date_creation_pnr_resiber'))
    if date_future is not None:
        df = df.where((col('date_creation_pnr_resiber') > date_limit) & \
                      (col('date_creation_pnr_resiber') <= date_future))
    else:
        df = df.where((col('date_creation_pnr_resiber') > date_limit) & \
                      (col('date_creation_pnr_resiber') <= date_1year))
    df = df.withColumn('year', when((col('date_creation_pnr_resiber') > date_limit) & \
                                    (col('date_creation_pnr_resiber') <= date_3year), year_3).\
                       when((col('date_creation_pnr_resiber') > date_3year) & \
                            (col('date_creation_pnr_resiber') <= date_2year), year_2).\
                       when((col('date_creation_pnr_resiber') > date_2year) & \
                            (col('date_creation_pnr_resiber') <= date_1year), year_1).otherwise(lit(0)))
    
    return df

def group_by_od(df: DataFrame) -> DataFrame:
    
    """Groupby train dataset by od with logic selection
     
    Parameters
    ----------
    df_customer_data_filtered_train: pyspark DataFrame

    Returns
    -------
    pyspark Dataframe grouped by od
    
    """
    
    df = df.groupby('cid', 'pnr_resiber', 'date_creation_pnr_resiber', 'itinerary_od', 'year').agg(
        sum(col('gross_revenue_eur')).alias('gross_revenue_eur'),
        first(col('prime_ticket_num')).alias('prime_ticket_num'),
        sum(col('num_bags')).alias('num_bags'),
        sum(col('eur_bags')).alias('eur_bags'),
        sum(col('num_seats')).alias('num_seats'),
        sum(col('eur_seats')).alias('eur_seats'),
        sum(col('num_upgs')).alias('num_upgs'),
        sum(col('eur_upgs')).alias('eur_upgs'),
        sum(col('num_others')).alias('num_others'),
        sum(col('eur_others')).alias('eur_others'),
        max(col('ff_num')).alias('ff_num'),
        first(col('date')).alias('date'),
        first(col('pax_type_ind')).alias('pax_type_ind'),
        first(col('ind_reason_business')).alias('ind_reason_business'),
        first(col('num_days_anticipation')).alias('num_days_anticipation'),
        first(col('num_hours_in_destination')).alias('num_hours_in_destination'),
        min(col('loc_dep_date')).alias('loc_dep_date'),
        max(col('loc_arr_date')).alias('loc_arr_date'),
        min(col('loc_dep_time')).alias('loc_dep_time'),
        max(col('loc_arr_time')).alias('loc_arr_time'),
        first(col('date_creation_idgoldenrecord')).alias('date_creation_idgoldenrecord'),
        first(col('min_haul')).alias('min_haul'),
        max(when(col('ff_group') == 'IBP', col('ff_tier')).otherwise(0)).alias('ff_tier'),
        max(when(col('ff_group') == 'IBP', col('ff_group')).otherwise(0)).alias('ff_group'),
        sum(col('flag_misconnection_misc')).alias('flag_misconnection_misc'),
        sum(col('flag_dng')).alias('flag_dng'),
        max(col('delayed_minutes_arrival')).alias('max_delayed_minutes_arrival'),
        min(col('delayed_minutes_arrival')).alias('min_delayed_minutes_arrival'),
        sum(col('cancelled')).alias('cancelled'),
        max(col('mkt_permission')).alias('mkt_permission'),
        max(col('purchases_2015')).alias('purchases_2015'),
        max(col('purchases_2016')).alias('purchases_2016'),
        max(col('purchases_2017')).alias('purchases_2017'),
        first(col('n_exchanges')).alias('n_exchanges'),
        first(col('n_not_travelled')).alias('n_not_travelled'),
        first(col('n_refunded')).alias('n_refunded'),
        max(col('deviation_flight_spent')).alias('deviation_flight_spent'),
        first('ind_direct_sale').alias('ind_direct_sale'),
        first('is_corporate').alias('is_corporate'),
        first('destination_city_od').alias('destination_city_od'),
        sum(col('revenue_avios')).alias('revenue_avios'),
        max(col('coupon_usage_code')).alias('coupon_usage_code'),
        max(col('delayed_minutes_arrival')).alias('delayed_minutes_arrival'),
        first(col('IB_pctg')).alias('IB_pctg'),
        max(col('is_resident')).alias('is_resident'),
        max(col('perc_email_agency')).alias('perc_email_agency'),
        first(col('n_ticket_email')).alias('n_ticket_email'),
        first(col('point_of_sale')).alias('point_of_sale'))
    
    return df

def class_code_generation(df: DataFrame) -> DataFrame:
    
    """Groupby cid an creates 3 columns indicating
    the % of each class flown (must sum 1 all columns)
     
    Parameters
    ----------
    df_customer_data_filtered_train: pyspark DataFrame

    Returns
    -------
    pyspark Dataframe grouped by cid with 3 columns by class
    
    """
    
    df = df.groupby('pnr_resiber', 'date_creation_pnr_resiber', 'cid', 'sold_class_code', 'year').count()
    df = df.groupby('cid').agg((sum(when((col('sold_class_code') == 'economy'),
                                                    lit(1)).otherwise(0))/sum(lit(1))).alias('economy'),
                                          (sum(when((col('sold_class_code') == 'premium'),
                                                    lit(1)).otherwise(0))/sum(lit(1))).alias('premium'),
                                          (sum(when((col('sold_class_code') == 'business'),
                                                    lit(1)).otherwise(0))/sum(lit(1))).alias('business'))
    df = df.withColumnRenamed('cid', 'cid2')
    
    return df

def extract_od_information(df: DataFrame, date_train: str) -> DataFrame:
    
    """Process that extract features at od level after grouping it.
    Feature examples_ seniority, age, weekend flights %, etc.
     
    Parameters
    ----------
    df_customer_data_filtered_train: pyspark DataFrame

    Returns
    -------
    pyspark Dataframe with new features columns at od level
    
    """
    
    # Extract seniority
    df = df.withColumn("seniority_1", round(datediff(to_date(lit(date_train)),col("date_creation_idgoldenrecord"))))
    df = df.withColumn("seniority_2", round(datediff(to_date(lit(date_train)),col("date_creation_pnr_resiber"))))
    df = df.withColumn("seniority", when((col('seniority_1') < 0) | (col('seniority_1').isNull()), col('seniority_2')).otherwise(col('seniority_1')))
    
    # Extract age
    df = df.withColumn("age", round(months_between(to_date(lit(date_train)),col("date"))/lit(12)))
    df = df.withColumn("age", when((col('age') >= 100) | (col('age') <= 0), lit(None)).otherwise(col('age')))
    
    # Extract anticipation days
    df = df.withColumn('num_days_anticipation', datediff(col('loc_dep_date'),col('date_creation_pnr_resiber')))
    
    # Extract weekend or weekday departure date (0 weekday, 1 weekend)
    df = df.withColumn('loc_dep_date_day_week', dayofweek('loc_dep_date'))
    df = df.withColumn('loc_dep_date_day_weekend', when(col('loc_dep_date_day_week').isin([1,7,6]),
                                                        lit(1)).otherwise(0))
    df = df.withColumn('loc_dep_date_day_week', when(col('loc_dep_date_day_week').isin([2,3,4,5]),
                                                     lit(1)).otherwise(0))
    
    # Extract season of year
    df = df.withColumn('day_of_year', dayofyear(col('loc_dep_date')))
    extractSeasonUDF = udf(lambda x: season_extraction(x))
    df = df.withColumn('season', extractSeasonUDF(col("day_of_year")))
    
    # Apply window function to obtain values for conditions
    w1 = Window.partitionBy('pnr_resiber', 'date_creation_pnr_resiber', 'cid')
    w2 = Window.partitionBy('pnr_resiber', 'date_creation_pnr_resiber')
    w3 = Window.partitionBy('cid')
    w4 = Window.partitionBy('cid', 'itinerary_od_by_pnr')
    w5 = Window.partitionBy('cid', 'destination_city_od')
    w6 = Window.partitionBy('cid', 'pnr_resiber', 'date_creation_pnr_resiber',
                            'prime_ticket_num').orderBy(asc('loc_arr_time'))
    # Min and max departure and arrival date
    df = df.withColumn("max_loc_arr_date", max("loc_arr_date").over(w1))
    df = df.withColumn("min_loc_dep_date", min("loc_dep_date").over(w1))
    df = df.withColumn("min_loc_arr_date", min("loc_arr_date").over(w1))
    
    # Extract number of passenger
    df = df.withColumn("n_passengers", size(collect_set("cid").over(w2)))
    
    # Delay flag
    df = df.withColumn("flight_delay",when((col('delayed_minutes_arrival')>=15) & (col('min_haul') == 'SH'),
                                           lit(1)).otherwise(lit(0)))
    df = df.withColumn("flight_delay",when((col('delayed_minutes_arrival')>=20) & (col('min_haul') == 'MH'),
                                           lit(1)).otherwise(col('flight_delay')))
    df = df.withColumn("flight_delay",when((col('delayed_minutes_arrival')>=30) & (col('min_haul') == 'LH'),
                                           lit(1)).otherwise(col('flight_delay')))
    
    # Extract first od from pnr
    df = df.withColumn("itinerary_od_by_pnr", first("itinerary_od").over(w1.orderBy('loc_arr_date')))
    df = df.withColumn("od_distinct", size(collect_set("itinerary_od_by_pnr").over(w3)))
    df = df.withColumn("n_ods", size(collect_set("pnr_resiber").over(w4)))
    df = df.withColumn("n_ods_total", size(collect_set("pnr_resiber").over(w3)))
    
    # Extract number of distinct destination od
    df = df.withColumn("n_destination_city_od", size(collect_set("destination_city_od").over(w3)))
    
    # Extract most common destination city
    df = df.withColumn("n_order_od", row_number().over(w6))
    df = df.withColumn("n_order_od", when(col('coupon_usage_code') == 'N', lit(None)).otherwise(col('n_order_od')))
    df = df.withColumn("n_order_od_max", max(col('n_order_od')).over(w6))
    df = df.withColumn("destination_city_od_filt", when((col('n_order_od') == col('n_order_od_max')) & (col('n_order_od') != 1),
                                                        lit(None)).otherwise(col('destination_city_od')))
    df = df.withColumn("n_dest_cid", count("destination_city_od_filt").over(w5))
    df = df.withColumn("n_most_od_flown", max(when(col('destination_city_od_filt') != 'MAD',
                                                   col("n_dest_cid")).otherwise(lit(0))).over(w3))
    df = df.withColumn("most_od_flown", max(when(col('n_most_od_flown') == col('n_dest_cid'),
                                                 col('destination_city_od_filt')).otherwise(lit(None))).over(w3))
    
    w1 = Window.partitionBy('cid')
    # Extract flag delay last flight
    df = df.withColumn('n_flight_order_flown', (when(col('coupon_usage_code') == 'T',
                                               row_number().over(w1.orderBy(desc('loc_dep_date')))).otherwise(lit(None))))
    df = df.withColumn('min_n_flight_order_flown', min('n_flight_order_flown').over(w1))
    df = df.withColumn('flag_last_flight_delay',
                       max(when((col('n_flight_order_flown') == col('min_n_flight_order_flown')) & \
                                (col('min_haul') == 'SH') & \
                                (col('delayed_minutes_arrival') > 15), lit(1)) \
                                .otherwise(when((col('n_flight_order_flown') == col('min_n_flight_order_flown')) & \
                                                (col('min_haul').isin('MH', 'LH')) & \
                                                (col('delayed_minutes_arrival') > 30),
                                                lit(1)).otherwise(lit(0)))).over(w1))
    df = df.withColumn('n_flight_order', row_number().over(w1.orderBy(desc('loc_dep_date'))))
    df = df.withColumn('min_n_flight_order', min('n_flight_order').over(w1))
    df = df.withColumn('flag_last_flight_cancelled',
                       max(when((col('n_flight_order') == col('min_n_flight_order')) & \
                                (col('cancelled') == 1), lit(1)) \
                                .otherwise(lit(0))).over(w1))

    return df

def extract_flags_weeks(df: DataFrame, date_1year: object, date_last_week: object) -> DataFrame:
    
    """Creates a flag indicating if the row comes from a last week flight
     
    Parameters
    ----------
    df_customer_data_filtered_train: pyspark DataFrame

    Returns
    -------
    pyspark Dataframe with new flag columns indicating if it
    comes from last week
    
    """
    
    df = df.withColumn('last_week', when((col('loc_dep_date') >= date_last_week) & \
                                         (col('loc_dep_date') <= date_1year), lit(1)).otherwise(0))
    
    return df

def extract_features_cid_level(df: DataFrame) -> DataFrame:
    
    """Groupby a dataframe at od level by cid to extract
    features such as flight deviation spent, mkt permission or disruptions
     
    Parameters
    ----------
    df_customer_data_filtered_train_od: pyspark DataFrame

    Returns
    -------
    pyspark Dataframe with features columns at od level
    
    """
    
    df = df.groupby('cid').agg((mean(col('deviation_flight_spent'))).alias('mean_deviation_flight_spent'),
                                          max(col('mkt_permission')).alias('mkt_permission'),
                                          max(col('purchases_2015')).alias('purchases_2015'),
                                          max(col('purchases_2016')).alias('purchases_2016'),
                                          max(col('purchases_2017')).alias('purchases_2017'),
                                          sum(when(col('last_week') == 1, col('flag_misconnection_misc')).otherwise(lit(0))).alias('flag_misconnection_misc_last_week'),
                                          sum(when(col('last_week') == 1, col('flag_dng')).otherwise(lit(0))).alias('flag_dng_last_week'),
                                          sum(when(col('last_week') == 1, col('cancelled')).otherwise(lit(0))).alias('flight_cnld_last_week'),
                                          sum(when(col('last_week') == 1, col('flight_delay')).otherwise(lit(0))).alias('flight_delay_last_week'),
                                          first('most_od_flown').alias('most_od_flown'))

    df = df.withColumnRenamed('cid', 'cid2')

    return df

def target_dataframe(df: DataFrame) -> DataFrame:
    
    """Creates target columns for CLTV from a dataframe
     
    Parameters
    ----------
    df_customer_data_filtered_future: pyspark DataFrame for testing

    Returns
    -------
    pyspark Dataframe with targets columns
    
    """
    
    df = df.groupby('cid').agg(sum(col('gross_revenue_eur') + (col('eur_bags') + col('eur_seats') + col('eur_upgs') + col('eur_others'))).alias('gross_revenue_future'))
    df = df.withColumnRenamed('cid', 'cid2')

    return df

def train_dataframe(df: DataFrame) -> DataFrame:
    
    """Creates dataframe at pnr level, cid level
    with all posible features and creates flags for last Qs
     
    Parameters
    ----------
    df_customer_data_filtered_train_od: pyspark DataFrame for training

    Returns
    -------
    pyspark Dataframe with new features at pnr/cid level
    
    """
    
    df = df.groupby('pnr_resiber', 'date_creation_pnr_resiber', 'year', 'cid')\
        .agg(sum(col('gross_revenue_eur')).alias('gross_revenue_eur'),
         countDistinct(col('itinerary_od')).alias('n_flights'),
         sum((col('num_bags') + col('num_seats') + col('num_upgs') + col('num_others'))).alias('n_ancillaries'),
         sum((col('eur_bags') + col('eur_seats') + col('eur_upgs') + col('eur_others'))).alias('eur_ancillaries'),
         sum((col('num_bags'))).alias('num_bags'), sum((col('num_seats'))).alias('num_seats'), sum((col('num_upgs'))).alias('num_upgs'),
         sum((col('num_others'))).alias('num_others'), sum((col('eur_bags'))).alias('eur_bags'), sum((col('eur_seats'))).alias('eur_seats'),
         sum((col('eur_upgs'))).alias('eur_upgs'), sum((col('eur_others'))).alias('eur_others'),
         max('ff_num').alias('flag_ibplus'),
         max('pax_type_ind').alias('pax_type_ind'),
         max('age').alias('age'),
         min(col('num_days_anticipation')).alias('num_days_anticipation'),
         min(col('num_hours_in_destination')).alias('num_hours_in_destination'),
         max(col('ind_reason_business')).alias('ind_reason_business'),
         sum(col('loc_dep_date_day_weekend')).alias('loc_dep_date_day_weekend'),
         (max('n_passengers')-1).alias('n_passengers'),
         first(col('season')).alias('season'),
         max(col('seniority')).alias('seniority'),
         max(col('max_loc_arr_date')).alias('max_loc_arr_date'),
         max(col('min_haul')).alias('haul'),
         max(when(col('ff_group') == 'IBP', col('ff_tier')).otherwise(0)).alias('ff_tier'),
         max(col('max_delayed_minutes_arrival')).alias('max_minutes_delay_arr'),
         min(col('min_delayed_minutes_arrival')).alias('min_minutes_delay_arr'),
         first(col('n_exchanges')).alias('n_exchanges'),
         first(col('n_not_travelled')).alias('n_not_travelled'),
         first(col('n_refunded')).alias('n_refunded'),
         first(col('rank_purchase')).alias('rank_purchase'),
         first(col('rank_flight')).alias('rank_flight'),
         first('n_ods').alias('n_ods'),
         first('n_ods_total').alias('n_ods_total'),
         first('ind_direct_sale').alias('ind_direct_sale'),
         first('is_corporate').alias('is_corporate'),
         first('n_destination_city_od').alias('n_destination_city_od'),
         sum(col('revenue_avios')).alias('revenue_avios'),
         sum(col('flight_delay')).alias('flight_delay'),
         first('IB_pctg').alias('IB_pctg'),
         max('is_resident').alias('is_resident'),
         max('flag_last_flight_delay').alias('flag_last_flight_delay'),
         max('perc_email_agency').alias('perc_email_agency'),
         first('n_ticket_email').alias('n_ticket_email'),
         first(col('point_of_sale')).alias('point_of_sale'))

    return df

def preprocess_data(df: DataFrame, date_train: str, date_1year: object, date_2year: object, date_3year: object,
                    date_limit: object, date_last_week: object, date_future: object=None) -> object:
    
    """Main function that preprocess original dataset and extract all posible
    features dividing the data between training and future (label) dataset
     
    Parameters
    ----------
    df: original pyspark DataFrame
    date_future: date limit for labeling data
    date_train: string that indicates limit for training
    date_1year: date limit for year-1
    date_2year: date limit for year-2
    date_3year: date limit for year-3
    date_limit: last date for training data
    date_last_week: date limit for recent week
    
    Returns
    -------
    two pyspark dataframes containing features at pnr level and label dataset
    
    """
    
    # Filter out requirements
    df = filter_initial_dataframe(df)
    # Select years for training and testing
    year_1, year_2, year_3 = -1, -2, -3
    df = create_present_future_flags(df, year_1, year_2, year_3, date_limit,
                                     date_3year, date_2year, date_1year, date_future)
    # Extract general features from non-filtered data
    df = extract_features_non_filtered_data(df, year_1, year_2, year_3)
    # Use Travelled and Not Travelled and eliminate frees
    df = df.where((col('coupon_usage_code').isin(['T', 'N'])) & (col('revenue_pax_ind') == 'Y'))
    # Extract general features from flights
    df = extract_general_features(df, year_1, year_2, year_3)
    # Filter rows with no cid
    df = df.where(col("cid").isNotNull())
    # Filter agencies
    df = email_agencies_process(df)
    # Divide dataframe between past and future
    df_train = df.where(col('year').isin([year_1, year_2, year_3]))
    df_future = df.where(col('year') == 0)
    # Replace DO by SH in Haul feature
    df_train = df_train.withColumn('haul', when(col('haul') == 'DO', 'SH').otherwise(col('haul')))
    # Extract general features without future
    df_train = extract_general_features_after(df_train)
    # Min and max departure and arrival date by od
    w1 = Window.partitionBy('cid', 'pnr_resiber', 'date_creation_pnr_resiber', 'itinerary_od')
    df_train = df_train.withColumn("min_loc_dep_date", min("loc_dep_date").over(w1))
    df_train = df_train.withColumn("max_loc_arr_date", max("loc_arr_date").over(w1))
    df_train = df_train.withColumn("min_haul", min("haul").over(w1))

    # IBPlus Tier
    dict_values_condition = [{'col_condition': 'ff_tier', 'value_1': 'null', 'value_2': lit(0)},
                             {'col_condition': 'ff_tier', 'value_1': 'Clasica', 'value_2': lit(1)},
                             {'col_condition': 'ff_tier', 'value_1': 'Plata', 'value_2': lit(2)},
                             {'col_condition': 'ff_tier', 'value_1': 'Oro', 'value_2': lit(3)},
                             {'col_condition': 'ff_tier', 'value_1': 'Platino', 'value_2': lit(4)},
                             {'col_condition': 'ff_tier', 'value_1': 'Infinita', 'value_2': lit(5)},
                             {'col_condition': 'ff_tier', 'value_1': 'Infinita Prime', 'value_2': lit(6)},
                             {'col_condition': 'ff_tier', 'value_1': 'Singular', 'value_2': lit(7)}]
    df_train = substitute_values_column(df_train, 'ff_tier', dict_values_condition)
    
    # Extract class code at flight level (Not OD level)
    df_class_code = class_code_generation(df_train)
    
    # Group all data by od
    df_train_od = group_by_od(df_train)
        
    # Extract ranking purchase and flights
    df_train_od = extract_rank_features(df_train_od)
    
    # Extract n_passenger and od information
    df_train_od = extract_od_information(df_train_od, date_train)
    
    # Create flag column for last week
    df_train_od = extract_flags_weeks(df_train_od, date_1year, date_last_week)

    # Extract information at cid level
    df_cid = extract_features_cid_level(df_train_od)

    # Aggregate data with customer logic (PNR + DATE + CID = 1 PURCHASE) and select important variables
    # Target
    df_future = target_dataframe(df_future)
    # Train
    df_train = train_dataframe(df_train_od)
    # Join first level features
    df_train = df_train.join(df_class_code,df_train["cid"] == df_class_code["cid2"], "left_outer")
    df_train = df_train.join(df_cid,df_train["cid"] == df_cid["cid2"],  "left_outer")
    
    return df_train, df_future

def post_features_extraction(df: DataFrame) -> DataFrame:
    
    """Function that extract 4 columns with mean spend after grouping all
    data at cid level
     
    Parameters
    ----------
    df_training_features: pyspark Dataframe with all features
    
    Returns
    -------
    Final pyspark dataframe with all features
    
    """
    
    df = df.withColumn('mean_spend', col('gross_revenue_and_ancillaries')/col('frequency'))
    df = df.withColumn('mean_spend_year3', col('gross_revenue_and_ancillaries_year3')/col('frequency_year3'))
    df = df.withColumn('mean_spend_year2', col('gross_revenue_and_ancillaries_year2')/col('frequency_year2'))
    df = df.withColumn('mean_spend_year1', col('gross_revenue_and_ancillaries_year1')/col('frequency_year1'))
    df = df.fillna(0, subset=['mean_spend_year3', 'mean_spend_year2', 'mean_spend_year1'])
    
    return df
    
def features_frequency(year_1: str, year_2: str, year_3: str, date_train: str) -> object:
    
    """Function that creates conditions about frequency or purchase and flight
     
    Parameters
    ----------
    year_1: flag indicating year-1
    year_2: flag indicating year-2
    year_3: flag indicating year-3
    date_train: string containing limit date for training
    
    Returns
    -------
    Functions and conditions for grouping by at cid level
    
    """
    
    # Frequency
    f_total = (sum(lit(1))).alias('frequency')
    f_3 = (sum(when(col('year') == int(year_3), lit(1)).otherwise(0))).alias('frequency_year3')
    f_2 = (sum(when(col('year') == int(year_2), lit(1)).otherwise(0))).alias('frequency_year2')
    f_1 = (sum(when(col('year') == int(year_1), lit(1)).otherwise(0))).alias('frequency_year1')
    
    # Last purchase distance
    last_purchase = (datediff(to_date(lit(date_train)), max(when((col('rank_purchase') == 1), col('date_creation_pnr_resiber')).otherwise(to_date(lit('01-01-0001')))))).alias('last_purchase')
    last2_purchase = (datediff(to_date(lit(date_train)), max(when((col('rank_purchase') == 2), col('date_creation_pnr_resiber')).otherwise(to_date(lit('01-01-0001')))))).alias('last2_purchase')
    last3_purchase = (datediff(to_date(lit(date_train)), max(when((col('rank_purchase') == 3), col('date_creation_pnr_resiber')).otherwise(to_date(lit('01-01-0001')))))).alias('last3_purchase')
    # Last flight
    last_flight = (datediff(to_date(lit(date_train)), max((col('max_loc_arr_date'))))).alias('last_flight')

    purchases_2015 = max(col('purchases_2015')).alias('purchases_2015')
    purchases_2016 = max(col('purchases_2016')).alias('purchases_2016')
    purchases_2017 = max(col('purchases_2017')).alias('purchases_2017')
    
    return (f_total, f_3, f_2, f_1, last_purchase, last2_purchase, last3_purchase,
            last_flight, purchases_2015, purchases_2016, purchases_2017)

def features_gross_revenue(year_1: str, year_2: str, year_3: str) -> object:
    
    """Function that creates conditions about gross revenue
     
    Parameters
    ----------
    year_1: flag indicating year-1
    year_2: flag indicating year-2
    year_3: flag indicating year-3
    
    Returns
    -------
    Functions and conditions for grouping by at cid level
    
    """
    
    # Year
    gr_total = (sum(col('gross_revenue_eur'))).alias('gross_revenue_eur')
    gr_anc_total = (sum(col('gross_revenue_eur') + col('eur_ancillaries'))).alias('gross_revenue_and_ancillaries')
    gr_3 = (sum(when(col('year') == int(year_3), col('gross_revenue_eur')).otherwise(0))).alias('gross_revenue_year3')
    gr_anc_3 = (sum(when(col('year') == int(year_3), col('gross_revenue_eur') + col('eur_ancillaries')).otherwise(0))).alias('gross_revenue_and_ancillaries_year3')
    gr_2 = (sum(when(col('year') == int(year_2), col('gross_revenue_eur')).otherwise(0))).alias('gross_revenue_year2')
    gr_anc_2 = (sum(when(col('year') == int(year_2), col('gross_revenue_eur') + col('eur_ancillaries')).otherwise(0))).alias('gross_revenue_and_ancillaries_year2')
    gr_1 = (sum(when(col('year') == int(year_1), col('gross_revenue_eur')).otherwise(0))).alias('gross_revenue_year1')
    gr_anc_1 = (sum(when(col('year') == int(year_1), col('gross_revenue_eur') + col('eur_ancillaries')).otherwise(0))).alias('gross_revenue_and_ancillaries_year1')
    
    # Gross revenue by haul
    gr_1_sh = (sum(when((col('year') == int(year_1)) & (col('haul') == 'SH'), col('gross_revenue_eur')).otherwise(0))).alias('gross_revenue_sh_year1')
    gr_1_mh = (sum(when((col('year') == int(year_1)) & (col('haul') == 'MH'), col('gross_revenue_eur')).otherwise(0))).alias('gross_revenue_mh_year1')
    gr_1_lh = (sum(when((col('year') == int(year_1)) & (col('haul') == 'LH'), col('gross_revenue_eur')).otherwise(0))).alias('gross_revenue_lh_year1')
    
    # Last spend
    last_spent = (max(when((col('rank_purchase') == 1), col('gross_revenue_eur') + col('eur_ancillaries')).otherwise(0))).alias('last_spent')

    return (gr_total, gr_3, gr_2, gr_1, gr_anc_total, gr_anc_3, gr_anc_2, gr_anc_1, last_spent, gr_1_sh, gr_1_mh, gr_1_lh)

def features_number_ancillaries(year_1: str, year_2: str, year_3: str) -> object:
    
    """Function that creates conditions about number of ancillaries
     
    Parameters
    ----------
    year_1: flag indicating year-1
    year_2: flag indicating year-2
    year_3: flag indicating year-3
    
    Returns
    -------
    Functions and conditions for grouping by at cid level
    
    """
    
    n_anc_total = (sum(col('n_ancillaries'))).alias('n_ancillaries')
    n_bags_total = (sum(col('num_bags'))).alias('num_bags')
    n_seats_total = (sum(col('num_seats'))).alias('num_seats')
    n_upgs_total = (sum(col('num_upgs'))).alias('num_upgs')
    n_others_total = (sum(col('num_others'))).alias('num_others')
    n_anc_3 = (sum(when(col('year') == int(year_3), col('n_ancillaries')).otherwise(0))).alias('n_ancillaries_year3')
    n_anc_2 = (sum(when(col('year') == int(year_2), col('n_ancillaries')).otherwise(0))).alias('n_ancillaries_year2')
    n_anc_1 = (sum(when(col('year') == int(year_1), col('n_ancillaries')).otherwise(0))).alias('n_ancillaries_year1')

    return n_anc_total, n_anc_3, n_anc_2, n_anc_1, n_bags_total, n_seats_total, n_upgs_total, n_others_total

def features_euros_ancillaries(year_1: str, year_2: str, year_3: str) -> object:
    
    """Function that creates conditions about ancillaries prices
     
    Parameters
    ----------
    year_1: flag indicating year-1
    year_2: flag indicating year-2
    year_3: flag indicating year-3
    
    Returns
    -------
    Functions and conditions for grouping by at cid level
    
    """
    
    eur_anc_total = (sum(col('eur_ancillaries'))).alias('eur_ancillaries')
    eur_bags_total = (sum(col('eur_bags'))).alias('eur_bags')
    eur_seats_total = (sum(col('eur_seats'))).alias('eur_seats')
    eur_upgs_total = (sum(col('eur_upgs'))).alias('eur_upgs')
    eur_others_total = (sum(col('eur_others'))).alias('eur_others')
    eur_anc_3 = (sum(when(col('year') == int(year_3), col('eur_ancillaries')).otherwise(0))).alias('eur_ancillaries_year3')
    eur_anc_2 = (sum(when(col('year') == int(year_2), col('eur_ancillaries')).otherwise(0))).alias('eur_ancillaries_year2')
    eur_anc_1 = (sum(when(col('year') == int(year_1), col('eur_ancillaries')).otherwise(0))).alias('eur_ancillaries_year1')

    return eur_anc_total, eur_anc_3, eur_anc_2, eur_anc_1, eur_bags_total, eur_seats_total, eur_upgs_total, eur_others_total

def features_days_anticipation(year_1: str, year_2: str, year_3: str) -> object:
    
    """Function that creates conditions about days of anticipation
     
    Parameters
    ----------
    year_1: flag indicating year-1
    year_2: flag indicating year-2
    year_3: flag indicating year-3
    
    Returns
    -------
    Functions and conditions for grouping by at cid level
    
    """
    
    n_anticip_total = (sum(when((col('num_days_anticipation').isNotNull()), col('num_days_anticipation')).otherwise(0))\
                       /sum(when((col('num_days_anticipation').isNotNull()), lit(1)).otherwise(0))).alias('num_days_anticipation')
    n_anticip_3 = (sum(when((col('year') == int(year_3)) & (col('num_days_anticipation').isNotNull()),
                            col('num_days_anticipation')).otherwise(0))\
                   /sum(when((col('year') == int(year_3))  & (col('num_days_anticipation').isNotNull()),
                             lit(1)).otherwise(0))).alias('num_days_anticipation_year3')
    n_anticip_2 = (sum(when((col('year') == int(year_2)) & (col('num_days_anticipation').isNotNull()),
                            col('num_days_anticipation')).otherwise(0))\
                   /sum(when((col('year') == int(year_2))  & (col('num_days_anticipation').isNotNull()),
                             lit(1)).otherwise(0))).alias('num_days_anticipation_year2')
    n_anticip_1 = (sum(when((col('year') == int(year_1)) & (col('num_days_anticipation').isNotNull()),
                            col('num_days_anticipation')).otherwise(0))\
                   /sum(when((col('year') == int(year_1))  & (col('num_days_anticipation').isNotNull()),
                             lit(1)).otherwise(0))).alias('num_days_anticipation_year1')
    

    return n_anticip_total, n_anticip_3, n_anticip_2, n_anticip_1

def features_hours_destination(year_1: str, year_2: str, year_3: str) -> object:
    
    """Function that creates conditions about hours in destination
     
    Parameters
    ----------
    year_1: flag indicating year-1
    year_2: flag indicating year-2
    year_3: flag indicating year-3
    
    Returns
    -------
    Functions and conditions for grouping by at cid level
    
    """
    
    hours_dest_total = (sum(when((col('num_hours_in_destination').isNotNull()), col('num_hours_in_destination')).otherwise(0))\
                       /sum(when((col('num_hours_in_destination').isNotNull()), lit(1)).otherwise(0))).alias('num_hours_in_destination')
    hours_dest_3 = (sum(when((col('year') == int(year_3)) & (col('num_hours_in_destination').isNotNull()),
                            col('num_hours_in_destination')).otherwise(0))\
                   /sum(when((col('year') == int(year_3))  & (col('num_hours_in_destination').isNotNull()),
                             lit(1)).otherwise(0))).alias('num_hours_in_destination_year3')
    hours_dest_2 = (sum(when((col('year') == int(year_2)) & (col('num_hours_in_destination').isNotNull()),
                            col('num_hours_in_destination')).otherwise(0))\
                   /sum(when((col('year') == int(year_2))  & (col('num_hours_in_destination').isNotNull()),
                             lit(1)).otherwise(0))).alias('num_hours_in_destination_year2')
    hours_dest_1 = (sum(when((col('year') == int(year_1)) & (col('num_hours_in_destination').isNotNull()),
                            col('num_hours_in_destination')).otherwise(0))\
                   /sum(when((col('year') == int(year_1))  & (col('num_hours_in_destination').isNotNull()),
                             lit(1)).otherwise(0))).alias('num_hours_in_destination_year1')
    

    return hours_dest_total, hours_dest_3, hours_dest_2, hours_dest_1

def features_reason_business(year_1: str, year_2: str, year_3: str) -> object:
    
    """Function that creates conditions about % business reason
     
    Parameters
    ----------
    year_1: flag indicating year-1
    year_2: flag indicating year-2
    year_3: flag indicating year-3
    
    Returns
    -------
    Functions and conditions for grouping by at cid level
    
    """
    
    ind_reason_total = (sum(when((col('ind_reason_business').isNotNull()), col('ind_reason_business')).otherwise(0))\
                       /sum(when((col('ind_reason_business').isNotNull()), lit(1)).otherwise(0))).alias('ind_reason_business')
    ind_reason_3 = (sum(when((col('year') == int(year_3)) & (col('ind_reason_business').isNotNull()),
                            col('ind_reason_business')).otherwise(0))\
                   /sum(when((col('year') == int(year_3))  & (col('ind_reason_business').isNotNull()),
                             lit(1)).otherwise(0))).alias('ind_reason_business_year3')
    ind_reason_2 = (sum(when((col('year') == int(year_2)) & (col('ind_reason_business').isNotNull()),
                            col('ind_reason_business')).otherwise(0))\
                   /sum(when((col('year') == int(year_2))  & (col('ind_reason_business').isNotNull()),
                             lit(1)).otherwise(0))).alias('ind_reason_business_year2')
    ind_reason_1 = (sum(when((col('year') == int(year_1)) & (col('ind_reason_business').isNotNull()),
                            col('ind_reason_business')).otherwise(0))\
                   /sum(when((col('year') == int(year_1))  & (col('ind_reason_business').isNotNull()),
                             lit(1)).otherwise(0))).alias('ind_reason_business_year1')
    

    return ind_reason_total, ind_reason_3, ind_reason_2, ind_reason_1

def features_dep_weekend(year_1: str, year_2: str, year_3: str) -> object:
    
    """Function that creates conditions about % departure at weekend
     
    Parameters
    ----------
    year_1: flag indicating year-1
    year_2: flag indicating year-2
    year_3: flag indicating year-3
    
    Returns
    -------
    Functions and conditions for grouping by at cid level
    
    """
    
    dep_weekend_total = (sum(col('loc_dep_date_day_weekend'))/sum(col('n_flights'))).alias('loc_dep_weekend')
    dep_weekend_3 = (sum(when(col('year') == int(year_3), col('loc_dep_date_day_weekend')).otherwise(0))\
                     /sum(when(col('year') == int(year_3), col('n_flights')).otherwise(0))).alias('loc_dep_weekend_year3')
    dep_weekend_2 = (sum(when(col('year') == int(year_2), col('loc_dep_date_day_weekend')).otherwise(0))\
                     /sum(when(col('year') == int(year_2), col('n_flights')).otherwise(0))).alias('loc_dep_weekend_year2')
    dep_weekend_1 = (sum(when(col('year') == int(year_1), col('loc_dep_date_day_weekend')).otherwise(0))\
                     /sum(when(col('year') == int(year_1), col('n_flights')).otherwise(0))).alias('loc_dep_weekend_year1')
    
    return dep_weekend_total, dep_weekend_3, dep_weekend_2, dep_weekend_1

def features_ow_flights(year_1: str, year_2: str, year_3: str) -> object:
    
    """Function that creates conditions about % one way flights
     
    Parameters
    ----------
    year_1: flag indicating year-1
    year_2: flag indicating year-2
    year_3: flag indicating year-3
    
    Returns
    -------
    Functions and conditions for grouping by at cid level
    
    """
    
    ow_flight_total = (sum(when((col('n_flights') == 1), lit(1)).otherwise(0))/sum(lit(1))).alias('ow_flights')
    ow_flight_3 = (sum(when((col('year') == int(year_3)) & (col('n_flights') == 1), lit(1)).otherwise(0))\
                     /sum(when(col('year') == int(year_3), lit(1)).otherwise(0))).alias('ow_flights_year3')
    ow_flight_2 = (sum(when((col('year') == int(year_2)) & (col('n_flights') == 1), lit(1)).otherwise(0))\
                     /sum(when(col('year') == int(year_2), lit(1)).otherwise(0))).alias('ow_flights_year2')
    ow_flight_1 = (sum(when((col('year') == int(year_1)) & (col('n_flights') == 1), lit(1)).otherwise(0))\
                     /sum(when(col('year') == int(year_1), lit(1)).otherwise(0))).alias('ow_flights_year1')
    

    return ow_flight_total, ow_flight_3, ow_flight_2, ow_flight_1

def features_ibplus() -> object:
    
    """Function that creates conditions about ibplus condition
     
    Parameters
    ----------
    
    Returns
    -------
    Functions and conditions for grouping by at cid level
    
    """
    
    ibplus = (max(when((col('flag_ibplus').isNotNull()), lit(1)).otherwise(0))).alias('flag_ibplus')
    is_corporate = (sum(col('is_corporate'))/sum(when((col('is_corporate').isNotNull()), lit(1)).otherwise(0))).alias('is_corporate')
    revenue_avios = sum(col('revenue_avios')).alias('revenue_avios')
    
    return ibplus, is_corporate, revenue_avios

def features_type_ind() -> object:
    
    """Function that creates conditions about passenger type
     
    Parameters
    ----------
    
    Returns
    -------
    Functions and conditions for grouping by at cid level
    
    """
    
    type_ind = max('pax_type_ind').alias('pax_type_ind')
    is_resident = max('is_resident').alias('is_resident')
    
    return type_ind, is_resident

def features_age() -> object:
    
    """Function that creates conditions about age
     
    Parameters
    ----------
    
    Returns
    -------
    Functions and conditions for grouping by at cid level
    
    """
    
    age = max('age').alias('age')
    
    return age

def features_n_passengers() -> object:
    
    """Function that creates conditions about % number of passenger
    (alone, couple or groups)
     
    Parameters
    ----------
    
    Returns
    -------
    Functions and conditions for grouping by at cid level
    
    """
    
    alone_flights = (sum(when((col('n_passengers') == 0), lit(1)).otherwise(0))/sum(lit(1))).alias('alone_flights')
    couple_flights = (sum(when((col('n_passengers') == 1), lit(1)).otherwise(0))/sum(lit(1))).alias('couple_flights')
    group_flights = (sum(when((col('n_passengers') > 1), lit(1)).otherwise(0))/sum(lit(1))).alias('group_flights')
    
    return alone_flights, couple_flights, group_flights

def features_season() -> object:
    
    """Function that creates conditions about % flights in each
    season
     
    Parameters
    ----------
    
    Returns
    -------
    Functions and conditions for grouping by at cid level
    
    """
    
    summer_flights = (sum(when((col('season') == 'summer'), lit(1)).otherwise(0))/sum(lit(1))).alias('summer_flights')
    winter_flights = (sum(when((col('season') == 'winter'), lit(1)).otherwise(0))/sum(lit(1))).alias('winter_flights')
    fall_flights = (sum(when((col('season') == 'fall'), lit(1)).otherwise(0))/sum(lit(1))).alias('fall_flights')
    spring_flights = (sum(when((col('season') == 'spring'), lit(1)).otherwise(0))/sum(lit(1))).alias('spring_flights')
    
    return summer_flights, winter_flights, fall_flights, spring_flights

def features_seniority() -> object:
    
    """Function that creates conditions about seniority
     
    Parameters
    ----------
    
    Returns
    -------
    Functions and conditions for grouping by at cid level
    
    """
    
    seniority = max(col('seniority')).alias('seniority')
    
    return seniority

def features_haul() -> object:
    
    """Function that creates conditions about % of each haul
     
    Parameters
    ----------
    
    Returns
    -------
    Functions and conditions for grouping by at cid level
    
    """
    
    sh_haul = (sum(when((col('haul') == 'SH') & (col('haul').isNotNull()), lit(1)).otherwise(0))/sum(when((col('haul').isNotNull()), lit(1)).otherwise(0))).alias('sh_haul')
    mh_haul = (sum(when((col('haul') == 'MH') & (col('haul').isNotNull()), lit(1)).otherwise(0))/sum(when((col('haul').isNotNull()), lit(1)).otherwise(0))).alias('mh_haul')
    lh_haul = (sum(when((col('haul') == 'LH') & (col('haul').isNotNull()), lit(1)).otherwise(0))/sum(when((col('haul').isNotNull()), lit(1)).otherwise(0))).alias('lh_haul')
    
    return sh_haul, mh_haul, lh_haul

def features_class_code() -> object:
    
    """Function that creates conditions about % class flown
     
    Parameters
    ----------
    
    Returns
    -------
    Functions and conditions for grouping by at cid level
    
    """
    
    economy = first(col('economy')).alias('economy')
    premium = first(col('premium')).alias('premium')
    business = first(col('business')).alias('business')

    return economy, premium, business

def features_tier() -> object:
    
    """Function that creates conditions about tier condition
     
    Parameters
    ----------
    
    Returns
    -------
    Functions and conditions for grouping by at cid level
    
    """
    
    ff_tier = max(col('ff_tier')).alias('ff_tier')
    
    return ff_tier

def features_flight_spent() -> object:
    
    """Function that creates conditions about max, min and deviation
    of flight spent/position with respect flights
     
    Parameters
    ----------
    
    Returns
    -------
    Functions and conditions for grouping by at cid level
    
    """
    
    dev_spent_flight = first(col('mean_deviation_flight_spent')).alias('mean_deviation_flight_spent')
    
    return dev_spent_flight

def features_disruptions() -> object:
    
    """Function that creates conditions about disruptions
     
    Parameters
    ----------
    
    Returns
    -------
    Functions and conditions for grouping by at cid level
    
    """
    
    flag_misconnection_misc_last_week = first(col('flag_misconnection_misc_last_week')).alias('flag_misconnection_misc_last_week')
    flag_dng_last_week= first(col('flag_dng_last_week')).alias('flag_dng_last_week')
    flight_cnld_last_week = first(col('flight_cnld_last_week')).alias('flight_cnld_last_week')
    flight_delay_last_week = first(col('flight_delay_last_week')).alias('flight_delay_last_week')
    flag_last_flight_delay = first(col('flag_last_flight_delay')).alias('flag_last_flight_delay')

    return (flag_misconnection_misc_last_week, flag_dng_last_week, flight_cnld_last_week, flight_delay_last_week, flag_last_flight_delay)

def features_flight_time(year_1, year_2, year_3):
    
    """Function that creates conditions about delays
     
    Parameters
    ----------
    year_1: flag indicating year-1
    year_2: flag indicating year-2
    year_3: flag indicating year-3
    
    Returns
    -------
    Functions and conditions for grouping by at cid level
    
    """
    
    max_minutes_delay_arr_1 = (max(when(col('year') == int(year_1), col('max_minutes_delay_arr')).otherwise(-np.inf))).alias('max_minutes_delay_arr_year1')
    max_minutes_delay_arr_2 = (max(when(col('year') == int(year_2), col('max_minutes_delay_arr')).otherwise(-np.inf))).alias('max_minutes_delay_arr_year2')
    max_minutes_delay_arr_3 = (max(when(col('year') == int(year_3), col('max_minutes_delay_arr')).otherwise(-np.inf))).alias('max_minutes_delay_arr_year3')
    min_minutes_delay_arr_1 = (min(when(col('year') == int(year_1), col('min_minutes_delay_arr')).otherwise(np.nan))).alias('min_minutes_delay_arr_year1')
    min_minutes_delay_arr_2 = (min(when(col('year') == int(year_2), col('min_minutes_delay_arr')).otherwise(np.nan))).alias('min_minutes_delay_arr_year2')
    min_minutes_delay_arr_3 = (min(when(col('year') == int(year_3), col('min_minutes_delay_arr')).otherwise(np.nan))).alias('min_minutes_delay_arr_year3')
    total_minutes_delay_arr_1 = (sum(when(col('year') == int(year_1), col('max_minutes_delay_arr')).otherwise(0))).alias('total_minutes_delay_arr_year1')
    total_minutes_delay_arr_2 = (sum(when(col('year') == int(year_2), col('max_minutes_delay_arr')).otherwise(0))).alias('total_minutes_delay_arr_year2')
    total_minutes_delay_arr_3 = (sum(when(col('year') == int(year_3), col('max_minutes_delay_arr')).otherwise(0))).alias('total_minutes_delay_arr_year3')
    
    return (max_minutes_delay_arr_1, max_minutes_delay_arr_2, max_minutes_delay_arr_3, min_minutes_delay_arr_1, min_minutes_delay_arr_2,
            min_minutes_delay_arr_3, total_minutes_delay_arr_1, total_minutes_delay_arr_2, total_minutes_delay_arr_3)

def features_mkt_permission() -> object:
    
    """Function that creates conditions about marketing permission
     
    Parameters
    ----------
    
    Returns
    -------
    Functions and conditions for grouping by at cid level
    
    """
    
    mkt_permission = max(col('mkt_permission')).alias('mkt_permission')
    
    return mkt_permission

def features_type_travel() -> object:
    
    """Function that creates conditions about number of not travelled, refunded
    and exchanges
     
    Parameters
    ----------
    
    Returns
    -------
    Functions and conditions for grouping by at cid level
    
    """
    
    n_exchanges = first(col('n_exchanges')).alias('n_exchanges')
    n_not_travelled = first(col('n_not_travelled')).alias('n_not_travelled')
    n_refunded = first(col('n_refunded')).alias('n_refunded')

    return n_exchanges, n_not_travelled, n_refunded

def features_ods() -> object:
    
    """Function that creates conditions about number of ods
     
    Parameters
    ----------
    
    Returns
    -------
    Functions and conditions for grouping by at cid level
    
    """
    
    most_od_bought = (max(col('n_ods'))/first(col('n_ods_total'))).alias('most_od_bought')
    n_destination = first('n_destination_city_od').alias('n_destination_city_od')
    most_od_flown = first('most_od_flown').alias('most_od_flown')
    point_of_sale = first(col('point_of_sale')).alias('point_of_sale')

    return most_od_bought, n_destination, most_od_flown, point_of_sale

def sales_channel() -> object:
    
    """Function that creates conditions about sale channel
     
    Parameters
    ----------
    
    Returns
    -------
    Functions and conditions for grouping by at cid level
    
    """
    
    dir_ind = (sum(col('ind_direct_sale'))/sum(when((col('ind_direct_sale').isNotNull()), lit(1)).otherwise(0))).alias('ind_direct_sale')
    
    return dir_ind

def features_nps() -> object:
    
    """Function that creates conditions about nps
     
    Parameters
    ----------
    
    Returns
    -------
    Functions and conditions for grouping by at cid level
    
    """
    
    nps_100 = mean(col('nps_100')).alias('nps_100')
    
    return nps_100

def features_op_carrier() -> object:
    
    """Function that creates conditions about IB_pctg
     
    Parameters
    ----------
    
    Returns
    -------
    Functions and conditions for grouping by at cid level
    
    """
    
    IB_pctg = first(col('IB_pctg')).alias('IB_pctg')
    
    return IB_pctg

def features_perc_email_agency() -> object:
    
    """Function that creates conditions about IB_pctg
     
    Parameters
    ----------
    
    Returns
    -------
    Functions and conditions for grouping by at cid level
    
    """
    
    perc_email_agency = (sum(col('perc_email_agency'))/sum(when((col('perc_email_agency').isNotNull()), lit(1)).otherwise(0))).alias('perc_email_agency')
    n_ticket_email = first(col('n_ticket_email')).alias('n_ticket_email')
    
    return perc_email_agency, n_ticket_email

def create_final_features(df_ticket: DataFrame, df_future: DataFrame, date_train: str, po_execution: bool) -> DataFrame:
    
    """Function that creates conditions about marketing permission
     
    Parameters
    ----------
    df_customer_data_filtered_ticket: training dataframe at pnr/cid level
    df_customer_data_filtered_future: dataframe for labeling
    date_train: string with limit date for training
    
    Returns
    -------
    Final pyspark Dataframe with all features and cid for training
    
    """
        
    ## Extract features for groupby
    year_1, year_2, year_3 = -1, -2, -3
    # Frequency
    (f_total, f_3, f_2, f_1, last_purchase, last2_purchase, last3_purchase,
            last_flight, purchases_2015, purchases_2016, purchases_2017) = features_frequency(year_1, year_2, year_3, date_train)
    # Gross Revenue
    (gr_total, gr_3, gr_2, gr_1, gr_anc_total, gr_anc_3, gr_anc_2, 
    gr_anc_1, last_spent, gr_1_sh, gr_1_mh, gr_1_lh) = features_gross_revenue(year_1, year_2, year_3)
    # Number of ancillaries
    n_anc_total, n_anc_3, n_anc_2, n_anc_1, n_bags_total, n_seats_total, n_upgs_total, n_others_total = features_number_ancillaries(year_1, year_2, year_3)
    # Euros in ancillaries
    eur_anc_total, eur_anc_3, eur_anc_2, eur_anc_1, eur_bags_total, eur_seats_total, eur_upgs_total, eur_others_total = features_euros_ancillaries(year_1, year_2, year_3)
    # Number of days anticipation
    n_anticip_total, n_anticip_3, n_anticip_2, n_anticip_1 = features_days_anticipation(year_1, year_2, year_3)
    # Number of hours in destination
    hours_dest_total, hours_dest_3, hours_dest_2, hours_dest_1 = features_hours_destination(year_1, year_2, year_3)
    # Number of hours in destination
    ind_reason_total, ind_reason_3, ind_reason_2, ind_reason_1 = features_reason_business(year_1, year_2, year_3)
    # Percentage of flights with departure at weekend
    dep_weekend_total, dep_weekend_3, dep_weekend_2, dep_weekend_1 = features_dep_weekend(year_1, year_2, year_3)
    # Percentage of OW flights
    ow_flight_total, ow_flight_3, ow_flight_2, ow_flight_1 = features_ow_flights(year_1, year_2, year_3)
    # IBPlus flag
    ibplus, is_corporate, revenue_avios = features_ibplus()
    # Number of passengers
    alone_flights, couple_flights, group_flights = features_n_passengers()
    # Type Indicator passenger (C, I ,A)
    type_ind, is_resident = features_type_ind()
    # Age
    age = features_age()
    # Season flight percentage
    summer_flights, winter_flights, fall_flights, spring_flights = features_season()
    # Seniority
    seniority = features_seniority()
    # Haul percentage
    sh_haul, mh_haul, lh_haul = features_haul()
    # Class code
    economy, premium, business = features_class_code()
    # Tier
    ff_tier = features_tier()
    # Flight spent
    dev_spent_flight = features_flight_spent()
    # Disruptions
    (flag_misconnection_misc_last_week, flag_dng_last_week, flight_cnld_last_week,
     flight_delay_last_week, flag_last_flight_delay) = features_disruptions()
    # Delays
    (max_minutes_delay_arr_1, max_minutes_delay_arr_2, max_minutes_delay_arr_3, min_minutes_delay_arr_1, min_minutes_delay_arr_2,
     min_minutes_delay_arr_3, total_minutes_delay_arr_1, total_minutes_delay_arr_2, total_minutes_delay_arr_3) = features_flight_time(year_1, year_2, year_3)
    # Mkt permission
    mkt_permission = features_mkt_permission()
    # Exchanges and travelled
    n_exchanges, n_not_travelled, n_refunded = features_type_travel()
    # Features ods
    most_od_bought, n_destination, most_od_flown, point_of_sale = features_ods()
    # Sales channel
    dir_ind = sales_channel()
    # OP Carrier code percentage
    IB_pctg = features_op_carrier()
    # Flag email agencie
    perc_email_agency, n_ticket_email = features_perc_email_agency()
    
    # Groupby all features by id golden record
    df_ticket_year = df_ticket.groupby('cid').agg(f_total, gr_total, n_anc_total, eur_anc_total,
                                                 n_anticip_total, hours_dest_total, ind_reason_total,
                                                 dep_weekend_total, ow_flight_total, gr_anc_total, gr_anc_3, gr_anc_2,
                                                 gr_anc_1, f_3, f_2, f_1, ibplus, type_ind, age, alone_flights,
                                                 couple_flights, group_flights, summer_flights, winter_flights,
                                                 fall_flights, spring_flights, seniority, last_purchase, last2_purchase,
                                                 last3_purchase, last_flight, sh_haul, mh_haul, lh_haul, economy, premium,
                                                 business, ff_tier, dev_spent_flight,
                                                 max_minutes_delay_arr_1, mkt_permission, n_exchanges, n_not_travelled,
                                                 n_refunded, purchases_2015, purchases_2016, purchases_2017,
                                                 most_od_bought, most_od_flown, dir_ind, last_spent, n_destination, is_corporate, revenue_avios,
                                                 flag_misconnection_misc_last_week, flag_dng_last_week, flight_cnld_last_week,
                                                 flight_delay_last_week, IB_pctg, is_resident, flag_last_flight_delay, perc_email_agency, n_ticket_email,
                                                 gr_1_sh, gr_1_mh, gr_1_lh, point_of_sale)
    
    # Filter out clients with more than one purchase or more than one purchase + at least one purchase in year_mean
    df_features = df_ticket_year.where((col('frequency_year1') > 0) & ((col('frequency_year1') + col('frequency_year2') + col('frequency_year3')) > 1))
    # Add target gross revenue future
    if not po_execution:
        df_features = df_features.join(df_future, df_features["cid"] == df_future["cid2"], "left_outer")
        df_features = df_features.fillna(value=0, subset=['gross_revenue_future'])
        df_features = df_features.drop('cid2')
    # Extract post features group
    df_features = post_features_extraction(df_features)
    # Churn label
    if not po_execution:
        df_features = df_features.withColumn('churn', when(col('gross_revenue_future') == 0, lit(1)).otherwise(lit(0)))
    
    # Replace np.inf, string AAAAAA by nan
    df_features = df_features.replace(-np.inf, np.nan)
    df_features = df_features.replace(np.inf, np.nan)
    df_features = df_features.na.fill(value=0,subset=["purchases_2015", "purchases_2016", "purchases_2017", "revenue_avios"])
    
    # Eliminate possible agencies and isolated CIDs
    df_features = df_features.where(~(((col('perc_email_agency') > 0.7) | (col('ind_direct_sale') < 0.2)) & (col('frequency') < 4) & (col('n_ticket_email') > 1000)))

    return df_features

def execute_etl_churn(s3_dir_customer: str, columns: list, po_execution: bool=True) -> DataFrame:
    
    """Function execute all pipeline for etl churn features
     
    Parameters
    ----------
    s3_dir_customer: s3 path for reading data
    columns: columns for reading
    
    Returns
    -------
    Dataframe with all features
    
    """
    
    if po_execution:
        year_fin = 2023
        month_fin = 12
        day_fin = 31
        date_future = None
    else:
        year_fin = 2022
        month_fin = 12
        day_fin = 31
        date_future = datetime.datetime(year_fin, month_fin, day_fin)
    if day_fin == 29 and month_fin == 2:
        date_1year = datetime.datetime(year_fin - 1, month_fin, day_fin - 1)
        date_2year = datetime.datetime(year_fin - 2, month_fin, day_fin - 1)
        date_3year = datetime.datetime(year_fin - 3, month_fin, day_fin - 1)
        date_limit = datetime.datetime(year_fin - 4, month_fin, day_fin - 1)
    else:
        date_1year = datetime.datetime(year_fin - 1, month_fin, day_fin)
        date_2year = datetime.datetime(year_fin - 2, month_fin, day_fin)
        date_3year = datetime.datetime(year_fin - 3, month_fin, day_fin)
        date_limit = datetime.datetime(year_fin - 4, month_fin, day_fin)
    date_train = str(date_1year).split()[0]
    date_last_week = date_1year - datetime.timedelta(days=7)
    print(date_limit, date_1year, date_future)

    # Read data
    df = read_s3_data_churn(s3_dir_customer, columns)
    # Execute all code
    df_prc, df_future = preprocess_data(df, date_train, date_1year, date_2year, date_3year, date_limit,
                                        date_last_week, date_future)
    df_customer_data_prc_features = create_final_features(df_prc, df_future, date_train, po_execution)
    
    return df_customer_data_prc_features, df_prc

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

### Execution

Global parameters

In [38]:
# Data path from S3
po_execution = False
s3_dir_customer = 's3://iberia-data-lake/customer/customer_models_data_v2/'
# s3_dir_customer_part = get_last_s3_partition(s3_dir_customer)
s3_dir_customer_part = 's3://iberia-data-lake/customer/customer_models_data_v2/insert_date_ci=2023-02-19/'
columns = ['cid', 'coupon_usage_code', 'prime_ticket_num', 'coupon_num', 'pnr_resiber', 'pnr_amadeus',
           'gross_revenue_eur','ticket_sale_date', 'date_creation_pnr_resiber', 'num_bags', 'eur_bags', 'num_seats',
           'eur_seats', 'num_upgs', 'eur_upgs', 'num_others', 'eur_others', 'ff_num',
           'birth_date', 'pax_type_ind', 'num_days_anticipation', 'ind_reason_business',
           'num_hours_in_destination', 'loc_dep_date', 'loc_arr_date', 'loc_arr_time', 'loc_dep_time',
           'itinerary_od', 'date_creation_idgoldenrecord', 'haul', 'sold_class_code', 'ff_tier',
           'pax_demand_space', 'ff_group', 'op_carrier_code', 'op_flight_num', 'flag_misconnection_misc', 'flag_dng', 'delayed_minutes_arrival',
           'cancelled', 'mkt_permission', 'purchases_2015', 'purchases_2016', 'purchases_2017', 'ind_direct_sale', 'destination_city_od',
           'is_corporate', 'revenue_avios', 'pax_type_seg', 'op_carrier_code', 'email_operative', 'point_of_sale']

df_customer_data_prc_features, df_prc = execute_etl_churn(s3_dir_customer_part, columns, po_execution)
df_prc = df_prc.cache()
df_customer_data_prc_features = df_customer_data_prc_features.cache()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

2018-12-31 00:00:00 2021-12-31 00:00:00 2022-12-31 00:00:00

In [39]:
df_customer_data_prc_features.count()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

1901268

In [42]:
%%pretty
df_customer_data_prc_features.groupby('flag_ibplus').count().show()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

flag_ibplus,count
1,848377
0,1052891


In [40]:
%%pretty
df_customer_data_prc_features.groupby('flag_ibplus_2').count().show()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

flag_ibplus_2,count
1,848377
0,1052891


In [41]:
%%pretty
df_customer_data_prc_features.where(col('flag_ibplus') != col('flag_ibplus_2')).show(10)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

cid,frequency,gross_revenue_eur,n_ancillaries,eur_ancillaries,num_days_anticipation,num_hours_in_destination,ind_reason_business,loc_dep_weekend,ow_flights,gross_revenue_and_ancillaries,gross_revenue_and_ancillaries_year3,gross_revenue_and_ancillaries_year2,gross_revenue_and_ancillaries_year1,frequency_year3,frequency_year2,frequency_year1,flag_ibplus,flag_ibplus_2,pax_type_ind,age,alone_flights,couple_flights,group_flights,summer_flights,winter_flights,fall_flights,spring_flights,seniority,last_purchase,last2_purchase,last3_purchase,last_flight,sh_haul,mh_haul,lh_haul,economy,premium,business,ff_tier,mean_deviation_flight_spent,max_minutes_delay_arr_year1,mkt_permission,n_exchanges,n_not_travelled,n_refunded,purchases_2015,purchases_2016,purchases_2017,most_od_bought,most_od_flown,ind_direct_sale,last_spent,n_destination_city_od,is_corporate,revenue_avios,flag_misconnection_misc_last_week,flag_dng_last_week,flight_cnld_last_week,flight_delay_last_week,IB_pctg,is_resident,flag_last_flight_delay,perc_email_agency,n_ticket_email,gross_revenue_sh_year1,gross_revenue_mh_year1,gross_revenue_lh_year1,point_of_sale,gross_revenue_future,mean_spend,mean_spend_year3,mean_spend_year2,mean_spend_year1,churn


In [19]:
df_customer_data_prc_features.repartition(20).write.mode('overwrite').orc('s3://iberia-data-lake/customer/churn_model_v2/train_data_1/')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [46]:
# Data path from S3
po_execution = True
s3_dir_customer = 's3://iberia-data-lake/customer/customer_models_data_v2/'
# s3_dir_customer_part = get_last_s3_partition(s3_dir_customer)
s3_dir_customer_part = 's3://iberia-data-lake/customer/customer_models_data_v2/insert_date_ci=2023-02-19/'
columns = ['cid', 'coupon_usage_code', 'prime_ticket_num', 'coupon_num', 'pnr_resiber', 'pnr_amadeus',
           'gross_revenue_eur','ticket_sale_date', 'date_creation_pnr_resiber', 'num_bags', 'eur_bags', 'num_seats',
           'eur_seats', 'num_upgs', 'eur_upgs', 'num_others', 'eur_others', 'ff_num',
           'birth_date', 'pax_type_ind', 'num_days_anticipation', 'ind_reason_business',
           'num_hours_in_destination', 'loc_dep_date', 'loc_arr_date', 'loc_arr_time', 'loc_dep_time',
           'itinerary_od', 'date_creation_idgoldenrecord', 'haul', 'sold_class_code', 'ff_tier',
           'pax_demand_space', 'ff_group', 'op_carrier_code', 'op_flight_num', 'flag_misconnection_misc', 'flag_dng', 'delayed_minutes_arrival',
           'cancelled', 'mkt_permission', 'purchases_2015', 'purchases_2016', 'purchases_2017', 'ind_direct_sale', 'destination_city_od',
           'is_corporate', 'revenue_avios', 'pax_type_seg', 'op_carrier_code', 'email_operative', 'point_of_sale']

df_customer_data_prc_features, df_prc = execute_etl_churn(s3_dir_customer_part, columns, po_execution)
df_prc = df_prc.cache()
df_customer_data_prc_features = df_customer_data_prc_features.cache()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

2019-12-31 00:00:00 2022-12-31 00:00:00 None

In [51]:
df_customer_data_prc_features.where(col('flag_ibplus') == 1).count()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

1139060

In [11]:
df_customer_data_prc_features.repartition(20).write.mode('overwrite').orc('s3://iberia-data-lake/customer/churn_model_v2/po_data_2/')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [13]:
%%pretty
df_customer_data_prc_features.show(30)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

cid,frequency,gross_revenue_eur,n_ancillaries,eur_ancillaries,num_days_anticipation,num_hours_in_destination,ind_reason_business,loc_dep_weekend,ow_flights,gross_revenue_and_ancillaries,gross_revenue_and_ancillaries_year3,gross_revenue_and_ancillaries_year2,gross_revenue_and_ancillaries_year1,frequency_year3,frequency_year2,frequency_year1,flag_ibplus,pax_type_ind,age,alone_flights,couple_flights,group_flights,summer_flights,winter_flights,fall_flights,spring_flights,seniority,last_purchase,last2_purchase,last3_purchase,last_flight,sh_haul,mh_haul,lh_haul,economy,premium,business,ff_tier,mean_deviation_flight_spent,max_minutes_delay_arr_year1,mkt_permission,n_exchanges,n_not_travelled,n_refunded,purchases_2015,purchases_2016,purchases_2017,most_od_bought,most_od_flown,ind_direct_sale,last_spent,n_destination_city_od,is_corporate,revenue_avios,flag_misconnection_misc_last_week,flag_dng_last_week,flight_cnld_last_week,flight_delay_last_week,IB_pctg,is_resident,flag_last_flight_delay,perc_email_agency,n_ticket_email,gross_revenue_sh_year1,gross_revenue_mh_year1,gross_revenue_lh_year1,point_of_sale,mean_spend,mean_spend_year3,mean_spend_year2,mean_spend_year1
10022361,2,570.0,1,74.74,10.5,2072.1666,0.5,0.0,0.5,644.74,0.0,564.74,80.0,0,1,1,0,A,24.0,1.0,0.0,0.0,0.5,0.5,0.0,0.0,656,121,656,,16,0.0,0.5,0.5,1.0,0.0,0.0,0,0.18078783420207964,-27.0,0,0,1,0,0,0,0,0.5,VCE,0.5,80.0,2,0.5,0.0,0,0,0,0,1.0,0,0,0.0,3,0.0,80.0,0.0,ES,322.37,0.0,564.74,80.0
10048703,2,419.0,2,36.0,5.0,42.8333,0.0,0.5,0.0,455.0,0.0,188.0,267.0,0,1,1,1,A,53.0,0.0,1.0,0.0,0.5,0.0,0.5,0.0,494,44,494,,38,0.0,1.0,0.0,1.0,0.0,0.0,2,0.7316833302817913,-11.0,0,0,0,0,0,0,0,0.5,,1.0,267.0,3,0.0,0.0,0,0,0,0,1.0,0,0,0.0,2,0.0,267.0,0.0,ES,227.5,0.0,188.0,267.0
10050017,2,66.0,0,0.0,62.0,,0.0,1.0,1.0,66.0,0.0,0.0,66.0,0,0,2,0,A,,0.0,1.0,0.0,0.0,1.0,0.0,0.0,26,26,26,,-36,0.5,0.5,0.0,1.0,0.0,0.0,0,0.05666933367375418,18.0,0,0,0,0,0,0,0,0.5,MAD,1.0,18.0,2,0.0,0.0,0,0,0,0,0.5,0,0,0.0,2,48.0,18.0,0.0,NL,33.0,0.0,0.0,33.0
10056932,2,29.0,1,40.0,19.0,176.75,0.0,0.0,0.5,69.0,0.0,19.0,50.0,0,1,1,0,A,,1.0,0.0,0.0,0.0,0.5,0.5,0.0,721,56,721,,31,0.0,1.0,0.0,1.0,0.0,0.0,0,-0.7527035807438452,10.0,1,0,0,0,0,0,0,1.0,NTE,1.0,50.0,2,0.0,0.0,0,0,0,0,0.0,0,0,0.0,4,0.0,10.0,0.0,ES,34.5,0.0,19.0,50.0
10058857,3,192.44,0,0.0,35.333333333333336,,0.6666666666666666,0.6666666666666666,1.0,192.44,0.0,33.8,158.64,0,1,2,0,A,,1.0,0.0,0.0,0.3333333333333333,0.6666666666666666,0.0,0.0,380,196,339,380.0,166,0.6666666666666666,0.3333333333333333,0.0,1.0,0.0,0.0,0,-0.03595035771176919,9.0,1,0,1,0,0,0,0,0.6666666666666666,VGO,1.0,125.0,2,0.0,0.0,0,0,0,0,0.6666666666666666,0,0,0.0,4,33.64,125.0,0.0,ES,64.14666666666666,0.0,33.8,79.32
10071852,2,965.78,0,0.0,28.0,233.3333,1.0,0.0,0.0,965.78,0.0,0.0,965.78,0,0,2,0,A,28.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,133,133,135,,95,0.0,0.0,1.0,0.6666666666666666,0.3333333333333333,0.0,0,,-2.0,0,0,0,0,0,0,0,1.0,BRU,0.0,144.2,2,0.5,0.0,0,0,0,0,0.6666666666666666,0,0,0.0,4,0.0,0.0,965.78,CO,482.89,0.0,0.0,482.89
10117317,2,234.55,0,0.0,5.5,,1.0,0.0,1.0,234.55,0.0,0.0,234.55,0,0,2,0,A,,1.0,0.0,0.0,0.0,0.0,0.5,0.5,214,78,214,,74,1.0,0.0,0.0,1.0,0.0,0.0,0,0.8643230537599548,17.0,1,0,0,0,0,0,0,0.5,MAD,1.0,180.0,2,0.0,0.0,0,0,0,0,1.0,0,1,0.0,3,234.55,0.0,0.0,ES,117.275,0.0,0.0,117.275
10117752,2,315.02,1,79.47,24.5,13.5,1.0,0.6666666666666666,0.5,394.49,0.0,0.0,394.49,0,0,2,0,A,24.0,1.0,0.0,0.0,0.5,0.0,0.0,0.5,1873,200,255,,166,0.0,0.5,0.5,1.0,0.0,0.0,0,-0.2484193077748118,3.0,1,0,2,0,0,1,0,0.5,MIL,1.0,379.49,3,0.0,0.0,0,0,0,0,0.0,0,0,0.0,7,0.0,15.0,300.02,US,197.245,0.0,0.0,197.245
10117972,5,511.59,0,0.0,23.8,86.708325,0.0,0.5555555555555556,0.2,511.59,0.0,0.0,511.59,0,0,5,0,A,,1.0,0.0,0.0,0.2,0.2,0.0,0.6,288,42,203,241.0,-2,1.0,0.0,0.0,1.0,0.0,0.0,0,-0.25241424000043,-2.0,0,0,2,0,0,0,0,0.8,PMI,1.0,154.53,2,0.0,0.0,0,0,0,0,0.0,1,0,0.0,9,511.59,0.0,0.0,ES,102.318,0.0,0.0,102.318
10120594,2,365.3,4,128.06,35.5,1450.16665,0.0,0.25,0.0,493.36,0.0,172.92000000000002,320.44,0,1,1,0,A,55.0,1.0,0.0,0.0,0.5,0.0,0.0,0.5,677,160,677,,32,0.0,0.0,1.0,1.0,0.0,0.0,0,-0.6202216834043185,7.0,0,0,4,0,0,0,0,1.0,MUC,1.0,320.44,2,0.0,0.0,0,0,0,0,0.75,0,0,0.0,2,0.0,0.0,243.02,MX,246.68,0.0,172.92000000000002,320.44
