<a href="https://colab.research.google.com/github/lc0/deeplearning-playground/blob/master/Altair_image_hover.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import tensorflow as tf
from tensorflow import keras

fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

train_images_norm = train_images/256
test_images_norm = test_images/256

In [0]:
model = keras.Sequential([
    keras.layers.Flatten(input_shape=(28, 28)),
    keras.layers.Dense(128, activation=tf.nn.relu),
    keras.layers.Dense(10, activation=tf.nn.softmax)
])


In [3]:
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten (Flatten)            (None, 784)               0         
_________________________________________________________________
dense (Dense)                (None, 128)               100480    
_________________________________________________________________
dense_1 (Dense)              (None, 10)                1290      
Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0
_________________________________________________________________


In [0]:
model.compile(optimizer=tf.optimizers.Adam(),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])


In [5]:
model.fit(train_images_norm, train_labels, epochs=5)


Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x7f4200773d30>

In [6]:
loss, accuracy = model.evaluate(test_images_norm, test_labels)
print('Test loss:', loss)
print('Test accuracy:', accuracy)

Test loss: 0.3475441634654999
Test accuracy: 0.8741999864578247


## Visualisation

In [0]:
sample_images = train_images_norm[:5000].reshape(-1 ,28*28)
sample_labels = train_labels[:5000]

In [0]:
fashion_mnist_labels = ["T-shirt/top",  # index 0
                        "Trouser",      # index 1
                        "Pullover",     # index 2 
                        "Dress",        # index 3 
                        "Coat",         # index 4
                        "Sandal",       # index 5
                        "Shirt",        # index 6 
                        "Sneaker",      # index 7 
                        "Bag",          # index 8 
                        "Ankle boot"]   # index 9

## PCA + Extract from last layer

In [0]:
from sklearn.decomposition import PCA

model_features = tf.keras.Model(model.inputs, model.get_layer('dense').output)

In [0]:
image_features = model_features.predict(sample_images.reshape(-1, 28, 28))

In [0]:
pca = PCA(n_components=2)

extracted_pca_data = pca.fit_transform(image_features)

# Altair visualization

## Flask server for images

In [12]:
!pip install flask_cors



### flask-grok

Here I extended a bit to run flask in separate thread

In [13]:
!pip install git+https://github.com/lc0/flask-ngrok@thread

Collecting git+https://github.com/lc0/flask-ngrok@thread
  Cloning https://github.com/lc0/flask-ngrok (to revision thread) to /tmp/pip-req-build-0tqwvtk4
  Running command git clone -q https://github.com/lc0/flask-ngrok /tmp/pip-req-build-0tqwvtk4
  Running command git checkout -b thread --track origin/thread
  Switched to a new branch 'thread'
  Branch 'thread' set up to track remote branch 'thread' from 'origin'.
Building wheels for collected packages: flask-ngrok
  Building wheel for flask-ngrok (setup.py) ... [?25l[?25hdone
  Created wheel for flask-ngrok: filename=flask_ngrok-0.0.26-cp36-none-any.whl size=6233 sha256=fb626b5f42d66fc0ae7fa572ecafdc652ddeba077e3c6867c3132203a2d37397
  Stored in directory: /tmp/pip-ephem-wheel-cache-w_wn5okl/wheels/af/00/62/632d10dbddb6224ed6535cb27ae2d9629421362638885cdad1
Successfully built flask-ngrok


In [0]:
from flask_ngrok import start_flask_thread
server_thread = None

In [0]:
import io
from PIL import Image
from flask import Flask, send_file
from flask_cors import CORS


app = Flask(__name__)
CORS(app)

@app.route("/img/<int:image_id>.png")
def image(image_id):
    # convert numpy array to PIL Image
    img = Image.fromarray(sample_images[image_id].reshape(28, 28)*255).convert('RGB')

    # create file-object in memory
    file_object = io.BytesIO()

    # write PNG in file-object
    img.save(file_object, 'PNG')

    # move to beginning of file so `send_file()` it will read from start
    file_object.seek(0)

    return send_file(file_object, mimetype='image/PNG')

In [17]:
server_thread = start_flask_thread(server_thread, app, port=5555)

 * Running on ['https://77f2a29b.ngrok.io']
Started a server thread on 0.0.0.0:5555


## Altair Grammar

In [0]:
import altair as alt
import numpy as np
import pandas as pd

In [0]:
BASE_URL = 'https://77f2a29b.ngrok.io'
dummy_image_ids = [f"{BASE_URL}/img/{image_id}.png" for image_id in range(0, len(sample_labels))]

In [34]:
dummy_image_ids[:10]

['https://77f2a29b.ngrok.io/img/0.png',
 'https://77f2a29b.ngrok.io/img/1.png',
 'https://77f2a29b.ngrok.io/img/2.png',
 'https://77f2a29b.ngrok.io/img/3.png',
 'https://77f2a29b.ngrok.io/img/4.png',
 'https://77f2a29b.ngrok.io/img/5.png',
 'https://77f2a29b.ngrok.io/img/6.png',
 'https://77f2a29b.ngrok.io/img/7.png',
 'https://77f2a29b.ngrok.io/img/8.png',
 'https://77f2a29b.ngrok.io/img/9.png']

In [39]:
def visualize_embeddings_altair(data, labels, dummy_image_ids):
    pca_data_labels = np.vstack((data.T, labels, dummy_image_ids)).T
    df = pd.DataFrame(data=pca_data_labels, columns=['x', 'y', 'label', 'image_id'])

    multi_mouseover = alt.selection_multi(on='mouseover', toggle=True, empty='none')


    base = alt.Chart(df).mark_circle(radius=50).encode(
        x='x:Q',
        y='y:Q',

        color='label:N',
    ).add_selection(
        multi_mouseover
    )

    
    image = alt.Chart(df).mark_image(
    ).encode(
        url='image_id'
    ).transform_filter(
        multi_mouseover
    ).properties(width=100, height=100)
    
    return base | image


visualize_embeddings_altair(extracted_pca_data, sample_labels, dummy_image_ids)