In [1]:
%config Completer.use_jedi=False
import collections
from typing import *
import functools as ft
import torch as T
import numpy as np

In [2]:
Shape = Tuple[int, ...]

def bootstrap_scaling_factors(op: Callable[..., T.Tensor], args: Dict[str, Shape],
                              n_reps: int) -> Dict[str, np.ndarray]:
    results = collections.defaultdict(list)
    for _ in range(n_reps):
        inputs = [T.tensor(v, requires_grad=True)
                  if isinstance(v, np.ndarray) else
                  T.randn(v, requires_grad=True)
                  for v in args.values()]
        output = op(*inputs)
        output.backward(T.randn_like(output))
        results["y"].append(float(1 / T.std(output)))
        for arg, input in zip(args, inputs):
            results[f"grad_{arg}"].append(float(1 / T.std(input.grad)))
    return {k: np.array(v) for k, v in results.items()}

def show(factors: Dict[str, np.ndarray]) -> None:
    for k, samples in factors.items():
        samples = np.array(samples)
        confidence = 2 * np.std(samples) / np.sqrt(len(samples) - 1) if 2 <= samples.size else float("NaN")
        print(f"   {k:<10} {np.mean(samples):<8.4g} ± {confidence:.2g}")

In [3]:
for activation in [T.nn.functional.gelu, T.tanh, T.sigmoid]:
    print(f"### {activation.__name__}")
    show(bootstrap_scaling_factors(activation, dict(x=(int(1e6),)), n_reps=100))

### gelu


  allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag


   y          1.701    ± 0.00035
   grad_x     1.481    ± 0.00035
### tanh
   y          1.593    ± 0.00013
   grad_x     1.468    ± 0.00027
### sigmoid
   y          4.802    ± 0.00049
   grad_x     4.722    ± 0.00077


In [4]:
print("### relu")
print("Expected:")
show(dict(y=(2 / (1 - 1/np.pi)) ** 0.5, grad_x=2 ** 0.5))
print("Simulated:")
show(bootstrap_scaling_factors(T.nn.functional.relu, dict(x=(int(1e6),)), n_reps=100))

### relu
Expected:
   y          1.713    ± nan
   grad_x     1.414    ± nan
Simulated:
   y          1.713    ± 0.00037
   grad_x     1.414    ± 0.00027


In [5]:
print("### reduce_sum")
N = 512
show(dict(y=N ** -0.5, grad_x=1))
show(bootstrap_scaling_factors(ft.partial(T.sum, dim=1), dict(x=(int(1e4), N)), n_reps=100))

### reduce_sum
   y          0.04419  ± nan
   grad_x     1        ± nan
   y          0.04421  ± 6.8e-05
   grad_x     0.9997   ± 0.0013


In [6]:
print("### weighted_sum")
weight = T.rand(1000, 512)
x = T.randn_like(weight, requires_grad=True)
y = T.sum(weight * x, 1)
y.backward(T.randn_like(y))

with T.no_grad():
    show(dict(
        sy=T.std(T.sum(weight ** 2, 1) ** -0.5 * y),
        sgrad_x=T.std(weight ** -1 * x.grad),
    ))

### weighted_sum
   sy         0.9957   ± nan
   sgrad_x    0.9998   ± nan


In [7]:
print("### matmul")
B, M, N = 128, 256, 512
show(dict(y=M ** -0.5, grad_x=N ** -0.5, grad_w=B ** -0.5))
show(bootstrap_scaling_factors(T.matmul, dict(x=(B, M), w=(M, N)), n_reps=1000))

### matmul
   y          0.0625   ± nan
   grad_x     0.04419  ± nan
   grad_w     0.08839  ± nan
   y          0.0625   ± 2e-05
   grad_x     0.0442   ± 1.5e-05
   grad_w     0.08839  ± 2.8e-05


In [8]:
print("### layer_norm")
B, N = 200, 512
show(dict(y=1, grad_x=1, grad_w=B ** -0.5, grad_b=B ** -0.5))
show(bootstrap_scaling_factors(
    lambda x, w, b: T.nn.functional.layer_norm(x, (N,), w, b),
    dict(x=(B, N), w=np.ones(N, np.float32), b=np.zeros(N, np.float32)), n_reps=100))

### layer_norm
   y          1        ± nan
   grad_x     1        ± nan
   grad_w     0.07071  ± nan
   grad_b     0.07071  ± nan
   y          1        ± 1.2e-08
   grad_x     0.9992   ± 0.00061
   grad_w     0.07089  ± 0.00046
   grad_b     0.07093  ± 0.00047


In [9]:
print("### softmax_ce")
B, S = 1000, 5
show(dict(y=1, grad_x=S/np.sqrt(S-1)))
show(bootstrap_scaling_factors(
    lambda x: T.nn.functional.cross_entropy(x, T.randint(S, size=(B,)), reduction="none"), dict(x=np.zeros((B, S), np.float32)), n_reps=100))

### softmax_ce
   y          1        ± nan
   grad_x     2.5      ± nan
   y          inf      ± nan
   grad_x     2.501    ± 0.01


  x = asanyarray(arr - arrmean)
