In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib notebook
import glob
from astropy.io import fits
from astropy.convolution import Gaussian2DKernel
from astropy.stats import gaussian_fwhm_to_sigma
from astropy.stats import sigma_clipped_stats
from astropy.visualization import ImageNormalize, LogStretch, ZScaleInterval
import numpy as np


from ComputeStats import ComputeStats
from CosmicRayLabel import CosmicRayLabel
import matplotlib.pyplot as plt

from scipy import ndimage
from matplotlib import colors
from matplotlib.patches import Rectangle, Circle
from photutils import detect_sources, detect_threshold
from photutils import detection
from photutils import source_properties, deblend_sources

In [None]:
 def get_data(fname):
        """ Grab the SCI extensions from fits file
        """
        pc = ('sci', 1) # Chip 2
        wf2 = ('sci', 2) # Chip 1
        wf3 = ('sci', 3)
        wf4 = ('sci', 4)
        detector_data = []
        with fits.open(fname) as hdu:
            for ext in [pc, wf2, wf3, wf4]:
                try:
                    ext = hdu.index_of(ext)
                    ext_data = hdu[ext].data
                except KeyError:
                    print('{1} is missing for {0}'.format(fname, ext))
                    ext1 = None
                else:
                    detector_data.append(ext_data)
        # If second ext is missing, only work with the first
        # Otherwise combine each DQ ext to make full-frame

        return detector_data
    

In [None]:
def find_sources(fname, deblend=False):
    """ 
    Generate segmentation map
    """
    sci_data = get_data(fname)
    
    # Generate a detection threshold for the segmentation mapping, we want everything with SNR higher than 5
    threshold = detect_threshold(sci_data[0], snr=5.)
    
    # Generate a kernel for use in the segmentation mapping, normalize it's value to 1
    sigma = 1 * gaussian_fwhm_to_sigma # convert FWHM of 2.75 pix to sigma
    kernel = Gaussian2DKernel(sigma, x_size=3, y_size=3)
    kernel.normalize()
    data_list, segm_list = [], []
    for data in sci_data:
    # Find sources
        mean, median, std = sigma_clipped_stats(data, sigma_lower = 3, sigma_upper= 2)
        print('mean: {}, median: {}, std: {}'.format(mean, median, std))
        segm = detect_sources(data-median,
                          threshold=median+5*std,
                          npixels=4,
                          filter_kernel=kernel,
                          connectivity=8)
        
        
        if deblend:
            # Deblened sources
            print('deblending')
            try:
                segm_deblend = deblend_sources(sci_data[0],
                                           segm.data,
                                           npixels=4,
                                           nlevels=32,
                                           filter_kernel=kernel,
                                           contrast=0.1,
                                           connectivity=8,
                                          )
            except ValueError:
                pass
            else:
                segm = segm_deblend
#         return sci_data[0], segm_deblend
        segm_list.append(segm)
        data_list.append(data)
    return data_list, segm_list



In [None]:
flist = glob.glob('../crrejtab/WFPC2/mastDownload/HST/*/*c0m.fits')

In [None]:

f_1000 = []
t_exptime = []
for f in flist:
    t = fits.getval(f, 'exptime')
    if t > 1000:
        f_1000.append(f)
        t_exptime.append(t)

In [None]:
fname1 = f_1000[0]
fname2 = f_1000[1]

In [None]:
data1_list, segm1_list = find_sources(fname1, deblend=False)

In [None]:
data2_list, segm2_list = find_sources(fname2, deblend=False)

In [None]:
high_val = np.where(data1_list[0] > 4.160646625864344 + 5*1.1595283054779255, 1, 0)

In [None]:
high_val.shape

In [None]:
cmap = colors.ListedColormap(['black', 'white'])

In [None]:
threshold=2

In [None]:
label, num_feat = ndimage.label(high_val,structure=np.ones((3,3)))
print('Found {} objects'.format(num_feat))
cr_labels = label.ravel()  # Returns a flattened label
# Count up the number of pixels associated with each unique label
sizes = np.bincount(cr_labels)
arg_max = np.argmax(sizes)
sizes[arg_max] = 0
large_CRs = sizes > threshold

# Create a 2-D mask from the 1-D array of large cosmic rays, and set all
# labels of cosmic rays smaller than threshold to 0 so they are ignored.
label_mask = large_CRs[label]
high_val[~label_mask] = 0
label, num_feat = ndimage.label(high_val,
                                     structure=np.ones((3,3)))
print('After thresholding there are {} objects'.format(num_feat))

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True, figsize=(6,6))
norm = ImageNormalize(data1_list[0], stretch=LogStretch(), vmin=0, vmax=500)

axes[0, 0].imshow(data1_list[0], norm=norm, cmap='gray', origin='lower')
# axes[0, 1].imshow(segm1_list[0].data, origin='lower', cmap=segm1_list[0].cmap(random_state=1245))
axes[0, 1].imshow(data1_list[1], norm=norm, cmap='gray', origin='lower')

axes[1, 0].imshow(data1_list[2], norm=norm, cmap='gray', origin='lower')
# axes[1, 1].imshow(high_val, origin='lower', cmap=cmap)
axes[1, 1].imshow(data1_list[3], norm=norm, cmap='gray', origin='lower')
axes[0, 0].set_xlim(460, 470)
axes[0, 0].set_ylim(415, 430)

In [None]:
len(np.unique(segm1_list[0].data))/t_exptime[0]

In [None]:
len(np.unique(label))/t_exptime[0]

In [None]:
from ComputeStats import ComputeStats
from CosmicRayLabel import CosmicRayLabel

In [None]:
label_obj = CosmicRayLabel(fname1)
label_obj.label_wfpc2_data()

In [None]:
stats_obj = ComputeStats(fname1, label, label_obj.sci[1], label_obj.integration_time)

In [None]:
affected_tmp, rate_tmp, sizes_tmp, shapes_tmp, deposition_tmp = stats_obj.compute_stats(threshold=True)

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True, figsize=(6,6))
norm = ImageNormalize(label_obj.sci[1], stretch=LogStretch(), vmin=0, vmax=500)

axes[0, 0].imshow(label_obj.sci[1], norm=norm, cmap='gray', origin='lower')
axes[0, 1].imshow(label, origin='lower', cmap=segm1_list[1].cmap(random_state=1245))

axes[1, 0].imshow(data1_list[1], norm=norm, cmap='gray', origin='lower')
axes[1, 1].imshow(segm1_list[1].data, origin='lower', cmap=segm1_list[1].cmap(random_state=1245))