# Case study: MNIST hand-written digits dataset

##### License: Apache 2.0


This notebook shows how to use topological data analysis to generate features for classifying digits.

## Import libraries
The first step consists in importing relevant *gtda* components and other useful libraries or modules.

In [None]:
from gtda.images import Binarizer, Inverter, ImageToPointCloud, HeightFiltration, DilationFiltration, RadialFiltration, ErosionFiltration, SignedDistanceFiltration
from pgtda.images import DensityFiltration
from gtda.homology import VietorisRipsPersistence, CubicalPersistence
from gtda.diagrams import ForgetDimension, PairwiseDistance, Amplitude, Scaler, PersistenceEntropy, BettiCurve, PersistenceLandscape, HeatKernel
from sklearn.pipeline import Pipeline, make_pipeline, FeatureUnion, make_union
from gtda.diagrams._utils import _subdiagrams

import numpy as np
import gzip
import pickle as pkl
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline

## Plotting functions

In [None]:
def plot_images(X):
    fig, axes = plt.subplots(4, 5, figsize=(15,10))
    axes = axes.flatten()
    cmap = plt.cm.binary
    cmap.set_bad('y')
    vmin, vmax = np.min(X[X != np.inf]), np.max(X[X != np.inf])
    
    for i in range(20):
        axes[i].imshow(X[i], cmap='binary', vmin=vmin, vmax=vmax)
        axes[i].axis('off') # hide the axes ticks
        axes[i].set_title('Correct label is '+str(int(y_train[i])), color= 'black', fontsize=12)
    plt.show()

In [None]:
def plot_point_clouds(X):
    fig, axes = plt.subplots(4, 5, figsize=(15,10))
    axes = axes.flatten()
    cmap = plt.cm.binary
    cmap.set_bad('y')
    vmin, vmax = np.min(X[X != np.inf]), np.max(X[X != np.inf])

    for i in range(20):
        axes[i].plot(X[i, :, 0], X[i, :, 1], marker='s', linestyle='')
        axes[i].set_xlim(0, 27)
        axes[i].set_ylim(0, 27)
        axes[i].axis('off') # hide the axes ticks
        axes[i].set_title('Correct label is '+str(int(y_train[i])), color= 'black', fontsize=12)
    plt.show()

In [None]:
def plot_diagrams(X):
    fig, axes = plt.subplots(4, 5, figsize=(15,10))
    axes = axes.flatten()
    colors = {0: 'b', 1: 'r', 2: 'g'}
    homology_dimensions = sorted(list(set(X[0, :, 2])))
    
    vmin, vmax = np.inf, -np.inf
    for i in range(20):
        for dim in homology_dimensions:
            diagram_dim = _subdiagrams(X, [dim], remove_dim=True)[i]
            vmin, vmax = min(vmin, np.min(diagram_dim)), max(vmax, np.max(diagram_dim))
            axes[i].plot(diagram_dim[:,0], diagram_dim[:,1], 'o', color=colors[int(dim)])
            
    for i in range(20):
        axes[i].plot([vmin, vmax], [vmin, vmax], color='k')
        axes[i].set_title('Diagram for label '+str(int(y_train[i])), color= 'black', fontsize=12)
    plt.show()

In [None]:
def plot_matrix(X):
    figure = plt.figure(figsize=(15,10))
    vmin, vmax = np.min(X), np.max(X)

    plt.imshow(X)
    figure.subplots_adjust(bottom=0.2)
    cbar_ax = figure.add_axes([0.3, 0.2, 0.4, 0.03])
    norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
    colorbar = mpl.colorbar.ColorbarBase(cbar_ax, norm=norm, orientation='horizontal')
    colorbar.set_label('Filtration values')
    plt.show()

In [None]:
def plot_matrices(X):
    n_matrices = 20
    figure, axes = plt.subplots(4, 5, figsize=(15,10))
    axes = axes.flatten()
    
    iterator = tuple(itertools.product(range(n_matrices), range(1)))
    axes = axes.reshape((1, n_matrices))
    vmin, vmax = np.min(X), np.max(X)
    for i, j in iterator:
        plot = axes[j, i].imshow(X[i], vmin=vmin, vmax=vmax)
        axes[j, i].axis('off') # hide the axes ticks

    figure.subplots_adjust(bottom=0.2)
    cbar_ax = figure.add_axes([0.3, 0.2, 0.4, 0.03])
    norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
    colorbar = mpl.colorbar.ColorbarBase(cbar_ax, norm=norm, orientation='horizontal')
    colorbar.set_label('Filtration values')
    plt.show()

In [None]:
def plot_curve(X, y=None, n_curves=1):
    figure = plt.figure(figsize=(10,5))
    n_points = X.shape[0] // n_curves
    if y is None:
        y = np.arange(n_points)
    for i in range(n_curves):
        X_curve = X[i*n_points:(i+1)*n_points]
        plt.plot(y, X_curve)
    plt.show()

## Loading the MNIST dataset

In [None]:
# Download data here: https://github.com/mnielsen/neural-networks-and-deep-learning/blob/master/data/mnist.pkl.gz
((X, y), (X_valid, y_valid), _) = pkl.load(gzip.open('/home/rookstar/Downloads/mnist.pkl.gz', 'rb'), encoding='latin-1')

X = X.reshape((-1, 28, 28))

print(X.shape, y.shape)
print(np.min(X), np.max(X))

In [None]:
# Set up the data

n_train, n_test = 40000, 10000

X_train = X[:n_train]
y_train = y[:n_train]
X_test = X[n_train:n_train+n_test]
y_test = y[n_train:n_train+n_test]

print(X_train.shape, y_train.shape, X_test.shape, y_test.shape)

## Some examples of the input data
We choose the first 20 samples from the training set and visualize them.

In [None]:
plot_images(X_train[:20])

## Binarization of the images

In [None]:
binarizer = Binarizer(threshold=0.4)
binarizer.fit(X_train[:20])
X_train_binarized = binarizer.transform(X_train[:20])

In [None]:
plot_images(X_train_binarized)

## Transforming an image to a point cloud

In [None]:
point_cloud_transformer = ImageToPointCloud()
point_cloud_transformer.fit(X_train[:20])
X_train_points = point_cloud_transformer.transform(X_train[:20])

In [None]:
plot_point_clouds(X_train_points)

In [None]:
rips_complex = VietorisRipsPersistence(metric='euclidean', max_edge_length=100, 
                                       homology_dimensions=(0, 1))
rips_complex.fit(X_train_points)
X_train_rips = rips_complex.transform(X_train_points)

In [None]:
plot_diagrams(X_train_rips)

## Computing the persistence landscape

In [None]:
landscape = PersistenceLandscape(n_bins=1000, n_layers=3, n_jobs=1)
landscape.fit(X_train_rips)
X_train_landscape = landscape.transform(X_train_rips)

In [None]:
print(X_train_landscape.shape)
plot_curve(X_train_landscape[1,1,0,:], landscape.samplings_[1.0].reshape(-1,), n_curves=1)

## Computing the heat kernel of stacked diagrams

In [None]:
diagram_stacker = ForgetDimension()
diagram_stacker.fit(X_train_rips)
X_train_stacked = diagram_stacker.transform(X_train_rips)

In [None]:
betti = BettiCurve(n_bins=100, n_jobs=1)
betti.fit(X_train_stacked)
X_train_betti = betti.transform(X_train_stacked)

In [None]:
print(X_train_betti.shape)
plot_curve(X_train_betti[1, 0], betti.samplings_[np.inf].reshape(-1,), n_curves=1)

## Inverting the boolean images

In [None]:
inverter = Inverter(n_jobs=4)
inverter.fit(X_train_binarized[:20])
X_train_inverted = inverter.transform(X_train_binarized[:20])

In [None]:
plot_images(X_train_inverted)

## Applying a boolean image filtration

In [None]:
n_iterations = 28

signed_distance_filtration = DilationFiltration(n_iterations=n_iterations, n_jobs=4)
signed_distance_filtration.fit(X_train_binarized[:20])
X_train_filtered = signed_distance_filtration.transform(X_train_binarized[:20])

In [None]:
plot_images(X_train_filtered)

## Getting persistence diagrams out of images

In [None]:
cubical_complex = CubicalPersistence(n_jobs=1)
cubical_complex.fit(X_train_filtered)
X_train_cubical = cubical_complex.transform(X_train_filtered)

In [None]:
plot_diagrams(X_train_cubical)

## Rescaling the diagrams

In [None]:
metric = {'metric': 'bottleneck', 'metric_params': {}}

diagram_scaler = Scaler(**metric)
diagram_scaler.fit(X_train_cubical)
X_train_scaled = diagram_scaler.transform(X_train_cubical)

In [None]:
plot_diagrams(X_train_scaled)

## Computing the distance matrix between the diagrams

In [None]:
diagram_distance = PairwiseDistance(metric='wasserstein', metric_params={'p': 2, 'delta': 0.1}, n_jobs=1)
diagram_distance.fit(X_train_cubical)
X_train_distance = diagram_distance.transform(X_train_cubical)

In [None]:
plot_matrix(X_train_distance)

## Putting everything in a pipeline

In [None]:
steps = [
    ('binarizer', Binarizer(threshold=0.4)),
    ('filtration', SignedDistanceFiltration(n_iterations=28)),
    ('persistence', CubicalPersistence(n_jobs=1)),
    ('distance', PairwiseDistance(metric='wasserstein', metric_params={'p': 2, 'delta': 0.1}, n_jobs=1))
    ]

pipeline_signed_distance = Pipeline(steps)

In [None]:
pipeline_signed_distance.fit(X_train[:20])
X_train_pipeline_distance = pipeline_signed_distance.fit_transform(X_train[:20])

In [None]:
plot_matrix(X_train_pipeline_distance)

## Applying several pipelines based on different filtrations

In [None]:
direction_list = [ np.array([0, 1]), np.array([0, -1]), np.array([1, 0]), np.array([-1, 0]) ]

filtration_list = [HeightFiltration(direction=direction) 
                    for direction in direction_list]

steps_list = [ [
    ('binarizer', Binarizer(threshold=0.4)),
    ('filtration', filtration),
    ('persistence', CubicalPersistence()),
    ('distance', Amplitude(metric='heat', metric_params={'p': 2}))]
    for filtration in filtration_list ]

pipeline_list = [ (str(direction_list[i]), Pipeline(steps_list[i])) for i in range(len(steps_list))]
feature_union_filtrations = FeatureUnion(pipeline_list, n_jobs=-1)

In [None]:
feature_union_filtrations.fit(X_train[:20])
X_train_filtrations = feature_union_filtrations.transform(X_train[:20])

In [None]:
plot_curve(X_train_filtrations, n_curves=len(filtration_list))