Ensure that this notebook is running on the kernel with the legacy survey cutout service + dask docker image.

To set up the Dask cluster, run following cell in Jupyter terminal

`salloc -N 4 -n 512 -t 240 -C cpu -q interactive --image=biprateep/dask-viewer-cutouts:latest --account=m4236`

and then 

`./launch_dask.sh` 

the `-n` argument controls the number of workers to be launched.


### Collect the spectra data

In [None]:
import os
from pathlib import Path
import pandas as pd
import fitsio
import numpy as np
import re
import time
from desiutil.io import encode_table
from desiutil.log import get_logger, DEBUG
from desispec.io.util import native_endian, checkgzip
from desispec.io import iotime
from tqdm import tqdm
import zarr

import dask
from dask.distributed import Client
import dask.dataframe as dd
import dask.array as da
from dask.distributed import LocalCluster
from dask.diagnostics import ProgressBar

# np.seterr(divide='ignore', invalid='ignore')

In [None]:
# Read the scheduler file generated by the script above and connect your notebook to the client
scheduler_file = os.path.join(os.environ["SCRATCH"], "scheduler.json")
# scheduler_file = os.path.join(os.environ["CFS"], "desi/users/bid13/scheduler.json")
dask.config.config["distributed"]["dashboard"]["link"] = "{JUPYTERHUB_SERVICE_PREFIX}proxy/{host}:{port}/status"
client = Client(scheduler_file=scheduler_file)
client

In [None]:
# cluster = LocalCluster(threads_per_worker=1)
# client = cluster.get_client()
# client

In [None]:
release = "iron"
dest_path = Path(os.environ["SCRATCH"]) / "data" / "foundation" / f"{release}"

In [None]:
zcat = pd.read_parquet(dest_path / "desi_zcat_maglim_19_5.parquet", columns = ["SURVEY","PROGRAM","HEALPIX","TARGETID","MYID"])
zcat = zcat.set_index("MYID")
zcat = zcat.sort_index()


In [None]:
def _resolution_coadd(resolution, pix_weights):
    """
    Given the resolution matrices for set of spectra, and
    inverse variances (or generally weights) for fluxes return the
    accumulated resolution matrix, and the combined weights
    See #2372.

    Args:
    resolution (ndarray): (nspec, nres, npix) array of resolution matrices
    pix_weights (ndarray): (nspec, npix) array of ivars or weights

    Returns resolution matrix (nres, npix),
    and the weight (nres, npix)
    """
    ww = resolution.shape[1] // 2
    # resolution kernel width
    npix = resolution.shape[2]
    # indices of the corresponding variance point
    # that needs to be used for ivar weights
    res_indices = (np.arange(npix)[None, :] +
                   np.arange(-ww, ww + 1)[:, None]) % npix
    res_whts = np.array([_[res_indices] for _ in pix_weights])
    res = np.sum(res_whts * resolution, axis=0)
    res_norm = np.sum(res_whts, axis=0)
    return res, res_norm

In [None]:
def coadd_cameras(flux_cam, wave_cam, ivar_cam, mask_cam, res_cam):
    

    sbands = np.array(["b", "r", "z"])  # bands sorted by inc. wavelength
    # create wavelength array
    wave = None
    tolerance = 0.0001  # A , tolerance
  
    windict = {}

    for b in sbands:
        if wave is None:
            wave = wave_cam[b]
        else:
            wave = np.append(wave, wave_cam[b][wave_cam[b] > wave[-1] + tolerance])

    # check alignment, caching band wavelength grid indices as we go
    for b in sbands:
        imin = np.argmin(np.abs(wave_cam[b][0] - wave))
        windices = np.arange(imin, imin + len(wave_cam[b]), dtype=int)
        dwave = wave_cam[b] - wave[windices]
        if np.any(np.abs(dwave) > tolerance):
            msg = "Input wavelength grids (band '{}') are not aligned. Use --lin-step or --log10-step to resample to a common grid.".format(
                b)
            raise ValueError(msg)

    nwave = wave.size

    # creating a dictionary for each band to assign
    # which pixels are overlapping with other bands
    # it masked life easier tracking everything when normalizing the pixels
    overlap_flag = {}
    for i, b in enumerate(sbands):
        wave_b = wave_cam[b]
        flag = np.zeros_like(wave_b, dtype=int)

        # Check overlap with previous band
        if i > 0:
            wave_prev = wave_cam[sbands[i - 1]]
            # Mark overlapping pixels in current band
            for j, w in enumerate(wave_b):
                if np.any(np.abs(w - wave_prev) <= tolerance):
                    flag[j] = 1

        # Check overlap with next band
        if i < len(sbands) - 1:
            wave_next = wave_cam[sbands[i + 1]]
            for j, w in enumerate(wave_b):
                if np.any(np.abs(w - wave_next) <= tolerance):
                    flag[j] = 1

        overlap_flag[b] = flag

    # defining arrays for coadded data
    flux = np.zeros((1, nwave))
    ivar = np.zeros((1, nwave))
    mask = np.zeros((1, nwave), dtype=np.int32)

   

    
    max_ndiag = max([res_cam[b].shape[1] for b in sbands])
    rdata = np.zeros((1, max_ndiag, nwave))
    rnorm = np.zeros_like(rdata)

    for b in sbands:
        
        wband = wave_cam[b]
        start = np.searchsorted(wave, wband[0])
        end = start + len(wband)
        iband = slice(start, end)
        windict[b] = iband
        no_overlap = (overlap_flag[b]==0)

        f = flux_cam[b]
        iv = ivar_cam[b]
        m = mask_cam[b] 

        # True for pixels in b that are non-overlapping
        no_overlap = (overlap_flag[b] == 0)

        
        # Non-overlapping: directly copy
        flux[0, iband][no_overlap] = f[0][no_overlap]
        ivar[0, iband][no_overlap] = iv[0][no_overlap]

        # Overlapping: accumulate (inverse variance weighted sum)
        overlap = ~no_overlap

        # coadding flux and ivar
        flux[0, iband][overlap] += iv[0][overlap] * f[0][overlap]
        ivar[0, iband][overlap] += iv[0][overlap]

        # for masks, models and resolution matrix
        # (in no overlapping regions, simple copying)
        # in overlapping regions, inverse variance weighted mean
        
        # coadding mask
        mask[0, iband][no_overlap] = m[0][no_overlap] # non-overlapping, simple copy
        mask[0, iband][overlap] |= m[0][overlap] # overlapping, OR logic

        

        
        res = res_cam[b][0][np.newaxis, :, :]
        iv_i = iv[0:0+1]
        raccum, rnorm_i = _resolution_coadd(res, iv_i)
        ndiag = raccum.shape[0]
        offset = (max_ndiag - ndiag) // 2

        # non-overlapping regions, simple copying
        rdata[0, offset:offset+ndiag, iband.start:iband.stop][:, no_overlap] = res[0][:, no_overlap]
        rnorm[0, offset:offset+ndiag, iband.start:iband.stop][:, no_overlap] = 1.0

        # non-overlapping regions, weighted mean
        rdata[0, offset:offset+ndiag, iband.start:iband.stop][:, overlap] += raccum[:, overlap]
        rnorm[0, offset:offset+ndiag, iband.start:iband.stop][:, overlap] += rnorm_i[:, overlap]

    # in the combined unique wave pixels
    # which pixels have two measurements due to overlapping
    overlap_pixel_mask = np.zeros_like(flux, dtype=bool)
    for b in sbands:
        band_indices = np.arange(windict[b].start, windict[b].stop)
        overlap_pixel_mask[:, band_indices] = overlap_flag[b][None, :]

    # Only normalize on overlapping pixels (basically inverse variance weighted mean)
    # For non-overlapping (already direct copied), skip normalization
    normalize_mask = (overlap_pixel_mask == 1)
    flux[normalize_mask] /= (ivar[normalize_mask] + (ivar[normalize_mask] == 0))

    mask[ivar > 0] = 0 # mask =0 means good pixels

    ivar[mask.astype(bool)] = 0 # encoding all mask values in ivar
   

    # just sanity chack that wavelength is an increasing array
    assert np.all(np.diff(wave) > 0)

    
    rdata_norm_pixels = normalize_mask[0] # all rows of normalize mask are basically same
    rdata[:, :, rdata_norm_pixels] /= rnorm[:, :, rdata_norm_pixels] + (rnorm[:, :, rdata_norm_pixels] == 0)
    
  


   
    return flux, wave, ivar, mask, rdata


In [None]:
def read_spectra(survey, program, healpix, targetid, release="iron",read_hdu={
            "FIBERMAP": True,
            "EXP_FIBERMAP": False,
            "SCORES": False,
            "EXTRA_CATALOG": False,
            "MASK": False,
            "RESOLUTION": True,
        }):
    release_path = Path( f"/global/cfs/cdirs/desi/spectro/redux/{release}")
   
    infile = (
       release_path 
        / "healpix"
        / survey
        / program
        / str(int(healpix / 100))
        / str(healpix)
        / f"coadd-{survey}-{program}-{healpix}.fits"
    )


    # log = get_logger()
    ftype = np.float32

    # t0 = time.time()
    hdus = fitsio.FITS(infile, mode='r')

    targetrow = np.argwhere(hdus["FIBERMAP"].read(columns="TARGETID")==targetid)[0][0]
    nhdu = len(hdus)

    # load the metadata.

    meta = dict(hdus[0].read_header())

    # initialize data objects

    bands = []
    fmap = None
    expfmap = None
    wave = None
    flux = None
    ivar = None
    mask = None
    res = None
    extra = None
    extra_catalog = None
    scores = None

    # For efficiency, go through the HDUs in disk-order.  Use the
    # extension name to determine where to put the data.  We don't
    # explicitly copy the data, since that will be done when constructing
    # the Spectra object.
            
    for h in range(1, nhdu):
        name = hdus[h].read_header()["EXTNAME"]
        if name == "FIBERMAP":
            pass
        elif name == "EXP_FIBERMAP":
            pass
        elif name == "SCORES":
            pass
        elif name == "EXTRA_CATALOG":
            pass
        else:
            # Find the band based on the name
            mat = re.match(r"(.*)_(.*)", name)
            if mat is None:
                raise RuntimeError("FITS extension name {} does not contain the band".format(name))
            band = mat.group(1).lower()
            type = mat.group(2)
            if band not in bands:
                bands.append(band)
            if type == "WAVELENGTH":
                if wave is None:
                    wave = {}
                #- Note: keep original float64 resolution for wavelength
                wave[band] = native_endian(hdus[h].read())
            elif type == "FLUX":
                if flux is None:
                    flux = {}
                flux[band] = native_endian(hdus[h][targetrow:targetrow+1, :])
            elif type == "IVAR":
                if ivar is None:
                    ivar = {}
                ivar[band] = native_endian(hdus[h][targetrow:targetrow+1, :])
            elif type == "MASK":
                if mask is None:
                    mask = {}
                mask[band] = native_endian(hdus[h][targetrow:targetrow+1, :].astype(np.uint32))
                
            elif type == "RESOLUTION" and read_hdu["RESOLUTION"]:
                if res is None:
                    res = {}
                res[band] = native_endian(
                    hdus[h][targetrow : targetrow + 1, :, :]
                )
        
            else:
                pass
    hdus.close()
    flux, wave, ivar, mask, res = coadd_cameras(flux, wave, ivar, mask, res)
    

    return flux.astype(ftype), wave, ivar.astype(ftype), mask, res.astype(ftype) # use fp 16 for Res?

In [None]:
def split_dataframe(df, chunk_size = 1000): 
    chunks = list()
    num_chunks = int(np.ceil(len(df) / chunk_size))
    for i in range(num_chunks):
        chunks.append(df[i*chunk_size:(i+1)*chunk_size])
    return chunks

In [None]:
@dask.delayed(nout=3)
def get_spectra(params:pd.DataFrame) -> np.array:
    fluxes = list()
    ivars = list()
    reses = list()
    for i, p in params.iterrows():
        flux, wave, ivar, mask, res = read_spectra(p["SURVEY"],p["PROGRAM"],p["HEALPIX"],p["TARGETID"])
        fluxes.append(flux)
        ivars.append(ivar)
        reses.append(res)
    
    return np.concatenate(fluxes), np.concatenate(ivars), np.concatenate(reses)

In [None]:
chunk_params = split_dataframe(zcat)

In [None]:
all_fluxes = []
all_ivars = []
all_reses = []

for i in tqdm(chunk_params):
    fluxes, ivars, reses = get_spectra(i)
    all_fluxes.append(da.from_delayed(fluxes,dtype=np.float32, shape=(len(i),7781)))
    all_ivars.append(da.from_delayed(ivars,dtype=np.float32, shape=(len(i),7781)))
    all_reses.append(da.from_delayed(reses,dtype=np.float32, shape=(len(i),11,7781)))



In [None]:
all_fluxes =  da.concatenate(all_fluxes)
all_ivars =  da.concatenate(all_ivars)
all_reses = da.concatenate(all_reses)

In [None]:
filepath = dest_path / "desi_maglim_19_5.zarr"

In [None]:
with ProgressBar():
    #TODO: Move the kwargs to match the updated API
    all_fluxes.to_zarr(str(filepath),overwrite=True, compressor=None,component="FLUX") 
    all_ivars.to_zarr(str(filepath),overwrite=True, compressor=None,component="IVAR")
    all_reses.to_zarr(str(filepath),overwrite=True, compressor=None,component="RESOLUTION")
    da.linspace(3600,9824,7781).to_zarr(str(filepath),overwrite=True, compressor=None,component="WAVE")

# Quality Assurance
Spot check by retrieving a sample spectra from the saved file and compare with spectra on the disk
and legacy survey viewer

In [None]:
import os
from pathlib import Path
import zarr
import matplotlib.pyplot as plt

In [None]:
release = "iron"
dest_path = Path(os.environ["SCRATCH"]) / "data" / "foundation" / f"{release}"


zcat = pd.read_parquet(dest_path / "desi_zcat_maglim_19_5.parquet", columns = ["SURVEY","PROGRAM","HEALPIX","TARGETID","MYID"])

spec = zarr.open(dest_path / "desi_maglim_19_5.zarr", mode="r")


In [None]:
idx = 67990
survey = zcat["SURVEY"].iloc[idx]
program = zcat["PROGRAM"].iloc[idx]
healpix = zcat["HEALPIX"].iloc[idx]
targetid = zcat["TARGETID"].iloc[idx]
flux, wave, ivar, mask, res = read_spectra(zcat["SURVEY"].iloc[idx], zcat["PROGRAM"].iloc[idx], zcat["HEALPIX"].iloc[idx], zcat["TARGETID"].iloc[idx])

flux2, wave2, ivar2, mask2, res2 = spec["FLUX"][idx,:], spec["WAVE"][:], spec["IVAR"][idx,:], None, spec["RESOLUTION"][idx,:]
print(zcat.iloc[idx])
# print url for LS viewer
print(f"Web link: https://www.legacysurvey.org/viewer/desi-spectrum/dr1/targetid{targetid}")

In [None]:
plt.figure(figsize=(12,5))
plt.plot(wave,np.squeeze(flux))
plt.plot(wave,flux2, ls="--",lw=0.5)
print(f"Squares Error:{np.mean((flux-flux2)**2)}")

In [None]:
plt.figure(figsize=(12,5))
plt.plot(wave,np.squeeze(ivar))
plt.plot(wave,ivar2, ls="--",lw=0.5)
print(f"Squares Error:{np.mean((ivar-ivar2)**2)}")

In [None]:
fig, ax = plt.subplots(2,1, figsize=(10,5))
ax[0].matshow(np.squeeze(res))
ax[0].set_aspect('auto')
ax[1].matshow(np.squeeze(res2))
ax[1].set_aspect('auto')
print(f"Squares Error:{np.mean((res-res2)**2)}")