In [None]:
"""Plague Classification Model with ArcGIS Data - Undersampling Implementation"""

import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns
from tensorflow.keras.preprocessing import image_dataset_from_directory
from PIL import Image

# Parámetros de la red
EPOCHS = 50
IMAGE_SIZE = (128, 128)
INPUT_SHAPE = (128, 128, 3)
SEED = 123
BATCH_SIZE = 32
BUFFER_SIZE = 1000
LEARNING_RATE = 1e-4

images_dir = '../arcgis-survey-images'

# Cargar dataset de imágenes con batch_size=BATCH_SIZE
train_ds = image_dataset_from_directory(
    images_dir,
    labels="inferred",
    batch_size=BATCH_SIZE,
    image_size=IMAGE_SIZE,
    validation_split=0.2,
    subset="training",
    seed=SEED,
    shuffle=True
)

validation_ds = image_dataset_from_directory(
    images_dir,
    labels="inferred",
    batch_size=BATCH_SIZE,
    image_size=IMAGE_SIZE,
    validation_split=0.2,
    subset="validation",
    seed=SEED
)

class_names = train_ds.class_names

# Obtener las etiquetas de entrenamiento
all_labels = np.concatenate([y.numpy() for x, y in train_ds], axis=0)
unique_classes, counts = np.unique(all_labels, return_counts=True)
class_counts = dict(zip(unique_classes, counts))
min_count = min(class_counts.values())
print(f"Tamaño de la clase minoritaria: {min_count}")

# Desagrupar el dataset antes de filtrar
train_ds_unbatched = train_ds.unbatch()

# Crear un diccionario para almacenar los datasets por clase
datasets_per_class = {}
for class_index in unique_classes:
    filtered_ds = train_ds_unbatched.filter(lambda x, y: y == class_index)
    datasets_per_class[class_index] = filtered_ds

# Submuestrear las clases mayoritarias
balanced_datasets = []
for class_index, dataset in datasets_per_class.items():
    class_count = class_counts[class_index]
    if class_count > min_count:
        # Tomar una muestra aleatoria de min_count ejemplos
        sampled_ds = dataset.shuffle(BUFFER_SIZE).take(min_count)
    else:
        sampled_ds = dataset
    balanced_datasets.append(sampled_ds)

# Combinar los datasets balanceados
balanced_train_ds = balanced_datasets[0]
for ds in balanced_datasets[1:]:
    balanced_train_ds = balanced_train_ds.concatenate(ds)

# Aplicar data augmentation al dataset balanceado
data_augmentation = tf.keras.Sequential([
    tf.keras.layers.RandomFlip('horizontal_and_vertical'),
    tf.keras.layers.RandomRotation(0.2),
    tf.keras.layers.RandomZoom(0.1),
])

augmented_balanced_train_ds = balanced_train_ds.map(
    lambda x, y: (data_augmentation(x, training=True), y),
    num_parallel_calls=tf.data.AUTOTUNE
)

# Aplicar batching y prefetching
augmented_balanced_train_ds = augmented_balanced_train_ds.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).prefetch(buffer_size=tf.data.AUTOTUNE)
validation_ds = validation_ds.cache().prefetch(buffer_size=tf.data.AUTOTUNE)

# Verificar que las clases están balanceadas
balanced_class_counts = np.zeros(len(class_names))
for images, labels in augmented_balanced_train_ds:
    for label in labels:
        balanced_class_counts[label.numpy()] += 1

print("\nDistribución de clases en el dataset balanceado:")
for i, count in enumerate(balanced_class_counts):
    print(f"Clase: {class_names[i]}, Número de ejemplos: {int(count)}")

# Definir el modelo
base_model = tf.keras.applications.MobileNetV2(input_shape=INPUT_SHAPE,
                                               include_top=False,
                                               weights='imagenet')
base_model.trainable = True

model = tf.keras.models.Sequential([
    tf.keras.layers.Rescaling(1./127.5, offset=-1),
    base_model,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(128, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(0.01)),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(64, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(0.01)),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(len(class_names), activation='softmax')
])

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['accuracy'])

# Callbacks
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=3, min_lr=1e-6)

# Entrenamiento del modelo
history = model.fit(
    augmented_balanced_train_ds,
    validation_data=validation_ds,
    epochs=EPOCHS,
    callbacks=[early_stopping, reduce_lr]
)


Found 3289 files belonging to 5 classes.
Using 2632 files for training.
Found 3289 files belonging to 5 classes.
Using 657 files for validation.
Tamaño de la clase minoritaria: 181

Distribución de clases en el dataset balanceado:
Clase: Chinche salivosa, Número de ejemplos: 181
Clase: Clororis, Número de ejemplos: 181
Clase: Hoja sana, Número de ejemplos: 181
Clase: Roya naranja, Número de ejemplos: 181
Clase: Roya purpura, Número de ejemplos: 181
Epoch 1/50
      4/Unknown [1m293s[0m 2s/step - accuracy: 0.1784 - loss: 6.0480

In [None]:
# Guardar el modelo en formato .keras
model.save('Under.keras')