# Clasificación de Muestras Geológicas con Deep Learning (Google Colab)

Dataset: 5 clases (calcite, pyrite, quartz, Rocks, superficies_texturizadas)

Total: ~19,800 imágenes procedentes de tres fuentes:
- Minerales: calcite, pyrite, quartz (clases independientes)
- Rocas: agrupadas en una sola clase (Rocks)
- Superficies texturizadas: cracked, porous, wrinkled (agrupadas en una clase)

Modelos: CustomCNN (baseline) + ResNet18 (transfer learning)

**NOTA**: Asegúrate de activar GPU en Runtime > Change runtime type > Hardware accelerator > GPU

In [None]:
# Verificar GPU disponible
!nvidia-smi

In [None]:
# Clonar repositorio (opcional, si usas GitHub)
# !git clone https://github.com/tu-usuario/tu-repo.git
# %cd tu-repo

# O montar Google Drive (si tienes el proyecto ahí)
# from google.colab import drive
# drive.mount('/content/drive')
# %cd /content/drive/MyDrive/AF

In [None]:
# Configuración inicial
import sys
sys.path.append('/content')

import torch
import matplotlib.pyplot as plt
import random

from src.dataset import get_dataloaders
from src.models import CustomCNN, get_resnet18
from src.train import train_model
from src.evaluate import evaluate_model, plot_confusion_matrix, plot_training_history
from src.utils import predict_from_dataset, visualize_prediction_from_dataset, visualize_dataset_samples

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")
if DEVICE.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 1. Carga del Dataset

In [None]:
# Carga de datos geológicos
DATA_DIR = 'data'
BATCH_SIZE = 32

train_loader, val_loader, test_loader, class_names = get_dataloaders(
    DATA_DIR,
    batch_size=BATCH_SIZE,
    val_split=0.15,
    test_split=0.15
)

print(f"Clases: {len(class_names)}")
print(f"Training samples: {len(train_loader.dataset)}")
print(f"Validation samples: {len(val_loader.dataset)}")
print(f"Test samples: {len(test_loader.dataset)}")

In [None]:
# Clases del dataset
print("Clases geológicas:")
for i, class_name in enumerate(class_names, 1):
    print(f"{i}. {class_name}")

## 2. Visualización del Dataset

In [None]:
# Muestra aleatoria de imágenes del dataset
visualize_dataset_samples(train_loader.dataset.dataset, class_names, n_samples=16)

## 3. Entrenamiento Custom CNN

Arquitectura: 4 bloques convolucionales (32→64→128→256)

Dataset: ~19,800 imágenes, 5 clases geológicas

Configuración óptima: 30 epochs, lr=0.001

In [None]:
# Entrenar CustomCNN
print("="*70)
print("ENTRENAMIENTO CUSTOM CNN")
print("="*70)
print("Configuración óptima: 30 epochs, lr=0.001, weight_decay=1e-4")
print("="*70)

model_cnn = CustomCNN(num_classes=len(class_names))

history_cnn = train_model(
    model_cnn,
    train_loader,
    val_loader,
    epochs=30,
    lr=0.001,
    weight_decay=1e-4,
    device=DEVICE
)

print(f"\nResultados CustomCNN:")
print(f"  Train Accuracy: {history_cnn['train_acc'][-1]:.2f}%")
print(f"  Val Accuracy: {history_cnn['val_acc'][-1]:.2f}%")
print(f"  Val Loss: {history_cnn['val_loss'][-1]:.4f}")

In [None]:
# Visualización curvas de entrenamiento CustomCNN
plot_training_history(history_cnn)

## 4. Entrenamiento ResNet18 (Transfer Learning)

Transfer Learning desde ImageNet

Frozen: conv1, bn1, layer1, layer2

Trainable: layer3, layer4, fc

Configuración óptima: 15 epochs, lr=0.0005

In [None]:
# Entrenar ResNet18
print("="*70)
print("ENTRENAMIENTO RESNET18 (TRANSFER LEARNING)")
print("="*70)
print("Configuración óptima: 15 epochs, lr=0.0005, weight_decay=1e-4")
print("="*70)

model_resnet = get_resnet18(num_classes=len(class_names), pretrained=True, freeze_layers=True)

history_resnet = train_model(
    model_resnet,
    train_loader,
    val_loader,
    epochs=15,
    lr=0.0005,
    weight_decay=1e-4,
    device=DEVICE
)

print(f"\nResultados ResNet18:")
print(f"  Train Accuracy: {history_resnet['train_acc'][-1]:.2f}%")
print(f"  Val Accuracy: {history_resnet['val_acc'][-1]:.2f}%")
print(f"  Val Loss: {history_resnet['val_loss'][-1]:.4f}")

In [None]:
# Visualización curvas de entrenamiento ResNet18
plot_training_history(history_resnet)

## 5. Evaluación en Test Set

In [None]:
# Evaluación Custom CNN en test set
print("="*70)
print("EVALUACIÓN: CUSTOM CNN")
print("="*70)
results_cnn = evaluate_model(model_cnn, test_loader, class_names, device=DEVICE)

In [None]:
# Evaluación ResNet18 en test set
print("\n" + "="*70)
print("EVALUACIÓN: RESNET18")
print("="*70)
results_resnet = evaluate_model(model_resnet, test_loader, class_names, device=DEVICE)

## 6. Comparación Final

In [None]:
# Comparación final en test set
print("\n" + "="*70)
print("COMPARACIÓN FINAL - TEST SET")
print("="*70)
print(f"Custom CNN: {results_cnn['accuracy']:.2f}%")
print(f"ResNet18: {results_resnet['accuracy']:.2f}%")
print(f"\nMejora con Transfer Learning: {results_resnet['accuracy'] - results_cnn['accuracy']:.2f}%")
print("="*70)

## 7. Matrices de Confusión

In [None]:
# Matriz de confusión CustomCNN
print("Matriz de Confusión - Custom CNN")
plot_confusion_matrix(
    results_cnn['labels'],
    results_cnn['predictions'],
    class_names,
    figsize=(10, 8)
)

In [None]:
# Matriz de confusión ResNet18
print("Matriz de Confusión - ResNet18")
plot_confusion_matrix(
    results_resnet['labels'],
    results_resnet['predictions'],
    class_names,
    figsize=(10, 8)
)

## 8. Predicción en Imagen del Test Set

In [None]:
# Predicción con CustomCNN en imagen aleatoria del test set
test_dataset = test_loader.dataset.dataset
random_idx = random.randint(0, len(test_dataset) - 1)

predictions_cnn, true_label, image = predict_from_dataset(
    test_dataset,
    model_cnn,
    class_names,
    random_idx,
    device=DEVICE,
    top_k=5
)

print("PREDICCIÓN CUSTOM CNN")
print(f"True Label: {true_label}")
print("\nTop 5 predictions:")
for i, (class_name, prob) in enumerate(predictions_cnn, 1):
    print(f"{i}. {class_name}: {prob:.2f}%")

visualize_prediction_from_dataset(image, predictions_cnn[:3], true_label)

In [None]:
# Predicción con ResNet18 en la misma imagen
predictions_resnet, _, _ = predict_from_dataset(
    test_dataset,
    model_resnet,
    class_names,
    random_idx,
    device=DEVICE,
    top_k=5
)

print("PREDICCIÓN RESNET18")
print(f"True Label: {true_label}")
print("\nTop 5 predictions:")
for i, (class_name, prob) in enumerate(predictions_resnet, 1):
    print(f"{i}. {class_name}: {prob:.2f}%")

visualize_prediction_from_dataset(image, predictions_resnet[:3], true_label)

## 9. Guardar Ambos Modelos

In [None]:
# Guardar ambos modelos
import os
os.makedirs('models', exist_ok=True)

# Guardar CustomCNN
torch.save(model_cnn.state_dict(), 'models/custom_cnn.pth')
print("CustomCNN guardado en: models/custom_cnn.pth")
print(f"  Test Accuracy: {results_cnn['accuracy']:.2f}%")

# Guardar ResNet18
torch.save(model_resnet.state_dict(), 'models/resnet18.pth')
print("\nResNet18 guardado en: models/resnet18.pth")
print(f"  Test Accuracy: {results_resnet['accuracy']:.2f}%")

In [None]:
# Descargar modelos (opcional)
from google.colab import files
files.download('models/custom_cnn.pth')
files.download('models/resnet18.pth')