In [1]:
import sys

sys.path.append("..")
import jabble.model
import jabble.dataset
import jabble.loss
import jabble.physics 
import astropy.units as u

import h5py
import matplotlib.pyplot as plt
import scipy.optimize

from jaxopt import GaussNewton
import jax.numpy as jnp
import jax
import numpy as np
from mpl_axes_aligner import align
import os
import jabble.physics

import jax.config

jax.config.update("jax_enable_x64", True)



In [2]:
import os
import datetime

today = datetime.date.today()
out_dir = os.path.join('..','out',today.strftime("%y-%m-%d"))
os.makedirs(out_dir,exist_ok=True)

In [3]:
file_b = h5py.File("../data/barnards_e2ds.hdf5", "r")
file_p = h5py.File("../data/51peg_e2ds.hdf5", "r")

In [4]:
class MyData(jabble.dataset.Data):
    def blockify(data,device):
        max_ind = np.max([len(dataframe.xs) for dataframe in data])
        xs    = np.zeros((len(data),max_ind))
        ys    = np.zeros((len(data),max_ind))
        yivar = np.zeros((len(data),max_ind))
        mask  = np.ones((len(data),max_ind))
    
        for i,dataframe in enumerate(data):
            frame_size = len(dataframe.xs)
            xs[i,:frame_size]    = dataframe.xs
            ys[i,:frame_size]    = dataframe.ys
            yivar[i,:frame_size] = dataframe.yivar
            mask[i,:frame_size]  = dataframe.mask

        xs    = jax.device_put(jnp.array(xs),device)
        ys    = jax.device_put(jnp.array(ys),device)
        yivar = jax.device_put(jnp.array(yivar),device)
        mask  = jax.device_put(jnp.array(mask,dtype=bool),device)

        return xs, ys, yivar, mask

In [5]:
def get_dataset(file,orders,device):
    ys = []
    xs = []
    yivar = []
    mask = []

    init_shifts = []
    airmass = []
    
    for iii in range(file["data"].shape[0]):
        for jjj in range(file["data"].shape[1]):
            ys.append(jnp.array(file["data"][iii,jjj,:]))
            xs.append(jnp.array(file["xs"][iii,jjj,:]))
            yivar.append(jnp.array(file["ivars"][iii,jjj,:]))
            mask.append(jnp.zeros(file["data"][iii,jjj,:].shape,dtype=bool))
            init_shifts.append(jabble.physics.shifts(file["bervs"][jjj]))
            airmass.append(file["airms"][jjj])
    
    init_shifts = jnp.array(init_shifts)
    airmass = jnp.array(airmass)
                         
    dataset = MyData(jabble.dataset.Data.from_lists(xs,ys,yivar,mask).dataframes)
    dataset.to_device(device)
    init_shifts = jax.device_put(init_shifts,device)
    airmass = jax.device_put(airmass,device)

    return dataset, init_shifts, airmass

In [6]:
def gpu_optimize(
        self, loss, data, device, options={}
    ):
        """
        Choosen optimizer for jabble is scipy.fmin_l_bfgs_b.
        optimizes all parameters in fit mode with respect to the loss function using jabble.Dataset

        Parameters
        ----------
        loss : `jabble.Loss`
            jabble.loss object, 
        data : `jabble.Dataset`
            jabble.Dataset, that is handed to the Loss function during optimization
        verbose : `bool`
            if true prints, loss, grad dot grad at every function
        save_history : `bool`
            if true, saves values of parameters at every function call
        save_loss : `bool`
            if true, saves loss array every function call of optimization
        options : 
            additional keyword options to be passed to scipy.fmin_l_bfgs_b


        Returns
        ----------
        d : `dict`
            Results from scipy.fmin_l_bgs_b call
        """

        func_grad = jax.value_and_grad(loss.loss_all, argnums=0)

        def val_gradient_function(p, *args):
            val, grad = func_grad(p, *args)

            return np.array(val, dtype="f8"), np.array(grad, dtype="f8")
        
        # blockify dataset
        # mask extra points added to block
        xs, ys, yivar, mask = data.blockify(device)

        ##########################################################
    
        x, f, d = scipy.optimize.fmin_l_bfgs_b(
            val_gradient_function, self.get_parameters(), None, (xs,ys,yivar,mask,self), **options
        )
        self.results.append(d)
        self._unpack(jax.device_put(jnp.array(x),device))
        return d


In [7]:
class MyChiSquare(jabble.loss.ChiSquare):
    def __call__(self, p, xs, ys, yivar, mask, i, model, *args):
        return jnp.where(~mask,yivar * (((ys - model(p,xs,i,*args))**2)),0.0)
    
    def loss_all(self,p,xs,ys,yivar,mask,model,*args):
        
        #blockify parameters
        #what if normalization model has different number of parameters per model
        #anything that is going to take the epoch index needs to blockified and be the only parameter
        #this is an issue with the normalization model because its epoch specific but the parameters vary by epoch
        # just putting in the zero below will assume the same number of parameters as the first one
        # not the one specified, whats the better way to do multiple epoch fitting without indices
        
        def _internal(xs_row,ys_row,yivar_row,mask_row,index):
            return self(p,xs_row,ys_row,yivar_row,mask_row,index,model,*args).sum()

        indices = jnp.arange(0,xs.shape[0],dtype=int)
        out = jax.vmap(_internal, in_axes=(0, 0, 0, 0, 0), out_axes=0)(xs, ys, yivar ,mask, indices)
        return out.sum()

In [8]:
def get_model(dataset,resolution,p_val,vel_padding,init_shifts,airmass,pts_per_wavelength,norm_p_val,device):
    def generate_norm_grid(xs,pts_per_wavelength):
        return np.linspace(
            np.min(xs),np.max(xs),
            int((np.exp(np.max(xs)) - np.exp((np.min(xs)))) * pts_per_wavelength)
        )
    norm_models = []
    for dataframe in dataset:
        norm_models.append(jabble.model.IrwinHallModel_vmap(generate_norm_grid(dataframe.xs,pts_per_wavelength),norm_p_val))
        
    dx = jabble.physics.delta_x(2 * resolution)
    x_grid = np.arange(np.min(np.concatenate(dataset.xs)), np.max(np.concatenate(dataset.xs)), step=dx, dtype="float64")
    
    model_grid = jabble.model.create_x_grid(
        x_grid, vel_padding.to(u.m/u.s).value, 2 * resolution
    )  
    model = jabble.model.CompositeModel(
        [
            jabble.model.ShiftingModel(init_shifts),
            jabble.model.IrwinHallModel_vmap(model_grid, p_val),
        ]
    ) + jabble.model.CompositeModel(
        [
            jabble.model.IrwinHallModel_vmap(model_grid, p_val),
            jabble.model.StretchingModel(airmass),
        ]
    ) #+ jabble.model.NormalizationModel(norm_models)

    model.to_device(device)
    return model

In [9]:
def train_cycle(model, dataset, loss):
    # Fit Normalization
    # model.fix()
    # model.fit(2)
    # res1 = gpu_optimize(model,loss,dataset,gpus[0])
    # print(res1)
    
    # Fit Stellar & Telluric Template
    model.fix()
    model.fit(0, 1)
    model.fit(1, 0)
    res1 = gpu_optimize(model,loss,dataset,gpus[0])#model.optimize(loss, dataset)
    # print(type(model[0][0].p))
    print(res1)
    
    # Fit RV
    model.fix()
    model.fit(0, 0)
    res1 = gpu_optimize(model,loss,dataset,gpus[0])# model.optimize(loss, dataset)
    # print(type(model_p[0][0].p))
    print(res1)

    # RV Parabola Fit
    # model.fix()
    # shift_search = jabble.physics.shifts(np.linspace(-10, 10, 100))
    # model[0][0].parabola_fit(shift_search, loss, model, dataset)
    # print(type(model_p[0][0].p))

    # Fit Everything
    model.fix()
    model.fit(0, 0)
    model.fit(0, 1)
    model.fit(1, 0)
    # model.fit(2)
    res1 = gpu_optimize(model,loss,dataset,gpus[0])#model.optimize(loss, dataset)
    # print(type(model_p[0][0].p))
    print(res1)

    return model

In [10]:
cpus = jax.devices("cpu")
gpus = jax.devices("gpu")
loss = MyChiSquare()

In [11]:
resolution = 115_000
p_val = 2
vel_padding = 100 * u.km / u.s

pts_per_wavelength = 1/10
norm_p_val = 4

model_names = [os.path.join(out_dir,'barnards.mdl'), os.path.join(out_dir,'peg51.mdl')]

# model_p = get_model(dataset_p,resolution,p_val,vel_padding,shifts_p,airmass_p,pts_per_wavelength,norm_p_val,gpus[0])
# model_b = get_model(dataset_b,resolution,p_val,vel_padding,shifts_b,airmass_b,pts_per_wavelength,norm_p_val,gpus[0])

In [12]:
def run(file,model_name):
    dataset, shifts, airmass = get_dataset(file,0,cpus[0])
    # dataset_p, shifts_p, airmass_p = get_dataset(file_p,0,cpus[0])

    model = get_model(dataset,resolution,p_val,vel_padding,shifts,airmass,pts_per_wavelength,norm_p_val,gpus[0])
    model = train_cycle(model, dataset, loss)
    # model_b = train_cycle(model_b, dataset_b, loss)
    
    
    jabble.model.save(model_name,model)
    # jabble.model.save(model_name_p,model_p)

In [13]:
for model_name,file in zip(model_names,[file_b,file_p]):
    run(file,model_name)

{'grad': array([0., 0., 0., ..., 0., 0., 0.]), 'task': 'CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH', 'funcalls': 728, 'nit': 708, 'warnflag': 0}


2024-04-03 20:11:58.523116: W external/tsl/tsl/framework/bfc_allocator.cc:485] Allocator (GPU_0_bfc) ran out of memory trying to allocate 4.03GiB (rounded to 4331667456)requested by op 
2024-04-03 20:11:58.523366: W external/tsl/tsl/framework/bfc_allocator.cc:497] *******_***********************************************************************************_________
2024-04-03 20:11:58.523594: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 4331667456 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:  690.62MiB
              constant allocation:       176B
        maybe_live_out allocation:   29.50GiB
     preallocated temp allocation:       328B
                 total allocation:   30.17GiB
              total fragmentation:       528B (0.00%)
Peak buffers:
	Buffer 1:
		Size: 12.10GiB
		Operator: op_name="jit(cardinal_vmap_model)/jit(main)

XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 4331667456 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:  690.62MiB
              constant allocation:       176B
        maybe_live_out allocation:   29.50GiB
     preallocated temp allocation:       328B
                 total allocation:   30.17GiB
              total fragmentation:       528B (0.00%)
Peak buffers:
	Buffer 1:
		Size: 12.10GiB
		Operator: op_name="jit(cardinal_vmap_model)/jit(main)/vmap(jit(polyval))/concatenate[dimension=0]" source_file="/home/mdd423/wobble_jax/notebooks/../jabble/irwinhall.py" source_line=48
		XLA Label: fusion
		Shape: f64[3,22032,4096,6]
		==========================

	Buffer 2:
		Size: 4.03GiB
		XLA Label: fusion
		Shape: f64[22032,4096,6]
		==========================

	Buffer 3:
		Size: 4.03GiB
		Operator: op_name="jit(cardinal_vmap_model)/jit(main)/vmap(jit(polyval))/broadcast_in_dim[shape=(22032, 4096, 6) broadcast_dimensions=(1, 2)];jit(cardinal_vmap_model)/jit(main)/vmap(jit(polyval))/broadcast_in_dim[shape=(22032, 4096, 6) broadcast_dimensions=()]"
		XLA Label: fusion
		Shape: f64[22032,4096,6]
		==========================

	Buffer 4:
		Size: 4.03GiB
		Operator: op_name="jit(cardinal_vmap_model)/jit(main)/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)) slice_sizes=(1,) unique_indices=False indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]" source_file="/home/mdd423/wobble_jax/notebooks/../jabble/model.py" source_line=1016
		XLA Label: fusion
		Shape: f64[541458432,1]
		==========================

	Buffer 5:
		Size: 4.03GiB
		Operator: op_name="jit(cardinal_vmap_model)/jit(main)/and" source_file="/home/mdd423/wobble_jax/notebooks/../jabble/irwinhall.py" source_line=48
		XLA Label: fusion
		Shape: f64[22032,4096,6]
		==========================

	Buffer 6:
		Size: 688.50MiB
		Entry Parameter Subshape: f64[22032,4096]
		==========================

	Buffer 7:
		Size: 688.50MiB
		Operator: op_name="jit(cardinal_vmap_model)/jit(main)/dot_general[dimension_numbers=(((2,), (2,)), ((0, 1), (0, 1))) precision=None preferred_element_type=float64]" source_file="/home/mdd423/wobble_jax/notebooks/../jabble/model.py" source_line=1016
		XLA Label: fusion
		Shape: f64[22032,4096]
		==========================

	Buffer 8:
		Size: 516.38MiB
		Operator: op_name="jit(cardinal_vmap_model)/jit(main)/and" source_file="/home/mdd423/wobble_jax/notebooks/../jabble/irwinhall.py" source_line=48
		XLA Label: fusion
		Shape: pred[22032,4096,6]
		==========================

	Buffer 9:
		Size: 86.06MiB
		Operator: op_name="jit(cardinal_vmap_model)/jit(main)/jit(remainder)/and" source_file="/home/mdd423/wobble_jax/notebooks/../jabble/model.py" source_line=1010
		XLA Label: fusion
		Shape: pred[22032,4096]
		==========================

	Buffer 10:
		Size: 1.06MiB
		Entry Parameter Subshape: f64[139025]
		==========================

	Buffer 11:
		Size: 1.06MiB
		Entry Parameter Subshape: f64[139025]
		==========================

	Buffer 12:
		Size: 72B
		XLA Label: tuple
		Shape: (f64[22032,4096], f64[22032,4096,6], f64[], pred[22032,4096], f64[22032,4096,6], /*index=5*/f64[22032,4096,6], f64[3,22032,4096,6], pred[22032,4096,6], f64[22032,4096,6])
		==========================

	Buffer 13:
		Size: 72B
		XLA Label: constant
		Shape: f64[3,3]
		==========================

	Buffer 14:
		Size: 72B
		XLA Label: tuple
		Shape: (f64[22032,4096], f64[22032,4096,6], f64[], pred[22032,4096], f64[22032,4096,6], /*index=5*/f64[22032,4096,6], f64[3,22032,4096,6], pred[22032,4096,6], f64[22032,4096,6])
		==========================

	Buffer 15:
		Size: 48B
		XLA Label: constant
		Shape: s64[6]
		==========================



5577 night sky line

In [None]:
# modelname = 'barnardsvmapmodel1.mdl'
# # model = jabble.model.load(modelname)
# jabble.model.save(modelname,model)

In [None]:
def make_plot(model,dataset,init_shifts,filename):
    x_window = np.log(4550) - np.log(4549)
    lmin = np.exp(dataset.xs[0,500])
    lmax = np.exp(dataset.xs[0,1500])
    lrange = np.arange(lmin,lmax,5)
    plt_unit = u.Angstrom
    epoches = 25
    r_plots = 5

    vel_epoch = 5
    fig, axes = plt.subplots(
        epoches // r_plots,
        r_plots,
        figsize=(8, 8),
        sharex=False,
        sharey=True,
        facecolor=(1, 1, 1),
        dpi=200,
    )
    # fig.suptitle(filenames[model_num])
    for plt_epoch in range((epoches // r_plots) * r_plots):
        xplot = np.linspace(np.log(lmin), np.log(lmax), dataset.xs.shape[1] * 10)
        axes[plt_epoch // r_plots, plt_epoch % r_plots].set_xlim(
            xplot.min() + model[0][0].p[plt_epoch],
            xplot.max() + model[0][0].p[plt_epoch],
        )

        # model_set[model_num].fix()
        # model_set[model_num].fit(0)
        # rv_model_deriv = jax.jacfwd(model_set[model_num], argnums=0)(model_set[model_num].get_parameters(),dataset.xs[plt_epoch,:],plt_epoch)
        # rv_loss_deriv = jax.jacfwd(loss, argnums=0)(model_set[model_num].get_parameters(),datasets[0],vel_epoch,model_set[model_num])

        model.fix()

        axes[plt_epoch // r_plots, plt_epoch % r_plots].errorbar(
            dataset.xs[plt_epoch, :],
            dataset.ys[plt_epoch, :],
            dataset.yerr[plt_epoch, :],
            fmt=".k",
            elinewidth=1.2,
            zorder=1,
            alpha=0.5,
            ms=3,
        )

        # true_model.fix()

        axes[plt_epoch // r_plots, plt_epoch % r_plots].plot(
            xplot,
            model([], xplot, plt_epoch),
            "-r",
            linewidth=1.2,
            zorder=2,
            alpha=0.5,
            ms=6,
        )
        # axes[plt_epoch // r_plots, plt_epoch % r_plots].plot(xplot,true_model([],xplot,plt_epoch),'-r',linewidth=1.2,zorder=1,alpha=0.5,ms=6)

        axes[plt_epoch // r_plots, plt_epoch % r_plots].set_ylim(-2, 1)
        #         axes[i,j].set_yticks([])
        axes[plt_epoch // r_plots, plt_epoch % r_plots].set_xticks(np.log(lrange))
        axes[plt_epoch // r_plots, plt_epoch % r_plots].set_xticklabels(
            ["{:2.0f}".format(x) for x in lrange]
        )

        res_ax = axes[plt_epoch // r_plots, plt_epoch % r_plots].twinx()
        residual = loss(
            model.get_parameters(),
            dataset,
            plt_epoch,
            model,
        )
        res_ax.step(
            dataset.xs[plt_epoch, :], residual, where="mid", alpha=0.3, label="residual"
        )
        res_ax.set_ylim(0.0, 20)
        res_ax.set_yticks([])
        # res_ax.step(model_set[i][j][1].xs+model_set[i][j][0].p[plt_epoch],\
        #             model_set[i][j].results[-2]['grad'][:],\
        #             where='mid',alpha=0.4,label='residual',zorder=-1)
        # res_ax.set_yticks([])

        # res_ax.step(x_grid,\
        #             rv_model_deriv[:,plt_epoch],\
        #             where='mid',alpha=0.4,label='RV Derivative',zorder=-1)

        #     res_ax.step(x_grid,\
        #                 rv_loss_deriv[:,plt_epoch],\
        #                 where='mid',alpha=0.4,label='RV Derivative',zorder=-1)

        #     align_yaxis(, 0, , 0)

        align.yaxes(
            axes[plt_epoch // r_plots, plt_epoch % r_plots], 0.0, res_ax, 0.0, 2.0 / 3.0
        )

    # res.get_shared_y_axes().join(ax1, ax3)
    fig.text(0.5, 0.04, "$\lambda$", ha="center")
    fig.text(0.04, 0.5, "y", va="center", rotation="vertical")
    # fig.text(0.96, 0.5, '$d \L /d \delta x$', va='center', rotation=270)
    # fig.text(0.96, 0.5, '$d f_{{{ji}}} /d \delta x_k$', va='center', rotation=270)
    fig.text(0.96, 0.5, "residuals", va="center", rotation=270)

    plt.savefig(
        os.path.join(out_dir, "02-res_{}.png".format(filename)),
        dpi=300,
        bbox_inches="tight",
    )
    plt.show()

6563 h alpha

In [None]:
filenames = ['51peg','barnards']
make_plot(model_b,dataset_b,shifts_b,filenames[1])

In [None]:
def rv_plot(model_set,datasets,shift_set,filenames,file_set):
    fig, ax = plt.subplots(
        len(model_set),
        figsize=(8, 8),
        facecolor=(1, 1, 1),
        dpi=300,
        sharey=True,
    )
     
    for i in range(len(model_set)):
        velocities = jabble.physics.velocities(shift_set[i]) * u.m/u.s
        epoches = datasets[i].xs.shape[0]
        epoch_range = np.arange(0, epoches, dtype=int)
        fischer_information = np.zeros(epoches)
        for e_num in range(epoches):
            model_set[i].fix()
            model_set[i].fit(0,0)
            temp = jax.jacfwd(model_set[i], argnums=0)(model_set[i].get_parameters(),datasets[i].xs[e_num,:],e_num)
            # print(temp.shape)
            fischer_information[e_num] = jnp.dot(
                temp[:, e_num] ** 2, datasets[i].yivar[e_num, :]
            )
    
        dvddx = jnp.array(
            [jax.grad(jabble.physics.velocities)(x) for x in model_set[i][0][0].p]
        )
        verr = np.sqrt(1 / fischer_information) * dvddx
        estimate_vel = jabble.physics.velocities(model_set[i][0][0].p)
        tv = velocities.to(u.m/u.s).value - velocities.to(u.m/u.s).value.mean()
        ev = estimate_vel - estimate_vel.mean()
        ax[i].errorbar(
            epoch_range,
            tv - tv,
            yerr=file_set[i]["pipeline_sigmas"][:],
            fmt=".r",
            elinewidth=2.2,
            zorder=1,
            alpha=0.5,
            ms=6,
        )
    
        ax[i].errorbar(epoch_range,tv - ev,yerr=verr,fmt='.k',elinewidth=2.2,zorder=1,alpha=0.5,ms=6)
    
        ax[i].set_title('{}'.format(filenames[i], model_set[i][1][0].p_val))
        ax[i].set_xlim(-0.5, epoches - 0.5)
    fig.text(0.04, 0.5, "$v_{truth} - v_{est}$ [$m/s$]", va="center", rotation="vertical")
    fig.text(0.5, 0.04, "epochs", ha="center")
    plt.savefig(os.path.join(out_dir, "02-dv-barn-51peg.png"))
    plt.show()

In [None]:
model_set = [model_p,  model_b]
datasets  = [dataset_p,dataset_b]
shift_set = [shifts_p, shifts_b]
file_set  = [file_p,   file_b]
rv_plot(model_set,datasets,shift_set,filenames,file_set)

In [None]:
def make_better_plot(model_set,datasets,file_set):
    

    fig, axes = plt.subplots(2*len(model_set),4,figsize=(4*4,4*len(model_set)),sharex=True,facecolor=(1, 1, 1),dpi=200,height_ratios=[4,1]*len(model_set))
    
    for jj,(model,dataset,file) in enumerate(zip(model_set,datasets,file_set)):
        x_window = np.log(4550) - np.log(4549)
        lmin = np.exp(dataset.xs[0,0])
        lmax = np.exp(dataset.xs[0,2000])
        lrange = np.arange(lmin,lmax,5)
        sort_airmasses = np.argsort(np.array(file['airms'][:]))
        plt_epochs = np.concatenate((sort_airmasses[:2],sort_airmasses[-2:]))
        
        
        
        
        offset = 1.0
        xplot = np.linspace(np.log(lmin)-x_window,np.log(lmax)+x_window,dataset.xs.shape[1]*10)
        for ii,plt_epoch in enumerate(plt_epochs):
            axes[2*jj,ii].set_xlim(xplot.min()+model[0][0].p[plt_epoch],xplot.max()+model[0][0].p[plt_epoch])
            
            model.fix()
            
            axes[2*jj,ii].errorbar(dataset.xs[plt_epoch,:],dataset.ys[plt_epoch,:],\
                                     dataset.yerr[plt_epoch,:],fmt='.k',elinewidth=1.2,zorder=1,alpha=0.5,ms=3)
            
            axes[2*jj,ii].plot(xplot,offset + model[0]([],xplot,plt_epoch),'-r',linewidth=1.2,zorder=2,alpha=0.7,ms=6)
            axes[2*jj,ii].plot(xplot,2*offset + model[1]([],xplot,plt_epoch),'-b',linewidth=1.2,zorder=2,alpha=0.7,ms=6)
            axes[2*jj,ii].plot(xplot,model[2]([],xplot,plt_epoch),'-m',linewidth=1.2,zorder=3,alpha=0.7,ms=6)
            
            # axes[0,ii].plot(xplot,2*offset + model[1]([],xplot,plt_epoch),'-b',linewidth=1.2,zorder=2,alpha=0.7,ms=6)
            # axes[0,ii].plot(xplot,offset + model([],xplot,plt_epoch),'-g',linewidth=1.2,zorder=2,alpha=0.7,ms=6)
            # axes[plt_epoch // r_plots, plt_epoch % r_plots].plot(xplot,true_model([],xplot,plt_epoch),'-r',linewidth=1.2,zorder=1,alpha=0.5,ms=6)
            
            
            axes[2*jj,ii].set_ylim(-2,3)
            axes[2*jj,ii].set_xticks([])
            # axes[0].set_yticks([])
            axes[2*jj+1,ii].set_xticks(np.log(lrange))
            axes[2*jj+1,ii].set_xticklabels(['{:2.0f}'.format(x) for x in lrange])
            
            axes[2*jj+1,ii].plot(dataset.xs[plt_epoch,:],dataset.ys[plt_epoch,:] - model([],dataset.xs[plt_epoch,:],plt_epoch),'.k',alpha=0.4,ms=1)
            
            axes[2*jj+1,ii].set_ylim(-0.1,0.1)
            axes[2*jj,ii].set_title('airmass = {}'.format(file['airms'][:][plt_epoch]))
        # res_ax = axes[plt_epoch // r_plots, plt_epoch % r_plots].twinx()
        # residual = loss(model_set[model_num].get_parameters(),dataset,plt_epoch,model_set[model_num])
        # res_ax.step(dataset.xs[plt_epoch,:],residual,where='mid',alpha=0.3,label='residual')
        # res_ax.set_ylim(0.0,20)
        # res_ax.set_yticks([])
        # res_ax.step(model_set[i][j][1].xs+model_set[i][j][0].p[plt_epoch],\
        #             model_set[i][j].results[-2]['grad'][:],\
        #             where='mid',alpha=0.4,label='residual',zorder=-1)
        # res_ax.set_yticks([])
        
        # res_ax.step(x_grid,\
        #             rv_model_deriv[:,plt_epoch],\
        #             where='mid',alpha=0.4,label='RV Derivative',zorder=-1)
            
        #     res_ax.step(x_grid,\
        #                 rv_loss_deriv[:,plt_epoch],\
        #                 where='mid',alpha=0.4,label='RV Derivative',zorder=-1)
            
        #     align_yaxis(, 0, , 0)
            
            # align.yaxes(axes[plt_epoch // r_plots, plt_epoch % r_plots], 0.0, res_ax, 0.0, 2./3.)
        
        # res.get_shared_y_axes().join(ax1, ax3)
        fig.text(0.5, 0.04, '$\lambda$', ha='center')
        # fig.text(0.04, 0.5, 'y', va='center', rotation='vertical')
        # fig.text(0.96, 0.5, '$d \L /d \delta x$', va='center', rotation=270)
        # fig.text(0.96, 0.5, '$d f_{{{ji}}} /d \delta x_k$', va='center', rotation=270)
        # fig.text(0.96, 0.5, 'residuals', va='center', rotation=270)
    
    plt.savefig(os.path.join(out_dir,'02-full-barn-51peg.png'),dpi=300,bbox_inches='tight')
    plt.show()

In [None]:
make_better_plot(model_set,datasets,file_set)

In [None]:
file_set

In [None]:
tell_loss = [[],[]]
for jjj, (dataset, model) in enumerate(zip(datasets,model_set)):
    for iii in range(dataset.ys.shape[0]):
        tell_loss[jjj].append(loss([],dataset,iii,model[0]).sum())

plt.plot(np.array(file_p['airms'][:]),tell_loss[0],'.k',label='51 peg')
plt.plot(np.array(file_b['airms'][:]),tell_loss[1],'.r',label='barnards')
# plt.ylim(0.0,5e4)

# plt.plot(np.array(file_p['airms'][:]),model_p[1][1].p,'.k',label='51 peg')
# plt.plot(np.array(file_b['airms'][:]),model_b[1][1].p,'.r',label='barnards')
plt.xlabel('airmass')
plt.ylabel('$\Sigma_* (y_* - \hat{y}_s(x_*)) I_{y*}$')
# plt.plot()
plt.legend()
plt.savefig(os.path.join(out_dir,'02-airmass_loss.png'),dpi=300,bbox_inches='tight')
plt.show()

In [None]:
plt.plot(np.array(file_p['airms'][:]),model_p[1][1].p,'.k',label='51 peg')
plt.plot(np.array(file_b['airms'][:]),model_b[1][1].p,'.r',label='barnards')
plt.xlabel('airmass')
plt.ylabel('~a')
x_space = np.linspace(np.min(np.array(file_b['airms'][:])),np.max(np.array(file_b['airms'][:])))
plt.plot(x_space,x_space,'-.k',alpha=0.3)
plt.legend()
plt.savefig(os.path.join(out_dir,'02-airmass_an.png'),dpi=300,bbox_inches='tight')
plt.show()

In [None]:
fig, axes = plt.subplots(1,figsize=(4,4),sharex=True,facecolor=(1, 1, 1),dpi=200)

plt_epoch = 10
x_window = np.log(4550) - np.log(4549)
lmin = np.exp(dataset_p.xs[0,500])
lmax = np.exp(dataset_p.xs[0,1500])
lrange = np.arange(lmin,lmax,5)
xplot = np.linspace(np.log(lmin)-x_window,np.log(lmax)+x_window,dataset_p.xs.shape[1]*10)
axes.plot(xplot,model_p[1]([],xplot,plt_epoch),'-b',linewidth=1.2,zorder=2,alpha=0.6,ms=6,label='51 peg')
axes.plot(xplot,0.05 + model_b[1]([],xplot,plt_epoch),'-r',linewidth=1.2,zorder=2,alpha=0.6,ms=6,label='barnard')
axes.legend()

axes.set_ylim(-0.2,0.1)
axes.set_xticks([])
axes.set_ylabel('log flux + offset')
axes.set_xlabel('$\lambda$')
axes.set_xticks(np.log(lrange))
axes.set_xticklabels(['{:2.0f}'.format(x) for x in lrange])
plt.title('just tellurics')
plt.savefig(os.path.join(out_dir,'02-airmass-tell.png'),dpi=300,bbox_inches='tight')
plt.show()

In [None]:
# model_p.fix()
# model_p.fit(0,1)
# e_num = 0
# dudth = jax.jacfwd(model_p, argnums=0)(model_p.get_parameters(),dataset_p.xs[e_num,:],e_num)
# ith   = dudth * dataset.yivar * dudth.T
# print(dudth)

In [None]:
def variation_info(self,model,dataset):
    f_info = np.zeros(dataset.xs.shape)
    model.fix()
    self.fit()
    for e_num in range(dataset.xs.shape[0]):
        duddx = jax.jacfwd(model, argnums=0)(model.get_parameters(),dataset.xs[e_num,:],e_num)
        f_info[e_num,:] =  jnp.dot(duddx[:,e_num]**2,dataset.yivar[e_num,:])
    return f_info