In [1]:
import os
import glob
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import random
from datetime import datetime, timedelta
from dateutil.relativedelta import relativedelta
import pprint
import pyspark
import pyspark.sql.functions as F

from pyspark.sql.functions import col
from pyspark.sql.types import StringType, IntegerType, FloatType, DateType


In [2]:
# Initialize SparkSession
spark = pyspark.sql.SparkSession.builder \
    .appName("dev") \
    .config("spark.driver.memory", "4g") \
    .master("local[*]") \
    .getOrCreate()

# Set log level to ERROR to hide warnings
spark.sparkContext.setLogLevel("ERROR")

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/10/29 13:22:08 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:

df = (spark.read
      .option("header", True)
      .option("inferSchema", True)
      .parquet("datamart/bronze/user_logs/year=2015/month=01"))


In [4]:
df_members = (spark.read
      .option("header", True)
      .option("inferSchema", True)
      .parquet("datamart/bronze/members"))

In [4]:
# ========= 0) Setup =========
from pyspark.sql import functions as F

SNAPSHOT_DATE_STR = "2017-02-28"   # cutoff cho train
SNAPSHOT_YEAR     = 2017           # (không dùng nếu bạn giữ rule 14–68)

# ========= 1) Field Format =========
dfm = (
    df_members
      .withColumn("msno", F.lower(F.trim(F.col("msno"))))
      .withColumn("city", F.col("city").cast("int"))
      .withColumn("bd", F.col("bd").cast("int"))
      .withColumn("gender", F.lower(F.trim(F.col("gender"))))
      .withColumn("registered_via", F.col("registered_via").cast("int"))
      .withColumn("registration_init_time", F.col("registration_init_time").cast("string"))
)

# ========= 2) Date =========
dfm = dfm.withColumn(
    "registration_date",
    F.to_date(F.col("registration_init_time"), "yyyyMMdd")
)

# ========= 3) City clean =========
dfm = dfm.withColumn(
    "city_clean",
    F.when(F.col("city") <= 0, None).otherwise(F.col("city"))
)

# ========= 4.1) Gender clean =========
dfm = dfm.withColumn(
    "gender_norm",
    F.when(F.col("gender").isin("male", "female"), F.col("gender")).otherwise(F.lit("unknown"))
)

# ========= 4.2) Gender one-hot =========
dfm = (dfm
    .drop("gender_male","gender_female","gender_unknown")
    .withColumn("gender_male",    (F.col("gender_norm")=="male").cast("int"))
    .withColumn("gender_female",  (F.col("gender_norm")=="female").cast("int"))
    .withColumn("gender_unknown", (F.col("gender_norm")=="unknown").cast("int"))
)

# ========= 5) BD clean rule: 14–68 & count >= 1000 =========
bd_hist = dfm.groupBy("bd").agg(F.count("*").alias("bd_count"))
dfm = (
    dfm.join(bd_hist, on="bd", how="left")
       .withColumn("bd_count", F.coalesce(F.col("bd_count"), F.lit(0)))
       .withColumn(
           "bd_clean",
           F.when((F.col("bd").between(14, 68)) & (F.col("bd_count") >= 1000), F.col("bd"))
            .otherwise(F.lit(None).cast("int"))
       )
       .drop("bd_count")
)


# ========= 6) Tenure to cutoff =========
dfm = dfm.withColumn(
    "tenure_days_at_snapshot",
    F.datediff(F.to_date(F.lit(SNAPSHOT_DATE_STR)), F.col("registration_date"))
)

# ========= 7) Frequency enrich (Silver+) =========
# 7a) registered_via frequency
total_cnt = dfm.count()  # nếu bảng rất lớn, có thể approx bằng sample ratio
via_freq = (
    dfm.groupBy("registered_via")
       .agg((F.count("*") / F.lit(total_cnt)).alias("registered_via_freq"))
)

# 7b) city frequency
city_freq = (
    dfm.groupBy("city_clean")
       .agg((F.count("*") / F.lit(total_cnt)).alias("city_freq"))
)

# 7c) Join freq
dfm = (
    dfm.drop("registered_via_freq", "city_freq")
       .join(via_freq, on="registered_via", how="left")
       .join(city_freq, on="city_clean",  how="left")
       .fillna({"registered_via_freq": 0.0, "city_freq": 0.0})
)

# ========= 8) SILVER (clean + enrich) =========
silver_cols = [
    "msno",
    "city_clean",
    "bd_clean",
    "gender_norm", "gender_male","gender_female","gender_unknown",
    "registered_via",
    "registration_date",
    "tenure_days_at_snapshot",
    "registered_via_freq",
    "city_freq"
]
silver_members = dfm.select(*silver_cols)

In [None]:
# ----- Layer 2 -------

In [5]:
from pyspark.sql import functions as F

cols_to_drop = [c for c in [
    "bd", "bd_clean",
    "gender", "gender_norm",
    "gender_male", "gender_female",
    "gender_unknown", "gender_unknown_flag"
] if c in silver_members.columns]

df2 = silver_members.drop(*cols_to_drop)

In [6]:
df2 = df2.withColumn("city_clean", F.col("city_clean").cast("string"))

In [7]:
if "registered_via" in df2.columns:
    df2 = df2.withColumn("registered_via", F.col("registered_via").cast("string"))

In [8]:
from pyspark.ml.feature import StringIndexer, OneHotEncoder
from pyspark.ml import Pipeline

pipe_silver = Pipeline(stages=[
    StringIndexer(inputCol="city_clean",      outputCol="city_idx", handleInvalid="keep"),
    StringIndexer(inputCol="registered_via",  outputCol="via_idx",  handleInvalid="keep"),
    OneHotEncoder(inputCols=["city_idx","via_idx"], outputCols=["city_oh","via_oh"], dropLast=True)
])

sil_model = pipe_silver.fit(df2)
df_silver = (sil_model.transform(df2)
             .select(*df2.columns, "city_idx","via_idx","city_oh","via_oh"))

                                                                                

In [9]:
df_silver.printSchema()
df_silver.selectExpr(
    "count(*) as n_rows",
    "sum(case when city_clean is null then 1 else 0 end) as n_city_null",
    "sum(case when registered_via is null then 1 else 0 end) as n_via_null"
).show()

df_silver.select("city_clean","city_idx","city_oh",
                 "registered_via","via_idx","via_oh").limit(5).show(truncate=False)

df_silver.limit(10).show(truncate=False)

root
 |-- msno: string (nullable = true)
 |-- city_clean: string (nullable = true)
 |-- registered_via: string (nullable = true)
 |-- registration_date: date (nullable = true)
 |-- tenure_days_at_snapshot: integer (nullable = true)
 |-- registered_via_freq: double (nullable = false)
 |-- city_freq: double (nullable = false)
 |-- city_idx: double (nullable = false)
 |-- via_idx: double (nullable = false)
 |-- city_oh: vector (nullable = true)
 |-- via_oh: vector (nullable = true)

+-------+-----------+----------+
| n_rows|n_city_null|n_via_null|
+-------+-----------+----------+
|6769473|          0|         0|
+-------+-----------+----------+

+----------+--------+--------------+--------------+-------+--------------+
|city_clean|city_idx|city_oh       |registered_via|via_idx|via_oh        |
+----------+--------+--------------+--------------+-------+--------------+
|14        |7.0     |(21,[7],[1.0])|9             |2.0    |(18,[2],[1.0])|
|13        |2.0     |(21,[2],[1.0])|9            

                                                                                

+--------------------------------------------+----------+--------------+-----------------+-----------------------+--------------------+------------------+--------+-------+--------------+---------------+
|msno                                        |city_clean|registered_via|registration_date|tenure_days_at_snapshot|registered_via_freq |city_freq         |city_idx|via_idx|city_oh       |via_oh         |
+--------------------------------------------+----------+--------------+-----------------+-----------------------+--------------------+------------------+--------+-------+--------------+---------------+
|grugjyjs++wxah75n0kcpxr4qf9hb9mv0hbm1wkzfji=|1         |1             |2016-06-08       |265                    |6.352045425101776E-6|0.7097045811394772|0.0     |14.0   |(21,[0],[1.0])|(18,[14],[1.0])|
|yxl0onnyqtkz3aio+e+3hqjia3kvyglhes9+pfpsspg=|1         |1             |2016-07-26       |217                    |6.352045425101776E-6|0.7097045811394772|0.0     |14.0   |(21,[0],[1.0])|(1

In [10]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()

output_path = "datamart/silver/members"

(
    df_silver
      .coalesce(10)             # tùy: giảm số file output, nếu local có thể coalesce(10)
      .write
      .mode("overwrite")         # ghi đè nếu có
      .option("compression", "snappy")
      .parquet(output_path)
)

print(f"✅ Silver layer saved to {output_path}")

                                                                                

✅ Silver layer saved to datamart/silver/members


# Train

In [61]:
df_train = (spark.read
      .option("header", True)
      .option("inferSchema", True)
      .csv("data/train.csv"))

In [62]:
df_train.show()

+--------------------+--------+
|                msno|is_churn|
+--------------------+--------+
|waLDQMmcOu2jLDaV1...|       1|
|QA7uiXy8vIbUSPOkC...|       1|
|fGwBva6hikQmTJzrb...|       1|
|mT5V8rEpa+8wuqi6x...|       1|
|XaPhtGLk/5UvvOYHc...|       1|
|GBy8qSz16X5iYWD+3...|       1|
|lYLh7TdkWpIoQs3i3...|       1|
|T0FF6lumjKcqEO0O+...|       1|
|Nb1ZGEmagQeba5E+n...|       1|
|MkuWz0Nq6/Oq5fKqR...|       1|
|I8dFN2EjFN1mt4Xel...|       1|
|0Ip2rzeoa44alqEw3...|       1|
|piVhWxrWDmiNQFY6x...|       1|
|wEUOkYvyz3xTOx2p9...|       1|
|xt4EjWRyXBMgEgKBJ...|       1|
|QS3ob4zLlWcWzBIlb...|       1|
|9iW/UpqRoviya9CQh...|       1|
|d7QVMhAzjj4yc1Ojj...|       1|
|uV7rJjHPrpNssDMmY...|       1|
|TZxhkfZ9NwxqnUrNs...|       1|
+--------------------+--------+
only showing top 20 rows



In [63]:
df_train.count()

992931

In [64]:
df_train_v2 = (spark.read
      .option("header", True)
      .option("inferSchema", True)
      .csv("data/train_v2.csv"))

                                                                                

In [66]:
df_train_v2.count()

970960

In [9]:
df_members.head()

NameError: name 'df_members' is not defined

In [None]:
df_train_v2.show()

In [1]:
df_train_v2.show(5)

NameError: name 'df_train_v2' is not defined