# Example 2: Digits

### Load digits dataset

In [None]:
from sklearn.datasets import load_digits
from sklearn.decomposition import PCA

# We load a labelled dataset of 8x8 greyscale images,
# and compute the lens values using PCA on the first 8
# principal components.

X, y = load_digits(return_X_y=True)
lens = PCA(8, random_state=42).fit_transform(X)     

### Build Mapper graph

In [None]:
from sklearn.cluster import DBSCAN

from tdamapper.core import MapperAlgorithm
from tdamapper.cover import CubicalCover, BallCover

# We run the Mapper algorithm using these settings.
# It's common practice to play with settings interactively
# before finding those that work best.

mapper_graph = MapperAlgorithm(
    cover=CubicalCover(
        n_intervals=6,
        overlap_frac=0.45
    ),
    clustering=DBSCAN(
        eps=50.0,
        min_samples=4
    ),
    verbose=False
).fit_transform(X, lens)

### Plot Mapper graph showing the most frequent digits

In [None]:
import numpy as np

from tdamapper.plot import MapperPlot


def mode(X):
    values, counts = np.unique(X, return_counts=True)
    index = np.argmax(counts)
    return values[index]


# We create a Mapper plot, that computes and stores the positions
# on the nodes in dimensions 2 or 3.

mapper_plot = MapperPlot(
    mapper_graph,
    dim=2,
    iterations=400,
    seed=42
)

# Then we obtain an interactive figure where each node is colored
# according to the most frequent digit contained in it.

fig = mapper_plot.plot_plotly(
    colors=y,
    cmap='jet',
    agg=mode,
    title='most frequent digit',
    width=600,
    height=600
)

fig.show(
    renderer='notebook_connected',
    config={'scrollZoom': True}
)

### Plot Mapper graph showing the mean of digits

In [None]:
# We create a new figure where each node is colored according 
# to the mean value of the digits contained in the node.
# The positions are expected to be the same as in the previous
# plot, since we pass the same MapperPlot object.

fig = mapper_plot.plot_plotly(
    colors=y,
    cmap='jet',                      
    agg=np.nanmean,
    title='mean digit',
    width=600,
    height=600,
)

fig.show(
    renderer='notebook_connected',
    config={'scrollZoom': True}
)

### Plot Mapper graph showing the standard deviation of digits

In [None]:
# We create a new figure where each node is colored according 
# to the standard deviation of the digits contained in the node.
# Also in this case the positions are expected to be the same as
# in the previous plot.

fig = mapper_plot.plot_plotly(
    colors=y,                        
    cmap='viridis',
    agg=np.nanstd,
    title='std of digit',
    width=600,
    height=600
)

fig.show(
    renderer='notebook_connected',
    config={'scrollZoom': True}
)