In [1]:
# Library imports
import numpy as np
import pandas as pd
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import matplotlib.pyplot as plt
from matplotlib import rcParams
from PIL import Image
import math
from collections import OrderedDict

import umap
import umap.plot
from torch.utils.data import Dataset, DataLoader
from torchvision import utils

from arch.VAE import Encoder

In [2]:
rcParams['figure.dpi']=400

# The Data

In [3]:
# data settings:
batch_size=128 #input batch size for training (default: 64)
batch_size_test=1000 

# misc settings
no_cuda=False #disables CUDA training (default: True)
use_cuda = not no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print("Device:", device)

kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

scale_mu = 0.5
scale_sd = 0.5

# datasets
data = datasets.FashionMNIST('../data', train=True, download=False,
                    transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((scale_mu,), (scale_sd,))
                    ]))

# Data Loaders
loader = torch.utils.data.DataLoader(dataset=data, batch_size=batch_size_test, shuffle=True)

Device: cuda


# Define the Encoder

In [4]:
enc = Encoder(coarse_resolution=(4,4)).to(device)
print("Number of free parameters: ", sum(p.numel() for p in enc.parameters() if p.requires_grad))

Number of free parameters:  2864


In [5]:
enc.load_state_dict(torch.load('../models/fashion_mnist_encoder.pt'))

<All keys matched successfully>

# Get the encodings for all images

In [6]:
enc.eval()
encodings = []
targets = []
with torch.no_grad():
    for images, trg in loader:
        images = images.to(device)
        encodings.append(enc(images))
        targets.append(trg)
            
encodings = torch.cat(encodings).cpu().numpy()
targets = torch.cat(targets).cpu().numpy()

In [7]:
print(encodings.shape, targets.shape)

(60000, 256) (60000,)


# UMAP

In [8]:
targets = targets.astype(str)

In [9]:
num=60000
mapper = umap.UMAP().fit(encodings[:num])

### Plot settings

In [31]:
#connectivity(mapper, edge_bundling='hammer', labels = targets[:num], show_points=True, height=4096, width=4096, theme='darkgreen', pointsize=45)
umap_object=mapper
edge_bundling='hammer'
labels=targets[:num]
show_points=True
theme='fire'
width=12000
height=12000
pointsize=300
name='umap_'+theme

"""Plot connectivity relationships of the underlying UMAP
simplicial set data structure. Internally UMAP will make
use of what can be viewed as a weighted graph. This graph
can be plotted using the layout provided by UMAP as a
potential diagnostic view of the embedding. Currently this only works
for 2D embeddings. While there are many optional parameters
to further control and tailor the plotting, you need only
pass in the trained/fit umap model to get results. This plot
utility will attempt to do the hard work of avoiding
overplotting issues and provide options for plotting the
points as well as using edge bundling for graph visualization.
Parameters
----------
umap_object: trained UMAP object
    A trained UMAP object that has a 2D embedding.
edge_bundling: string or None (optional, default None)
    The edge bundling method to use. Currently supported
    are None or 'hammer'. See the datashader docs
    on graph visualization for more details.
edge_cmap: string (default 'gray_r')
    The name of a matplotlib colormap to use for shading/
    coloring the edges of the connectivity graph. Note that
    the ``theme``, if specified, will override this.
show_points: bool (optional False)
    Whether to display the points over top of the edge
    connectivity. Further options allow for coloring/
    shading the points accordingly.
labels: array, shape (n_samples,) (optional, default None)
    An array of labels (assumed integer or categorical),
    one for each data sample.
    This will be used for coloring the points in
    the plot according to their label. Note that
    this option is mutually exclusive to the ``values``
    option.
values: array, shape (n_samples,) (optional, default None)
    An array of values (assumed float or continuous),
    one for each sample.
    This will be used for coloring the points in
    the plot according to a colorscale associated
    to the total range of values. Note that this
    option is mutually exclusive to the ``labels``
    option.
theme: string (optional, default None)
    A color theme to use for plotting. A small set of
    predefined themes are provided which have relatively
    good aesthetics. Available themes are:
       * 'blue'
       * 'red'
       * 'green'
       * 'inferno'
       * 'fire'
       * 'viridis'
       * 'darkblue'
       * 'darkred'
       * 'darkgreen'
cmap: string (optional, default 'Blues')
    The name of a matplotlib colormap to use for coloring
    or shading points. If no labels or values are passed
    this will be used for shading points according to
    density (largely only of relevance for very large
    datasets). If values are passed this will be used for
    shading according the value. Note that if theme
    is passed then this value will be overridden by the
    corresponding option of the theme.
color_key: dict or array, shape (n_categories) (optional, default None)
    A way to assign colors to categoricals. This can either be
    an explicit dict mapping labels to colors (as strings of form
    '#RRGGBB'), or an array like object providing one color for
    each distinct category being provided in ``labels``. Either
    way this mapping will be used to color points according to
    the label. Note that if theme
    is passed then this value will be overridden by the
    corresponding option of the theme.
color_key_cmap: string (optional, default 'Spectral')
    The name of a matplotlib colormap to use for categorical coloring.
    If an explicit ``color_key`` is not given a color mapping for
    categories can be generated from the label list and selecting
    a matching list of colors from the given colormap. Note
    that if theme
    is passed then this value will be overridden by the
    corresponding option of the theme.
background: string (optional, default 'white)
    The color of the background. Usually this will be either
    'white' or 'black', but any color name will work. Ideally
    one wants to match this appropriately to the colors being
    used for points etc. This is one of the things that themes
    handle for you. Note that if theme
    is passed then this value will be overridden by the
    corresponding option of the theme.
width: int (optional, default 800)
    The desired width of the plot in pixels.
height: int (optional, default 800)
    The desired height of the plot in pixels
Returns
-------
result: matplotlib axis
    The result is a matplotlib axis with the relevant plot displayed.
    If you are using a notbooks and have ``%matplotlib inline`` set
    then this will simply display inline.
"""
import datashader as ds
import datashader.bundling as bd
import datashader.transfer_functions as tf
from umap.plot import _get_embedding, _datashade_points, _select_font_color, _embed_datashader_in_an_axis, _themes
from warnings import warn

In [32]:
if theme is not None:
    cmap = _themes[theme]["cmap"]
    color_key_cmap = _themes[theme]["color_key_cmap"]
    edge_cmap = _themes[theme]["edge_cmap"]
    background = _themes[theme]["background"]

### Get the points and edges

In [33]:
points = _get_embedding(umap_object)
point_df = pd.DataFrame(points, columns=("x", "y"))

point_size = pointsize#100.0 / np.sqrt(points.shape[0])
if point_size > 1:
    px_size = int(np.round(point_size))
else:
    px_size = 1

if show_points:
    edge_how = "log"
else:
    edge_how = "eq_hist"

In [12]:
coo_graph = umap_object.graph_.tocoo()
edge_df = pd.DataFrame(
    np.vstack([coo_graph.row, coo_graph.col, coo_graph.data]).T,
    columns=("source", "target", "weight"),
)
edge_df["source"] = edge_df.source.astype(np.int32)
edge_df["target"] = edge_df.target.astype(np.int32)

In [13]:
if edge_bundling is None:
    edges = bd.directly_connect_edges(point_df, edge_df, weight="weight")
elif edge_bundling == "hammer":
    warn(
        "Hammer edge bundling is expensive for large graphs!\n"
        "This may take a long time to compute!"
    )
    edges = bd.hammer_bundle(point_df, edge_df, weight="weight")
else:
    raise ValueError("{} is not a recognised bundling method".format(edge_bundling))

This may take a long time to compute!
  """


### Create the canvas and render the objects

In [34]:
extent = umap.plot._get_extent(points)
canvas = ds.Canvas(
    plot_width=width,
    plot_height=height,
    x_range=(extent[0], extent[1]),
    y_range=(extent[2], extent[3]),
)

In [35]:
# render the edges
edge_img = tf.shade(
    canvas.line(edges, "x", "y", agg=ds.sum("weight")),
    cmap=plt.get_cmap(edge_cmap),
    how=edge_how,
)

# render the background
edge_img = tf.set_background(edge_img, background)

In [36]:
# render the points
if show_points:
    point_img = _datashade_points(
        points,
        None,
        labels,
        None,
        cmap,
        None,
        color_key_cmap,
        None,
        width,
        height,
        False,
    )
    if px_size > 1:
        point_img = tf.dynspread(point_img, threshold=0.9, max_px=px_size, how='add')
    result = tf.stack(edge_img, point_img, how="over")
else:
    result = edge_img

In [37]:
result.to_pil().save('../data/{name}.png'.format(name=name))