In [9]:
# In case you need to point to pre-existing scarlet install
import sys
# change these paths to your specific directories where deepdisc and detectron2 are stored
sys.path.insert(0, '/home/yse2/deepdisc/src')
sys.path.insert(0, '/home/yse2/detectron2')
import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=UserWarning)
# warnings.filterwarnings("ignore", category=SettingWithCopyWarning)
# warnings.filterwarnings("ignore", category=DtypeWarning)
import deepdisc
import detectron2
print(deepdisc.__file__)
print(detectron2.__file__)
from detectron2.data import MetadataCatalog, DatasetCatalog

# Standard imports
import os, json
import numpy as np
import pandas as pd
import time
import math
import glob
import scarlet
import cv2
import argparse
# for multiprocessing
import multiprocessing
from functools import partial
import psutil

# astropy
import astropy.io.fits as fits
import astropy.units as u
from astropy.wcs import WCS
from astropy.coordinates import SkyCoord
from astropy.nddata import Cutout2D
from astropy.table import Table
from astropy.stats import gaussian_fwhm_to_sigma
from astropy.visualization import make_lupton_rgb

# Astrodet imports
from deepdisc.preprocessing.get_data import get_cutout
from deepdisc.astrodet.hsc import get_tract_patch_from_coord, get_hsc_data
from deepdisc.astrodet.visualizer import ColorMode
from deepdisc.astrodet.visualizer import Visualizer

import matplotlib
import matplotlib.pyplot as plt
from matplotlib import colors

from galcheat.utilities import mag2counts, mean_sky_level
from btk.survey import Filter, Survey, make_wcs
import galsim
import btk

def e1e2_to_ephi(e1,e2):
    
    pa = np.arctan(e2/e1)
    
    return pa

def get_wcs_dict_path(roman_file):
    parts = roman_file.split('/')
    subpatch = parts[-2]  #  dc2_51.37_-38.3
    filename = parts[-1]  # full_c108_51.37_-38.3_centered_roman_9.npy
    
    cutout_num = filename.split('_')[1]  # c108
    obj_id = filename.split('_')[-1].replace('.npy', '')  # 9
    
    wcs_path = f"./trunc-lsst/metadata/{subpatch}/full_{cutout_num}_lsst_{obj_id}.json"
    
    return wcs_path

def dcut_reformat(obj_params):
    """Reformat object parameters for a single object"""
    cat = pd.DataFrame([obj_params])  # Convert single object dict to DataFrame
    L0 = 3.0128e28
    for band in ['u', 'g', 'r', 'i', 'z', 'y']:
        cat[f'{band}_ab'] = cat[f'mag_true_{band}']
        total_flux = L0 * 10**(-0.4*cat[f'mag_true_{band}'])
        bulge_to_total_ratio = cat[f'bulge_to_total_ratio_{band}']

        cat[f'fluxnorm_bulge_{band}'] = total_flux * bulge_to_total_ratio
        cat[f'fluxnorm_disk_{band}'] = total_flux * (1-bulge_to_total_ratio)
        cat[f'fluxnorm_agn_{band}'] = np.zeros(total_flux.shape)

    cat['a_b'] = cat['size_bulge_true']
    cat['b_b'] = cat['size_minor_bulge_true']
    cat['a_d'] = cat['size_disk_true']
    cat['b_d'] = cat['size_minor_disk_true']
    
    cat['pa_bulge'] = e1e2_to_ephi(cat['ellipticity_1_bulge_true'],cat['ellipticity_2_bulge_true']) * 180.0/np.pi

    cat['pa_disk'] = e1e2_to_ephi(cat['ellipticity_1_disk_true'],cat['ellipticity_2_disk_true']) * 180.0/np.pi
    
    cat['pa_tot'] = e1e2_to_ephi(cat['ellipticity_1_true'],cat['ellipticity_2_true']) * 180.0/np.pi

    cat['g1'] = cat['shear_1']
    cat['g2'] = cat['shear_2']
    
    return cat

seed = 8312
rng = np.random.RandomState(seed)
grng = galsim.BaseDeviate(rng.randint(0, 2**30))

def get_star_gsparams(mag, flux, noise):
    """
    Get appropriate gsparams given flux and noise

    Parameters
    ----------
    mag: float
        mag of star
    flux: float
        flux of star
    noise: float
        noise of image

    Returns
    --------
    GSParams, isbright where isbright is true for stars with mag less than 18
    """
    do_thresh = do_acc = False
    if mag < 18:
        do_thresh = True
    if mag < 15:
        do_acc = True

    if do_thresh or do_acc:
        isbright = True

        kw = {}
        if do_thresh:

            # this is designed to quantize the folding_threshold values,
            # so that there are fewer objects in the GalSim C++ cache.
            # With continuous values of folding_threshold, there would be
            # a moderately largish overhead for each object.

            folding_threshold = noise/flux
            folding_threshold = np.exp(
                np.floor(np.log(folding_threshold))
            )
            kw['folding_threshold'] = min(folding_threshold, 0.005)

        if do_acc:
            kw['kvalue_accuracy'] = 1.0e-8
            kw['maxk_threshold'] = 1.0e-5

        gsparams = galsim.GSParams(**kw)
    else:
        gsparams = None
        isbright = False

    return gsparams, isbright

def make_star(entry, survey, filt):
    mag = entry[f'mag_{filt.name}'].iloc[0]
    flux = mag2counts(mag, survey, filt).to_value("electron")
    noise = mean_sky_level(survey, filt).to_value('electron')
    gsparams, isbright = get_star_gsparams(mag, flux, noise)
    star = galsim.Gaussian(fwhm=1.0e-4, flux=flux, gsparams=gsparams)
    return star, gsparams, flux

def make_galaxy(entry, survey, filt, no_disk=False, no_bulge=False, no_agn=True):
    """Create galaxy object simulation - keeping original logic but for single object"""
    components = []
    total_flux = mag2counts(entry[filt.name + "_ab"].iloc[0], survey, filt).to_value("electron")
    
    total_fluxnorm = (entry["fluxnorm_disk_"+filt.name] + 
                      entry["fluxnorm_bulge_"+filt.name] + 
                      entry["fluxnorm_agn_"+filt.name]).iloc[0]
    
    disk_flux = 0.0 if no_disk else entry["fluxnorm_disk_"+filt.name].iloc[0] / total_fluxnorm * total_flux
    bulge_flux = 0.0 if no_bulge else entry["fluxnorm_bulge_"+filt.name].iloc[0] / total_fluxnorm * total_flux
    agn_flux = 0.0 if no_agn else entry["fluxnorm_agn_"+filt.name].iloc[0]  / total_fluxnorm * total_flux
    
    if disk_flux + bulge_flux + agn_flux == 0:
        raise ValueError("No visible components")

    if disk_flux > 0:
        a_d, b_d = entry["a_d"].iloc[0], entry["b_d"].iloc[0]
        disk_hlr_arcsecs = a_d
        disk_q = b_d/a_d
        pa = np.pi*entry['position_angle_true_dc2'].iloc[0]/180
        
        epsilon_disk = (1 - disk_q) / (1 + disk_q)
        e1_disk = epsilon_disk * np.cos(2 * pa)
        e2_disk = epsilon_disk * np.sin(2 * pa)

        disk = galsim.Exponential(flux=disk_flux, half_light_radius=disk_hlr_arcsecs).shear(
            e1=-e1_disk, e2=e2_disk
        )
        components.append(disk)
        
    if bulge_flux > 0:
        a_b, b_b = entry["a_b"].iloc[0], entry["b_b"].iloc[0]
        bulge_hlr_arcsecs = np.sqrt(a_b * b_b)
        bulge_q = b_b/a_b
        pa = np.pi*entry['position_angle_true_dc2'].iloc[0]/180
        
        epsilon_bulge = (1 - bulge_q) / (1 + bulge_q)
        e1_bulge = epsilon_bulge * np.cos(2 * pa)
        e2_bulge = epsilon_bulge * np.sin(2 * pa)
        
        bulge = galsim.DeVaucouleurs(flux=bulge_flux, half_light_radius=bulge_hlr_arcsecs).shear(
           e1=-e1_bulge, e2=e2_bulge
        )
        components.append(bulge)
    
    if agn_flux > 0:
        agn = galsim.Gaussian(flux=agn_flux, sigma=1e-8)
        components.append(agn)

    profile = galsim.Add(components)
    return profile

def make_im(entry, survey, filt, nx=128, ny=128):
    """Create image simulation for a single object"""
    psf = survey.get_filter(filt).psf
    obj_type = entry['truth_type'].iloc[0]
    
    if obj_type == 1:  # Galaxy
        gal = make_galaxy(entry, survey, survey.get_filter(filt))
        gal = gal.shear(g1=entry["g1"].iloc[0], g2=entry["g2"].iloc[0])
        conv_gal = galsim.Convolve(gal, psf)
        im = conv_gal.drawImage(
            nx=nx,
            ny=ny,
            scale=survey.pixel_scale.to_value("arcsec")
        )
    else:  # Star
        star, gsparams, flux = make_star(entry, survey, survey.get_filter(filt))
        max_n_photons = 10_000_000
        # 0 means use the flux for n_photons 
        n_photons = 0 if flux < max_n_photons else max_n_photons
        conv_star = galsim.Convolve(star, psf)
        im = conv_star.drawImage(
            nx=nx,
            ny=ny,
            scale=survey.pixel_scale.to_value("arcsec"),
            method="phot",
            n_photons=n_photons,
            poisson_flux=True,
            maxN=1_000_000,  # shoot in batches this size
            rng=grng
        )
    return im

def get_bbox(mask):
    rows = np.any(mask, axis=1)
    cols = np.any(mask, axis=0)
    rmin, rmax = np.where(rows)[0][[0, -1]]
    cmin, cmax = np.where(cols)[0][[0, -1]]

    return rmin-4, rmax+4, cmin-4, cmax+4

def create_single_object_metadata(obj_info, survey, filters, wcs_dict, lvl=2):
    """Create metadata for a single centered object"""
    ddict = {
        "file_name": "./" + wcs_dict["filename"],
        "image_id": obj_info["orig_image_id"],
        "height": 142,  # Assuming fixed size for centered cutouts
        "width": 142,
        "subpatch": obj_info["subpatch"],
        "wcs": wcs_dict['wcs']
    }
    
    cat = dcut_reformat(obj_info["obj_params"])
    new_wcs = WCS(wcs_dict["wcs"])
    
    ra = obj_info["ra"]
    dec = obj_info["dec"]
    new_x, new_y = new_wcs.world_to_pixel(SkyCoord(ra=ra*u.deg, dec=dec*u.deg))
    
    x = int(new_x)
    y = int(new_y)
    
    segs = []
    for filt in filters:
        im = make_im(cat, survey, filt)
        imd = np.expand_dims(np.expand_dims(im.array, 0), 0)
        sky_level = mean_sky_level(survey, filt).to_value('electron') # gain = 1
        segs.append(btk.metrics.utils.get_segmentation(imd, sky_level, sigma_noise=2))
    
    mask = np.clip(np.sum(segs, axis=0), a_min=0, a_max=1)[0][0]
    if np.sum(mask)==0:
        print(f"Mask summed to zero. Object skipped! {ddict['file_name']}")
        return
    # Get object bbox and segmentation
    bbox = get_bbox(mask)
    x0 = bbox[2]
    x1 = bbox[3]
    y0 = bbox[0]
    y1 = bbox[1]
    
    w = x1 - x0
    h = y1 - y0
    
    bbox = [x-w/2, y-h/2, w, h]
    
    redshift = cat['redshift']
    
    contours, _ = cv2.findContours(
        mask.astype(np.uint8),
        cv2.RETR_TREE,
        cv2.CHAIN_APPROX_SIMPLE
    )
    
    segmentation = []
    for contour in contours:
        contour = contour.flatten()
        if len(contour) > 4:
            contour[::2] += (int(np.rint(x))-x0-w//2)
            contour[1::2] += (int(np.rint(y))-y0-h//2)
            segmentation.append(contour.tolist())
    
    if len(segmentation) == 0:
        print(f"No segm mask! Obj skipped! {ddict['file_name']}")
        return
    
    obj = {
        "obj_id": obj_info["obj_id"],
        "bbox": bbox,
        "area": w * h,
        "bbox_mode": 1,
        "segmentation": segmentation,
        "category_id": 1 if obj_info["truth_type"] == 2 else 0,
        "redshift": redshift,
        "mag_i": cat["mag_i"]
    }
    
    ddict["annotations"] = [obj]
    
    return ddict

def process_centered_objects(centered_objects_file):
    """Process all centered objects and create annotations"""
   
    with open(centered_objects_file, 'r') as f:
        centered_objects = json.load(f)
    
    survey = btk.survey.get_surveys("LSST")
    filters = ['u', 'g', 'r', 'i', 'z', 'y']
    
    # Process each object
    all_metadata = []
    for obj_info in centered_objects:
        wcs_dict_file = get_wcs_dict_path(obj_info["roman_file"])
        with open(wcs_dict_file, 'r') as f:
            wcs_dict = json.load(f)
        metadata = create_single_object_metadata(
            obj_info,
            survey,
            filters,
            wcs_dict
        )
        all_metadata.append(metadata)
    
    df = pd.DataFrame(all_metadata)
#     output_file = f'/home/shared/hsc/roman_lsst/lsst_data/annotations/{sub_patch}.json'
    output_file = f'test-centered.json'
    df.to_json(output_file, orient='records')
#     with open('test_centered.json', 'w') as f:
#         json.dump(all_metadata, f, indent=2)

/home/yse2/deepdisc/src/deepdisc/__init__.py
/home/yse2/detectron2/detectron2/__init__.py


In [11]:
centered_objects_file = "cutout_processing_info_1.json"  # Your truncated objects file
wcs_dict_file = "./trunc-lsst/metadata/dc2_51.37_-38.3/full_c108_lsst_1.json"  # Your WCS dictionary file
process_centered_objects(centered_objects_file, wcs_dict_file)

In [9]:
centered_objects_file = "cutout_processing_info_9.json"  # Your truncated objects file
wcs_dict_file = "./trunc-lsst/metadata/dc2_51.37_-38.3/full_c108_lsst_9.json"  # Your WCS dictionary file
process_centered_objects(centered_objects_file, wcs_dict_file)

In [10]:
centered_objects_file = "cutout_processing_info.json"
process_centered_objects(centered_objects_file)