In [1]:
import jax.numpy as jnp
from jax import jit as jjit

In [2]:
%matplotlib inline
import matplotlib.pyplot as plt
from astropy.io import ascii as astro_ascii
import pickle, scipy, copy, numpy as np, pandas as pd
from dynesty import plotting as dyplot
import sys, math
alfpydir = '../scripts/'
sys.path.insert(1, alfpydir)

In [3]:
from str2arr import alfobj
alfvar = pickle.load(open('../pickle/alfvar_sspgrid_alff90_test1.p', "rb" ))

pos = alfobj()
pos.sigma = 200
pos.zh = -0.2
pos.logage = math.log10(8.)
pos.imf1 = 1.301
pos.imf2 = 2.301
pos.imf3 = 0.0801
#spec_p = getmodel(pos, alfvar = alfvar)

### update get_dv

In [36]:
def contnormspec(lam, flx, err, il1, il2, coeff=False, return_poly=False, 
                 npolymax = 10, npow = None):
    """
    !routine to continuum normalize a spectrum by a high-order
    !polynomial.  The order of the polynomial is determined by
    !n=(lam_max-lam_min)/100.  Only normalized over the input
    !min/max wavelength range
    #, lam,flx,err,il1,il2,flxout,coeff=None
    return: normed spectra
    """
    
    poly_dlam = 100.
    buff = 0.0
    #mask = np.ones(npolymax+1)
    #covar = np.empty((npolymax+1, npolymax+1))
    
    n1 = lam.size
    flxout = np.copy(flx)

    # ---- !divide by a power-law of degree npow. one degree per poly_dlam.
    # ---- !don't let things get out of hand (force Npow<=npolymax)

    if npow is None:
        npow = min((il2-il1)//poly_dlam, npolymax)
    print('npow=',npow)
    i1 = min(max(locate(lam, il1-buff),0), n1-2)
    i2 = min(max(locate(lam, il2+buff),1), n1-1)+1   
    ml = (il1+il2)/2.0
    
    #!simple linear least squares polynomial fit
    ind = np.isfinite(flx[i1:i2])
    res = np.polyfit(x = lam[i1:i2][ind]-ml, 
                     y = flx[i1:i2][ind], 
                     deg = npow, full = True, 
                     w = 1./np.square(err[i1:i2][ind]), 
                     cov = True)
    
    covar = res[2]
    chi2sqr = res[1]
    tcoeff = res[0]
    
    p = np.poly1d(tcoeff)
    poly = p(lam-ml)
    
    if coeff == False and return_poly==False:
        return npow
    
    elif coeff == True and return_poly==False:
        return npow, tcoeff    
    
    elif coeff == True and return_poly==True:
        return npow, tcoeff, poly


In [41]:
res = contnormspec(sspgrid.lam, sspgrid.logssp[:,0,0,0,0], 
             np.ones(sspgrid.lam.shape), 4400, 5500, coeff = True, return_poly=True)

npow= 10


In [51]:
res[1].shape

(11,)

In [23]:
%timeit locate(sspgrid.logagegrid, pos.logage)

317 ns ± 14 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)


In [8]:
tem = jnp.asarray(alfvar.sspgrid.logagegrid)



In [84]:
%timeit locate(tem, pos.logage)

6.18 ms ± 458 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [27]:
%timeit np.square(sspgrid.lam)

1.93 µs ± 68.9 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [29]:
%timeit sspgrid.lam**2

2 µs ± 60.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
