In [1]:
import os
import datetime
import glob
import pathlib
import sys
from dataclasses import dataclass
from collections import defaultdict

import jax
import jax.numpy as jnp
import h5py
import matplotlib.pyplot as plt
import numpy as np

from astropy.io import fits
import astropy.coordinates as coord
import astropy.time as atime
import astropy.units as u

sys.path.insert(0,'..')
import jabble.loss
import jabble.dataset
import jabble.model
import jabble.physics

import jax.config 
jax.config.update("jax_enable_x64", True)



In [2]:
allvisits = fits.open('/scratch/mdd423/wobble_jax/data/apogee/allVisit-dr17-synspec_rev1.fits')
allstar = fits.open('/scratch/mdd423/wobble_jax/data/apogee/allStar-dr17-synspec_rev1.fits')

In [3]:
rrl_mask = (
    ((allstar[1].data["APOGEE2_TARGET1"] & 2**24) != 0)
    & (allstar[1].data["NVISITS"] >= 5)
)
rrl = allstar[1].data[rrl_mask]

In [4]:
cache_path = pathlib.Path("/scratch/mdd423/wobble_jax/data/apogee")

In [5]:
def extract_dataset(apogee_id):
    def load_apVisit(visit):
#         root_path = f"sas/dr17/apogee/spectro/redux/dr17/visit/"

        if visit['TELESCOPE'] == 'apo25m':
            sorp = 'p'
        elif visit['TELESCOPE'] == 'lco25m':
            sorp = 's'
        else:
            raise NotImplementedError()

        sub_path = (f"{visit['TELESCOPE']}/" +
                    f"{visit['FIELD'].strip()}/" +
                    f"{int(visit['PLATE']):04d}/" +
                    f"{int(visit['MJD']):05d}/")
        filename = (f"a{sorp}Visit-dr17-{int(visit['PLATE']):04d}-" +
                    f"{int(visit['MJD']):05d}-" +
                    f"{int(visit['FIBERID']):03d}.fits")
        local_path = cache_path / visit['APOGEE_ID']
        local_path.mkdir(exist_ok=True, parents=True)

        local_file = local_path / filename

        hdul = fits.open(local_file)
        return hdul
    
    def load_apCframes(visit):
#         root_path = f"sas/dr17/apogee/spectro/redux/dr17/visit/"

        if visit['TELESCOPE'] == 'apo25m':
            sorp = 'p'
        elif visit['TELESCOPE'] == 'lco25m':
            sorp = 's'
        else:
            raise NotImplementedError()

        sub_path = (f"{visit['TELESCOPE']}/" +
                    f"{visit['FIELD'].strip()}/" +
                    f"{int(visit['PLATE']):04d}/" +
                    f"{int(visit['MJD']):05d}/")

        visit_hdul = load_apVisit(visit)
        frames = [int(visit_hdul[0].header[k]) for k in visit_hdul[0].header.keys()
                  if k.startswith('FRAME')]

        if len(frames) <= 1:
            return None

        hduls = defaultdict(dict)
        for frame in frames:
            for chip in ['a', 'b', 'c']:
                filename = f'a{sorp}Cframe-{chip}-{frame:08d}.fits'

                local_path = (cache_path /
                              f"{int(visit['PLATE']):04d}" /
                              f"{int(visit['MJD']):05d}/")
                local_path.mkdir(exist_ok=True, parents=True)

                local_file = local_path / filename
                hduls[frame][chip] = fits.open(local_file)

        return hduls
    
    visits = allvisits[1].data[allvisits[1].data["APOGEE_ID"] == apogee_id]
    
    fluxes = []
    errores = []
    masks = []
    waves = []
    times = []
    
    for visit in visits:
        apvisitframe = load_apVisit(visit)
        apcframes    = load_apCframes(visit)
        
        targ_id = apvisitframe[0].header['OBJID']
        
#         print(apcframes)
        # print(apcframes.keys())
        for key in apcframes.keys():
            subframe_images = apcframes[key]
#             print(subframe_images)
            for subkey in subframe_images.keys():
                subframe = subframe_images[subkey]
                fiber_id = np.array(subframe[11].data['FIBERID'])[np.where(targ_id == np.array(subframe[11].data['OBJECT']))]

                flux = np.array(subframe[1].data)[fiber_id,:].flatten()
                wave = np.array(subframe[4].data)[fiber_id,:].flatten()
                err  = np.array(np.array(subframe[2].data)[fiber_id,:]).flatten()

                # print(flux.shape)
                mask_1 = np.array(subframe[3].data,dtype=bool)[fiber_id,:].flatten()
                # print(subframe[3].header)
                mask_1 = ~((16639 & mask_1) == 0)
            
                # print(np.stack(np.unique(subframe[3].data,return_counts=True)))
                
                mask_2 = flux <= 0.0
                mask_3 = err <= 0.0
                # mask_4 = np.isnan(flux)
                # mask_5 = np.isnan(wave)
                # mask_6 = np.isnan(err)
                mask_full = (mask_1 + mask_2 + mask_3).astype(bool) #+ mask_4 + mask_5 + mask_6).astype(bool)
                
                masks.append(mask_full)
                # print(jnp.sum(masks[-1]))
                fluxes.append(np.log(flux).flatten())
                errores.append((1/(err/flux)**2))
                waves.append(np.log(wave))

                fluxes[-1][mask_full] = 0.0
                errores[-1][mask_full] = 0.0
                waves[-1][mask_full] = np.min(waves[-1])
                
                times.append(atime.Time(subframe[0].header['DATE-OBS']))
    return jabble.dataset.Data.from_lists(waves, fluxes, errores, masks), np.array(times)

In [6]:
today = datetime.date.today()
out_dir = os.path.join('..','out',today.strftime("%y-%m-%d"))
os.makedirs(out_dir,exist_ok=True)

In [7]:
def fit_cycle(model,loss,data,device):
    # Fit Norm
    model.fix()
    model.fit(0,1)
    res = model.gpu_optimize(loss, data, device)
    print(res)
    
    # Fit Stellar
    model.fix()
    model.fit(1,1)
    res = model.gpu_optimize(loss, data, device)
    print(res)
    
    # Fit RVs
    model.fix()
    model.fit(1,0)
    res = model.gpu_optimize(loss, data, device)
    print(res)
    
    # Fit Everything
    model.fix()
    model.fit(0,1)
    model.fit(1,0)
    model.fit(1,1)
    res = model.gpu_optimize(loss, data, device)
    print(res)
    
    return model

In [8]:
def get_model(dataset,resolution,p_val,vel_padding,init_shifts):
    # dx = jabble.physics.delta_x(2 * resolution)
    # x_grid = np.arange(np.min(xs), np.max(xs), step=dx, dtype="float64")
    
    model_grid = jabble.model.create_x_grid(
        np.concatenate([dataframe.xs for dataframe in data]), vel_padding, 2 * resolution
    )  

    model = jabble.model.CompositeModel(
        [
            jabble.model.ShiftingModel(init_shifts),
            jabble.model.IrwinHallModel_vmap(model_grid, p_val),
        ]
    )
    
    return model

In [9]:
subframe = extract_dataset(rrl[0]["APOGEE_ID"])

  fluxes.append(np.log(flux).flatten())
  fluxes.append(np.log(flux).flatten())
  errores.append((1/(err/flux)**2))


In [10]:
# subframe[3].header

In [11]:
cpu = jax.devices('cpu')[0]
gpu = jax.devices('gpu')[0]

In [12]:
loss = jabble.loss.ChiSquare()

In [13]:
def get_normalization_model(dataset,norm_p_val,pts_per_wavelength):
    len_xs = np.max([np.max(dataframe.xs) - np.min(dataframe.xs) for dataframe in dataset])
    min_xs = np.min([np.min(dataframe.xs) for dataframe in dataset])
    max_xs = np.max([np.max(dataframe.xs) for dataframe in dataset])

    shifts = jnp.array([dataframe.xs.min() - min_xs for dataframe in dataset])

    x_num = int((np.exp(max_xs) - np.exp(min_xs)) * pts_per_wavelength)
    x_spacing = len_xs/x_num
    x_grid = jnp.linspace(-x_spacing,len_xs+x_spacing,x_num+2) + min_xs
    
    model = jabble.model.IrwinHallModel_vmap(x_grid, norm_p_val)
    size  = len(dataset)
    
    norm_model = MyNormalizationModel(model,size)
    return jabble.model.ShiftingModel(shifts).composite(norm_model)

In [14]:
class MyNormalizationModel(jabble.model.Model):
    def __init__(self, model, size):
        super(MyNormalizationModel, self).__init__()
        self.p     = jnp.repeat(model.p,size)
        self.model = model
        self.parameters_per_model = jnp.empty(size,dtype=int)
        self.size  = size
        self.update_parameters_per()

    def update_parameters_per(self):
        for i in range(self.size):
            self.parameters_per_model = self.parameters_per_model.at[i].set(self.model.get_parameters().shape[0])
        self.create_param_bool()
    
    def get_parameters(self):
            
        x  = super(MyNormalizationModel, self).get_parameters()
        self.update_parameters_per()
        return x   

    def fit(self):
        """
        Sets model into fitting model. All parameters will be varied during next optimization call.
        """
        self._fit = True
        self.model.fit()
        self.update_parameters_per()

    def fix(self):
        """
        Sets model into fitting model. All parameters will be varied during next optimization call.
        """
        self._fit = False
        self.model.fix()
        self.update_parameters_per()

    def split_p(self, p):
        p_list = jnp.array([
            p[
                self.get_indices(k)
            ]
            for k in range(len(self.parameters_per_model))
        ])

        return p_list

    def get_indices(self,i):
        """
        Get array of ints for the ith submodel, in models list using parameters_per_model
        Returns
        -------
        indices : 'np.ndarray(int)`
            Array of indices for the parameters in the ith model that is in fitting mode
        """
        return self._param_bool[i]

    def create_param_bool(self):
        self._param_bool = np.zeros((self.size,int(np.sum(self.parameters_per_model))))
        for i in range(self.size):
            self._param_bool[i,int(jnp.sum(self.parameters_per_model[:i])):int(jnp.sum(self.parameters_per_model[: i + 1]))] = jnp.ones(
                                            (int(jnp.sum(self.parameters_per_model[: i + 1])) - int(jnp.sum(self.parameters_per_model[:i]))),
                                            dtype=bool,
                                        )
        self._param_bool = jnp.array(self._param_bool,dtype=bool)
    
    def call(self, p, x, i, *args):
        # indices = self.get_indices(i)
        parameters = self.split_p(p)
        x = self.model(parameters[i], x, i, *args)
        return x

In [None]:
for rrl_row in rrl:

    data, times = extract_dataset(rrl_row["APOGEE_ID"])
    data.to_device(gpu)
    
    targ_name = "2MASS " + rrl_row["APOGEE_ID"][2:]

    star = coord.SkyCoord.from_name(targ_name,parse=True)
    loc      = coord.EarthLocation.of_site("APO")
    init_vels   = np.array([star.radial_velocity_correction(obstime=time, location=loc).to(u.m/u.s).value for time in times])
    init_shifts = jabble.physics.shifts(init_vels)
    norm_p_val = 2
    p_val = 3
    resolution = 30_000
    
    vel_padding = 60_000
    
    norm_pt = 1/60
    model = get_normalization_model(data,norm_p_val,norm_pt) + get_model(data,resolution,p_val,vel_padding,init_shifts)
    model.to_device(gpu)

    # model.fix()
    # model.fit(0)
    # model.display()
    # for i,dataframe in enumerate(data):
    #     # print(jnp.sum(jnp.isnan(dataframe.ys[~dataframe.mask])))
    #     # print(jnp.sum(jnp.isnan(dataframe.yivar[~dataframe.mask])))
    #     # print(jnp.sum(jnp.isnan(dataframe.xs[~dataframe.mask])))
    #     # print("masks:",jnp.sum(dataframe.mask))
    #     def _internal(p):
    #         return loss(p,dataframe.xs,dataframe.ys,dataframe.yivar,dataframe.mask,i,model).sum()
    #     gradient = jax.grad(_internal)(model.get_parameters())
    #     nan_grad = np.isnan(gradient)
    #     zer_grad = np.isclose(gradient,0.0)
    #     # so none of the nans are coming from other epochs
    #     # print(np.sum(zer_grad))
    #     # if np.sum(nan_grad) != 0:
    #     # nan_grad_temp = model[0].split_p(nan_grad)
    #     # if jnp.sum(nan_grad_temp[i]):
    #     #     print(i,':',jnp.sum(nan_grad_temp[i]),'/',len(nan_grad_temp[i]))
    #     #     print('ma:',jnp.sum(dataframe.mask))
    #     #     print('y:: zeros {}, nans {}'.format(jnp.sum(np.isclose(dataframe.ys[~dataframe.mask],0.0)),jnp.sum(np.isnan(dataframe.ys[~dataframe.mask]))))
    #     #     print('yi: zeros {}, nans {}'.format(jnp.sum(np.isclose(dataframe.yivar[~dataframe.mask],0.0)),jnp.sum(np.isnan(dataframe.yivar[~dataframe.mask]))))
    #     #     break
    #     # plt.plot(dataframe.xs,dataframe.ys)
    #     # plt.plot(model[0][i].xs[nan_grad_temp[i]],model(model.get_parameters(),model[0][i].xs[nan_grad_temp[i]],i),'.r')
    #     # title_name = '07-{}_nangrad_norm_{}.png'.format(rrl_row["APOGEE_ID"],i)
    #     # plt.title(title_name)
    #     # plt.savefig(os.path.join(out_dir, title_name))
    #     # plt.show()
    # # break
    model = fit_cycle(model,loss,data,gpu)

    mdl_name = 'apogee-{}-cframe'.format(rrl_row["APOGEE_ID"])
    jabble.model.save('{}.mdl'.format(mdl_name),model)

  fluxes.append(np.log(flux).flatten())
  fluxes.append(np.log(flux).flatten())
  errores.append((1/(err/flux)**2))


[32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32
 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32
 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32
 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32
 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32
 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32
 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32
 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32
 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32
 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32 32]


In [None]:
x1 = np.array([1,0,3]) + np.zeros(3)
print(x1,x1.astype(bool))

In [None]:
x1 = np.array([1,2,1])
print(x1[:0])