In [None]:
import logging
format_ = "%(asctime)s %(name)s-%(levelname)s "\
         + "[%(pathname)s %(lineno)d] %(message)s"
formatter = logging.Formatter(format_)
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
logger.handlers.clear()
handler = logging.FileHandler(filename='ipynb.log', encoding='utf-8', mode='a+')
handler.setFormatter(formatter)
handler.setLevel(logging.DEBUG)
logger.setLevel(logging.DEBUG)
logger.addHandler(handler)
logger.debug("Jupyter logging configured.")

In [None]:
import astropy
from astroquery.sdss import SDSS

from bokeh.io import curdoc
from bokeh.models import CategoricalColorMapper, ColumnDataSource,\
    CustomJS, LassoSelectTool, BoxSelectTool, Range1d, ResetTool
from bokeh.models.formatters import NumeralTickFormatter
from bokeh.models.selections import Selection
from bokeh.plotting import figure
from bokeh.events import Reset

import ipyaladin as ipyal

import ipywidgets as widgets

import numpy as np

import pandas as pd

import astropixie_widgets  as aww
aww.config.setup_notebook()


class SHRD():
    """
    Skyviewer and HR Diagram Widget.
    """
    aladin = None
    pf = None
    doc = None
    region= None
    selection_ids = None

    def __init__(self):
        self._skyviewer()
        self._catalog()

    def _skyviewer(self):
        self.aladin = ipyal.Aladin(target='Berkeley 20', fov=0.42, survey='P/SDSS9/color')
        self.aladin.show_reticle = False
        self.aladin.show_zoom_control = False
        self.aladin.show_fullscreen_control = False
        self.aladin.show_layers_control = False
        self.aladin.show_goto_control = False
        self.aladin.show_share_control = False
        self.aladin.show_catalog = True
        self.aladin.show_frame = False
        self.aladin.show_coo_grid = False
        return self.aladin

    def _catalog(self):
        query = """
SELECT TOP 3200
       p.objID,
       p.ra,
       p.dec,
       p.u,
       p.g,
       p.r,
       p.i,
       p.z
FROM PhotoPrimary AS p
JOIN dbo.fGetNearbyObjEq(83.15416667, 0.18833333, 3.24) AS r ON r.objID = p.objID
WHERE p.clean = 1 and p.probPSF = 1
"""
        self.cat = SDSS.query_sql(query)
        if(self.aladin):
            self.aladin.add_table(self.cat)
        return self.cat
    
    def _hr_diagram_select(self, doc):
        self.region = aww.data.SDSSRegion(self.cat.copy())
        temps, lums = aww.science.teff(self.region), aww.science.luminosity(self.region)
        ids = self.region.ids()
        x, y = temps, lums
        colors, color_mapper = aww.visual.hr_diagram_color_helper(temps)
        x_range = [max(x) + max(x) * 0.05, min(x) - min(x) * 0.05]
        source = ColumnDataSource(data=dict(x=x, y=y, id=ids, color=colors), name='hr')
        name = 'hr'
        color = {'field': 'color',
                 'transform': color_mapper}
        xaxis_label = 'Temperature (Kelvin)'
        yaxis_label = 'Luminosity (solar units)'
        line_color = '#444444'
        self.pf = figure(y_axis_type='log', x_range=x_range,
                         tools='lasso_select,box_select,reset',
                         title='H-R Diagram for {0}'.format(self.region.name))
        self.pf.select(LassoSelectTool).select_every_mousemove = False
        self.pf.select(LassoSelectTool).select_every_mousemove = False
        self.session = self.pf.circle(x='x', y='y', source=source,
                                 size=4, color=color, alpha=1, name=name,
                                 line_color=line_color, line_width=0.5)
        self.pf.xaxis.axis_label = xaxis_label
        self.pf.yaxis.axis_label = yaxis_label
        self.pf.yaxis.formatter = NumeralTickFormatter()
        doc.add_root(self.pf)
        def reset_(event):
            logger.debug('reset!')
            #session.data_source = ColumnDataSource(data=dict(x=x, y=y, id=ids, color=colors), name='hr')
            #self.selection_ids = None
            #self.aladin.selection_ids = None
        self.doc = doc
        self.aladin.selection_update = self.meta_selection_update
        self.session.data_source.on_change('selected', self._hr_selection)
        self.pf.on_event(Reset, reset_)
        
    def _hr_selection(self, attr, old, new):
        inds = np.array(new['1d']['indices'])
        try:
            selection_ids = np.take(self.region.cat['objID'], inds)
        except Exception as e:
            logger.warning(e)
        self.aladin.selection_ids = [str(s) for s in selection_ids]
        
    def show(self):
        try:
            widgets.widget.display(self.aladin)
            self.aladin.add_table(self.cat)
            aww.visual.show(self._hr_diagram_select)
        except Exception as e:
            logger.debug(e)

    def _filter_selection(self, selection_ids):
        selection_ids = [np.int64(i) for i in self.selection_ids]
        region_selected = type(self.region)(self.cat.copy())
        arr = region_selected.to_array()
        df = pd.DataFrame(arr.flatten(), index=arr['id'].flatten(),
                          columns=[d[0] for d in region_selected._dtype])
        df_selected = df[df['id'].isin(selection_ids)]
        region_selected.cat = astropy.table.Table(
            rows=df_selected.values,
            names=[d[0] for d in region_selected._dtype],
            dtype=[d[1] for d in region_selected._dtype])
        temps, lums = aww.science.teff(self.region), aww.science.luminosity(self.region)
        return temps, lums, df['id'] #selection_ids

    def _filter_selection_indices(self, selection_ids):
        selection_ids = [np.int64(i) for i in self.selection_ids]
        region_selected = type(self.region)(self.cat.copy())
        arr = region_selected.to_array()
        df = pd.DataFrame(arr.flatten(), index=arr['id'].flatten(),
                          columns=[d[0] for d in region_selected._dtype])
        df_selected = df[df['id'].isin(selection_ids)]
        select_indices = list(np.where(df['id'].isin(selection_ids))[0])
        return select_indices
    
    def _skyviewer_selection(self):
        try:
            if self.pf:
                selected = self.pf.select(name='hr')
                if selected:
                    new_temps, new_lums, new_ids = self._filter_selection(self.selection_ids)
                    indices = self._filter_selection_indices(self.selection_ids)
                    colors, color_mapper = aww.visual.hr_diagram_color_helper(new_temps)
                    selection = Selection(indices=indices)
                    new_source = ColumnDataSource(
                        data=dict(x=new_temps, y=new_lums, ids=new_ids, color=colors),
                        selected=selection, name='hr')
                    if isinstance(selected[0], ColumnDataSource):
                        selected_old = selected[0].selected
                        self.session.data_source = new_source
                    elif selected[0]:
                        selected_old = selected[0].data_source.selected
                        selected[0].data_source = new_source
                    self.session.data_source.trigger('selected', selected_old, selection)
                    self.session.data_source.on_change('selected', self._hr_selection)
            else:
                logger.warning('Figure does not exist.')
        except Exception as e:
            logger.warning(e)

    def meta_selection_update(self, selection_ids):
        self.selection_ids = selection_ids
        self.doc.add_next_tick_callback(self._skyviewer_selection)

try:
    shrd = SHRD()
    shrd.show()
except Exception as e:
    logger.debug(e)

In [None]:
import astropixie_widgets as aww
aww.config.setup_notebook()

try:
    shrd = aww.visual.SHRD()
    shrd.show()
except Exception as e:
    logger.debug(e)