In [1]:
# Some basic setup:
# Setup detectron2 logger
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()

# import some common libraries
import numpy as np
import os, json, cv2, random
#from google.colab.patches import cv2_imshow
import matplotlib.pyplot as plt

# import some common detectron2 utilities
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.data import build_detection_train_loader
from detectron2.data import build_detection_test_loader

from detectron2.engine import DefaultTrainer
from detectron2.engine import SimpleTrainer
from detectron2.engine import HookBase
from typing import Dict, List, Optional
import detectron2.solver as solver
import detectron2.modeling as modeler
import detectron2.data as data
import detectron2.data.transforms as T
import detectron2.checkpoint as checkpointer
from detectron2.data import detection_utils as utils
import weakref
import copy
import torch
import time

import imgaug.augmenters as iaa

from astrodet import astrodet as toolkit
from PIL import Image, ImageEnhance
from astropy.visualization import make_lupton_rgb
from astrodet.detectron import plot_stretch_Q
from detectron2.utils.file_io import PathManager
from iopath.common.file_io import file_lock

import logging
logger = logging.getLogger(__name__)
import shutil
import json

In [2]:
# Print the versions to test the imports and so we know what works
print(detectron2.__version__)
print(np.__version__)
print(cv2.__version__)

0.6
1.20.3
4.5.3


In [3]:
# Prettify the plotting
from astrodet.astrodet import set_mpl_style
set_mpl_style()

In [4]:
#Todo: change the dirpath here
dirpath = '/home/shared/hsc/decam/decam_data/' # Path to dataset
output_dir = '/home/shared/hsc/decam/models/'

#dataset_names = ['train', 'test', 'val'] 
dataset_names = ['test'] 

In [5]:
from detectron2.structures import BoxMode
from astropy.io import fits
import glob

def get_astro_dicts(img_dir):
        
    # It's weird to call this img_dir
    set_dirs = sorted(glob.glob('%s/set_*' % img_dir))
    
    dataset_dicts = []
    
    # Loop through each set
    for idx, set_dir in enumerate(set_dirs):
        record = {}
        
        mask_dir = os.path.join(img_dir, set_dir, "masks.fits")
        filename = os.path.join(img_dir, set_dir, "img")
        
        # Open each FITS image
        with fits.open(mask_dir, memmap=False, lazy_load_hdus=False) as hdul:
            sources = len(hdul)
            height, width = hdul[0].data.shape
            data = [hdu.data/np.max(hdu.data) for hdu in hdul]
            category_ids = [hdu.header["CLASS_ID"] for hdu in hdul]
            
        record["file_name"] = filename
        record["image_id"] = idx
        record["height"] = height
        record["width"] = width
        objs = []
        
        # Mask value thresholds per category_id
        thresh = [0.005 if i == 1 else 0.08 for i in category_ids]
        
        # Generate segmentation masks
        for i in range(sources):
            image = data[i]
            mask = np.zeros([height, width], dtype=np.uint8)
            # Create mask from threshold
            mask[:,:][image > thresh[i]] = 1
            # Smooth mask
            mask[:,:] = cv2.GaussianBlur(mask[:,:], (9,9), 2)
            
            # https://github.com/facebookresearch/Detectron/issues/100
            contours, hierarchy = cv2.findContours((mask).astype(np.uint8), cv2.RETR_TREE,
                                                        cv2.CHAIN_APPROX_SIMPLE)
            segmentation = []
            for contour in contours:
                x,y,w,h = cv2.boundingRect(contour)
                contour = contour.flatten().tolist()
                # segmentation.append(contour)
                if len(contour) > 4:
                    segmentation.append(contour)
            # No valid countors
            if len(segmentation) == 0:
                continue
            
            # Add to dict
            obj = {
                "bbox": [x, y, w, h],
                "area": w*h,
                "bbox_mode": BoxMode.XYWH_ABS,
                "segmentation": segmentation,
                "category_id": category_ids[i] - 1,
            }
            objs.append(obj)
            
        record["annotations"] = objs
        dataset_dicts.append(record)
         
    return dataset_dicts


for i, d in enumerate(dataset_names):
    filenames_dir = os.path.join(dirpath,d)
    DatasetCatalog.register("astro_" + d, lambda: get_astro_dicts(filenames_dir))
    MetadataCatalog.get("astro_" + d).set(thing_classes=["star", "galaxy"], things_colors = ['blue', 'gray'])
astro_metadata = MetadataCatalog.get("astro_train")

In [6]:
def convert_to_json(dict_list, name, output_file, allow_cached=True):
    """
    Converts dataset into COCO format and saves it to a json file.
    dataset_name must be registered in DatasetCatalog and in detectron2's standard format.

    Args:
        dataset_name:
            reference from the config file to the catalogs
            must be registered in DatasetCatalog and in detectron2's standard format
        output_file: path of json file that will be saved to
        allow_cached: if json file is already present then skip conversion
    """

    # TODO: The dataset or the conversion script *may* change,
    # a checksum would be useful for validating the cached data

    PathManager.mkdirs(os.path.dirname(output_file))
    with file_lock(output_file):
        if PathManager.exists(output_file) and allow_cached:
            logger.warning(
                f"Using previously cached COCO format annotations at '{output_file}'. "
                "You need to clear the cache file if your dataset has been modified."
            )
        else:            

            print(f"Caching COCO format annotations at '{output_file}' ...")
            tmp_file = output_file + ".tmp"
            with PathManager.open(tmp_file, "w") as f:
                json.dump(dict_list, f)
            shutil.move(tmp_file, output_file)

In [8]:
t0 = time.time()
dataset_dicts = {}
for i, d in enumerate(dataset_names):
    print(f'Loading {d}')
    dataset_dicts[d] = get_astro_dicts(os.path.join(dirpath, d))
    
print('Took ', time.time()-t0, 'seconds ')

Loading test
Took  23.645785093307495 seconds 


In [9]:
dataset_dicts['test'][0]

{'file_name': '/home/shared/hsc/decam/decam_data/test/set_0/img',
 'image_id': 0,
 'height': 512,
 'width': 512,
 'annotations': [{'bbox': [327, 503, 8, 9],
   'area': 72,
   'bbox_mode': <BoxMode.XYWH_ABS: 1>,
   'segmentation': [[328,
     503,
     328,
     504,
     327,
     505,
     327,
     508,
     328,
     509,
     328,
     511,
     333,
     511,
     333,
     509,
     334,
     508,
     334,
     504,
     333,
     503]],
   'category_id': 1},
  {'bbox': [110, 509, 7, 3],
   'area': 21,
   'bbox_mode': <BoxMode.XYWH_ABS: 1>,
   'segmentation': [[111, 509, 110, 510, 110, 511, 116, 511, 114, 509]],
   'category_id': 1},
  {'bbox': [366, 79, 12, 13],
   'area': 156,
   'bbox_mode': <BoxMode.XYWH_ABS: 1>,
   'segmentation': [[370,
     79,
     369,
     80,
     368,
     80,
     366,
     82,
     366,
     87,
     368,
     89,
     368,
     90,
     369,
     90,
     370,
     91,
     374,
     91,
     377,
     88,
     377,
     82,
     374,
     79]],
 

In [41]:
#! rm /home/shared/hsc/HSC/HSC_DR3/data/test.json
#! rm /home/shared/hsc/HSC/HSC_DR3/data/test.json.lock

In [10]:
dirpath

'/home/shared/hsc/decam/decam_data/'

In [11]:
#convert_to_coco_json('astro_train', os.path.join(dirpath,'train.json'), allow_cached=True)

#convert_to_json(dataset_dicts['train'], 'astro_train', os.path.join(dirpath,'train.json'), allow_cached=False)
convert_to_json(dataset_dicts['test'], 'astro_test', os.path.join(dirpath,'test.json'), allow_cached=False)
#convert_to_json(dataset_dicts['val'], 'astro_val', os.path.join(dirpath,'val.json'), allow_cached=False)

Caching COCO format annotations at '/home/shared/hsc/decam/decam_data/test.json' ...


In [18]:
def get_data_from_json(file):
    # Opening JSON file
    with open(file, 'r') as f:
        data = json.load(f)
    return data

In [19]:
testfile='/home/shared/hsc/decam/models/int16/astro_test_coco_format.json'

td = get_data_from_json(testfile)


In [20]:
len(td['images'])

250