In [3]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, trim, initcap, when

# 1. Initialize Spark
spark = SparkSession.builder \
    .appName("KenyaAgriAnalysis") \
    .getOrCreate()

# 2. Load Dataset
# Note: Use the actual path to your saved csv
df = spark.read.csv("agri_data.csv", header=True, inferSchema=True)

# 3. Clean and Transform
# - trim() removes the space in "Narok "
# - initcap() fixes "narok" to "Narok"
# - fillna() handles the missing Maize production in Trans Nzoia
df_cleaned = df.withColumn("County", initcap(trim(col("County")))) \
               .fillna({"Production_Tonnes": 0})

# 4. Spark SQL Analysis
df_cleaned.createOrReplaceTempView("kenya_crops")

# Query: Find the most productive counties for Cereal crops (Maize, Wheat, Rice)
results = spark.sql("""
    SELECT 
        County, 
        SUM(Production_Tonnes) as Total_Tonnes,
        ROUND(AVG(Production_Tonnes / Area_Hectares), 2) as Yield_Efficiency
    FROM kenya_crops
    WHERE Crop IN ('Maize', 'Wheat', 'Rice') AND Year = 2023
    GROUP BY County
    ORDER BY Total_Tonnes DESC
""")

results.show()

# Stop Session
spark.stop()

+-----------+------------+----------------+
|     County|Total_Tonnes|Yield_Efficiency|
+-----------+------------+----------------+
|Trans Nzoia|      480000|            4.57|
|Uasin Gishu|      450000|            4.59|
|      Narok|      280000|            3.29|
|     Nakuru|      185000|            4.11|
|  Kirinyaga|       75000|            6.25|
+-----------+------------+----------------+

