In [None]:
import os
import boto3
from botocore.client import Config
from PIL import Image
import numpy as np
import json
import pandas as pd

from pyspark.sql import SparkSession
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import ArrayType, FloatType

# Spark ML Imports
from pyspark.ml.feature import StringIndexer, PCA
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml import Pipeline
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.linalg import Vectors, VectorUDT

print("=== Starting MinIO Image Training Script (Pandas UDF + Spark ML) ===")

# -----------------------------
# Configuration and Data Loading
# -----------------------------
print("Configuring MinIO client...")
minio_endpoint = 'http://minio:9000'
minio_access_key = 'minioadmin'
minio_secret_key = 'minioadmin'
bucket_name = 'dev'
local_image_dir = '/tmp/minio_images/'

print("Initializing MinIO client...")
s3 = boto3.client(
    's3',
    endpoint_url=minio_endpoint,
    aws_access_key_id=minio_access_key,
    aws_secret_access_key=minio_secret_key,
    config=Config(signature_version='s3v4')
)

if not os.path.exists(local_image_dir):
    print("Creating local directory for images:", local_image_dir)
    os.makedirs(local_image_dir)
else:
    print("Local image directory exists:", local_image_dir)

print("Listing and downloading images from MinIO bucket...")
image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.bmp'}
downloaded_files = []  # each entry: (local_file_path, label)

list_response = s3.list_objects_v2(Bucket=bucket_name)
if 'Contents' in list_response:
    for obj in list_response['Contents']:
        key = obj['Key']  # e.g., "antelope/27a5369441.jpg"
        if any(key.lower().endswith(ext) for ext in image_extensions):
            local_file_path = os.path.join(local_image_dir, key)
            os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
            s3.download_file(bucket_name, key, local_file_path)
            label = key.split('/')[0] if "/" in key else "unknown"
            downloaded_files.append((local_file_path, label))
else:
    print("No objects found in bucket:", bucket_name)

total_downloaded = len(downloaded_files)
print("Finished downloading images. Total images downloaded:", total_downloaded)

print("Walking local image directory to list images...")
image_data_list = []
for root, dirs, files in os.walk(local_image_dir):
    for file in files:
        if any(file.lower().endswith(ext) for ext in image_extensions):
            file_path = os.path.join(root, file)
            rel_path = os.path.relpath(file_path, local_image_dir)
            label = rel_path.split(os.sep)[0]
            image_data_list.append((file_path, label))
print("Total image files found (by walking directory):", len(image_data_list))

# -----------------------------
# Create Spark Session and DataFrame
# -----------------------------
print("Creating Spark session...")
spark = SparkSession.builder.appName("MinIO Image Training PandasUDF + Spark ML").getOrCreate()

print("Creating Spark DataFrame with image paths and labels...")
df = spark.createDataFrame(image_data_list, schema=["image_path", "label"]).persist()
df.show(5, truncate=False)

# -----------------------------
# Define Image Processing via Pandas UDF
# -----------------------------
print("Defining image loading function...")

def load_image(file_path):
    try:
        img = Image.open(file_path)
        img = img.resize((128, 128))  # resize to fixed size
        img = np.array(img)
        # Debug: print out shape and dtype for each image load (optional, may slow processing)
        print("Loaded image:", file_path, "-> shape:", img.shape, "dtype:", img.dtype)
        if img.ndim == 2:
            img = np.stack((img,) * 3, axis=-1)
        elif img.shape[-1] == 4:
            img = img[:, :, :3]
        if img.shape != (128, 128, 3):
            raise ValueError(f"Unexpected image shape: {img.shape}")
        return img.astype("float32") / 255.0  # Normalize image
    except Exception as e:
        print(f"Error loading image {file_path}: {e}")
        return np.zeros((128, 128, 3), dtype=np.float32)

print("Defining Pandas UDF to preprocess images...")

@pandas_udf(ArrayType(FloatType()))
def process_image_udf(path_series: pd.Series) -> pd.Series:
    # Apply the load_image function to each file path and flatten the image to a list of floats
    return path_series.apply(lambda path: load_image(path).flatten().tolist())

print("Applying Pandas UDF to DataFrame (this may take a while)...")
df = df.withColumn("image_data", process_image_udf("image_path"))
# Show a small subset without printing huge arrays
df.select("image_path", "label").show(5, truncate=50)

# -----------------------------
# Convert Feature Array to Dense Vector
# -----------------------------
print("Converting image_data to DenseVector...")

def to_vector(arr):
    # Convert list of floats to a DenseVector
    return Vectors.dense(arr) if arr is not None else None

from pyspark.sql.functions import udf as spark_udf
vector_udf = spark_udf(to_vector, VectorUDT())
df = df.withColumn("features", vector_udf("image_data"))
df.select("image_path", "label", "features").show(5, truncate=50)

# -----------------------------
# Label Indexing
# -----------------------------
print("Indexing labels...")
from pyspark.ml.feature import StringIndexer
label_indexer = StringIndexer(inputCol="label", outputCol="label_index").fit(df)
df = label_indexer.transform(df)
df.select("label", "label_index").show(5)

# -----------------------------
# Dimensionality Reduction with PCA
# -----------------------------
print("Applying PCA for dimensionality reduction...")
# Original feature dimension = 128*128*3 = 49152. Reducing to a lower dimension, e.g., k=100.
pca = PCA(k=100, inputCol="features", outputCol="pca_features")

# -----------------------------
# Classification using RandomForest
# -----------------------------
print("Defining RandomForestClassifier...")
rf = RandomForestClassifier(featuresCol="pca_features", labelCol="label_index", numTrees=10)

# Create the ML pipeline: PCA transformation followed by RF classifier
pipeline = Pipeline(stages=[pca, rf])

# -----------------------------
# Split Data, Train Model and Evaluate
# -----------------------------
print("Splitting data into training and test sets...")
(train_df, test_df) = df.randomSplit([0.8, 0.2], seed=42)

print("Training the Spark ML pipeline model...")
try:
    model = pipeline.fit(train_df)
except Exception as e:
    print("Error during model training:", e)
    spark.stop()
    exit(1)

print("Making predictions on test set...")
predictions = model.transform(test_df)
predictions.select("label", "label_index", "prediction").show(5)

print("Evaluating model performance...")
evaluator = MulticlassClassificationEvaluator(labelCol="label_index", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(predictions)
print("Test set accuracy:", accuracy)

print("=== MinIO Image Training and Spark ML Classification Finished ===")
spark.stop()