In [1]:
from CherryTreeDataset import CherryTreeDataset
from torchvision import transforms
from funciones_auxiliares import plot_spectra, analyze_image, analyze_tiff_metadata, PATH, crop_central_region, set_seed, seed_worker
from resnet_adapters import adapt_resnet_channels
import torch
from torch.utils.data import DataLoader, SubsetRandomSampler, Subset
import torchvision.models as models
from torchvision.models import ResNet18_Weights, ResNet50_Weights
from sklearn.metrics import confusion_matrix
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import torch.nn as nn
from sklearn.metrics import precision_recall_fscore_support, accuracy_score, roc_auc_score

set_seed(42)
g = torch.Generator()
g.manual_seed(42)

Semillas aleatorias configuradas a: 42


<torch._C.Generator at 0x7fa8aaf57c70>

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'El dispositivo seleccionado es {device}')

El dispositivo seleccionado es cuda


In [3]:
# Define las transformaciones si son necesarias
transform = transforms.Compose([
    transforms.Resize((1280,960)),
    #transforms.Lambda(lambda x: crop_central_region(x, center_ratio=0.8)),
    transforms.ToTensor(),
])

#formats = ('RGB.JPG', 'RED.TIF','GRE.TIF','NIR.TIF','REG.TIF')
#formats = ('RGB.JPG','NIR.TIF','REG.TIF')
formats = ('RGB.JPG',)
input_channels = 3
dataset = CherryTreeDataset(PATH, transform=transform, formats = formats, concatenate = True, balance=False)

healthy_indices = []
disease_indices = []
    
for i, (_, label) in enumerate(dataset.samples):
    if label == 0:  # Healthy
        healthy_indices.append(i)
    else:  # Disease
        disease_indices.append(i)
    
    # Dividir los índices de árboles sanos en entrenamiento y validación
np.random.shuffle(healthy_indices)

train_healthy = healthy_indices[:int(0.8 * len(healthy_indices))]
test_healthy = healthy_indices[int(0.8 * len(healthy_indices)):]
train_disease = disease_indices[:int(0.8 * len(disease_indices))]
test_disease = disease_indices[int(0.8 * len(disease_indices)):]
    
    # Los índices de árboles enfermos solo se usan para testing
    # Create datasets
train_dataset = Subset(dataset, train_healthy)
val_dataset = Subset(dataset, test_healthy + test_disease)  # Incluimos enfermos solo para evaluación
    
    # Crear dataloader
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4, worker_init_fn=seed_worker)
test_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4, worker_init_fn=seed_worker)


In [4]:
class VariationalAutoencoder(nn.Module):
    def __init__(self, input_channels):
        super(VariationalAutoencoder, self).__init__()
        
        # Encoder: Reducción progresiva de la dimensionalidad a través de más capas convolucionales
        self.encoder = nn.Sequential(
            # 1280x960 -> 640x480
            nn.Conv2d(input_channels, 16, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 640x480 -> 320x240
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 320x240 -> 160x120
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 160x120 -> 80x60
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 80x60 -> 40x30
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 40x30 -> 20x15
            nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
        )
        
        # Tamaño del feature map después del encoder: 20x15 con 512 canales
        self.encoder_output_h = 15
        self.encoder_output_w = 20
        self.encoder_output_channels = 512
        
        # Capa de proyección para reducir dimensionalidad antes de capas fully-connected
        self.projection = nn.Sequential(
            nn.AdaptiveAvgPool2d((8, 8)),  # Reduce a 8x8 independientemente del tamaño de entrada
            nn.Flatten(),
            nn.Linear(512 * 8 * 8, 1024),
            nn.LeakyReLU(0.2, inplace=True),
        )
        
        # Capas para generar la media y el log de la varianza
        self.fc_mu = nn.Linear(1024, latent_dim)
        self.fc_log_var = nn.Linear(1024, latent_dim)
        
        # Capa para pasar del espacio latente a la forma del decoder
        self.fc_decoder = nn.Linear(latent_dim, 512 * 8 * 8)
        
        # Capa de desproyección para recuperar la dimensionalidad antes de deconvoluciones
        self.unprojection = nn.Sequential(
            nn.Linear(512 * 8 * 8, 512 * self.encoder_output_h * self.encoder_output_w),
            nn.LeakyReLU(0.2, inplace=True),
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            # 20x15 -> 40x30
            nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 40x30 -> 80x60
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 80x60 -> 160x120
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 160x120 -> 320x240
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 320x240 -> 640x480
            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(16),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 640x480 -> 1280x960
            nn.ConvTranspose2d(16, input_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )

        # Inicialización de pesos
        self._initialize_weights()

    def forward(self, x):
        # Propagación hacia adelante del encoder y decoder
        x = self.encoder(x)
        x = self.decoder(x)
        return x

    def encode(self, x):
        # Codificar mediante convoluciones
        x = self.encoder(x)
        
        # Proyectar a menor dimensionalidad
        x = self.projection(x)
        
        # Calcular mu y log_var
        mu = self.fc_mu(x)
        log_var = self.fc_log_var(x)

    def reparameterize(self, mu, log_var):
        # Trick de reparametrización
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        z = mu + eps * std
        return z

    def decode(self, z):
        # Del espacio latente a la forma para el decoder
        x = self.fc_decoder(z)
        x = x.view(x.size(0), 512, 8, 8)
        
        # Desproyectar a la dimensión original
        x = self.unprojection(x.flatten(1))
        x = x.view(x.size(0), 512, self.encoder_output_h, self.encoder_output_w)
        
        # Decodificar mediante transposición de convoluciones
        x = self.decoder(x)
        return x

    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        x_recon = self.decode(z)
        return x_recon, mu, log_va

    def _initialize_weights(self):
        # Inicialización de los pesos de manera más robusta
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu', generator=torch.Generator().manual_seed(42))
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

class VAEAnomalyDetector:
    def __init__(self, input_channels, device='cuda', latent_dim=256,beta=1.0):
        self.device = device
        self.model = VariationalAutoencoder(input_channels).to(device)
        self.input_channels = input_channels
        self.threshold = None
        self.latent_dim = latent_dim
        self.beta = beta # Factor de peso para el término KL

    def loss_function(self, recon_x, x, mu, log_var):
        # Error de reconstrucción (MSE)
        recon_loss = F.mse_loss(recon_x, x, reduction='sum')
        
        # Divergencia KL
        kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        
        # Pérdida total
        total_loss = recon_loss + self.beta * kl_loss
        
        return total_loss, recon_loss, kl_loss
    
    def train(self, train_loader, val_loader, epochs=50, lr=1e-3, weight_decay=1e-5, beta_annealing=True):
        optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay=weight_decay)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
        
        # Para almacenar las pérdidas
        train_losses = []
        val_losses = []
        best_val_loss = float('inf')
        
        # Beta annealing: aumentar beta gradualmente
        initial_beta = 0.0
        final_beta = self.beta
        
        for epoch in range(epochs):
            # Actualizar beta si se usa annealing
            if beta_annealing:
                self.beta = initial_beta + (final_beta - initial_beta) * min(1.0, epoch / (epochs / 3))
            
            # ----- Entrenamiento -----
            self.model.train()
            train_loss = 0
            train_recon_loss = 0
            train_kl_loss = 0
            
            for data, _ in train_loader:
                # Solo usar árboles sanos para entrenar
                healthy_mask = _ == 0
                if not any(healthy_mask):
                    continue
                
                data = data[healthy_mask].to(self.device)
                optimizer.zero_grad()
                
                # Forward pass
                recon_batch, mu, log_var = self.model(data)
                
                # Calcular pérdida
                loss, recon, kl = self.loss_function(recon_batch, data, mu, log_var)
                loss.backward()
                optimizer.step()
                
                train_loss += loss.item()
                train_recon_loss += recon.item()
                train_kl_loss += kl.item()
            
            n_batches = len(train_loader)
            train_loss /= n_batches
            train_recon_loss /= n_batches
            train_kl_loss /= n_batches
            train_losses.append(train_loss)
            
            # ----- Validación -----
            self.model.eval()
            val_loss = 0
            val_recon_loss = 0
            val_kl_loss = 0
            
            with torch.no_grad():
                for data, _ in val_loader:
                    healthy_mask = _ == 0
                    if not any(healthy_mask):
                        continue
                    
                    data = data[healthy_mask].to(self.device)
                    
                    # Forward pass
                    recon_batch, mu, log_var = self.model(data)
                    
                    # Calcular pérdida
                    loss, recon, kl = self.loss_function(recon_batch, data, mu, log_var)
                    
                    val_loss += loss.item()
                    val_recon_loss += recon.item()
                    val_kl_loss += kl.item()
            
            val_loss /= len(val_loader)
            val_recon_loss /= len(val_loader)
            val_kl_loss /= len(val_loader)
            val_losses.append(val_loss)
            
            # Actualizar scheduler
            scheduler.step(val_loss)
            
            # Guardar el mejor modelo
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save(self.model.state_dict(), 'best_vae_anomaly_detector.pth')
            
            print(f'Epoch {epoch+1}/{epochs}, Beta: {self.beta:.2f}, Train Loss: {train_loss:.4f} '
                  f'(Recon: {train_recon_loss:.4f}, KL: {train_kl_loss:.4f}), '
                  f'Val Loss: {val_loss:.4f} (Recon: {val_recon_loss:.4f}, KL: {val_kl_loss:.4f})')
        
        # Cargar el mejor modelo
        self.model.load_state_dict(torch.load('best_vae_anomaly_detector.pth'))
        return train_losses, val_losses
    
    def anomaly_score(self, x):
        """Calcula la puntuación de anomalía para una imagen"""
        self.model.eval()
        with torch.no_grad():
            x_recon, mu, log_var = self.model(x)
            
            # Error de reconstrucción
            recon_error = F.mse_loss(x_recon, x, reduction='none').mean(dim=[1, 2, 3])
            
            # Divergencia KL
            kl_div = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=1)
            
            # Puntuación de anomalía combinada
            score = recon_error + self.beta * kl_div
            
            return score, recon_error, kl_div

    def calculate_threshold(self, val_loader, percentile=95):
        """Calcula el umbral de detección basado en los datos de validación de árboles sanos"""
        self.model.eval()
        anomaly_scores = []
        
        with torch.no_grad():
            for data, labels in val_loader:
                # Solo usar árboles sanos para calibrar el umbral
                healthy_mask = labels == 0
                if not any(healthy_mask):
                    continue
                
                healthy_data = data[healthy_mask].to(self.device)
                scores, _, _ = self.anomaly_score(healthy_data)
                
                # Añadir puntuaciones a la lista
                anomaly_scores.extend(scores.cpu().numpy())
        
        # Establecer umbral en el percentil especificado
        self.threshold = np.percentile(anomaly_scores, percentile)
        print(f'Threshold set to: {self.threshold:.6f} (percentile {percentile})')
        return self.threshold
    
    def predict(self, data_loader):
        """Predice si un árbol está enfermo basándose en la puntuación de anomalía"""
        if self.threshold is None:
            raise ValueError("Threshold must be set before prediction using calculate_threshold")
        
        self.model.eval()
        all_preds = []
        all_labels = []
        all_scores = []
        
        with torch.no_grad():
            for data, labels in data_loader:
                data = data.to(self.device)
                scores, recon_errors, kl_divs = self.anomaly_score(data)
                
                # Clasificar como enfermo si la puntuación supera el umbral
                predictions = (scores > self.threshold).cpu().numpy().astype(int)
                
                all_preds.extend(predictions)
                all_labels.extend(labels.cpu().numpy())
                all_scores.extend(scores.cpu().numpy())
        
        return np.array(all_preds), np.array(all_labels), np.array(all_scores)
    
    def evaluate(self, data_loader):
        """Evalúa el rendimiento del detector de anomalías"""
        predictions, labels, scores = self.predict(data_loader)
        
        # Calcular métricas
        accuracy = accuracy_score(labels, predictions)
        precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')
        
        # Calcular AUC-ROC
        auc_roc = roc_auc_score(labels, scores)
        
        results = {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1_score': f1,
            'auc_roc': auc_roc
        }
        
        return results, predictions, labels, scores

In [None]:
beta=1.0
latent_dim=128
detector = VAEAnomalyDetector(
        input_channels=input_channels, 
        device=device, 
        latent_dim=latent_dim,
        beta=beta
    )
print("Training VAE...")
train_losses, val_losses = detector.train(
        train_loader, 
        test_loader, 
        epochs=2, 
        beta_annealing=True
    )    # Calcular umbral
print("Calculating threshold...")
threshold = detector.calculate_threshold(train_loader) 
    # Evaluar en el conjunto de prueba
print("Evaluating model...")
results, predictions, labels, scores = detector.evaluate(test_loader)
    # Imprimir resultados
print("\nResults:")
for metric, value in results.items():
    print(f"{metric}: {value:.4f}")