### HOWTO Guide for HSTCosmicrays (Beta version)

1. Create `labeler.CosmicRayLabel()`  and `metadata.GenerateMetadata()` objects to store all the necessary information in convenient containers
2. Run the cosmic ray labeling algorithm on the input image two separate ways
    - By using the DQ array information after running a cosmic ray rejection routine
    - By defining a custom threshold value to use when generating a binary image
3. Using the generated label, compute statistics for each one of the cosmic rays identified 
4. Load in the pre-trained ML model for distinguishing cosmic rays from point sources and use it to classify the identified sources
5. Use a more interesting dataset containing acutal an astrophysical source 

In [None]:
%matplotlib notebook
import os
import glob
import pickle
import sys
sys.path.append('/Users/nmiles/hst_cosmic_rays/')
import warnings
warnings.simplefilter('ignore')

# Import packages from the hst_cosmic_rays package
from pipeline.label import labeler
from pipeline.stat_utils import statshandler
from pipeline.utils import metadata, initialize



from astropy.io import fits
from astropy.visualization import ImageNormalize, ZScaleInterval,\
                                    SqrtStretch, LinearStretch, LogStretch
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
import numpy as np


Define a convenience dictionary to make it simpler to determine the img stretch when displaying FITS files

In [None]:
stretchdict_ = {'sqrt': SqrtStretch(),
           'log': LogStretch(),
           'linear':LinearStretch()}


Setup a path to a directory containing the data you want to analyze

In [None]:
data_path = '/Users/nmiles/hst_cosmic_rays/data/STIS/SAA_data'
flist = glob.glob(os.path.join(data_path, 'o*flt.fits'))
print('\n'.join(flist[:5]))

Define a helper function to use for plotting images

In [None]:
def plot(data, norm=None, stretch_type=None, vmin=None, vmax=None):
    """Simple plotting function"""
    if norm is None and stretch_type is None:
        pass
    elif norm is None and stretch_type is not None:
        if vmin is not None and vmax is not None:
            norm = ImageNormalize(data, stretch=stretchdict_[stretch_type],vmin=vmin, vmax=vmax)
        else:
            norm = ImageNormalize(data, stretch=stretchdict_[stretch_type], interval=ZScaleInterval())
    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(5,5))
    im = ax.imshow(data, norm=norm, origin='lower', cmap='gray')
    plt.show()

Now that we have a list of files that we want to examine and all our helper functions defined, we examine the two classes, `metadata.GenerateMetadata()` and `labeler.CosmicRayLabel`, from the `hstcosmicrays` package that we will be using.

In [None]:
metadata.GenerateMetadata?

In [None]:
labeler.CosmicRayLabel?

In [None]:
instr = 'STIS_CCD'
test_file = flist[0]

The first thing we want to do is examine some metadata associated with our test file. The two attributes of the `GenerateMetadata` class used to store the relevant metadata are `instr_cfg` and `metadata`

In [None]:
# Create a metadata object for our test file and instr
img_meta = metadata.GenerateMetadata(fname=test_file, instr=instr)

# Get the image metadata
img_meta.get_image_data()

# Get the observatory metadata from the SPT file
img_meta.get_observatory_info()

# Print out the extracted metadata and our instrument configuration
print(f"Metadata extracted for {os.path.basename(test_file)}")
print("-"*50)
for key, val in img_meta.metadata.items():
    print(f"{key} -> {val}")
print("-"*50,'\n')
print(f"Instrument Configuration for {instr}")
print("-"*50)
for key1, val1 in img_meta.instr_cfg.items():
    if isinstance(val1, dict):
        print(f"{str(key1):}:")
        for key, val in img_meta.instr_cfg[key1].items():
            print(f"{str(key):>25} -> {str(val):}")
    

Ok, now that we have the required metadata read in, we create a `CosmicRayLabel` object. We use the gain keyword contained in the `instr_cfg` attribute of the `GenerateMetadata` object that we created above. This ensures that if the units of the input file are in DN or COUNTS, we convert it to ELECTRONS before proceeding.

In [None]:
# Create an instance of the CosmicRayLabel class and read in SCI and DQ extensions
cr_label = labeler.CosmicRayLabel(
    fname=test_files[0], 
    gain_keyword=img_meta.instr_cfg['instr_params']['gain_keyword']
)
# Read in the sci data
cr_label.get_data(extname='sci', extnums=[1])

# Read in the dq data
cr_label.get_data(extname='dq', extnums=[1])

Now lets plot each extension that we just read in 

In [None]:
plot(cr_label.dq)

In [None]:
plot(cr_label.sci, stretch_type='sqrt', vmin=0, vmax=20)

There are two ways to label cosmic rays. 
1. Using the DQ array to identify groups of pixels affected by the same CR
2. Using a custom generated binary image to identify groups of pixels affected by the same CR

The cell blocks below demonstrate the first method
<hr style="height:1px;border:none;color:#333;background-color:#333;" />
If the images have been processed such that their DQ array contains BIT flag for identifiy cosmic ray affected pixels, then we run the labeling analysis with the following parameters defined below. 

In [None]:
dq_labeling_parameters = {
    'use_dq': True, # Whether or not to use the DQ array
    'dq_flag': 8192,  # The BIT flag identifying CR (default 8192)
    'do_bitwise_comp': True, # Do a BITWISE_AND comparison with dq_flag
    'deblend': False, # If True, try to deblend (experimental, best to leave as False)
    'threshold_l': 2, # Lower threshold for size of the labeled object to be consider a CR
    'threshold_u': 5000, # Upper threshold for size of labeled object to be consider a CR
    'pix_thresh': None, # Not used for using the DQ to label cosmic rays
    'structure_element': np.ones((3,3)) # Structuring element to be used in labeling
}

Run the labeling analysis code for CCDs using the parameters defined above. Once the processing has finished, plot the `SCI` extension and the CR segmentation map side-by-side using the `plot()` method of the `CosmicRayLabel` class.

In [None]:
cr_label.ccd_labeling(**dq_labeling_parameters)
cr_label.plot(xlim=(200, 400), ylim=(200, 400))

If the image's DQ arry doesnt contain BIT flags for CRs, then you can instead run the labeling using a custom threshold (this is option 2. mentioned previously)

In [None]:
threshold_labeling_parameters = {
    'use_dq': False,
    'dq_flag': None,
    'do_bitwise_comp': False,
    'deblend': False, # If True, try to deblend (experimental, best to leave as False)
    'threshold_l': 2, # Lower threshold for size of the labeled object to be consider a CR
    'threshold_u': 5000, # Upper threshold for size of labeled object to be consider a CR
    'pix_thresh': 3*np.mean(cr_label.sci), # Set the absolute threshold to 3x the mean val
    'structure_element': np.ones((3,3)) # Structuring element to be used in labeling
}

In [None]:
cr_label.ccd_labeling(**threshold_labeling_parameters)
cr_label.plot(xlim=(200, 400), ylim=(200, 400))

Once we have the label in hand, we can start computing some statistics describing the identified cosmic rays using the `Stats` class in the `statshandler` module. As before, we inspect the class first to determine the proper inputs

In [None]:
stats = statshandler.Stats?

In [None]:
stats = statshandler.Stats

In [None]:
cr_stats = statshandler.Stats(cr_label=cr_label,
                           integration_time=img_meta.metadata['integration_time'],
                          detector_size=img_meta.instr_cfg['instr_params']['detector_size'])
cr_stats.compute_cr_statistics()

Now we load the previously trained K-Nearest Neighbors classifier and use the model to predict whether the object identified are cosmic rays (1) or stars (0). Since this is a dark frame, the exptectation is for everything to be classified as a cosmic ray 

In [None]:
with open('knn_classifier_Oct03_2019.pkl', 'rb') as fd:
    clf = pickle.load(fd)

In [None]:
predict = clf.predict(list(zip(cr_stats.size_in_sigmas, cr_stats.shapes)))

In [None]:
predict.sum() == len(predict)

### Example 2: Identifying CRs in observations with external sources

Now we are going to apply everything that we just used above, but instead we are going to analyze an image with actual sources and cosmic rays

In [None]:
from astropy.convolution import Gaussian2DKernel
from astropy.stats import gaussian_fwhm_to_sigma

from photutils import detect_sources
from photutils import detect_threshold
from photutils import detect_sources

In [None]:
datapath = '/Users/nmiles/hst_cosmic_rays/notebooks/MAST_2019-11-09T1348/HST/ocr7qvhaq'

In [None]:
flist = glob.glob(datapath+'/*flt.fits')

In [None]:
star_label = labeler.CosmicRayLabel(fname=flist[0])
star_label.get_data(extname='sci',extnums=[1])

In [None]:
plot(star_label.sci, stretch_type='log', vmin=0, vmax=1000)

In [None]:
star_label.sci

In [None]:
np.mean(star_label.sci)

In [None]:
# sigma = 3.0 * gaussian_fwhm_to_sigma  # FWHM = 3.
# kernel = Gaussian2DKernel(sigma, x_size=3, y_size=3)
# kernel.normalize()
segm = detect_sources(star_label.sci, 10*np.median(star_label.sci) , npixels=8)

In [None]:
threshold_labeling_parameters = {
    'use_dq': False,
    'dq_flag': None,
    'do_bitwise_comp': False,
    'deblend': False, # If True, try to deblend (experimental, best to leave as False)
    'threshold_l': 2, # Lower threshold for size of the labeled object to be consider a CR
    'threshold_u': 800, # Upper threshold for size of labeled object to be consider a CR
    'pix_thresh': 10*np.mean(cr_label.sci), # Set the absolute threshold to 3x the mean val
    'structure_element': np.ones((3,3)) # Structuring element to be used in labeling
}

In [None]:
star_label.ccd_labeling(**threshold_labeling_parameters)

In [None]:
star_label.plot(instr='STIS/CCD')

In [None]:
star_stats = statshandler.Stats(
    cr_label=star_label,
    integration_time=img_meta.metadata['integration_time'],
    detector_size = img_meta.instr_cfg['instr_params']['detector_size']
)
star_stats.compute_cr_statistics()

In [None]:
classifications = clf.predict(list(zip(star_stats.size_in_sigmas, star_stats.shapes)))

In [None]:
percentage_of_crs = np.sum(classifications)/ len(classifications)
percentage_of_stars = len(np.where(classifications==0)[0])/len(classifications)

In [None]:
print(f"Percentage of CRs: {percentage_of_crs:.2%}")
print(f"Percentage of stars: {percentage_of_stars:.2%}")
print(f"Total: {percentage_of_crs + percentage_of_stars:.2%}")

In [None]:
star_label.plot(instr='STIS/CCD')

In [None]:
for i, c in enumerate(classifications):
    centroid = star_stats.centroids[i]
    if c:
        patch = Circle(xy=(centroid[1], centroid[0]), radius=3, color='red', fill=False)
    else:
        patch = Circle(xy=(centroid[1], centroid[0]), radius=3, color='green', fill=False)
    star_label.ax1.add_patch(patch)