In [None]:
# default_exp core

# Implementation

This is the implementation of `ndpretty`. 

In [None]:
#export
import numpy as np
from sys import modules

import IPython
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

Some static definitions of the colour scale and the number format.

In [None]:
#export
lowest_color = (110, 110, 255)
highest_color = (220, 55, 55)

number_format = '%.5g'

And the core implementation:

In [None]:
#export
def ndarray_html(a):
    assert type(np.array([])) == np.ndarray, 'Only numpy ndarrays are supported'

    if len(a.shape) == 1:
        if a.shape[0] == 0:
            print('[]')
            return
        _html_array(a[:,np.newaxis])
    elif len(a.shape) == 2:
        _html_array(a)
    elif len(a.shape) > 2:
        d = len(a.shape)
        slice_str = "[:, :, " + "0, " * (d - 3) + "0]"
        slice_widget = widgets.Text(
            value=slice_str,
            placeholder="e.g. " + slice_str,
            description='Slice:',
            disabled=False
        )
        interact(_html_higher_d_array, a=fixed(a), slice_str=slice_widget)

def _to_HTML(a, alphas, is_numeric, lowest_color, highest_color):
    html = '<div style="overflow: auto">'
    html += '<table style="border-spacing: 0px;">'
    html += '<tr>'
    html += '<th></th>'
    for j in range(a.shape[1]):
        html += f'<th>{j}</th>'
    html += '</tr>'

    for i in range(a.shape[0]):
        html += "<tr>"
        html += f'<td><b>{i}</b></td>'
        for j in range(a.shape[1]):
            alpha = alphas[i][j]
            rgb_color = tuple((alpha * np.array(highest_color) + (1 - alpha) * np.array(lowest_color)).astype(np.uint8))
            hex_color = '#%02x%02x%02x' % rgb_color
            if is_numeric:
                value = number_format % a[i][j]
            else:
                value = str(a[i][j])
            html += '<td style="background-color: %s; padding:5px !important">%s</td>' % (hex_color, value)
        html += "</tr>"
    html += '</table>'
    html += '</div>'
    return html

def _html_higher_d_array(a, slice_str):
    try: 
        sliced_a = eval("a" + slice_str)
        assert len(sliced_a.shape) == 2, "didn't slice down to 2D"
        _html_array(sliced_a)
    except Exception as e:
        print("Invalid slice: " + str(e))

def _html_array(a):   
    is_numeric = np.issubdtype(a.dtype, np.number)

    if is_numeric:
        a_range = a.max() - a.min()
        if a_range == 0:
            alphas = np.zeros_like(a, dtype=np.float) + 0.5
        else:
            alphas = (a - a.min()) / a_range
    elif np.issubdtype(a.dtype, np.bool_):
        alphas = a.astype(np.float)
    else:
        alphas = np.zeros_like(a, dtype=np.float) + 0.5

    html = _to_HTML(a, alphas, is_numeric, lowest_color, highest_color)
    IPython.core.display.display(IPython.display.HTML(html))

## Example usages

A numeric 2D array:

In [None]:
ndarray_html(np.diag(np.ones(10)))

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
0,1,0,0,0,0,0,0,0,0,0
1,0,1,0,0,0,0,0,0,0,0
2,0,0,1,0,0,0,0,0,0,0
3,0,0,0,1,0,0,0,0,0,0
4,0,0,0,0,1,0,0,0,0,0
5,0,0,0,0,0,1,0,0,0,0
6,0,0,0,0,0,0,1,0,0,0
7,0,0,0,0,0,0,0,1,0,0
8,0,0,0,0,0,0,0,0,1,0
9,0,0,0,0,0,0,0,0,0,1


A numeric 1D array:

In [None]:
ndarray_html(np.ones((10,)))

Unnamed: 0,0
0,1
1,1
2,1
3,1
4,1
5,1
6,1
7,1
8,1
9,1


A numeric 4D array:

In [None]:
ndarray_html(np.random.rand(2, 3, 4, 5))

interactive(children=(Text(value='[:, :, 0, 0]', description='Slice:', placeholder='e.g. [:, :, 0, 0]'), Outpu…

When an ndarray has more than two dimensions, a text field (an [ipywidget](https://ipywidgets.readthedocs.io)) is depicted on top of the table, where the user can slice down the array to have two dimension, so that it can be displayed in a table.

_Note: The `ipywidget` that is used for multi-dimensional arrays is not correctly rendered in the documentation. Therefore see this screenshot:_ 

<img src="img/4D.png" height="80px" />

A string array:

In [None]:
ndarray_html(np.diag(['nd', 'pretty', 'ndpretty']))

Unnamed: 0,0,1,2
0,nd,,
1,,pretty,
2,,,ndpretty


A bool array:

In [None]:
ndarray_html(np.array([True, False, True]))

Unnamed: 0,0
0,True
1,False
2,True


## Registering formatters for IPython

We don't always want to call `ndarray_html` to show our nice table. In order to make it the default formatter for cell return values, here are some helper functions to automatically register the formatters.

This makes use of [IPyhton third-party formatters](https://ipython.readthedocs.io/en/stable/config/integrating.html?highlight=third%20party#formatters-for-third-party-types).

In [None]:
#export
def register_formatter(dtype, html_formatter, print_formatter=None):
    formatters = get_ipython().display_formatter.formatters
    
    formatters['text/html'].for_type(dtype, html_formatter)
    if print_formatter is not None:
        formatters['text/plain'].for_type(dtype, print_formatter)

### for `ndarray`

In [None]:
#export
def ndarray_stats_print_formatter(x, _, __):
    print('×'.join(map(str, x.shape)) + " " + str(x.dtype) + ' ndarray')

def no_print_formatter(x, _, __):
    return

def register_ndarray_formatter(print_formatter=ndarray_stats_print_formatter):
    register_formatter(np.ndarray, ndarray_html, print_formatter)

A formatter can be registered like this:

In [None]:
register_ndarray_formatter()

From then onwards, all returned ndarrays are formatted using ndpretty.

In [None]:
np.random.rand(2, 8)

2×8 float64 ndarray


Unnamed: 0,0,1,2,3,4,5,6,7
0,0.66466,0.16588,0.74568,0.86075,0.70097,0.29058,0.0038586,0.14309
1,0.2015,0.16422,0.41461,0.66311,0.012157,0.17411,0.17041,0.14735




### for `torch.Tensor`

We also define a default formatter for PyTorch `Tensor`s:

In [None]:
#export
def torch_tensor_html(t):
    ndarray_html(t.numpy())

def tensor_stats_print_formatter(x, _, __):
    print('×'.join(map(str, x.shape)) + " " + str(x.dtype) + ' tensor')

def register_torch_tensor_formatter(print_formatter=tensor_stats_print_formatter):
    register_formatter(torch.Tensor, torch_tensor_html, print_formatter)

In [None]:
import torch

register_torch_tensor_formatter()

torch.Tensor(np.random.rand(10, 4))

10×4 torch.float32 tensor


Unnamed: 0,0,1,2,3
0,0.16095,0.47095,0.18072,0.1782
1,0.80807,0.98237,0.555,0.35875
2,0.43075,0.63691,0.29013,0.035132
3,0.86955,0.79319,0.57463,0.051737
4,0.90334,0.040102,0.40316,0.44179
5,0.90831,0.53108,0.50677,0.26363
6,0.16269,0.16396,0.46756,0.43925
7,0.97244,0.81819,0.022852,0.95482
8,0.97707,0.68817,0.37027,0.31923
9,0.34871,0.90932,0.83071,0.84133




### Default configuration for convenience

Finally we define a convenience function to quickly initialise the default configuration where both ndarrays and PyTorch tensors are formatted.

In [None]:
#export
def default():
    register_ndarray_formatter()
    if 'torch' in modules:
        register_torch_tensor_formatter()

In [None]:
default()

---

### TODOs and potential new features
- [x] zero-dim arrays
- [x] strings
- [x] ints
- [x] bools
- [ ] slider for decimal places or text field for format
- [ ] fallback in case of exception
- [ ] size limit: if array is larger than certain size, do auto-slicing
- [ ] unregister hook
- [ ] check if in IPython and don't crash otherwise