[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/gouthamgo/FineTuning/blob/main/lessons/module3_advanced/03_custom_loss_functions.ipynb)

# üé≤ Custom Loss Functions: Get Creative!

**Duration:** 1 hour  
**Level:** Advanced  
**Prerequisites:** Module 2 complete

---

## When Standard Loss Functions Aren't Enough üéØ

Most of the time, `CrossEntropyLoss` works great. But sometimes you need something special:

**Situations where custom loss functions shine:**
- üìä **Class imbalance** - Some classes way more common than others
- üéØ **Custom business metrics** - Optimize for what actually matters
- ü§ù **Multi-objective learning** - Balance multiple goals
- üí∞ **Cost-sensitive prediction** - Some errors cost more than others
- üé™ **Ranking/Ordering tasks** - Not just classification

Today we'll build custom loss functions from scratch and see when to use them!

Let's get creative! üé®

## üß† Understanding Loss Functions

### What IS a loss function?

It's how we tell the model "you're doing this wrong, fix it!"

**Simple analogy:**
- Model makes prediction
- Loss function compares to true answer
- Returns a number (higher = worse)
- Model adjusts to make that number smaller

**Example - CrossEntropyLoss:**
```python
# True label: class 1 (positive)
# Model predicts: [0.2, 0.8] (80% confident it's class 1)
# Loss: low (good prediction!)

# Model predicts: [0.9, 0.1] (90% confident it's class 0)
# Loss: high (bad prediction, needs fixing!)
```

Now let's build our own! üí™

In [None]:
!pip install -q torch transformers datasets numpy matplotlib

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

## üéØ Custom Loss #1: Weighted Cross-Entropy

**Problem:** You have 1000 negative examples and only 100 positive ones.

Standard loss treats them equally ‚Üí Model learns to just predict "negative" always!

**Solution:** Give more weight to the minority class

In [None]:
class WeightedCrossEntropyLoss(nn.Module):
    """Custom loss that weights classes differently"""
    
    def __init__(self, weights):
        super().__init__()
        self.weights = torch.tensor(weights, dtype=torch.float32)
    
    def forward(self, predictions, targets):
        """
        predictions: model logits [batch_size, num_classes]
        targets: true labels [batch_size]
        """
        # Move weights to same device as predictions
        weights = self.weights.to(predictions.device)
        
        # Use PyTorch's cross entropy with weights
        return F.cross_entropy(predictions, targets, weight=weights)

# Example usage
print("üéØ Weighted Cross-Entropy Example\n")

# Simulate imbalanced data: 90% class 0, 10% class 1
# We want to give class 1 more importance
class_weights = [1.0, 9.0]  # Class 1 is 9x more important

# Create loss function
weighted_loss = WeightedCrossEntropyLoss(class_weights)
standard_loss = nn.CrossEntropyLoss()

# Test predictions
predictions = torch.tensor([[2.0, -1.0], [1.0, 3.0]])  # 2 examples
targets = torch.tensor([0, 1])  # True labels

print(f"Standard loss: {standard_loss(predictions, targets):.4f}")
print(f"Weighted loss: {weighted_loss(predictions, targets):.4f}")
print("\nüí° Weighted loss is higher because we care more about class 1 errors!")

## üî• Custom Loss #2: Focal Loss

**Problem:** Model is too confident on easy examples, ignores hard ones

**Solution:** Focal Loss - reduces loss for well-classified examples, focuses on hard cases

**When to use:** Imbalanced classes + want model to focus on hard examples

This is what **Facebook AI Research used for object detection!**

In [None]:
class FocalLoss(nn.Module):
    """Focal Loss - focuses on hard examples
    
    Paper: https://arxiv.org/abs/1708.02002
    Used by: Facebook AI Research for object detection
    """
    
    def __init__(self, alpha=1.0, gamma=2.0):
        super().__init__()
        self.alpha = alpha  # Weighting factor
        self.gamma = gamma  # Focusing parameter (higher = more focus on hard examples)
    
    def forward(self, predictions, targets):
        # Get probabilities
        probs = F.softmax(predictions, dim=-1)
        
        # Get probability of correct class
        batch_size = targets.size(0)
        correct_class_probs = probs[range(batch_size), targets]
        
        # Focal weight: (1 - p)^gamma
        # If p is high (easy example), weight is low
        # If p is low (hard example), weight is high
        focal_weight = (1 - correct_class_probs) ** self.gamma
        
        # Calculate cross entropy
        ce_loss = F.cross_entropy(predictions, targets, reduction='none')
        
        # Apply focal weight
        focal_loss = self.alpha * focal_weight * ce_loss
        
        return focal_loss.mean()

# Demonstrate the difference
print("üî• Focal Loss vs Standard Loss\n")

focal_loss = FocalLoss(alpha=1.0, gamma=2.0)
ce_loss = nn.CrossEntropyLoss()

# Easy example: model is very confident and correct
easy_pred = torch.tensor([[10.0, -5.0]])  # Very confident about class 0
easy_target = torch.tensor([0])

print("Easy example (model confident & correct):")
print(f"  Standard loss: {ce_loss(easy_pred, easy_target):.4f}")
print(f"  Focal loss: {focal_loss(easy_pred, easy_target):.4f}")
print("  ‚Üí Focal loss is MUCH lower (focuses less on easy examples)\n")

# Hard example: model is uncertain
hard_pred = torch.tensor([[0.5, 0.4]])  # Uncertain between classes
hard_target = torch.tensor([0])

print("Hard example (model uncertain):")
print(f"  Standard loss: {ce_loss(hard_pred, hard_target):.4f}")
print(f"  Focal loss: {focal_loss(hard_pred, hard_target):.4f}")
print("  ‚Üí Focal loss is higher (focuses more on hard examples)\n")

print("üí° This is how Focal Loss helps with imbalanced data!")

## üí∞ Custom Loss #3: Cost-Sensitive Loss

**Problem:** Not all errors are equal!

**Example:** Medical diagnosis
- False Negative (missing cancer) = VERY BAD ($1M+ cost, could be fatal)
- False Positive (extra test) = Not great but okay ($100 cost)

**Solution:** Custom loss that penalizes expensive errors more

In [None]:
class CostSensitiveLoss(nn.Module):
    """Loss function that considers cost of different errors"""
    
    def __init__(self, cost_matrix):
        """
        cost_matrix: [num_classes, num_classes]
        cost_matrix[i][j] = cost of predicting j when true label is i
        """
        super().__init__()
        self.cost_matrix = torch.tensor(cost_matrix, dtype=torch.float32)
    
    def forward(self, predictions, targets):
        batch_size, num_classes = predictions.shape
        
        # Get probabilities
        probs = F.softmax(predictions, dim=-1)
        
        # Move cost matrix to same device
        cost_matrix = self.cost_matrix.to(predictions.device)
        
        # Calculate expected cost for each example
        losses = []
        for i in range(batch_size):
            true_class = targets[i]
            # Expected cost = sum of (probability of predicting j) * (cost of predicting j)
            expected_cost = (probs[i] * cost_matrix[true_class]).sum()
            losses.append(expected_cost)
        
        return torch.stack(losses).mean()

# Medical example
print("üíä Medical Diagnosis Example\n")

# Cost matrix for binary classification (healthy vs sick)
# Rows = true label, Columns = predicted label
cost_matrix = [
    [0,    100],   # True=Healthy, Pred=Healthy (0 cost) or Pred=Sick (100 cost - unnecessary treatment)
    [10000, 0]     # True=Sick, Pred=Healthy (10000 cost - missed diagnosis!) or Pred=Sick (0 cost)
]

cost_loss = CostSensitiveLoss(cost_matrix)
standard_loss = nn.CrossEntropyLoss()

# Case 1: Model predicts healthy for a sick patient (VERY BAD!)
pred_healthy = torch.tensor([[2.0, -2.0]])  # Confident about healthy
true_sick = torch.tensor([1])  # Actually sick

print("Case 1: Predicting healthy for sick patient")
print(f"  Standard loss: {standard_loss(pred_healthy, true_sick):.4f}")
print(f"  Cost-sensitive loss: {cost_loss(pred_healthy, true_sick):.4f}")
print("  ‚Üí Cost-sensitive loss is MUCH higher (this error is expensive!)\n")

# Case 2: Model predicts sick for a healthy patient (not great but okay)
pred_sick = torch.tensor([[-2.0, 2.0]])  # Confident about sick
true_healthy = torch.tensor([0])  # Actually healthy

print("Case 2: Predicting sick for healthy patient")
print(f"  Standard loss: {standard_loss(pred_sick, true_healthy):.4f}")
print(f"  Cost-sensitive loss: {cost_loss(pred_sick, true_healthy):.4f}")
print("  ‚Üí Cost-sensitive loss is lower (this error is less expensive)\n")

print("üí° Model learns to be more conservative - better safe than sorry!")

## üé™ Custom Loss #4: Ranking Loss (Triplet Loss)

**Use case:** You want to learn similarity, not just classify

**Examples:**
- Face recognition (same person = similar, different person = dissimilar)
- Search engines (relevant docs = similar to query)
- Recommendation systems

**How it works:**
- Anchor: your reference point
- Positive: similar example
- Negative: dissimilar example

Goal: Make anchor closer to positive than to negative

In [None]:
class TripletLoss(nn.Module):
    """Triplet Loss for learning embeddings
    
    Used in: Face recognition, semantic search, recommendations
    """
    
    def __init__(self, margin=1.0):
        super().__init__()
        self.margin = margin  # How much closer should positive be?
    
    def forward(self, anchor, positive, negative):
        """
        anchor: embeddings of anchor examples [batch_size, embedding_dim]
        positive: embeddings of similar examples [batch_size, embedding_dim]
        negative: embeddings of dissimilar examples [batch_size, embedding_dim]
        """
        # Distance between anchor and positive (should be small)
        pos_dist = F.pairwise_distance(anchor, positive, p=2)
        
        # Distance between anchor and negative (should be large)
        neg_dist = F.pairwise_distance(anchor, negative, p=2)
        
        # Loss: we want pos_dist < neg_dist - margin
        # If already satisfied, loss = 0
        # If not, loss = how much we need to improve
        loss = F.relu(pos_dist - neg_dist + self.margin)
        
        return loss.mean()

# Example: Learning document similarity
print("üìö Document Similarity Example\n")

triplet_loss = TripletLoss(margin=1.0)

# Simulate embeddings (in reality, these come from your model)
anchor = torch.randn(4, 128)  # 4 anchor documents, 128-dim embeddings
positive = anchor + torch.randn(4, 128) * 0.1  # Similar docs (close to anchor)
negative = torch.randn(4, 128)  # Dissimilar docs (random)

loss = triplet_loss(anchor, positive, negative)
print(f"Triplet loss: {loss.item():.4f}")

# Visualize the concept
print("\nüìä Visual concept:")
print("\n  Anchor    Positive    Negative")
print("    ‚Ä¢  ----  ‚Ä¢             ")
print("          (close)          ")
print("    ‚Ä¢  --------------------  ‚Ä¢")
print("              (far)          ")
print("\nüí° Goal: Keep positive close, push negative far!")

## üéØ Custom Loss #5: Multi-Task Loss

**Use case:** Training on multiple tasks simultaneously

**Challenge:** Different tasks have different loss scales!

**Solution:** Combine losses with learned weights

In [None]:
class MultiTaskLoss(nn.Module):
    """Automatically balance multiple task losses
    
    Based on: "Multi-Task Learning Using Uncertainty to Weigh Losses"
    Paper: https://arxiv.org/abs/1705.07115
    """
    
    def __init__(self, num_tasks):
        super().__init__()
        # Learnable log variance for each task
        self.log_vars = nn.Parameter(torch.zeros(num_tasks))
    
    def forward(self, losses):
        """
        losses: list of losses for each task
        """
        total_loss = 0
        
        for i, loss in enumerate(losses):
            # Automatic weighting based on learned uncertainty
            precision = torch.exp(-self.log_vars[i])
            total_loss += precision * loss + self.log_vars[i]
        
        return total_loss

# Example with 3 tasks
print("üéØ Multi-Task Loss Example\n")

mtl_loss = MultiTaskLoss(num_tasks=3)

# Simulate losses from 3 different tasks
task1_loss = torch.tensor(0.5)  # Sentiment analysis
task2_loss = torch.tensor(2.0)  # Topic classification
task3_loss = torch.tensor(0.1)  # Length prediction

combined_loss = mtl_loss([task1_loss, task2_loss, task3_loss])

print(f"Task 1 loss: {task1_loss:.4f}")
print(f"Task 2 loss: {task2_loss:.4f}")
print(f"Task 3 loss: {task3_loss:.4f}")
print(f"\nCombined loss: {combined_loss:.4f}")
print(f"\nLearned weights: {torch.exp(-mtl_loss.log_vars).detach()}")
print("\nüí° Model automatically learns how to balance tasks!")

## üõ†Ô∏è How to Use Custom Loss in Training

Super easy with HuggingFace Trainer!

In [None]:
from transformers import Trainer, TrainingArguments

class CustomLossTrainer(Trainer):
    """Trainer with custom loss function"""
    
    def __init__(self, *args, custom_loss_fn=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.custom_loss_fn = custom_loss_fn
    
    def compute_loss(self, model, inputs, return_outputs=False):
        # Get model outputs
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        
        # Use custom loss if provided
        if self.custom_loss_fn:
            loss = self.custom_loss_fn(logits, labels)
        else:
            # Use default loss
            loss = outputs.loss
        
        return (loss, outputs) if return_outputs else loss

# Example usage
print("\nüìù How to use in practice:\n")
print("""
# Create your custom loss
focal_loss = FocalLoss(alpha=1.0, gamma=2.0)

# Use it in training
trainer = CustomLossTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    custom_loss_fn=focal_loss  # ‚Üê Your custom loss!
)

trainer.train()
""")

print("‚úÖ That's it! Model now trains with your custom loss!")

## üéì When to Use Which Loss?

| Problem | Loss Function | Why? |
|---------|--------------|------|
| **Imbalanced classes** | Weighted CE or Focal Loss | Prevents model from ignoring minority class |
| **Hard negative mining** | Focal Loss | Focuses on difficult examples |
| **Cost-sensitive errors** | Custom Cost Matrix | Different errors have different impacts |
| **Learning similarities** | Triplet Loss | For embeddings/search/recommendations |
| **Multiple tasks** | Multi-Task Loss | Automatically balances task importance |
| **Ranking problems** | Ranking Loss | When order matters more than class |
| **Outlier robustness** | Huber Loss | Combines MSE + MAE benefits |

**Default choice:** CrossEntropyLoss (it works 90% of the time!)

**When to customize:**
- You have a specific business need (e.g., false negatives are very expensive)
- Standard loss isn't working well
- You're doing something unique (similarity learning, multi-task, etc.)

## üí° Pro Tips

### 1. **Start Simple**
Always try standard loss first. Only customize if you have a good reason.

### 2. **Monitor Multiple Metrics**
```python
# Don't just look at loss!
metrics = {
    'loss': loss.item(),
    'accuracy': accuracy,
    'f1_score': f1,
    'per_class_accuracy': per_class_acc
}
```

### 3. **Validate on Business Metrics**
```python
# Example: Medical diagnosis
false_negatives = count_fn(predictions, labels)
if false_negatives > threshold:
    # Increase weight on that class
```

### 4. **Experiment with Hyperparameters**
```python
# Try different values
for gamma in [0.5, 1.0, 2.0, 5.0]:
    loss_fn = FocalLoss(gamma=gamma)
    # Train and compare
```

### 5. **Combine Losses**
```python
# Sometimes you want multiple objectives
total_loss = 0.7 * classification_loss + 0.3 * regularization_loss
```

## üéâ You're Now a Loss Function Expert!

You learned:
- ‚úÖ How loss functions work
- ‚úÖ 5 powerful custom losses (Weighted, Focal, Cost-Sensitive, Triplet, Multi-Task)
- ‚úÖ When to use each one
- ‚úÖ How to implement them in PyTorch
- ‚úÖ How to use them with HuggingFace Trainer

**Real-world applications:**
- Medical diagnosis (cost-sensitive)
- Fraud detection (focal loss for rare events)
- Face recognition (triplet loss)
- Multi-task NLP (multi-task loss)

**Interview question you can now answer:**

Q: "Have you ever used custom loss functions?"

A: "Yes! In my customer support bot project, I used weighted cross-entropy because we had imbalanced categories - 80% billing questions vs 20% technical. By weighting the technical class higher (weight=4.0), I improved F1 score on technical questions from 0.65 to 0.82 while maintaining overall accuracy."

**That's the kind of answer that gets you hired!** üöÄ

---

**Next up:** Real-world projects where you'll use these techniques! üíº

## üìö Further Reading

**Papers:**
- [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002)
- [Multi-Task Learning Using Uncertainty to Weigh Losses](https://arxiv.org/abs/1705.07115)
- [FaceNet: A Unified Embedding for Face Recognition](https://arxiv.org/abs/1503.03832) (Triplet Loss)

**Resources:**
- [PyTorch Loss Functions](https://pytorch.org/docs/stable/nn.html#loss-functions)
- [Papers With Code - Loss Functions](https://paperswithcode.com/methods/category/loss-functions)

Now go optimize those losses! üí™