### Install desigal package (assuming all other desi packages are installed)

In [None]:
## Installation for use
# !pip install --user git+https://github.com/desihub/desigal.git#egg=desigal

In [None]:
# installation for development
# !git clone https://github.com/desihub/desigal
# %cd ../..
# !pip install --user --editable .

In [None]:
%load_ext autoreload
%autoreload 2

# Demo on stacking DESI spectra

This is a long version of the demo which goes through each individual step of the pipeline so that people can tinker with each component. The end user might not always need to use all these steps. For a quick two step version of the stacking pipeline see the notebook `spectra_stack_quick_demo.ipynb`

In [None]:
import os
import sys
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
from astropy.table import Table


import desispec
import desispec.io
import desigal.specutils

### Select a couple of spectra for the demo

In [None]:
targets = [39627652591526038,  39627646576885987] # LRGS
z_targets = [0.3313666995460735, 0.3757204903818251]

# targets = [39627646576885924, 39627640566454233] # QSO
# z_targets = [0.47741841167195725, 0.7193102969415338]

spectra = desigal.specutils.get_spectra(targets, release="fuji", n_workers=-1, use_db = True)

### Unfold the various components of the spectra object

In [None]:
z_cat = Table({"TARGETID":targets,"Z":z_targets})
flux = spectra.flux
wave = spectra.wave
ivar = spectra.ivar
mask = spectra.mask
spec_z = z_cat["Z"]
fibermap = spectra.fibermap
exp_fibermap = spectra.exp_fibermap

# Coadd Cameras

This is a faster alternative to `desispec.coaddition.coadd_cameras` but gives identical results.
It can handle redshifted spectra and has a vectorized implementation of the coaddition algorithm.
Currently this is not a full blown replacement for the `desispec` version as this does not merge fibermaps or does sanity checks.
For our specific purpose none of them are important and hence not implemented.

In [None]:
##MASK COADD DOES NOT WORK, FIX THIS!!!!!!!
flux_coadd, wave_coadd, ivar_coadd, mask_coadd = desigal.specutils.coadd_cameras(flux, wave,ivar,mask)

In [None]:
plt.figure(figsize=(20,10))
for idx in range(2):
    plt.plot(wave_coadd, flux_coadd[idx], c= "red")
    plt.plot(wave_coadd, 1/np.sqrt(ivar_coadd[idx]), c= "k")
    plt.title("Coadd cameras before redshift correction")

# Do redshift Correction

The function takes either a `np.array` or a `dict`

In [None]:
flux_dered = desigal.specutils.deredshift(flux_coadd,spec_z, 0 , "flux")
wave_dered = desigal.specutils.deredshift(wave_coadd,spec_z, 0 , "wave")
ivar_dered = desigal.specutils.deredshift(ivar_coadd,spec_z, 0 , "ivar")

Plot observed and de-redshifted flux for one spectra

In [None]:
plt.figure(figsize=(20,10))
for i in range(2):
    plt.plot(wave_dered[i], flux_dered[idx])

# Resample to Common Grid

In [None]:
wave_grid = np.arange(
            np.min(wave_dered), np.max(wave_dered), 0.8
        )

In [None]:
#available options: linear, sn-cons, flux-cons
flux_grid, ivar_grid = desigal.specutils.resample(wave_grid, wave_dered, flux_dered, ivar_dered, fill_val=np.nan, method="linear",n_workers=1)

In [None]:
plt.figure(figsize=(20,10))
idx = 1
plt.plot(wave_dered[idx],flux_dered[idx])
plt.plot(wave_grid,flux_grid[idx])

# Normalize the spectra

In [None]:
flux_normed, ivar_normed = desigal.specutils.normalize(wave_grid, flux_grid,ivar_grid, method="flux-window", flux_window=[4000,4050])

In [None]:
plt.figure(figsize=(20,10))
idx =1

plt.plot(wave_grid, flux_normed[idx])
plt.plot(wave_grid, flux_grid[idx])

plt.plot(wave_grid, 1/np.sqrt(ivar_grid[idx]))
plt.plot(wave_grid, 1/np.sqrt(ivar_normed[idx]))
plt.title("Normalized Spectra")

# Model the IVAR using Sky
Feature under development

### before that lets get the sky

In [None]:
# sky_flux, sky_mask = desigal.get_sky(fibermap = fibermap, exp_fibermap=exp_fibermap)

# sky_flux_coadd, wave_coadd, ivar_coadd, mask_coadd = desigal.coadd_cameras(sky_flux, wave, ivar, sky_mask)

In [None]:
# sky_flux_coadd, wave_coadd, ivar_coadd, mask_coadd = desigal.coadd_cameras(sky_flux, wave, ivar, sky_mask)

In [None]:
# plt.figure(figsize=(20,10))
# for idx in range(2):
#     plt.plot(wave_coadd, sky_flux_coadd[idx], ls="--")

### Now lets model the error using Sky

P.S: Algorithm currently optimized to take sky ivar as input, therefore needs an update to work with sky

In [None]:
# ivar_model= desigal.model_ivar(ivar_coadd, sky_flux_coadd, wave_coadd)

In [None]:
# index = 0

# plt.figure(figsize=(15,6))
# plt.plot(wave_coadd, 1/np.sqrt(ivar_model[index])+1, label="model")
# plt.plot(wave_coadd, 1/np.sqrt(ivar_coadd[index])+1, label ="observed")
# plt.plot(wave_coadd, ((1/ivar_model[index])-(1/ivar_coadd[index]))/(1/ivar_coadd[index]), label="Fractional residual")
# plt.legend(fontsize=10)
# plt.ylabel("Variance (arbitrary units)",size=20)
# plt.xlabel("Wavelength ($\AA$)", size=20)
# plt.axhline(0, ls="--", c="k")
# plt.axhline(-1, ls="--", c="k", alpha=0.5)
# plt.axhline(1, ls="--", c="k", alpha=0.5)

# Add the spectra

In [None]:
stacked_flux = desigal.specutils.coadd_flux(wave_grid, flux_normed,ivar_normed, method="mean")

In [None]:
plt.figure(figsize=(20,10))
plt.plot(wave_grid, stacked_flux)

# OR
# `stack_spectra()`: One Function to Rule Them All

In [None]:
stacked_spectra, stack_grid = desigal.specutils.stack_spectra(
                                flux= flux,
                                wave= wave,
                                ivar = ivar,
                                mask = mask,
                                redshift = spec_z,
                                fibermap = fibermap,
                                exp_fibermap = exp_fibermap,
                                norm_method = "flux-window", #"mean", "median", "flux-window"
                                norm_flux_window=[4000,4050],
                                resample_resolution=0.8,
                                resample_method = "linear", #"linear", "sn-cons", "flux-cons"    
                                stack_method = "mean", #median , mean
                                n_workers =1
                                # weight = "none", #TO BE IMPLEMENTED
                                # stack_error = "none" #bootstrap #TO BE IMPLEMENTED
                               )

In [None]:
plt.figure(figsize=(20,10))
plt.plot(stack_grid, stacked_spectra)
plt.plot(wave_grid, stacked_flux, ls="--", lw=0.5)