In [None]:
# default_exp core

# `ndpretty` implementation

This is the implementation of `ndpretty`. 

In [None]:
#export
import numpy as np
from sys import modules

import IPython
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

The first attempt of implementing the cell backgrounds was to use the `plt` colormaps, but they are hardly legible.

In [None]:
#hide
import matplotlib.pyplot as plt
from matplotlib import cm

def show_color(rgb_color):
    hex_color = '#%02x%02x%02x' % rgb_color
    html = f'<p style="background-color: {hex_color}">This is text.</p>'
    print(html)
    IPython.core.display.display(IPython.display.HTML(html))

def show_plt_color(idx, heatmap):
    rgb = tuple((np.array(heatmap.colors)[idx] * 256).astype(np.uint8))
    show_color(rgb)

# show_color(cm.viridis, 200)
interact(show_plt_color, idx=widgets.IntSlider(min=0, max=255, step=1, value=10), heatmap=fixed(cm.viridis))

interactive(children=(IntSlider(value=10, description='idx', max=255), Output()), _dom_classes=('widget-intera…

<function __main__.show_plt_color(idx, heatmap)>

Therefore we go with our own solution that is just a linear combination of two colours.

In [None]:
lowest_color = (110, 110, 255)
highest_color = (220, 55, 55)

def show_my_color(alpha):
    rgb = tuple((alpha * np.array(highest_color) + (1 - alpha) * np.array(lowest_color)).astype(np.uint8))
    show_color(rgb)

interact(show_my_color, alpha=widgets.FloatSlider(min=0, max=1, step=0.01, value=0))

interactive(children=(FloatSlider(value=0.0, description='alpha', max=1.0, step=0.01), Output()), _dom_classes…

<function __main__.show_my_color(alpha)>

Some definitions

In [None]:
#export
lowest_color = (110, 110, 255)
highest_color = (220, 55, 55)

number_format = '%.5g'

And the core implementation

In [None]:
#export
def ndarray_html(a):
    assert type(np.array([])) == np.ndarray, 'Only numpy ndarrays are supported'

    if len(a.shape) == 1:
        if a.shape[0] == 0:
            print('[]')
            return
        _html_array(a[:,np.newaxis])
    elif len(a.shape) == 2:
        _html_array(a)
    elif len(a.shape) > 2:
        d = len(a.shape)
        slice_str = "[:, :, " + "0, " * (d - 3) + "0]"
        slice_widget = widgets.Text(
            value=slice_str,
            placeholder="e.g. " + slice_str,
            description='Slice:',
            disabled=False
        )
        interact(_html_higher_d_array, a=fixed(a), slice_str=slice_widget)

def _to_HTML(a, alphas, is_numeric, lowest_color, highest_color):
    html = '<div style="overflow: auto">'
    html += '<table>'
    html += '<tr>'
    html += '<th></th>'
    for j in range(a.shape[1]):
        html += f'<th>{j}</th>'
    html += '</tr>'

    for i in range(a.shape[0]):
        html += "<tr>"
        html += f'<td><b>{i}</b></td>'
        for j in range(a.shape[1]):
            alpha = alphas[i][j]
            rgb_color = tuple((alpha * np.array(highest_color) + (1 - alpha) * np.array(lowest_color)).astype(np.uint8))
            hex_color = '#%02x%02x%02x' % rgb_color
            if is_numeric:
                value = number_format % a[i][j]
            else:
                value = str(a[i][j])
            html += '<td style="background-color: %s">%s</td>' % (hex_color, value)
        html += "</tr>"
    html += '</table>'
    html += '</div>'
    return html

def _html_higher_d_array(a, slice_str):
    try: 
        sliced_a = eval("a" + slice_str)
        assert len(sliced_a.shape) == 2, "didn't slice down to 2D"
        _html_array(sliced_a)
    except Exception as e:
        print("Invalid slice: " + str(e))

def _html_array(a):   
    is_numeric = np.issubdtype(a.dtype, np.number)

    if is_numeric:
        a_range = a.max() - a.min()
        if a_range == 0:
            alphas = np.zeros_like(a, dtype=np.float) + 0.5
        else:
            alphas = (a - a.min()) / a_range
    elif np.issubdtype(a.dtype, np.bool_):
        alphas = a.astype(np.float)
    else:
        alphas = np.zeros_like(a, dtype=np.float) + 0.5

    html = _to_HTML(a, alphas, is_numeric, lowest_color, highest_color)
    IPython.core.display.display(IPython.display.HTML(html))

## Example usages

A numeric 2D array:

In [None]:
ndarray_html(np.diag(np.ones(10)))

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
0,1,0,0,0,0,0,0,0,0,0
1,0,1,0,0,0,0,0,0,0,0
2,0,0,1,0,0,0,0,0,0,0
3,0,0,0,1,0,0,0,0,0,0
4,0,0,0,0,1,0,0,0,0,0
5,0,0,0,0,0,1,0,0,0,0
6,0,0,0,0,0,0,1,0,0,0
7,0,0,0,0,0,0,0,1,0,0
8,0,0,0,0,0,0,0,0,1,0
9,0,0,0,0,0,0,0,0,0,1


A numeric 1D array:

In [None]:
ndarray_html(np.ones((10,)))

Unnamed: 0,0
0,1
1,1
2,1
3,1
4,1
5,1
6,1
7,1
8,1
9,1


A numeric 4D array:

In [None]:
ndarray_html(np.random.rand(2, 3, 4, 5))

interactive(children=(Text(value='[:, :, 0, 0]', description='Slice:', placeholder='e.g. [:, :, 0, 0]'), Outpu…

A string array:

In [None]:
ndarray_html(np.diag(['nd', 'pretty', 'ndpretty']))

Unnamed: 0,0,1,2
0,nd,,
1,,pretty,
2,,,ndpretty


A bool array:

In [None]:
ndarray_html(np.array([True, False, True]))

Unnamed: 0,0
0,True
1,False
2,True


## Registering formatters for IPython

We don't always want to call `ndarray_html` to show our nice table. In order to make it the default formatter for cell return values, here are some helper functions to automatically register the formatters.

This makes use of IPyhton formatters as documented here: https://ipython.readthedocs.io/en/stable/config/integrating.html?highlight=third%20party#formatters-for-third-party-types

In [None]:
#export
def register_formatter(dtype, html_formatter, print_formatter=None):
    formatters = get_ipython().display_formatter.formatters
    
    formatters['text/html'].for_type(dtype, html_formatter)
    if print_formatter is not None:
        formatters['text/plain'].for_type(dtype, print_formatter)

### for `ndarray`

In [None]:
#export
def ndarray_stats_print_formatter(x, _, __):
    print('×'.join(map(str, x.shape)) + " " + str(x.dtype) + ' ndarray')

def no_print_formatter(x, _, __):
    return

def register_ndarray_formatter(print_formatter=ndarray_stats_print_formatter):
    register_formatter(np.ndarray, ndarray_html, print_formatter)

In [None]:
register_ndarray_formatter()

In [None]:
np.random.rand(20, 4, 3)

20×4×3 float64 ndarray


interactive(children=(Text(value='[:, :, 0]', description='Slice:', placeholder='e.g. [:, :, 0]'), Output()), …



### for `torch.Tensor`

We also define default formatter for PyTorch `Tensor`s

In [None]:
#export
def torch_tensor_html(t):
    ndarray_html(t.numpy())

def tensor_stats_print_formatter(x, _, __):
    print('×'.join(map(str, x.shape)) + " " + str(x.dtype) + ' tensor')

def register_torch_tensor_formatter(print_formatter=tensor_stats_print_formatter):
    register_formatter(torch.Tensor, torch_tensor_html, print_formatter)

In [None]:
import torch

register_torch_tensor_formatter()

torch.Tensor(np.random.rand(10, 4))

10×4 torch.float32 tensor


Unnamed: 0,0,1,2,3
0,0.059401,0.80517,0.32449,0.19236
1,0.87222,0.60864,0.3428,0.75751
2,0.74584,0.17048,0.57657,0.53344
3,0.75652,0.56956,0.1592,0.41391
4,0.11494,0.56134,0.68147,0.7926
5,0.055549,0.47204,0.47246,0.045021
6,0.48567,0.56684,0.16062,0.068107
7,0.7451,0.042947,0.67884,0.23879
8,0.12544,0.25079,0.48289,0.95939
9,0.16596,0.29857,0.90868,0.38793




### Default configuration for convenience

Finally we define a convenience function to quickly initialise the default configuration.

In [None]:
#export
def default():
    register_ndarray_formatter()
    if 'torch' in modules:
        register_torch_tensor_formatter()

In [None]:
default()

---

### TODOs and potential new features
- [x] zero-dim arrays
- [x] strings
- [x] ints
- [x] bools
- [ ] slider for decimal places or text field for format?
- [ ] fallback in case of exception
- [ ] size limit: if array is larger than certain size, do auto-slicing
- [ ] unregister hook
- [ ] check if in IPython and don't crash otherwise