### Notebook containing code for combining truncated object detections with regular detections

In [None]:
from astropy.io import fits
import matplotlib.pyplot as plt
from matplotlib import colors
from astropy.visualization import make_lupton_rgb
import numpy as np
import pandas as pd
from astropy.nddata import Cutout2D
from astropy.wcs import WCS

import cv2
from detectron2.structures import BoxMode
from astropy.table import Table
import glob
from astropy.coordinates import SkyCoord  # High-level coordinates
from detectron2.config import LazyConfig, get_cfg, instantiate
import os
import scipy.stats as stats
import h5py
import json
import astropy.units as u
from astropy.coordinates import SkyCoord

import warnings
import time

from astropy.wcs import FITSFixedWarning
warnings.filterwarnings("ignore", category=FITSFixedWarning)
import torch
import torch.nn.functional as F
from detectron2.data import detection_utils as utils
import pickle

In [3]:
testmetaf = h5py.File(test_metadatafile, "r")
test_metadata = testmetaf['metadata_dicts']

testf = h5py.File(testfile, "r")
testims = testf['images']

testmetaf_trunc = h5py.File(test_trunc_metadatafile, "r")
test_trunc_metadata = testmetaf_trunc['metadata_dicts']

testf_trunc = h5py.File(testfile_trunc, "r")
testims_trunc = testf_trunc['images']


trainmetaf = h5py.File(train_metadatafile, "r")
train_metadata = trainmetaf['metadata_dicts']

trainf = h5py.File(trainfile, "r")
trainims = trainf['images']



train_ids = []
for i in range(len(train_metadata)):
    d = json.loads(train_metadata[i])
    for a in d['annotations']:
        if a['redshift']!=-1:
            train_ids.append(a['obj_id'])

            
test_ids = []
for i in range(len(test_metadata)):
    d = json.loads(test_metadata[i])
    for a in d['annotations']:
        if a['redshift']!=-1:
            test_ids.append(a['obj_id'])

fncat = '/home/shared/hsc/JWST/catalogs/hlsp_jades_jwst_nircam_goods-s-deep_photometry_v2.0_catalog.fits'
fnspecz = '/home/shared/hsc/JWST/catalogs/JADES_GOODS_zspec_cleaned.fits'

dphot = Table.read(fncat, hdu=2).to_pandas()
dspecz = Table.read(fnspecz,hdu=1).to_pandas()
dt = Table.read(fncat, hdu=7).to_pandas()
            
dspecztrain = dspecz.iloc[np.nonzero(np.in1d(dspecz.ID.values,np.unique(train_ids)))]
dspecztest = dspecz.iloc[np.nonzero(np.in1d(dspecz.ID.values,np.unique(test_ids)))]

In [6]:
def missing_bands(dspec,dphot_kron):
    nmb = []
    mbs = []
    for idi in dspec.ID.values:
        ind = np.where(dphot_kron.ID==idi)
        nob=0
        mbi=[]
        for F in JADES_filters_F:
            if dphot_kron[f'{F}_KRON'].values[ind]==0:
                nob+=1
                mbi.append(JADES_filters_F)
        nmb.append(nob)

    return nmb,mbs

dphot_kron = Table.read(fncat, hdu=8).to_pandas()
dphot_size = Table.read(fncat, hdu=3).to_pandas()

nmb_train,mbs_train = missing_bands(dspecztrain,dphot_kron)
dspecztrain['Num_missing']=nmb_train

nmb_test,mbs_test = missing_bands(dspecztest,dphot_kron)
dspecztest['Num_missing']=nmb_test

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  dspecztrain['Num_missing']=nmb_train
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  dspecztest['Num_missing']=nmb_test


In [12]:

def outside_box(box,shape):
    if box[0]<0:
        return True
    elif box[1]<0:
        return True
    elif box[2]>shape[1]:
        return True
    elif box[3]>shape[0]:
        return True
    else:
        return False


## Determines the truncated objects in the original test set
truncated = []
for i in range(len(test_metadata)):
    d = json.loads(test_metadata[i])
    shape=(d['height'],d['width'])
    boxes = utils.annotations_to_instances(d['annotations'],shape).gt_boxes.tensor.cpu().numpy()
    for j,box in enumerate(boxes):
        outside = outside_box(box,shape)
        if d['annotations'][j]['redshift']!=-1 and outside:
            truncated.append((i,j,d['annotations'][j]['obj_id']))
            
truncated=np.array(truncated)
    

def get_res(filename):
    with open(filename, 'rb') as fp:
        data = pickle.load(fp)
    return [data[key] for key in data.keys()]
    
    
JADESfilts = ['F090W','F115W','F150W','F200W','F277W','F335M','F356W','F410M','F444W']

def get_missfilts(ids):
    missfilts = []
    for i, idi in enumerate(ids):
        cnb=0
        mf=[]
        for filt in JADESfilts:
            if dt[dt['ID']==idi][f'{filt}_KRON'].values[0] ==0:
                cnb+=1
                mf.append(filt)

        missfilts.append(mf)

    return missfilts



def get_mi(missfilts):
    mi=[]
    for i,m in enumerate(missfilts):
        if m==np.unique(missfilts)[0]:
            mi.append(i)
    return mi
    
def get_tot(dtest, truncated, res_name,res_trunc_name):
    
     """Takes as the outputs of running inference on the normal and truncated test set 
         and returns a single catalog

        Parameters
        ----------
        dtest: pandas DataFrame
            Dataframe that contains the entire test set of objects 
            (including those truncated in the original test set of images).  Must have an ID key
        truncated: array(int):
            Array of test set object IDs that were truncated
        res_name: str
            path to a dictionary containing zphot, ztrue, object_ids, scores, and pdfs of inferred objects
            Assumes the detections have already been matched to a test set catalog and duplicate matches 
            have been filtered by score
        trunc_name: str
            path to a dictionary containing zphot, ztrue, object_ids, scores, and pdfs of inferred objects
            from the truncated test set. Assumes the detections have already been matched to a test set catalog 
            and duplicate matches have been filtered by score

        
        Returns
        -------
        zts_tot, zps_tot, ids_tot, pdfs_tot
        zts_tot : array(float)
            Filtered array of true redshifts 
        zps_tot : array(float)
            Filtered array of photometric redshifts
        ids_tot : array(float)
            Filtered array of object ids
        pdfs_tot : 2D array(float)
            Filtered array of z PDFs
        """
    
    zps,zts,ids,scores,pdfs = get_res(res_name)
    zps_trunc,zts_trunc,ids_trunc,scores_trunc,pdfs_trunc = get_res(res_trunc_name)
    
    #What are the test set objects that are not detected in the original test set of images?
    nondects_ids = dtest.ID.values[np.nonzero(np.in1d(dtest.ID.values,ids,invert=True))]
    
    #How many test set objects that are not detected are truncated?
    truncated_ids = np.unique(truncated)
    #print(len(np.nonzero(np.in1d(nondects_ids,truncated_ids))[0]))

    #How many originally truncated objects that weren't detected, are now detected?
    orig_trunc_nondects = nondects_ids[np.nonzero(np.in1d(nondects_ids,truncated_ids))[0]]
    
    #objects originally detected that were truncated
    orig_trunc_dects = ids[np.nonzero(np.in1d(ids,truncated_ids))]
    
    zts[np.nonzero(np.in1d(ids,truncated_ids)*np.in1d(ids,ids_trunc))] = zts_trunc[np.nonzero(np.in1d(ids_trunc,orig_trunc_dects))]
    zps[np.nonzero(np.in1d(ids,truncated_ids)*np.in1d(ids,ids_trunc))] = zps_trunc[np.nonzero(np.in1d(ids_trunc,orig_trunc_dects))]
    ids[np.nonzero(np.in1d(ids,truncated_ids)*np.in1d(ids,ids_trunc))] = ids_trunc[np.nonzero(np.in1d(ids_trunc,orig_trunc_dects))]
    pdfs[np.nonzero(np.in1d(ids,truncated_ids)*np.in1d(ids,ids_trunc))] = pdfs_trunc[np.nonzero(np.in1d(ids_trunc,orig_trunc_dects))]
    
    zts_tot = np.concatenate([zts,zts_trunc[np.nonzero(np.in1d(ids_trunc,orig_trunc_nondects))]])
    zps_tot = np.concatenate([zps,zps_trunc[np.nonzero(np.in1d(ids_trunc,orig_trunc_nondects))]])
    ids_tot = np.concatenate([ids,ids_trunc[np.nonzero(np.in1d(ids_trunc,orig_trunc_nondects))]])
    pdfs_tot = np.concatenate([pdfs,pdfs_trunc[np.nonzero(np.in1d(ids_trunc,orig_trunc_nondects))]])
        
    return zts_tot, zps_tot, ids_tot, pdfs_tot

