# Exploring Fashion MNIST

In this notebook we're going to show an example of how one can build bespoke a interface for exploring the [Fashion MNIST dataset](https://github.com/zalandoresearch/fashion-mnist) embeddings that we looked at in [2-Composing-Linking-Scatter-Plots.ipynb](#Synchronizing-the-Selection-and-Hover). With `jscatter` alone we can easily select points and sychronize the selection across all four embeddings. While this is great but in order to better understand point clusters it'd be great know what the images they represent. `jscatter` does not support this out of the box but you can easily build an interface that supports this using [traitlets](https://traitlets.readthedocs.io/), [ipywidgets](https://ipywidgets.readthedocs.io/), and [anywidget](https://anywidget.dev/).

> 🚨 Shout-out Alert

Another shout-out to the marvelous [Trevor Manz](https://trevorma.nz/) for creating [anywidget](https://anywidget.dev/), which makes it super easy to build custom ipywidgets. Fun fact: `jscatter` itself is implemented as an [anywidget](https://anywidget.dev/).

---

## Setup Data & Scatter Plot Config

This is the same as from [2-Composing-Linking-Scatter-Plots.ipynb](#Synchronizing-the-Selection-and-Hover)

In [None]:
!mkdir -p data
!curl -L -C - -o data/fashion-mnist-embeddings.pq https://storage.googleapis.com/flekschas/jupyter-scatter-tutorial/fashion-mnist-embeddings.pq

In [None]:
import pandas as pd
fashion_mnist_embeddings = pd.read_parquet('data/fashion-mnist-embeddings.pq')
fashion_mnist_embeddings = fashion_mnist_embeddings.replace({"class": {0: "T-shirt/top", 1: "Trouser", 2: "Pullover", 3: "Dress", 4: "Coat", 5: "Sandal", 6: "Shirt", 7: "Sneaker", 8: "Bag", 9: "Ankle boot"}}).astype('category')
fashion_mnist_embeddings.head(3)

In [None]:
config = dict(
    background_color='#111111',
    color_by='class',
    color_map={
        "T-shirt/top": '#FFFF00',
        "Trouser": '#1CE6FF',
        "Pullover": '#FF34FF',
        "Dress": '#FF4A46',
        "Coat": '#008941',
        "Sandal": '#006FA6',
        "Shirt": '#A30059',
        "Sneaker": '#FFDBE5',
        "Bag": '#7A4900',
        "Ankle boot": '#0000A6'
    },
    legend=True,
    axes=False,
    zoom_on_selection=True, # To automatically zoom to selected points
)

## Compose & Link Scatter Plots

As before, we're going to create four scatter plot instances for each of the four dimensionality reduction techniqes.

In [None]:
from jscatter import Scatter

pca = Scatter(data=fashion_mnist_embeddings, x='pcaX', y='pcaY', **config)
tsne = Scatter(data=fashion_mnist_embeddings, x='tsneX', y='tsneY', **config)
umap = Scatter(data=fashion_mnist_embeddings, x='umapX', y='umapY', **config)
cae = Scatter(data=fashion_mnist_embeddings, x='caeX', y='caeY', **config)

Selecting points and sychronizing the selection across all four embeddings is great but ultimately each point represents an image. And in order to better understand point clusters it'd be great know what the images look like. To render out the images we're going to create a simple widget with `anywidget` that renders a grid of images.

_(Note that to keep this demo fairly simple, we rendered out the Fashion MNIST images beforehand and uploaded to the cloud for easy access. But you could also build a more sophisticated widget with `anywidget` that receives as input the pixels as a binary stream and renders them out on the fly with something like the [Canvas API](https://developer.mozilla.org/en-US/docs/Web/API/Canvas_API).)_

In [None]:
from anywidget import AnyWidget
from traitlets import Int, List

class ImagesWidget(AnyWidget):
    _esm = """
    const baseUrl = 'https://storage.googleapis.com/flekschas/regl-scatterplot/fashion-mnist-images/';
    export function render({ model, el }) {
      const container = document.createElement('div');
      container.classList.add('images-container');
      
      const title = document.createElement('div');
      title.classList.add('images-title');
      title.textContent = 'Selected Images';
      container.appendChild(title);

      const grid = document.createElement('div');
      grid.classList.add('images-grid');
      container.appendChild(grid);

      function choose(x, k) {
        const idxs = Array.from({ length: x.length }, (_, i) => i);
        return Array.from({ length: Math.min(k, x.length) }, () => {
          const i = Math.round(Math.random() * (idxs.length - 1));
          const idx = idxs[i];
          idxs.splice(i, 1);
          return x[idx];
        });
      }

      function getImages() {
        return choose(model.get("images"), model.get("max"));
      }

      function renderImages() {
        grid.textContent = ''; // Remove all elements from container
        
        getImages().forEach(([image, color]) => {
          const imgId = String(image).padStart(5, '0');
        
          const img = document.createElement('div');
          img.classList.add('images-fashion-mnist');
          img.style.backgroundColor = color;
          img.style.backgroundImage = `url(${baseUrl}${imgId}.png)`;
        
          grid.appendChild(img);
        });
      }

      model.on("change:images", renderImages);
      model.on("change:max", renderImages);
      
      renderImages();
      
      el.appendChild(container);
    }
    """

    _css = """
    .images-container {
      width: 100%;
      height: 736px;
      padding: 0 0 0 0.25rem;
      overflow: auto;
    }
    
    .images-title {
      font-size: var(--jp-widgets-font-size);
      font-weight: bold;
      text-align: center;
      line-height: 28px;
    }
    
    .images-grid {
      display: grid;
      grid-template-columns: repeat(auto-fit, minmax(32px, 1fr));
      align-content: flex-start;
      gap: 8px;
      width: 100%;
      height: 610px;
      margin-top: 4px;
      overflow: auto;
    }
    
    .images-fashion-mnist {
      display: flex;
      width: 32px;
      height: 32px;
      background-repeat: no-repeat;
      background-position: center;
    }
    """

    images = List().tag(sync=True)
    max = Int(50).tag(sync=True)

Next, we instatiate an instance of the images widget and update it whenever the selection of the PCA scatter plot changes

In [None]:
from ipywidgets import Box, GridBox, Layout, Output
from random import choices

images_widget = ImagesWidget()

def on_selection_change(change):
    images = []

    for i in change.new:
        color = config["color_map"].get(fashion_mnist_embeddings.iloc[i]["class"], "#666") 
        images.append((i, color))

    images_widget.images = images

pca.widget.observe(on_selection_change, names=["selection"])

Finally, to put everything together into a usable layout, we compose the scatter plots into a 2x2 grid and combine it with the images widget using `ipywidgets.AppLayout`.

In [None]:
from ipywidgets import AppLayout, VBox
from jscatter import compose

scatters = compose(
    [(pca, "PCA"), (tsne, "t-SNE"), (umap, "UMAP"), (cae, "CAE")],
    sync_selection=True,
    sync_hover=True,
    rows=2,
    row_height=310
)

AppLayout(center=scatters, right_sidebar=images_widget)

Try selecting some points either by selecting a class from the drop down menu or by lassoing some points. See how up to `50` images appear on the right. You can also easily bump up the maximum number of images that are shown on the right

In [None]:
images_widget.max = 300

---

## Next

Next up, learn how `jscatter`'s point selection can be used to navigate in genomic data.

➡️ [Building a Bespoke Interface for Navigating Genomic Data via Loci Embeddings](3-Genomics.ipynb)