# Finding Edge Cases

This notebook illustrates the detection and mitigation of certain edge cases in the `specprod` database.

In [1]:
#
# Imports
#
import os
import sys
sys.path.insert(0, os.path.join(os.environ['HOME'], 'Documents', 'Code', 'git', 'desihub', 'specprod-db', 'py'))
import itertools
from sqlalchemy import __version__ as sqlalchemy_version
from sqlalchemy import and_
from sqlalchemy.sql import func
#
# DESI software
#
from desiutil.log import get_logger, DEBUG
from specprodDB import __version__ as specprodDB_version
import specprodDB.load as db
#
# Set the spectroscopic production run.
#
specprod = os.environ['SPECPROD'] = 'fuji'  # Change this to 'guadalupe' if needed.
#
# Working directory.
#
workingdir = os.getcwd()
print(f'sqlalchemy=={sqlalchemy_version}')
print(f'specprodDB=={specprodDB_version}')

sqlalchemy==1.4.46
specprodDB==1.1.0.dev49


In [2]:
#
# For much more output, use DEBUG/verbose mode.
#
# db.log = get_logger(DEBUG)
# postgresql = db.setup_db(schema=specprod, hostname='specprod-db.desi.lbl.gov', username='desi', verbose=True)
db.log = get_logger()
postgresql = db.setup_db(schema='fuji_test', hostname='specprod-db.desi.lbl.gov', username='desi')

## Finding Anomalous Targeting

We are trying to identify objects that appear on multiple tiles that have the same `targetid` and `survey`, but different targeting bits on different tiles. However, in principle, we don't care about cases where the `targetid`, `tileid` only appears as a potential target.

Let's get the set of `targetid` for a particular `survey` and `program` that appear on more than one tile:

```sql
SELECT t.targetid
    FROM fuji.target AS t JOIN fuji.fiberassign AS f ON t.targetid = f.targetid AND t.tileid = f.tileid
    WHERE t.survey = 'sv1' AND t.program = 'dark'
    GROUP BY t.targetid HAVING COUNT(t.tileid) > 1;
```

In [None]:
observed_multiple_tiles = db.dbSession.query(db.Target.targetid).join(db.Fiberassign, and_(db.Target.targetid == db.Fiberassign.targetid, db.Target.tileid == db.Fiberassign.tileid)).filter(db.Target.survey == 'sv1').filter(db.Target.program == 'dark').group_by(db.Target.targetid).having(func.count(db.Target.tileid) > 1)
print(observed_multiple_tiles)

We will call the result of this query `observed_multiple_tiles`. Now let's find the distinct pairs of `targetid`, `sv1_desi_target` from this set:

```sql
SELECT DISTINCT targetid, sv1_desi_target
    FROM fuji.target WHERE targetid IN (observed_multiple_tiles) AND survey = 'sv1' AND program = 'dark';
```

In [None]:
targets = {'desi': db.Target.sv1_desi_target, 'bgs': db.Target.sv1_bgs_target, 'mws': db.Target.sv1_mws_target}
distinct_target = dict()
for t in targets:
    distinct_target[t] = db.dbSession.query(db.Target.targetid, targets[t]).filter(db.Target.targetid.in_(observed_multiple_tiles)).filter(db.Target.survey == 'sv1').filter(db.Target.program == 'dark').distinct().subquery()
    print(distinct_target[t])

We will call the result of this query `distinct_target`.  Next we eliminate cases where targetid only appears once in distinct_target:

```sql
SELECT targetid
    FROM (distinct_target) AS dt GROUP BY targetid HAVING COUNT(sv1_desi_target) > 1;
```

In [None]:
columns = {'desi': distinct_target['desi'].c.sv1_desi_target, 'bgs': distinct_target['bgs'].c.sv1_bgs_target, 'mws': distinct_target['mws'].c.sv1_mws_target}
multiple_target = dict()
multiple_targetids = list()
for t in distinct_target:
    multiple_target[t] = db.dbSession.query(distinct_target[t].c.targetid).group_by(distinct_target[t].c.targetid).having(func.count(columns[t]) > 1)
    print(multiple_target[t])
    print(multiple_target[t].count())
    multiple_targetids += [row[0] for row in multiple_target[t].all()]

In [None]:
len(multiple_targetids), len(set(multiple_targetids))

We will call the result of this query `multiple_target`. If we only want to know the number of objects, we're actually done at this stage: the answer is the number of rows of `multiple_target`.  But we can easily get more complete information:

```sql
SELECT t.targetid, t.survey, t.tileid, t.program, t.obsconditions, t.numobs_init, t.priority_init, t.subpriority, t.sv1_desi_target, t.sv1_bgs_target, t.sv1_mws_target, t.sv1_scnd_target, p.ra, p.dec
    FROM fuji.target AS t JOIN fuji.photometry AS p ON t.targetid = p.targetid
    WHERE t.survey = 'sv1' AND t.program = 'dark' AND t.targetid IN (multiple_target) ORDER BY t.targetid, t.tileid;
```

In [None]:
for t in multiple_target:
    q = db.dbSession.query(db.Target.targetid, db.Target.survey, db.Target.tileid, db.Target.program, db.Target.obsconditions, db.Target.numobs_init, db.Target.priority_init, db.Target.subpriority, db.Target.sv1_desi_target, db.Target.sv1_bgs_target, db.Target.sv1_mws_target, db.Target.sv1_scnd_target, db.Photometry.ra, db.Photometry.dec).join(db.Photometry).filter(db.Target.survey == 'sv1').filter(db.Target.program == 'dark').filter(db.Target.targetid.in_(multiple_target[t])).order_by(db.Target.targetid, db.Target.tileid)
    print(q)
    print(q.all())

Now let's find corresponding rows in the `zpix` table. We can reuse the `multiple_target` query from above.

```sql
SELECT id, targetid, z, zwarn
    FROM fuji.zpix
    WHERE targetid IN (multiple_targets) AND survey = 'sv1' AND program = 'dark';
```

In [None]:
multiple_zpix = dict()
for t in multiple_target:
    multiple_zpix[t] = db.dbSession.query(db.Zpix).filter(db.Zpix.targetid.in_(multiple_target[t])).filter(db.Zpix.survey == 'sv1').filter(db.Zpix.program == 'dark')
    print(multiple_zpix[t])
    print(multiple_zpix[t].count())

### Correcting Anomalous Targeting

Now that we know exactly which objects are anomalous, we can try to fix their targeting bits. We want to take the bitwise `OR` of the targeting bits for these objects. We can reuse objects returned by the `multiple_target` query above.  There are a lot of targeting bits, so it's easier to generate the full list programmatically. We're doing metaprogramming!

```sql
SELECT t.targetid, BIT_OR(t.cmx_target) AS cmx_target, BIT_OR(desi_target) AS desi_target, ...
    FROM fuji.target AS t WHERE t.targetid IN (multiple_target) AND t.survey = 'sv1' AND t.program = 'dark' GROUP BY t.targetid;
```

In [None]:
import itertools
table = 'zpix'
surveys = ('', 'sv1', 'sv2', 'sv3')
programs = ('desi', 'bgs', 'mws', 'scnd')
masks = ['cmx_target'] + [('_'.join(p) if p[0] else p[1]) + '_target'
                          for p in itertools.product(surveys, programs)]
inner_columns = ['targetid', 'survey', 'program'] + masks
print("SELECT t.targetid, " + 
      ', '.join([f"BIT_OR(t.{m}) AS {m}" for m in masks]) +
      f" FROM {specprod}.target AS t WHERE t.targetid IN (multiple_target) AND t.survey = 'sv1' AND t.program = 'dark' GROUP BY t.targetid;")
print("db.dbSession.query(db.Target.targetid, " +
      ', '.join([f"func.bit_or(db.Target.{m}).label('{m}')" for m in masks]) +
      ").filter(db.Target.targetid.in_(multiple_target)).filter(db.Target.survey == 'sv1').filter(db.Target.program == 'dark').group_by(db.Target.targetid)")

In [None]:
multiple_target_or = db.dbSession.query(db.Target.targetid, func.bit_or(db.Target.cmx_target).label('cmx_target'), func.bit_or(db.Target.desi_target).label('desi_target'), func.bit_or(db.Target.bgs_target).label('bgs_target'), func.bit_or(db.Target.mws_target).label('mws_target'), func.bit_or(db.Target.scnd_target).label('scnd_target'), func.bit_or(db.Target.sv1_desi_target).label('sv1_desi_target'), func.bit_or(db.Target.sv1_bgs_target).label('sv1_bgs_target'), func.bit_or(db.Target.sv1_mws_target).label('sv1_mws_target'), func.bit_or(db.Target.sv1_scnd_target).label('sv1_scnd_target'), func.bit_or(db.Target.sv2_desi_target).label('sv2_desi_target'), func.bit_or(db.Target.sv2_bgs_target).label('sv2_bgs_target'), func.bit_or(db.Target.sv2_mws_target).label('sv2_mws_target'), func.bit_or(db.Target.sv2_scnd_target).label('sv2_scnd_target'), func.bit_or(db.Target.sv3_desi_target).label('sv3_desi_target'), func.bit_or(db.Target.sv3_bgs_target).label('sv3_bgs_target'), func.bit_or(db.Target.sv3_mws_target).label('sv3_mws_target'), func.bit_or(db.Target.sv3_scnd_target).label('sv3_scnd_target')).filter(db.Target.targetid.in_(multiple_target)).filter(db.Target.survey == 'sv1').filter(db.Target.program == 'dark').group_by(db.Target.targetid)

In [None]:
print(multiple_target_or)
multiple_target_or.count()

There are a small number of these, so we can just loop over each one, ensuring that only one row in the `zpix` table is updated at a time.

In [None]:
from sqlalchemy.exc import ProgrammingError
for row in multiple_target_or.all():
    zpix_match = db.dbSession.query(db.Zpix).filter(db.Zpix.targetid == row.targetid).filter(db.Zpix.survey == 'sv1').filter(db.Zpix.program == 'dark')
    print(zpix_match)
    try:
        zpix_update = db.dbSession.query(db.Zpix).filter(db.Zpix.targetid == row.targetid).filter(db.Zpix.survey == 'sv1').filter(db.Zpix.program == 'dark').update({db.Zpix.z: row.cmx_target}).compile()
    except ProgrammingError as e:
        print(e)
        db.dbSession.rollback()
    # print(zpix_match.count())

In [None]:
db.dbSession.rollback()

## Did some not get observed?

In [None]:
ztile_check = db.dbSession.query(db.Ztile).filter(db.Ztile.tileid.in_([80690, 80691])).filter(db.Ztile.targetid.in_([39632929852229953,
                                                                                                                     39632929856425410,
                                                                                                                     39632940065360515,
                                                                                                                     39632940073750247,
                                                                                                                     39632950194603290,
                                                                                                                     39632929860619211,
                                                                                                                     39632945140469357,
                                                                                                                     39632929852228688,
                                                                                                                     39632929856423605,
                                                                                                                     39632940073748950,
                                                                                                                     39632940065359576,
                                                                                                                     39632945144663513,
                                                                                                                     39632940044388890])).order_by(db.Ztile.targetid, db.Ztile.tileid)

In [None]:
ztile_check.all()

In [None]:
diff = [39632929852229953, 39632929856425410, 39632940065360515, 39632940073750247,
        39632950194603290, 39632929860619211, 39632945140469357, 39632929852228688, 39632929856423605,
        39632940073748950, 39632940065359576, 39632945144663513, 39632940044388890]

for targid in diff:
    print(targid, db.dbSession.query(db.Fiberassign.tileid).filter(db.Fiberassign.targetid == targid).order_by(db.Fiberassign.tileid).all())


In [None]:
exposures = db.dbSession.query(db.Exposure).filter(db.Exposure.tileid.in_([80690, 80691])).order_by(db.Exposure.tileid, db.Exposure.night, db.Exposure.expid).all()

In [None]:
all_cameras = set([f"{a}{b:d}" for a, b in itertools.product('brz', range(10))])

In [None]:
for e in exposures:
    print(e.tileid, set([f.camera for f in e.frames]) ^ all_cameras)

In [None]:
[(e.tileid, e.frames) for e in exposures]

In [4]:
from desitarget.sv1.sv1_targetmask import desi_mask as sv1_desi_mask
from desitarget.sv3.sv3_targetmask import desi_mask as sv3_desi_mask

In [25]:
qso_bump = db.dbSession.query(db.Zpix, db.Target).join(db.Target, and_(db.Zpix.targetid == db.Target.targetid, db.Zpix.survey == db.Target.survey, db.Zpix.program == db.Target.program)).filter(db.Zpix.z < 0.5).filter(db.Zpix.spectype == 'QSO').filter(db.Zpix.zwarn == 0).filter(db.Zpix.sv1_desi_target.op('&')(sv1_desi_mask.QSO) == 0).filter(db.Zpix.sv3_desi_target.op('&')(sv3_desi_mask.QSO) == 0).distinct().order_by(db.Zpix.survey, db.Zpix.program, db.Zpix.targetid)
print(qso_bump)
print(qso_bump.count())

SELECT DISTINCT fuji_test.zpix.id AS fuji_test_zpix_id, fuji_test.zpix.targetid AS fuji_test_zpix_targetid, fuji_test.zpix.survey AS fuji_test_zpix_survey, fuji_test.zpix.program AS fuji_test_zpix_program, fuji_test.zpix.spgrp AS fuji_test_zpix_spgrp, fuji_test.zpix.spgrpval AS fuji_test_zpix_spgrpval, fuji_test.zpix.healpix AS fuji_test_zpix_healpix, fuji_test.zpix.z AS fuji_test_zpix_z, fuji_test.zpix.zerr AS fuji_test_zpix_zerr, fuji_test.zpix.zwarn AS fuji_test_zpix_zwarn, fuji_test.zpix.chi2 AS fuji_test_zpix_chi2, fuji_test.zpix.coeff_0 AS fuji_test_zpix_coeff_0, fuji_test.zpix.coeff_1 AS fuji_test_zpix_coeff_1, fuji_test.zpix.coeff_2 AS fuji_test_zpix_coeff_2, fuji_test.zpix.coeff_3 AS fuji_test_zpix_coeff_3, fuji_test.zpix.coeff_4 AS fuji_test_zpix_coeff_4, fuji_test.zpix.coeff_5 AS fuji_test_zpix_coeff_5, fuji_test.zpix.coeff_6 AS fuji_test_zpix_coeff_6, fuji_test.zpix.coeff_7 AS fuji_test_zpix_coeff_7, fuji_test.zpix.coeff_8 AS fuji_test_zpix_coeff_8, fuji_test.zpix.coeff_9 A

In [27]:
lines = ["targetid,survey,program,cmx_target,desi_target,bgs_target,mws_target,scnd_target,sv1_desi_target,sv1_bgs_target,sv1_mws_target,sv1_scnd_target,sv2_desi_target,sv2_bgs_target,sv2_mws_target,sv2_scnd_target,sv3_desi_target,sv3_bgs_target,sv3_mws_target,sv3_scnd_target"]
for row in qso_bump.all():
    # print(row)
    lines.append(f"{row[0].targetid:d},{row[0].survey},{row[0].program},{row[1].cmx_target:d},{row[1].desi_target:d},{row[1].bgs_target:d},{row[1].mws_target:d},{row[1].scnd_target:d},{row[1].sv1_desi_target:d},{row[1].sv1_bgs_target:d},{row[1].sv1_mws_target:d},{row[1].sv1_scnd_target:d},{row[1].sv2_desi_target:d},{row[1].sv2_bgs_target:d},{row[1].sv2_mws_target:d},{row[1].sv2_scnd_target:d},{row[1].sv3_desi_target:d},{row[1].sv3_bgs_target:d},{row[1].sv3_mws_target:d},{row[1].sv3_scnd_target:d}")

    

In [28]:
with open('/global/cfs/cdirs/desi/users/bweaver/qso_bump.csv', 'w') as f:
    f.write('\n'.join(lines) + '\n')