In [None]:
import os
import time
from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession, functions as F
from pyspark.sql.functions import broadcast

# 1. Environment Setup
os.environ["PYSPARK_SUBMIT_ARGS"] = (
    "--packages com.amazonaws:aws-java-sdk-s3:1.12.196,"
    "org.apache.hadoop:hadoop-aws:3.3.1 pyspark-shell"
)

# 2. Spark Session Initialization
# We start with a generic name; we will update the UI later for history tracking
conf = SparkConf().setAppName('ISIC_Comparison_Study')
sc = SparkContext(conf=conf)
spark = SparkSession(sc).builder.getOrCreate()

# 3. AWS Configuration
hadoopConf = sc._jsc.hadoopConfiguration()
# here set the keys as environment variables or use IAM roles if running on AWS infrastructure
hadoopConf.set('fs.s3a.impl', 'org.apache.hadoop.fs.s3a.S3AFileSystem')
hadoopConf.set('spark.hadoop.fs.s3a.aws.credentials.provider', 'org.apache.hadoop.fs.s3a.SimpleAWSCredentialsProvider')

# Path Definitions
s3_path_metadata = "s3a://skin-milan/raw-data/ISIC_2019_Training_Metadata.csv"
s3_path_labels   = "s3a://skin-milan/raw-data/ISIC_2019_Training_GroundTruth.csv"
s3_output_base   = "s3a://skin-milan/final-results"
s3_parquet_path  = "s3a://skin-milan/processed-data/metadata_cache.parquet"

# 4. Load Raw Data
df_metadata = spark.read.csv(s3_path_metadata, header=True, inferSchema=True)
df_labels = spark.read.csv(s3_path_labels, header=True, inferSchema=True)

# 5. Data Synthesis (Scaling)
df_large_metadata = df_metadata.withColumn("dummy", F.explode(F.array([F.lit(i) for i in range(100)]))) \
    .withColumn("patient_record_id", F.monotonically_increasing_id()) \
    .drop("dummy")

# --- PHASE 1: NON-OPTIMIZED VERSION ---
sc.setJobGroup("Baseline", "Non-Optimized Shuffle Run") # Helps in Spark UI/History
print("\n--- STARTING NON-OPTIMIZED SHUFFLES ---")
start_raw = time.time()

df_joined_raw = df_large_metadata.join(df_labels, on="image", how="inner")

df_final_risk_raw = df_joined_raw.groupBy("anatom_site_general", "sex") \
    .agg(
        F.avg("age_approx").alias("avg_age"),
        F.sum("MEL").alias("total_melanoma"),
        F.count("image").alias("total_cases")
    ) \
    .withColumn("risk_factor", F.col("total_melanoma") / F.col("total_cases")) \
    .orderBy(F.desc("risk_factor"))

# Save Final Non-Optimized Result to S3
df_final_risk_raw.write.mode("overwrite").csv(f"{s3_output_base}/non_optimized_results")

raw_duration = round(time.time() - start_raw, 2)
print(f"Non-Optimized Pipeline Time: {raw_duration} seconds")


# --- PHASE 2: OPTIMIZED VERSION ---
sc.setJobGroup("Optimized", "Broadcast and Partition Run")
print("\n--- STARTING OPTIMIZED SHUFFLES ---")

# Step A: Optimization - Use Parquet for intermediate storage
df_large_metadata.write.mode("overwrite").parquet(s3_parquet_path)
df_opt_metadata = spark.read.parquet(s3_parquet_path)

start_opt = time.time()

# Step B: Optimization - Broadcast Join
df_joined_opt = df_opt_metadata.join(broadcast(df_labels), on="image", how="inner")

# Step C: Optimization - Explicit Repartitioning
df_final_risk_opt = df_joined_opt.repartition("anatom_site_general") \
    .groupBy("anatom_site_general", "sex") \
    .agg(
        F.avg("age_approx").alias("avg_age"),
        F.sum("MEL").alias("total_melanoma"),
        F.count("image").alias("total_cases")
    ) \
    .withColumn("risk_factor", F.col("total_melanoma") / F.col("total_cases")) \
    .orderBy(F.desc("risk_factor"))

# Save Final Optimized Result to S3
df_final_risk_opt.write.mode("overwrite").csv(f"{s3_output_base}/optimized_results")

opt_duration = round(time.time() - start_opt, 2)
print(f"Optimized Pipeline Time: {opt_duration} seconds")

# --- HISTORY & PLANS ---
print("\n--- PHYSICAL EXECUTION PLAN (NON-OPTIMIZED) ---")
df_final_risk_raw.explain()

print("\n--- PHYSICAL EXECUTION PLAN (OPTIMIZED) ---")
df_final_risk_opt.explain()

print(f"\nSummary: Baseline ({raw_duration}s) vs Optimized ({opt_duration}s)")


--- STARTING NON-OPTIMIZED SHUFFLES ---
Non-Optimized Pipeline Time: 5.62 seconds

--- STARTING OPTIMIZED SHUFFLES ---
Optimized Pipeline Time: 6.57 seconds

--- PHYSICAL EXECUTION PLAN (NON-OPTIMIZED) ---
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Sort [risk_factor#129 DESC NULLS LAST], true, 0
   +- Exchange rangepartitioning(risk_factor#129 DESC NULLS LAST, 200), ENSURE_REQUIREMENTS, [plan_id=708]
      +- Project [anatom_site_general#19, sex#21, avg_age#119, total_melanoma#121, total_cases#123L, (total_melanoma#121 / cast(total_cases#123L as double)) AS risk_factor#129]
         +- HashAggregate(keys=[anatom_site_general#19, sex#21], functions=[avg(age_approx#18), sum(MEL#45), count(image#17)])
            +- Exchange hashpartitioning(anatom_site_general#19, sex#21, 200), ENSURE_REQUIREMENTS, [plan_id=704]
               +- HashAggregate(keys=[anatom_site_general#19, sex#21], functions=[partial_avg(age_approx#18), partial_sum(MEL#45), partial_count(image#17)])
    