Skip to content

Commit

Permalink
Rewrote the ResInterp class
Browse files Browse the repository at this point in the history
  • Loading branch information
benbaror committed Feb 16, 2018
1 parent 08ac5c5 commit dcada12
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 74 deletions.
90 changes: 21 additions & 69 deletions src/scrrpy/drr.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(self, sma, gamma=1.75, mbh_mass=4e6, star_mass=1.0, j_grid_size=128
self.sma = sma
self.gr_factor = 1.0
self.j = np.logspace(np.log10(self.jlc(self.sma)), 0, j_grid_size + 1)[:-1]
self.omega = abs(self.nu_p(self.sma, self.j))
self.omega = self.nu_p(self.sma, self.j)

@lru_cache()
def _res_intrp(self, ratio):
Expand Down Expand Up @@ -183,8 +183,8 @@ def parallel_drr(pos, seed, j, omega):

def _drr(self, j, omega, lnnp, neval=1e3, tol=0.0):
ratio = lnnp[1] / lnnp[-1]
get_jf1 = self._res_intrp(ratio).get_jf1(omega * ratio)
get_jf2 = self._res_intrp(ratio).get_jf2(omega * ratio)
get_jf1 = self._res_intrp(ratio)(omega * ratio)
get_jf2 = self._res_intrp(-ratio)(- omega * ratio)

@vegas.batchintegrand
def c_lnnp1(x):
Expand All @@ -210,6 +210,7 @@ def c_lnnp2(x):
true_anomaly[:, ix2])
return res

self.c_lnnp1 = c_lnnp1
integ = vegas.Integrator(5 * [[0, 1]])

if get_jf1 is None:
Expand Down Expand Up @@ -343,74 +344,25 @@ def __init__(self, cusp, omega, gr_factor=1.0):
"""
"""
self._cusp = cusp
self.size = 1000
self._cusp.gr_factor = gr_factor
self.omega = omega
self._af = np.logspace(np.log10(self._cusp.rg),
np.log10(self._cusp.rh),
1000)

# self._jf = np.logspace(np.log10(self._cusp.jlc(self._cusp.rh)),
# 0, 1001)[:-1]

def get_j(_nup):
jf = self._jf[_nup > 0]
_nup = _nup[_nup > 0]
s = np.argsort(_nup)
j = np.interp(self.omega, _nup[s], jf[s], left=0, right=0)

# j[self.omega < nup.min()] = 0
# j[self.omega > nup.max()] = 0
return j

# The minimal a at which omega changes sign.
a_gr1 = self._cusp.a_gr1
# The minimal at which omega intersects nu_p
self._af = np.logspace(np.log10(self._cusp.rg),
np.log10(self._cusp.rh),
1000)

a_min = self._af[(self._af < a_gr1) *
(omega.max() < self._cusp.nu_p1(self._af))].max()
self._af = np.logspace(np.log10(a_min),
np.log10(self._cusp.rh),
1000)

self._j1 = np.zeros([self._af.size, self.omega.size])
self._j2 = np.zeros([self._af.size, self.omega.size])

last = 0
for i, a in enumerate(self._af[self._af < a_gr1]):
self._jf = np.logspace(np.log10(self._cusp.jlc(a)),
0, 1001)[:-1]
nup = self._cusp.nu_p(a, self._jf)
self._j1[i, :] = get_j(nup)
last += 1
# last += 1

for i, a in enumerate(self._af[self._af > a_gr1]):
self._jf = np.logspace(np.log10(self._cusp.jlc(a)),
0, 1001)[:-1]
nup = self._cusp.nu_p(a, self._jf)
self._j1[i + last, :] = get_j(nup)
if any(nup < 0):
self._j2[i + last, :] = get_j(-nup)

def get_jf1(self, omega):
i = np.argmin(abs(self.omega - omega))
if abs(self.omega[i] - omega) > 1e-8:
raise ValueError
j = self._j1[:, i]
ix = np.where(j > 0)[0]
if len(ix) > 0:
return lambda af: np.interp(af,
self._af[ix], j[ix], left=0, right=0)

def get_jf2(self, omega):
i = np.argmin(abs(self.omega - omega))
if abs(self.omega[i] - omega) > 1e-8:
raise ValueError
j = self._j2[:, i]
ix = np.where(j > 0)[0]
if len(ix) > 0:
return lambda af: np.interp(af,
self._af[ix], j[ix], left=0, right=0)
self.x = np.logspace(-5, 0, self.size + 1)[:-1][::-1]
j_grid = []
for af in self._af:
jlc = self._cusp.jlc(af)
j = (1 - jlc) * self.x + jlc
nup = self._cusp.nu_p(af, j)
j_grid.append(np.exp(np.interp(self.omega, nup, np.log(j), left=-np.inf, right=-np.inf)))
j_grid = np.array(list(zip(*j_grid)))
self.j_grid = dict(zip(omega, j_grid))

def __call__(self, omega):
j = self.j_grid[omega]
ix = j > 0
if sum(ix) >= 1:
return lambda af: np.interp(np.log(af), np.log(self._af[ix]), j[ix], left=0, right=0)
pass
70 changes: 70 additions & 0 deletions src/scrrpy/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import numpy as np
from numba import float64
from numba import guvectorize
from numba import vectorize


def interp_reg(x, f):
"""
Linear interpolation over a regular grid
Parameters
----------
x : array,
assumed to be a linear regular grid
f : array, func
:return: an interpolation function
"""

x0 = x[0]
dx_inv = 1 / (x[1] - x0)

@vectorize([float64(float64)], nopython=True)
def _interp_reg(x_intp):
x_dx = (x_intp - x0) * dx_inv
ind = int(x_dx)
w = x_dx - ind
return f[ind] * (1 - w) + f[ind + 1] * w

return _interp_reg


def interp_reg_semilogx_vec(x, f, log=np.log):
"""
Linear interpolation over a logarithmic regular grid
Parameters
----------
x : array
assumed to be a logarithmic regular grid
f : array
:return: an interpolation function
"""

x0 = x[0]
dx_inv = 1 / log(x[1] / x0)

@vectorize([float64(float64)])
def _interp_reg(x_intp):
x_dx = log(x_intp / x0) * dx_inv
ind = int(x_dx)
w = x_dx - ind
return f[ind] * (1 - w) + f[ind + 1] * w

return _interp_reg


@guvectorize([(float64[:], float64[:], float64, float64, float64[:])], '(n),(m),(),()->(n)')
def interp_reg_semilogx(x_intp, f, x0, x1, res):
"""
Linear interpolation over a logarithmic regular grid
"""
dx_inv = 1 / np.log(x1 / x0)
for i in range(x_intp.size):
x_dx = np.log(x_intp[i] / x0) * dx_inv
ind = int(x_dx)
w = x_dx - ind
res[i] = f[ind] * (1 - w) + f[ind + 1] * w
32 changes: 27 additions & 5 deletions tests/test_drr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np

from scrrpy.drr import DRR
from scrrpy.drr import ResInterp


def test_io():
Expand All @@ -10,8 +11,8 @@ def test_io():
drr.save("test.hdf5")
drr = DRR.from_file("test.hdf5")
d_load, d_load_err = drr(drr.l_max, neval=drr.neval, tol=drr.tol, progress_bar=False)
np.testing.assert_array_max_ulp(d, d_load)
np.testing.assert_array_max_ulp(d_err, d_load_err)
np.testing.assert_array_almost_equal_nulp(d, d_load)
np.testing.assert_array_almost_equal_nulp(d_err, d_load_err)


def test_drr_parallel():
Expand All @@ -27,8 +28,7 @@ def test_drr_parallel():
assert (d > 0).all(), d.min()
assert drr.neval == neval
assert drr.l_max == l_max
np.testing.assert_almost_equal(d_mean*1e10, 1.755037479682267, 6)

np.testing.assert_almost_equal(d_mean*1e10, 1.7550221461604085, 6)

def test_drr():
np.random.seed(1234)
Expand All @@ -43,4 +43,26 @@ def test_drr():
assert (d > 0).all(), d.min()
assert drr.neval == neval
assert drr.l_max == l_max
np.testing.assert_almost_equal(d_mean*1e10, 1.7540880485435062, 6)
np.testing.assert_almost_equal(d_mean*1e10, 1.7536283542080702, 6)


def test_res_int(tol=5e-2):
drr = DRR(0.1)
jlc = drr.jlc(drr.sma)
j = np.logspace(np.log10(jlc), 0, 11)[:-1]
omega = drr.nu_p(drr.sma, j)
res_int = ResInterp(drr, omega)
f = res_int(omega[5])
af = np.logspace(-5,1, 100)
jf = f(af)
assert jf.max() < 1.0
assert jf.min() >= 0
assert abs(1 - drr.nu_p(af[jf>0], jf[jf>0])/omega[5]).max() < tol

def test_res_int_no_solution():
drr = DRR(0.1)
jlc = drr.jlc(drr.sma)
j = np.logspace(np.log10(jlc), 0, 11)[:-1]
res_int = ResInterp(drr, [-10])
f = res_int(-10)
assert f is None
30 changes: 30 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import numpy as np

from scrrpy.utils import interp_reg
from scrrpy.utils import interp_reg_semilogx
from scrrpy.utils import interp_reg_semilogx_vec


def test_interp_reg(tol=1e-4):
x = np.linspace(0, 1, 100)
x_intp = np.random.rand(1000)
y = x**2
f_intp = interp_reg(x, y)(x_intp)
err = max(abs(x_intp**2 - f_intp))
assert err < tol, err

def test_interp_reg_semilogx(tol=1e-3):
x = np.logspace(-5, 5, 1000)
x_intp = 10**((1 - 2*np.random.rand(1000))*5)
y = np.log(x)**2
f_intp = interp_reg_semilogx(x_intp, y, x[0], x[1])
err = max(abs(np.log(x_intp)**2 - f_intp))
assert err < tol, err

def test_interp_reg_semilogx_vec(tol=1e-3):
x = np.logspace(-5, 5, 1000)
x_intp = 10**((1 - 2*np.random.rand(1000))*5)
y = np.log(x)**2
f_intp = interp_reg_semilogx_vec(x, y)(x_intp)
err = max(abs(np.log(x_intp)**2 - f_intp))
assert err < tol, err

0 comments on commit dcada12

Please sign in to comment.