In [1]:
import os
import glob
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import random
from datetime import datetime, timedelta
from dateutil.relativedelta import relativedelta
import pprint
import pyspark
import pyspark.sql.functions as F

from pyspark.sql.functions import to_date, col, lit
from pyspark.sql.types import StringType, IntegerType, FloatType, DateType

In [2]:
# Initialize SparkSession
spark = pyspark.sql.SparkSession.builder \
    .appName("dev") \
    .config("spark.driver.memory", "4g") \
    .master("local[*]") \
    .getOrCreate()
# Set log level to ERROR to hide warnings
spark.sparkContext.setLogLevel("ERROR")

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/10/29 06:54:11 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/10/29 06:54:12 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
25/10/29 06:54:12 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.


# Gold label store

In [3]:
df_members = (spark.read
      .option("header", True)
      .option("inferSchema", True)
      .parquet("datamart/silver/members"))

In [4]:
df_members.show(5)

                                                                                

+--------------------+----------+--------------+-----------------+-----------------------+--------------------+------------------+--------+-------+--------------+---------------+
|                msno|city_clean|registered_via|registration_date|tenure_days_at_snapshot| registered_via_freq|         city_freq|city_idx|via_idx|       city_oh|         via_oh|
+--------------------+----------+--------------+-----------------+-----------------------+--------------------+------------------+--------+-------+--------------+---------------+
|Z1SBrlbnzZzQZtlS3...|         1|             1|       2016-06-12|                    261|6.352045425101776E-6|0.7097045811394772|     0.0|   14.0|(21,[0],[1.0])|(18,[14],[1.0])|
|2GkgHuwB+NCVnpSRS...|         1|             1|       2016-01-19|                    406|6.352045425101776E-6|0.7097045811394772|     0.0|   14.0|(21,[0],[1.0])|(18,[14],[1.0])|
|X1AmJaNJ1bpGEgxLv...|         1|            13|       2016-12-27|                     63| 8.058234370681

In [18]:
df_transactions = (spark.read
      .option("header", True)
      .option("inferSchema", True)
      .parquet("datamart/silver/transactions"))

In [19]:
df_transactions.show(5)

+--------------------+-----------------+-----------------+---------------+------------------+-------------+----------------+----------------------+---------+-------------------+---+--------------------+----+-----+
|                msno|payment_method_id|payment_plan_days|plan_list_price|actual_amount_paid|is_auto_renew|transaction_date|membership_expire_date|is_cancel|        source_file|day|      transaction_id|year|month|
+--------------------+-----------------+-----------------+---------------+------------------+-------------+----------------+----------------------+---------+-------------------+---+--------------------+----+-----+
|VIHHLepUEMXxF2iel...|               41|               30|            149|               149|            1|      2016-11-08|            2016-12-08|        0|   transactions.csv|  8|84c76dce-ee16-403...|2016|   11|
|TkAhpvvz+vU7LBuVH...|               26|                1|              0|                 0|            0|      2016-11-23|            2016-11-

# Set Today's date (aka inference date)

This date is training "cutoff".. anything after this date will not be used

In [20]:
inference_date = "2015-12-15"
print(f"Today's date is set to: {inference_date}")

Today's date is set to: 2015-12-15


### Filter members
Fetch all users who were registered before this inference date.

In [21]:
registered_users = (
    df_members
    .withColumn("registration_date", F.to_date("registration_date"))
    .filter(F.col("registration_date") <= F.to_date(F.lit(inference_date)))
    .withColumn(
        "tenure_days_at_snapshot",
        F.datediff(F.to_date(F.lit(inference_date)), F.col("registration_date"))
    )
    .select(
        "msno",
        "registration_date",
        "tenure_days_at_snapshot",
        "registered_via",
        "city_clean",
        "via_oh",
        "city_oh"
    )
)


print("Registered users up to", inference_date, ":", registered_users.count())
registered_users.show(5, truncate=False)

Registered users up to 2015-12-15 : 3924570




+--------------------------------------------+-----------------+-----------------------+--------------+----------+---------------+--------------+
|msno                                        |registration_date|tenure_days_at_snapshot|registered_via|city_clean|via_oh         |city_oh       |
+--------------------------------------------+-----------------+-----------------------+--------------+----------+---------------+--------------+
|2Dd2zaY3UQAbsnOW2ffCse4HYFpnBKkA3GNenPUYxdE=|2014-11-01       |409                    |16            |1         |(18,[12],[1.0])|(21,[0],[1.0])|
|TgFv/gOLcbCTq1K5UrsCxZuozau6mWpTlGolwm6t1/g=|2014-10-23       |418                    |16            |1         |(18,[12],[1.0])|(21,[0],[1.0])|
|ZZuGBLOObobJjHLd5aFRq0K/2jZk9mrZNdIY8UO7lrs=|2014-12-16       |364                    |16            |1         |(18,[12],[1.0])|(21,[0],[1.0])|
|bh4qCoZL8kgUmYBBjglJm03CYujRqL4WdS1ybD4kSLc=|2014-12-21       |359                    |16            |1         |(18,[12],[

                                                                                

### Filter Transactions
Filter all transactions before the inference date

In [None]:
from pyspark.sql import functions as F

latest_expiry = (
    df_transactions
    .filter(F.col("membership_expire_date") <= F.to_date(F.lit(inference_date)))
    .groupBy("msno")
    .agg(F.max("membership_expire_date").alias("latest_expiry"))
)