In [49]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import explode, size, col
import json
from pyspark.ml.feature import VectorAssembler, StringIndexer
from pyspark.ml import Pipeline
from pyspark.sql import functions as F

In [50]:
#Saprksession starts here
spark = SparkSession.builder.appName("FoodML").getOrCreate()

In [51]:
#read json
usda_fnds_json = spark.read.format('json').option('multiline', 'true').load("FoodData_Central_foundation_food_json_2025-04-24.json")

In [52]:
# checking the schema
usda_fnds_json.printSchema()

root
 |-- FoundationFoods: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- dataType: string (nullable = true)
 |    |    |-- description: string (nullable = true)
 |    |    |-- fdcId: long (nullable = true)
 |    |    |-- foodAttributes: array (nullable = true)
 |    |    |    |-- element: string (containsNull = true)
 |    |    |-- foodCategory: struct (nullable = true)
 |    |    |    |-- description: string (nullable = true)
 |    |    |-- foodClass: string (nullable = true)
 |    |    |-- foodNutrients: array (nullable = true)
 |    |    |    |-- element: struct (containsNull = true)
 |    |    |    |    |-- amount: double (nullable = true)
 |    |    |    |    |-- dataPoints: long (nullable = true)
 |    |    |    |    |-- foodNutrientDerivation: struct (nullable = true)
 |    |    |    |    |    |-- code: string (nullable = true)
 |    |    |    |    |    |-- description: string (nullable = true)
 |    |    |    |    |    |-- foodNutrientSo

In [53]:
df_foods = usda_fnds_json.select(explode("FoundationFoods").alias("food"))

# explode nutrients
df_nutrients = df_foods.select(
    "food.fdcId",
    "food.description",
    explode("food.foodNutrients").alias("nutrient")
)

df_nutrients.printSchema()

# only nutrient name and value, and dimension is needed, for e.g. gram, miligram
df_nutrients =  df_nutrients.select(
    "fdcId",
    "description",
    col("nutrient.nutrient.name").alias("nutrient_name"),
    col("nutrient.amount").alias("nutrient_value"),
    col("nutrient.nutrient.unitName").alias("nutrient_unit")
)

df_nutrients.printSchema()
df_nutrients.show(200)

root
 |-- fdcId: long (nullable = true)
 |-- description: string (nullable = true)
 |-- nutrient: struct (nullable = true)
 |    |-- amount: double (nullable = true)
 |    |-- dataPoints: long (nullable = true)
 |    |-- foodNutrientDerivation: struct (nullable = true)
 |    |    |-- code: string (nullable = true)
 |    |    |-- description: string (nullable = true)
 |    |    |-- foodNutrientSource: struct (nullable = true)
 |    |    |    |-- code: string (nullable = true)
 |    |    |    |-- description: string (nullable = true)
 |    |    |    |-- id: long (nullable = true)
 |    |-- footnote: string (nullable = true)
 |    |-- id: long (nullable = true)
 |    |-- max: double (nullable = true)
 |    |-- median: double (nullable = true)
 |    |-- min: double (nullable = true)
 |    |-- nutrient: struct (nullable = true)
 |    |    |-- id: long (nullable = true)
 |    |    |-- name: string (nullable = true)
 |    |    |-- number: string (nullable = true)
 |    |    |-- rank: long (nu

In [54]:
# pivot
df_nutrients_pivot = df_nutrients.groupBy("fdcId", "description")\
    .pivot("nutrient_name")\
    .agg(F.first("nutrient_value"))

df_final = df_nutrients_pivot.fillna(0)

df_final.show(5)

+-------+--------------------+------------------------------+-------------------------+-------------------------------------+----------------------------------+-------+--------+-----+-------------+-----------+---------------+---------------+-------+------+--------+--------------+-----------+-----------+-----------+---------------------------+--------------------------+---------------+--------------+---------------+-----------+-------------+-----------------------------------+----------------------------+----------------------------------+---------------------------+--------------+-----------+----------+----------+--------------------+-------------------+--------+-------+--------+-------+-------------------+--------------------+------+--------------------------------+---------------------------------+------------------+-------------------+--------------+----------+-------------+----------------------------------+----------------------------------+----------------------------+-----------

In [63]:
#some cleaning
from pyspark.sql.functions import col

def sanitize_column_name(name):
    return name.replace(" ", "_") \
               .replace(",", "") \
               .replace("(", "") \
               .replace(")", "") \
               .replace(".", "") \
               .replace("/", "_") \
               .replace("-", "_")

cleaned_cols = [col(f"`{c}`").alias(sanitize_column_name(c)) for c in df_final.columns]
df_final = df_final.select(*cleaned_cols)


In [64]:

# feature + label

label_indexer = StringIndexer(inputCol="description",outputCol= "label")

nutrient_cols =[col for col in df_final.columns if col not in["fdcId", "description"]]

assembler = VectorAssembler(inputCols=nutrient_cols, outputCol="features")

In [65]:
#pipeline
pipeline = Pipeline(stages =[label_indexer, assembler])

df_rdy = pipeline.fit(df_final).transform(df_final)
df_rdy.select("label", "features").show(truncate=False)

+-----+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [66]:
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.tuning import TrainValidationSplit, ParamGridBuilder

# Train/Test split
train, test = df_rdy.randomSplit([0.8, 0.2], seed=42)

# model
rf = RandomForestClassifier(featuresCol="features", labelCol="label", numTrees=50)

# train
model = rf.fit(train)

# prediction
predictions = model.transform(test)

# evaluation
evaluator = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(predictions)
print(f"Test Accuracy = {accuracy:.3f}")

Test Accuracy = 0.000


In [67]:
predictions.groupBy("label", "prediction").count().orderBy("label", "prediction").show()

+-----+----------+-----+
|label|prediction|count|
+-----+----------+-----+
| 13.0|     270.0|    1|
| 17.0|      25.0|    1|
| 33.0|     110.0|    1|
| 41.0|      58.0|    1|
| 48.0|     186.0|    1|
| 49.0|      52.0|    1|
| 50.0|     270.0|    1|
| 51.0|     186.0|    1|
| 55.0|      91.0|    1|
| 57.0|      52.0|    1|
| 69.0|     173.0|    1|
| 86.0|      98.0|    1|
| 87.0|      98.0|    1|
| 89.0|      91.0|    1|
| 95.0|      92.0|    1|
| 96.0|      92.0|    1|
|100.0|     134.0|    1|
|106.0|     270.0|    1|
|113.0|     286.0|    1|
|115.0|     279.0|    1|
+-----+----------+-----+
only showing top 20 rows

