In [None]:
from tqdm import tqdm
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score
import numpy as np

def dynamic_lambda(t, T, schedule='linear'):
    """
    Compute dynamic lambda(t) based on the iteration number and schedule type.
    
    Args:
        t (int): Current iteration (1-based index).
        T (int): Total number of attack iterations.
        schedule (str): Type of schedule ('linear', 'log', 'exp').
        
    Returns:
        float: Lambda value for the current iteration.
    """
    if schedule == 'linear':
        return (t - 1) / (2 * T)
    elif schedule == 'log':
        return 0.5 * np.log2(1 + (t - 1) / T)
    elif schedule == 'exp':
        return 0.5 * (2 * (t - 1) / T - 1)
    else:
        raise ValueError("Unsupported schedule type. Choose from 'linear', 'log', 'exp'.")

def segpgd_attack(model, loader: DataLoader, criterion, linf_bound, num_pgd_steps=10, device="cuda", schedule='linear'):
    """
    Perform SegPGD attack on the segmentation model and evaluate its robustness.
    
    Args:
        model (nn.Module): Trained segmentation model.
        loader (DataLoader): DataLoader for input data.
        criterion (nn.Module): Loss function (e.g., CrossEntropyLoss).
        linf_bound (float): Maximum perturbation (epsilon).
        num_pgd_steps (int): Number of PGD attack steps.
        device (str): Device to perform computations on ('cuda' or 'cpu').
        schedule (str): Schedule type for lambda(t).
        
    Returns:
        float: mIoU (mean Intersection over Union) after the attack.
    """
    model.eval()
    total_miou = 0.0
    total_samples = 0
    
    with torch.no_grad():
        for vectors, labels in tqdm(loader, desc="SegPGD Attack", total=len(loader)):
            vectors, labels = vectors.to(device), labels.to(device)
            batch_size, _, H, W = vectors.shape
            # Initialize perturbations randomly within the l_inf ball
            perts = torch.empty_like(vectors).uniform_(-linf_bound, linf_bound).to(device)
            perts.requires_grad = True
            
            for t in range(1, num_pgd_steps + 1):
                # Compute model predictions
                outputs = model(vectors + perts)  # Shape: (batch_size, num_classes, H, W)
                
                # Get predicted classes
                preds = torch.argmax(outputs, dim=1)  # Shape: (batch_size, H, W)
                
                # Determine correctly and incorrectly classified pixels
                correct = (preds == labels)  # Shape: (batch_size, H, W)
                incorrect = ~correct
                
                # Compute dynamic lambda(t)
                lambda_t = dynamic_lambda(t, num_pgd_steps, schedule=schedule)
                
                # Compute loss
                # Flatten tensors to shape (batch_size * H * W, ...)
                correct = correct.view(-1)
                incorrect = incorrect.view(-1)
                outputs = outputs.permute(0, 2, 3, 1).reshape(-1, outputs.shape[1])  # Shape: (batch_size*H*W, num_classes)
                labels_flat = labels.view(-1)  # Shape: (batch_size*H*W)
                
                # Cross-entropy loss for correct pixels
                if correct.sum() > 0:
                    loss_correct = criterion(outputs[correct], labels_flat[correct])
                else:
                    loss_correct = torch.tensor(0.0).to(device)
                
                # Cross-entropy loss for incorrect pixels
                if incorrect.sum() > 0:
                    loss_incorrect = criterion(outputs[incorrect], labels_flat[incorrect])
                else:
                    loss_incorrect = torch.tensor(0.0).to(device)
                
                # Weighted loss
                loss = (1 - lambda_t) * loss_correct + lambda_t * loss_incorrect
                
                # Backward pass
                model.zero_grad()
                loss.backward()
                
                # Update perturbations
                perts_grad = perts.grad.data
                perts = perts + linf_bound / num_pgd_steps * perts_grad.sign()
                
                # Project perturbations back to l_inf ball
                perts = torch.clamp(vectors + perts, min=vectors - linf_bound, max=vectors + linf_bound) - vectors
                perts = perts.detach().requires_grad_()
            
            # After attack iterations, compute mIoU
            with torch.no_grad():
                adversarial_vectors = vectors + perts
                adversarial_outputs = model(adversarial_vectors)
                adversarial_preds = torch.argmax(adversarial_outputs, dim=1)  # Shape: (batch_size, H, W)
                
                # Compute mIoU for the batch
                miou = compute_miou(adversarial_preds, labels, num_classes=outputs.shape[1])
                total_miou += miou * batch_size
                total_samples += batch_size
    
    average_miou = total_miou / total_samples
    print(f'Average mIoU after SegPGD attack with l_inf_bound={linf_bound}: {average_miou:.2f}%')
    return average_miou

def compute_miou(preds, labels, num_classes):
    """
    Compute mean Intersection over Union (mIoU) for a batch of predictions and labels.
    
    Args:
        preds (torch.Tensor): Predicted labels. Shape: (batch_size, H, W)
        labels (torch.Tensor): Ground truth labels. Shape: (batch_size, H, W)
        num_classes (int): Number of classes.
        
    Returns:
        float: mIoU score.
    """
    miou = 0.0
    for cls in range(num_classes):
        pred_inds = (preds == cls)
        target_inds = (labels == cls)
        intersection = (pred_inds & target_inds).sum().item()
        union = (pred_inds | target_inds).sum().item()
        if union == 0:
            continue  # Skip if there is no ground truth for this class
        miou += intersection / union
    miou /= num_classes
    return miou * 100