# Finding Edge Cases

This notebook illustrates the detection and mitigation of certain edge cases in the `specprod` database.

## Imports

In [None]:
#
# Imports
#
import os
import sys
sys.path.insert(0, os.path.join(os.environ['HOME'], 'Documents', 'Code', 'git', 'desihub', 'specprod-db', 'py'))
import itertools
from sqlalchemy import __version__ as sqlalchemy_version
from sqlalchemy import and_
from sqlalchemy.sql import func
from sqlalchemy.exc import ProgrammingError

#
# DESI software
#
from desiutil.log import get_logger, DEBUG
from specprodDB import __version__ as specprodDB_version
import specprodDB.load as db
#
# Set the spectroscopic production run.
#
specprod = os.environ['SPECPROD'] = 'fuji'  # Change this to 'guadalupe' if needed.
#
# Working directory.
#
workingdir = os.getcwd()
print(f'sqlalchemy=={sqlalchemy_version}')
print(f'specprodDB=={specprodDB_version}')

## Connect to database

In [None]:
#
# For much more output, use DEBUG/verbose mode.
#
# db.log = get_logger(DEBUG)
# postgresql = db.setup_db(schema=specprod, hostname='specprod-db.desi.lbl.gov', username='desi', verbose=True)
db.log = get_logger()
postgresql = db.setup_db(schema='fuji_test', hostname='specprod-db.desi.lbl.gov', username='desi')

## Useful Configuration

In [None]:
specprod_survey_program = {'fuji': {'cmx': ('other', ),
                                    'special': ('dark', ),
                                    'sv1': ('backup', 'bright', 'dark', 'other'),
                                    'sv2': ('backup', 'bright', 'dark'),
                                    'sv3': ('backup', 'bright', 'dark')},
                           'guadalupe': {'special': ('bright', 'dark'),
                                         'main': ('bright', 'dark')},
                           'iron': {'cmx': ('other', ),
                                    'main': ('backup', 'bright', 'dark'),
                                    'special': ('backup', 'bright', 'dark', 'other'),
                                    'sv1': ('backup', 'bright', 'dark', 'other'),
                                    'sv2': ('backup', 'bright', 'dark'),
                                    'sv3': ('backup', 'bright', 'dark')},}
target_bits = {'cmx': {'cmx': db.Target.cmx_target},
               'sv1': {'desi': db.Target.sv1_desi_target, 'bgs': db.Target.sv1_bgs_target, 'mws': db.Target.sv1_mws_target},
               'sv2': {'desi': db.Target.sv2_desi_target, 'bgs': db.Target.sv2_bgs_target, 'mws': db.Target.sv2_mws_target},
               'sv3': {'desi': db.Target.sv3_desi_target, 'bgs': db.Target.sv3_bgs_target, 'mws': db.Target.sv3_mws_target},
               'main': {'desi': db.Target.desi_target, 'bgs': db.Target.bgs_target, 'mws': db.Target.mws_target},
               'special': {'desi': db.Target.desi_target, 'bgs': db.Target.bgs_target, 'mws': db.Target.mws_target},}

## Finding Anomalous Targeting

We are trying to identify objects that appear on multiple tiles that have the same `targetid` and `survey`, but different targeting bits on different tiles. However, in principle, we don't care about cases where the `targetid`, `tileid` only appears as a potential target.

Let's get the set of `targetid` for a particular `survey` and `program` that appear on more than one tile:

```sql
SELECT t.targetid
    FROM fuji.target AS t JOIN fuji.fiberassign AS f ON t.targetid = f.targetid AND t.tileid = f.tileid
    WHERE t.survey = 'sv1' AND t.program = 'dark'
    GROUP BY t.targetid HAVING COUNT(t.tileid) > 1;
```

In [None]:
assigned_multiple_tiles = dict()
for survey in specprod_survey_program[specprod]:
    assigned_multiple_tiles[survey] = dict()
    for program in specprod_survey_program[specprod][survey]:
        assigned_multiple_tiles[survey][program] = db.dbSession.query(db.Target.targetid).join(db.Fiberassign, and_(db.Target.targetid == db.Fiberassign.targetid, db.Target.tileid == db.Fiberassign.tileid)).filter(db.Target.survey == survey).filter(db.Target.program == program).group_by(db.Target.targetid).having(func.count(db.Target.tileid) > 1)
print(assigned_multiple_tiles['sv1']['dark'])

We will call the result of this query `assigned_multiple_tiles`. Now let's find the distinct pairs of `targetid`, `sv1_desi_target` from this set:

```sql
SELECT DISTINCT targetid, sv1_desi_target
    FROM fuji.target WHERE targetid IN (assigned_multiple_tiles) AND survey = 'sv1' AND program = 'dark';
```

and check `sv1_desi_target` and `sv1_mws_target` as well.

In [None]:
distinct_target = dict()
for survey in assigned_multiple_tiles:
    distinct_target[survey] = dict()
    for program in assigned_multiple_tiles[survey]:
        distinct_target[survey][program] = dict()
        for bits in target_bits[survey]:
            distinct_target[survey][program][bits] = db.dbSession.query(db.Target.targetid, target_bits[survey][bits]).filter(db.Target.targetid.in_(assigned_multiple_tiles[survey][program])).filter(db.Target.survey == survey).filter(db.Target.program == program).distinct().subquery()
print(distinct_target['sv1']['dark']['desi'])

We will call the result of this query `distinct_target`.  Next we eliminate cases where targetid only appears once in distinct_target:

```sql
SELECT targetid
    FROM (distinct_target) AS dt GROUP BY targetid HAVING COUNT(sv1_desi_target) > 1;
```

In [None]:
multiple_target = dict()
for survey in distinct_target:
    multiple_target[survey] = dict()
    for program in distinct_target[survey]:
        multiple_target[survey][program] = dict()
        for bits in distinct_target[survey][program]:
            if survey.startswith('sv'):
                column = getattr(distinct_target[survey][program][bits].c, f"{survey}_{bits}_target")
            elif survey == 'cmx':
                column = distinct_target[survey][program][bits].c.cmx_target
            else:
                column = getattr(distinct_target[survey][program][bits].c, f"{bits}_target")
            multiple_target[survey][program][bits] = [row[0] for row in db.dbSession.query(distinct_target[survey][program][bits].c.targetid).group_by(distinct_target[survey][program][bits].c.targetid).having(func.count(column) > 1).all()]
multiple_target

Do these sets of targetid overlap?

In [None]:
for targetid in multiple_target['sv1']['dark']['mws']:
    print(targetid, targetid in multiple_target['sv1']['dark']['desi'])

Yes, so there are actually only 3 additional, unique targetids that have differences in `sv1_mws_target`.

We will call the result of this query `multiple_target`. If we only want to know the number of objects, we're actually done at this stage: the answer is the number of rows of `multiple_target`.  But we can easily get more complete information:

```sql
SELECT t.targetid, t.survey, t.tileid, t.program, t.obsconditions, t.numobs_init, t.priority_init, t.subpriority, t.sv1_desi_target, t.sv1_bgs_target, t.sv1_mws_target, t.sv1_scnd_target, p.ra, p.dec
    FROM fuji.target AS t JOIN fuji.photometry AS p ON t.targetid = p.targetid
    WHERE t.survey = 'sv1' AND t.program = 'dark' AND t.targetid IN (multiple_target) ORDER BY t.targetid, t.tileid;
```

In [None]:
for survey in multiple_target:
    for program in multiple_target[survey]:
        for bits in multiple_target[survey][program]:
            if multiple_target[survey][program][bits]:
                print(survey, program, bits)
                q = db.dbSession.query(db.Target.targetid, db.Target.survey, db.Target.tileid, db.Target.program, db.Target.obsconditions, db.Target.numobs_init, db.Target.priority_init, db.Target.subpriority, db.Target.sv1_desi_target, db.Target.sv1_bgs_target, db.Target.sv1_mws_target, db.Target.sv1_scnd_target, db.Photometry.ra, db.Photometry.dec).join(db.Photometry).filter(db.Target.survey == survey).filter(db.Target.program == program).filter(db.Target.targetid.in_(multiple_target[survey][program][bits])).order_by(db.Target.targetid, db.Target.tileid)
                # print(q)
                print(q.all())

Now let's find corresponding rows in the `zpix` table. We can reuse the `multiple_target` data from above.

```sql
SELECT id, targetid, z, zwarn
    FROM fuji.zpix
    WHERE targetid IN (multiple_target) AND survey = 'sv1' AND program = 'dark';
```

In [None]:
multiple_zpix = dict()
for survey in multiple_target:
    multiple_zpix[survey] = dict()
    for program in multiple_target[survey]:
        multiple_zpix[survey][program] = dict()
        for bits in multiple_target[survey][program]:
            if multiple_target[survey][program][bits]:
                print(survey, program, bits)
                multiple_zpix[survey][program][bits] = db.dbSession.query(db.Zpix).filter(db.Zpix.targetid.in_(multiple_target[survey][program][bits])).filter(db.Zpix.survey == survey).filter(db.Zpix.program == program).all()
multiple_zpix

### Did some anomalous targets not get observed?

We know that the targetids above were *assigned* -- they are not just potential targets -- but it is possible in principle that some did not actually get observed.

In [None]:
multiple_set = set(multiple_target['sv1']['dark']['desi'] + multiple_target['sv1']['dark']['mws'])
multiple_tiles = db.dbSession.query(db.Target.tileid).filter(db.Target.targetid.in_(multiple_set)).filter(db.Target.survey == 'sv1').filter(db.Target.program == 'dark').distinct()
ztile_check = db.dbSession.query(db.Ztile).filter(db.Ztile.tileid.in_(multiple_tiles)).filter(db.Ztile.targetid.in_(multiple_set)).order_by(db.Ztile.targetid, db.Ztile.tileid)

In [None]:
ztiles = ztile_check.all()
for targetid in multiple_set:
    matching_ztiles = [z for z in ztiles if z.targetid == targetid]
    print(targetid, [z.tileid for z in matching_ztiles])

In [None]:
exposures = db.dbSession.query(db.Exposure).filter(db.Exposure.tileid.in_([80690, 80691])).order_by(db.Exposure.tileid, db.Exposure.night, db.Exposure.expid).all()

In [None]:
all_cameras = set([f"{a}{b:d}" for a, b in itertools.product('brz', range(10))])

In [None]:
for e in exposures:
    print(e.tileid, set([f.camera for f in e.frames]) ^ all_cameras)

In [None]:
[(e.tileid, e.frames) for e in exposures]

## ToO With Zero Target Bits

Due to a [known issue](https://github.com/desihub/fiberassign/pull/342), some ToO objects had targeting zeroed out. This was subsequently fixed in the fiberassign files and propagated forward to the lsdr9-photometry, but in some cases the fix was not retroactively applied to the individual redshift catalog files. Therefore when the final zcatalog files were assembled, they still had the targeting bits zeroed out. Let's find the ToOs in `fuji`.
```sql
SELECT targetid, survey, program, sv3_desi_target, sv3_bgs_target, sv3_mws_target, sv3_scnd_target
    FROM fuji_test.zpix WHERE ((targetid & (CAST((2^16 - 1) AS bigint) << 42)) >> 42) = 9999;
```

In [None]:
zero_ToO = dict()
for survey in specprod_survey_program[specprod]:
    zero_ToO[survey] = dict()
    for program in specprod_survey_program[specprod][survey]:
        zero_ToO[survey][program] = [row[0] for row in db.dbSession.query(db.Zpix.targetid).filter((db.Zpix.targetid.op('&')((2**16 - 1) << 42)).op('>>')(42) == 9999).filter(db.Zpix.survey == survey).filter(db.Zpix.program == program).all()]
zero_ToO

## Correcting Anomalous Targeting

Now that we know exactly which objects are anomalous, we can try to fix their targeting bits. We want to take the bitwise `OR` of the targeting bits for these objects. We can reuse objects returned by the `multiple_target` query above. First though, let's compress the list of `targetid` by removing duplicates.

In [None]:
targetids_to_fix = dict()
for survey in multiple_target:
    for program in multiple_target[survey]:
        for bits in multiple_target[survey][program]:
            if multiple_target[survey][program][bits]:
                if survey not in targetids_to_fix:
                    targetids_to_fix[survey] = dict()
                if program in targetids_to_fix[survey]:
                    print(f"targetids_to_fix['{survey}']['{program}'] += multiple_target['{survey}']['{program}']['{bits}']")
                    targetids_to_fix[survey][program] += multiple_target[survey][program][bits]
                else:
                    print(f"targetids_to_fix['{survey}']['{program}'] = multiple_target['{survey}']['{program}']['{bits}']")
                    targetids_to_fix[survey][program] = multiple_target[survey][program][bits]
for survey in zero_ToO:
    for program in zero_ToO[survey]:
        if zero_ToO[survey][program]:
            if survey not in targetids_to_fix:
                targetids_to_fix[survey] = dict()
            if program in targetids_to_fix[survey]:
                print(f"targetids_to_fix['{survey}']['{program}'] += zero_ToO['{survey}']['{program}']")
                targetids_to_fix[survey][program] += zero_ToO[survey][program]
            else:
                print(f"targetids_to_fix['{survey}']['{program}'] = zero_ToO['{survey}']['{program}']")
                targetids_to_fix[survey][program] = zero_ToO[survey][program]
targetids_to_fix

In [None]:
len(set(targetids_to_fix['sv1']['dark']))

There are a lot of targeting bits, so it's easier to generate the full list programmatically. We're doing metaprogramming!

```sql
SELECT t.targetid, BIT_OR(t.cmx_target) AS cmx_target, BIT_OR(desi_target) AS desi_target, ...
    FROM fuji.target AS t WHERE t.targetid IN (targetids_to_fix) AND t.survey = 'sv1' AND t.program = 'dark' GROUP BY t.targetid;
```

In [None]:
table = 'zpix'
surveys = ('', 'sv1', 'sv2', 'sv3')
programs = ('desi', 'bgs', 'mws', 'scnd')
masks = ['cmx_target'] + [('_'.join(p) if p[0] else p[1]) + '_target'
                          for p in itertools.product(surveys, programs)]
bit_or_query = dict()
for survey in targetids_to_fix:
    bit_or_query[survey] = dict()
    for program in targetids_to_fix[survey]:
        print("SELECT t.targetid, " + 
              ', '.join([f"BIT_OR(t.{m}) AS {m}" for m in masks]) +
              f" FROM {specprod}.target AS t WHERE t.targetid IN ({', '.join(map(str, set(targetids_to_fix[survey][program])))}) AND t.survey = '{survey}' AND t.program = '{program}' GROUP BY t.targetid;")
        bq = ("db.dbSession.query(db.Target.targetid, " +
              ', '.join([f"func.bit_or(db.Target.{m}).label('{m}')" for m in masks]) +
              f").filter(db.Target.targetid.in_([{', '.join(map(str, set(targetids_to_fix[survey][program])))}])).filter(db.Target.survey == '{survey}').filter(db.Target.program == '{program}').group_by(db.Target.targetid)")
        print(bq)
        bit_or_query[survey][program] = eval(bq)

In [None]:
print(bit_or_query['sv1']['dark'])
bit_or_query['sv1']['dark'].count()

Even with several different categories, there are a small number of these, so we can just loop over each one, ensuring that only one row in the `zpix` table is updated at a time.

In [None]:
update = ', '.join([f"db.Zpix.{m}: {{0.{m}:d}}" for m in masks])
for survey in bit_or_query:
    for program in bit_or_query[survey]:
        for row in bit_or_query[survey][program].all():
            zpix_match = db.dbSession.query(db.Zpix).filter(db.Zpix.targetid == row.targetid).filter(db.Zpix.survey == survey).filter(db.Zpix.program == program).one()
            print(f"{zpix_match}.update({{ {update.format(row)} }})")
            # try:
            #     zpix_update = db.dbSession.query(db.Zpix).filter(db.Zpix.targetid == row.targetid).filter(db.Zpix.survey == survey).filter(db.Zpix.program == program).update({db.Zpix.z: row.cmx_target})
            # except ProgrammingError as e:
            #     print(e)
            #     db.dbSession.rollback()

In [None]:
db.dbSession.rollback()

## QSO "Bump"

In [None]:
from desitarget.sv1.sv1_targetmask import desi_mask as sv1_desi_mask
from desitarget.sv3.sv3_targetmask import desi_mask as sv3_desi_mask

In [None]:
qso_bump = db.dbSession.query(db.Zpix, db.Target).join(db.Target, and_(db.Zpix.targetid == db.Target.targetid, db.Zpix.survey == db.Target.survey, db.Zpix.program == db.Target.program)).filter(db.Zpix.z < 0.5).filter(db.Zpix.spectype == 'QSO').filter(db.Zpix.zwarn == 0).filter(db.Zpix.sv1_desi_target.op('&')(sv1_desi_mask.QSO) == 0).filter(db.Zpix.sv3_desi_target.op('&')(sv3_desi_mask.QSO) == 0).distinct().order_by(db.Zpix.survey, db.Zpix.program, db.Zpix.targetid)
print(qso_bump)
print(qso_bump.count())

In [None]:
lines = ["targetid,survey,program,cmx_target,desi_target,bgs_target,mws_target,scnd_target,sv1_desi_target,sv1_bgs_target,sv1_mws_target,sv1_scnd_target,sv2_desi_target,sv2_bgs_target,sv2_mws_target,sv2_scnd_target,sv3_desi_target,sv3_bgs_target,sv3_mws_target,sv3_scnd_target"]
for row in qso_bump.all():
    # print(row)
    lines.append(f"{row[0].targetid:d},{row[0].survey},{row[0].program},{row[1].cmx_target:d},{row[1].desi_target:d},{row[1].bgs_target:d},{row[1].mws_target:d},{row[1].scnd_target:d},{row[1].sv1_desi_target:d},{row[1].sv1_bgs_target:d},{row[1].sv1_mws_target:d},{row[1].sv1_scnd_target:d},{row[1].sv2_desi_target:d},{row[1].sv2_bgs_target:d},{row[1].sv2_mws_target:d},{row[1].sv2_scnd_target:d},{row[1].sv3_desi_target:d},{row[1].sv3_bgs_target:d},{row[1].sv3_mws_target:d},{row[1].sv3_scnd_target:d}")

    

In [None]:
with open('/global/cfs/cdirs/desi/users/bweaver/qso_bump.csv', 'w') as f:
    f.write('\n'.join(lines) + '\n')