In [1]:
"""Plague Classification Model with ArcGIS Data"""

import os
import shutil
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns
from tensorflow.keras.preprocessing import image_dataset_from_directory
from tensorflow.keras.callbacks import Callback, EarlyStopping, ModelCheckpoint
from tensorflow.keras.models import load_model
from PIL import Image
import logging
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm.notebook import tqdm

# Parámetros de la red
EPOCHS = 100
IMAGE_SIZE = (128, 128)
INPUT_SHAPE = (128, 128, 3)
SEED = 123
BATCH_SIZE = 32
BUFFER_SIZE = 250
LEARNING_RATE = 0.001

# from google.colab import drive
# drive.mount('/content/drive')

# Directorio de las imágenes
images_dir = '../last_data'


# Cargar dataset de imágenes
train_ds = image_dataset_from_directory(
    images_dir,
    labels="inferred",
    batch_size=BATCH_SIZE,
    image_size=IMAGE_SIZE,
    validation_split=0.2,
    subset="training",
    seed=SEED,
    shuffle=True
)

validation_ds = image_dataset_from_directory(
    images_dir,
    labels="inferred",
    batch_size=BATCH_SIZE,
    image_size=IMAGE_SIZE,
    validation_split=0.2,
    subset="validation",
    seed=SEED
)

# Obtener los nombres de las clases antes de mapear
class_names = train_ds.class_names

# Definir la función para aplicar el filtro Sharpen
def apply_sharpen_filter(image, label):
    # Definir el kernel de Sharpen
    sharpen_kernel = tf.constant([[0, -1, 0],
                                  [-1, 5, -1],
                                  [0, -1, 0]], dtype=tf.float32)
    sharpen_kernel = sharpen_kernel[:, :, tf.newaxis, tf.newaxis]  # [3, 3, 1, 1]
    sharpen_kernel = tf.tile(sharpen_kernel, [1, 1, 3, 1])  # [3, 3, 3, 1]

    # Asegurarse de que la imagen es de tipo float32
    image = tf.cast(image, tf.float32)

    # Aplicar la convolución depthwise
    sharpened_image = tf.nn.depthwise_conv2d(image, sharpen_kernel, strides=[1,1,1,1], padding='SAME')

    # Clip y conversión de tipo
    sharpened_image = tf.clip_by_value(sharpened_image, 0, 255)
    sharpened_image = tf.cast(sharpened_image, tf.uint8)

    return sharpened_image, label

# Aplicar el filtro Sharpen a los datasets
train_ds = train_ds.map(apply_sharpen_filter, num_parallel_calls=tf.data.AUTOTUNE)
validation_ds = validation_ds.map(apply_sharpen_filter, num_parallel_calls=tf.data.AUTOTUNE)

# Paso 1: Filtrar el dataset por cada clase usando tf.reduce_any para obtener un tensor booleano escalar
chinche_salivosa_ds = train_ds.filter(lambda x, y: tf.reduce_any(tf.equal(y, 0)))
clororis_ds = train_ds.filter(lambda x, y: tf.reduce_any(tf.equal(y, 1)))
hoja_sana_ds = train_ds.filter(lambda x, y: tf.reduce_any(tf.equal(y, 2)))
roya_naranja_ds = train_ds.filter(lambda x, y: tf.reduce_any(tf.equal(y, 3)))
roya_purpura_ds = train_ds.filter(lambda x, y: tf.reduce_any(tf.equal(y, 4)))

# Paso 2: Contar ejemplos manualmente
def count_examples(dataset):
    return dataset.reduce(0, lambda x, _: x + 1).numpy()

# Obtener el tamaño de la clase mayoritaria
chinche_salivosa_size = count_examples(chinche_salivosa_ds)
clororis_size = count_examples(clororis_ds)
hoja_sana_size = count_examples(hoja_sana_ds)
roya_naranja_size = count_examples(roya_naranja_ds)
roya_purpura_size = count_examples(roya_purpura_ds)

majority_class_size = max(
    chinche_salivosa_size,
    clororis_size,
    hoja_sana_size,
    roya_naranja_size,
    roya_purpura_size
)

# Paso 3: Sobremuestrear las clases minoritarias
chinche_salivosa_ds = chinche_salivosa_ds.repeat().take(majority_class_size)
clororis_ds = clororis_ds.repeat().take(majority_class_size)
hoja_sana_ds = hoja_sana_ds.repeat().take(majority_class_size)
roya_naranja_ds = roya_naranja_ds.repeat().take(majority_class_size)
roya_purpura_ds = roya_purpura_ds.repeat().take(majority_class_size)

# Paso 4: Concatenar los datasets sobremuestreados
oversampled_train_ds = chinche_salivosa_ds.concatenate(clororis_ds)
oversampled_train_ds = oversampled_train_ds.concatenate(hoja_sana_ds)
oversampled_train_ds = oversampled_train_ds.concatenate(roya_naranja_ds)
oversampled_train_ds = oversampled_train_ds.concatenate(roya_purpura_ds)

# Paso 5: Aplicar las operaciones de optimización de cache y prefetch
oversampled_train_ds = oversampled_train_ds.cache().shuffle(BUFFER_SIZE).prefetch(buffer_size=tf.data.AUTOTUNE)
validation_ds = validation_ds.cache().prefetch(buffer_size=tf.data.AUTOTUNE)

# Paso 6: Verificar la nueva distribución de clases
# Filtrar el dataset final por clase para contar los ejemplos nuevamente
oversampled_chinche_salivosa_ds = oversampled_train_ds.filter(lambda x, y: tf.reduce_any(tf.equal(y, 0)))
oversampled_clororis_ds = oversampled_train_ds.filter(lambda x, y: tf.reduce_any(tf.equal(y, 1)))
oversampled_hoja_sana_ds = oversampled_train_ds.filter(lambda x, y: tf.reduce_any(tf.equal(y, 2)))
oversampled_roya_naranja_ds = oversampled_train_ds.filter(lambda x, y: tf.reduce_any(tf.equal(y, 3)))
oversampled_roya_purpura_ds = oversampled_train_ds.filter(lambda x, y: tf.reduce_any(tf.equal(y, 4)))

# Contar la cantidad de ejemplos en cada clase
oversampled_chinche_salivosa_size = count_examples(oversampled_chinche_salivosa_ds)
oversampled_clororis_size = count_examples(oversampled_clororis_ds)
oversampled_hoja_sana_size = count_examples(oversampled_hoja_sana_ds)
oversampled_roya_naranja_size = count_examples(oversampled_roya_naranja_ds)
oversampled_roya_purpura_size = count_examples(oversampled_roya_purpura_ds)

# Mostrar la distribución de clases final
print(f"Nueva distribución de clases después del sobremuestreo:")
print(f"Chinche salivosa: {oversampled_chinche_salivosa_size}")
print(f"Clororis: {oversampled_clororis_size}")
print(f"Hoja sana: {oversampled_hoja_sana_size}")
print(f"Roya naranja: {oversampled_roya_naranja_size}")
print(f"Roya purpura: {oversampled_roya_purpura_size}")

# Cargar el modelo base (MobileNetV2)
from tensorflow.keras.applications import MobileNetV2

base_model = MobileNetV2(input_shape=INPUT_SHAPE,
                         include_top=False,
                         weights='imagenet')

# Ajuste de las capas del modelo base
for layer in base_model.layers[:100]:  # Ajusta según sea necesario
    layer.trainable = False

# Definir el modelo
model = tf.keras.models.Sequential([
    tf.keras.layers.Rescaling(1./255),
    base_model,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(64, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(0.01)),
    tf.keras.layers.Dropout(0.5),  # Aumentamos el Dropout a 0.5
    tf.keras.layers.Dense(len(class_names), activation='softmax')  # Cambiamos la activación a 'softmax'
])

# Compilar el modelo
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=LEARNING_RATE,
    decay_steps=10000,
    decay_rate=0.9,
    staircase=True
)

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr_schedule),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# Definir los callbacks
early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
checkpoint = ModelCheckpoint('best_model.keras', monitor='val_loss', save_best_only=True)

# Definir el callback TQDMCallback
class TQDMCallback(Callback):
    def __init__(self, epochs):
        self.epochs = epochs
        self.progbar = tqdm(total=epochs, desc='Entrenando', unit='época', ncols=90)

    def on_epoch_end(self, epoch, logs=None):
        self.progbar.update(1)
        self.progbar.set_postfix({
            'loss': f"{logs.get('loss'):.4f}",
            'accuracy': f"{logs.get('accuracy'):.4f}",
            'val_loss': f"{logs.get('val_loss'):.4f}",
            'val_accuracy': f"{logs.get('val_accuracy'):.4f}"
        })

    def on_train_end(self, logs=None):
        self.progbar.close()

# Entrenamiento del modelo con los callbacks
history = model.fit(
    oversampled_train_ds,
    validation_data=validation_ds,
    epochs=EPOCHS,
    callbacks=[TQDMCallback(EPOCHS), early_stopping, checkpoint],
)

# Guardar el modelo en formato .h5 y .keras
model.save('final_model.keras')
# model.save('final_model.h5')

# Gráfica de la pérdida y precisión
metrics = history.history
plt.figure(figsize=(16, 6))
plt.subplot(1, 2, 1)
plt.plot(history.epoch, metrics['loss'], label='training')
plt.plot(history.epoch, metrics['val_loss'], label='validation')
plt.legend()
plt.ylim([0, max(plt.ylim())])
plt.ylabel('Loss')
plt.xlabel('Epoch')

plt.subplot(1, 2, 2)
plt.plot(history.epoch, metrics['accuracy'], label='training')
plt.plot(history.epoch, metrics['val_accuracy'], label='validation')
plt.legend()
plt.ylim([0, 1])
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.show()

# Evaluación en el conjunto de test
test_ds = validation_ds.shard(num_shards=2, index=1)
test_results = model.evaluate(test_ds, return_dict=True)
print("Resultados de evaluación en test set:")
for metric, value in test_results.items():
    print(f"{metric}: {value:.4f}")

# Matriz de confusión
y_pred = model.predict(test_ds)
y_pred_classes = np.argmax(y_pred, axis=1)
y_true = np.concatenate([y for x, y in test_ds], axis=0)

conf_matrix = tf.math.confusion_matrix(y_true, y_pred_classes)
plt.figure(figsize=(10, 8))
sns.heatmap(conf_matrix, xticklabels=class_names, yticklabels=class_names, annot=True, fmt='g')
plt.xlabel('Predicción')
plt.ylabel('Etiqueta verdadera')
plt.title('Matriz de Confusión')
plt.show()

# Reporte de clasificación
from sklearn.metrics import classification_report
print("Reporte de clasificación:")
print(classification_report(y_true, y_pred_classes, target_names=class_names))


Found 4308 files belonging to 5 classes.
Using 3447 files for training.
Found 4308 files belonging to 5 classes.
Using 861 files for validation.
Nueva distribución de clases después del sobremuestreo:
Chinche salivosa: 539
Clororis: 533
Hoja sana: 540
Roya naranja: 540
Roya purpura: 540


Entrenando:   0%|                                              | 0/100 [00:00<?, ?época/s]

Epoch 1/100
[1m540/540[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m152s[0m 256ms/step - accuracy: 0.7785 - loss: 1.3963 - val_accuracy: 0.5528 - val_loss: 7.5917
Epoch 2/100
[1m540/540[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m127s[0m 235ms/step - accuracy: 0.9355 - loss: 0.2802 - val_accuracy: 0.6446 - val_loss: 4.2091
Epoch 3/100
[1m540/540[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m129s[0m 239ms/step - accuracy: 0.9706 - loss: 0.1429 - val_accuracy: 0.7921 - val_loss: 3.1822
Epoch 4/100
[1m540/540[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m137s[0m 253ms/step - accuracy: 0.9806 - loss: 0.1067 - val_accuracy: 0.7851 - val_loss: 3.4550
Epoch 5/100
[1m540/540[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m129s[0m 240ms/step - accuracy: 0.9827 - loss: 0.1087 - val_accuracy: 0.5168 - val_loss: 8.0962
Epoch 6/100
[1m540/540[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m126s[0m 234ms/step - accuracy: 0.9810 - loss: 0.1094 - val_accuracy: 0.8223 - val_loss: 1.398

: 