In [None]:
from contextlib import contextmanager
from ipywidgets import interact
from matplotlib import pyplot as plt
import math
import numpy as np
import random
import struct
import tensorflow as tf

In [None]:
@contextmanager
def np_error_mode(all=None, **kwargs):
    old = np.seterr(**kwargs)
    try:
        yield
    finally:
        np.seterr(**old)

In [None]:
def get_next_z_coeff(z_coeff):
    out = [0.] * len(z_coeff) * 2
    for i, zi in enumerate(z_coeff):
        for j, zj in enumerate(z_coeff):
            out[i + j + 1] += zi * zj
    assert out[0] == 0.
    out[0] = 1.
    return out

z_coeffs = [[1.]]
for i in range(8):
    z_coeffs += [get_next_z_coeff(z_coeffs[-1])]
#z_coeffs

In [None]:
@interact(
    x=(-2., 2., 0.01),
    y=(-2., 2., 0.01),
    iterations=(0, 50),
    zoom=(-6., 2., 0.1),
)
def f(
    show_basis=True,
    show_terms=True,
    show_sum=True,
    x=0.27,
    y=0.44,
    iterations=19,
    zoom=0.5,
):
    fig, ax = plt.subplots(1, 1, figsize=(16, 12))
    ax.set_aspect('equal')
    lim = 2. * 0.5 ** zoom
    ax.set_xlim(-lim, lim)
    ax.set_ylim(-lim, lim)
    ax.axhline(y=0, color='k', lw=0.5)
    ax.axvline(x=0, color='k', lw=0.5)
    
    z0 = x + 1j * y
    
    ax.add_artist(plt.Circle((0., 0.), 1., fill=False))
    ax.add_artist(plt.Circle((0., 0.), 2., fill=False))

    if show_sum:
        zs = np.array([z0])
        with np_error_mode('ignore'):
            for i in range(iterations):
                zs = np.append(zs, zs[-1] ** 2 + zs[0])
                if np.abs(zs[-1]) > 10:
                    break
        ax.plot(zs.real, zs.imag, 'r--', lw=1.)
    
    term_iters = min(iterations, len(z_coeffs) - 1)

    if show_basis or show_terms:
        tis = np.arange(1, 2 ** term_iters + 1, 1)
        basis = np.array([z0 ** ti for ti in tis])
        
        if show_basis:
            ts = np.arange(1, 2 ** term_iters + 1, 0.1)
            basis_lines = np.array([z0 ** t for t in ts])
            ax.plot(basis_lines.real, basis_lines.imag, 'b--', lw=1.)            
            ax.plot(basis.real, basis.imag, 'bo')

        if show_terms:
            z_coeff = np.array(z_coeffs[term_iters])
            terms = basis * z_coeff # / max(z_coeff)
            ax.plot(terms.real, terms.imag, 'g--')
            ax.plot(terms.real, terms.imag, 'go', ms=4.)
        
    ax.plot(z0.real, z0.imag, 'ro')


In [None]:
def float_to_fixed(f):
    return int(f * (1 <<  60))

def fixed_to_float(i):
    return float(i) / (1 << 60)

def fixed_multiply(i1, i2):
    return (i1 >> 30) * (i2 >> 30)


In [None]:
a = 0.5
b = 0.3
c = a * b

i = float_to_fixed(a)
j = float_to_fixed(b)
k = fixed_multiply(i, j)

c, fixed_to_float(k)

In [None]:
FIXED_BITS = 50


def float_complex_square(i, j):
    return i ** 2 - j ** 2, 2 * i * j


def float_complex_abs(i, j):
    return i ** 2 + j ** 2


def float_to_fixed(f, bits=FIXED_BITS):
    return tf.cast(f * (1 << bits), tf.int64)


def fixed_to_float(i, bits=FIXED_BITS):
    return tf.cast(i, tf.float64) / (1 << bits)


def fixed_multiply(i1, i2, bits=FIXED_BITS):
    return tf.bitwise.right_shift(i1, bits // 2) * tf.bitwise.right_shift(i2, bits // 2)


def fixed_square(i, bits=FIXED_BITS):
    return fixed_multiply(i, i, bits)


def fixed_complex_multiply(i1, j1, i2, j2, bits=FIXED_BITS):
    return (
        fixed_multiply(i1, i2, bits) - fixed_multiply(j1, j2, bits),
        fixed_multiply(i1, j2, bits) + fixed_multiply(i2, j1, bits),
    )


def fixed_complex_square(i, j, bits=FIXED_BITS):
    return fixed_complex_multiply(i, j, i, j, bits)


def fixed_complex_abs(i, j, bits=FIXED_BITS):
    return fixed_square(i, bits) + fixed_square(j, bits)

In [None]:
a = 0.05
b = 0.03
c = a * b

i = float_to_fixed(a)
j = float_to_fixed(b)
k = fixed_multiply(i, j)

c, fixed_to_float(k)

In [None]:
"""
Trying to figure out how many carry-"digits" are required for gradeschool
multiplication, to see how to do fixed-point arbitrary precision math.

Suppose you wish to multiply two 512-bit numbers together when you only
have 32-bit multiplication at your disposal.  The 512-bit numbers can be
broken up into sixteen 32-bit numbers.  In gradeschool, one learns how to
multiply by similar means, by multiplying each digit with each other
digit and adding the results with shifts and carry.  The gradeschool method
is based on digits in base-10 with values 0-9, but in this case, our digits
are base-(2**32), with values from 0 to (2**32)-1.

The conclusion from the code snippet below is that multiplying a pair of
two-digit numbers seems to require a single carry digit, while multiplying a
pair of numbers with more than two digits seems to generally require two
digits for each carried term.  For 32-bit multiplication this means that the
carry needs to be 64-bit, and that special care must be taken to "carry the
carry" for lack of better words.
"""

def get_max_remainder(base, term):
    return int(sum(i * (base - 1) ** 2 * base ** (i - term) for i in range(term)))

def get_max_digit(base, term):
    return term * (base - 1) ** 2 + get_max_remainder(10, term)

base = 1 << 32
max_digits = [get_max_digit(base, i) for i in range(1, 20)]
[(i, x, int(math.log(x) / math.log(base))) for i, x in enumerate(max_digits)]

In [None]:
def arbitrary_precision_mult(a, b, base):
    assert all(x < base for x in a)
    assert all(x < base for x in b)
    assert len(a) == len(b)  # TODO: pad arrays if necessary, or adjust index range.
    n = len(a)
    c = []
    d = 0
    for i in range(n * 2):
        min_j = max(i - n + 1, 0)
        max_j = min(i, n - 1)
        for j in range(min_j, max_j + 1):
            d += a[j] * b[i - j]
        c.append(d % base)
        assert d < (base ** 3)
        d = d // base
    assert d == 0
    return c

#213 * 427, arbitrary_precision_mult([3, 1, 2], [7, 2, 4], 10)
9999999999 ** 2, arbitrary_precision_mult([9] * 10, [9] * 10, 10)

In [None]:
def mult_with_overflow(a, b, base):
    assert all(x < base for x in a)
    assert all(x < base for x in b)
    assert len(a) == len(b)  # TODO: pad arrays if necessary, or adjust index range.
    n = len(a)
    out = [0] * (n * 2)
    carry = 0
    overflow = 0
    for i in range(n * 2):
        min_j = max(i - n + 1, 0)
        max_j = min(i, n - 1)
        for j in range(min_j, max_j + 1):
            ab = a[j] * b[i - j]
            assert ab < base ** 2
            carry += ab
            overflow += carry // base
            carry %= base
        assert carry < base
        assert overflow < base ** 2
        out[i] = carry
        carry = overflow % base
        overflow //= base
    assert carry == 0, carry
    assert overflow == 0, overflow
    return out

213 * 427, mult_with_better_carry([3, 1, 2], [7, 2, 4], 10)
#9999999999 ** 2, mult_with_better_carry([9] * 10, [9] * 10, 10)

In [None]:
FXA_BITS = 12
FXA_WORDS = 4

def float_to_fxa(f):
    words = [0] * FXA_WORDS
    for i in range(FXA_WORDS):
        floor = np.floor(f)
        words[FXA_WORDS - i - 1] = int(floor)
        f = (f - floor) * 2 ** FXA_BITS
    return words

def fxa_to_float(fxa):
    return fxa[-1] + fxa[-2] * 2 ** -FXA_BITS

#f = 13*2**-12
f = 13.5
fxa = float_to_fxa(f)
f, fxa, fxa_to_float(fxa)

In [None]:
def fxa_multiply(a, b):
    c = mult_with_overflow(a, b, 2 ** FXA_BITS)
    return c[:-1][-FXA_WORDS:]

a, b = 0.5, -0.3
a * b, fxa_to_float(fxa_multiply(float_to_fxa(a), float_to_fxa(b)))

In [None]:
def fxa_add(a, b):
    n = len(a)
    out = [0] * n
    carry = 0
    for i in range(n):
        carry += a[i] + b[i]
        out[i] = carry % (2 ** FXA_BITS)
        carry >>= FXA_BITS
    return out

a, b = 0.5, -0.3
#fxa_add(float_to_fxa(a), float_to_fxa(b))
a + b, fxa_to_float(fxa_add(float_to_fxa(a), float_to_fxa(b)))

In [None]:
def float_to_binary(f):
    bytes_ = struct.pack('!f', f)
    return ''.join(['{:08b}'.format(x) for x in bytes_])


def decode_float_bits(f):
    bits = float_to_binary(f)
    sign = bits[0:1]
    exponent = bits[1:9]
    mantissa = bits[9:]
    return sign, exponent, mantissa


def decode_float_ints(f):
    sign_bits, exponent_bits, mantissa_bits = decode_float_bits(f)
    sign = int(sign_bits)
    exponent = int(exponent_bits, 2)
    mantissa = int(mantissa_bits, 2)
    return sign, exponent, mantissa


def decode_float(f):
    sign, exponent, mantissa = decode_float_ints(f)
    exponent -= 127
    mantissa = mantissa * 2 ** -23 + 1
    if sign:
        mantissa *= -1
    return exponent, mantissa


denorm = 2**-(126+23)
#f = (100*2**-10) + (223*2**-10)
f = (1000 * denorm) + (2230 * denorm)
decode_float_bits(f), decode_float(f), decode_float_ints(f)

In [None]:
# Confirm that float32 can be used to do 12-bit multiply and 23-bit add with lossless carry:

def test_float32_multiply_safety(bits):
    for i in range(10000):
        a = random.randint(-2**bits, 2**bits)
        b = random.randint(-2**bits, 2**bits)
        assert a * b == int((np.float32(a) * np.float32(b))), (a, b)
        
def test_float32_add_safety(bits):
    for i in range(100000):
        a = random.randint(-2**bits, 2**bits)
        b = random.randint(-2**bits, 2**bits)
        assert a + b == int((np.float32(a) + np.float32(b))), (a, b)

test_float32_multiply_safety(12)
#test_float32_multiply_safety(13)  # blows up, as expected.

test_float32_add_safety(23)
#test_float32_add_safety(24)  # blows up, as expected.