# Redshift Database Tutorial

## Abstract

This tutorial will cover the basics of using the redshift database, which is loaded from the outputs of the DESI pipeline.  Currently, this is based on software release 22.1b, and uses a [PostgreSQL](https://www.postgresql.org/) database. We use [SQLAlchemy](http://www.sqlalchemy.org/) to abstract away the details of the database.

## Requirements

This tutorial uses data from the `fuji` production (`/global/cfs/cdirs/desi/spectro/redux/fuji`), and the **DESI 22.1b** kernel.

## Initial Setup

This just imports everything we need and sets up paths and environment variables so we can find things.  The paths are based on the [minitest notebook](https://github.com/desihub/desitest/blob/master/mini/minitest.ipynb).

In [None]:
#
# Imports
#
import os
from argparse import Namespace
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.font_manager import fontManager, FontProperties
from sqlalchemy import __version__ as sqlalchemy_version
from sqlalchemy import inspect
from sqlalchemy.sql import func
from astropy.constants import c as lightspeed
from astropy.table import Table, MaskedColumn
#
# DESI software
#
from desiutil.log import get_logger, DEBUG
from desitarget.targetmask import (desi_mask, mws_mask, bgs_mask)
from desisim.spec_qa import redshifts as dsq_z
from desispec import __version__ as desispec_version
import desispec.database.redshift as db
#
# Paths to files, etc.
#
specprod = os.environ['SPECPROD'] = 'fuji'
basedir = os.path.join(os.environ['DESI_SPECTRO_REDUX'], specprod)
# surveydir = os.environ['DESISURVEY_OUTPUT'] = os.path.join(basedir, 'survey')
# targetdir = os.path.join(basedir, 'targets')
# fibassigndir = os.path.join(basedir, 'fiberassign')
# os.environ['DESI_SPECTRO_REDUX'] = os.path.join(basedir, 'spectro', 'redux')
# os.environ['DESI_SPECTRO_SIM'] = os.path.join(basedir, 'spectro', 'sim')
# os.environ['PIXPROD'] = 'mini'
# os.environ['SPECPROD'] = 'mini'
# reduxdir = os.path.join(os.environ['DESI_SPECTRO_REDUX'], os.environ['SPECPROD'])
# simdatadir = os.path.join(os.environ['DESI_SPECTRO_SIM'], os.environ['PIXPROD'])
# os.environ['DESI_SPECTRO_DATA'] = simdatadir
#
# Working directory.
#
workingdir = os.getcwd()
print(sqlalchemy_version)
print(desispec_version)

## Contents of the Database

All tables are grouped into a database "schema" and that schema is named for the production run, (*e.g.* `fuji`).  When writing "raw" SQL, table names need to be schema-qualified, for example, `fuji.target`.  However, the SQLAlchemy abstraction layer is designed to take care of this for you.

The tables are:

* `target`. This contains the photometric and targeting bits.
  - Loaded from `targetphot` files.
  - SQLAlchemy object: `db.Target`.
  - Primary key: `targetid`.
* `tile`. This contains information about observations grouped by tile.
  - Loaded from top-level `tiles-${SPECPROD}.fits`.
  - SQLAlchemy object: `db.Tile`.
  - Primary key: `tileid`.
* `exposure`. This contains information about individual exposures.
  - Loaded from top-level `exposures-${SPECPROD}.fits`, `EXPOSURES` HDU.
  - SQLAlchemy object: `db.Exposure`.
  - Primary key: `expid`.
* `frame`. This contains information about individual exposures, but broken down by camera.  There will usually, but not always, be 30 frames per exposure.
  - Loaded from top-level `exposures-${SPECPROD.fits`, `FRAMES` HDU.
  - SQLAlchemy object: `db.Frame`.
  - Primary key: `frameid`, composed from `expid` and a mapping of `camera` to an arbitrary integer.
* `fiberassign`. This contains information about fiber positions.
  - Loaded from fiberassign files in the tiles product.  All fiberassign files corresponding to tiles in the `tile` table are loaded.
  - SQLAlchemy object: `db.Fiberassign`.
  - Primary key: (`tileid`, `targetid`, `location`)
* `potential`. This contains a list of `targetid`s that *could* have been targeted on a given tile.
  - Loaded from the `POTENTIAL_ASSIGNMENTS` HDU in the same fiberassign files mentioned above.
  - SQLAlchemy object: `db.Potential`.
  - Primary key: (`tileid`, `targetid`, `location`)
* `zpix`. This contains the pipeline redshifts grouped by HEALPixel.
  - Loaded from the `zpix-*.fits` files in the `zcatalog/` directory.
  - SQLAlchemy object: `db.Zpix`.
  - Primary key: (`targetid`, `survey`, `program`)
* `ztile`. This contains the pipeline redshifts grouped by tile in a variety of ways.
  - Loaded from the `ztile-*.fits` files in the `zcatalog/` directory.
  - SQLAlchemy object: `db.Ztile`
  - Primary key: (`targetid`, `spgrp`, `spgrpval`, `tileid`)

## Initial Database Connection

This connection uses a `~/.pgpass` file to set up connection credentials.  [Be sure you have set that up](https://desi.lbl.gov/trac/wiki/DESIProductionDatabase#Setuppgpass).

In [None]:
db.log = get_logger(DEBUG)
postgresql = db.setup_db(schema=specprod, hostname='nerscdb03.nersc.gov', username='desi', verbose=True)

## Learning About the Tables

The tables in the database are listed above.  To inspect an individual table, you can use the `__table__` attribute.

In [None]:
#
# Print the table columns and their types.
#
[(c.name, c.type) for c in db.Zpix.__table__.columns]

We can also `inspect()` the database.  For details see [here](http://docs.sqlalchemy.org/en/latest/core/inspection.html?highlight=inspect#module-sqlalchemy.inspection).

In [None]:
inspector = inspect(db.engine)
for table_name in inspector.get_table_names(schema=specprod):
    print(table_name)
    for column in inspector.get_columns(table_name, schema=specprod):
        print("Column: {name} {type}".format(**column))

### Exercises

* What is the type of the `night` column of the `exposures` table?
* What is the primary key of the `ztile` table?

## Simple Queries

Queries are set up with the `.query()` method on Session objects.  In this case, there's a prepared Session object called `db.dbSession`.  `.filter()` corresponds to a `WHERE` clause in SQL.

### Select ELG Targets

Note the special way we obtain the bitwise and operator.

In [None]:
q = db.dbSession.query(db.Target).filter(db.Target.desi_target.op('&')(desi_mask.ELG) != 0).all()

In [None]:
[(row.targetid, row.desi_target, row.ra, row.dec) for row in q[:10]]

### Exercise

* How many objects in the `zpix` table have `spectype` 'GALAXY'?

### Exposures, Nights, Tiles

In [None]:
q = db.dbSession.query(db.Exposure.tileid, db.Exposure.survey, db.Exposure.program).filter(db.Exposure.night==20210115).all()

In [None]:
q

In [None]:
q = db.dbSession.query(db.Tile).count()

In [None]:
q

In [None]:
q = db.dbSession.query(db.Exposure.night, db.Exposure.expid).filter(db.Exposure.tileid==100).all()

In [None]:
q

### Redshift and Classification

Simple query filtering on string values.

In [None]:
q = db.dbSession.query(db.Zpix.spectype, db.Zpix.subtype, db.Zpix.z).filter(db.Zpix.spectype=='STAR').filter(db.Zpix.subtype!='').all()

In [None]:
q

## A Join

Now let's `JOIN` two tables.  In this case, we'll look at photometric flux and measured redshift. We'll `LIMIT` the query with slice notation.

In [None]:
q = db.dbSession.query(db.Target, db.Zpix).filter(db.Target.targetid == db.Zpix.targetid)[:50]

In [None]:
[(row.Target.flux_g, row.Target.flux_r, row.Target.flux_z, row.Zpix.z) for row in q]

In [None]:
#
# A very similar plot appears in the tutorial notebook dc17a-truth.
#
# dv = lightspeed.to('km / s') * np.array([(row.ZCat.z - row.Truth.truez) / (1.0 + row.Truth.truez) for row in q])
# ttype = [row.Truth.templatetype for row in q]
# fig, axes = plt.subplots(2, 3, figsize=(9,6), dpi=100)
# for k, objtype in enumerate(set(ttype)):
#     i = k % 2
#     j = k % 3
#     # s = axes[i].subplot(2, 3, 1+i)
#     ii = np.array(ttype) == objtype
#     axes[i][j].hist(dv[ii], 50, (-100, 100))
#     axes[i][j].set_xlabel('{} dv [km/s]'.format(objtype))
# fig.tight_layout()

### Exercise

* Create a color-color plot for objects targeted as QSOs, and spectroscopically confirmed as such.

## A More Complicated Join

Let's look at objects that appear on more than one tile. For each of those tiles, how many exposures where there?

In this example, we're using `sqlalchemy.sql.func` to get the equivalent of `COUNT(*)` and a subquery that itself is a multi-table join.

In [None]:
# db.dbSession.rollback()
q1 = db.dbSession.query(db.Fiberassign.targetid, func.count('*').label('n_assign')).group_by(db.Fiberassign.targetid).subquery()
q2 = db.dbSession.query(db.Tile.nexp, db.Fiberassign.tileid, q1.c.targetid, q1.c.n_assign).filter(q1.c.n_assign>2).filter(db.Fiberassign.targetid == q1.c.targetid).filter(db.Tile.tileid == db.Fiberassign.tileid)[:100]

In [None]:
q2

In [None]:
#
# If everything matches up, this should return True.
#
# all([row.ZCat.numexp == row.n_assign for row in q2])

### Exercise

* What is the distribution of number of exposures?

## Efficiency Studies

In `desisim.spec_qa.redshifts` there is a lot of functionality for matching redshifts to the truth table (file).  This matching is done automatically for us just by doing a join.  Also note that we're letting the database compute the value of `dz`.

We're going to cheat a little bit and convert the database output into an `astropy.table.Table` that can be understood by the `desisim.spec_qa` machinery.  No reason to waste perfectly good code!  In the future, this machinery can & should be updated to use database inputs directly.  Who wants to work on that?

In [None]:
q = db.dbSession.query(db.Truth, db.Target, db.ZCat, ((db.ZCat.z - db.Truth.truez)/(1.0 + db.Truth.truez)).label('dz'))\
                .filter(db.Truth.targetid == db.ZCat.targetid).filter(db.Target.targetid == db.ZCat.targetid).all()

In [None]:
def truth_query_to_table(q):
    """Convert a query result into a Table, so that it can be used with functions in ``desisim.spec_qa.redshifts.``
    """
    t = Table()
    columns = list()
    mask = [False]*len(q)
    for c in db.Truth.__table__.columns:
        if c.name == 'truespectype' or c.name == 'templatetype':
            columns.append(MaskedColumn([np.char.rstrip(getattr(row.Truth, c.name)) for row in q], name=c.name.upper(), mask=mask))
        else:
            columns.append(MaskedColumn([getattr(row.Truth, c.name) for row in q], name=c.name.upper(), mask=mask))
    for c in ('desi_target', 'bgs_target', 'mws_target'):
        columns.append(MaskedColumn([getattr(row.Target, c) for row in q], name=c.upper(), mask=mask))
    for c in ('z', 'zerr', 'zwarn', 'spectype'):
        if c == 'spectype':
            columns.append(MaskedColumn([np.char.rstrip(getattr(row.ZCat, c)) for row in q], name=c.upper(), mask=mask))
        else:
            columns.append(MaskedColumn([getattr(row.ZCat, c) for row in q], name=c.upper(), mask=mask))
    columns.append(MaskedColumn([row.dz for row in q], name='DZ', mask=mask))
    t.add_columns(columns)
    return t
truth = truth_query_to_table(q)

In [None]:
print('          ntarg   good  fail  miss  lost')
for objtype in set(truth['TEMPLATETYPE']):
    #isx = (truth['TEMPLATETYPE'] == objtype)
    pgood, pfail, pmiss, plost, nx = dsq_z.zstats(truth, objtype=objtype)
    #nx = np.count_nonzero(isx)
    print('{:6s} {:8d}  {:5.1f} {:5.1f} {:5.1f} {:5.1f}'.format(objtype, nx, pgood, pfail, pmiss, plost))

print()
print('good = correct redshift and ZWARN==0')
print('fail = bad redshift and ZWARN==0 (i.e. catastrophic failures)')
print('miss = correct redshift ZWARN!=0 (missed opportunities)')
print('lost = wrong redshift ZWARN!=0 (wrong but at least we know it)')

In [None]:
#
# Confusion matrix.  Borrowed from the minitest notebook.
#
confusion = dsq_z.spectype_confusion(truth)
#
# Pretty print the confusion matrix.
#
print('            Redrock')
print('Truth     ', end='')
for s1 in confusion:
    print('{:>8s}'.format(s1), end='')
print()
for s1 in confusion:
    print('{:8s}  '.format(s1), end='')
    for s2 in confusion:
        try:
            print('{:8d}'.format(confusion[s1][s2]), end='')
        except KeyError:
            print('{:8d}'.format(0), end='')
    print()

In [None]:
#
# Obtain detailed statistics for all objects.
#
stats = dict()
for s in np.unique(truth['TEMPLATETYPE']):
    stats[s] = dsq_z.calc_obj_stats(truth, s)
stats

## Going Beyond the Summary

Summary statistics are useful, but they don't tell how efficiency and other parameters depend on each other.  How does efficiency depend on magnitude?  Moon in the sky?

Some capability exists to do this in `desisim.spec_qa.redshifts`, but we'll start with a basic example just to get the feel of plotting.

In [None]:
#
# ZWARNING versus magnitude.
#
g = 22.5 - 2.5*np.log10(truth['FLUX_G'])
r = 22.5 - 2.5*np.log10(truth['FLUX_R'])
z = 22.5 - 2.5*np.log10(truth['FLUX_Z'])
fig, axes = plt.subplots(3, 1, figsize=(8, 4.5*3), dpi=100)
p = axes[0].plot(g, truth['ZWARN'], 'k.')
foo = axes[0].set_xlim(axes[0].get_xlim()[1], axes[0].get_xlim()[0])
foo = axes[0].grid(True)
# foo = axes[0].set_xlabel('g Magnitude')
foo = axes[0].set_ylabel('ZWARNING')
p = axes[1].plot(g, truth['ZWARN'], 'k.')
foo = axes[1].set_xlim(axes[1].get_xlim()[1], axes[1].get_xlim()[0])
foo = axes[1].set_ylim(0, 50)
foo = axes[1].grid(True)
# foo = axes[1].set_xlabel('g Magnitude')
foo = axes[1].set_ylabel('ZWARNING')
p = axes[2].plot(g, truth['ZWARN'], 'k.')
foo = axes[2].set_xlim(axes[2].get_xlim()[1], axes[2].get_xlim()[0])
foo = axes[2].set_ylim(0, 5)
foo = axes[2].grid(True)
foo = axes[2].set_xlabel('g Magnitude')
foo = axes[2].set_ylabel('ZWARNING')

`desisim.spec_qa.redshifts.plot_slices()` makes nice plots, so we'll leverage that for a simple example.

In [None]:
#
# Only return a subset of columns, and then use zip() to go from row-based to column-based.
#
q = db.dbSession.query(db.Truth.truez, db.ZCat.z, db.ZCat.zwarn, db.Target.flux_g, 
                       ((db.ZCat.z - db.Truth.truez)/(1.0 + db.Truth.truez)).label('dz'))\
                .filter(db.Truth.targetid == db.ZCat.targetid).filter(db.Target.targetid == db.ZCat.targetid).all()    
truez, z, zwarn, flux_g, dz = zip(*q)
g = 22.5 - 2.5*np.log10(np.array(flux_g))
ok = np.array(zwarn) == 0
dv = lightspeed.to('km / s').value * np.array(dz)
bad = (np.abs(dv) > 1000)

In [None]:
fig, axes = plt.subplots(1, 1, figsize=(8, 4.5), dpi=100)
p = dsq_z.plot_slices(g, dv, ok, bad, 16, 25, 1000, num_slices=20, axis=axes)
foo = axes.set_xlabel('g Magnitude')
foo = axes.set_ylabel('Velocity Residual [km / s]')

### Exercises

* Plot a particular template class, *e.g.* 'QSO_T'.
* Plot versus other magnitudes, *e.g.* r, W1.

## Fly me to the Moon

How does the Moon affect redshifts?

In [None]:
#
# How many actual exposures are there with the Moon up?
#
q = db.dbSession.query(db.ObsList.expid, db.ObsList.moonsep, db.ObsList.moonalt, db.ObsList.moonfrac).filter(db.ObsList.moonalt > 0).all()
q

So there are a few.  But there is a subtle issue: redshifts are based on *all* exposures, but maybe there are some redshifts where the object was observed *only* with the Moon up. And we can try to compare those objects to similar objects observed *only* with the Moon down.

In [None]:
expid_up = [x[0] for x in q]
q = db.dbSession.query(db.ZCat.targetid, db.Target.desi_target, db.Target.bgs_target, db.Target.mws_target, db.ObsList.expid)\
                .filter(db.ZCat.targetid == db.FiberAssign.targetid)\
                .filter(db.ZCat.targetid == db.Target.targetid)\
                .filter(db.FiberAssign.tileid == db.ObsList.tileid)\
                .filter(db.ObsList.expid.in_(expid_up))\
                .order_by(db.ZCat.targetid, db.ObsList.expid).all()
targetid, desi_target, bgs_target, mws_target, expid = zip(*q)

In [None]:
sum(['ELG' in desi_mask.names(t) for t in desi_target])

In [None]:
#
# OK, let's find some ELGs with the Moon up, and some with the Moon down.
#
q_up = db.dbSession.query(db.ZCat.targetid, db.Truth.truez, db.ZCat.z, db.ZCat.zwarn,
                       ((db.ZCat.z - db.Truth.truez)/(1.0 + db.Truth.truez)).label('dz'))\
                   .filter(db.Truth.targetid == db.ZCat.targetid)\
                   .filter(db.Target.targetid == db.ZCat.targetid)\
                   .filter(db.ZCat.targetid == db.FiberAssign.targetid)\
                   .filter(db.FiberAssign.tileid == db.ObsList.tileid)\
                   .filter(db.ObsList.expid.in_(expid_up))\
                   .filter(db.Target.desi_target.op('&')(desi_mask.ELG) != 0)\
                   .all()
q_dn = db.dbSession.query(db.ZCat.targetid, db.Truth.truez, db.ZCat.z, db.ZCat.zwarn,
                       ((db.ZCat.z - db.Truth.truez)/(1.0 + db.Truth.truez)).label('dz'))\
                   .filter(db.Truth.targetid == db.ZCat.targetid)\
                   .filter(db.Target.targetid == db.ZCat.targetid)\
                   .filter(db.ZCat.targetid == db.FiberAssign.targetid)\
                   .filter(db.FiberAssign.tileid == db.ObsList.tileid)\
                   .filter(~db.ObsList.expid.in_(expid_up))\
                   .filter(db.Target.desi_target.op('&')(desi_mask.ELG) != 0)\
                   .all()[:8342]
foo, truez_up, z_up, zwarn_up, dz_up = zip(*q_up)
foo, truez_dn, z_dn, zwarn_dn, dz_dn = zip(*q_dn)
truez_up = np.array(truez_up)
z_up = np.array(z_up)
zwarn_up = np.array(zwarn_up)
dv_up = lightspeed.to('km / s').value * np.array(dz_up)
truez_dn = np.array(truez_dn)
z_dn = np.array(z_dn)
zwarn_dn = np.array(zwarn_dn)
dv_dn = lightspeed.to('km / s').value * np.array(dz_dn)
ok_up = zwarn_up == 0
ok_dn = zwarn_dn == 0

In [None]:
#
# Observed redshift versus true redshift.
#
fig, axes = plt.subplots(1, 1, figsize=(8, 8), dpi=100)
p1 = axes.plot(truez_up[ok_up], z_up[ok_up], 'r.', label='Up')
p2 = axes.plot(truez_dn[ok_dn], z_dn[ok_dn], 'b.', label='Down')
foo = axes.set_xlabel('True redshift')
foo = axes.set_ylabel('Pipeline redshift')
foo = axes.legend(loc=4)

In [None]:
#
# Velocity residual versus true redshift.
#
fig, axes = plt.subplots(1, 1, figsize=(8, 8), dpi=100)
p1 = axes.semilogy(truez_up[ok_up], np.abs(dv_up[ok_up]), 'r.', label='Up')
p2 = axes.semilogy(truez_dn[ok_dn], np.abs(dv_dn[ok_dn]), 'b.', label='Down')
foo = axes.set_xlabel('True redshift')
foo = axes.set_ylabel('Absolute Velocity residual [km/s]')
foo = axes.legend(loc=1)

Well, there doesn't appear to be much difference here.  That's not necessarily a bad thing!

### Exercise

* Try a different target class!

## Survey Progress

Let's see which nights have science data.

In [None]:
q = db.dbSession.query(db.ObsList.night, func.count('*').label('n_science'))\
                .filter(db.ObsList.flavor == 'science')\
                .group_by(db.ObsList.night).order_by(db.ObsList.night).all()
q

Observation timestamp for a given night.

In [None]:
q = db.dbSession.query(db.ObsList.expid, db.ObsList.mjd)\
                .filter(db.ObsList.flavor == 'science')\
                .filter(db.ObsList.night == '20200317')\
                .order_by(db.ObsList.mjd).all()
q

So, for a given target in the `target` table, when was the observation completed?  In other words, if a target has multiple observations, we want the date of the *last* observation.

In [None]:
#
# How many targets are there?
#
N_targets = db.dbSession.query(db.Target).count()
N_targets

In [None]:
#
# Find all targetids that have observations.
#
q1 = db.dbSession.query(db.Target.targetid)\
                .filter(db.Target.targetid == db.FiberAssign.targetid)\
                .filter(db.FiberAssign.tileid == db.ObsList.tileid)\
                .group_by(db.Target.targetid)\
                .subquery()
#
# Find the exposure times for the targetids that have been observed
#
q2 = db.dbSession.query(db.FiberAssign.targetid, db.ObsList.expid, db.ObsList.mjd)\
                 .filter(db.FiberAssign.targetid == q1.c.targetid)\
                 .filter(db.FiberAssign.tileid == db.ObsList.tileid)\
                 .order_by(q1.c.targetid, db.ObsList.expid).all()
targetid, expid, mjd = zip(*q2)
targetid = np.array(targetid)
expid = np.array(expid)
mjd = np.array(mjd)
#
# Use the counts to give the *last* observation.
#
unique_targetid, i, j, c = np.unique(targetid, return_index=True, return_inverse=True, return_counts=True)
unique_expid = expid[i + (c-1)]
unique_mjd = mjd[i + (c-1)]

In [None]:
#
# Now we have the targets and the date of last observation.  But it's sorted by targetid.
#
ii = unique_expid.argsort()
unique_targetid, i3, j3, c3 = np.unique(unique_expid[ii], return_index=True, return_inverse=True, return_counts=True)
N_completed = np.cumsum(c3)

In [None]:
fig, axes = plt.subplots(1, 1, figsize=(8, 8), dpi=100)
p1 = axes.plot(unique_mjd[ii][i3] - 58920, N_completed/N_targets, 'k-')
foo = axes.set_xlabel('MJD - 58920')
foo = axes.set_ylabel('Fraction completed')
foo = axes.grid(True)
# foo = axes.legend(loc=1)

### Exercise

* Break down the progress by target class, target bit, etc.