In [1]:
from CherryTreeDataset import CherryTreeDataset
from torchvision import transforms
from funciones_auxiliares import plot_spectra, analyze_image, analyze_tiff_metadata, PATH, crop_central_region, set_seed, seed_worker
from resnet_adapters import adapt_resnet_channels
import torch
from torch.utils.data import DataLoader, SubsetRandomSampler, Subset
import torchvision.models as models
from torchvision.models import ResNet18_Weights, ResNet50_Weights
from sklearn.metrics import confusion_matrix
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import torch.nn as nn
from sklearn.metrics import precision_recall_fscore_support, accuracy_score, roc_auc_score

set_seed(42)
g = torch.Generator()
g.manual_seed(42)

Semillas aleatorias configuradas a: 42


<torch._C.Generator at 0x7f95b9f47b30>

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'El dispositivo seleccionado es {device}')

El dispositivo seleccionado es cuda


In [3]:
# Define las transformaciones si son necesarias
transform = transforms.Compose([
    transforms.Resize((1280,960)),
    #transforms.Lambda(lambda x: crop_central_region(x, center_ratio=0.8)),
    transforms.ToTensor(),
])

#formats = ('RGB.JPG', 'RED.TIF','GRE.TIF','NIR.TIF','REG.TIF')
#formats = ('RGB.JPG','NIR.TIF','REG.TIF')
formats = ('RGB.JPG',)
input_channels = 3
dataset = CherryTreeDataset(PATH, transform=transform, formats = formats, concatenate = True, balance=False)

healthy_indices = []
disease_indices = []
    
for i, (_, label) in enumerate(dataset.samples):
    if label == 0:  # Healthy
        healthy_indices.append(i)
    else:  # Disease
        disease_indices.append(i)
    
    # Dividir los índices de árboles sanos en entrenamiento y validación
np.random.shuffle(healthy_indices)

train_healthy = healthy_indices[:int(0.8 * len(healthy_indices))]
test_healthy = healthy_indices[int(0.8 * len(healthy_indices)):]
train_disease = disease_indices[:int(0.8 * len(disease_indices))]
test_disease = disease_indices[int(0.8 * len(disease_indices)):]
    
    # Los índices de árboles enfermos solo se usan para testing
    # Create datasets
train_dataset = Subset(dataset, train_healthy)
val_dataset = Subset(dataset, test_healthy + test_disease)  # Incluimos enfermos solo para evaluación
    
    # Crear dataloader
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4, worker_init_fn=seed_worker)
test_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4, worker_init_fn=seed_worker)


In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score

class VariationalAutoencoder(nn.Module):
    def __init__(self, input_channels, latent_dim=128):
        super(VariationalAutoencoder, self).__init__()
        
        self.latent_dim = latent_dim
        
        # Encoder: Reducción progresiva de la dimensionalidad
        self.encoder = nn.Sequential(
            # Primera capa convolucional (1280x960 -> 640x480)
            nn.Conv2d(input_channels, 16, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            
            # Segunda capa convolucional (640x480 -> 320x240)
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            
            # Tercera capa convolucional (320x240 -> 160x120)
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            
            # Cuarta capa convolucional (160x120 -> 80x60)
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            
            # Quinta capa convolucional (80x60 -> 40x30)
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            
            # Sexta capa convolucional (40x30 -> 20x15)
            nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
        )
        
        # Tamaño del feature map después del encoder
        self.feature_size = (20, 15)
        self.encoder_output_dim = 512 * self.feature_size[0] * self.feature_size[1]
        
        # Capas para calcular mu y logvar (parámetros de la distribución latente)
        self.fc_mu = nn.Linear(self.encoder_output_dim, latent_dim)
        self.fc_logvar = nn.Linear(self.encoder_output_dim, latent_dim)
        
        # Capa para transformar el espacio latente de vuelta a la forma adecuada para el decoder
        self.fc_decoder = nn.Linear(latent_dim, 512 * self.feature_size[0] * self.feature_size[1])
        
        # Decodificador
        self.decoder = nn.Sequential(
            # Primera capa de transposición convolucional (20x15 -> 40x30)
            nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            
            # Segunda capa de transposición convolucional (40x30 -> 80x60)
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            
            # Tercera capa de transposición convolucional (80x60 -> 160x120)
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            
            # Cuarta capa de transposición convolucional (160x120 -> 320x240)
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            
            # Quinta capa de transposición convolucional (320x240 -> 640x480)
            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            
            # Sexta capa de transposición convolucional (640x480 -> 1280x960)
            nn.ConvTranspose2d(16, input_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()  # Para normalizar la salida entre 0 y 1
        )

    def encode(self, x):
        # Pasa la entrada por el encoder
        x = self.encoder(x)
        # Aplanar para las capas fully connected
        x = x.view(x.size(0), -1)
        # Devuelve mu y logvar
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        # Técnica de reparametrización con mayor estabilidad numérica
        # Aseguramos que logvar esté en un rango razonable
        logvar = torch.clamp(logvar, min=-20, max=20)
        
        # Calculamos la desviación estándar
        std = torch.exp(0.5 * logvar)
        
        # Limitamos también la desviación estándar para evitar valores extremos
        std = torch.clamp(std, min=1e-6, max=10)
        
        # Generamos ruido aleatorio
        eps = torch.randn_like(std)
        
        # Técnica de reparametrización
        z = mu + eps * std
        return z
    
    def decode(self, z):
        # Transforma desde espacio latente a forma adecuada para el decoder
        z = self.fc_decoder(z)
        z = z.view(z.size(0), 512, self.feature_size[0], self.feature_size[1])
        # Pasa por el decoder
        return self.decoder(z)
    
    def forward(self, x):
        # Encode -> reparameterize -> decode
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        reconstructed = self.decode(z)
        return reconstructed, mu, logvar

class VAEAnomalyDetector:
    def __init__(self, input_channels, device='cuda', latent_dim=128):
        self.device = device
        self.model = VariationalAutoencoder(input_channels, latent_dim).to(device)
        self.input_channels = input_channels
        self.threshold = None
        self.latent_dim = latent_dim
    
    def vae_loss(self, recon_x, x, mu, logvar, kld_weight=0.0001):
        """
        Función de pérdida para VAE: Reconstrucción + KL Divergence
        Con mejor estabilidad numérica
        """
        # Error de reconstrucción (MSE) - usando reducción 'mean' para estabilidad
        recon_loss = F.mse_loss(recon_x, x, reduction='mean') * x.size(0) * x.size(1) * x.size(2) * x.size(3)
        
        # Divergencia KL con clipping para estabilidad
        # Limitamos los valores extremos que podrían causar problemas numéricos
        logvar = torch.clamp(logvar, min=-20, max=20)
        mu = torch.clamp(mu, min=-20, max=20)
        
        # Fórmula KL estándar con mejor estabilidad numérica
        kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        
        # Verificamos que no haya NaNs en las pérdidas
        if torch.isnan(recon_loss) or torch.isnan(kld_loss):
            print("¡Alerta! NaN detectado en la función de pérdida")
            print(f"recon_loss: {recon_loss.item() if not torch.isnan(recon_loss) else 'NaN'}")
            print(f"kld_loss: {kld_loss.item() if not torch.isnan(kld_loss) else 'NaN'}")
            print(f"mu min/max: {mu.min().item()}/{mu.max().item()}")
            print(f"logvar min/max: {logvar.min().item()}/{logvar.max().item()}")
            # Proporcionar valores válidos si se detectan NaNs
            if torch.isnan(recon_loss):
                recon_loss = torch.tensor(1.0).to(recon_x.device)
            if torch.isnan(kld_loss):
                kld_loss = torch.tensor(1.0).to(recon_x.device)
        
        # Pérdida total - comenzamos con un peso KL muy bajo
        return recon_loss + kld_weight * kld_loss, recon_loss, kld_loss
    
    def train(self, train_loader, val_loader, epochs=50, lr=1e-4, weight_decay=1e-5, kld_weight=0.0001):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, weight_decay=weight_decay)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
        
        best_val_loss = float('inf')
        train_losses = []
        val_losses = []
        
        for epoch in range(epochs):
            # Entrenamiento (train_loader ya solo contiene árboles sanos)
            self.model.train()
            train_loss = 0
            train_recon_loss = 0
            train_kld_loss = 0
            
            for data, _ in train_loader:
                data = data.to(self.device)
                optimizer.zero_grad()
                
                reconstructed, mu, logvar = self.model(data)
                loss, recon_loss, kld_loss = self.vae_loss(reconstructed, data, mu, logvar, kld_weight)
                
                loss.backward()
                optimizer.step()
                
                train_loss += loss.item()
                train_recon_loss += recon_loss.item()
                train_kld_loss += kld_loss.item()
            
            # Normalizar pérdidas
            train_loss /= len(train_loader.dataset)
            train_recon_loss /= len(train_loader.dataset)
            train_kld_loss /= len(train_loader.dataset)
            train_losses.append(train_loss)
            
            # Validación (solo evaluamos el error de reconstrucción en árboles sanos)
            self.model.eval()
            val_loss = 0
            val_recon_loss = 0
            val_kld_loss = 0
            
            with torch.no_grad():
                for data, labels in val_loader:
                    # Solo evaluar en árboles sanos
                    healthy_mask = labels == 0
                    if not any(healthy_mask):
                        continue
                        
                    data = data[healthy_mask].to(self.device)
                    reconstructed, mu, logvar = self.model(data)
                    
                    loss, recon_loss, kld_loss = self.vae_loss(reconstructed, data, mu, logvar, kld_weight)
                    
                    val_loss += loss.item()
                    val_recon_loss += recon_loss.item()
                    val_kld_loss += kld_loss.item()
            
            # Normalizar pérdidas
            n_healthy = sum((labels.cpu().numpy() == 0).sum() for _, labels in val_loader)
            if n_healthy > 0:
                val_loss /= n_healthy
                val_recon_loss /= n_healthy
                val_kld_loss /= n_healthy
                val_losses.append(val_loss)
                
                # Actualizar el scheduler
                scheduler.step(val_loss)
                
                # Guardar el mejor modelo
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    torch.save(self.model.state_dict(), 'best_vae.pth')
                    print(f'Modelo guardado época: {epoch+1}')
            
            print(f'Epoch {epoch+1}/{epochs}')
            print(f'  Train Loss: {train_loss:.6f} (Recon: {train_recon_loss:.6f}, KLD: {train_kld_loss:.6f})')
            print(f'  Val Loss: {val_loss:.6f} (Recon: {val_recon_loss:.6f}, KLD: {val_kld_loss:.6f})')
        
        # Cargar el mejor modelo
        self.model.load_state_dict(torch.load('best_vae.pth'))
        return train_losses, val_losses
    
    def calculate_threshold(self, val_loader, percentile=95):
        """Calcula el umbral basado en los datos de validación de árboles sanos"""
        self.model.eval()
        reconstruction_errors = []
        
        with torch.no_grad():
            for data, labels in val_loader:
                # Solo usar árboles sanos para calibrar el umbral
                healthy_mask = labels == 0
                if not any(healthy_mask):
                    continue
                    
                healthy_data = data[healthy_mask].to(self.device)
                reconstructed, _, _ = self.model(healthy_data)
                
                # Calcular error por imagen
                for i in range(healthy_data.size(0)):
                    error = torch.mean((reconstructed[i] - healthy_data[i])**2).item()
                    reconstruction_errors.append(error)
        
        # Establecer umbral en el percentil especificado
        self.threshold = np.percentile(reconstruction_errors, percentile)
        print(f'Threshold set to: {self.threshold:.6f} (percentile {percentile})')
        return self.threshold
    
    def predict(self, data_loader):
        """Predice si un árbol está enfermo basándose en el error de reconstrucción"""
        if self.threshold is None:
            raise ValueError("Threshold must be set before prediction using calculate_threshold")
        
        self.model.eval()
        all_preds = []
        all_labels = []
        all_scores = []
        
        with torch.no_grad():
            for data, labels in data_loader:
                data = data.to(self.device)
                reconstructed, _, _ = self.model(data)
                
                # Calcular error de reconstrucción por imagen
                for i in range(data.size(0)):
                    recon_error = torch.mean((reconstructed[i] - data[i])**2).item()
                    # Si el error supera el umbral, clasificar como enfermo (anomalía)
                    prediction = 1 if recon_error > self.threshold else 0
                    all_preds.append(prediction)
                    all_labels.append(labels[i].item())
                    all_scores.append(recon_error)
        
        return np.array(all_preds), np.array(all_labels), np.array(all_scores)
    
    def evaluate(self, data_loader):
        """Evalúa el rendimiento del detector de anomalías"""
        predictions, labels, scores = self.predict(data_loader)
        
        # Calcular métricas
        accuracy = accuracy_score(labels, predictions)
        precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')
        
        # Calcular AUC-ROC
        auc_roc = roc_auc_score(labels, scores)
        
        results = {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1_score': f1,
            'auc_roc': auc_roc
        }
        
        return results, predictions, labels, scores
    
    def generate_samples(self, num_samples=1):
        """Genera nuevas muestras desde el espacio latente"""
        self.model.eval()
        with torch.no_grad():
            # Muestrear del espacio latente (distribución normal)
            z = torch.randn(num_samples, self.latent_dim).to(self.device)
            # Decodificar
            samples = self.model.decode(z)
        return samples

In [None]:
beta=1.0
latent_dim=128
detector = VAEAnomalyDetector(
        input_channels=input_channels, 
        device=device, 
        latent_dim=latent_dim,
    )
print("Training VAE...")
train_losses, val_losses = detector.train(
        train_loader, 
        test_loader, 
        epochs=2, 
    )    # Calcular umbral
print("Calculating threshold...")
threshold = detector.calculate_threshold(train_loader) 
    # Evaluar en el conjunto de prueba
print("Evaluating model...")
results, predictions, labels, scores = detector.evaluate(test_loader)
    # Imprimir resultados
print("\nResults:")
for metric, value in results.items():
    print(f"{metric}: {value:.4f}")

Training VAE...
Modelo guardado época: 1
Epoch 1/2
  Train Loss: 5456831.797658 (Recon: 5456810.296956, KLD: 215277.801531)
  Val Loss: 5158923.000000 (Recon: 5158922.003745, KLD: 9713.938817)
Modelo guardado época: 2
Epoch 2/2
  Train Loss: 5039498.792506 (Recon: 5039497.792974, KLD: 9221.306856)
  Val Loss: 4727160.554307 (Recon: 4727159.565543, KLD: 8574.221925)


  self.model.load_state_dict(torch.load('best_vae.pth'))


Calculating threshold...
