In [1]:
import sys

sys.path.append("..")
import jabble.model
import jabble.dataset
import jabble.loss
import jabble.physics 
import astropy.units as u

import h5py
import matplotlib.pyplot as plt
import scipy.optimize

from jaxopt import GaussNewton
import jax.numpy as jnp
import jax
import numpy as np
from mpl_axes_aligner import align
import os
import jabble.physics

import jax.config

jax.config.update("jax_enable_x64", True)



In [2]:
import os
import datetime

today = datetime.date.today()
out_dir = os.path.join('..','out',today.strftime("%y-%m-%d"))
os.makedirs(out_dir,exist_ok=True)

In [3]:
file_b = h5py.File("../data/barnards_e2ds.hdf5", "r")
file_p = h5py.File("../data/51peg_e2ds.hdf5"   , "r")

In [4]:
file_p.keys()

<KeysViewHDF5 ['airms', 'bervs', 'data', 'dates', 'drifts', 'filelist', 'ivars', 'pipeline_rvs', 'pipeline_sigmas', 'xs']>

In [8]:
def get_dataset(file,orders,device):
    ys = []
    xs = []
    yivar = []
    mask = []

    init_shifts = []
    airmass = []
    
    
    for iii in range(file["data"].shape[0]):
        for jjj in range(file["data"].shape[1]):
            ys.append(jnp.array(file["data"][iii,jjj,:]))
            xs.append(jnp.array(file["xs"][iii,jjj,:]))
            yivar.append(jnp.array(file["ivars"][iii,jjj,:]))
            mask.append(jnp.zeros(file["data"][iii,jjj,:].shape,dtype=bool))
            init_shifts.append(jabble.physics.shifts(file["bervs"][jjj]))
            airmass.append(file["airms"][jjj])
    
    init_shifts = jnp.array(init_shifts)
    airmass = jnp.array(airmass)

    targ_rv, targ_err = jax.device_put(jnp.array(file['pipeline_rvs']),device), jax.device_put(jnp.array(file['pipeline_sigmas']),device)
    dataset = jabble.dataset.Data(jabble.dataset.Data.from_lists(xs,ys,yivar,mask).dataframes)
    dataset.to_device(device)
    init_shifts = jax.device_put(init_shifts,device)
    airmass = jax.device_put(airmass,device)

    return dataset, init_shifts, airmass, targ_rv, targ_err

In [9]:
cpus = jax.devices("cpu")
# gpus = jax.devices("gpu")
loss = jabble.loss.ChiSquare()

In [12]:
dataset_p, shifts_p, airmass_p, targ_rv_p, targ_err_p = get_dataset(file_p,0,cpus[0])
model_name_p = os.path.join('..','out','24-04-03','peg51.mdl')
model_p = jabble.model.load(model_name_p)

5577 night sky line

6563 h alpha

In [19]:
model_p[1].__dict__

{'_fit': False,
 'func_evals': [],
 'history': [],
 'save_history': False,
 'loss_history': [],
 'save_loss': [],
 'results': [],
 'models': [<jabble.model.IrwinHallModel_vmap at 0x1465a2d4e290>,
  <jabble.model.StretchingModel at 0x1465a2d4ebd0>],
 'parameters_per_model': Array([139024.,      0.], dtype=float64)}

In [13]:
model_p.display()

AttributeError: 'AdditiveModel' object has no attribute 'size'

In [None]:
def rv_plot(model,data,targ_vel,targ_err,mjds,model_name):
    fig, ax = plt.subplots(
        1,
        figsize=(6, 4),
        facecolor=(1, 1, 1),
        dpi=300,
        sharey=True
    )
     
    epoches = len(data)
    epoch_range = np.arange(0, epoches, dtype=int)
    fischer_info = f_info(model[1][0],model,data)

    dvddx = jnp.array(
        [jax.grad(jabble.physics.velocities)(x) for x in model_set[1][0].p]
    )
    verr = np.sqrt(1 / fischer_info) * dvddx
    estimate_vel = jabble.physics.velocities(model[1][0].p)
    tv = targ_vel
    ev = estimate_vel
    ax.errorbar(
        mjds,
        tv,
        targ_err
        ".r",
        zorder=1,
        alpha=0.5,
        ms=6,
        label='HARPS RV'
    )
    print(np.mean(targ_rv),estimate_vel.mean())
    ax.errorbar(mjds,ev,verr,'.k',zorder=1,alpha=0.5,ms=6,label='Jabble RV')
    fig.legend()
    # ax.set_xlim(-0.5, epoches - 0.5)
    ax.set_ylabel("RV [$m/s$]")
    ax.set_xlabel( "MJD")
    plt.savefig(os.path.join(out_dir, "02-{}-vel.png".format(model_name)))
    plt.show()

In [None]:
rv_plot(model_p,dataset_p,targ_rv_p,targ_err_p,mjds,model_name_p)

In [None]:
def make_bary_plot(model,dataset,plt_epoches,lmin,lmax,lrange,plt_name,bcs):
    fig, axes = plt.subplots(2,len(plt_epoches),figsize=(4*len(plt_epoches),4),sharex=True,sharey='row',facecolor=(1, 1, 1),height_ratios=[4,1],dpi=200)
        
    model.fix()
    for ii, plt_epoch in enumerate(plt_epoches):
        
        xplot = np.linspace(np.log(lmin),np.log(lmax),\
                            dataset.xs[plt_epoch].shape[0]*10)
        yplot = model[0]([],xplot,plt_epoch)
        yplot_norm_stel = model[1]([],xplot,plt_epoch)
        yplot_norm_tell = model[2]([],xplot,plt_epoch)
        yhat = model[0]([],dataset.xs[plt_epoch],plt_epoch)
        axes[0,ii].set_xlim(xplot.min(),xplot.max())

        velocity = jabble.physics.velocities(model[1][0].p[plt_epoch])
        # axes[0,ii].set_title('RV: {:2.1e}, $\delta\lambda$: {:2.1e}, BERV: {:2.1e}'.format(velocity,model[1][0].p[plt_epoch] * np.mean(lrange),bcs[plt_epoch]),fontsize=10)
        
        
    
        # Data
        # axes[0,ii].plot(dataset.xs[plt_epoch][:],dataset.ys[plt_epoch][:],\
        #                          '.k',zorder=1,alpha=0.1,ms=3)
        axes[0,ii].plot(dataset.xs[plt_epoch][:],dataset.ys[plt_epoch][:] - yhat,\
                                 '.k',zorder=1,alpha=0.1,ms=3)

        # Norm Model
        # axes[0,ii].plot(xplot,yplot,'-r',linewidth=1.2,zorder=2,alpha=0.7,ms=6)
        # axes[0,ii].set_ylim(-2+np.mean(yplot),0.1+np.mean(yplot))
        # print(np.mean(yplot))
        

        # Stellar Model        
        axes[0,ii].plot(xplot,yplot_norm_stel,'-r',linewidth=1.2,zorder=10,alpha=0.7,ms=6)
        # Telluric Model
        axes[0,ii].plot(xplot,yplot_norm_tell,'-b',linewidth=1.2,zorder=10,alpha=0.7,ms=6)
        

        # Residuals
        axes[1,ii].step(dataset.xs[plt_epoch][:],dataset.ys[plt_epoch][:] - model([],dataset.xs[plt_epoch][:],plt_epoch),\
                                 'k',where='mid',zorder=1,alpha=0.3,ms=3)
        
        axes[0,ii].set_ylim(-2,1)
        axes[1,ii].set_ylim(-0.1,0.1)
        axes[0,ii].set_xticks([])
        axes[0,ii].set_xticks(np.log(lrange))
        axes[0,ii].set_xticklabels(['{}'.format(x) for x in lrange])

    # plt.x
    # plt.text(1, 1, 'Wavelength ($\AA$)', ha='center')
    plt.savefig(plt_name,dpi=200,bbox_inches='tight')
    fig.suptitle('Barycentric Rest Frame')
    plt.show()