### **Log-Softmax Trick: Numerical Stability in Log-Probabilities**

The **log-softmax trick** is a technique used to improve numerical stability when computing the logarithm of a softmax function. Instead of computing:

$$
\log(\text{softmax}(\mathbf{z}))
$$

directly, we use a numerically stable formulation.

---

### **1. The Problem: Direct Log of Softmax**
Given a vector of logits $ \mathbf{z} = [z_1, z_2, ..., z_n] $, the softmax function is:

$$
p_i = \frac{\exp(z_i)}{\sum_j \exp(z_j)}
$$

Taking the logarithm:

$$
\log p_i = \log \left( \frac{\exp(z_i)}{\sum_j \exp(z_j)} \right)
$$

Using the logarithm property:

$$
\log p_i = z_i - \log \left( \sum_j \exp(z_j) \right)
$$

However, directly computing $ \sum_j \exp(z_j) $ can lead to **numerical instability** when the logits are large. The exponential function grows very fast, which can cause:
- **Overflow**: When logits are very large, $ \exp(z_i) $ can exceed the maximum float representation.
- **Underflow**: When logits are very small, $ \exp(z_i) $ can become too small, leading to precision loss.

---

### **2. The Trick: Log-Softmax with a Shift**
To stabilize the computation, we subtract the **maximum logit value** from all logits before applying softmax. Define:

$$
m = \max_j z_j
$$

Then rewrite the softmax function:

$$
p_i = \frac{\exp(z_i - m)}{\sum_j \exp(z_j - m)}
$$

Taking the logarithm:

$$
\log p_i = (z_i - m) - \log \left( \sum_j \exp(z_j - m) \right)
$$

This transformation **does not change the result** but improves numerical stability because:
1. **Preventing Overflow**: Since $ z_i - m $ is at most 0 (largest logit becomes 0), exponentiation is less likely to overflow.
2. **Preventing Underflow**: The exponentials remain in a well-behaved range.

---

### **3. Log-Softmax in PyTorch**
PyTorch provides a built-in numerically stable implementation:

```python
import torch
import torch.nn.functional as F

logits = torch.tensor([1000.0, 1001.0, 1002.0])  # Large values cause numerical issues
log_probs = F.log_softmax(logits, dim=-1)
print(log_probs)
```

Instead of computing `torch.log(torch.softmax(logits, dim=-1))`, using `F.log_softmax` directly applies the trick, ensuring stability.

---

### **4. Summary**
- Computing `log(softmax(x))` directly is unstable due to potential overflow/underflow in `exp(x)`.
- The **log-softmax trick** rewrites the equation by subtracting the max logit value before exponentiation.
- PyTorch’s `F.log_softmax()` is numerically stable and should be preferred over `torch.log(torch.softmax())`.

This trick is crucial in deep learning, particularly in **categorical cross-entropy loss** and **policy gradient methods** like PPO.

In [1]:
import torch
import torch.nn.functional as F

logits = torch.tensor([1000.0, 1001.0, 1002.0])  # Large values cause numerical issues
log_probs = F.log_softmax(logits, dim=-1)
print(log_probs)


tensor([-2.4076, -1.4076, -0.4076])


In [7]:
import torch
import torch.nn.functional as F

logits = torch.tensor([2.0, 1.0, 0.1], requires_grad=True)  # Example logits
log_probs = F.log_softmax(logits, dim=-1)  # Compute log-softmax directly
chosen_action = 0  # Let's say action 0 was chosen

log_prob = log_probs[chosen_action]  # Extract log-probability of chosen action
log_prob.backward()  # Compute gradients

print("Log-Softmax values:", log_probs)
print("Gradients:", logits.grad)  # Should match (delta_i,a - p_i)


Log-Softmax values: tensor([-0.4170, -1.4170, -2.3170], grad_fn=<LogSoftmaxBackward0>)
Gradients: tensor([ 0.3410, -0.2424, -0.0986])


In [6]:
import torch
import torch.nn.functional as F

# Define logits
torch.manual_seed(42)
logits = torch.tensor([2.0, 1.0, 0.1], requires_grad=True)  # Example logits

# Step 1: Compute Softmax Probabilities
exp_logits = torch.exp(logits)
sum_exp_logits = torch.sum(exp_logits)
probs = exp_logits / sum_exp_logits  # Softmax probabilities

# Step 2: Compute Log-Softmax
log_probs = logits - torch.log(sum_exp_logits)  # Log-softmax values

# Choose an action (action 0)
chosen_action = 0
log_prob = log_probs[chosen_action]

# Step 3: Compute Gradients Manually
grad_manual = torch.eye(len(probs))[chosen_action] - probs  # (delta_i,a - p_i)


# Print results
print("Softmax Probabilities:", probs)
print("Log-Softmax Values:", log_probs)
print("Manual Gradients:", grad_manual)





Softmax Probabilities: tensor([0.6590, 0.2424, 0.0986], grad_fn=<DivBackward0>)
Log-Softmax Values: tensor([-0.4170, -1.4170, -2.3170], grad_fn=<SubBackward0>)
Manual Gradients: tensor([ 0.3410, -0.2424, -0.0986], grad_fn=<SubBackward0>)
