# Training Data




## This notebook assumes you have the DC2 data downloaded  

You will have to change directorty paths

In [None]:
# In case you need to point to pre-existing scarlet install
import sys
import deepdisc

In [None]:
# Standard imports
import sys, os
import numpy as np
import time
import glob

import scarlet
import sep

import astropy.io.fits as fits
from astropy.wcs import WCS
from astropy.stats import gaussian_fwhm_to_sigma
from astropy.coordinates import SkyCoord

from scarlet.display import AsinhMapping
from astropy.nddata import Cutout2D

# Astrodet imports
import deepdisc.preprocessing.detection as detection
import deepdisc.preprocessing.process as process

from deepdisc.astrodet.hsc import get_tract_patch_from_coord, get_hsc_data

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import matplotlib
import matplotlib.pyplot as plt

# use a better colormap and don't interpolate the pixels
matplotlib.rc('image', cmap='gray', interpolation='none', origin='lower')
from skimage.util.shape import view_as_blocks
from matplotlib import colors

import pandas as pd

In [None]:
# Print the versions to test the imports and so we know what works
print(scarlet.__version__)
print(np.__version__)
print(sep.__version__)

### Run Scarlet to produce segmentation maps

First, let's test scarlet using one DC2 image. The DC2 image data is divided into "tracts" and "patches" on the sky. You can get the data here https://data.lsstdesc.org/.

You will need to change the directory paths below

In [None]:
from deepdisc.preprocessing.get_data import get_cutout
filters = ['u','g','r','i','z','y']
dirpath = '/home/g4merz/DC2/coadd-t3828-t3829/deepCoadd-results/'
nb=16 #The number of cutouts per side of an image.  4k CCDs are too large to train with, so we reduce the size
sp=18 #The "subpatch", i.e. which of the nb x nb cutouts to use 

### Using an input catalog

The cells below assume you have an input catalog `all_tracts_cat.csv` corresponding to the tracts and patches you've downloaded.  We can run the code without one, but it is necessary for truth-matching any quantities

In [None]:
import astropy.units as u
from astropy.coordinates import SkyCoord
from astropy.stats import gaussian_fwhm_to_sigma

dall=pd.read_csv('/home/g4merz/DC2/nersc_data/data/all_tracts_cat.csv')
ra_all = dall['ra'][:].values
dec_all = dall['dec'][:].values
allcatalog = SkyCoord(ra=ra_all*u.degree, dec=dec_all*u.degree)

In [None]:
def get_cutout_cat(dirpath,tract,patch,sp,nblocks=4,filters=['u','g','r','i','z','y']):
    '''
        WARNING!!!!!
        It is not efficient to have the full catalog (defined here as dall) as input to a function when doing multiprocesing.  
        Keep it in the top level process
    
    '''
    cutout,datsm= get_cutout(dirpath,tract=tract,patch=patch,sp=sp,nblocks=nblocks,filters=filters,plot=False)
    xs,ys = cutout.wcs.world_to_pixel(allcatalog)
    inds = np.where((xs>=0) & (xs<cutout.shape[1]-1) & (ys>=0) & (ys<cutout.shape[0]-1))[0]
    
    dcut = dall.iloc[inds]

    dcut['new_x'] = xs[inds]
    dcut['new_y'] = ys[inds]

    column_to_move = dcut.pop("objectId")

    # insert column with insert(location, column_name, column_value)
    dcut.insert(0, "objectId", column_to_move)
    dcut.sort_values(by='objectId')
    
    return datsm, dcut

In [None]:
datsm,dcut = get_cutout_cat(dirpath,'3828','1,1',18,nblocks=8)

In [None]:
plt.figure(figsize=(7,7))
plt.imshow(datsm[3],norm=colors.LogNorm(),origin='lower')
plt.scatter(dcut['new_x'].values,dcut['new_y'].values,marker='.')
plt.axis('off')

#### The function below will run scarlet on the input images, and output segmentation maps as well as image fits files for each cutout

In [None]:
filters = ['u','g','r','i','z','y']


def generate_training_data_example(dirpath, tract, patch, sp, outdir, plot_image=False, plot_stretch_Q=False, plot_scene=True,
                                   plot_likelihood=False, write_results=True):
    """
    Parameters
    ----------
    c : SkyCoord object
          The ra, dec pointing (single or lists of pointings)
    plot_image : bool
          Whether or not to plot the image
    plot_stretch_Q : bool
          Whether or not to plot different normalizations of your image using the stretch, Q parameters.
    plot_scene : bool
           Whether or not plot scene with scarlet
    plot_likelihood : bool
           Whether or not plot the log likelihood of the scarlet fitting
    write_results : bool
          Whether or not to write results to FITS file
    cutout_size : [int, int]
          Cutout shape of image
          
    Returns
    -------
    The scarlet image test in FITS files.
    
    """
    print(tract,patch,sp)
    print()
    
    #datas,dcut = get_cutout_cat(dirpath,tract=tract,patch=patch,sp=sp,nblocks=nb)
    #cut = np.where(dcut['mag_i'].values<25.3)[0]
    #dcut = dcut.iloc[cut]
    
    cutout,datas= get_cutout(dirpath,tract=tract,patch=patch,sp=sp,nblocks=nb,filters=filters,plot=False)


    ### Run scarlet on image ###

    # HSC pixel scale in arcsec/pixel
    ps = 0.2
    # Approximate PSF size in UD field according to HSC DR2 paper is 0.8 arcsec
    sigma_obs = gaussian_fwhm_to_sigma*0.8/ps
    
    
    psf = np.load(f'/home/g4merz/DC2/nersc_data/data/psfs/{tract}_{patch}_0_psfs.npy')
    
    # Run Scarlet
    out = detection.run_scarlet(datas, filters, catalog=None, lvl=2, sigma_model=1, sigma_obs=sigma_obs, psf=psf, plot_scene=plot_scene,
                         max_chi2=1000000, morph_thresh=1, stretch=1, Q=5, 
                         plot_wavelet=False, plot_likelihood=plot_likelihood, plot_sources=False, add_ellipses=False,
                         add_labels=False, add_boxes=False, lvl_segmask=1, maskthresh=0.005)

    # Unpack output
    observation, starlet_sources, model_frame, catalog, catalog_deblended, segmentation_masks = out

    
    # Save Scarlet data to FITS file
    if write_results:
        filenames = process.write_scarlet_results(datas, observation, starlet_sources, model_frame, 
                                             catalog_deblended, segmentation_masks, outdir=outdir, 
                                             filters=filters, s=f'{tract}_{patch}_{sp}', source_catalog=catalog)
    
        print(f'\nSaved scarlet results as {filenames} \n')
    
    
        

#### Let's run on one cutout

In [None]:
%%time
outdir='./'

generate_training_data_example(tract='3828',patch='1,1',sp=17, dirpath=dirpath,plot_scene=True, plot_likelihood=True, write_results=True, outdir=outdir)


#### Now let's run in parallel to speed things up a bit

In [None]:
import multiprocessing

processes = 4

tract='3828'
sps = (13,17,18,19)
patch = '1,1'
outdir = './'
import multiprocessing
from itertools import repeat

args = zip(repeat(dirpath),repeat(tract), repeat(patch), sps, repeat(outdir))

t0 = time.time()
with multiprocessing.Pool(processes=processes) as pool:
    results = pool.starmap(generate_training_data_example, args)

#### Now we can utilize some preprocessing functions

In [None]:
from src.deepdisc.data_format.file_io import DDLoader
from src.deepdisc.data_format.annotation_functions.annotate_dc2 import annotate_dc2

Here we create a DDLoader class, which helps gather output files and format them

In [None]:
loader = DDLoader().generate_filedict('./', ['U', 'G', 'R', 'I', 'Z','Y'], '*_scarlet_img.fits', '*_scarlet_segmask.fits')
filedict = loader.filedict
img_files = np.transpose([filedict[filt]["img"] for filt in filedict["filters"]])


Here we randomly split the datasets into "train" and "test" directories, with 2 cutouts each.  These new directories are created inside `splitdirs`

In [None]:
splitdirs = '/home/g4merz/deepdisc/tests/'
loader.random_sample(splitdirs,nfiles=[2,2])

Generate a new filedict for the new train directory

In [None]:
loader = DDLoader().generate_filedict('/home/g4merz/deepdisc/tests/train/', ['U', 'G', 'R', 'I', 'Z','Y'], '*_scarlet_img.fits', '*_scarlet_segmask.fits')
filedict = loader.filedict
img_files = np.transpose([filedict[filt]["img"] for filt in filedict["filters"]])

Use a preprocessing function to turn the single-band fits images into a multi-band numpy array


In [None]:
from deepdisc.preprocessing.process import fits_to_numpy #, fits_to_hdf5

In [None]:
fits_to_numpy(img_files,'/home/g4merz/deepdisc/tests/train/')

Additional functionality not covered in this notebook

In [None]:
'''
#This is used to create a dataset_dict in the deepdisc format.  It is necessary for training.  
#However, the code assumes you have used an input catalog to get ground truth redshifts and object classes, which we have skipped here

d='train'
dataset_dicts={}
dataset_dicts[d] = loader.generate_dataset_dict(annotate_dc2).get_dataset()  


#This is used to flatten images for RAIL

fits_to_hdf5(img_files,'/home/g4merz/deepdisc/tests/train/',dset='train')


'''