## 🔸 Log-Sum-Exp Trick

### ❓ What is log-sum-exp?

The naive computation of:

$$
\log\left( \sum_{i=1}^{n} e^{z_i} \right)
$$

can overflow when $ z_i $ is large (e.g., $ e^{1000} $ 🔥).

To stabilize this, we use:

$$
\log\left( \sum_{i=1}^{n} e^{z_i} \right) = m + \log\left( \sum_{i=1}^{n} e^{z_i - m} \right)
$$

where:

$$
m = \max(z_1, z_2, \dots, z_n)
$$

This is the **log-sum-exp trick**.

---

### ✅ Purpose of the Log-Sum-Exp Trick

- **Numerical stability** when taking logs of sums of exponentials
- **Used in**:
  - Softmax denominator (`log(∑ e^z)` in log-softmax)
  - Log-likelihood computations
  - Energy-based models, variational inference

---

Let’s walk through **why** this identity works:

$$
\log\left(\sum_{i=1}^{n} e^{z_i}\right) = m + \log\left(\sum_{i=1}^{n} e^{z_i - m}\right)
$$

This is known as the **Log-Sum-Exp trick**, and it's super important for **numerical stability**.

---

### 🔧 Step-by-step Explanation

Let:
- $ z = [z_1, z_2, ..., z_n] $
- $ m = \max(z_1, z_2, ..., z_n) $

We want to compute:
$$
\log\left(\sum_{i=1}^{n} e^{z_i} \right)
$$

But this might **overflow** if any $ z_i $ is large (e.g. $ z_i = 1000 \Rightarrow e^{z_i} $ is huge).

---

### 💡 Trick: Factor out the largest value

We rewrite the sum by factoring $ e^m $ out of all terms:

$$
\sum_{i=1}^{n} e^{z_i} = e^m \sum_{i=1}^{n} e^{z_i - m}
$$

Then apply log:

$$
\log\left( \sum_{i=1}^{n} e^{z_i} \right)
= \log\left( e^m \sum_{i=1}^{n} e^{z_i - m} \right)
$$

Now use the identity:

$$
\log(a \cdot b) = \log(a) + \log(b)
$$

So:

$$
= \log(e^m) + \log\left( \sum_{i=1}^{n} e^{z_i - m} \right)
= m + \log\left( \sum_{i=1}^{n} e^{z_i - m} \right)
$$

🎉 And that's the trick!

---

### ✅ Why it's stable

- **Before**: Large $ z_i $ → large $ e^{z_i} $ → overflow
- **After**: $ z_i - m \le 0 $ → $ e^{z_i - m} \le 1 $ → no overflow

Even if `z = [1000, 1001, 1002]`, subtracting `m = 1002` gives `[-2, -1, 0]`, which are totally safe to exponentiate.

---




In [1]:
import torch

z = torch.tensor([1000.0, 1001.0, 1002.0])
m = torch.max(z)
log_sum_exp_stable = m + torch.log(torch.sum(torch.exp(z - m)))
print(log_sum_exp_stable)


tensor(1002.4076)


https://gregorygundersen.com/blog/2020/02/09/log-sum-exp/

In [3]:
import numpy as np
def logsumexp(x):
    c = x.max()
    return c + np.log(np.sum(np.exp(x - c)))

In [16]:
x = np.array([1000, 1000, 1000])
np.exp(x)



  np.exp(x)


array([inf, inf, inf])

In [15]:
print(logsumexp(x))
np.exp(x - logsumexp(x))

1001.0986122886682


array([0.33333333, 0.33333333, 0.33333333])

In [6]:
print(logsumexp(x))

np.exp(x - logsumexp(x))

1001.0986122886682


array([0.33333333, 0.33333333, 0.33333333])

In [11]:
x = np.array([-1000, -1000, -1000])
print(np.exp(x))
print(logsumexp(x))
print(np.exp(x - logsumexp(x)))

[0. 0. 0.]
-998.9013877113318
[0.33333333 0.33333333 0.33333333]


In [9]:
x = np.array([-1000, -1000, 1000])

print(np.exp(x))

print(logsumexp(x))
print(np.exp(x - logsumexp(x)))

[ 0.  0. inf]
1000.0
[0. 0. 1.]


  print(np.exp(x))
