**Use Case 2:- Crop Yield Prediction**

**STEP 1:- Load & inspect the data**

In [0]:
df = spark.table("workspace.default.cropyield")
df.printSchema()
df.show(5)

root
 |-- Crop: string (nullable = true)
 |-- State: string (nullable = true)
 |-- Cost of Cultivation (`/Hectare) A2+FL: double (nullable = true)
 |-- Cost of Cultivation (`/Hectare) C2: double (nullable = true)
 |-- Cost of Production (`/Quintal) C2: double (nullable = true)
 |-- Yield (Quintal/ Hectare) : double (nullable = true)

+-----+--------------+-------------------------------------+----------------------------------+---------------------------------+-------------------------+
| Crop|         State|Cost of Cultivation (`/Hectare) A2+FL|Cost of Cultivation (`/Hectare) C2|Cost of Production (`/Quintal) C2|Yield (Quintal/ Hectare) |
+-----+--------------+-------------------------------------+----------------------------------+---------------------------------+-------------------------+
|ARHAR| Uttar Pradesh|                              9794.05|                          23076.74|                          1941.55|                     9.83|
|ARHAR|     Karnataka|                  

**STEP 2:- Rename columns for consistency**

In [0]:
df = (df
    .withColumnRenamed("Crop", "CROP")
    .withColumnRenamed("State", "STATE")
    .withColumnRenamed("Cost of Cultivation (`/Hectare) A2+FL", "COST_A2FL")
    .withColumnRenamed("Cost of Cultivation (`/Hectare) C2", "COST_C2")
    .withColumnRenamed("Cost of Production (`/Quintal) C2", "COST_PROD")
    .withColumnRenamed("Yield (Quintal/ Hectare) ", "YIELD")
)

**STEP 3:- Clean and filter data**

In [0]:
from pyspark.sql.functions import col, when
from pyspark.sql.types import DoubleType

df = (df
      .withColumn("COST_A2FL", col("COST_A2FL").cast(DoubleType()))
      .withColumn("COST_C2", col("COST_C2").cast(DoubleType()))
      .withColumn("COST_PROD", col("COST_PROD").cast(DoubleType()))
      .withColumn("YIELD", col("YIELD").cast(DoubleType()))
      .filter(col("YIELD").isNotNull())
      .filter(col("YIELD") > 0)
)

display(df.limit(5))

CROP,STATE,COST_A2FL,COST_C2,COST_PROD,YIELD
ARHAR,Uttar Pradesh,9794.05,23076.74,1941.55,9.83
ARHAR,Karnataka,10593.15,16528.68,2172.46,7.47
ARHAR,Gujarat,13468.82,19551.9,1898.3,9.59
ARHAR,Andhra Pradesh,17051.66,24171.65,3670.54,6.42
ARHAR,Maharashtra,17130.55,25270.26,2775.8,8.72


**STEP 4:- Feature engineering**

In [0]:
# Create a binary "high_yield" column (1 if yield > median yield)
median_yield = df.approxQuantile("YIELD", [0.5], 0.01)[0]
df = df.withColumn("HIGH_YIELD", when(col("YIELD") > median_yield, 1).otherwise(0))
display(df.select("CROP", "STATE", "YIELD", "HIGH_YIELD").limit(5))

CROP,STATE,YIELD,HIGH_YIELD
ARHAR,Uttar Pradesh,9.83,0
ARHAR,Karnataka,7.47,0
ARHAR,Gujarat,9.59,0
ARHAR,Andhra Pradesh,6.42,0
ARHAR,Maharashtra,8.72,0


**STEP 5:- SQL exploratory analysis**

In [0]:
df.createOrReplaceTempView("crops")

# Average yield and cost per crop
spark.sql("""
  SELECT CROP, COUNT(*) AS total_records,
         ROUND(AVG(YIELD),2) AS avg_yield,
         ROUND(AVG(COST_C2),2) AS avg_cost,
         ROUND(SUM(HIGH_YIELD)/COUNT(*)*100,2) AS pct_high_yield
  FROM crops
  GROUP BY CROP
  ORDER BY avg_yield DESC
""").show(20)

+--------------------+-------------+---------+--------+--------------+
|                CROP|total_records|avg_yield|avg_cost|pct_high_yield|
+--------------------+-------------+---------+--------+--------------+
|           SUGARCANE|            5|    790.5|79655.03|         100.0|
|               PADDY|            5|     46.3|35768.22|         100.0|
|               WHEAT|            4|     33.9|29923.08|         100.0|
|               MAIZE|            5|     30.8| 23837.3|          80.0|
|              COTTON|            5|    18.77| 42958.2|          80.0|
|RAPESEED AND MUSTARD|            5|    14.32|21223.43|          20.0|
|                GRAM|            5|    10.56|19308.77|          20.0|
|           GROUNDNUT|            5|    10.29|28188.08|           0.0|
|               ARHAR|            5|     8.41|21719.85|           0.0|
|               MOONG|            5|      4.2| 10776.4|           0.0|
+--------------------+-------------+---------+--------+--------------+



**STEP 6:- Prepare ML features**

In [0]:
from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler, StandardScaler
from pyspark.ml import Pipeline

crop_indexer = StringIndexer(inputCol="CROP", outputCol="CROP_idx", handleInvalid="keep")
state_indexer = StringIndexer(inputCol="STATE", outputCol="STATE_idx", handleInvalid="keep")
crop_encoder = OneHotEncoder(inputCols=["CROP_idx"], outputCols=["CROP_ohe"])
state_encoder = OneHotEncoder(inputCols=["STATE_idx"], outputCols=["STATE_ohe"])

assembler = VectorAssembler(
    inputCols=["COST_A2FL", "COST_C2", "COST_PROD", "CROP_ohe", "STATE_ohe"],
    outputCol="features"
)
scaler = StandardScaler(inputCol="features", outputCol="scaledFeatures")


**STEP 7:- Model training (Random Forest Regression)**

In [0]:
from pyspark.ml.regression import RandomForestRegressor

rf = RandomForestRegressor(featuresCol="scaledFeatures", labelCol="YIELD", numTrees=100, maxDepth=10, seed=42)
pipeline = Pipeline(stages=[crop_indexer, state_indexer, crop_encoder, state_encoder, assembler, scaler, rf])

train, test = df.randomSplit([0.8, 0.2], seed=42)
model = pipeline.fit(train)
pred = model.transform(test)

In [0]:
from pyspark.ml.evaluation import RegressionEvaluator

# Evaluate RMSE (Root Mean Squared Error)
rmse_evaluator = RegressionEvaluator(labelCol="YIELD", predictionCol="prediction", metricName="rmse")
rmse = rmse_evaluator.evaluate(pred)
print("RMSE:", round(rmse, 3))

# Evaluate R2 (R-squared)
r2_evaluator = RegressionEvaluator(labelCol="YIELD", predictionCol="prediction", metricName="r2")
r2 = r2_evaluator.evaluate(pred)
print("R-squared:", round(r2, 3))

# Evaluate MAE (Mean Absolute Error)
mae_evaluator = RegressionEvaluator(labelCol="YIELD", predictionCol="prediction", metricName="mae")
mae = mae_evaluator.evaluate(pred)
print("MAE:", round(mae, 3))

pred.createOrReplaceTempView("predictions")
spark.sql("""
SELECT CROP, STATE, YIELD as actual_yield, ROUND(prediction,2) as predicted_yield,
       ROUND(ABS(YIELD - prediction),2) as error
FROM predictions
ORDER BY error DESC
LIMIT 10
""").show()

RMSE: 14.148
R-squared: 0.995
MAE: 9.281
+---------+--------------+------------+---------------+-----+
|     CROP|         STATE|actual_yield|predicted_yield|error|
+---------+--------------+------------+---------------+-----+
|SUGARCANE|Andhra Pradesh|      757.92|         720.36|37.56|
|    MAIZE|         Bihar|       42.95|          65.69|22.74|
|    PADDY|Andhra Pradesh|        56.0|          77.09|21.09|
|   COTTON|   Maharashtra|       12.69|          20.24| 7.55|
|    WHEAT|        Punjab|       39.83|           32.9| 6.93|
|    MAIZE|     Rajasthan|       23.56|          30.22| 6.66|
|    PADDY|   West Bengal|       39.04|           45.2| 6.16|
|    WHEAT|     Rajasthan|       37.19|          32.71| 4.48|
|    MAIZE| Uttar Pradesh|        13.7|          17.81| 4.11|
|     GRAM|     Rajasthan|        6.83|           9.01| 2.18|
+---------+--------------+------------+---------------+-----+



**STEP 8:- Save model and results**

In [0]:
pred.select("CROP","STATE","COST_C2","YIELD","prediction") \
    .write.mode("overwrite").saveAsTable("crop_yield_predictions")

print("✅ Predictions saved as SQL table: crop_yield_predictions")

# Query saved results
spark.sql("""
SELECT CROP, ROUND(AVG(YIELD),2) AS actual_avg_yield, 
       ROUND(AVG(prediction),2) AS predicted_avg_yield
FROM crop_yield_predictions
GROUP BY CROP
ORDER BY actual_avg_yield DESC
""").show()

✅ Predictions saved as SQL table: crop_yield_predictions
+---------+----------------+-------------------+
|     CROP|actual_avg_yield|predicted_avg_yield|
+---------+----------------+-------------------+
|SUGARCANE|          757.92|             720.36|
|    PADDY|           47.52|              61.14|
|    WHEAT|           38.51|               32.8|
|    MAIZE|           26.74|              37.91|
|   COTTON|           12.69|              20.24|
|     GRAM|            8.56|               9.92|
|    MOONG|             6.3|               6.63|
+---------+----------------+-------------------+



**STEP 9:- Visualization**

In [0]:
spark.sql("""
SELECT CROP, COST_C2, YIELD
FROM crops
ORDER BY CROP
""").show()

# Heatmap data: Average yield by Crop and State
spark.sql("""
SELECT CROP, STATE, ROUND(AVG(YIELD),2) AS avg_yield
FROM crops
GROUP BY CROP, STATE
ORDER BY CROP, STATE
""").show()

# Line chart: Yield trend across states for each crop
spark.sql("""
SELECT STATE, CROP, ROUND(AVG(YIELD),2) AS avg_yield
FROM crops
GROUP BY STATE, CROP
ORDER BY STATE, CROP
""").show()

+---------+--------+-----+
|     CROP| COST_C2|YIELD|
+---------+--------+-----+
|    ARHAR|23076.74| 9.83|
|    ARHAR|25270.26| 8.72|
|    ARHAR| 19551.9| 9.59|
|    ARHAR|24171.65| 6.42|
|    ARHAR|16528.68| 7.47|
|   COTTON|33116.82|12.69|
|   COTTON|44756.72|17.83|
|   COTTON|42070.44|19.05|
|   COTTON|44018.18| 19.9|
|   COTTON|50828.83|24.39|
|     GRAM|21618.43|10.93|
|     GRAM|12610.85| 6.83|
|     GRAM|18679.33| 8.05|
|     GRAM|16873.17|10.29|
|     GRAM|26762.09|16.69|
|GROUNDNUT|30393.66|11.98|
|GROUNDNUT|30114.45|13.45|
|GROUNDNUT| 17314.2| 4.71|
|GROUNDNUT|30434.61|11.97|
|GROUNDNUT|32683.46| 9.33|
+---------+--------+-----+
only showing top 20 rows
+---------+--------------+---------+
|     CROP|         STATE|avg_yield|
+---------+--------------+---------+
|    ARHAR|Andhra Pradesh|     6.42|
|    ARHAR|       Gujarat|     9.59|
|    ARHAR|     Karnataka|     7.47|
|    ARHAR|   Maharashtra|     8.72|
|    ARHAR| Uttar Pradesh|     9.83|
|   COTTON|Andhra Pradesh|    1

**SCATTER PLOT**

In [0]:

# Save as table
cost_yield_df.write.mode("overwrite").saveAsTable("crop_cost_yield_analysis")
print("✅ Table 1 saved: crop_cost_yield_analysis")

# Display the table
display(spark.table("crop_cost_yield_analysis"))

✅ Table 1 saved: crop_cost_yield_analysis


CROP,COST_C2,YIELD
ARHAR,23076.74,9.83
ARHAR,16528.68,7.47
ARHAR,19551.9,9.59
ARHAR,24171.65,6.42
ARHAR,25270.26,8.72
COTTON,33116.82,12.69
COTTON,50828.83,24.39
COTTON,44756.72,17.83
COTTON,42070.44,19.05
COTTON,44018.18,19.9


Databricks visualization. Run in Databricks to view.

**LINE GRAPH**

In [0]:
# Line chart data: Yield trend across states for each crop
state_crop_yield_df = spark.sql("""
SELECT STATE, CROP, ROUND(AVG(YIELD),2) AS avg_yield
FROM crops
GROUP BY STATE, CROP
ORDER BY STATE, CROP
""")

# Save as table
state_crop_yield_df.write.mode("overwrite").saveAsTable("state_crop_yield_trends")
print("✅ Table 3 saved: state_crop_yield_trends")

# Display the table
display(spark.table("state_crop_yield_trends"))

✅ Table 3 saved: state_crop_yield_trends


STATE,CROP,avg_yield
Andhra Pradesh,ARHAR,6.42
Andhra Pradesh,COTTON,17.83
Andhra Pradesh,GRAM,16.69
Andhra Pradesh,GROUNDNUT,11.97
Andhra Pradesh,MAIZE,42.68
Andhra Pradesh,MOONG,5.9
Andhra Pradesh,PADDY,56.0
Andhra Pradesh,SUGARCANE,757.92
Bihar,MAIZE,42.95
Gujarat,ARHAR,9.59


Databricks visualization. Run in Databricks to view.