In [None]:
import csv
import os
import io
import time
import requests
import numpy as np
from glob import glob
import matplotlib as mpl
import matplotlib.pyplot as plt
from PIL import Image
import seaborn as sns

from astropy.table import Table, vstack
from astropy.io import fits, ascii
from scipy.signal import medfilt
from astropy.convolution import Gaussian1DKernel, convolve
from astropy.modeling import fitting, models
from astropy.cosmology import Planck15
from astropy import units as u

from desispec.io import read_spectra, read_fibermap
from prospect.mycoaddcam import coadd_brz_cameras
from desitarget.targetmask import desi_mask

In [None]:
class DesiRedshiftTable:
    
    '''This is a class intended for use on tables structured such that it contains the columns:

        'TARGETID' : unique target id assigned to DESI object
        'Z' : best redshift calculated from Redrock pipeline
        'ZWARN' : bitmask assigned to quality of calculated z from redrock
        'SPECTYPE' : spectral type assigned by DESI Spectroscopic pipeline
        'target_ra' : target right ascension given for DESI fiber assignment
        'target_dec' : target declination given for DESI fiber assignment
        'FIBERSTATUS' : bitmask assigned for fiber conditions during observation
        'DESI_TARGET' : bitmask assigned by DESI describing the type of target 
                        and conditions during observation
        'LAST_TILEID' : last tile the target was observed on
        'LAST_NIGHT' : last night the target was assigned through

       at a minimum.
       This class can return a table of data from targets to date which have a
       calculated redshift and apply cuts to this table or any table with similar formatting.
    '''
    
    
    def __init__(self, t, desi_targetid, z, zwarn, spectype, fiberstatus, desi_target):
        self.t = t
        self.tid = desi_targetid.name
        self.z = z.name
        self.zwarn = zwarn.name
        self.spectype = spectype.name
        self.fiberstatus = fiberstatus.name
        self.desi_target = desi_target.name
    
    
    def qual_cuts(self):
        '''Implements numerous quality cuts.'''
        isGAL = ((self.t[self.desi_target] & desi_mask.mask("QSO|LRG|ELG|BGS_ANY")) == 0)
        fiberstat = (self.t[self.fiberstatus] != 0)
        zwarn = (self.t[self.zwarn] != 0)
        bright_obj = (self.t[self.desi_target] & desi_mask.mask("NEAR_BRIGHT_OBJECT|IN_BRIGHT_OBJECT|BRIGHT_OBJECT|NO_TARGET") != 0)
        isSPECTYPE = (self.t[self.spectype]=='STAR')
        isb = np.where(isGAL | fiberstat | zwarn | bright_obj | isSPECTYPE)[0]
        self.t.remove_rows([isb])
    
    def add_angsep_dl(self):
        '''Adds columns for the luminosity distance of a redshift value and angular separation of 300 kpc.'''
        self.t.add_columns([self.d_lumin(self.t[self.z]), self.ang_sep(self.t[self.z],300)], names=['dL', 'ang_sep'])
    
    def min_cuts(self):
        '''Implements a minimum on the redshift values of a table and corrects the spectype for negative redshifts.'''
        self.t[self.spectype][(self.t[self.z] < 0)] = 'STAR'  # The redrock redshifts < 0 are blueshifted stars moving towards us
        self.t['ang_sep'][(self.t[self.z] < 0.01)] = 0.001  # Sets a minimum angular separation for low redshifts
        self.t['dL'][(self.t[self.z] < 0.01)] = self.d_lumin(0.01)  # Sets luminosity distance to match the angular separation limit on low redshifts
        self.t.sort('ang_sep')

    @staticmethod
    def d_lumin(z):
        '''Returns the luminosity distance in kpc calculated from the redshift value.'''
        return Planck15.luminosity_distance(z=z).value*1000
    
    @staticmethod
    def ang_sep(z,d):
        '''Returns the angular separation in degrees for a distance of d in kpc determined from the 
        luminosity distance at a specific redshift.
        '''
        return (d/(Planck15.luminosity_distance(z=z).value*1000))*(180/np.pi)
    
    @staticmethod
    def all_z_todate():
        '''Returns a table containing data of all targets that have a calculated redshift from redrock to date.'''
        start = time.time()
        basedir = os.environ['DESI_SPECTRO_REDUX']
        redux = 'daily/tiles/cumulative'
        zbest_paths = glob(f'{basedir}/{redux}/**/zbest*.fits', recursive=True)
        tabs = [Table.read(path, 'ZBEST')['TARGETID','Z','ZERR','ZWARN','SPECTYPE','SUBTYPE'] for path in zbest_paths]
    
        for i,tab in enumerate(tabs):
            s = Table.read(zbest_paths[i].replace('zbest', 'coadd'), 'FIBERMAP')
            tab.add_columns([s['TARGET_RA'],s['TARGET_DEC'],s['FIBERSTATUS'],s['DESI_TARGET'],s['LAST_TILEID'],s['LAST_NIGHT'],
                             zbest_paths[i].split("zbest-")[1][0],zbest_paths[i].replace('zbest', 'coadd'),list(range(0,500))], 
                            names=['target_ra','target_dec','FIBERSTATUS','DESI_TARGET','LAST_TILEID','LAST_NIGHT',
                                   'spectro','coadd_file','coadd_index']
                           )
        
        end = (time.time() - start)/60
        print(f"  Total time: {end:0.3} min")
        
        return vstack(tabs)

In [None]:
z_tab = DesiRedshiftTable.all_z_todate()

In [None]:
z_tab.info

In [None]:
t = DesiRedshiftTable(z_tab, z_tab['TARGETID'], z_tab['Z'], z_tab['ZWARN'], z_tab['SPECTYPE'], z_tab['FIBERSTATUS'], z_tab['DESI_TARGET'])
t.qual_cuts()
t.add_angsep_dl()
t.min_cuts()

In [None]:
def plot_host(gal_ra, gal_dec, t_ra, t_dec, z, **kwargs):
    """ Plots a cutout from the Legacy Survey using DR9 layer and marks the locations 
    of a galaxy and transient within 60" of each other.

    Parameters
    ----------
    gal_ra : Right ascension of galaxy.
    gal_dec :
    t_ra : 
    t_dec :
    z : 
    
    Optional Parameters
    -------------------
    desi_id :
    survey :
    t_name :
    spectype :
    
    """
    # Define optional parameters
    opt = {
        'desi_id' : 'None Given',
        'survey': '?',
        't_name' : 'None Given',
        'spectype' : 'None Given',
    }
    opt.update(kwargs)
    
    q3c_ang_sep = 3600*(300/(Planck15.luminosity_distance(z=z).value*1000))*(180/np.pi)
    d_ra = (t_ra-gal_ra)*np.cos(np.abs(gal_dec))*3600
    d_dec = (t_dec-gal_dec)*3600
    t_ang_sep = ((d_ra**2)+(d_dec**2))**(0.5)
    
    # Adjust image size to be sure to capture a cutout that has the transient location in view
    if q3c_ang_sep>30:
        center = 256
        url = f'http://legacysurvey.org/viewer/cutout.jpg?ra={gal_ra}s&dec={gal_dec}s&size=512&layer=ls-dr9&pixscale=0.27&bands=grz'.format()
        
    if q3c_ang_sep<=30:
        center = 128
        url = f'http://legacysurvey.org/viewer/cutout.jpg?ra={gal_ra}s&dec={gal_dec}s&size=256&layer=ls-dr9&pixscale=0.27&bands=grz'.format()
        
    if t_ang_sep>60:
        print(f'https://www.legacysurvey.org/viewer-desi?ra={gal_ra}&dec={gal_dec}&layer=ls-dr9&zoom=15'.format())
        return print('Angular separation too large for cutout image.')
        
    print(f'https://www.legacysurvey.org/viewer-desi?ra={gal_ra}&dec={gal_dec}&layer=ls-dr9&zoom=15'.format())
    print('DESI TargetID: '+opt.get('desi_id')+'\nSurvey: '+opt.get('survey')+'\nName/ID: '+opt.get('t_name'))
    
    plt.clf()
    plt.figure(figsize=(10,10))
    r = requests.get(url)
    im = Image.open(io.BytesIO(r.content))
    plt.imshow(im)
    
    # Plot a target for galaxy and transient location
    plt.scatter(center+d_ra/0.27, center+d_dec/0.27, marker= 'x', s=100, c='fuchsia', 
                label='Survey: '+opt.get('survey')+f'\nname: '.format()+opt.get('t_name')+f'\nRA: {t_ra :0.2f} deg\nDec: {t_dec :0.2f} deg'.format())
    plt.scatter(np.NaN, np.NaN, marker = '+', s=100, color = 'tab:green', 
                label=f'Galaxy\nTarget ID: '.format()+opt.get('desi_id')+f'\nRA: {gal_ra :0.2f} deg\nDec: {gal_dec :0.2f} deg'.format())
    plt.hlines(center, center+20, center+30, color = 'tab:green')
    plt.hlines(center, center-30, center-20, color = 'tab:green')
    plt.vlines(center, center+20, center+30, color = 'tab:green')
    plt.vlines(center, center-30, center-20, color = 'tab:green')

    # add ellipse displaying the region searched by q3c_join( ) function during cross-matching
    ell = mpl.patches.Ellipse((center,center), width=2*q3c_ang_sep*np.cos(gal_dec)/0.27, height=2*q3c_ang_sep/0.27, edgecolor='red', ls=':', facecolor='none', linewidth=2)
    
    # image scale
    plt.errorbar(im.size[0]-50, im.size[0]-25, xerr=10/0.27/2, color='w', capsize=5)
    plt.text(im.size[0]-50,im.size[0]-15, '10 arcsec', c='w', horizontalalignment='center', verticalalignment='top', size='medium', fontweight='bold')
    
    # compass
    plt.arrow(25,im.size[0]-25,0,-50, 
              color='w', 
              head_width=5)
    plt.arrow(25,im.size[0]-25,50,0, 
              color='w', 
              head_width=5)
    plt.text(25,im.size[0]-90,'N',
             c='w',
             horizontalalignment='center', 
             verticalalignment='center', 
             fontweight='bold')
    plt.text(90,im.size[0]-25,'E',
             c='w',
             horizontalalignment='left', 
             verticalalignment='center', 
             fontweight='bold')
    
    bbox_props=dict(facecolor='w',alpha=0.5)
    legend_properties = {'weight':'bold'}
    plt.gca().add_patch(ell)
    plt.text(15,45,f'spectpye = '.format()+opt.get('spectype')+f'\nz = {z :0.1f}\nq3c angular separation = {q3c_ang_sep :0.2f}\ncalculated angular separation = {t_ang_sep :0.2f}'.format(), 
             bbox=bbox_props, 
             fontsize=12)
    plt.legend(fontsize=12, 
               bbox_to_anchor=(1.02,1), 
               loc='upper left', 
               frameon=False, 
               prop=legend_properties, 
               labelspacing=2)
    plt.title("LegSurv DR9 Cutout\ncentered on Desi Target (RA, Dec)", fontsize = 14)
    plt.axis('off')
    plt.show()

In [None]:
z_tab[0]

In [None]:
def plot_spec(tab, targetid):
    idx = np.where(tab['TARGETID']==targetid)[0]
    #print(idx)
    
    if len(idx)>1:
        #print(">1")
        index = idx[-1]
    
    if len(idx) == 1:
        #print("=1")
        index = idx[0]
        
    if len(idx) == 0:
        #print("=0")
        return print("No matching target id in table.")
    
    ra=tab['target_ra'][index]
    dec=tab['target_dec'][index]
    print(f'https://www.legacysurvey.org/viewer-desi?ra={ra}&dec={dec}&layer=ls-dr9&zoom=15'.format())
    
    for i,col in enumerate(tab.columns):
        print(col,': ',tab[col][index])
    
    coadd_obj = read_spectra(tab['coadd_file'][index])
    t_id = coadd_obj.target_ids()[tab['coadd_index'][index]]
    
    rf_err_margain = 0.50
    kernel_smooth = 4
    kernel = Gaussian1DKernel(stddev=kernel_smooth)
    med_filt_size = 19
    snr_threshold = 3.0
    qi_min = 0.01
    sim_fudge = 0.94
    
    wave_arr = [coadd_obj.wave["b"],coadd_obj.wave["r"],coadd_obj.wave["z"]]
    flux_arr = [coadd_obj.flux["b"][tab['coadd_index'][index]],coadd_obj.flux["r"][tab['coadd_index'][index]],coadd_obj.flux["z"][tab['coadd_index'][index]]]
    noise_arr = [np.sqrt(coadd_obj.ivar["b"][tab['coadd_index'][index]])**(-1.0),np.sqrt(coadd_obj.ivar["r"][tab['coadd_index'][index]])**(-1.0),np.sqrt(coadd_obj.ivar["z"][tab['coadd_index'][index]])**(-1.0)]

    x_spc, y_flx, y_err = coadd_brz_cameras(wave_arr, flux_arr,noise_arr)
    
    smooth_yflx = convolve(y_flx, kernel)
    continuum = medfilt(y_flx, med_filt_size)
    
    emission_lines=np.array([7248.09137, 7253.53653, 7978.90963, 8443.30399, 9456.492690000001, 9646.29541, 9739.44654])
    names = ['[OII]', '[OII]', '[H\u03B4]', '[H\u03B3]', '[H\u03B2]', '[OIII]', '[OIII]']
    
    y = np.concatenate((continuum, y_err))
    xmin, xmax, ymin, ymax = np.nanmin(coadd_obj.wave['b']), np.nanmax(coadd_obj.wave['z']), np.nanmin(y[~np.isinf(y)]), np.nanmax(y[~np.isinf(y)])
    
    sns.set_style("darkgrid")
    plt.figure(figsize=(25, 7))
    for cam in ['b','r','z']:
        plt.plot(coadd_obj.wave[cam], coadd_obj.flux[cam][tab['coadd_index'][index]], 'lightseagreen', alpha=0.25, linewidth=0.5)
    plt.plot(np.NaN, np.NaN, 'lightseagreen', linewidth=1, label='data')
    
    plt.plot(x_spc, y_err, 'cornflowerblue', alpha=0.75, linewidth=1, label='noise')
    for i,line in enumerate(emission_lines):
        if (line > np.min(coadd_obj.wave["b"])) & (line < np.max(coadd_obj.wave["z"])):
            plt.axvline(line, color='coral', alpha=0.5, linewidth=2)
    plt.plot(x_spc, continuum, 'mediumorchid', alpha=1, linewidth=1, label='Gaussian Fit')
    
    for i, line in enumerate(emission_lines):
        if (line > np.min(coadd_obj.wave["b"])) & (line < np.max(coadd_obj.wave["z"])):
            plt.text(line, ymax, f"{names[i]}", rotation=90,va='center', ha = 'right', fontsize=20)

    plt.axis([xmin, xmax, ymin-1, ymax+1]) 
    plt.legend(fontsize=20)
    plt.xlabel('wavelength $\lambda [\AA]$', fontsize=18)
    plt.ylabel('Flux [erg/s/cm2/$\AA$]', fontsize=18)
    plt.title(f'DESI Target ID: {t_id}', fontsize=18)