<a href="https://colab.research.google.com/github/mariamcs/Customer_Churn/blob/main/Customer_Churn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Data Creation Step**

In [2]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import rand, when, round as spark_round

# Create Spark session
spark = SparkSession.builder.appName("SimulatedNetflixChurn").getOrCreate()

# Number of rows
n = 100000

# Generate DataFrame
df = (
    spark.range(0, n)
    .withColumn("daily_watch_minutes", spark_round(rand() * 300, 1))
    .withColumn("avg_session_length", spark_round(rand() * 90 + 10, 1))
    .withColumn("last_login_days_ago", (rand() * 60).cast("int"))
    .withColumn("binge_sessions_last_30d", (rand() * 10).cast("int"))
    .withColumn("completion_rate", spark_round(rand(), 2))

    .withColumn("plan_type", when(rand() < 0.6, "Standard").when(rand() < 0.85, "Premium").otherwise("Basic"))
    .withColumn("tenure_months", (rand() * 48).cast("int"))
    .withColumn("price_per_hour_watched", spark_round(rand() * 0.5 + 0.2, 2))
    .withColumn("billing_failures_last_90d", (rand() * 3).cast("int"))
    .withColumn("upgrades_last_6mo", (rand() * 2).cast("int"))

    .withColumn("has_kids_profile", when(rand() < 0.3, 1).otherwise(0))
    .withColumn("uses_download_feature", when(rand() < 0.5, 1).otherwise(0))
    .withColumn("simultaneous_streams_used", (rand() * 4 + 1).cast("int"))
    .withColumn("primary_device_type", when(rand() < 0.4, "Smart TV").when(rand() < 0.7, "Mobile").otherwise("Laptop"))
    .withColumn("geo_consistency_score", spark_round(rand(), 2))

    .withColumn("support_tickets_last_6mo", (rand() * 5).cast("int"))
    .withColumn("cancel_reason_code", when(rand() < 0.1, "Pricing").when(rand() < 0.2, "Content").when(rand() < 0.3, "Tech Issues").otherwise("None"))
    .withColumn("issue_resolution_time_avg", spark_round(rand() * 48, 1))

    .withColumn("churned", when(rand() < 0.2, 1).otherwise(0))
)

df.show(5)


+---+-------------------+------------------+-------------------+-----------------------+---------------+---------+-------------+----------------------+-------------------------+-----------------+----------------+---------------------+-------------------------+-------------------+---------------------+------------------------+------------------+-------------------------+-------+
| id|daily_watch_minutes|avg_session_length|last_login_days_ago|binge_sessions_last_30d|completion_rate|plan_type|tenure_months|price_per_hour_watched|billing_failures_last_90d|upgrades_last_6mo|has_kids_profile|uses_download_feature|simultaneous_streams_used|primary_device_type|geo_consistency_score|support_tickets_last_6mo|cancel_reason_code|issue_resolution_time_avg|churned|
+---+-------------------+------------------+-------------------+-----------------------+---------------+---------+-------------+----------------------+-------------------------+-----------------+----------------+---------------------+----

# **Data Cleaning Step**

In [5]:
# Drop rows with nulls (simple method, can be replaced with imputation if needed)
df_clean = df.na.drop()

# Optionally: You could also handle outliers or convert types here
# Example: Log-transform skewed columns, or clip values


# **Data engineering Step**


In [6]:
df_eng = df_clean \
    .withColumn("engagement_ratio", col("daily_watch_minutes") / (col("avg_session_length") + 1)) \
    .withColumn("recent_login_flag", when(col("last_login_days_ago") < 7, 1).otherwise(0)) \
    .withColumn("binge_watcher", when(col("binge_sessions_last_30d") > 5, 1).otherwise(0)) \
    .withColumn("low_completion_flag", when(col("completion_rate") < 0.5, 1).otherwise(0)) \
    .withColumn("tenure_group", when(col("tenure_months") < 6, "New")
                .when(col("tenure_months") < 12, "6-12mo")
                .when(col("tenure_months") < 24, "12-24mo")
                .otherwise("24mo+")) \
    .withColumn("support_contact_flag", when(col("support_tickets_last_6mo") > 0, 1).otherwise(0)) \
    .withColumn("high_resolution_delay", when(col("issue_resolution_time_avg") > 24, 1).otherwise(0))


Encoding Categorical Features

In [7]:
categorical = ["plan_type", "primary_device_type", "cancel_reason_code", "tenure_group"]
indexers = [StringIndexer(inputCol=c, outputCol=c+"_idx", handleInvalid="keep") for c in categorical]
encoders = [OneHotEncoder(inputCol=c+"_idx", outputCol=c+"_vec") for c in categorical]


# **Preparing the Data for Model**

In [8]:
from pyspark.ml.feature import VectorAssembler
from pyspark.ml import Pipeline

engineered = ["engagement_ratio", "recent_login_flag", "binge_watcher", "low_completion_flag",
              "support_contact_flag", "high_resolution_delay"]

numerics = [c for c in df.columns if c not in categorical + ["churned"]] + engineered
features = numerics + [c + "_vec" for c in categorical]

assembler = VectorAssembler(inputCols=features, outputCol="features")
pipeline = Pipeline(stages=indexers + encoders + [assembler])

model = pipeline.fit(df_eng)
df_ready = model.transform(df_eng)
df_ready.select("features", "churned").show(5, truncate=False)


+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+-------+
|features                                                                                                                                                                        |churned|
+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+-------+
|[0.0,9.8,81.6,56.0,6.0,0.14,38.0,0.56,2.0,1.0,1.0,0.0,4.0,0.25,1.0,34.2,0.11864406779661019,0.0,1.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0]        |0      |
|(36,[0,1,2,3,4,5,7,8,9,12,13,14,15,16,19,20,21,22,27,30,34],[1.0,141.9,67.1,56.0,2.0,0.24,0.31,1.0,1.0,4.0,0.15,2.0,39.3,2.083700440528635,1.0,1.0,1.0,1.0,1.0,1.0,1.0])        |0      |
|(36,[0,1,2,3,4,5,6,7,8,12,13,14,15,16,18,19,20,21,23,26,28,33],[