In [1]:
import numpy as np
from tensorflow import keras
from tools import signature as sig
from sklearn.preprocessing import StandardScaler

from ipywidgets import interact, fixed
import ipywidgets as widgets
from plotly.subplots import make_subplots
import plotly.graph_objects as go

from ipywidgets import Image, Layout
import PIL.Image
import io

In [2]:
# Téléchargement des données
(x_train_origin, y_train_origin), (x_test_origin, y_test_origin)= \
        keras.datasets.mnist.load_data(path='all_mnist.npz')

In [24]:
keep=list(range(10))
n = 100

keep.sort()
(x, y), (x_test, y_test) = sig.pretraitement(x_train_origin, y_train_origin, 
                                                 x_test_origin, y_test_origin, 
                                                 keep=keep, n=n, verbose=False)

In [25]:
model_name="./models/signature_all_100"
model = sig.trained_model(x, y, hidden_layers=[32, 16],
                          model_name=model_name, verbose=True,load=False)

predictions = sig.get_preditions(model, x, keep, verbose=False)

labels = [keep[np.argmax(l)] for l in y]

hidden_layers = model.get_hidden_layers_outputs(x)

for i in range(len(hidden_layers)):
    hidden_layers[i] = StandardScaler().fit_transform(hidden_layers[i])

best_k = sig.get_best_k(hidden_layers, kmax = len(keep),verbose=False)
clusters = sig.clustering(hidden_layers, best_k)

- input_shape:  (784,)
- n_neurons:  [32, 16, 10]
Epoch 1/15
84/84 - 1s - loss: 0.0892 - accuracy: 0.1550
Epoch 2/15
84/84 - 0s - loss: 0.0868 - accuracy: 0.2370
Epoch 3/15
84/84 - 0s - loss: 0.0839 - accuracy: 0.2990
Epoch 4/15
84/84 - 0s - loss: 0.0799 - accuracy: 0.3660
Epoch 5/15
84/84 - 0s - loss: 0.0751 - accuracy: 0.4660
Epoch 6/15
84/84 - 0s - loss: 0.0702 - accuracy: 0.5240
Epoch 7/15
84/84 - 0s - loss: 0.0656 - accuracy: 0.5410
Epoch 8/15
84/84 - 0s - loss: 0.0616 - accuracy: 0.5740
Epoch 9/15
84/84 - 0s - loss: 0.0581 - accuracy: 0.5960
Epoch 10/15
84/84 - 0s - loss: 0.0547 - accuracy: 0.6140
Epoch 11/15
84/84 - 0s - loss: 0.0517 - accuracy: 0.6290
Epoch 12/15
84/84 - 0s - loss: 0.0492 - accuracy: 0.6490
Epoch 13/15
84/84 - 0s - loss: 0.0470 - accuracy: 0.6620
Epoch 14/15
84/84 - 0s - loss: 0.0448 - accuracy: 0.6840
Epoch 15/15
84/84 - 0s - loss: 0.0426 - accuracy: 0.7340
INFO:tensorflow:Assets written to: ./models/signature_all_100/assets


In [26]:
# Build colorscale
colorscale = {
    'color': np.zeros(len(x), dtype='uint8'),
    'colorscale': ['grey', '#ee1717', '#7201a8', '#17a60d', '#291ae0', '#f418ff', '#e3ae04', 
                    '#fb5e09', 'black', '#0ed6e1', '#d35400'],
    'cmin': 0,
    'cmax': 10
}

traces = sig.umap_plot(x, labels, predictions, hidden_layers, clusters, colorscale)

parcats = sig.parcats(labels, keep, best_k, clusters, predictions, colorscale)

traces.append(parcats)

In [27]:
def image_to_byte_array(image):
    imgByteArr = io.BytesIO()
    image.save(imgByteArr, format='png')
    imgByteArr = imgByteArr.getvalue()
    return imgByteArr

def to_image(x):
    return  image_to_byte_array(
                PIL.Image.fromarray(
                    (x.reshape((len(x_train_origin[0]) , len(x_train_origin[0][0]))) * 255).astype(np.uint8)
                ))

# Update color callback
def update_color(trace, points, state):
    # Compute new color array
    new_color = np.array(trace.marker.color)
    new_color[points.point_inds] = color_toggle.index
    with fig.batch_update():
        # Update scatter color
        for i in range(len(fig.data)-1):
            fig.data[i].marker.color = new_color

        # Update parcats colors
        fig.data[len(fig.data)-1].line.color = new_color
        
# Update color callback
def update_parcats_color(trace, points, state):
    # Compute new color array
    new_color = np.array(trace.line.color)
    new_color[points.point_inds] = color_toggle.index
    with fig.batch_update():
        # Update scatter color
        for i in range(len(fig.data)-1):
            fig.data[i].marker.color = new_color

        # Update parcats colors
        fig.data[len(fig.data)-1].line.color = new_color
        
def print_image(trace, points, state):     
    # Image update
    if (points.point_inds != []):
        image_widget.value = to_image(x[points.point_inds[0]])

In [28]:
fig = make_subplots(rows=3, figure=go.FigureWidget(), specs=[[{'type': 'xy'}],
                                                             [{'type': 'xy'}],
                                                            [{'type': 'parcats'}]])

fig.add_traces(traces, rows=[1, 2, 3], cols=[1, 1, 1])

fig.update_layout(height=800,
                    dragmode='lasso', hovermode='closest')

# Build color selection widget
color_toggle = widgets.ToggleButtons(
    options=['None', 'Red', 'Purple', 'Green', 'Blue', 'Pink', 
                 'Yellow', 'Orange','Black' ,'Aqua', 'Brown'],
    index=1, description='Color:', disabled=False)
    
image_widget = Image(
    value=to_image(x[0]),
    layout=Layout(height='100px', width='100px')
)

# Register callback on scatter selection...
fig.data[0].on_selection(update_color)
fig.data[1].on_selection(update_color)

fig.data[0].on_hover(print_image)
fig.data[1].on_hover(print_image)

# and parcats click
fig.data[2].on_click(update_parcats_color)

# Display figure
widgets.VBox([color_toggle, image_widget, fig])

VBox(children=(ToggleButtons(description='Color:', index=1, options=('None', 'Red', 'Purple', 'Green', 'Blue',…