In [None]:
import glob
import sys
import os
import requests
from desispec.io import read_spectra
from astropy.table import Table
import numpy
import matplotlib.pyplot as plt
from desiutil.log import get_logger, DEBUG
from desidiff.src.group_tiles import *
from desidiff.src.dates_to_process import *
from desidiff.src.coadd import *
from desidiff.src.scores import *
from desidiff.src.ContinuumFitFilter_desidiff import *

In [None]:
filename = "/global/project/projectdirs/desi/spectro/redux/fuji/tiles/cumulative/80664/20210402/spectra-4-80664-thru20210402.fits"

In [None]:
spectra = read_spectra(filename)
ra_dec_list = Table.read(filename, format='fits',hdu=1, memmap=True)['TARGETID','TARGET_RA', 'TARGET_DEC']
rr = Table.read(filename.replace('spectra','redrock'), format='fits',hdu=1, memmap=True)['TARGETID','Z','ZERR','ZWARN','SPECTYPE']

In [None]:
#Set non-default plot size 
plt.rcParams["figure.figsize"] = (20,6)

#SkyPortal token:
secret_file = "/global/cfs/cdirs/desi/science/td/secrets/desidiff_sp.txt"
with open(secret_file, 'r') as file:
    token = file.read().replace('\n', '')
headers = {'Authorization': f'token {token}'}

filter_name = 'DESIDIFF'

In [None]:
t = 39628438574727938

In [None]:
target_spectra = spectra.select(targets=[t])

In [None]:
# the number of unique nights for this target id
unique_nights = numpy.unique(target_spectra.fibermap['NIGHT'])

In [None]:
lminb=3700.
lminr=5800.
lmaxr=7580.
lmaxz=9100.

zinfo = rr[rr['TARGETID']==t]
ra_dec_data = ra_dec_list[ra_dec_list['TARGETID'] == t]
for night in unique_nights: 
    newSpectra=[]
    refSpectra=[]
    idx = numpy.in1d(unique_nights, night)
    ref_nights = unique_nights[~idx]

    ## build reference
    refSpectra.append(target_spectra.select(nights=ref_nights, targets = t))  

    ## build new
    newSpectra.append(target_spectra.select(nights=night, targets = t))

    ## search
    newflux, newivar, newwave, newmask = coadd(newSpectra)
    refflux, refivar, refwave, refmask = coadd(refSpectra)

    # renormalize spectra to match each other
    # There is a significant background of spectra that have the same shape but different fluxes
    # This seems to be related to mistaken coordinates of bright sources
    norm = normalization(newflux,newmask, refflux,refmask)

    for key in newflux.keys():
        newflux[key]=newflux[key]/norm
        newivar[key]=newivar[key]*norm**2        

    #difflux, difivar, difmask, difwave = dict.fromkeys(["b", "r", "z"]), dict.fromkeys(["b", "r", "z"]), dict.fromkeys(["b", "r", "z"]), dict.fromkeys(["b", "r", "z"])

    difflux = {key: newflux[key] - refflux[key]
                   for key in newflux.keys()}
    difivar = {key: 1./(1./newivar[key] + 1./refivar[key])
                   for key in newivar.keys()}
    difmask = {key: newmask[key] + refmask[key]
                   for key in newmask.keys()}
    difwave = dict(newwave)
    
    restwave_diff = dict()
    restwave_new = dict()
    restwave_ref = dict()
    for band in newwave.keys():
        restwave_new[band] = [i/(1+zinfo['Z'][0]) for i in newwave[band]]
        restwave_ref[band] = [i/(1+zinfo['Z'][0]) for i in refwave[band]]
        restwave_diff[band] = [i/(1+zinfo['Z'][0]) for i in difwave[band]]

    # Mask systematically problematic areas of the spectrum
    # trim red edge 
    if 'z' in difflux.keys():
        difmask['z'][difwave['z'] > lmaxz]=1

    # trim blue edge 
    if 'b' in difflux.keys():
        difmask['b'][difwave['b'] < lminb]=1

    if 'r' in difflux.keys():
        difmask['r'][difwave['r'] < lminr]=1
        difmask['r'][difwave['r'] > lmaxr]=1

    # mean-subtracted difference
    difflux_clipmean = clipmean(difflux,difivar,difmask)

    ## filters 

    # Difference spectrum may have broadband signal
    perband_filter = perband_SN(difflux_clipmean,difivar,difmask)

    # fractional increase
    perband_inc = perband_increase(difflux_clipmean,difivar,difmask, refflux,refivar,refmask)

    # Difference spectrum may have high-frequency signal
    perres_filter = perconv_SN(difwave, difflux_clipmean,difivar,difmask, ncon = 7)

    # Search for signature lines of TDEs, only interested in Galaxies
    linetable = line_finder(difwave, difflux_clipmean,difivar,difmask,zinfo['Z'][0])
    spectype = zinfo['SPECTYPE'][0]
    if spectype == "GALAXY":
        TDE_score = (linetable, difflux_clipmean)
    else:
        TDE_score = 0
    # Hlines
    Hline_score = Hline_filter(linetable)

    #broadband
    bblogic = any(numpy.logical_and(numpy.array(list(perband_filter.values()))>10, numpy.array(list(perband_inc.values()))>0.25))
    linelogic = perres_filter >=2
    TDElogic = any([TDE_score == 2, TDE_score == 3, TDE_score == 4, TDE_score == 5])
    Hline_score >= 1

    logic = [bblogic,linelogic, TDElogic, Hline_score]
    logic_name = ['Broadband', 'line', 'TDE','Hline'] #must be in same order as logic!, use as mask
    logic_name = numpy.ma.masked_array(logic_name, mask = [not i for i in logic])
    plt.clf()
    if any(logic):
        #Uncomment next line if you want to print only those TargetIds that get plotted
        print(t, night)
        print(logic_name)
        for b in difflux.keys():
            w=numpy.where(difmask[b] ==0)[0]
            if b == list(difflux.keys())[-1]:
                plt.plot(difwave[b][w],difflux[b][w],color='red', label = 'Difference')
                plt.plot(newwave[b][w],newflux[b][w],color='blue',alpha=0.5, label = 'New Spectrum')
                plt.plot(refwave[b][w],refflux[b][w],color='green',alpha=0.5, label = 'Reference Spectrum')
            else:
                plt.plot(difwave[b][w],difflux[b][w],color='red')
                plt.plot(newwave[b][w],newflux[b][w],color='blue',alpha=0.5)
                plt.plot(refwave[b][w],refflux[b][w],color='green',alpha=0.5)
        plt.legend()
        plt.xlim((lminb,lmaxz))
        plt.xlabel('Wavelength (A)')
        plt.ylabel('Flux') 
        plt.show()
        
        #Rest Frame plot
        for b in difflux.keys():
            w=numpy.where(difmask[b] ==0)[0]
            plt.plot(restwave_diff[b],difflux[b],color='red')
            plt.plot(restwave_new[b],newflux[b],color='blue',alpha=0.5)
            plt.plot(restwave_ref[b],refflux[b],color='green',alpha=0.5)
        plt.legend()
        plt.xlim((lminb,lmaxz))
        plt.xlabel('Wavelength (A)')
        plt.ylabel('Flux') 
        plt.show()