In [None]:
import bqplot
import os
from bqplot_image_gl import ImageGL
from astropy.io import fits
from astropy.visualization import PercentileInterval
from bqplot import Figure, LinearScale, Axis, ColorScale
from bqplot import pyplot as plt
from bqplot import Toolbar
import numpy as np
import ipyvuetify as v
from ipywidgets import Layout
import ipywidgets as widgets
import traitlets


In [None]:
# this is a widget for debugging
out = widgets.Output()
out.add_traits(_metadata=traitlets.Dict(default_value={'mount_id': 'out'}).tag(sync=True))
out

In [None]:
# set the jwst target sources and randomly select one to start with
jwst_sources = [227, 482, 546, 1186]
#selected_src = jwst_sources[np.random.randint(0,3)]

def prep_data(source):
    ''' prepare test data '''
    path = '/Users/bcherinka/Work/mosviz/data'
    path1d = os.path.join(path, 'jwst_level3_NIRSpec/jw95065_{0}_nrs_msaspec_barshadow_x1d.fits'.format(source))
    path2d = os.path.join(path, 'jwst_level3_NIRSpec/jw95065_{0}_nrs_msaspec_barshadow_s2d.fits'.format(source))
    pathimg = os.path.join(path, 'jwst_level3_NIRSpec/cutouts/SOURCEID_{0}.fits'.format(source))
    hdu1d = fits.open(path1d)
    hdu2d = fits.open(path2d)
    cut = fits.open(pathimg)
    return hdu1d, hdu2d, cut

In [None]:
# get the data for the selected source
# hdu1d, hdu, cut = prep_data(selected_src)

In [None]:
def get_xy(data):
    ''' get the xy data '''
    wave = data['WAVELENGTH']
    flux = data['FLUX']
    return wave, flux

# # get the x and y values for the 1d spectrum plot
# wave, flux = get_xy(hdu1d[1].data)

In [None]:
def create_spec1d(x, y):
    ''' create a bqplot line chart for a 1d spectrum '''
    fig = plt.figure(title='Spectrum 1d')
    p = plt.plot(x, y)
    fig.layout.width = 'auto'
    fig.layout.height = 'auto'
    fig.layout.min_height = '400px' # so it shows nicely in the notebook
    return fig, p

# # create the 1d spectrum plot
# fig, plot = create_spec1d(wave, flux)
# fig

In [None]:
def create_spec2d_heatmap(data):
    ''' create a 2d spectrum view as a bqplot Heatmap'''
    # trim the data to 95% of values
    i = PercentileInterval(95)
    limdata = i(data)

    # create the heatmap
    spec2d = plt.figure(padding_y=0)
    plt.scales(scales={'color': ColorScale(scheme='Greys', reverse=True)})
    axes_options = {'color': {'visible': False}}
    heat = plt.heatmap(limdata, axes_options=axes_options)
    return spec2d, heat

In [None]:
def create_spec2d_image(data):
    ''' create a 2d spectrum view as a bqplot ImageGL'''
    scale_x = LinearScale()
    scale_y = LinearScale()
    scales = {'x': scale_x,
              'y': scale_y}
    axis_x = Axis(scale=scale_x, label='x')
    axis_y = Axis(scale=scale_y, label='y', orientation='vertical')

    spec2d = Figure(scales=scales, axes=[axis_x, axis_y])

    scales_image = {'x': scale_x,
                    'y': scale_y,
                    'image': ColorScale(scheme='viridis', min=0, max=3)}

    image = ImageGL(image=hdu[1].data, scales=scales_image)

    spec2d.marks = (image,)
    return spec2d

In [None]:
# # create the 2d spectra image as a heatmap
# spec2d, heat = create_spec2d_heatmap(hdu[1].data)
# spec2d

In [None]:
def create_cutout(data, target=None):
    ''' create image cutout as bqplot Heatmap '''
    aspect_ratio = data.shape[1] / data.shape[0]
    img = plt.figure(title='Src {0}'.format(target), layout=Layout(width='500px', height='500px'),
                     min_aspect_ratio=aspect_ratio, 
                     max_aspect_ratio=aspect_ratio, padding_y=0)
    plt.scales(scales={'color': ColorScale(scheme='Greys', reverse=True)})
    axes_options = {'x': {'visible': False}, 'y': {'visible': False}, 'color': {'visible': False}}
    p = plt.heatmap(data, axes_options=axes_options)
    return img, p

In [None]:
# # create the image cutout
# img, image = create_cutout(cut[0].data, target=selected_src)
# img

In [None]:
# @out.capture()
# def update_data(widget, event, data):
#     # get new data from selected target
#     selected_src = int(data)
#     hdu1d, hdu, cut = prep_data(selected_src)

#     # update spec1d data
#     wave, flux = get_xy(hdu1d[1].data)
#     plot.y = flux
#     plot.x = wave

#     # update spec2d data
#     i = PercentileInterval(95)
#     heat.color = i(hdu[1].data)

#     print('selected source', selected_src)
#     print('old image title', img.title)

#     # update image cutout data
#     image.color = cut[0].data
#     img.title = 'Src {0}'.format(selected_src)
    
#     print('new image title', img.title)


In [None]:
@out.capture()
def update_data(selected_src):
    hdu1d, hdu, cut = prep_data(selected_src)
    
    # update spec1d data
    wave, flux = get_xy(hdu1d[1].data)
    fig, plot = create_spec1d(wave, flux)
    print('plot info', len(plot.x), len(plot.y))

    spec2d, heat = create_spec2d_heatmap(hdu[1].data)
    print('selected source', selected_src)
   
    img, image = create_cutout(cut[0].data, target=selected_src)
 
    print('new image title', img.title)
    children = [
            # load the histogram and slider content
            v.Col(xs12=True, lg6=True, xl4=True, children=[
                img
            ]),

            # load the line plot content
            v.Col(xs12=True, xl4=True, children=[
                spec2d
            ]),
            # load the line plot content
            v.Col(xs12=True, xl4=True, children=[
                fig
            ]),       
    ]
    row.children = children

In [None]:
# this class is needed to allow front-end defined components to send events to notebook defined components
# class UpdateRow(v.Row):
#     def __init__(self, **kwargs):
#         super().__init__(**kwargs)
#         self.on_event('change', update_data)    

In [None]:
# # create the Vuetify layout using v.Row and v.Col (Vuetify API v2.0 spec - replaces v.Layout and v.Flex, respectively)
# row = v.Row(_metadata={'mount_id': 'protospec'}, dense=True, row=True, wrap=True, align_center=True, children=[

#     # load the histogram and slider content
#     v.Col(xs12=True, lg6=True, xl4=True, children=[
#         img
#     ]),

#     # load the line plot content
#     v.Col(xs12=True, xl4=True, children=[
#         spec2d
#     ]),
#     # load the line plot content
#     v.Col(xs12=True, xl4=True, children=[
#         fig
#     ]),       
# ])

# # attach a change event so the Layout to update the data of all three components
# row.on_event('change', update_data)

In [None]:
@out.capture()
def update_data_handler(widget, event, data):
    # get new data from selected target
    selected_src = int(data)
    print('selected_src', selected_src)
    update_data(selected_src)

In [None]:
row = v.Row(_metadata={'mount_id': 'protospec'}, dense=True, row=True, wrap=True, align_center=True, children=['Loading..'])
row.on_event('change', update_data_handler)

In [None]:
# create an app bar to load a select button
btn1 = v.Select(dense=True, outlined=True, label='Target', items=jwst_sources)
btn1.on_event('change', update_data_handler)
v.AppBar(_metadata={'mount_id': 'appbar'}, app=True, absolute=True, children=[
    v.Col(class_='col-md-2 col-lg-2', children=[btn1])
])