In [6]:
from CherryTreeDataset import CherryTreeDataset
from torchvision import transforms
from funciones_auxiliares import plot_spectra, analyze_image, analyze_tiff_metadata, PATH, crop_central_region, set_seed, seed_worker
from resnet_adapters import adapt_resnet_channels
import torch
from torch.utils.data import DataLoader, SubsetRandomSampler, Subset
import torchvision.models as models
from torchvision.models import ResNet18_Weights, ResNet50_Weights
from sklearn.metrics import confusion_matrix
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import torch.nn as nn
from sklearn.metrics import precision_recall_fscore_support, accuracy_score, roc_auc_score

set_seed(42)
g = torch.Generator()
g.manual_seed(42)

Semillas aleatorias configuradas a: 42


<torch._C.Generator at 0x7fbd0f116bb0>

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'El dispositivo seleccionado es {device}')

El dispositivo seleccionado es cuda


In [8]:
# Define las transformaciones si son necesarias
transform = transforms.Compose([
    transforms.Resize((1280,960)),
    #transforms.Lambda(lambda x: crop_central_region(x, center_ratio=0.8)),
    transforms.ToTensor(),
])

formats = ('RGB.JPG', 'RED.TIF','GRE.TIF','NIR.TIF','REG.TIF')
#formats = ('RGB.JPG','NIR.TIF','REG.TIF')
#formats = ('RGB.JPG',)
input_channels = 7
dataset = CherryTreeDataset(PATH, transform=transform, formats = formats, concatenate = True, balance=False)

healthy_indices = []
disease_indices = []
    
for i, (_, label) in enumerate(dataset.samples):
    if label == 0:  # Healthy
        healthy_indices.append(i)
    else:  # Disease
        disease_indices.append(i)
    
    # Dividir los índices de árboles sanos en entrenamiento y validación
np.random.shuffle(healthy_indices)

train_healthy = healthy_indices[:int(0.8 * len(healthy_indices))]
test_healthy = healthy_indices[int(0.8 * len(healthy_indices)):]
train_disease = disease_indices[:int(0.8 * len(disease_indices))]
test_disease = disease_indices[int(0.8 * len(disease_indices)):]
    
    # Los índices de árboles enfermos solo se usan para testing
    # Create datasets
train_dataset = Subset(dataset, train_healthy)
val_dataset = Subset(dataset, test_healthy + test_disease)  # Incluimos enfermos solo para evaluación
    
    # Crear dataloader
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4, worker_init_fn=seed_worker)
test_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4, worker_init_fn=seed_worker)


In [9]:
class ConvAutoencoder(nn.Module):
    def __init__(self, input_channels):
        super(ConvAutoencoder, self).__init__()
        
        # Encoder: Reducción progresiva de la dimensionalidad a través de más capas convolucionales
        self.encoder = nn.Sequential(
            # Primera capa convolucional
            nn.Conv2d(input_channels, 32, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            
            # Segunda capa convolucional
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            
            # Tercera capa convolucional
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            
            # Cuarta capa convolucional
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
        )
        
        # Decodificador
        self.decoder = nn.Sequential(
            # Primera capa de transposición convolucional
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            
            # Segunda capa de transposición convolucional
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            
            # Tercera capa de transposición convolucional
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            
            # Capa final para reconstruir la imagen original
            nn.ConvTranspose2d(32, input_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()  # Para normalizar la salida entre 0 y 1
        )
        # Inicialización de pesos
        self._initialize_weights()

    def forward(self, x):
        # Propagación hacia adelante del encoder y decoder
        x = self.encoder(x)
        x = self.decoder(x)
        return x

    def encode(self, x):
        return self.encoder(x)

    def _initialize_weights(self):
        # Inicialización de los pesos de manera más robusta
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu', generator=torch.Generator().manual_seed(42))
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

class AnomalyDetector:
    def __init__(self, input_channels, device='cuda', latent_dim=256):
        self.device = device
        self.model = ConvAutoencoder(input_channels).to(device)
        self.input_channels = input_channels
        self.threshold = None
        self.latent_dim = latent_dim
    
    def train(self, train_loader, val_loader, epochs=50, lr=1e-3, weight_decay=1e-5):
        criterion = nn.MSELoss()
        optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, weight_decay=weight_decay)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
        
        best_val_loss = float('inf')
        train_losses = []
        val_losses = []
        
        for epoch in range(epochs):
            # Entrenamiento (train_loader ya solo contiene árboles sanos)
            self.model.train()
            train_loss = 0
            for data, _ in train_loader:
                data = data.to(self.device)
                optimizer.zero_grad()
                outputs = self.model(data)
                loss = criterion(outputs, data)
                loss.backward()
                optimizer.step()
                train_loss += loss.item()
            
            train_loss /= len(train_loader)
            train_losses.append(train_loss)
            
            # Validación (solo evaluamos el error de reconstrucción en árboles sanos)
            self.model.eval()
            val_loss = 0
            with torch.no_grad():
                for data, labels in val_loader:
                    # Solo evaluar en árboles sanos
                    healthy_mask = labels == 0
                    if not any(healthy_mask):
                        continue
                        
                    data = data[healthy_mask].to(self.device)
                    outputs = self.model(data)
                    loss = criterion(outputs, data)
                    val_loss += loss.item()
            
            val_loss /= len(val_loader)
            val_losses.append(val_loss)
            
            # Actualizar el scheduler
            scheduler.step(val_loss)
            
            # Guardar el mejor modelo
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save(self.model.state_dict(), 'best_autoencoder.pth')
                print(f'Modelo guardado epoca: {epoch+1}')
            
            print(f'Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}')
        
        # Cargar el mejor modelo
        self.model.load_state_dict(torch.load('best_autoencoder.pth'))
        return train_losses, val_losses
    
    def calculate_threshold(self, val_loader, percentile=95):
        """Calcula el umbral de error de reconstrucción basado en los datos de validación de árboles sanos"""
        self.model.eval()
        reconstruction_errors = []
        
        with torch.no_grad():
            for data, labels in val_loader:
                # Solo usar árboles sanos para calibrar el umbral
                healthy_mask = labels == 0
                if not any(healthy_mask):
                    continue
                    
                healthy_data = data[healthy_mask].to(self.device)
                outputs = self.model(healthy_data)
                
                # Calcular error por imagen
                for i in range(healthy_data.size(0)):
                    error = torch.mean((outputs[i] - healthy_data[i])**2).item()
                    reconstruction_errors.append(error)
        
        # Establecer umbral en el percentil especificado
        self.threshold = np.percentile(reconstruction_errors, percentile)
        print(f'Threshold set to: {self.threshold:.6f} (percentile {percentile})')
        return self.threshold
    
    def predict(self, data_loader):
        """Predice si un árbol está enfermo basándose en el error de reconstrucción"""
        if self.threshold is None:
            raise ValueError("Threshold must be set before prediction using calculate_threshold")
        
        self.model.eval()
        all_preds = []
        all_labels = []
        all_scores = []
        
        with torch.no_grad():
            for data, labels in data_loader:
                data = data.to(self.device)
                outputs = self.model(data)
                
                # Calcular error de reconstrucción por imagen
                for i in range(data.size(0)):
                    recon_error = torch.mean((outputs[i] - data[i])**2).item()
                    # Si el error supera el umbral, clasificar como enfermo (anomalía)
                    prediction = 1 if recon_error > self.threshold else 0
                    all_preds.append(prediction)
                    all_labels.append(labels[i].item())
                    all_scores.append(recon_error)
        
        return np.array(all_preds), np.array(all_labels), np.array(all_scores)
    
    def evaluate(self, data_loader):
        """Evalúa el rendimiento del detector de anomalías"""
        predictions, labels, scores = self.predict(data_loader)
        
        # Calcular métricas
        accuracy = accuracy_score(labels, predictions)
        precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')
        
        # Calcular AUC-ROC
        auc_roc = roc_auc_score(labels, scores)
        
        results = {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1_score': f1,
            'auc_roc': auc_roc
        }
        
        return results, predictions, labels, scores

In [13]:
#detector = AnomalyDetector(input_channels=input_channels, device=device)
print("Training autoencoder...")
train_losses, val_losses = detector.train(train_loader, test_loader, epochs=5) 
    # Calcular umbral
print("Calculating threshold...")
threshold = detector.calculate_threshold(train_loader, percentile=70) 
    # Evaluar en el conjunto de prueba
print("Evaluating model...")
results, predictions, labels, scores = detector.evaluate(test_loader)  
    # Imprimir resultados
print("\nResults:")
for metric, value in results.items():
    print(f"{metric}: {value:.4f}")

Training autoencoder...
Modelo guardado epoca: 1
Epoch 1/5, Train Loss: 0.885666, Val Loss: 0.827089
Epoch 2/5, Train Loss: 0.878936, Val Loss: 0.827873
Modelo guardado epoca: 3
Epoch 3/5, Train Loss: 0.878583, Val Loss: 0.826702
Epoch 4/5, Train Loss: 0.878400, Val Loss: 0.827060
Epoch 5/5, Train Loss: 0.878547, Val Loss: 0.828194
Calculating threshold...


  self.model.load_state_dict(torch.load('best_autoencoder.pth'))


Threshold set to: 0.935662 (percentile 70)
Evaluating model...

Results:
accuracy: 0.7083
precision: 0.1500
recall: 0.6429
f1_score: 0.2432
auc_roc: 0.8032
