In [None]:
import torch
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import CyclicLR
import wandb

class UNetTrainer:
    def __init__(self, model, device, optimizer, criterion, train_loader, val_loader=None, test_loader=None,
                 lr_step_size=10, lr_gamma=0.1, epochs=10, filename_model = 'best_model.pth', folder_pretrained= 'pretrained_model/'):
        """
        UNet Trainer class that handles training, validation, and testing.

        Args:
        - model: The UNet model to be trained.
        - device: The computing device (CPU or GPU).
        - optimizer: The optimizer to use (SGD, Adam, etc.).
        - criterion: The loss function.
        - train_loader: DataLoader for training.
        - val_loader: DataLoader for validation (optional).
        - test_loader: DataLoader for testing (optional).
        - lr_step_size: Step size for learning rate scheduler (if applicable).
        - lr_gamma: Decay factor for learning rate scheduler.
        - epochs: Number of training epochs.
        """
        self.model = model.to(device)
        self.device = device
        self.optimizer = optimizer
        self.criterion = criterion
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.epochs = epochs
        self.lr_step_size = lr_step_size
        self.lr_gamma = lr_gamma
        self.filename_model = filename_model
        self.folder_pretrained = folder_pretrained
        # Learning rate scheduler only for SGD
        self.scheduler = None
        if isinstance(self.optimizer, optim.SGD):
            self.scheduler = CyclicLR(optimizer, base_lr=0.001, max_lr=0.1, step_size_up=lr_step_size, mode='triangular')

        self.best_val_loss = float('inf')
        self.best_model_wts = self.model.state_dict()
        # Save the losses
        self.train_losses = []
        self.val_losses = []

    def train(self):
        """Trains the UNet model with validation."""
        for epoch in range(self.epochs):
            # Training phase
            self.model.train()
            train_loss = self._train_one_epoch()

            # Validation phase (if applicable)
            val_loss = self._validate() if self.val_loader else None

            # Print losses and update scheduler if needed
            if val_loss is not None:
                print(f"Epoch [{epoch+1}/{self.epochs}] - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
            else:
                print(f"Epoch [{epoch+1}/{self.epochs}] - Train Loss: {train_loss:.4f}")

            if self.scheduler:
                self.scheduler.step()
                print(f"Learning Rate: {self.scheduler.get_last_lr()[0]:.6f}")

            # Append losses to the lists
            self.train_losses.append(train_loss)   # code added in 6/08/25
            if val_loss is not None:               #code added 6/08/25
                self.val_losses.append(val_loss)   #code added 6/08/25
                
            # Save best model
            if val_loss is not None and val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                self.best_model_wts = self.model.state_dict()
                torch.save(self.best_model_wts, self.folder_pretrained + self.filename_model)
                #wandb.save(self.folder_pretrained + self.filename_model) # Testing phase
                print(f"New best model saved with Val Loss: {self.best_val_loss:.4f}")

            # Log to Weights & Biases (wandb)
            wandb.log({'epoch': epoch, 'train_loss': train_loss, 'val_loss': val_loss})

        print("Training completed.")
        self.model.load_state_dict(self.best_model_wts)
        print("Loaded best model weights.")

    def _train_one_epoch(self):
        """Runs one epoch of training."""
        running_loss = 0.0
        for images, masks in self.train_loader:
            images, masks = images.to(self.device), masks.to(self.device)

            self.optimizer.zero_grad()
            outputs = self.model(images)
            loss = self.criterion(outputs, masks)
            loss.backward()
            self.optimizer.step()

            running_loss += loss.item()

        return running_loss / len(self.train_loader)

    def _validate(self):
        """Runs validation."""
        self.model.eval()
        running_loss = 0.0
        with torch.no_grad():
            for images, masks in self.val_loader:
                images, masks = images.to(self.device), masks.to(self.device)
                outputs = self.model(images)
                loss = self.criterion(outputs, masks)
                running_loss += loss.item()

        return running_loss / len(self.val_loader)

    def plot_losses(self):   # Function added on 06/08/2025
            """Plots the training and validation losses."""
            plt.figure(figsize=(10, 5))
            plt.plot(self.train_losses, label='Training Loss', color='blue', linestyle='-', marker='o')
            if self.val_loader:
                plt.plot(self.val_losses, label='Validation Loss', color='red', linestyle='-', marker='x')
            plt.title('Training and Validation Losses')
            plt.xlabel('Epochs')
            plt.ylabel('Loss')
            plt.legend()
            plt.grid(True)
            plt.show()
        
    def test(self):
        """Runs the model on the test set and returns predictions."""
        if not self.test_loader:
            raise ValueError("Test DataLoader not provided.")

        self.model.eval()
        all_images, all_preds, all_targets = [], [], []

        with torch.no_grad():
            for images, masks in self.test_loader:
                images, masks = images.to(self.device), masks.to(self.device)
                outputs = self.model(images)

                all_images.append(images.cpu())
                all_preds.append(outputs.cpu())
                all_targets.append(masks.cpu())

        return torch.cat(all_images, dim=0), torch.cat(all_preds, dim=0), torch.cat(all_targets, dim=0)

    @staticmethod
    def compute_model_stats(model, input_tensor):
        """
        Computes FLOPs, total parameters, and trainable parameters of the model.

        Returns:
        - flops (int): Total number of FLOPs for a forward pass.
        - total_params (int): Total number of parameters in the model.
        - trainable_params (int): Number of trainable parameters in the model.
        """
        # Compute FLOPs
        flop_count = FlopCountAnalysis(model, input_tensor)

        # Compute total and trainable parameters
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

        #print(f"FLOPs: {flop_count.total():,}")  # Print FLOPs with formatting
        print(f"Total Parameters: {total_params:,}")
        print(f"Trainable Parameters: {trainable_params:,}")

        return flop_count, total_params, trainable_params

    @staticmethod
    def visualize_results(images, targets, preds, num_samples=3, save_path=None):
        if save_path is not None:
            pdf = PdfPages(save_path)
        
        for i in range(num_samples):
            image = images[i].numpy()
            target = targets[i].squeeze(0).numpy()
            pred = preds[i].squeeze(0).numpy()
            pred_binary = (pred > 0.5).astype(np.uint8)
        
            # Ajustar visualizaci√≥n RGB o grayscale
            if image.shape[0] == 1:
                image = image.squeeze(0)
            elif image.shape[0] == 3:
                image = image.transpose(1, 2, 0)
        
            fig, axs = plt.subplots(1, 3, figsize=(15, 5))
            axs[0].imshow(image, cmap="gray" if image.ndim == 2 else None)
            axs[0].set_title("Original Image")
            axs[0].axis('off')
        
            axs[1].imshow(target, cmap="gray")
            axs[1].set_title("Ground Truth")
            axs[1].axis('off')
        
            axs[2].imshow(pred_binary, cmap="gray")
            axs[2].set_title("Prediction")
            axs[2].axis('off')
        
            plt.tight_layout()
        
            if save_path is not None:
                pdf.savefig(fig)  # Guardar figura en PDF
                plt.close(fig)    # Cerrar figura para liberar memoria
            else:
                plt.show()
        
        if save_path is not None:
            pdf.close()
            print(f"Resultados guardados en PDF: {save_path}")
            wandb.save(save_path)

    def save_model_wanDB(self):
        torch.save(self.model.state_dict(), self.folder_pretrained+ self.filename_model)
        # Subir el modelo a WandB
        wandb.save(self.folder_pretrained+ self.filename_model)
        print("Model saved and uploaded to WandB")

    def evaluate_metrics(self, threshold=0.1):
            """
            Evaluates the model and logs the metrics to WandB.
            """
            evaluator = SegmentationEvaluator(self.model, self.test_loader, threshold, device=self.device)
            metrics = evaluator.evaluate(visualize=True)
            print(metrics)
            wandb.log(metrics)  # Log the metrics to WandB
            return metrics