Skip to content

Commit

Permalink
WIP towards caching PSF parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
sbailey committed Mar 11, 2018
1 parent 734ccff commit d528e88
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 26 deletions.
21 changes: 19 additions & 2 deletions py/specter/psf/gausshermite.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,27 @@ def _pgh(self, x, m=0, xc=0.0, sigma=1.0):
y = sp.erf(u/np.sqrt(2.))
return 0.5 * (y[1:] - y[0:-1])

def cache_params(self, spec_range, wavelengths):
"""
Cache PSF parameters to make future xypix calls faster

Args:
spec_range: indices (specmin, specmax) python-style indexing
wavelengths: float array of wavelengths
If called, future calls to xypix and projection_matrix may include
an optional `iwave_cache` index into the `wavelengths` array provided
here, which will be used to retrieve cached PSF parameters for faster
evaluation.
"""
## TODO: implement this.
pass

def _xypix(self, ispec, wavelength, iwave_cache=None):

def _xypix(self, ispec, wavelength):
## TODO: implement using iwave_cache (if not None)
## to lookup the cached values instead of calling
## self.XXX.eval(ispec, wavelength)

# x, y = self.xy(ispec, wavelength)
x = self._x.eval(ispec, wavelength)
Expand Down
2 changes: 1 addition & 1 deletion py/specter/psf/gausshermite2.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def _pgh(self, x, m=0, xc=0.0, sigma=1.0):
return 0.5 * (y[1:] - y[0:-1])


def _xypix(self, ispec, wavelength):
def _xypix(self, ispec, wavelength, iwave_cache=None):

# x, y = self.xy(ispec, wavelength)
x = self.coeff['X'].eval(ispec, wavelength)
Expand Down
2 changes: 1 addition & 1 deletion py/specter/psf/monospot.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self, filename, spot=None, scale=1.0):
self._spot = spot.copy()
self._scale = scale

def _xypix(self, ispec, wavelength):
def _xypix(self, ispec, wavelength, iwave_cache=None):
"""
Return xslice, yslice, pix for PSF at spectrum ispec, wavelength
"""
Expand Down
2 changes: 1 addition & 1 deletion py/specter/psf/pixpsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, filename):
self.psfimage = fx[5].data.view(np.ndarray) #- [igroup, icoeff, iy, ix]
fx.close()

def _xypix(self, ispec, wavelength):
def _xypix(self, ispec, wavelength, iwave_cache=None):
"""
Evaluate PSF for a given spectrum and wavelength
Expand Down
75 changes: 55 additions & 20 deletions py/specter/psf/psf.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,26 +197,53 @@ def wdisp(self, ispec, wavelength):
#-------------------------------------------------------------------------
#- Evaluate the PSF into pixels

def pix(self, ispec, wavelength):
def cache_params(self, spec_range, wavelengths):
"""
Optionally cache PSF parameters to make future xypix calls faster
Args:
spec_range: indices (specmin, specmax) python-style indexing
wavelengths: float array of wavelengths
Subclasses may optionally implement this function. Subsequent calls
to self.xypix() and self.projection_matrix() may contain an
iwave_cache parameter giving an index into the cache.
If not implemented by the subclass, this function from the PSF base
class does nothing, and including iwave to the other calls is
harmless (but also not beneficial.)
"""
pass

def pix(self, ispec, wavelength, iwave_cache=None):
"""
Evaluate PSF for spectrum[ispec] at given wavelength
returns 2D array pixels[iy,ix]
If iwave_cache is set, it can be used as an index to the wavelength
array previously passed to cache PSF parameters in self.cache_xypix().
It is optional if subclasses want to use this or not.
also see xypix(ispec, wavelength)
"""
return self.xypix(ispec, wavelength)[2]
return self.xypix(ispec, wavelength, iwave_cache=None)[2]

def _xypix(self, ispec, wavelength):
def _xypix(self, ispec, wavelength, iwave_cache=None):
"""
Subclasses of PSF should implement this to return
xslice, yslice, pixels[iy,ix] for their particular
models. Don't worry about edge effects -- PSF.xypix
will take care of that.
If iwave_cache is set, it can be used as an index to the wavelength
array previously passed to cache PSF parameters in self.cache_xypix().
It is optional if subclasses want to use this or not.
"""
raise NotImplementedError

def xypix(self, ispec, wavelength, xmin=0, xmax=None, ymin=0, ymax=None):
def xypix(self, ispec, wavelength, xmin=0, xmax=None, ymin=0, ymax=None,
iwave_cache=None):
"""
Evaluate PSF for spectrum[ispec] at given wavelength
Expand All @@ -226,6 +253,9 @@ def xypix(self, ispec, wavelength, xmin=0, xmax=None, ymin=0, ymax=None):
if xmin or ymin are set, the slices are relative to those
minima (useful for simulating subimages)
if iwave is set, it refers to the index of the
wavelengths previously passed to self.cache_xypix().
"""
if xmax is None:
xmax = self.npix_x
Expand All @@ -237,17 +267,9 @@ def xypix(self, ispec, wavelength, xmin=0, xmax=None, ymin=0, ymax=None):
elif wavelength > self.wavelength(ispec, self.npix_y-0.5):
return slice(0,0), slice(ymax, ymax), np.zeros((0,0))

key = (ispec, wavelength)
try:
if key in self._cache:
xx, yy, ccdpix = self._cache[key]
else:
xx, yy, ccdpix = self._xypix(ispec, wavelength)
self._cache[key] = (xx, yy, ccdpix)
except AttributeError:
self._cache = CacheDict(2500)
xx, yy, ccdpix = self._xypix(ispec, wavelength)

xx, yy, ccdpix = self._xypix(ispec, wavelength,
iwave_cache=iwave_cache)

xlo, xhi = xx.start, xx.stop
ylo, yhi = yy.start, yy.stop

Expand Down Expand Up @@ -612,14 +634,18 @@ def wmax_all(self):
"""Maximum wavelength seen by all spectra"""
return self._wmax_all

def projection_matrix(self, spec_range, wavelengths, xyrange):
def projection_matrix(self, spec_range, wavelengths, xyrange, iwave_cache=None):
"""
Returns sparse projection matrix from flux to pixels
Inputs:
Args:
spec_range = (ispecmin, ispecmax) or scalar ispec
wavelengths = array_like wavelengths
xyrange = (xmin, xmax, ymin, ymax)
Options:
iwave_cache: index of wavelengths[0] in the possibly larger
wavelengths array previously passed to self.cache_xypix()
Usage:
xyrange = xmin, xmax, ymin, ymax
Expand All @@ -645,14 +671,23 @@ def projection_matrix(self, spec_range, wavelengths, xyrange):
A = np.zeros( (ny*nx, nspec*nflux) )
tmp = np.zeros((ny, nx))
for ispec in range(specmin, specmax):
for iflux, w in enumerate(wavelengths):
for iw, w in enumerate(wavelengths):
#- Are use using a pre-cached wavelength?
if iwave_cache is not None:
iwave = iwave_cache + iw
else:
iwave = None

#- Get subimage and index slices
xslice, yslice, pix = self.xypix(ispec, w, xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax)
xslice, yslice, pix = self.xypix(ispec, w,
xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax,
iwave_cache = iwave
)

#- If there is overlap with pix_range, put into sub-region of A
if pix.shape[0]>0 and pix.shape[1]>0:
tmp[yslice, xslice] = pix
ij = (ispec-specmin)*nflux + iflux
ij = (ispec-specmin)*nflux + iw
A[:, ij] = tmp.ravel()
tmp[yslice, xslice] = 0.0

Expand Down
2 changes: 1 addition & 1 deletion py/specter/psf/spotgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(self, filename):

fx.close()

def _xypix(self, ispec, wavelength):
def _xypix(self, ispec, wavelength, iwave_cache=None):
"""
Return xslice, yslice, pix for PSF at spectrum ispec, wavelength
"""
Expand Down
17 changes: 17 additions & 0 deletions py/specter/test/test_psf.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,23 @@ def test_xyrange(self):
self.assertLess(xmin, xmax)
self.assertGreater(xmax, 0)

def test_cache(self):
ww = np.linspace(self.psf.wmin_all, self.psf.wmax_all)
spec_range = (3,5)
self.psf.cache_params(spec_range, ww)
for ispec in range(spec_range[0], spec_range[1]):
for iw, w in enumerate(ww):
#- Cached values should agree with direct calls
xx1, yy1, pix1 = self.psf.xypix(ispec, w)
xx2, yy2, pix2 = self.psf.xypix(ispec, w, iwave_cache=iw)
self.assertEqual(xx1, xx2)
self.assertEqual(yy1, yy2)
#- maybe need np.allclose, but let's start with np.all()
self.assertTrue(np.all(pix1 == pix2))

#- Calling something that isn't in the cache is still ok
xx, yy, pix = self.psf.xypix(0, ww[0]+1.0)

#- Test Pixellated PSF format
class TestPixPSF(GenericPSFTests,unittest.TestCase):
@classmethod
Expand Down

0 comments on commit d528e88

Please sign in to comment.