In [1]:
# Configuración inicial
import sys
sys.path.append('..')

import torch
import matplotlib.pyplot as plt
import random

from src.dataset import get_dataloaders
from src.models import CustomCNN, get_resnet18
from src.train import train_model
from src.evaluate import evaluate_model, plot_confusion_matrix, plot_training_history
from src.utils import predict_from_dataset, visualize_prediction_from_dataset, visualize_dataset_samples

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

Using device: cpu


In [None]:
# Carga de datos geológicos
DATA_DIR = '../data'
BATCH_SIZE = 32

train_loader, val_loader, test_loader, class_names = get_dataloaders(
    DATA_DIR,
    batch_size=BATCH_SIZE,
    val_split=0.15,
    test_split=0.15
)

print(f"Clases: {len(class_names)}")
print(f"Training samples: {len(train_loader.dataset)}")
print(f"Validation samples: {len(val_loader.dataset)}")
print(f"Test samples: {len(test_loader.dataset)}")

In [None]:
# Clases del dataset
print("Clases geológicas:")
for i, class_name in enumerate(class_names, 1):
    print(f"{i}. {class_name}")

In [None]:
# Muestra aleatoria de imágenes del dataset
visualize_dataset_samples(train_loader.dataset.dataset, class_names, n_samples=16)

In [None]:
# Entrenamiento CustomCNN (configuración óptima)
print("="*70)
print("CUSTOM CNN - CONFIGURACIÓN ÓPTIMA")
print("="*70)
print("Epochs: 30 | LR: 0.001 | Weight Decay: 1e-4")
print("="*70)

model_cnn = CustomCNN(num_classes=len(class_names))

history_cnn = train_model(
    model_cnn,
    train_loader,
    val_loader,
    epochs=30,
    lr=0.001,
    weight_decay=1e-4,
    device=DEVICE
)

print(f"\nResultados CustomCNN:")
print(f"  Train Accuracy: {history_cnn['train_acc'][-1]:.2f}%")
print(f"  Val Accuracy: {history_cnn['val_acc'][-1]:.2f}%")
print(f"  Val Loss: {history_cnn['val_loss'][-1]:.4f}")
print(f"  Overfitting: {history_cnn['train_acc'][-1] - history_cnn['val_acc'][-1]:.2f}%")

In [None]:
# Visualización curvas de entrenamiento CustomCNN
plot_training_history(history_cnn)

In [None]:
# Entrenamiento ResNet18 (configuración óptima)
print("="*70)
print("RESNET18 (TRANSFER LEARNING) - CONFIGURACIÓN ÓPTIMA")
print("="*70)
print("Epochs: 15 | LR: 0.0005 | Weight Decay: 1e-4")
print("="*70)

model_resnet = get_resnet18(num_classes=len(class_names), pretrained=True, freeze_layers=True)

history_resnet = train_model(
    model_resnet,
    train_loader,
    val_loader,
    epochs=15,
    lr=0.0005,
    weight_decay=1e-4,
    device=DEVICE
)

print(f"\nResultados ResNet18:")
print(f"  Train Accuracy: {history_resnet['train_acc'][-1]:.2f}%")
print(f"  Val Accuracy: {history_resnet['val_acc'][-1]:.2f}%")
print(f"  Val Loss: {history_resnet['val_loss'][-1]:.4f}")
print(f"  Overfitting: {history_resnet['train_acc'][-1] - history_resnet['val_acc'][-1]:.2f}%")

In [None]:
# Visualización curvas de entrenamiento ResNet18
plot_training_history(history_resnet)

In [None]:
# Comparación visual de ambos modelos
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(15, 5))
fig.suptitle('Comparación: CustomCNN vs ResNet18', fontsize=16, fontweight='bold')

# CustomCNN
ax = axes[0]
ax.plot(history_cnn['train_acc'], label='Train Accuracy', linewidth=2)
ax.plot(history_cnn['val_acc'], label='Val Accuracy', linewidth=2)
ax.set_title('Custom CNN\n(epochs=30, lr=0.001)')
ax.set_xlabel('Epoch')
ax.set_ylabel('Accuracy (%)')
ax.legend()
ax.grid(True, alpha=0.3)

# ResNet18
ax = axes[1]
ax.plot(history_resnet['train_acc'], label='Train Accuracy', linewidth=2)
ax.plot(history_resnet['val_acc'], label='Val Accuracy', linewidth=2)
ax.set_title('ResNet18 (Transfer Learning)\n(epochs=15, lr=0.0005)')
ax.set_xlabel('Epoch')
ax.set_ylabel('Accuracy (%)')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Evaluación Custom CNN en test set
print("="*70)
print("EVALUACIÓN: CUSTOM CNN")
print("="*70)
results_cnn = evaluate_model(model_cnn, test_loader, class_names, device=DEVICE)

In [None]:
# Evaluación ResNet18 en test set
print("\n" + "="*70)
print("EVALUACIÓN: RESNET18")
print("="*70)
results_resnet = evaluate_model(model_resnet, test_loader, class_names, device=DEVICE)

In [None]:
# Comparación final en test set
print("\n" + "="*70)
print("COMPARACIÓN FINAL - TEST SET")
print("="*70)
print(f"Custom CNN: {results_cnn['accuracy']:.2f}%")
print(f"ResNet18: {results_resnet['accuracy']:.2f}%")
print(f"\nMejora con Transfer Learning: {results_resnet['accuracy'] - results_cnn['accuracy']:.2f}%")
print("="*70)

# Selección del modelo ganador
if results_resnet['accuracy'] > results_cnn['accuracy']:
    winner_name = "ResNet18"
    winner_model = model_resnet
    winner_results = results_resnet
else:
    winner_name = "Custom CNN"
    winner_model = model_cnn
    winner_results = results_cnn

print(f"\nMODELO GANADOR: {winner_name}")
print(f"Test Accuracy: {winner_results['accuracy']:.2f}%")

In [None]:
# Matriz de confusión del modelo ganador
plot_confusion_matrix(
    winner_results['labels'],
    winner_results['predictions'],
    class_names,
    figsize=(10, 8)
)

In [None]:
# Matriz de confusión del modelo ganador
plot_confusion_matrix(
    winner_results['labels'],
    winner_results['predictions'],
    class_names,
    figsize=(10, 8)
)

In [None]:
# Predicción en imagen aleatoria del test set
test_dataset = test_loader.dataset.dataset
random_idx = random.randint(0, len(test_dataset) - 1)

predictions, true_label, image = predict_from_dataset(
    test_dataset,
    winner_model,
    class_names,
    random_idx,
    device=DEVICE,
    top_k=5
)

print(f"True Label: {true_label}")
print("\nTop 5 predictions:")
for i, (class_name, prob) in enumerate(predictions, 1):
    print(f"{i}. {class_name}: {prob:.2f}%")

In [None]:
# Visualización de la predicción
visualize_prediction_from_dataset(image, predictions[:3], true_label)

In [None]:
# Guardar modelo ganador
MODEL_PATH = '../models/best_model.pth'
torch.save(winner_model.state_dict(), MODEL_PATH)

print("="*70)
print("MODELO GUARDADO")
print("="*70)
print(f"Modelo: {winner_name}")
print(f"Path: {MODEL_PATH}")
print(f"Test Accuracy: {winner_results['accuracy']:.2f}%")
print("="*70)