Ensure that this notebook is running on the kernel with the legacy survey cutout service + dask docker image.

To set up the Dask cluster, run following cell in Jupyter terminal

`salloc -N 4 -n 512 -t 240 -C cpu -q interactive --image=biprateep/dask-viewer-cutouts:latest --account=m4236`

and then 

`./launch_dask.sh` 

the `-n` argument controls the number of workers to be launched.


### Collect the spectra data

In [1]:
import os
from pathlib import Path
import pandas as pd
import fitsio
import numpy as np
import re
import time
from desiutil.io import encode_table
from desiutil.log import get_logger, DEBUG
from desispec.io.util import native_endian, checkgzip
from desispec.io import iotime
from tqdm import tqdm
import zarr

import dask
from dask.distributed import Client
import dask.dataframe as dd
import dask.array as da
import dask.dataframe as dd
from dask.distributed import LocalCluster
from dask.diagnostics import ProgressBar

from map.views import get_layer
from astrometry.util.util import Tan

# np.seterr(divide='ignore', invalid='ignore')

Matplotlib created a temporary cache directory at /tmp/matplotlib-zksa8wsq because the default path (/homedir/.config/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.


In [2]:
# Read the scheduler file generated by the script above and connect your notebook to the client
scheduler_file = os.path.join(os.environ["SCRATCH"], "scheduler.json")
# scheduler_file = os.path.join(os.environ["CFS"], "desi/users/bid13/scheduler.json")
dask.config.config["distributed"]["dashboard"]["link"] = "{JUPYTERHUB_SERVICE_PREFIX}proxy/{host}:{port}/status"
client = Client(scheduler_file=scheduler_file)
client

0,1
Connection method: Scheduler file,Scheduler file: /pscratch/sd/b/bid13/scheduler.json
Dashboard: /user/bid13/perlmutter-login-node-base/proxy/10.249.1.127:8787/status,

0,1
Comm: tcp://10.249.1.127:8786,Workers: 84 (shown below: 5)
Dashboard: /user/bid13/perlmutter-login-node-base/proxy/10.249.1.127:8787/status,Total threads: 168
Started: Just now,Total memory: 39.08 TiB

0,1
Comm: tcp://10.249.1.127:32913,Total threads: 2
Dashboard: /user/bid13/perlmutter-login-node-base/proxy/10.249.1.127:46097/status,Memory: 476.37 GiB
Nanny: tcp://10.249.1.127:38445,
Local directory: /tmp/dask-scratch-space/worker-jhg3h9uw,Local directory: /tmp/dask-scratch-space/worker-jhg3h9uw
Tasks executing:,Tasks in memory:
Tasks ready:,Tasks in flight:
CPU usage: 0.0%,Last seen: Just now
Memory usage: 61.21 MiB,Spilled bytes: 0 B
Read bytes: 0.0 B,Write bytes: 0.0 B

0,1
Comm: tcp://10.249.1.127:32985,Total threads: 2
Dashboard: /user/bid13/perlmutter-login-node-base/proxy/10.249.1.127:40443/status,Memory: 476.37 GiB
Nanny: tcp://10.249.1.127:38163,
Local directory: /tmp/dask-scratch-space/worker-vc7x0io3,Local directory: /tmp/dask-scratch-space/worker-vc7x0io3
Tasks executing:,Tasks in memory:
Tasks ready:,Tasks in flight:
CPU usage: 0.0%,Last seen: Just now
Memory usage: 61.45 MiB,Spilled bytes: 0 B
Read bytes: 0.0 B,Write bytes: 0.0 B

0,1
Comm: tcp://10.249.1.127:33091,Total threads: 2
Dashboard: /user/bid13/perlmutter-login-node-base/proxy/10.249.1.127:43963/status,Memory: 476.37 GiB
Nanny: tcp://10.249.1.127:45279,
Local directory: /tmp/dask-scratch-space/worker-rn5nrm7z,Local directory: /tmp/dask-scratch-space/worker-rn5nrm7z
Tasks executing:,Tasks in memory:
Tasks ready:,Tasks in flight:
CPU usage: 0.0%,Last seen: Just now
Memory usage: 65.12 MiB,Spilled bytes: 0 B
Read bytes: 0.0 B,Write bytes: 0.0 B

0,1
Comm: tcp://10.249.1.127:33319,Total threads: 2
Dashboard: /user/bid13/perlmutter-login-node-base/proxy/10.249.1.127:45613/status,Memory: 476.37 GiB
Nanny: tcp://10.249.1.127:45239,
Local directory: /tmp/dask-scratch-space/worker-ljq3qmw9,Local directory: /tmp/dask-scratch-space/worker-ljq3qmw9
Tasks executing:,Tasks in memory:
Tasks ready:,Tasks in flight:
CPU usage: 0.0%,Last seen: Just now
Memory usage: 65.06 MiB,Spilled bytes: 0 B
Read bytes: 0.0 B,Write bytes: 0.0 B

0,1
Comm: tcp://10.249.1.127:33557,Total threads: 2
Dashboard: /user/bid13/perlmutter-login-node-base/proxy/10.249.1.127:35987/status,Memory: 476.37 GiB
Nanny: tcp://10.249.1.127:43411,
Local directory: /tmp/dask-scratch-space/worker-s_d3crz1,Local directory: /tmp/dask-scratch-space/worker-s_d3crz1
Tasks executing:,Tasks in memory:
Tasks ready:,Tasks in flight:
CPU usage: 0.0%,Last seen: Just now
Memory usage: 65.21 MiB,Spilled bytes: 0 B
Read bytes: 0.0 B,Write bytes: 0.0 B


In [3]:
# cluster = LocalCluster(threads_per_worker=1)
# client = cluster.get_client()
# client

In [4]:
release = "iron"
dest_path = Path(os.environ["SCRATCH"]) / "data" / "foundation" / f"{release}"

In [5]:
zcat = pd.read_parquet(dest_path / "desi_zcat_maglim_19_5.parquet", columns = ["SURVEY","PROGRAM","HEALPIX","TARGETID","MYID","PHOTSYS","TARGET_RA","TARGET_DEC"])
zcat = zcat.set_index("MYID")
zcat = zcat.sort_index()


In [6]:
def split_dataframe(df, chunk_size = 1000): 
    chunks = list()
    num_chunks = int(np.ceil(len(df) / chunk_size))
    for i in range(num_chunks):
        chunks.append(df[i*chunk_size:(i+1)*chunk_size])
    return chunks

In [7]:
def _resolution_coadd(resolution, pix_weights):
    """
    Given the resolution matrices for set of spectra, and
    inverse variances (or generally weights) for fluxes return the
    accumulated resolution matrix, and the combined weights
    See #2372.

    Args:
    resolution (ndarray): (nspec, nres, npix) array of resolution matrices
    pix_weights (ndarray): (nspec, npix) array of ivars or weights

    Returns resolution matrix (nres, npix),
    and the weight (nres, npix)
    """
    ww = resolution.shape[1] // 2
    # resolution kernel width
    npix = resolution.shape[2]
    # indices of the corresponding variance point
    # that needs to be used for ivar weights
    res_indices = (np.arange(npix)[None, :] +
                   np.arange(-ww, ww + 1)[:, None]) % npix
    res_whts = np.array([_[res_indices] for _ in pix_weights])
    res = np.sum(res_whts * resolution, axis=0)
    res_norm = np.sum(res_whts, axis=0)
    return res, res_norm

In [8]:
def coadd_cameras(flux_cam, wave_cam, ivar_cam, mask_cam, res_cam):
    

    sbands = np.array(["b", "r", "z"])  # bands sorted by inc. wavelength
    # create wavelength array
    wave = None
    tolerance = 0.0001  # A , tolerance
  
    windict = {}

    for b in sbands:
        if wave is None:
            wave = wave_cam[b]
        else:
            wave = np.append(wave, wave_cam[b][wave_cam[b] > wave[-1] + tolerance])

    # check alignment, caching band wavelength grid indices as we go
    for b in sbands:
        imin = np.argmin(np.abs(wave_cam[b][0] - wave))
        windices = np.arange(imin, imin + len(wave_cam[b]), dtype=int)
        dwave = wave_cam[b] - wave[windices]
        if np.any(np.abs(dwave) > tolerance):
            msg = "Input wavelength grids (band '{}') are not aligned. Use --lin-step or --log10-step to resample to a common grid.".format(
                b)
            raise ValueError(msg)

    nwave = wave.size

    # creating a dictionary for each band to assign
    # which pixels are overlapping with other bands
    # it masked life easier tracking everything when normalizing the pixels
    overlap_flag = {}
    for i, b in enumerate(sbands):
        wave_b = wave_cam[b]
        flag = np.zeros_like(wave_b, dtype=int)

        # Check overlap with previous band
        if i > 0:
            wave_prev = wave_cam[sbands[i - 1]]
            # Mark overlapping pixels in current band
            for j, w in enumerate(wave_b):
                if np.any(np.abs(w - wave_prev) <= tolerance):
                    flag[j] = 1

        # Check overlap with next band
        if i < len(sbands) - 1:
            wave_next = wave_cam[sbands[i + 1]]
            for j, w in enumerate(wave_b):
                if np.any(np.abs(w - wave_next) <= tolerance):
                    flag[j] = 1

        overlap_flag[b] = flag

    # defining arrays for coadded data
    flux = np.zeros((1, nwave))
    ivar = np.zeros((1, nwave))
    mask = np.zeros((1, nwave), dtype=np.int32)

   

    
    max_ndiag = max([res_cam[b].shape[1] for b in sbands])
    rdata = np.zeros((1, max_ndiag, nwave))
    rnorm = np.zeros_like(rdata)

    for b in sbands:
        
        wband = wave_cam[b]
        start = np.searchsorted(wave, wband[0])
        end = start + len(wband)
        iband = slice(start, end)
        windict[b] = iband
        no_overlap = (overlap_flag[b]==0)

        f = flux_cam[b]
        iv = ivar_cam[b]
        m = mask_cam[b] 

        # True for pixels in b that are non-overlapping
        no_overlap = (overlap_flag[b] == 0)

        
        # Non-overlapping: directly copy
        flux[0, iband][no_overlap] = f[0][no_overlap]
        ivar[0, iband][no_overlap] = iv[0][no_overlap]

        # Overlapping: accumulate (inverse variance weighted sum)
        overlap = ~no_overlap

        # coadding flux and ivar
        flux[0, iband][overlap] += iv[0][overlap] * f[0][overlap]
        ivar[0, iband][overlap] += iv[0][overlap]

        # for masks, models and resolution matrix
        # (in no overlapping regions, simple copying)
        # in overlapping regions, inverse variance weighted mean
        
        # coadding mask
        mask[0, iband][no_overlap] = m[0][no_overlap] # non-overlapping, simple copy
        mask[0, iband][overlap] |= m[0][overlap] # overlapping, OR logic

        

        
        res = res_cam[b][0][np.newaxis, :, :]
        iv_i = iv[0:0+1]
        raccum, rnorm_i = _resolution_coadd(res, iv_i)
        ndiag = raccum.shape[0]
        offset = (max_ndiag - ndiag) // 2

        # non-overlapping regions, simple copying
        rdata[0, offset:offset+ndiag, iband.start:iband.stop][:, no_overlap] = res[0][:, no_overlap]
        rnorm[0, offset:offset+ndiag, iband.start:iband.stop][:, no_overlap] = 1.0

        # non-overlapping regions, weighted mean
        rdata[0, offset:offset+ndiag, iband.start:iband.stop][:, overlap] += raccum[:, overlap]
        rnorm[0, offset:offset+ndiag, iband.start:iband.stop][:, overlap] += rnorm_i[:, overlap]

    # in the combined unique wave pixels
    # which pixels have two measurements due to overlapping
    overlap_pixel_mask = np.zeros_like(flux, dtype=bool)
    for b in sbands:
        band_indices = np.arange(windict[b].start, windict[b].stop)
        overlap_pixel_mask[:, band_indices] = overlap_flag[b][None, :]

    # Only normalize on overlapping pixels (basically inverse variance weighted mean)
    # For non-overlapping (already direct copied), skip normalization
    normalize_mask = (overlap_pixel_mask == 1)
    flux[normalize_mask] /= (ivar[normalize_mask] + (ivar[normalize_mask] == 0))

    mask[ivar > 0] = 0 # mask =0 means good pixels

    ivar[mask.astype(bool)] = 0 # encoding all mask values in ivar
   

    # just sanity chack that wavelength is an increasing array
    assert np.all(np.diff(wave) > 0)

    
    rdata_norm_pixels = normalize_mask[0] # all rows of normalize mask are basically same
    rdata[:, :, rdata_norm_pixels] /= rnorm[:, :, rdata_norm_pixels] + (rnorm[:, :, rdata_norm_pixels] == 0)
    
  


   
    return flux, wave, ivar, mask, rdata


In [9]:
def read_spectra(survey, program, healpix, targetid, release="iron",read_hdu={
            "FIBERMAP": True,
            "EXP_FIBERMAP": False,
            "SCORES": False,
            "EXTRA_CATALOG": False,
            "MASK": False,
            "RESOLUTION": True,
        }):
    release_path = Path( f"/global/cfs/cdirs/desi/spectro/redux/{release}")
   
    infile = (
       release_path 
        / "healpix"
        / survey
        / program
        / str(int(healpix / 100))
        / str(healpix)
        / f"coadd-{survey}-{program}-{healpix}.fits"
    )


    # log = get_logger()
    ftype = np.float32

    # t0 = time.time()
    hdus = fitsio.FITS(infile, mode='r')

    targetrow = np.argwhere(hdus["FIBERMAP"].read(columns="TARGETID")==targetid)[0][0]
    nhdu = len(hdus)

    # load the metadata.

    meta = dict(hdus[0].read_header())

    # initialize data objects

    bands = []
    fmap = None
    expfmap = None
    wave = None
    flux = None
    ivar = None
    mask = None
    res = None
    extra = None
    extra_catalog = None
    scores = None

    # For efficiency, go through the HDUs in disk-order.  Use the
    # extension name to determine where to put the data.  We don't
    # explicitly copy the data, since that will be done when constructing
    # the Spectra object.
            
    for h in range(1, nhdu):
        name = hdus[h].read_header()["EXTNAME"]
        if name == "FIBERMAP":
            pass
        elif name == "EXP_FIBERMAP":
            pass
        elif name == "SCORES":
            pass
        elif name == "EXTRA_CATALOG":
            pass
        else:
            # Find the band based on the name
            mat = re.match(r"(.*)_(.*)", name)
            if mat is None:
                raise RuntimeError("FITS extension name {} does not contain the band".format(name))
            band = mat.group(1).lower()
            type = mat.group(2)
            if band not in bands:
                bands.append(band)
            if type == "WAVELENGTH":
                if wave is None:
                    wave = {}
                #- Note: keep original float64 resolution for wavelength
                wave[band] = native_endian(hdus[h].read())
            elif type == "FLUX":
                if flux is None:
                    flux = {}
                flux[band] = native_endian(hdus[h][targetrow:targetrow+1, :])
            elif type == "IVAR":
                if ivar is None:
                    ivar = {}
                ivar[band] = native_endian(hdus[h][targetrow:targetrow+1, :])
            elif type == "MASK":
                if mask is None:
                    mask = {}
                mask[band] = native_endian(hdus[h][targetrow:targetrow+1, :].astype(np.uint32))
                
            elif type == "RESOLUTION" and read_hdu["RESOLUTION"]:
                if res is None:
                    res = {}
                res[band] = native_endian(
                    hdus[h][targetrow : targetrow + 1, :, :]
                )
        
            else:
                pass
    hdus.close()
    flux, wave, ivar, mask, res = coadd_cameras(flux, wave, ivar, mask, res)
    

    return flux.astype(ftype), wave, ivar.astype(ftype), mask, res.astype(np.float16) # use fp 16 for Res?

In [10]:
@dask.delayed(nout=3)
def get_spectra(params:pd.DataFrame) -> np.array:
    fluxes = list()
    ivars = list()
    reses = list()
    for i, p in params.iterrows():
        flux, wave, ivar, mask, res = read_spectra(p["SURVEY"],p["PROGRAM"],p["HEALPIX"],p["TARGETID"])
        fluxes.append(flux)
        ivars.append(ivar)
        reses.append(res)
    
    return np.concatenate(fluxes), np.concatenate(ivars), np.concatenate(reses)

In [11]:
chunk_params = split_dataframe(zcat)

In [12]:
all_fluxes = []
all_ivars = []
all_reses = []

for i in tqdm(chunk_params):
    fluxes, ivars, reses = get_spectra(i)
    all_fluxes.append(da.from_delayed(fluxes,dtype=np.float32, shape=(len(i),7781)))
    all_ivars.append(da.from_delayed(ivars,dtype=np.float32, shape=(len(i),7781)))
    all_reses.append(da.from_delayed(reses,dtype=np.float16, shape=(len(i),11,7781)))



100%|██████████| 4738/4738 [00:02<00:00, 1683.63it/s]


In [13]:
all_fluxes =  da.concatenate(all_fluxes)
all_ivars =  da.concatenate(all_ivars)
all_reses = da.concatenate(all_reses)

In [14]:
filepath = dest_path / "desi_maglim_19_5.zarr"

In [15]:
flux_task = all_fluxes.to_zarr(str(filepath), compressor=None,component="FLUX", compute=False) 
ivar_task = all_ivars.to_zarr(str(filepath), compressor=None,component="IVAR", compute=False)
res_task = all_reses.to_zarr(str(filepath), compressor=None,component="RESOLUTION", compute=False)
wave_task = da.linspace(3600,9824,7781).to_zarr(str(filepath),overwrite=True, compressor=None,component="WAVE", compute=False)

  return to_zarr(self, *args, **kwargs)


In [16]:
dask.compute([flux_task, ivar_task,  res_task, wave_task])

This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.


([array([], dtype=float64),
  array([], dtype=float64),
  array([], dtype=float64),
  None],)

### Precheck for querying images

In [None]:
def ls_layer_name(photsys):
    if photsys==b"S":
        return "ls-dr9-south"
    elif photsys==b"N":
        return "ls-dr9-north"
    else:
        raise ValueError
ls_layer_name = pd.Series(zcat['PHOTSYS']).astype("|S").map(ls_layer_name)

In [None]:
def check_survey_overlap(layer, ra, dec, pixscale, width, height):
    """
    Checks if a specific cutout overlaps with any survey bricks.
    Might not give right answer for Split Layers like 'ls-dr9'
    
    Args:
        layer: Name of the layer e.g., 'ls-dr9'
        ra: Right Ascension in degrees
        dec: Declination in degrees
        pixscale: Pixel scale in arcseconds per pixel
        width: Width of image in pixels
        height: Height of image in pixels
        
    Returns:
        bool: True if overlap exists, False otherwise.
    """
    # Retrieve the appropriate layer object (should be done outside a function called in a loop?)
    
    layer = get_layer(layer)

    # Construct WCS like: https://github.com/legacysurvey/imagine/blob/main/map/views.py#L2086
    ps = pixscale / 3600.
    raps = -ps
    decps = ps
    
    # Create a TAN WCS centered on the RA, Dec
    wcs = Tan(*[float(x) for x in [
        ra, dec, 
        (width + 1) / 2., (height + 1) / 2., 
        raps, 0., 0., decps,                 
        width, height                        
    ]])

    
    scale = 0 # What does this scale parameter mean and what should be the value?

    
    # This hopefully queries the brick table/tree without reading image files
    bricks = layer.bricks_touching_aa_wcs(wcs, scale=scale)
    
    if bricks is None or len(bricks) == 0:
        return False
    
    return True

In [None]:
pixscale = 0.262
H = 128
W = 128

params = pd.DataFrame({"ls9_layer_name": ls_layer_name,
                      "wise_layer_name": "unwise-neo6",
                      "ls10_layer_name": "ls-dr10-south",
                      "ra": zcat["TARGET_RA"],
                      "dec": zcat["TARGET_DEC"],
                      "height": H,
                      "width": W,
                      "pixscale": pixscale,})

In [None]:
chunk_params = split_dataframe(params, chunk_size=200)

In [None]:
@dask.delayed(nout=1)
def get_overlap_flags(params:pd.DataFrame):
    flags = list()
   
    for i, p in params.iterrows():
        lsdr9 = check_survey_overlap(p["ls9_layer_name"], p["ra"], p["dec"], p["pixscale"], p["width"], p["height"])
        wise = check_survey_overlap(p["wise_layer_name"], p["ra"], p["dec"], p["pixscale"], p["width"], p["height"])
        lsdr10 = check_survey_overlap(p["ls10_layer_name"], p["ra"], p["dec"], p["pixscale"], p["width"], p["height"])
        flags.append(np.array([lsdr9, wise, lsdr10]))
    return np.vstack(flags)

In [None]:
flags = []

for i in tqdm(chunk_params):
    flags.append(da.from_delayed(get_overlap_flags(i),dtype=np.bool, shape=(len(i),3)))

In [None]:
flags = da.vstack(flags).compute()

In [None]:
overlap_df = pd.DataFrame.from_records(flags, columns = ["lsdr9_overlap_flag",
                                                    "wise_overlap_flag",
                                                    "lsdr10_overlap_flag"])

In [None]:
params = pd.concat([overlap_df,params], axis=1)

In [None]:
params.to_parquet(dest_path / "desi_img_params_maglim_19_5.parquet")

### Get images

In [None]:
params = pd.read_parquet(dest_path / "desi_img_params_maglim_19_5.parquet")
params = params.sort_index()

In [None]:
@dask.delayed(nout=2)
def get_cutouts(params:pd.DataFrame) -> np.array:
    images = list()
    ivars = list()
    
    for i, p in params.iterrows():
        
        tempfiles = list()
        output = None

        if p["lsdr9_overlap_flag"]:
            layer = get_layer(p["ls9_layer_name"])
            try:
                img_grz, ivar_grz, hdr_grz = layer.write_cutout(p["ra"], p["dec"], p["pixscale"], p["width"], p["height"], output, bands='grz',
                           fits=True, jpeg=False, tempfiles=tempfiles, get_images=True, with_invvar=True,)
            except Exception as e:
                img_grz = [np.zeros((p["width"],p["height"])) for i in range(3)]
                ivar_grz = [np.zeros((p["width"],p["height"])) for i in range(3)]
        else:
            img_grz = [np.zeros((p["width"],p["height"])) for i in range(3)]
            ivar_grz = [np.zeros((p["width"],p["height"])) for i in range(3)]


        if p["wise_overlap_flag"]:
            layer = get_layer(p["wise_layer_name"])
            try:
                img_w1w2, ivar_w1w2, hdr_w1w2 = layer.write_cutout(p["ra"], p["dec"],p["pixscale"], p["width"], p["height"], output, bands='12',
                       fits=True, jpeg=False, tempfiles=tempfiles, get_images=True, with_invvar=True,)
            except Exception as e:
                img_w1w2 = [np.zeros((p["width"],p["height"])) for i in range(2)]
                ivar_w1w2 = [np.zeros((p["width"],p["height"])) for i in range(2)]
        else:
            img_w1w2 = [np.zeros((p["width"],p["height"])) for i in range(2)]
            ivar_w1w2 = [np.zeros((p["width"],p["height"])) for i in range(2)]
            
        if p["lsdr10_overlap_flag"]:
            layer = get_layer(p["ls10_layer_name"])
            try:
                img_i, ivar_i ,hdr_i = layer.write_cutout(p["ra"], p["dec"],p["pixscale"], p["width"], p["height"], output, bands='i',
                       fits=True, jpeg=False, tempfiles=tempfiles, get_images=True, with_invvar=True,)
            except Exception as e:
                img_i = [np.zeros((p["width"],p["height"])),]
                ivar_i = [np.zeros((p["width"],p["height"])),]
        else:
            img_i = [np.zeros((p["width"],p["height"])),]
            ivar_i = [np.zeros((p["width"],p["height"])),]
            
        
        img_grz.insert(2, *img_i)
        img_grz.extend(img_w1w2)
        images.append(np.array(img_grz))

        ivar_grz.insert(2, *ivar_i)
        ivar_grz.extend(ivar_w1w2)
        ivars.append(np.array(ivar_grz))
    #To prevent overflow during float16 casting
    images = np.clip(images, -65500, 65500)
    ivars = np.clip(ivars, -65500, 65500)
    return np.array(images).astype(np.float16), np.array(ivars).astype(np.float16)

In [None]:
chunk_params = split_dataframe(params, chunk_size = 100)

In [None]:
all_images = []
all_ivars = []
for i in tqdm(chunk_params):
    images, ivars = get_cutouts(i)
    all_images.append(da.from_delayed(images,dtype=np.float16, shape=(len(i),6,128,128)))
    all_ivars.append(da.from_delayed(ivars,dtype=np.float16, shape=(len(i),6,128,128)))

In [None]:
all_images = da.concatenate(all_images)
all_ivars = da.concatenate(all_ivars)
# all_images = all_images.compute()
# all_images = all_images.rechunk(chunks=(10,-1,-1,-1))


In [None]:
filepath = dest_path / "desi_maglim_19_5.zarr"
# filepath = dest_path / "test"

In [None]:
task_img = all_images.to_zarr(str(filepath), compressor=None, component = "IMG",compute=False)
task_ivar = all_ivars.to_zarr(str(filepath), compressor=None, component = "IMG_IVAR",compute=False)

In [None]:
dask.compute([task_img, task_ivar])

In [None]:
print("Total number of Objects:", len(data)/1e6)

# Quality Assurance
Spot check by retrieving a sample spectra from the saved file and compare with spectra on the disk
and legacy survey viewer

In [None]:
import os
from pathlib import Path
import zarr
import matplotlib.pyplot as plt

In [None]:
release = "iron"
dest_path = Path(os.environ["SCRATCH"]) / "data" / "foundation" / f"{release}"


zcat = pd.read_parquet(dest_path / "desi_zcat_maglim_19_5.parquet", columns = ["SURVEY","PROGRAM","HEALPIX","TARGETID","MYID"])

spec = zarr.open(dest_path / "desi_maglim_19_5.zarr", mode="r")


In [None]:
idx = 67990
survey = zcat["SURVEY"].iloc[idx]
program = zcat["PROGRAM"].iloc[idx]
healpix = zcat["HEALPIX"].iloc[idx]
targetid = zcat["TARGETID"].iloc[idx]
flux, wave, ivar, mask, res = read_spectra(zcat["SURVEY"].iloc[idx], zcat["PROGRAM"].iloc[idx], zcat["HEALPIX"].iloc[idx], zcat["TARGETID"].iloc[idx])

flux2, wave2, ivar2, mask2, res2 = spec["FLUX"][idx,:], spec["WAVE"][:], spec["IVAR"][idx,:], None, spec["RESOLUTION"][idx,:]
print(zcat.iloc[idx])
# print url for LS viewer
print(f"Web link: https://www.legacysurvey.org/viewer/desi-spectrum/dr1/targetid{targetid}")

In [None]:
plt.figure(figsize=(12,5))
plt.plot(wave,np.squeeze(flux))
plt.plot(wave,flux2, ls="--",lw=0.5)
print(f"Squares Error:{np.mean((flux-flux2)**2)}")

In [None]:
plt.figure(figsize=(12,5))
plt.plot(wave,np.squeeze(ivar))
plt.plot(wave,ivar2, ls="--",lw=0.5)
print(f"Squares Error:{np.mean((ivar-ivar2)**2)}")

In [None]:
fig, ax = plt.subplots(2,1, figsize=(10,5))
ax[0].matshow(np.squeeze(res))
ax[0].set_aspect('auto')
ax[1].matshow(np.squeeze(res2))
ax[1].set_aspect('auto')
print(f"Squares Error:{np.mean((res-res2)**2)}")

### images

In [None]:
sys.path.insert(1, "../src/utils")

from plotutils import make_rgb

In [None]:
dest_path = Path(os.environ["SCRATCH"]) / "photo_z" / "data" / "DESI_supervised"

zcatalog = pd.read_parquet(dest_path / "zcat_supervised_labels_subset")
imgs = zarr.open(dest_path / "images_supervised.zarr")

In [None]:
idx = 27000000 

In [None]:

select_saved_img = imgs[idx]
select_row = zcatalog.iloc[idx]

if select_row["PHOTSYS"] == "N":
    layer_name = 'ls-dr9-north'
elif select_row["PHOTSYS"] == "S":
    layer_name = 'ls-dr9-south'
tempfiles=[]
layer = get_layer( layer_name)
grz_img_download, hdr = layer.write_cutout(select_row["TARGET_RA"],select_row["TARGET_DEC"], 0.262,128, 128, out_fn=None,bands="grz",
                   fits=True, jpeg=False, tempfiles=tempfiles, get_images=True)
grz_img_saved = select_saved_img[:3,:,:]


layer = get_layer( "unwise-neo6")
wise_img_download, hdr = layer.write_cutout(select_row["TARGET_RA"],select_row["TARGET_DEC"], 0.262,128, 128, out_fn=None,bands="12",
                   fits=True, jpeg=False, tempfiles=tempfiles, get_images=True)
wise_img_saved = select_saved_img[3:,:,:]

In [None]:
fig , ax = plt.subplots(2,2)
ax = np.ravel(ax)
ax[0].imshow(make_rgb(grz_img_download,survey="ls_grz"))
ax[0].set_title("Downloaded")
ax[1].imshow(make_rgb(grz_img_saved,survey="ls_grz"))
ax[1].set_title("saved")

ax[2].imshow(make_rgb(wise_img_download,survey="unwise_w1w2"))
ax[3].imshow(make_rgb(wise_img_saved,survey="unwise_w1w2"))

for a in ax:
    a.set_axis_off()