In [None]:
from astropy.io import fits
import matplotlib.pyplot as plt
from matplotlib import colors
from astropy.visualization import make_lupton_rgb
import numpy as np
import pandas as pd
from astropy.nddata import Cutout2D
from astropy.wcs import WCS

import cv2
from detectron2.structures import BoxMode
from astropy.table import Table
import glob
from astropy.coordinates import SkyCoord  # High-level coordinates
from detectron2.config import LazyConfig, get_cfg, instantiate
import os
import scipy.stats as stats

import astropy.units as u
from astropy.coordinates import SkyCoord
import h5py
import json
from detectron2.data import detection_utils as utils


In [None]:

#Getting all the image data

files = sorted(glob.glob('/home/shared/hsc/JWST/images/hlsp_jades_jwst_nircam_goods-s-deep_f*'))
JADES_filters = ['090','115','150','200','277','335','356','410','444']

JADES_files = []
for filt in JADES_filters:
    for file in files:
        if filt in file:
            JADES_files.append(file)

            
allims = []
for file in files:
    with fits.open(file) as hdul:
        im = hdul[1].data
        wcs = WCS(hdul[1].header)
        allims.append(im)
        
        

JADESims = []
for file in JADES_files:
    with fits.open(file) as hdul:
        im = hdul[1].data
        wcs = WCS(hdul[1].header)
        JADESims.append(im)
        

In [None]:
#Gets the test metadata (here it's stored in hdf5 format)

testfile = '/home/shared/hsc/JWST/processed_data/by_area/flattened_images_test_compresample.hdf5'
test_metadatafile = '/home/shared/hsc/JWST/processed_data/by_area/test_compresample_metadata_wcs.hdf5'

with h5py.File(test_metadatafile, 'r') as f:
    test_metadata = f['metadata_dicts'][:]

with h5py.File(testfile, 'r') as fim:
    test_images = fim['images'][:]


In [None]:

#determine truncated object IDs

def outside_box(box,shape):
    if box[0]<0:
        return True
    elif box[1]<0:
        return True
    elif box[2]>shape[1]:
        return True
    elif box[3]>shape[0]:
        return True
    else:
        return False


truncated = []
for i in range(len(test_metadata)):
    d = json.loads(test_metadata[i])
    shape=(d['height'],d['width'])
    boxes = utils.annotations_to_instances(d['annotations'],shape).gt_boxes.tensor.cpu().numpy()
    for j,box in enumerate(boxes):
        outside = outside_box(box,shape)
        if d['annotations'][j]['redshift']!=-1 and outside:
            truncated.append((i,j,d['annotations'][j]['obj_id']))
            
truncated=np.array(truncated)
truncated_ids = np.unique(truncated[:,2])


In [None]:

#Get the catalog and divvy into the associated training/test set 

fncat = '/home/shared/hsc/JWST/catalogs/hlsp_jades_jwst_nircam_goods-s-deep_photometry_v2.0_catalog.fits'
fnspecz = '/home/shared/hsc/JWST/catalogs/JADES_GOODS_zspec_cleaned.fits'
fndect = '/home/shared/hsc/JWST/images/hlsp_jades_jwst_nircam_goods-s-deep_detection_v2.0_drz.fits'
fnseg = '/home/shared/hsc/JWST/images/hlsp_jades_jwst_nircam_goods-s-deep_segmentation_v2.0_drz.fits'

dphot = Table.read(fncat, hdu=2).to_pandas()
dspecz = Table.read(fnspecz,hdu=1).to_pandas()
dseg = Table.read(fncat, hdu=3).to_pandas()



with fits.open(fnseg) as hdul:
    hdul.info()
    segim = hdul[0].data
    segh = hdul[0].header
    

with fits.open(fndect) as hdul:
    hdul.info()
    dectim = hdul[0].data
    dect = hdul[0].header

In [None]:

#Annotation function

def get_metadata(segcutout, dspecz, imid):
    segimcut = segcutout.data
    d = {}
    annos = []
    for s in np.unique(segimcut):
        mask = np.zeros(segimcut.shape)
        if s == 0:
            continue
        s0i = np.where(segimcut == s)
        mask[s0i] =1

        #x0 = s0i[1].min()
        #x1 = s0i[1].max()
        #y0 = s0i[0].min()
        #y1 = s0i[0].max()

        #h = int(y1-y0)
        #w = int(x1-x0)
        
     
        x0,y0 = segcutout.to_cutout_position((dseg[dseg['ID']==s]['BBOX_XMIN'].values[0],dseg[dseg['ID']==s]['BBOX_YMIN'].values[0]))
        x1,y1 = segcutout.to_cutout_position((dseg[dseg['ID']==s]['BBOX_XMAX'].values[0],dseg[dseg['ID']==s]['BBOX_YMAX'].values[0]))

        h = int(y1-y0)
        w = int(x1-x0)

        contours, hierarchy = cv2.findContours(
                (mask).astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
            )

        segmentation = []
        for contour in contours:
            # contour = [x1, y1, ..., xn, yn]
            contour = contour.flatten()
            if len(contour) > 4:
                segmentation.append(contour.tolist())
        
        specmatch = np.where(dspecz.ID==s)[0]
        if len(specmatch) ==0:
            redshift = -1
        else:
            #redshift = dspecz.iloc[specmatch]['z'].value
            redshift = dspecz['z'].values[specmatch[0]]
            
        obj = {
            "bbox": [int(x0), int(y0), w, h],
            "area": w * h,
            "bbox_mode": BoxMode.XYWH_ABS,
            "segmentation": segmentation,
            "category_id": 0,
            "obj_id": int(s),
            "redshift":redshift
        }
        
        annos.append(obj)

    height, width = mask.shape



    d["annotations"] = annos
    d['height'] = height
    d['width'] = width
    d["image_id"] = imid
    
    return d


In [None]:

#Example of making a cutout centered on the jth truncated object
#Get the RA and DEC and then cutout from the main image using its WCS

j=38
coords = SkyCoord(ra=dspecz.RA.values[np.where(dspecz.ID.values==trunc_ids[j])]*u.degree, dec=dspecz.DEC.values[np.where(dspecz.ID.values==trunc_ids[j])]*u.degree)

image = []
for im in JADESims:
    cutout = Cutout2D(im, position=coords, size=(512,512), wcs=wcs)
    image.append(cutout.data)
image = np.array(image)


segcutout = Cutout2D(segim, position=coords, size=(512,512), wcs=wcs)
#segimcut = segcutout.data

d = get_metadata(segcutout, dspecz, j)

In [None]:
from detectron2.data import MetadataCatalog, DatasetCatalog

astrotest_metadata = MetadataCatalog.get("astro_test")

In [None]:

import matplotlib.pyplot as plt
from deepdisc.astrodet.visualizer import Visualizer
from deepdisc.astrodet.visualizer import ColorMode

from astropy.visualization import make_lupton_rgb

plt.figure(figsize=(14,7))
b1 = image[2]
b2 = image[1]
b3 = image[0]

img = make_lupton_rgb(b1, b2, b3, minimum=0, stretch=0.1, Q=5)

print("total instances:", len(d["annotations"]))
v0 = Visualizer(
    img,
    metadata=astrotest_metadata,
    scale=5,
    instance_mode=ColorMode.SEGMENTATION,  # remove the colors of unsegmented pixels. This option is only available for segmentation models
)
groundTruth = v0.draw_dataset_dict(d, lf=False, alpha=0.01)

ax1 = plt.subplot(1, 1, 1)
ax1.imshow(groundTruth.get_image())
ax1.axis("off")
ax1.set_title('Ground Truth')
#ax2 = plt.subplot(1, 2, 2)
#ax2.imshow(image[0],norm=colors.LogNorm())
#ax2.axis("off")

In [None]:

#Loops over the truncated objects

dicts = []
images = []
for j in range(len(truncated_ids)):
    print(j)
    
    coords = SkyCoord(ra=dspecz.RA.values[np.where(dspecz.ID.values==truncated_ids[j])]*u.degree, dec=dspecz.DEC.values[np.where(dspecz.ID.values==truncated_ids[j])]*u.degree)

    
    image = []
    for im in JADESims:
        cutout = Cutout2D(im, position=coords, size=(512,512), wcs=wcs)
        image.append(cutout.data)
    image = np.array(image)
    
    
    segcutout = Cutout2D(segim, position=coords, size=(512,512), wcs=wcs)
    #segimcut = segcutout.data
    
    d = get_metadata(segcutout, dspecz, j)
    if np.all(image==0):
        continue
        
    ddict={}
    ddict['wcs'] = cutout.wcs.to_header_string()
    ddict['height'] = image.shape[1]
    ddict['width'] = image.shape[2]
    ddict['annotations'] = d['annotations']
    ddict['image_id'] = int(truncated_ids[j])

    dicts.append(ddict)
    images.append(image.flatten())


In [None]:

#Saves in HDF5 format (for RAIL photozs)

import deepdisc.data_format.conversions as conversions

trunc_images = np.array(images)

with h5py.File('/home/shared/hsc/JWST/processed_data/by_area/flattened_images_test_compresample_trunc.hdf5', "w") as f:
    data = f.create_dataset("images", data=trunc_images)
    
conversions.ddict_to_hdf5(dicts,'/home/shared/hsc/JWST/processed_data/by_area/test_comp_resample_trunc_metadata_wcs.hdf5')