In [1]:
    import gridplot

In [None]:
    gridplot.Plots

In [252]:
    __all__ = 'label', 'Plots'

In [253]:
    import pandas as pd
    from bokeh import plotting, models, layouts
    from operator import gt, lt, eq
    from sklearn.cluster import KMeans
    import collections, itertools, operator
    from operator import eq, lt, gt    

In [254]:
    class DeepChainMap(collections.ChainMap):
        def __setitem__(self, key, value):
            for mapping in self.maps:
                if key in mapping:
                    mapping[key] = value
                    return
            self.maps[0][key] = value

        def __delitem__(self, key):
            for mapping in self.maps:
                if key in mapping:
                    del mapping[key]
                    return
            raise KeyError(key)
            
        def new_child(self, m=None):
            return setattr(self, 'maps', super().new_child(m).maps) or self
        
        def complement(self):
            return type(self)(*reversed(self.maps))

In [255]:
    def replaceRowCol(obj, **kwargs):
        if isinstance(obj, str):
            obj = obj.format(**kwargs)
        if isinstance(obj, dict):
            obj = {k: replaceRowCol(v, **kwargs) for k, v in obj.items()} 
        return obj

In [256]:
    class Plots(object):
        def __init__(
            self, sources, features=None, 
            figures=None, glyph=None, agg=pd.Series.describe, **kwargs
        ):
            if isinstance(sources, pd.DataFrame):
                sources = (
                    DeepChainMap({None: sources})
                    .new_child({None: sources.pipe(plotting.ColumnDataSource)}))
            self.sources = sources
            self.features = features or list(df.columns)
            self.figures = figures
            self.glyph = glyph
            self.agg = agg
            self._prepare_diag_source(**kwargs)
            self.figures or self.reset(**kwargs)
            
            
            
        def __getitem__(self, key, **kwargs):
            if not isinstance(key, tuple):
                key = key, 
            return type(self)({
                k: self.sources[k] for k in key
            }, self.features, self.figures, self.glyph, self.agg, **kwargs)
            

        def reset(self, **kwargs):
            self._diag, self._upper, self._lower = {}, {}, {}
            index = type(self.sources.complement()[None].index)
            for cmp, object in zip([eq, gt, lt], operator.attrgetter('_diag', '_lower', '_upper')(self)):
                for row, col in self:
                    if cmp(*map(self.features.index, [row, col])):
                        args = dict(width=200, height=200)
                        if row == col and index is pd.DatetimeIndex:
                            args.update(x_axis_type='datetime')
                        object[(row, col)] = plotting.Figure(**{
                            **args, **kwargs})
                    
            self.figures = DeepChainMap(self._diag, self._upper, self._lower)

        def __iter__(self):
            for row, col in itertools.product(*[self.features]*2):
                yield row, col
                
        @property
        def diagonal(self):
            return Diagonal(self.sources, self.features, self._diag)
        
        @property
        def upper(self):
            return type(self)(self.sources, self.features, self._upper)
        
        @property
        def lower(self):
            return type(self)(self.sources, self.features, self._lower)
        
        def __getattr__(self, key):
            return setattr(self, 'glyph', getattr(models, key)) or self
        
        def __call__(self, **kwargs):
            for source in [self.sources[None]] if None in self.sources else self.sources.values():
                for (r, c), p in self.figures.items():
                    p.add_glyph(source, self.glyph(**replaceRowCol(kwargs, row=r, col=c)))
            return self
        
        def apply(self, func):
            for (r, c), p in self.figures.items():
                func(p)
            return self
            
        def layout(self):
            return layouts.gridplot([self.figures[(r,c)] for r, c in self], ncols=len(self.features))
        
        def show(self):
            return plotting.show(self.layout())
        
        def _prepare_diag_source(self, **kwargs):
            if isinstance(self.sources, collections.ChainMap):
                df = self.sources.complement()[None]
                for feature in self.features:
                    if feature not in self.sources:
                        s = df[feature]
                        g = s.groupby(
                            s.index if type(df.index) is pd.CategoricalIndex
                            else s.index
                            if type(df.index) is pd.DatetimeIndex
                            else s.pipe(pd.np.digitize, pd.np.linspace(s.min(), s.max(), 11))
                        ).describe().to_frame()
                        self.sources.new_child({
                            feature: g[feature].unstack(-1).pipe(plotting.ColumnDataSource)})       

In [257]:
    class Diagonal(Plots):
        def __call__(self, **kwargs):
            for (r, c), p in self.figures.items():
                assert r == c
                p.add_glyph(self.sources[r], self.glyph(**replaceRowCol(kwargs, row=r, col=c)))
            return self

In [258]:
    def label(df, model=KMeans, **kwargs):
        if callable(model):
            model = model()
        model.set_params(**kwargs)
        return df.set_index(pd.CategoricalIndex(model.fit(df).predict(df)))