In [3]:
%matplotlib widget
import glob
import os
from mpl_toolkits.axes_grid1 import make_axes_locatable

from astropy.io import fits
from astropy.stats import sigma_clipped_stats
from astropy.visualization import ImageNormalize, SqrtStretch, LogStretch, LinearStretch, ZScaleInterval, ManualInterval
import ipywidgets as widgets
from ipywidgets import interact, fixed, interactive, VBox, HBox
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.ticker import (MultipleLocator, FormatStrFormatter,
                               AutoMinorLocator)
from matplotlib import ticker
# plt.style.use('dark_background')
plt.style.use('ggplot')

import numpy as np

In [4]:
_BASE = '/Users/nmiles/hst_cosmic_rays/analyzing_cr_rejection/'

In [5]:
img_to_compare = 'odvbbpa5q_all.fits' # random 60 seconds exposure
exptime=60.0

# img_to_compare = 'odvbbozpq_all.fits' # random 1100 second exposure
# exptime=1100.0

In [6]:
param1 = 'med_4'
img1 = os.path.join(_BASE,f'{exptime}_{param1}', img_to_compare)

param2 = 'med_6.5,5.5,4.5'
img2 = os.path.join(_BASE,f'{exptime}_{param2}', img_to_compare)

param3 = 'min_6.5,5.5,4.5'
img3 = os.path.join(_BASE,f'{exptime}_{param3}', img_to_compare)

In [7]:
def read_image(img):
    data = {}
    with fits.open(img) as hdu:
        data['sci'] = hdu[0].data
        data['dq'] = hdu[1].data
        data['crlabel'] = hdu[2].data
    return data

In [8]:
data1 = read_image(img1)
data2 = read_image(img2)
data3 = read_image(img3)
mean, med, std = sigma_clipped_stats(data1['sci'], sigma=5)

In [9]:
norm = ImageNormalize(data1['sci'], stretch=SqrtStretch(), interval=ManualInterval(vmin=0, vmax=med+10*std))

In [10]:
def generate_cmap(ncolors):
#     ncolors = np.max(label) + 1
#     print('Generating colormap for label')
    prng = np.random.RandomState(1234)
    h = prng.uniform(low=0.0, high=1.0, size=ncolors)
    s = prng.uniform(low=0.2, high=0.7, size=ncolors)
    v = prng.uniform(low=0.5, high=1.0, size=ncolors)
    hsv = np.dstack((h, s, v))

    rgb = np.squeeze(colors.hsv_to_rgb(hsv))
    rgb[0] = (0,0,0)
    cmap = colors.ListedColormap(rgb)
    return cmap

In [11]:
xmin_slider = widgets.IntText(
    options=[i for i in range(1,1025)],
    value=500,
    description='xmin',
    disabled=False,
    continuous_update=True,
    orientation='horizontal',
    readout=True 
)
xmax_slider = widgets.IntText(
    options=[i for i in range(1,1025)],
    value= 540,
    description='xmax',
    disabled=False,
    continuous_update=True,
    orientation='horizontal',
    readout=True 
)

ymin_slider = widgets.IntText(
    options=[i for i in range(1,1025)],
    value=500,
    description='ymin',
    disabled=False,
    continuous_update=True,
    orientation='horizontal',
    readout=True 
)
ymax_slider = widgets.IntText(
    options=[i for i in range(1,1025)],
    value= 540,
    description='ymax',
    disabled=False,
    continuous_update=True,
    orientation='horizontal',
    readout=True 
)
l1 = widgets.link((xmin_slider, 'value'), (ymin_slider, 'value'))
l2 = widgets.link((xmax_slider, 'value'), (ymax_slider, 'value'))

In [12]:
xmin_slider = widgets.IntSlider(
    min=0,
    max=1024,
    step=10,
    value=500,
    description='xmin',
    disabled=False,
    continuous_update=True,
    orientation='horizontal',
    readout=True 
)
xmax_slider = widgets.IntSlider(
    min=0,
    max=1024,
    step=10,
    value=540,
    description='xmax',
    disabled=False,
    continuous_update=True,
    orientation='horizontal',
    readout=True 
)

ymin_slider = widgets.IntSlider(
    min=0,
    max=1024,
    step=10,
    value=500,
    description='ymin',
    disabled=False,
    continuous_update=True,
    orientation='horizontal',
    readout=True 
)
ymax_slider = widgets.IntSlider(
    min=0,
    max=1024,
    step=10,
    value=540,
    description='ymax',
    disabled=False,
    continuous_update=True,
    orientation='horizontal',
    readout=True 
)

l1 = widgets.link((xmin_slider, 'value'), (ymin_slider, 'value'))
l2 = widgets.link((xmax_slider, 'value'), (ymax_slider, 'value'))

In [13]:
def plot_results(
    axes=None,
    data1=None, 
    data2=None, 
    data3=None,
    figsize=(10,6),
    norm=None,
    param1=None,
    param2=None,
    param3=None,
    xmin=None,
    xmax=None,
    ymin=None,
    ymax=None
):
#     fig, axes = plt.subplots(nrows=3, ncols=3, sharex=True, sharey=True, gridspec_kw={'wspace':0.01, 'hspace':0.15})
    first_col = axes[:, 0]
    second_col = axes[:, 1]
    third_col = axes[:, 2]
    cbar_bounds = [0, 8, 16, 32, 64, 256, 1024, 4096, 8192]
    dq_cmap = plt.cm.inferno
    dq_norm = colors.BoundaryNorm(boundaries=cbar_bounds,
                                      ncolors=dq_cmap.N)
    fig.canvas.layout.width = f'{figsize[0]}in'
    fig.canvas.layout.height= f'{figsize[1]}in'
    for ax, key in zip(first_col, data1.keys()):
        ax.grid(False)
        ax.set_xlim((xmin, xmax))
        ax.set_ylim((ymin, ymax))
        if key == 'sci':
            ax.set_title(f'SCI {param1}')
            ax.imshow(data1[key], norm=norm, cmap='gray', origin='lower')
        elif key == 'dq':
            ax.set_title('DQ')
            im = ax.imshow(data1[key], cmap='inferno', norm=dq_norm, origin='lower')
        elif key == 'crlabel':
            ax.set_title('CR LABEL')
            cmap_label1 = generate_cmap(np.max(data1[key]) + 1)
            ax.imshow(data1[key], cmap=cmap_label1, origin='lower')
        
    for ax, key in zip(second_col, data2.keys()):
        ax.grid(False)
        if key == 'sci':
            ax.set_title(f'SCI {param2}')
            ax.imshow(data2[key], norm=norm, cmap='gray', origin='lower')
        elif key == 'dq':
            ax.set_title('DQ')
            ax.imshow(data2[key], cmap='inferno', norm=dq_norm, origin='lower')
            
        elif key == 'crlabel':
            ax.set_title('CR LABEL')
            cmap_label2 = generate_cmap(np.max(data2[key]) + 1)
            ax.imshow(data2[key], cmap=cmap_label2, origin='lower')
            
    for ax, key in zip(third_col, data3.keys()):
        ax.grid(False)
        if key == 'sci':
            ax.set_title(f'SCI {param3}')
            sci_im = ax.imshow(data3[key], norm=norm, cmap='gray', origin='lower')
            sci_divider = make_axes_locatable(ax)
            sci_cax = sci_divider.append_axes("right", size="5%", pad=0.05)
            sci_cbar = fig.colorbar(sci_im, cax=sci_cax)
            sci_cbar.set_label(f"SCI Value (COUNTS)")
        elif key == 'dq':
            ax.set_title('DQ')
            im = ax.imshow(data3[key], cmap='inferno', norm=dq_norm, origin='lower')
            divider = make_axes_locatable(ax)
            cax = divider.append_axes("right", size="5%", pad=0.05)
            cbar = fig.colorbar(im, cax=cax)
            cbar.set_label('DQ Value')
        elif key == 'crlabel':
            ax.set_title('CR LABEL')
            cmap_label3 = generate_cmap(np.max(data3[key]) + 1)
            ax.imshow(data3[key], cmap=cmap_label3, origin='lower')
        
    
        
#     divider = make_axes_locatable(ax)
#     cax = divider.append_axes("right", size="5%", pad=0.05)
#     cbar = fig.colorbar(im, cax=cax)
#     cbar.set_label(f"{units}")
#     ax.set_xlim(xlim)
#     ax.set_ylim(ylim)
#     ax.set_title(title)
    plt.show()

In [14]:
out = widgets.Output(layout={'border': '1px solid black'})

In [15]:
with out:
    fig, axes = plt.subplots(nrows=3, ncols=3, sharex=True, sharey=True, gridspec_kw={'wspace':0.01, 'hspace':0.15})
    w = interactive(plot_results,axes=fixed(axes),
                    data1=fixed(data1), param1=fixed(param1),
                    data2=fixed(data2), param2=fixed(param2), 
                    data3=fixed(data3), param3=fixed(param3),
                    norm=fixed(norm), figsize=fixed((12,10)), xmin=xmin_slider, xmax=xmax_slider, ymin=ymin_slider, ymax=ymax_slider)
    display(w)

In [16]:
out

Output(layout=Layout(border='1px solid black'), outputs=({'output_type': 'display_data', 'data': {'text/plain'…

In [15]:
out.clear_output()
plt.close('all')