In [1]:
import numpy as np

#### Maximum (LogSumExp)

- https://en.wikipedia.org/wiki/Smooth_maximum
- https://en.wikipedia.org/wiki/LogSumExp

In [22]:
x = np.array([3, 1, 4, 5, 9, 2, 6])

def maximum_approx(x, k):
  return np.log(np.sum(np.exp(x*k))) / k

def maximum(x):
  return np.max(x)

print(f"ground truth: {maximum(x)}")
for k in [1, 2, 4, 8]:
  print(f"LogSumExp (k={k}): {maximum_approx(x, k=k)}")

ground truth: 9
LogSumExp (k=1): 9.075633077398951
LogSumExp (k=2): 9.00143130092675
LogSumExp (k=4): 9.000001564706889
LogSumExp (k=8): 9.00000000000472


#### Softmax

- softmax 是 onehot(argmax) 的近似？
- 向量里每个值减去最大值之后，最大值变为 1，其余值变为复数，再经过一个 exp，得到 softmax。


In [23]:
x = np.array([3, 1, 4, 5, 9, 2, 6])

def softmax(x):
  return np.exp(x) / np.exp(x).sum()

def softmax_approx(x, k=1):
  return np.exp(x - maximum_approx(x, k=k))
  
print(f"ground truth: {softmax(x)}")
print(f"approx softmax: {softmax_approx(x)}")

ground truth: [2.29819079e-03 3.11026302e-04 6.24713027e-03 1.69814607e-02
 9.27156339e-01 8.45457145e-04 4.61603960e-02]
approx softmax: [2.29819079e-03 3.11026302e-04 6.24713027e-03 1.69814607e-02
 9.27156339e-01 8.45457145e-04 4.61603960e-02]


### Argmax

In [26]:
x = np.array([3, 1, 4, 5, 9, 2, 6])

def argmax(x):
  return np.argmax(x)

def argmax_approx(x):
  return np.sum(np.arange(len(x)) * softmax_approx(x))

print(f"ground truth: {argmax(x)}")
print(f"approx: {argmax_approx(x)}")

ground truth: 4
approx: 4.053564685884353
