In [None]:
"""
AVANCE1:

üå± AGRO-AI VISION PRO - Sistema de Diagn√≥stico de Enfermedades en Plantas
==========================================================================
Sistema avanzado de detecci√≥n de enfermedades con:
- Transfer Learning (MobileNetV2)
- Data Augmentation
- Visualizaciones profesionales
- Matriz de confusi√≥n
- Predicciones con probabilidades
- Exportaci√≥n de reportes
- Interfaz mejorada para demo

Autor: M.I.N.D RESEARCH GROUP
Hackathon: [SAMSUMG INNOVATION CAMPUS]
"""

import os
import json
from google.colab import files
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from sklearn.metrics import classification_report, confusion_matrix
import pathlib
from datetime import datetime

# ==========================================
# üîß CONFIGURACI√ìN INICIAL
# ==========================================
print("="*60)
print("üå± AGRO-AI VISION PRO - Inicializando Sistema")
print("="*60)

# Limpieza de configuraci√≥n anterior
if os.path.exists('/root/.kaggle'):
    !rm -rf /root/.kaggle
    print("‚úÖ Configuraci√≥n anterior limpiada")

# Cargar credenciales de Kaggle
print("\nüì• Cargue su archivo kaggle.json:")
uploaded = files.upload()

!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
print("‚úÖ Credenciales de Kaggle configuradas")

# Descargar dataset
print("\nüì¶ Descargando dataset PlantVillage...")
!kaggle datasets download -d emmarex/plantdisease
!unzip -q plantdisease.zip
print("‚úÖ Dataset descargado y descomprimido")

# ==========================================
# üìä CONFIGURACI√ìN DEL MODELO
# ==========================================
# Configuraci√≥n optimizada para balance entre precisi√≥n y velocidad
BATCH_SIZE = 32
IMG_HEIGHT = 224  # Tama√±o √≥ptimo para MobileNetV2
IMG_WIDTH = 224
EPOCHS = 20
LEARNING_RATE = 0.001

# Verificar ruta del dataset
if os.path.exists('PlantVillage'):
    data_dir = pathlib.Path("PlantVillage")
    print("‚úÖ Dataset PlantVillage encontrado")
else:
    data_dir = pathlib.Path(".")
    print("‚ö†Ô∏è Usando directorio actual")

# ==========================================
# üé® DATA AUGMENTATION AVANZADO
# ==========================================
print("\nüé® Configurando Data Augmentation...")

data_augmentation = tf.keras.Sequential([
    layers.RandomFlip("horizontal_and_vertical"),
    layers.RandomRotation(0.2),
    layers.RandomZoom(0.2),
    layers.RandomContrast(0.2),
    layers.RandomBrightness(0.2),
])

# ==========================================
# üìÇ CARGA DE DATOS
# ==========================================
print("\nüìÇ Cargando datos de entrenamiento...")
train_ds = tf.keras.utils.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=(IMG_HEIGHT, IMG_WIDTH),
    batch_size=BATCH_SIZE
)

print("üìÇ Cargando datos de validaci√≥n...")
val_ds = tf.keras.utils.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=(IMG_HEIGHT, IMG_WIDTH),
    batch_size=BATCH_SIZE
)

class_names = train_ds.class_names
num_classes = len(class_names)
print(f"\n‚úÖ {num_classes} clases detectadas")
print(f"üìã Primeras 10 clases: {class_names[:10]}")

# Optimizaci√≥n de rendimiento
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

# ==========================================
# üèóÔ∏è CONSTRUCCI√ìN DEL MODELO CON TRANSFER LEARNING
# ==========================================
print("\nüèóÔ∏è Construyendo modelo con Transfer Learning (MobileNetV2)...")

def create_model(num_classes, img_height, img_width):
    """
    Crea un modelo de CNN usando Transfer Learning con MobileNetV2
    pre-entrenado en ImageNet
    """
    # Cargar modelo base pre-entrenado
    base_model = MobileNetV2(
        input_shape=(img_height, img_width, 3),
        include_top=False,
        weights='imagenet'
    )
    
    # Congelar las capas base inicialmente
    base_model.trainable = False
    
    # Construir el modelo completo
    inputs = tf.keras.Input(shape=(img_height, img_width, 3))
    
    # Data augmentation (solo en entrenamiento)
    x = data_augmentation(inputs)
    
    # Preprocesamiento espec√≠fico de MobileNetV2
    x = tf.keras.applications.mobilenet_v2.preprocess_input(x)
    
    # Modelo base
    x = base_model(x, training=False)
    
    # Capas superiores personalizadas
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.3)(x)
    x = layers.Dense(512, activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.3)(x)
    x = layers.Dense(256, activation='relu')(x)
    x = layers.Dropout(0.2)(x)
    
    # Capa de salida
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    
    model = models.Model(inputs, outputs)
    
    return model, base_model

model, base_model = create_model(num_classes, IMG_HEIGHT, IMG_WIDTH)

# Compilar modelo
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy', tf.keras.metrics.TopKCategoricalAccuracy(k=3, name='top_3_accuracy')]
)

print("\nüìã Arquitectura del Modelo:")
model.summary()

# ==========================================
# üéØ CALLBACKS AVANZADOS
# ==========================================
print("\nüéØ Configurando callbacks...")

# Early Stopping: Detiene el entrenamiento si no hay mejora
early_stopping = EarlyStopping(
    monitor='val_accuracy',
    patience=5,
    restore_best_weights=True,
    verbose=1
)

# Reduce Learning Rate: Reduce LR cuando la m√©trica se estanca
reduce_lr = ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.5,
    patience=3,
    min_lr=1e-7,
    verbose=1
)

# Model Checkpoint: Guarda el mejor modelo
checkpoint = ModelCheckpoint(
    'best_model.keras',
    monitor='val_accuracy',
    save_best_only=True,
    verbose=1
)

callbacks = [early_stopping, reduce_lr, checkpoint]

# ==========================================
# üöÄ FASE 1: ENTRENAMIENTO INICIAL
# ==========================================
print("\n" + "="*60)
print("üöÄ FASE 1: Entrenamiento con capas base congeladas")
print("="*60)

history_phase1 = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=10,
    callbacks=callbacks,
    verbose=1
)

# ==========================================
# üîì FASE 2: FINE-TUNING
# ==========================================
print("\n" + "="*60)
print("üîì FASE 2: Fine-tuning - Descongelando capas superiores")
print("="*60)

# Descongelar las √∫ltimas 50 capas del modelo base
base_model.trainable = True
for layer in base_model.layers[:-50]:
    layer.trainable = False

print(f"Capas entrenables: {len([l for l in model.layers if l.trainable])}")

# Recompilar con learning rate m√°s bajo
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE/10),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy', tf.keras.metrics.TopKCategoricalAccuracy(k=3, name='top_3_accuracy')]
)

# Continuar entrenamiento
history_phase2 = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=10,
    initial_epoch=len(history_phase1.history['loss']),
    callbacks=callbacks,
    verbose=1
)

# Combinar historiales
history = {
    'accuracy': history_phase1.history['accuracy'] + history_phase2.history['accuracy'],
    'val_accuracy': history_phase1.history['val_accuracy'] + history_phase2.history['val_accuracy'],
    'loss': history_phase1.history['loss'] + history_phase2.history['loss'],
    'val_loss': history_phase1.history['val_loss'] + history_phase2.history['val_loss'],
}

print("\n‚úÖ Entrenamiento completado!")

# ==========================================
# üìä VISUALIZACI√ìN DE RESULTADOS
# ==========================================
print("\nüìä Generando visualizaciones...")

def plot_training_history(history):
    """Visualiza el hist√≥rico de entrenamiento"""
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    
    # Accuracy
    axes[0, 0].plot(history['accuracy'], label='Train Accuracy', linewidth=2)
    axes[0, 0].plot(history['val_accuracy'], label='Val Accuracy', linewidth=2)
    axes[0, 0].set_title('Model Accuracy', fontsize=14, fontweight='bold')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Accuracy')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # Loss
    axes[0, 1].plot(history['loss'], label='Train Loss', linewidth=2)
    axes[0, 1].plot(history['val_loss'], label='Val Loss', linewidth=2)
    axes[0, 1].set_title('Model Loss', fontsize=14, fontweight='bold')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Loss')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # Accuracy comparison bar
    final_train_acc = history['accuracy'][-1]
    final_val_acc = history['val_accuracy'][-1]
    axes[1, 0].bar(['Train', 'Validation'], [final_train_acc, final_val_acc], 
                   color=['#2ecc71', '#3498db'], width=0.5)
    axes[1, 0].set_title('Final Accuracy Comparison', fontsize=14, fontweight='bold')
    axes[1, 0].set_ylabel('Accuracy')
    axes[1, 0].set_ylim([0, 1])
    for i, v in enumerate([final_train_acc, final_val_acc]):
        axes[1, 0].text(i, v + 0.02, f'{v:.4f}', ha='center', fontweight='bold')
    
    # Metrics summary
    axes[1, 1].axis('off')
    metrics_text = f"""
    üìä RESUMEN DE M√âTRICAS
    {'='*40}
    
    üéØ Accuracy Final (Train): {final_train_acc:.4f}
    ‚úÖ Accuracy Final (Val):   {final_val_acc:.4f}
    
    üìâ Loss Final (Train):     {history['loss'][-1]:.4f}
    üìâ Loss Final (Val):       {history['val_loss'][-1]:.4f}
    
    üìà Mejor Val Accuracy:     {max(history['val_accuracy']):.4f}
    üìâ Mejor Val Loss:         {min(history['val_loss']):.4f}
    
    üîÑ √âpocas Entrenadas:      {len(history['accuracy'])}
    """
    axes[1, 1].text(0.1, 0.5, metrics_text, fontsize=11, 
                    family='monospace', verticalalignment='center')
    
    plt.tight_layout()
    plt.savefig('training_history.png', dpi=300, bbox_inches='tight')
    plt.show()
    print("‚úÖ Gr√°fico guardado como 'training_history.png'")

plot_training_history(history)

# ==========================================
# üéØ MATRIZ DE CONFUSI√ìN
# ==========================================
print("\nüéØ Generando matriz de confusi√≥n...")

def generate_confusion_matrix(model, val_ds, class_names):
    """Genera y visualiza la matriz de confusi√≥n"""
    y_true = []
    y_pred = []
    
    for images, labels in val_ds:
        predictions = model.predict(images, verbose=0)
        y_pred.extend(np.argmax(predictions, axis=1))
        y_true.extend(labels.numpy())
    
    # Calcular matriz de confusi√≥n
    cm = confusion_matrix(y_true, y_pred)
    
    # Si hay muchas clases, mostrar solo las 20 m√°s frecuentes
    if len(class_names) > 20:
        top_classes_idx = np.argsort(np.sum(cm, axis=1))[-20:]
        cm = cm[top_classes_idx][:, top_classes_idx]
        selected_classes = [class_names[i] for i in top_classes_idx]
    else:
        selected_classes = class_names
    
    # Visualizar
    plt.figure(figsize=(16, 14))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=selected_classes, 
                yticklabels=selected_classes,
                cbar_kws={'label': 'Frecuencia'})
    plt.title('Matriz de Confusi√≥n - Top 20 Clases', fontsize=16, fontweight='bold', pad=20)
    plt.xlabel('Predicci√≥n', fontsize=12)
    plt.ylabel('Real', fontsize=12)
    plt.xticks(rotation=90)
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig('confusion_matrix.png', dpi=300, bbox_inches='tight')
    plt.show()
    print("‚úÖ Matriz guardada como 'confusion_matrix.png'")
    
    # Reporte de clasificaci√≥n
    print("\nüìã Reporte de Clasificaci√≥n (Top 10 clases):")
    report = classification_report(y_true, y_pred, target_names=class_names, 
                                    output_dict=True, zero_division=0)
    
    # Mostrar las 10 mejores clases por F1-score
    class_scores = [(name, data['f1-score']) for name, data in report.items() 
                    if name not in ['accuracy', 'macro avg', 'weighted avg']]
    class_scores.sort(key=lambda x: x[1], reverse=True)
    
    print("\nüèÜ Top 10 Clases por F1-Score:")
    for i, (class_name, score) in enumerate(class_scores[:10], 1):
        print(f"{i:2d}. {class_name:40s} F1: {score:.4f}")
    
    return cm, report

cm, report = generate_confusion_matrix(model, val_ds, class_names)

# ==========================================
# üîç SISTEMA DE PREDICCI√ìN AVANZADO
# ==========================================
print("\nüîç Configurando sistema de predicci√≥n...")

def diagnose_plant(model, image, class_names, top_k=5):
    """
    Realiza diagn√≥stico de una planta con probabilidades detalladas
    """
    # Preparar imagen
    img_array = tf.expand_dims(image, 0)
    
    # Hacer predicci√≥n
    predictions = model.predict(img_array, verbose=0)
    probabilities = tf.nn.softmax(predictions[0]).numpy()
    
    # Obtener top K predicciones
    top_indices = np.argsort(probabilities)[-top_k:][::-1]
    top_probs = probabilities[top_indices]
    top_classes = [class_names[i] for i in top_indices]
    
    return {
        'predicted_class': top_classes[0],
        'confidence': float(top_probs[0]),
        'top_predictions': list(zip(top_classes, [float(p) for p in top_probs]))
    }

def visualize_prediction(image, result, class_names):
    """Visualiza una predicci√≥n con estilo profesional"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
    
    # Imagen
    ax1.imshow(image.astype("uint8"))
    ax1.axis('off')
    ax1.set_title(f'Muestra Analizada', fontsize=14, fontweight='bold')
    
    # Probabilidades
    classes, probs = zip(*result['top_predictions'])
    colors = plt.cm.RdYlGn(np.array(probs))
    y_pos = np.arange(len(classes))
    
    bars = ax2.barh(y_pos, probs, color=colors)
    ax2.set_yticks(y_pos)
    ax2.set_yticklabels([c.replace('___', ' - ').replace('_', ' ') for c in classes])
    ax2.set_xlabel('Probabilidad', fontsize=12)
    ax2.set_title('Top 5 Predicciones', fontsize=14, fontweight='bold')
    ax2.set_xlim([0, 1])
    
    # A√±adir valores en las barras
    for i, (bar, prob) in enumerate(zip(bars, probs)):
        ax2.text(prob + 0.02, i, f'{prob:.2%}', 
                va='center', fontweight='bold')
    
    # A√±adir grid
    ax2.grid(axis='x', alpha=0.3)
    ax2.invert_yaxis()
    
    plt.tight_layout()
    return fig

# ==========================================
# üé¨ DEMO: DIAGN√ìSTICO EN TIEMPO REAL
# ==========================================
print("\n" + "="*60)
print("üé¨ DEMO: Sistema de Diagn√≥stico en Tiempo Real")
print("="*60)

# Tomar m√∫ltiples muestras para demo
print("\nüî¨ Analizando 6 muestras aleatorias del conjunto de validaci√≥n...")

fig, axes = plt.subplots(3, 2, figsize=(16, 20))
axes = axes.ravel()

image_batch, label_batch = next(iter(val_ds))

for idx in range(6):
    image = image_batch[idx].numpy().astype("uint8")
    true_label = class_names[label_batch[idx]]
    
    # Hacer diagn√≥stico
    result = diagnose_plant(model, image_batch[idx], class_names)
    
    # Visualizar
    axes[idx].imshow(image)
    axes[idx].axis('off')
    
    # Determinar si es correcto
    is_correct = result['predicted_class'] == true_label
    color = '#2ecc71' if is_correct else '#e74c3c'
    status = '‚úÖ CORRECTO' if is_correct else '‚ùå INCORRECTO'
    
    title = f"{status}\n"
    title += f"Predicci√≥n: {result['predicted_class'].replace('___', ' - ').replace('_', ' ')}\n"
    title += f"Confianza: {result['confidence']:.1%}\n"
    title += f"Real: {true_label.replace('___', ' - ').replace('_', ' ')}"
    
    axes[idx].set_title(title, fontsize=10, fontweight='bold', 
                        color=color, pad=10)

plt.suptitle('üå± AGRO-AI VISION PRO - Resultados de Diagn√≥stico', 
             fontsize=16, fontweight='bold', y=0.995)
plt.tight_layout()
plt.savefig('demo_predictions.png', dpi=300, bbox_inches='tight')
plt.show()
print("‚úÖ Demo guardado como 'demo_predictions.png'")

# ==========================================
# üíæ GUARDAR MODELO Y METADATOS
# ==========================================
print("\nüíæ Guardando modelo y metadatos...")

# Guardar modelo
model.save('agro_ai_vision_pro.keras')
print("‚úÖ Modelo guardado como 'agro_ai_vision_pro.keras'")

# Guardar metadatos
metadata = {
    'model_name': 'AGRO-AI VISION PRO',
    'version': '1.0.0',
    'created_at': datetime.now().isoformat(),
    'architecture': 'MobileNetV2 + Transfer Learning',
    'image_size': [IMG_HEIGHT, IMG_WIDTH],
    'num_classes': num_classes,
    'class_names': class_names,
    'final_accuracy': float(history['val_accuracy'][-1]),
    'final_loss': float(history['val_loss'][-1]),
    'best_accuracy': float(max(history['val_accuracy'])),
    'epochs_trained': len(history['accuracy']),
    'batch_size': BATCH_SIZE,
    'learning_rate': LEARNING_RATE,
}

with open('model_metadata.json', 'w') as f:
    json.dump(metadata, f, indent=2)
print("‚úÖ Metadatos guardados como 'model_metadata.json'")

# ==========================================
# üìà REPORTE FINAL
# ==========================================
print("\n" + "="*60)
print("üìà REPORTE FINAL DEL SISTEMA")
print("="*60)

print(f"""
üéØ M√âTRICAS DE RENDIMIENTO:
   ‚Ä¢ Accuracy Final (Validaci√≥n): {metadata['final_accuracy']:.2%}
   ‚Ä¢ Mejor Accuracy:              {metadata['best_accuracy']:.2%}
   ‚Ä¢ Loss Final:                  {metadata['final_loss']:.4f}
   ‚Ä¢ √âpocas Entrenadas:           {metadata['epochs_trained']}

üß† ARQUITECTURA:
   ‚Ä¢ Modelo Base:                 {metadata['architecture']}
   ‚Ä¢ Tama√±o de Imagen:            {IMG_HEIGHT}x{IMG_WIDTH}
   ‚Ä¢ Clases Detectables:          {num_classes}
   ‚Ä¢ Par√°metros Totales:          {model.count_params():,}

üìä CAPACIDADES:
   ‚Ä¢ Transfer Learning con ImageNet
   ‚Ä¢ Data Augmentation avanzado
   ‚Ä¢ Fine-tuning adaptativo
   ‚Ä¢ Predicciones con Top-K
   ‚Ä¢ Matriz de confusi√≥n
   ‚Ä¢ Visualizaciones profesionales

üíæ ARCHIVOS GENERADOS:
   ‚Ä¢ agro_ai_vision_pro.keras     (Modelo entrenado)
   ‚Ä¢ model_metadata.json          (Metadatos)
   ‚Ä¢ training_history.png         (Gr√°ficos de entrenamiento)
   ‚Ä¢ confusion_matrix.png         (Matriz de confusi√≥n)
   ‚Ä¢ demo_predictions.png         (Demo de predicciones)
   ‚Ä¢ best_model.keras             (Mejor checkpoint)

üöÄ LISTO PARA PRODUCCI√ìN!
""")

print("="*60)
print("‚úÖ Sistema AGRO-AI VISION PRO completamente entrenado!")
print("="*60)

# ==========================================
# üì• DESCARGAR ARCHIVOS
# ==========================================
print("\nüì• ¬ødescargar los archivos generados? (y/n)")
download = input()

if download.lower() == 'y':
    files.download('agro_ai_vision_pro.keras')
    files.download('model_metadata.json')
    files.download('training_history.png')
    files.download('confusion_matrix.png')
    files.download('demo_predictions.png')
    print("‚úÖ Archivos descargados exitosamente!")
else:
    print("‚ÑπÔ∏è Puede descargar los archivos manualmente desde el panel de archivos")

print("\nüéâ ¬°Gracias por usar AGRO-AI VISION PRO!")
