# Digits dataset

In [1]:
import numpy as np

from sklearn.datasets import load_digits
from sklearn.cluster import AgglomerativeClustering
from sklearn.decomposition import PCA

from tdamapper.core import MapperAlgorithm
from tdamapper.cover import CubicalCover
from tdamapper.clustering import FailSafeClustering
from tdamapper.plot import MapperPlot


X, y = load_digits(return_X_y=True)  # We load a labelled dataset
lens = PCA(2).fit_transform(X)       # We compute the lens values

### Build Mapper graph

In [2]:
mapper_algo = MapperAlgorithm(
    cover=CubicalCover(
        n_intervals=10,
        overlap_frac=0.65
    ),
    clustering=FailSafeClustering(   # We prevent clustering failures            
        clustering=AgglomerativeClustering(10),
        verbose=False
    )
)

mapper_graph = mapper_algo.fit_transform(X, lens)

### Plot Mapper graph with mean

In [3]:
mapper_plot = MapperPlot(
    mapper_graph,
    dim=2,
    iterations=400,
    seed=42
)

fig = mapper_plot.plot_plotly(
    colors=y,                        # We color according to digit values
    cmap='jet',                      # Jet colormap, used for classes
    agg=np.nanmean,                  # We aggregate on graph nodes according to mean
    title='digit (mean)',
    width=600,
    height=600
)

fig.show(
    renderer='notebook_connected',
    config={'scrollZoom': True}
)

### Plot Mapper graph with standard deviation

In [4]:
mapper_plot.plot_plotly_update(
    fig,
    colors=y,                        
    cmap='viridis',                  # Viridis colormap, used for ranges
    agg=np.nanstd,                   # We aggregate on graph nodes according to std
    title='digit (std)'
)

fig.show(
    renderer='notebook_connected',
    config={'scrollZoom': True}
)