# Derivations of biases using sympy

This notebook presents some of the paper's derivations using sympy, which may help readers to confirm the derivations.

In [None]:
# Copyright (c) 2025 Graphcore Ltd. All rights reserved.
%load_ext autoreload
%autoreload 2
from sympy import (
    Rational,
    symbols,
    summation,
    simplify,
    expand,
    Sum,
    Integral,
    Piecewise,
    Heaviside
)
import sympy

_1 = sympy.Wild("_1")
_2 = sympy.Wild("_2")
_3 = sympy.Wild("_3")


# Write eq("a",b) to display "a = b"
def eq(nm : str, expr):
    return sympy.Eq(sympy.Symbol(nm), expr)

# Sympy constants
one = Rational(1)
half = Rational(1,2)

# Sympy pretty printing: Prettier than srepr, more truthful than repr
import pprint

def _pprint_expr(self, x, stream, indent, allowance, context, level):
    """
    Modified from pprint dict https://github.com/python/cpython/blob/3.7/Lib/pprint.py#L194
    """
    if not x.args:
        stream.write(repr(x))
    else:
      fnm = x.func.__name__
      stream.write(fnm + "(")
      self._format_items(x.args, stream, indent + len(fnm), allowance + 1, context, level)
      stream.write(")")
pprint.PrettyPrinter._dispatch[sympy.Expr.__repr__] = _pprint_expr

# Example (small width to see the indentation)
from  pprint import pp
x,i,j = symbols('x i j')
ex = Heaviside(i + j - i**2)
ex = Sum(Heaviside(i + j - i**2), (i, j, ex), (j, 0, ex))
pp(ex, width=40)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Sum(Heaviside(-i**2 + i + j),
    (i, j, Heaviside(-i**2 + i + j)),
    (j, 0, Heaviside(-i**2 + i + j)))


# Case 1: Infinite precision inputs, limited-precision random variables

## Computing the bias: SRFF

Let $N$ be the number of SR bits, and $0 \le n \le 2^N$ the supplied bits.

The bias averaged over the range $x\in(0,1)$, corresponding to the range between two successive floats is
$$\begin{darray}{rcl}
\sum_{n=0}^{2^N-1} \int_0^1 \mathbb1[RoundAway_{SRFF}(x,n)] \mathrm dx - \frac12& = &
\sum_{n=0}^{2^N-1} \int_0^1 \mathbb1[x + n\times2^{-N} \ge 1] \mathrm dx - \frac12\\
& = & \sum_{n=0}^{2^N-1} \int_{1 - n\times2^{-N}}^1 \mathrm dx - \frac12\\
& = & -2^{-(N+1)}
\end{darray}$$
which we can show in sympy...

In [2]:
# Define variables
x, n = symbols("x n", real=True, positive=True)
N = symbols("N", integer=True, positive=True)

# Bias for SRFF
bias_n = Integral(1, (x, 1 - n * 2 ** (-N), 1))
bias = 2**-N * summation(bias_n, (n, 0, 2**N - 1)) - half
display(eq("Bias", bias))
display(eq("...", simplify(simplify(expand(bias.doit())))))

# Assert expected = computed
expected_bias = -(2 ** (-(N + 1)))
assert expand(bias.doit() - expected_bias) == 0

KeyboardInterrupt: 

So, the bias is nonzero - values are generally rounded towards zero - and the bias reduces to zero as $N \rightarrow \infty$.

## Bias computation: SRF

We can perform the same sort of calculation for the first-order correction, dubbed SRF in the paper:

In [None]:
# Bias for SRF
bias_n = Integral(1, (x, 1 - (n + half) * 2 ** (-N), 1))
bias = 2**-N * summation(bias_n, (n, 0, 2**N - 1)) - half
display(eq("Bias", bias))
display(eq("...", simplify(expand(bias.doit()))))

# Assert that bias is zero as expected
assert expand(bias.doit()) == 0

Eq(Bias, -1/2 + Sum(Integral(1, (x, 1 - (n + 1/2)/2**N, 1)), (n, 0, 2**N - 1))/2**N)

Eq(..., 0)


## Case 2: Finite-precision inputs and limited-precision random variables

The above derivations model the inputs `v` as being infinite precision, when in practice they may come from a more limited range (e.g. bfloat16 has P=8).

The following derivation show that SRF is biased in that case.

## Bias of SRFF, finite-precision inputs

Following the paper, this calculation is in two parts.  We first compute the bias at a single x value:
$$
\begin{darray}{rcl}
bias_{SRFF}(x) &=& 2^{-N}\sum_{n=0}^{2^N-1} \mathbb 1[R_{SRFF}(x,n)] - x\\
&=&2^{-N}\sum_{n=0}^{2^N-1} \mathbb 1[x + n \times 2^{-N} \ge 1] - x
\end{darray}
$$

In [None]:
# Define variables
x, n = symbols("x n", real=True, positive=True)
i, N, D = symbols("i N D", integer=True, positive=True)

bias_x = 2**-N * Sum(Heaviside(x + 2**-N * n - 1, 1), (n, 0, 2**N-1)) - x

# Simplify
display(eq("bias_{SRFF}(x)", bias_x))


Eq(bias_{SRFF}(x), -x + Sum(Heaviside(x - 1 + n/2**N, 1), (n, 0, 2**N - 1))/2**N)

Then the expectation over the $2^D$ input values is:
$$
bias_{SRFF,D} = 2^{-D} \sum_{i=0}^{2^D-1} bias_{SRFF}(x := 2^{-D}i)
$$

In [None]:
i_sum_lo = 0
i_sum_hi = 2**D - 1
bias = 2**-D * Sum(bias_x.subs(x, 2**-D * i), (i, i_sum_lo, i_sum_hi)) 
display(bias)
bias1 = simplify(simplify(expand(bias)))
display(bias1)

Sum(Sum(Heaviside(-1 + n/2**N + i/2**D, 1), (n, 0, 2**N - 1))/2**N - i/2**D, (i, 0, 2**D - 1))/2**D

2**(-D - 1) + 2**(-D - N)*Sum(Heaviside(-1 + n/2**N + i/2**D, 1), (n, 0, 2**N - 1), (i, 0, 2**D - 1)) - 1/2

In [None]:
import re

# def rules(x):
#   subs(Heaviside(x, 1), Piecewise((1, x >= 0), (0, True)))
#   subs(Heaviside(x, 1), Heaviside(x * 2**D))

def swapsum(x):
  # rewrite
  if x.func == Sum:
    args = tuple(map(swapsum, x.args))
    return x.func(args[0], args[2], args[1], *args[3:])

  # copy
  if x.args:
    return x.func(*map(swapsum, x.args))
  else:
    return x

bias1 = simplify(simplify(expand(bias)))

display(eq('expanded', bias1))

bias1 = swapsum(bias1)
display(bias1)

# H[x] = H[kx] for positive k - use this to simplify inside sum
bias1 = bias1.replace(Heaviside(_1, _2), Heaviside(_1 * 2**D, _2))
display(eq(r'replace[H[v] \rightarrow H[2^{D} v]]', bias1))

bias1 = simplify(simplify(expand(bias1)))
display(eq(r'simplified', bias1))



Eq(expanded, 2**(-D - 1) + 2**(-D - N)*Sum(Heaviside(-1 + n/2**N + i/2**D, 1), (n, 0, 2**N - 1), (i, 0, 2**D - 1)) - 1/2)

2**(-D - 1) + 2**(-D - N)*Sum(Heaviside(-1 + n/2**N + i/2**D, 1), (i, 0, 2**D - 1), (n, 0, 2**N - 1)) - 1/2

Eq(replace[H[v] \rightarrow H[2^{D} v]], 2**(-D - 1) + 2**(-D - N)*Sum(Heaviside(2**D*(-1 + n/2**N + i/2**D), 1), (i, 0, 2**D - 1), (n, 0, 2**N - 1)) - 1/2)

Eq(simplified, 2**(-D - 1) + 2**(-D - N)*Sum(Heaviside(-2**D + 2**(D - N)*n + i, 1), (i, 0, 2**D - 1), (n, 0, 2**N - 1)) - 1/2)

In [None]:

# Now replace sum_i=lo to hi H[i + k]
#             sum_i=lo to hi  {i >= -k}
#        with sum_i=max(lo,-k) to hi of 1
# Can't use sympy replace...
def match(x):
  # Sum(
  #     Heaviside(Add(i, _1_), 1),
  #     Tuple(i, lo, hi),
  #     ...
  # )
  ok = (x.func == Sum and
          x.args[0].match(Heaviside(i + _1, 1)) and
          x.args[1].match((i, _2, _3)))
  return ok

def replace(x):
  # Sum(
  #     1,
  #     Tuple(i, lo, hi),
  #     ...
  # )
  #
  rep = {
    **x.args[0].match(Heaviside(i + _1, 1)),
    **x.args[1].match((i, _2, _3))
  }
  return Sum(one, (i, sympy.Max(rep[_2], -rep[_1]), rep[_3]), *x.args[2:])

bias1 = bias1.replace(match, replace)
display(eq('replace sum(H[])', bias1))

bias1= simplify(bias1)
display(eq('replace sum(H[])', bias1))


Eq(replace sum(H[]), 2**(-D - 1) + 2**(-D - N)*Sum(1, (i, Max(0, 2**D - 2**(D - N)*n), 2**D - 1), (n, 0, 2**N - 1)) - 1/2)

Eq(replace sum(H[]), 2**(-D - 1) + 2**(-D - N)*Sum(2**D - Max(0, 2**D - 2**(D - N)*n), (n, 0, 2**N - 1)) - 1/2)

In [None]:
# replace max(a,b) \rightarrow max(ka, kb)/k
bias1 = bias1.replace(sympy.Max(_1, _2), sympy.Max(_1 * 2**-D, _2* 2**-D)* 2**D)
bias1 = simplify(bias1)
display(eq(r'replace: max(a,b) \rightarrow max(ka, kb)/k', bias1))

Eq(replace: max(a,b) \rightarrow max(ka, kb)/k, 2**(-D - 1) - 1/2 + Sum(1 - Max(0, 2**D*(1 - n/2**N))/2**D, (n, 0, 2**N - 1))/2**N)

In [None]:
# Now it's clear that (1-2^-N n) is positive, as n < 2^N
bias1 = bias1.replace(sympy.Max(0, _1), _1)
bias1 = expand(bias1)
display(eq(r'bias_SRFF,D', bias1))


Eq(bias_SRFF,D, -1/(2*2**N) + 1/(2*2**D))