In [2]:
import torch
import torch.nn.functional as F

def image_cost_function(predictions, targets, task_type='classification', weight=None):
    """
    General cost function for image-related tasks.

    Args:
        predictions: Tensor of predicted values (logits, probabilities, or pixel intensities).
        targets: Tensor of ground truth values (labels or images).
        task_type: The type of task: 'classification', 'segmentation', or 'generation'.
        weight: Class weights (optional, for imbalanced datasets).
    
    Returns:
        cost: Computed loss for the task.
    """
    if task_type == 'classification':
        # Cross-entropy loss for classification
        loss = F.cross_entropy(predictions, targets, weight=weight)
    
    elif task_type == 'segmentation':
        # Dice loss for segmentation
        smooth = 1.0
        predictions = torch.sigmoid(predictions)  # Ensure logits are in range [0, 1]
        intersection = torch.sum(predictions * targets)
        union = torch.sum(predictions) + torch.sum(targets)
        loss = 1 - (2.0 * intersection + smooth) / (union + smooth)
    
    elif task_type == 'generation':
        # Mean squared error (MSE) for image generation
        loss = F.mse_loss(predictions, targets)
    
    else:
        raise ValueError("Unsupported task type. Choose from 'classification', 'segmentation', or 'generation'.")
    
    return loss

# Example 1: Classification Task
# Dummy inputs for classification
predictions_classification = torch.tensor([[2.0, 1.0, 0.1], [0.5, 2.5, 0.3]], requires_grad=True)  # Logits
targets_classification = torch.tensor([0, 1])  # Ground truth labels

# Calculate loss
classification_loss = image_cost_function(predictions_classification, targets_classification, task_type='classification')
print("Classification Loss:", classification_loss.item())

# Example 2: Segmentation Task
# Dummy inputs for segmentation
predictions_segmentation = torch.randn(1, 1, 256, 256, requires_grad=True)  # Logits for segmentation
targets_segmentation = torch.randint(0, 2, (1, 1, 256, 256)).float()  # Ground truth binary mask

# Calculate loss
segmentation_loss = image_cost_function(predictions_segmentation, targets_segmentation, task_type='segmentation')
print("Segmentation Loss:", segmentation_loss.item())

# Example 3: Image Generation Task
# Dummy inputs for image generation
predictions_generation = torch.rand(1, 3, 64, 64, requires_grad=True)  # Predicted image
targets_generation = torch.rand(1, 3, 64, 64)  # Ground truth image

# Calculate loss
generation_loss = image_cost_function(predictions_generation, targets_generation, task_type='generation')
print("Generation Loss:", generation_loss.item())


Classification Loss: 0.31853973865509033
Segmentation Loss: 0.5007040500640869
Generation Loss: 0.16457389295101166
