In [1]:
import json
import time
import numpy as np
import bqplot as bq
import pandas as pd
import ipywidgets as ipyw
from regression_callable import run_neural_net
import traitlets as trt
from bqplot.interacts import (
    BrushSelector,
    BrushIntervalSelector,
)

# open our json file
f = open('nn_regression.json', 'r')
model_data = json.load(f)
f.close()
bounds = model_data['bounds']

In [2]:
# generate one million uniformly distributed points for each bound, put into dataframe
points = {}
for dimension, bound in bounds.items():
    points[dimension] = np.random.uniform(bound[0], bound[1], 1000)

df = pd.DataFrame(points)

In [3]:
# now lets get the target column which will be done by evaluating the function...
from functools import partial
neural_net = partial(run_neural_net, model=model_data)
x = neural_net(**points)
df['target'] = x

In [16]:
class InteractivePlot(bq.figure.Figure):
    data: pd.DataFrame = trt.Instance(pd.DataFrame, allow_none=True, args=())
    scales: dict = trt.Dict(default_value={'x':bq.LinearScale(), 'y':bq.LinearScale()})
    axes_dict: dict = trt.Dict()
    x_var: str = trt.Unicode(default_value=None, allow_none=True)
    y_var: str = trt.Unicode(default_value='', allow_none=True)
    brush = trt.Any()
    selected_indices: dict = trt.Dict(default_value={})
    filters: dict = trt.Dict(default_value=[])
    
    @trt.default("axes_dict")
    def _make_axes(self):
        return {
            'x': bq.Axis(scale=self.scales['x'], label=self.x_var),
            'y': bq.Axis(scale=self.scales['y'], orientation='vertical', label=self.y_var),
        }
    
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.min_aspect_ratio = 1
        self.max_aspect_ratio = 3
        self.interaction = self.brush
        self.marks = [self.points]
        self.fig_margin={'top':0, 'bottom':0, 'left':0, 'right':0}
        self.layout = ipyw.Layout(width="100%", height="100%", margin="0")
        
    @trt.observe("data")
    def _data_updated(self, *_):
        pass
        
class ScatterPlot(InteractivePlot):
    brush: BrushSelector = trt.Instance(BrushSelector)
    points: bq.marks.ScatterGL = trt.Instance(bq.marks.ScatterGL, allow_none=True, args=())

    @trt.default("brush")
    def _make_default_brush(self):
        brush = BrushSelector(x_scale=self.scales['x'], y_scale=self.scales['y'])
        brush.observe(self.update_brushing, "brushing")
        return brush
    
    def update_brushing(self, change):
        if self.brush.brushing is False:
            if self.points.selected is not None:
                self.selected_indices = {self.y_var:change['owner'].selected_y, self.x_var:change['owner'].selected_x}
    
    @trt.observe('x_var', 'y_var')
    def update_plot(self, *_):
        self.xdata = self.data[self.x_var]
        self.ydata = self.data[self.y_var]  
        self.scales['x'] = bq.LinearScale()
        self.scales['y'] = bq.LinearScale()
        
        self.axes_dict = self._make_axes()
        self.points.x = self.xdata
        self.points.y = self.ydata
        self.points.scales = self.scales
        self.points.opacity = [.3]*len(self.xdata)
        self.points.size = [.3]*len(self.xdata)
        self.points.selected_style={'opacity':'1'}
        self.points.unselected_style={'opacity':'.2'}
        self.axes = list(self.axes_dict.values())
        self.brush.marks=[self.points]


class Hist(InteractivePlot):
    brush: BrushIntervalSelector = trt.Instance(BrushIntervalSelector)
    points: bq.marks.Hist = trt.Instance(bq.marks.Hist, allow_none=True, args=())
    
    @trt.default("brush")
    def _make_default_brush(self):
        brush = BrushIntervalSelector(scale=self.scales['x'])
        brush.observe(self.update_brushing, "brushing")
        return brush
    
    def update_brushing(self, change):
        if self.brush.brushing is False:
            if self.points.selected is not None:
                self.selected_indices = {self.x_var:change['owner'].selected}
    
    @trt.observe('x_var')
    def update_plot(self, *_):
        self.xdata = self.data[self.x_var]
        self.scales['x'] = bq.LinearScale() 
        self.axes_dict = self._make_axes() 
        self.points = bq.marks.Hist(sample=self.xdata, scales={'sample':self.scales['x'], 'count':bq.LinearScale()})
        self.axes = [self.axes_dict['x']]
        self.axes[0].scale.max = max(self.xdata)
        self.axes[0].scale.min = min(self.xdata)
        self.brush.marks = [self.points]

In [29]:
class ScatterPlotMatrix(ipyw.GridBox):
    bounds: dict = trt.Dict(default_value={})
    data: pd.DataFrame = trt.Instance(pd.DataFrame, args=())
    filtered_data: pd.DataFrame = trt.Instance(pd.DataFrame, args=())
    
    def __init__(self, data, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.data = data
        self.vars = list(data.columns)
        #### just for testing #####
        self.vars = self.vars[::3]
        #### ################ #####
    
        # button for reset
        reset_button = ipyw.Button(description="reset filters")
        reset_button.on_click(self.reset_bounds)
    
        self.plots = []
        children = []
        #####################################################################
        # TODO: wrap this into a function similar to Dane
        default_margin = {'top':0, 'bottom':0, 'left':60,'right':0}
        for i,var1 in enumerate(self.vars):
            for j,var2 in enumerate(self.vars):
                if j > i:
                    children.append(ipyw.VBox())
                    continue
                if i == j:
                    plot = Hist(data=data, x_var=var2)
                else:
                    plot = ScatterPlot(data=data, x_var=var2, y_var=var1)

                if i != len(self.vars)-1:
                    plot.axes[0].tick_style = {'font-size':0}
                    plot.axes[0].label = ''
                
                if j != 0:
                    try:
                        plot.axes[1].tick_style = {'font-size':0}
                        plot.axes[1].label = ''
                    except:
                        if not isinstance(plot, Hist):
                            plot.axes[0].tick_style = {'font-size':0}
                            plot.axes[0].label = ''
                          
                if j == 0:
                    plot.fig_margin = {'top':0, 'bottom':0, 'left':60, 'right':0}
                
                if i == len(self.vars)-1:
                    if j==0:
                        plot.fig_margin = {'top':0, 'bottom':30, 'left':60, 'right':0}
                    else:
                        plot.fig_margin = {'top':0, 'bottom':30, 'left':0, 'right':0}
                
                children.append(plot)
                self.plots.append(plot)

        ##################################################################### 
        # create our default bounds
        self.default_bounds = {var:(min(data[var]), max(data[var])) for var in self.vars}
        self.bounds = self.default_bounds
        for plot in self.plots:
            plot.observe(self.update_bounds, "selected_indices")
            trt.dlink((self, "filtered_data"), (plot, "data"))
        self.children=children + [reset_button]
        
        # add in scaling with parameters
        layout=ipyw.Layout(
                width='100%',
                grid_template_columns='28% 24% 24% 24%',
                grid_template_rows='24% 24% 24% 28%',
                grid_gap='0px 0px',
                height='100%'
            )

        self.layout=layout
        
        
    def reset_bounds(self, *_):
        self.bounds = self.default_bounds
        self._update_plots(reset=True)
    
    def update_bounds(self, change):
        # when any selected indices change, we see if they expand our bounds
        owner = change['owner']
        new_bounds = owner.selected_indices
        change = False
        for var, bounds in new_bounds.items():
            if bounds is None:
                continue
            if bounds[0] > self.bounds[var][0]:
                change = True
                break
            elif bounds[1] < self.bounds[var][1]:
                change = True
                break
            else:
                change = False  
        if change:
            old_bounds = self.bounds
            old_bounds.update(new_bounds)
            self.bounds = old_bounds
            self._update_plots(reset=False)
            
    @trt.observe("bounds")
    def _update_plots(self, reset,*_):
        # lets do the updating of the dataframe here... 
        filtered_data = self.data.copy()
        for var in self.vars:
            filtered_data = filtered_data[(filtered_data[var]>=self.bounds[var][0]) & (filtered_data[var]<=self.bounds[var][1])]
        self.filtered_data = filtered_data

In [30]:
y = ScatterPlotMatrix(data=df)

           s6        s1        s5        s3        bp       sex       age  \
0   -0.133660  0.088087  0.046904  0.002966  0.010447 -0.010891  0.069047   
1   -0.035847 -0.032317  0.106514  0.165962  0.076178 -0.010248 -0.088393   
2    0.123674 -0.033558 -0.081485  0.127373  0.096450 -0.031086  0.091679   
3    0.028077  0.056038 -0.009792  0.063629  0.012026 -0.011794  0.102296   
4    0.076679 -0.004720 -0.031120 -0.043431 -0.034116 -0.026502 -0.045997   
..        ...       ...       ...       ...       ...       ...       ...   
995 -0.060683 -0.109912 -0.009358  0.103041 -0.087415  0.002105  0.090132   
996  0.010765  0.100682  0.091697  0.088698  0.097804 -0.012271 -0.021411   
997 -0.004750 -0.122496  0.093534 -0.016078 -0.056651  0.010000 -0.081203   
998  0.023128  0.022307  0.004470  0.136345 -0.052627  0.014907  0.038251   
999  0.127445 -0.016749 -0.106671 -0.002123 -0.060178  0.025048 -0.075558   

           s4       bmi        s2      target  
0    0.121270  0.104417  0.

In [31]:
y

ScatterPlotMatrix(children=(Hist(axes=[Axis(scale=LinearScale(max=0.13514865037440885, min=-0.1372018765475934…

In [490]:
mark = y.children[0].points

In [491]:
mark

Hist(colors=['steelblue'], count=array([1., 1., 1., 0., 0., 0., 0., 0., 0., 2.]), interactions={'hover': 'tooltip'}, midpoints=[0.0022860421329527576, 0.008142460208800412, 0.013998878284648067, 0.01985529636049572, 0.025711714436343378, 0.03156813251219103, 0.03742455058803869, 0.04328096866388634, 0.049137386739734], sample=array([ 0.05424654,  0.05792201,  0.0075429 , -0.00064217,  0.01270929]), scales={'sample': LinearScale(), 'count': LinearScale()}, scales_metadata={'sample': {'orientation': 'horizontal', 'dimension': 'x'}, 'count': {'orientation': 'vertical', 'dimension': 'y'}}, selected=array([0, 1, 2, 3, 4], dtype=uint32), tooltip_style={'opacity': 0.9})

In [500]:
y.children[0].axes[0].scale.min = -.14
y.children[0].axes[0].scale.max = .14

In [None]:
# formatting
# put into subclasses
# TRAITLETS TRAITLETS TRAITLETS
# vaex? 