# Exploring Fashion MNIST

In the this notebook we're going to show an example of how one can build bespoke a interface for exploring the [Fashion MNIST dataset](https://github.com/zalandoresearch/fashion-mnist) embeddings that we looked at in [2-Composing-Linking-Scatter-Plots.ipynb](#Synchronizing-the-Selection-and-Hover). With `jscatter` alone we can easily select points and sychronize the selection across all four embeddings. While this is great but in order to better understand point clusters it'd be great know what the images they represent. `jscatter` does not support this out of the box but you can easily build an interface that supports this using [traitlets](https://traitlets.readthedocs.io/), [ipywidgets](https://ipywidgets.readthedocs.io/), and [anywidget](https://anywidget.dev/).

> 🚨 Shout-out Alert

Another shout-out to the marvelous [Trevor Manz](https://trevorma.nz/) for creating [anywidget](https://anywidget.dev/), which makes it super easy to build custom ipywidgets. Fun fact: `jscatter` itself is implemented as an [anywidget](https://anywidget.dev/).

---

## Setup Data & Scatter Plot Config

This is the same as from [2-Composing-Linking-Scatter-Plots.ipynb](#Synchronizing-the-Selection-and-Hover)

In [1]:
!curl -L -C - -o data/fashion-mnist-embeddings.pq https://storage.googleapis.com/flekschas/jupyter-scatter-tutorial/fashion-mnist-embeddings.pq

** Resuming transfer from byte position 2696403
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100   170  100   170    0     0    352      0 --:--:-- --:--:-- --:--:--   354


In [2]:
import pandas as pd
fashion_mnist_embeddings = pd.read_parquet('data/fashion-mnist-embeddings.pq')
fashion_mnist_embeddings = fashion_mnist_embeddings.replace({"class": {0: "T-shirt/top", 1: "Trouser", 2: "Pullover", 3: "Dress", 4: "Coat", 5: "Sandal", 6: "Shirt", 7: "Sneaker", 8: "Bag", 9: "Ankle boot"}}).astype('category')
fashion_mnist_embeddings.head(3)

Unnamed: 0,pcaX,pcaY,tsneX,tsneY,umapX,umapY,caeX,caeY,class
0,-0.207672,0.619046,-0.512748,0.862887,-0.848567,-0.177148,-0.792607,-0.95234,Ankle boot
1,0.42387,-0.392556,0.556802,-0.625932,0.973414,-0.103313,-0.493724,-0.050538,T-shirt/top
2,-0.455815,-0.708062,-0.037304,-0.186733,0.463554,-0.061681,-0.372132,-0.272005,T-shirt/top


In [3]:
config = dict(
    background_color='#111111',
    color_by='class',
    color_map={
        "T-shirt/top": '#FFFF00',
        "Trouser": '#1CE6FF',
        "Pullover": '#FF34FF',
        "Dress": '#FF4A46',
        "Coat": '#008941',
        "Sandal": '#006FA6',
        "Shirt": '#A30059',
        "Sneaker": '#FFDBE5',
        "Bag": '#7A4900',
        "Ankle boot": '#0000A6'
    },
    legend=True,
    axes=False,
    zoom_on_selection=True, # To automatically zoom to selected points
)

## Compose & Link Scatter Plots

As before, we're going to create four scatter plot instances for each of the four dimensionality reduction techniqes.

In [4]:
from jscatter import Scatter

pca = Scatter(data=fashion_mnist_embeddings, x='pcaX', y='pcaY', **config)
tsne = Scatter(data=fashion_mnist_embeddings, x='tsneX', y='tsneY', **config)
umap = Scatter(data=fashion_mnist_embeddings, x='umapX', y='umapY', **config)
cae = Scatter(data=fashion_mnist_embeddings, x='caeX', y='caeY', **config)

  start_thread=_should_start_thread(maybe_path),


The first difference to before will be a simple drop down menu to allow selection all images of a certain type. For that we use `ipywidgets.Dropdown` and change the point selection of the PCA scatter plot as the drop down selection changes. Since we will make all four scatter plots synchronize their selection, it's sufficient to update selection of either one of them to update all of them.

In [None]:
from ipywidgets import Dropdown

classes = fashion_mnist_embeddings["class"].unique()
select_class = Dropdown(options=["-"] + classes, value="-", description="Class")

# Update scatter plot selection upon choosing a different class
select_class.observe(
    lambda change: pca.selection(fashion_mnist_embeddings.query(f'`class` == "{change.new}"').index),
    names=["value"]
)

Selecting points and sychronizing the selection across all four embeddings is great but ultimately each point represents an image. And in order to better understand point clusters it'd be great know what the images look like. To render out the images we're going to create a simple widget with `anywidget` that renders a grid of images.

_(Note that to keep this demo fairly simple, we rendered out the Fashion MNIST images beforehand and uploaded to the cloud for easy access. But you could also build a more sophisticated widget with `anywidget` that receives as input the pixels as a binary stream and renders them out on the fly with something like the [Canvas API](https://developer.mozilla.org/en-US/docs/Web/API/Canvas_API).)_

In [None]:
from anywidget import AnyWidget
from traitlets import Int, List

class ImagesWidget(AnyWidget):
    _esm = """
    const baseUrl = 'https://storage.googleapis.com/flekschas/regl-scatterplot/fashion-mnist-images/';
    export function render({ model, el }) {      
      const container = document.createElement("div");
      container.style.width = "100%";
      container.style.height = "100%";
      container.style.minHeight = "192px";
      container.style.overflow = "auto";
      container.style.display = "grid";
      container.style.gridTemplateColumns = "repeat(auto-fit, minmax(32px, 1fr))";
      container.style.gap = "8px";

      const style = document.createElement('style');
      style.type = 'text/css';
      style.innerHTML = '.fashion-mnist { display: flex; width: 32px; height: 32px; background-repeat: no-repeat; background-position: center; }';
      document.getElementsByTagName('head')[0].appendChild(style);

      function choose(x, k) {
        const idxs = Array.from({ length: x.length }, (_, i) => i);
        return Array.from({ length: Math.min(k, x.length) }, () => {
          const i = Math.round(Math.random() * (idxs.length - 1));
          const idx = idxs[i];
          idxs.splice(i, 1);
          return x[idx];
        });
      }

      function getImages() {
        return choose(model.get("images"), model.get("max"));
      }

      function renderImages() {
        container.textContent = ''; // Remove all elements from container
        
        getImages().forEach(([image, color]) => {
          const imgId = String(image).padStart(5, '0');
        
          const img = document.createElement("div");
          img.className = "fashion-mnist";
          img.style.backgroundColor = color;
          img.style.backgroundImage = `url(${baseUrl}${imgId}.png)`;
        
          container.appendChild(img);
        });
      }

      model.on("change:images", renderImages);
      model.on("change:max", renderImages);
      
      renderImages();
      
      el.appendChild(container);
    }
    """

    images = List().tag(sync=True)
    max = Int(50).tag(sync=True)

Next, we instatiate an instance of the images widget and update it whenever the selection of the PCA scatter plot changes

In [None]:
from ipywidgets import Box, GridBox, HTML, Layout, Output
from random import choices

images_title = HTML(value="<b>Selected Images:</b>")
images_widget = ImagesWidget()

def on_selection_change(change):
    images = []

    for i in change.new:
        color = config["color_map"].get(fashion_mnist_embeddings.iloc[i]["class"], "#666") 
        images.append((i, color))

    images_widget.images = images

pca.widget.observe(on_selection_change, names=["selection"])

As before, we're going to compose the four scatter plots into a regular 2x2 grid and enable synchronizing their selection and hover state.

In [None]:
from jscatter import compose

scatters = compose(
    # [(pca, "PCA"), (tsne, "t-SNE"), (umap, "UMAP"), (cae, "CAE")],
    [pca, tsne, umap, cae],
    sync_selection=True,
    sync_hover=True,
    rows=2,
    row_height=240
)

Finally, to put everything together into a usable layout, we're going to use `ipywidgets.AppLayout`.

In [None]:
from ipywidgets import AppLayout, VBox

AppLayout(
    center=VBox([select_class, scatters]),
    right_sidebar=VBox([images_title, images_widget]),
)

Try selecting some points either by selecting a class from the drop down menu or by lassoing some points. See how up to `50` images appear on the right. You can also easily bump up the maximum number of images that are shown on the right

In [None]:
images_widget.max = 100

---

## Next

Next up, learn how `jscatter`'s point selection can be used to navigate in genomic data.

➡️ [Building a Bespoke Interface for Navigating Genomic Data via Loci Embeddings](3-Genomics.ipynb)