# Classification fruits avec MobileNetV2 sur cluster EMR

In [1]:
print(f"Spark Version: {spark.version}")
print(f"Application ID: {spark.sparkContext.applicationId}")

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log,User,Current session?
45,application_1763712756826_0046,pyspark,idle,Link,Link,,✔


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

SparkSession available as 'spark'.


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Spark Version: 3.5.6-amzn-1
Application ID: application_1763712756826_0046

In [2]:
simple_rdd = spark.sparkContext.parallelize([1, 2, 3, 4, 5])
result = simple_rdd.count()
print(f"Test RDD count: {result}")

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Test RDD count: 5

In [3]:
# Configuration S3
import time
start_time = time.time()

BUCKET_NAME = "p8-fruits-jademayalb"
PATH_Data = f's3://{BUCKET_NAME}/Test'
PATH_Result = f's3://{BUCKET_NAME}/Results'
PATH_Result_PCA = f's3://{BUCKET_NAME}/Results_PCA'

print("Configuration AWS S3")
print(f"Bucket: {BUCKET_NAME}")
print(f"Data: {PATH_Data}")
print(f"Results: {PATH_Result}")

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Configuration AWS S3
Bucket: p8-fruits-jademayalb
Data: s3://p8-fruits-jademayalb/Test
Results: s3://p8-fruits-jademayalb/Results

In [4]:
# Chargement données
images = spark.read.format("binaryFile") \
    .option("pathGlobFilter", "*.jpg") \
    .option("recursiveFileLookup", "true") \
    .load(PATH_Data)

print("Comptage images...")
image_count = images.count()
print(f"Images trouvées: {image_count}")

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Chargement images depuis S3...
Comptage images...
Images trouv?es: 22687

In [5]:
from pyspark.sql.functions import element_at, split

# Extraction du nom de dossier comme label
images = images.withColumn('label', element_at(split(images['path'], '/'),-2))

# Vérification
images.printSchema()
images.select('path','label').show(5, False)

# Distribution des classes
images.groupBy('label').count().show()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

root
 |-- path: string (nullable = true)
 |-- modificationTime: timestamp (nullable = true)
 |-- length: long (nullable = true)
 |-- content: binary (nullable = true)
 |-- label: string (nullable = true)

+-------------------------------------------------------+----------+
|path                                                   |label     |
+-------------------------------------------------------+----------+
|s3://p8-fruits-jademayalb/Test/Watermelon/r_106_100.jpg|Watermelon|
|s3://p8-fruits-jademayalb/Test/Watermelon/r_109_100.jpg|Watermelon|
|s3://p8-fruits-jademayalb/Test/Watermelon/r_108_100.jpg|Watermelon|
|s3://p8-fruits-jademayalb/Test/Watermelon/r_107_100.jpg|Watermelon|
|s3://p8-fruits-jademayalb/Test/Watermelon/r_95_100.jpg |Watermelon|
+-------------------------------------------------------+----------+
only showing top 5 rows

+--------------+-----+
|         label|count|
+--------------+-----+
|    Watermelon|  157|
|Pineapple Mini|  163|
|   Cauliflower|  234|
| Cucumber 

In [6]:
# Cache et comptage
images.cache()
total_images = images.count()
print(f"Total images: {total_images}")

# Échantillonnage 
target_sample = 3000
sample_fraction = target_sample / total_images

print(f"Fraction: {sample_fraction:.3f}")

# Échantillonnage 
sampled_images = images.sample(withReplacement=False, 
                              fraction=sample_fraction * 1.2, 
                              seed=42).limit(target_sample)

# Vérification 
final_count = sampled_images.count()
print(f"Échantillon: {final_count} images")

# Distribution 
print("Top 10 classes échantillon:")
sampled_images.groupBy('label').count().orderBy('count', ascending=False).show(10)

# Remplacement
images = sampled_images

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Total images: 22687
Fraction: 0.132
?chantillon: 3000 images
Top 10 classes ?chantillon:
+------------------+-----+
|             label|count|
+------------------+-----+
|     Pepper Orange|   53|
|               Fig|   50|
|          Tomato 3|   43|
|       Cauliflower|   42|
|            Walnut|   42|
|          Tomato 1|   42|
|Melon Piel de Sapo|   41|
|        Pear Stone|   40|
|  Strawberry Wedge|   39|
|        Clementine|   37|
+------------------+-----+
only showing top 10 rows

In [8]:
# Nettoyage modules TensorFlow
import os
import sys
import gc
import importlib

modules_to_clean = [name for name in sys.modules.keys() 
                   if any(x in name.lower() for x in ['tensorflow', 'keras', 'tf', 'optree'])]

for module in modules_to_clean:
    if module in sys.modules:
        del sys.modules[module]

gc.collect()
importlib.invalidate_caches()

# Configuration paths
paths_to_add = [
    '/tmp/python_packages',
    '/tmp/python_user/lib/python3.9/site-packages',
    '/usr/local/lib64/python3.9/site-packages',
    '/usr/local/lib/python3.9/site-packages'
]

for path in paths_to_add:
    if os.path.exists(path) and path not in sys.path:
        sys.path.insert(0, path)

# Variables d'environnement
os.environ['KERAS_BACKEND'] = 'tensorflow'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [9]:
# Imports TensorFlow
from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2, preprocess_input
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras import Model
import numpy as np
import io
from PIL import Image

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [10]:
# Modèle MobileNetV2
model = MobileNetV2(weights='imagenet', include_top=True, input_shape=(224, 224, 3))
print(f"MobileNetV2 chargé: {model.output_shape}")

# Feature extractor (1280D)
feature_model = Model(inputs=model.input, outputs=model.layers[-2].output)
print(f"Feature model: {feature_model.output_shape}")

# Broadcast weights pour optimisation
model_weights = feature_model.get_weights()
weights_size_mb = sum(w.nbytes for w in model_weights) / (1024*1024)
weights_broadcast = spark.sparkContext.broadcast(model_weights)

print(f"Layers broadcastés: {len(model_weights)}")
print(f"Taille mémoire: {weights_size_mb:.1f} MB")

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Initialisation MobileNetV2...
MobileNetV2 charg?: (None, 1000)
Feature model: (None, 1280)
Layers broadcast?s: 260
Taille m?moire: 8.6 MB

In [12]:
#  Extraction features sur échantillon
import time

batch_size = 100
extraction_start = time.time()

image_paths = images.select("path", "label").collect()
total_sample = len(image_paths)
total_batches = (total_sample + batch_size - 1) // batch_size

print(f"Images échantillon: {total_sample}")
print(f"Batch size: {batch_size}")
print(f"Total batches: {total_batches}")
print(f"Temps estimé: {(total_sample / 20) / 60:.1f} minutes")

# Initialisation
all_features = []
all_labels = []
all_paths = []
errors_count = 0

# Processing par batch
for batch_start in range(0, total_sample, batch_size):
    batch_end = min(batch_start + batch_size, total_sample)
    batch_paths = image_paths[batch_start:batch_end]
    
    batch_num = batch_start // batch_size + 1
    
    try:
        # Chargement batch S3
        current_paths = [item['path'] for item in batch_paths]
        batch_df = spark.read.format("binaryFile").load(current_paths)
        batch_data = batch_df.select("path", "content").collect()
        
        # Preprocessing images
        batch_images = []
        batch_labels = []
        batch_paths_clean = []
        
        for data_item in batch_data:
            try:
                path_item = next(p for p in batch_paths if p['path'] == data_item['path'])
                
                img = Image.open(io.BytesIO(data_item['content'])).resize((224, 224))
                if img.mode != 'RGB':
                    img = img.convert('RGB')
                
                img_array = img_to_array(img)
                batch_images.append(preprocess_input(img_array))
                batch_labels.append(path_item['label'])
                batch_paths_clean.append(data_item['path'])
                
            except Exception:
                errors_count += 1
                continue
        
        # Extraction features MobileNetV2
        if batch_images:
            batch_input = np.stack(batch_images)
            features = feature_model.predict(batch_input, verbose=0)
            
            all_features.extend([f.flatten().tolist() for f in features])
            all_labels.extend(batch_labels)
            all_paths.extend(batch_paths_clean)
        
        # Progress monitoring
        elapsed = time.time() - extraction_start
        processed = len(all_features)
        rate = processed / elapsed if elapsed > 0 else 0
        
        if batch_num % 5 == 0 or batch_num == total_batches:
            remaining = total_sample - processed
            eta_minutes = (remaining / rate / 60) if rate > 0 else 0
            progress_pct = (processed / total_sample) * 100
            
            print(f"Batch {batch_num}/{total_batches} | " +
                  f"Progress: {processed}/{total_sample} ({progress_pct:.1f}%) | " +
                  f"Rate: {rate:.1f} img/s | " +
                  f"ETA: {eta_minutes:.1f}min")
        
        # Nettoyage mémoire
        del batch_data, batch_images
        if 'batch_input' in locals():
            del batch_input, features
        gc.collect()
    
    except Exception as e:
        print(f"Erreur batch {batch_num}: {str(e)[:60]}")
        errors_count += len(batch_paths)
        continue

# Résultats extraction
extraction_time = time.time() - extraction_start
successful_images = len(all_features)

print(f"Images traitées: {successful_images}")
print(f"Images échouées: {errors_count}")
print(f"Temps total: {extraction_time:.1f}s ({extraction_time/60:.1f} min)")
print(f"Performance: {successful_images/extraction_time:.1f} img/sec")

if successful_images == 0:
    print("ERREUR: Aucune image traitée")
    exit()

# Création DataFrame Spark
from pyspark.sql.types import StructType, StructField, StringType, ArrayType, FloatType

schema_features = StructType([
    StructField("path", StringType(), True),
    StructField("label", StringType(), True),
    StructField("features", ArrayType(FloatType()), True)
])

data_rows = list(zip(all_paths, all_labels, all_features))
df_features = spark.createDataFrame(data_rows, schema_features)
df_features.cache()

feature_count = df_features.count()
print(f"DataFrame créé: {feature_count} lignes")

# Vérification structure
df_features.printSchema()
df_features.select("path", "label").show(5, False)

# PCA 1280D -> 100D
from pyspark.ml.feature import PCA
from pyspark.ml.linalg import Vectors, VectorUDT
from pyspark.sql.functions import udf

pca_start = time.time()

def array_to_vector(arr):
    return Vectors.dense(arr)

vector_udf = udf(array_to_vector, VectorUDT())

df_vectors = df_features.withColumn("features_vector", vector_udf("features"))

pca = PCA(k=100, inputCol="features_vector", outputCol="pca_features")
pca_model = pca.fit(df_vectors)

df_pca = pca_model.transform(df_vectors)
df_final = df_pca.select("path", "label", "pca_features")

pca_time = time.time() - pca_start
print(f"PCA terminé: {pca_time:.1f}s")

# Analyse variance expliquée
explained_variance = pca_model.explainedVariance
total_variance = sum(explained_variance)
cumulative_variance = np.cumsum(explained_variance)

print(f"Variance totale: {total_variance:.4f} ({total_variance*100:.2f}%)")
print(f"Variance 10 comp: {cumulative_variance[9]:.4f} ({cumulative_variance[9]*100:.2f}%)")
print(f"Variance 50 comp: {cumulative_variance[49]:.4f} ({cumulative_variance[49]*100:.2f}%)")

# Sauvegarde S3
PATH_Result_Sample = f's3://{BUCKET_NAME}/Results_Sample_3k'
PATH_Result_PCA_Sample = f's3://{BUCKET_NAME}/Results_PCA_Sample_3k'
PATH_Metadata_Sample = f's3://{BUCKET_NAME}/Metadata_Sample_3k'

save_start = time.time()

try:
    df_features.write.mode("overwrite").parquet(PATH_Result_Sample)
    
    df_final.write.mode("overwrite").parquet(PATH_Result_PCA_Sample)
    
    # Métadonnées expérience
    from datetime import datetime
    
    metadata_data = [
        ("experiment_date", datetime.now().strftime('%Y-%m-%d %H:%M:%S')),
        ("sample_size", str(successful_images)),
        ("extraction_time_minutes", str(extraction_time/60)),
        ("pca_time_seconds", str(pca_time)),
        ("performance_img_per_sec", str(successful_images/extraction_time)),
        ("pca_variance_explained", str(total_variance)),
        ("model_used", "MobileNetV2"),
        ("feature_dimension_original", "1280"),
        ("feature_dimension_pca", "100"),
        ("batch_size", str(batch_size)),
        ("errors_count", str(errors_count)),
        ("sampling_method", "simple_random")
    ]
    
    schema_metadata = StructType([
        StructField("parameter", StringType(), True),
        StructField("value", StringType(), True)
    ])
    
    df_metadata = spark.createDataFrame(metadata_data, schema_metadata)
    df_metadata.write.mode("overwrite").parquet(PATH_Metadata_Sample)
    
    save_time = time.time() - save_start
    print(f"Sauvegarde terminée: {save_time:.1f}s")
    
except Exception as e:
    print(f"Erreur sauvegarde: {e}")

# Vérification finale
try:
    test_features = spark.read.parquet(PATH_Result_Sample)
    features_count = test_features.count()
    print(f"Features 1280D: {features_count} lignes")
    
    test_pca = spark.read.parquet(PATH_Result_PCA_Sample)
    pca_count = test_pca.count()
    print(f"Features PCA 100D: {pca_count} lignes")
    
    print("Structure PCA:")
    test_pca.printSchema()
    test_pca.select("path", "label").show(5, False)
    
    test_metadata = spark.read.parquet(PATH_Metadata_Sample)
    print(f"Métadonnées: {test_metadata.count()} paramètres")
    test_metadata.show(15, False)
    
except Exception as e:
    print(f"Erreur vérification: {e}")

# Résultats finaux
total_pipeline_time = time.time() - extraction_start

print(f"Images traitées: {successful_images:,}")
print(f"Temps extraction: {extraction_time/60:.1f} minutes")
print(f"Temps PCA: {pca_time:.1f} secondes")
print(f"Temps total: {total_pipeline_time/60:.1f} minutes")
print(f"Performance: {successful_images/extraction_time:.1f} img/s")

print(f"- Features 1280D: {PATH_Result_Sample}")
print(f"- Features PCA 100D: {PATH_Result_PCA_Sample}")
print(f"- Métadonnées: {PATH_Metadata_Sample}")

print(f"\nTerminé: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Images ?chantillon: 3000
Batch size: 100
Total batches: 30
Temps estim?: 2.5 minutes
Batch 5/30 | Progress: 500/3000 (16.7%) | Rate: 16.6 img/s | ETA: 2.5min
Batch 10/30 | Progress: 1000/3000 (33.3%) | Rate: 17.2 img/s | ETA: 1.9min
Batch 15/30 | Progress: 1500/3000 (50.0%) | Rate: 17.2 img/s | ETA: 1.5min
Batch 20/30 | Progress: 2000/3000 (66.7%) | Rate: 17.2 img/s | ETA: 1.0min
Batch 25/30 | Progress: 2500/3000 (83.3%) | Rate: 17.3 img/s | ETA: 0.5min
Batch 30/30 | Progress: 3000/3000 (100.0%) | Rate: 17.5 img/s | ETA: 0.0min
Images trait?es: 3000
Images ?chou?es: 0
Temps total: 171.3s (2.9 min)
Performance: 17.5 img/sec
DataFrame cr??: 3000 lignes
root
 |-- path: string (nullable = true)
 |-- label: string (nullable = true)
 |-- features: array (nullable = true)
 |    |-- element: float (containsNull = true)

+-------------------------------------------------------+----------+
|path                                                   |label     |
+-------------------------------------