In [None]:
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

# Many architectures have shifts where the right-hand-side is signed. A negative
# RHS is the same as a positive shift in the other direction.
def shift_right(x, y):
  return np.floor(x / 2**y)
def shift_left(x, y):
  return np.floor(x * 2**y)
def rounding_shift_right(x, y):
  return np.round(x / 2**y)
def rounding_shift_left(x, y):
  return np.round(x * 2**y)

def bitwise_and(x, y):
  return np.mod(x, y + 1)

# This is sqrdmulh on ARM
def multiply_2x_high(x, y):
  return rounding_shift_right(x * y, 15)

def relative_error(x, y):
  return (x - y) / (np.maximum(x, y) + 1e-3)

def plot_results(x, exact, approxs, title, logx = False, logy = False, relative = False, log2_xscale = 0, log2_yscale = 0):
  fig, [p1, p2] = plt.subplots(2, 1)

  p1.set_xlabel('x')
  if logx:
    p1.set_xscale('log')
  p1.set_ylabel(title)
  if logy:
    p1.set_yscale('log')

  xscale = 2**log2_xscale
  yscale = 2**log2_yscale

  exact = np.round(exact*yscale)/yscale

  p1.plot(x/xscale, exact)
  for approx in approxs:
    p1.plot(x/xscale, approx/yscale)

  p2.set_xlabel('x')
  if logx:
    p2.set_xscale('log')

  p2.set_ylabel('relative error' if relative else 'error')
  for approx in approxs:
    p2.plot(x/xscale, relative_error(approx/yscale, exact) if relative else approx/yscale - exact)

def eval_poly(x, p, q):
  x1 = rounding_shift_left(x, 15 - q)
  y = p[0]
  xi = x1
  for i in p[1:]:
    y = y + multiply_2x_high(i, xi)
    xi = multiply_2x_high(xi, x1)
  return rounding_shift_right(y, 15 - q)

points = 6
degree = 3
log2_poly_x = np.arange(points, 2 * points + 1) / points
log2_poly_y = np.log2(log2_poly_x)
log2_poly = np.polyfit(log2_poly_x - 1, log2_poly_y, degree)

exp2_poly_x = np.arange(points, 2 * points + 1) / points
exp2_poly_y = np.exp2(exp2_poly_x - 1) - 1
exp2_poly = np.polyfit(exp2_poly_x - 1, exp2_poly_y, degree)

log2_poly = log2_poly[::-1]
exp2_poly = exp2_poly[::-1]

print(log2_poly)
print(exp2_poly)

log2_poly = np.round(log2_poly * 2**15)
exp2_poly = np.round(exp2_poly * 2**15)
exp2_poly[0] = 0

print(log2_poly)
print(exp2_poly)

In [None]:
# Approximate N*log2(x*2^q_x), where N = 2^q, and the intermediate computations are
# restricted to be integers.
def approx_log2(x, q, q_x = 0):
  # This can be computed with count_leading_zeros
  floor_log2_x = np.select([x > 0], [np.floor(np.log2(x))], [-1])

  # We've computed log2(x*2^q_x) = log2(x) + q_x. Subtract that offset now
  # before multiplying by the result quantization.
  result = shift_left(floor_log2_x - q_x, q)

  frac = bitwise_and(shift_right(x, floor_log2_x - q), 2**q - 1)

  return result + eval_poly(frac, log2_poly, q)

x = np.arange(1, 10000)
q = 15
q_x = 2
log2_x = np.log2(x / 2**q_x)
approx_log2_x = approx_log2(x, q, q_x)

plot_results(x, log2_x, [approx_log2_x], 'log2(x)', logx=True, log2_xscale=q_x, log2_yscale=q)

In [None]:

# Approximate 2^(x/2^q_x)*2^q
def approx_exp2(x, q_x, q):
  int_part = shift_right(x, q_x)
  frac_part = x - shift_left(int_part, q_x)

  frac_part = eval_poly(frac_part, exp2_poly, q_x)

  exp_int_part = shift_left(1, int_part + q)
  return exp_int_part + rounding_shift_right(exp_int_part * frac_part, q_x)

q_x = 10
q = 15
x = np.arange(-4000, 2000)
approx_exp2_x = approx_exp2(x, q_x, q)
exact = np.exp2(x / 2**q_x)

plot_results(x, exact, [approx_exp2_x], '2^x', False, True, relative=True, log2_xscale=q_x, log2_yscale=q)


In [None]:
q = 15
x = np.arange(10, 10000) * 10
round_trip_x = approx_exp2(approx_log2(x, q), q, 0)

plot_results(x, x, [round_trip_x], '2^log2(x)', logx=True, logy=True, relative=True)

In [None]:
# Approximate 2^q*sqrt(2^(x/2^q_x))
def sqrt_approx_exp2(x, q_x, q):
  return approx_exp2(x, q_x + 1, q)

q = 11
q_x = 8
x = np.arange(-1000, 2000)
approx_exp2_x = sqrt_approx_exp2(x, q_x, q)
exact = np.sqrt(np.exp2(x / 2**q_x))

plot_results(x, exact, [approx_exp2_x], 'sqrt(2^x)', relative=True, log2_xscale=q_x, log2_yscale=q)


In [None]:
# Approximate sqrt(x) = 2^((1/2)*log2(x))
def approx_sqrt(x, q):
  # log2(x) will never be larger than 32, for 32-bit x. So to make the result
  # fit in a 16-bit integer, we can make the precision 2^16/32 = 2048.
  q_x = 11;

  log2_sqrt_x = approx_log2(x, q_x - 1)
  return approx_exp2(log2_sqrt_x, q_x, q)

q = 15
x = np.arange(1, 10000)**2
sqrt_x = np.sqrt(x)
approx_sqrt_x = approx_sqrt(x, q)

plot_results(x, sqrt_x, [approx_sqrt_x], 'sqrt(x)', log2_yscale=q, relative=True)


In [None]:
# Approximate 2^31/sqrt(x) = 2^(-(1/2)*log2(x))
def approx_reciprocal_sqrt(x):
  q = 15
  log2_sqrt_x = approx_log2(x, q - 1)
  return approx_exp2(-log2_sqrt_x, q, 31)

x = np.arange(1, 10000)**2
inv_sqrt_x = 1 / np.sqrt(x)
approx_reciprocal_sqrt_x = approx_reciprocal_sqrt(x)

plot_results(x, inv_sqrt_x, [approx_reciprocal_sqrt_x], '1/sqrt(x)', True, True, True, log2_yscale=31)


In [None]:
# Approximate 2^32/x = 2^32*2^(-log2(x))
def approx_reciprocal(x):
  q = 15;
  log2_x = approx_log2(x, q)
  return approx_exp2(-log2_x, q, 31)

x = 1.01**np.arange(0, 2000)
inv_x = 1 / x
approx_inv_x = approx_reciprocal(x)
# This is ~sqrt(2) times more accurate, but maybe not practical for large x.
approx_inv_sqrt_x2 = approx_reciprocal_sqrt(x*x)

plot_results(x, inv_x, [approx_inv_x], '1/x', True, True, log2_yscale=31, relative=True)
plot_results(x, inv_x, [approx_inv_sqrt_x2], '1/x', True, True, log2_yscale=31, relative=True)


In [None]:
# Approximate log2(exp2(x) + c)
def approx_log2_exp2_plus_constant(x, c, q_x, q):
  # When x/2^q_x is large, approx_exp2 below will overflow. But when it is large
  # we don't need it to be very precise
  q_exp = 16 #np.minimum(16, 16 - np.floor(np.log2(np.maximum(x, 1))))
  one = 2**q_exp

  one_plus_exp2_x = one * c + approx_exp2(x, q_x, q_exp)
  # Mimic overflow of int32
  one_plus_exp2_x = np.mod(one_plus_exp2_x, 2**31)

  raw = approx_log2(one_plus_exp2_x, q, q_exp)

  line = rounding_shift_right(x, q_x - q)

  threshold = 30 - q_exp
  result = np.select([shift_right(x, q_x) < threshold], [raw], line)
  return result

def approx_log2p1_exp2(x, q_x, q):
  return approx_log2_exp2_plus_constant(x, 1, q_x, q)

def approx_log2m1_exp2(x, q_x, q):
  return approx_log2_exp2_plus_constant(x, -1, q_x, q)

x = np.arange(-4000, 4000)*8
q_x = 11
q = 15

exact = np.log2(np.exp2(x / 2**q_x) + 1)
approx = approx_log2p1_exp2(x, q_x, q)
plot_results(x, exact, [approx], 'log2(2^x + 1)', log2_xscale=q_x, log2_yscale=q)

x = np.arange(1, 4000)*8
exact = np.log2(np.exp2(x / 2**q_x) - 1)
approx = approx_log2m1_exp2(x, q_x, q)
plot_results(x, exact, [approx], 'log2(2^x - 1)', log2_xscale=q_x, log2_yscale=q)


In [None]:
# Approximate logistic(x) = 1/(e^-x + 1)
# = 2^log2(1/(e^-x + 1))
# = 2^-log2(e^-x + 1)
def approx_logistic(x, q_x, q):
  x2 = multiply_2x_high(x, np.round(-np.log2(np.exp(1)) * 2**14))
  q_exp = 11
  log2_d = approx_log2p1_exp2(x2, q_x - 1, q_exp)
  return approx_exp2(-log2_d, q_exp, q)

x = np.arange(-4000, 4000)*8
q_x = 11
q = 15
exact = 1 / (1 + np.exp(-x / 2**q_x))
approx = approx_logistic(x, q_x, q)
plot_results(x, exact, [approx], '1/(1 + e^-x)', log2_xscale=q_x, log2_yscale=q)

In [None]:
# Approximate tanh(x) = (e^2x - 1)/(e^2x + 1)
# = 2^log2((e^2x - 1)/(e^2x + 1))
# = 2^(log2(e^2x - 1) - log2(e^2x + 1))
def approx_tanh(x, q_x, q):
  abs_x_base2 = multiply_2x_high(np.abs(x), np.round(np.log2(np.exp(1)) * 2**14))
  q_exp = 11
  log2_n = approx_log2m1_exp2(abs_x_base2, q_x - 2, q_exp)
  log2_d = approx_log2p1_exp2(abs_x_base2, q_x - 2, q_exp)
  # Saturate at int16
  log2_n = np.clip(log2_n, -(2**15), 2**15)
  log2_d = np.clip(log2_d, -(2**15), 2**15)
  return np.sign(x) * approx_exp2(log2_n - log2_d, q_exp, q)

x = np.arange(-4000, 4000)*8
q_x = 12
q = 15
exact = np.tanh(x / 2**q_x)
approx = approx_tanh(x, q_x, q)

points = 20
poly_x = np.arange(0, points * 3) / points
poly_y = np.tanh(poly_x)
poly = np.polyfit(poly_x, poly_y, 6)
approx2 = np.polyval(poly, x / 2**q_x) * 2**q


plot_results(x, exact, [approx], 'tanh(x)', log2_xscale=q_x, log2_yscale=q)