Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build/test-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ cloudpickle
colorama>=0.4.4
flatbuffers
hypothesis
mpmath>=1.3
numpy>=1.22
pillow>=9.1.0
portpicker
Expand Down
315 changes: 308 additions & 7 deletions jax/_src/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1467,11 +1467,11 @@ def complex_plane_sample(dtype, size_re=10, size_im=None):
>>> print(complex_plane_sample(np.complex64, 0, 3))
[[-inf -infj 0. -infj inf -infj]
[-inf-3.4028235e+38j 0.-3.4028235e+38j inf-3.4028235e+38j]
[-inf-2.0000052e+00j 0.-2.0000052e+00j inf-2.0000052e+00j]
[-inf-2.0000000e+00j 0.-2.0000000e+00j inf-2.0000000e+00j]
[-inf-1.1754944e-38j 0.-1.1754944e-38j inf-1.1754944e-38j]
[-inf+0.0000000e+00j 0.+0.0000000e+00j inf+0.0000000e+00j]
[-inf+1.1754944e-38j 0.+1.1754944e-38j inf+1.1754944e-38j]
[-inf+2.0000052e+00j 0.+2.0000052e+00j inf+2.0000052e+00j]
[-inf+2.0000000e+00j 0.+2.0000000e+00j inf+2.0000000e+00j]
[-inf+3.4028235e+38j 0.+3.4028235e+38j inf+3.4028235e+38j]
[-inf +infj 0. +infj inf +infj]]

Expand All @@ -1481,16 +1481,18 @@ def complex_plane_sample(dtype, size_re=10, size_im=None):
finfo = np.finfo(dtype)

def make_axis_points(size):
logmin = np.log10(abs(finfo.min))
logtiny = np.log10(finfo.tiny)
logmax = np.log10(finfo.max)
prec_dps_ratio = 3.3219280948873626
logmin = logmax = finfo.maxexp / prec_dps_ratio
logtiny = finfo.minexp / prec_dps_ratio
axis_points = np.zeros(3 + 2 * size, dtype=finfo.dtype)

with warnings.catch_warnings():
# Silence RuntimeWarning: overflow encountered in cast
warnings.simplefilter("ignore")
axis_points[1:size + 1] = -np.logspace(logmin, logtiny, size, dtype=finfo.dtype)
axis_points[-size - 1:-1] = np.logspace(logtiny, logmax, size, dtype=finfo.dtype)
half_neg_line = -np.logspace(logmin, logtiny, size, dtype=finfo.dtype)
half_line = -half_neg_line[::-1]
axis_points[-size - 1:-1] = half_line
axis_points[1:size + 1] = half_neg_line

if size > 1:
axis_points[1] = finfo.min
Expand All @@ -1512,3 +1514,302 @@ def make_axis_points(size):
imag_part = imag_part.reshape((3 + 2 * size_im, -1)).repeat(3 + 2 * size_re, 1)

return real_part + imag_part


class vectorize_with_mpmath(np.vectorize):
"""Same as numpy.vectorize but using mpmath backend for function evaluation.
"""

map_float_to_complex = dict(float16='complex32', float32='complex64', float64='complex128', float128='complex256', longdouble='clongdouble')
map_complex_to_float = {v: k for k, v in map_float_to_complex.items()}

float_prec = dict(
# float16=11,
float32=24,
float64=53,
# float128=113,
# longdouble=113
)

float_minexp = dict(
float16=-14,
float32=-126,
float64=-1022,
float128=-16382
)

float_maxexp = dict(
float16=16,
float32=128,
float64=1024,
float128=16384,
)

def __init__(self, *args, **kwargs):
mpmath = kwargs.pop('mpmath', None)
if mpmath is None:
raise ValueError('vectorize_with_mpmath: no mpmath argument specified')
self.extra_prec_multiplier = kwargs.pop('extra_prec_multiplier', 0)
self.extra_prec = kwargs.pop('extra_prec', 0)
self.mpmath = mpmath
self.contexts = dict()
self.contexts_inv = dict()
for fp_format, prec in self.float_prec.items():
ctx = self.mpmath.mp.clone()
ctx.prec = prec
self.contexts[fp_format] = ctx
self.contexts_inv[ctx] = fp_format

super().__init__(*args, **kwargs)

def get_context(self, x):
if isinstance(x, (np.ndarray, np.floating, np.complexfloating)):
fp_format = str(x.dtype)
fp_format = self.map_complex_to_float.get(fp_format, fp_format)
return self.contexts[fp_format]
raise NotImplementedError(f'get mpmath context from {type(x).__name__} instance')

def nptomp(self, x):
"""Convert numpy array/scalar to an array/instance of mpmath number type.
"""
if isinstance(x, np.ndarray):
return np.fromiter(map(self.nptomp, x.flatten()), dtype=object).reshape(x.shape)
elif isinstance(x, np.floating):
mpmath = self.mpmath
ctx = self.get_context(x)
prec, rounding = ctx._prec_rounding
if np.isposinf(x):
return ctx.make_mpf(mpmath.libmp.finf)
elif np.isneginf(x):
return ctx.make_mpf(mpmath.libmp.fninf)
elif np.isnan(x):
return ctx.make_mpf(mpmath.libmp.fnan)
elif np.isfinite(x):
mantissa, exponent = np.frexp(x)
man = int(np.ldexp(mantissa, prec))
exp = int(exponent - prec)
r = ctx.make_mpf(mpmath.libmp.from_man_exp(man, exp, prec, rounding))
assert ctx.isfinite(r), r._mpf_
return r
elif isinstance(x, np.complexfloating):
re, im = self.nptomp(x.real), self.nptomp(x.imag)
return re.context.make_mpc((re._mpf_, im._mpf_))
raise NotImplementedError(f'convert {type(x).__name__} instance to mpmath number type')

def mptonp(self, x):
"""Convert mpmath instance to numpy array/scalar type.
"""
if isinstance(x, np.ndarray) and x.dtype.kind == 'O':
x_flat = x.flatten()
item = x_flat[0]
ctx = item.context
fp_format = self.contexts_inv[ctx]
if isinstance(item, ctx.mpc):
dtype = getattr(np, self.map_float_to_complex[fp_format])
elif isinstance(item, ctx.mpf):
dtype = getattr(np, fp_format)
else:
dtype = None
if dtype is not None:
return np.fromiter(map(self.mptonp, x_flat), dtype=dtype).reshape(x.shape)
elif isinstance(x, self.mpmath.ctx_mp.mpnumeric):
ctx = x.context
if isinstance(x, ctx.mpc):
fp_format = self.contexts_inv[ctx]
dtype = getattr(np, self.map_float_to_complex[fp_format])
r = dtype().reshape(1).view(getattr(np, fp_format))
r[0] = self.mptonp(x.real)
r[1] = self.mptonp(x.imag)
return r.view(dtype)[0]
elif isinstance(x, ctx.mpf):
fp_format = self.contexts_inv[ctx]
dtype = getattr(np, fp_format)
if ctx.isfinite(x):
sign, man, exp, bc = self.mpmath.libmp.normalize(*x._mpf_, *ctx._prec_rounding)
assert bc >= 0, (sign, man, exp, bc, x._mpf_)
if exp + bc < self.float_minexp[fp_format]:
return -ctx.zero if sign else ctx.zero
if exp + bc > self.float_maxexp[fp_format]:
return ctx.ninf if sign else ctx.inf
man = dtype(-man if sign else man)
r = np.ldexp(man, exp)
assert np.isfinite(r), (x, r, x._mpf_, man)
return r
elif ctx.isnan(x):
return dtype(np.nan)
elif ctx.isinf(x):
return dtype(-np.inf if x._mpf_[0] else np.inf)
raise NotImplementedError(f'convert {type(x)} instance to numpy floating point type')

def __call__(self, *args, **kwargs):
mp_args = []
context = None
for a in args:
if isinstance(a, (np.ndarray, np.floating, np.complexfloating)):
mp_args.append(self.nptomp(a))
if context is None:
context = self.get_context(a)
else:
assert context is self.get_context(a)
else:
mp_args.append(a)

extra_prec = int(context.prec * self.extra_prec_multiplier) + self.extra_prec
with context.extraprec(extra_prec):
result = super().__call__(*mp_args, **kwargs)

if isinstance(result, tuple):
lst = []
for r in result:
if ((isinstance(r, np.ndarray) and r.dtype.kind == 'O')
or isinstance(r, self.mpmath.ctx_mp.mpnumeric)):
r = self.mptonp(r)
lst.append(r)
return tuple(lst)

if ((isinstance(result, np.ndarray) and result.dtype.kind == 'O')
or isinstance(result, self.mpmath.ctx_mp.mpnumeric)):
return self.mptonp(result)

return result


class numpy_with_mpmath:
"""Namespace of universal functions on numpy arrays that use mpmath
backend for evaluation and return numpy arrays as outputs.
"""

_provides = [
'abs', 'absolute', 'sqrt', 'exp', 'expm1', 'exp2',
'log', 'log1p', 'log10', 'log2',
'sin', 'cos', 'tan', 'arcsin', 'arccos', 'arctan',
'sinh', 'cosh', 'tanh', 'arcsinh', 'arccosh', 'arctanh',
'square', 'positive', 'negative', 'conjugate', 'sign', 'sinc',
'normalize',
]

_mp_names = dict(
abs='absmin', absolute='absmin',
log='ln',
arcsin='asin', arccos='acos', arctan='atan',
arcsinh='asinh', arccosh='acosh', arctanh='atanh',
)

def __init__(self, mpmath, extra_prec_multiplier=0, extra_prec=0):
self.mpmath = mpmath

for name in self._provides:
mp_name = self._mp_names.get(name, name)

if hasattr(self, name):
op = getattr(self, name)
else:

def op(x, mp_name=mp_name):
return getattr(x.context, mp_name)(x)

setattr(self, name, vectorize_with_mpmath(op, mpmath=mpmath, extra_prec_multiplier=extra_prec_multiplier, extra_prec=extra_prec))

# The following function methods operate on mpmath number instances.
# The corresponding function names must be listed in
# numpy_with_mpmath._provides list.

def square(self, x):
return x * x

def positive(self, x):
return x

def negative(self, x):
return -x

def sqrt(self, x):
ctx = x.context
# workaround mpmath bugs:
if isinstance(x, ctx.mpc):
if ctx.isinf(x.real) and ctx.isinf(x.imag):
if x.real > 0: return x
ninf = x.real
inf = -ninf
if x.imag > 0: return ctx.make_mpc((inf._mpf_, inf._mpf_))
return ctx.make_mpc((inf._mpf_, inf._mpf_))
elif ctx.isfinite(x.real) and ctx.isinf(x.imag):
if x.imag > 0:
inf = x.imag
return ctx.make_mpc((inf._mpf_, inf._mpf_))
else:
ninf = x.imag
inf = -ninf
return ctx.make_mpc((inf._mpf_, ninf._mpf_))
return ctx.sqrt(x)

def expm1(self, x):
return x.context.expm1(x)

def log2(self, x):
return x.context.ln(x) / x.context.ln2

def log10(self, x):
return x.context.ln(x) / x.context.ln10

def exp2(self, x):
return x.context.exp(x * x.context.ln2)

def normalize(self, exact, reference, value):
"""Normalize reference and value using precision defined by the
difference of exact and reference.
"""
def worker(ctx, s, e, r, v):
ss, sm, se, sbc = s._mpf_
es, em, ee, ebc = e._mpf_
rs, rm, re, rbc = r._mpf_
vs, vm, ve, vbc = v._mpf_

if not (ctx.isfinite(e) and ctx.isfinite(r) and ctx.isfinite(v)):
return r, v

me = min(se, ee, re, ve)

# transform mantissa parts to the same exponent base
sm_e = sm << (se - me)
em_e = em << (ee - me)
rm_e = rm << (re - me)
vm_e = vm << (ve - me)

# find matching higher and non-matching lower bits of e and r
sm_b = bin(sm_e)[2:] if sm_e else ''
em_b = bin(em_e)[2:] if em_e else ''
rm_b = bin(rm_e)[2:] if rm_e else ''
vm_b = bin(vm_e)[2:] if vm_e else ''

m = max(len(sm_b), len(em_b), len(rm_b), len(vm_b))
em_b = '0' * (m - len(em_b)) + em_b
rm_b = '0' * (m - len(rm_b)) + rm_b

c1 = 0
for b0, b1 in zip(em_b, rm_b):
if b0 != b1:
break
c1 += 1
c0 = m - c1

# truncate r and v mantissa
rm_m = rm_e >> c0
vm_m = vm_e >> c0

# normalized r and v
nr = ctx.make_mpf((rs, rm_m, -c1, len(bin(rm_m)) - 2)) if rm_m else (-ctx.zero if rs else ctx.zero)
nv = ctx.make_mpf((vs, vm_m, -c1, len(bin(vm_m)) - 2)) if vm_m else (-ctx.zero if vs else ctx.zero)

return nr, nv

ctx = exact.context
scale = abs(exact)
if isinstance(exact, ctx.mpc):
rr, rv = worker(ctx, scale, exact.real, reference.real, value.real)
ir, iv = worker(ctx, scale, exact.imag, reference.imag, value.imag)
return ctx.make_mpc((rr._mpf_, ir._mpf_)), ctx.make_mpc((rv._mpf_, iv._mpf_))
elif isinstance(exact, ctx.mpf):
return worker(ctx, scale, exact, reference, value)
else:
assert 0 # unreachable
Loading