- Understand spark configuration tuning for better performance during data processing.
- Undertand the amount of resources needed for a Job.
- Use increasing job loads for test purposes (500k -> 1M -> 2M records)

# Streaming:
.maxfile per trigger, 1 -> set the number of file when spart streaming is triggered.

In [None]:
from pyspark.sql.types import StructType, StructField, StringType, FloatType
from pyspark.sql import SparkSession

spark = (
    SparkSession.builder
      .appName("streaming")
      .config("spark.hadoop.fs.s3a.endpoint", "http://minio-service:9000")
      .config("spark.hadoop.fs.s3a.access.key", "minioadmin")
      .config("spark.hadoop.fs.s3a.secret.key", "minioadmin")
      .config("spark.hadoop.fs.s3a.path.style.access", "true")
      .config("spark.hadoop.fs.s3a.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem")
    #   .config("spark.jars.packages", "org.mongodb.spark:mongo-spark-connector_2.12:10.1.1")
      .config("spark.jars", "/opt/spark/jars/aws-java-sdk-bundle-1.12.262.jar,/opt/spark/jars/delta-core_2.12-2.4.0.jar,/opt/spark/jars/delta-storage-2.4.0.jar,/opt/spark/jars/hadoop-aws-3.3.4.jar")
      .getOrCreate()
)


schemaTypes = {
    "MSISDN": "String", "NID": "String", "NID_LEN": "Number", "CUSTOMERIDNAME": "String", "EMAIL": "String",
    "GROSS_DATE": "Date", "CUSTOMER_TYPE": "String", "USERTYPE": "String", "SUBSCRIBERCAT": "String",
    "SUBSCRIBERSUBCAT": "String", "SEGMENT": "String", "POSTPAID_TARIFF": "String", "POSTPAID_MAINPRODUCT": "String",
    "POSTPAID_SUBPRODUCT": "String", "NATIONALITY": "String", "AON_DAYS": "Number", "AON_MONTH": "Number",
    "AON_YEAR": "Number", "ALTERNATE_MOB_NUM": "String", "FAVOURITE_LOCATION": "String", "DAYS_INACT": "Number",
    "BAL_1": "Number", "BAL_2": "Number", "BAL_3": "Number", "DATA_USERS": "String",
    "PACK_USERS": "String", "PACK_90D": "Number", "PACK_30D": "Number", "DAILY_PACK_90D": "Number", "DAILY_PACK_30D": "Number",
    "WEEKLY_PACK_90D": "Number", "WEEKLY_PACK_30D": "Number", "MONTHLY_PACK_90D": "Number", "MONTHLY_PACK_30D": "Number",
    "RECHARGE_VAL_30D": "Number", "RECHARGE_VAL_D1": "Number", "RECHARGE_VAL_D2": "Number", "RECHARGE_VAL_D3": "Number",
    "AIRTIME_USG": "Number", "YOUTH_PACK": "String", "TOTAL_ARPU": "Number", "ONNET_REVENUE": "Number", "XNET_REVENUE": "Number",
    "ONNET_TOTAL_MINS_90D": "Number", "ONNET_TOTAL_MINS_30D": "Number", "XNET_TOTAL_MINS_90D": "Number", "XNET_TOTAL_MINS_30D": "Number",
    "VOICE_INCOMING_ARPU": "Number", "SMS_ARPU": "Number", "DATA_USG_GB_90D": "Number", "DATA_USG_GB_30D": "Number",
    "DATA_PAYG_USG_GB_90D": "Number", "DATA_PAYG_USG_GB_30D": "Number", "TOTAL_DATA_ARPU_90D": "Number", "TOTAL_DATA_ARPU_30D": "Number",
    "TOTAL_DATA_PAYG_ARPU_90D": "Number", "TOTAL_DATA_PAYG_ARPU_30D": "Number", "ILD_ARPU": "Number", "ILD_SMS_ARPU": "Number", "DEVICE_SUBTYPE_1": "String",
    "DEVICE_SUBTYPE_2": "String", "DEVICE_SUBTYPE_3": "String", "OS_SYSTEM_1": "String", "OS_SYSTEM_2": "String", "OS_SYSTEM_3": "String",
    "TECHNOLOGY_1": "String","TECHNOLOGY_2": "String","TECHNOLOGY_3": "String","SIM_TYPE_1": "String",
    "SIM_TYPE_2": "String", "SIM_TYPE_3": "String", "DUAL_SIM_1": "String", "DUAL_SIM_2": "String", "DUAL_SIM_3": "String",
    "ESIM_1": "String", "ESIM_2": "String", "ESIM_3": "String"
}

def spark_type(dtype):
    if dtype == "String":
        return StringType()
    elif dtype == "Number":
        return FloatType()
    elif dtype == "Date":
        return StringType()
    return StringType()

schema = StructType([
    StructField(name, spark_type(dtype), True)
    for name, dtype in schemaTypes.items()
])

spark.conf.set("spark.sql.shuffle.partitions", 50)

# sf = (
#     spark.read.parquet("s3a://batch-etl-pipeline/final/part-00000-31f26158-5a47-4602-8a52-bc7563c2ff3f-c000.snappy.parquet")
# )
# print("\n\n\n\n\n\n", sf.rdd.getNumPartitions(), "\n\n\n\n\n\n")  # This will return the number of partitions in the DataFrame.
# sf.show(2)  # This will show the first 2 rows of the DataFrame.
# print("\n\n\n\n\n\n", sf.rdd.getNumPartitions(), "\n\n\n\n\n\n")  # This will return the number of partitions in the DataFrame.



# sf.write.parquet("s3a://batch-etl-pipeline/3i")
# exit()

df = (
    spark.readStream.format("csv")
    .option("header", "true")
    .schema(schema)
    .option("maxFilesPerTrigger", 1)  # Process one file at a time
    .load("s3a://batch-etl-pipeline/zain/")
)

# df.rdd.getNumPartitions()
# df.show() # ❌ Not allowed for streaming DataFrame
# In structured streaming, you must always use .writeStream.start(), never use .show() or .collect(), or you’ll hit OOM.
# ✅ foreachBatch is the only way to inspect DataFrame.rdd in streaming.

"""

Just like spark.read(...), the spark.readStream(...) method reads data in parallel by splitting input files into partitions.
df.rdd.getNumPartitions() -> Returns the number of partitions in the DataFrame.


* ✅ **`.repartition`**: This **explicitly forces a cluster-wide shuffle** to redistribute data **before the next transformation** (or action).
* ✅ **`.coalesce`**: This **reduces the number of partitions**, **avoiding a full shuffle** by merging partitions where possible with minimal movement..
* ✅ **`spark.sql.shuffle.partitions`**: This controls the **default number of partitions** used **after wide transformations** (like `groupBy`, `join`, `distinct`, etc.) **when you don't specify partitioning explicitly**.
spark.conf.set("spark.sql.shuffle.partitions", 50)

---

Min partitions from sparkContext.defaultParallelism
> On local mode: usually = number of cores.
> On cluster mode: total number of cores across the cluster.

If you're reading a single 193MB CSV file and you're seeing 8 partitions, it's most likely because:
spark.sparkContext.defaultParallelism == 8

---

.repartition(n) → Controls number of partitions before a shuffle
spark.sql.shuffle.partitions → Controls number of shuffle output partitions during wide transformations

---

spark.sql.files.maxPartitionBytes	128MB	
The maximum number of bytes to pack into a single partition when reading files. This configuration is effective only when using file-based sources such as Parquet, JSON and ORC.


spark.default.parallelism:  Default number of partitions in RDDs returned by transformations like join, reduceByKey, and parallelize when not set by user.	
Local mode: number of cores on the local machine
Others: total number of cores on all executor nodes or 2, whichever is larger



🔹 When Should You Repartition?
Too few partitions? → Not enough parallelism → poor CPU utilization.

Too many small partitions? → Overhead from task scheduling and I/O.

Good rule of thumb: aim for 100–200 MB per partition for optimal performance.
"""

print(spark.sparkContext.defaultParallelism, "\n\n\n\n\n\n\n\n\n")  # This will print the default parallelism level, which is usually equal to the number of cores available in the cluster.
# spark.sparkContext.defaultParallelism  # This will also return the default parallelism level.
from pyspark.sql.functions import col, concat, lit

def inspect_partition_count(batch_df, batch_id):
    print(f"\n\n\n\n\n\n\n\n✅ Batch {batch_id} - Num Partitions: {batch_df.rdd.getNumPartitions()}")
    # batch_df = batch_df.repartition(2)
    batch_df = batch_df.coalesce(3)
    batch_df.collect()  # This will trigger the computation and allow you to see the number of partitions.
    batch_df = batch_df.groupBy("NID").count()

    print(f"\n\n\n\n\n\n\n\n✅ ---- Batch {batch_id} - Num Partitions: {batch_df.rdd.getNumPartitions()} \n\n\n\n")
    batch_df.show(2)
    # # batch_df.select("MSISDN").show()  # Example operation to inspect data
    # phon = batch_df.withColumn("Phone", concat(lit("230"), col("MSISDN")))  # Example operation to write data
    # phon.select("Phone").show(10)



# If you're using streaming, you must use .writeStream to start the query.
query = (df.writeStream
    .format("console")
    .foreachBatch(inspect_partition_count)
    .outputMode("append")
    .option("checkpointLocation", "s3a://batch-etl-pipeline/checkpoints1/upsert-csv/")
    .trigger(processingTime=f"10 seconds")
    .start()
)

# query = (
#     df.writeStream
#     .format("mongodb")
#     .outputMode("append")
#     .option("checkpointLocation", "s3a://batch-etl-pipeline/checkpoints/upsert-csv/")
#     .option("spark.mongodb.connection.uri", "mongodb://host.docker.internal:27017/")
#     .option("spark.mongodb.database", "backend")        # ✅ specify your DB
#     .option("spark.mongodb.collection", "new")    # ✅ specify your collection
#     .trigger(processingTime="10 seconds")
#     .start()
# )

query.awaitTermination()










# spark-submit --master spark://spark-master:7077 --deploy-mode client --jars / /opt/spark/jars/aws-java-sdk-bundle-1.12.262.jar,/opt/spark/jars/delta-core_2.12-2.4.0.jar,/opt/spark/jars/delta-storage-2.4.0.jar,/opt/spark/jars/hadoop-aws-3.3.4.jar test.py

# MINIO_INPUT_PATH=s3a://batch-etl-pipeline/split-csvs/
# CHECKPOINT_LOCATION=s3a://batch-etl-pipeline/checkpoints/upsert-csv/

# MONGO_USER=EMTEL_CAMPAIGN_PROD
# MONGO_PASSWORD=Emt3lCaMpa1gN0dC
# MONGO_HOSTS=172.26.64.42:27017,172.26.64.44:27017,172.26.64.43:27017,172.26.64.162:27017,172.26.64.163:27017
# MONGO_PORT=27017
# MONGO_DB=EMTEL_CAMPAIGN_PROD
# MONGO_COLLECTION=userdatas-test
# REPLICA_SET_NAME=EMTEL

# MINIO_S3A_ACCESS_KEY=minioadmin
# MINIO_S3A_SECRET_KEY=minioadmin

# SPARK_APP_NAME=CSV-Upsert-Job-UAT
# SPARK_JARS=/opt/spark/jars/aws-java-sdk-bundle-1.12.262.jar,/opt/spark/jars/delta-core_2.12-2.4.0.jar,/opt/spark/jars/delta-storage-2.4.0.jar,/opt/spark/jars/hadoop-aws-3.3.4.jar

# STREAMING_PROCESSING_TIME=20
# REPARTITION_COUNT=200

# df = spark.readStream.format("csv").option("header", "true").schema("").load("path/to/input/directory")




"""
spark-submit \
  --master spark://spark-master:7077 \
  --deploy-mode client \
  --packages org.mongodb.spark:mongo-spark-connector_2.12:10.2.0 \
  --jars /opt/spark/jars/aws-java-sdk-bundle-1.12.262.jar,/opt/spark/jars/delta-core_2.12-2.4.0.jar,/opt/spark/jars/delta-storage-2.4.0.jar,/opt/spark/jars/hadoop-aws-3.3.4.jar \
  test.py


"""