In [1]:
import torch
import torch.nn.functional as F

def image_cost_function(predictions, targets, task_type='classification', weight=None):
    """
    General cost function for image-related tasks.

    Args:
        predictions: Tensor of predicted values (logits, probabilities, or pixel intensities).
        targets: Tensor of ground truth values (labels or images).
        task_type: The type of task: 'classification', 'segmentation', or 'generation'.
        weight: Class weights (optional, for imbalanced datasets).
    
    Returns:
        cost: Computed loss for the task.
    """
    if task_type == 'classification':
        # Cross-entropy loss for classification
        loss = F.cross_entropy(predictions, targets, weight=weight)
    
    elif task_type == 'segmentation':
        # Dice loss for segmentation
        smooth = 1.0
        predictions = torch.sigmoid(predictions)  # Ensure logits are in range [0, 1]
        intersection = torch.sum(predictions * targets)
        union = torch.sum(predictions) + torch.sum(targets)
        loss = 1 - (2.0 * intersection + smooth) / (union + smooth)
    
    elif task_type == 'generation':
        # Mean squared error (MSE) for image generation
        loss = F.mse_loss(predictions, targets)
    
    else:
        raise ValueError("Unsupported task type. Choose from 'classification', 'segmentation', or 'generation'.")
    
    return loss
