# Study case: MNIST hand-written digits dataset

#### Authors: Guillaume Tauzin <guillaume.tauzin@epfl.ch>
##### License: TBD


This notebook shows how to use topological data analysis to generate features for classifying digits.

The first step consists in importing the *giotto* library.

In [None]:
from giotto.images import ImageInverter, HeightFiltration, DilationFiltration, RadialFiltration, ErosionFiltration, SignedDistanceFiltration
from giotto.homology import CubicalPersistence
from giotto.diagram import DiagramDistance
from sklearn.pipeline import Pipeline, FeatureUnion

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

## Plotting functions

In [None]:
def plot_images(X):
    fig, axes = plt.subplots(4, 5, figsize=(15,10))
    axes = axes.flatten()
    cmap = plt.cm.binary
    cmap.set_bad('y')
    vmin, vmax = np.min(X[X != np.inf]), np.max(X[X != np.inf])
    
    for i in range(20):
        axes[i].imshow(X[i], cmap='binary', vmin=vmin, vmax=vmax)
        axes[i].axis('off') # hide the axes ticks
        axes[i].set_title('Correct label is '+str(int(y_train[i])), color= 'black', fontsize=12)
    plt.show()

In [None]:
def plot_diagrams(X):
    fig, axes = plt.subplots(4, 5, figsize=(15,10))
    axes = axes.flatten()
    colors = {0: 'b', 1: 'r', 2: 'g'}
    
    for i in range(20):
        diagram = { dimension: X[dimension][i] for dimension in X.keys() }
        for dimension in X.keys():
            axes[i].plot(diagram[dimension][:,0], diagram[dimension][:,1], 'o', color=colors[dimension])
        
        axes[i].plot([0, 5], [0, 5])
        axes[i].set_title('Diagram for label '+str(int(y_train[i])), color= 'black', fontsize=12)
    plt.show()

In [None]:
def plot_matrix(X):
    figure = plt.figure(figsize=(10,10))
    plt.imshow(X)
    plt.colorbar()
    plt.show()

In [None]:
def plot_matrices(X):
    n_matrices = X.shape[0]
    n_z = X.shape[3] if len(matrices.shape) == 4 else 1
    figure, axes = plt.subplots(n_z, n_matrices, figsize=(18,8+(n_z-1)*5))
    
    iterator = tuple(itertools.product(range(n_matrices), range(n_z)))
    matrices = X.reshape((n_matrices, matrices.shape[1], matrices.shape[2], n_z))
    axes = axes.reshape((n_z, n_matrices))
    cmap = plt.cm.binary
    cmap.set_bad('y')
    vmin, vmax = np.min(X[X != np.inf]), np.max(X[X != np.inf])
    
    for i, j in iterator:
        plot = axes[j, i].imshow(X[i, :, :, j], cmap=cmap, vmin=vmin, vmax=vmax)
    
    figure.subplots_adjust(bottom=0.2)
    cbar_ax = figure.add_axes([0.3, 0.2-(n_z-1)*0.06/n_z, 0.4, 0.03-(n_z-1)*0.026/n_z])
    norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
    colorbar = mpl.colorbar.ColorbarBase(cbar_ax, cmap=cmap, norm=norm,
                                         orientation='horizontal')
    colorbar.set_label('Filtration values')
    plt.show()

## Loading the MNIST dataset

In [None]:
# Read the data
train_data = pd.read_csv('../data/mnist_train.csv')
test_data = pd.read_csv("../data/mnist_test.csv")

n_samples_train, n_samples_test = train_data.shape[0], test_data.shape[0]
print('n_samples in train: ', n_samples_train)
print('n_samples in test: ', n_samples_test)

In [None]:
# Set up the data

n_train, n_test = n_samples_train, n_samples_test
n_train, n_test = 20, 100
binary_threshold = 0.4

X_train = train_data.drop(columns=['label']).values[:n_train].reshape((n_train, 28, 28)) / 255 > binary_threshold
y_train = train_data['label'].values[:n_train].reshape((n_train, 1))
X_test = test_data.values[:n_test].reshape((n_test, 28, 28)) / 255 > binary_threshold

print(X_train.shape, X_test.shape)

## Some examples of the input data
We choose the first 20 samples from the training set and visualize them.

In [None]:
plot_images(X_train)

## Inverting the boolean images

In [None]:
image_inverter = ImageInverter(n_jobs=4)
image_inverter.fit(X_train)
X_train_inverted = image_inverter.transform(X_train)

In [None]:
plot_images(X_train_inverted)

## Applying a boolean image filtration

In [None]:
n_iterations = 4

signed_distance_filtration = SignedDistanceFiltration(n_iterations=n_iterations, normalize=False, n_jobs=4)
signed_distance_filtration.fit(X_train)
X_train_filtered = signed_distance_filtration.transform(X_train_inverted)

In [None]:
plot_images(X_train_filtered)

## Getting persistence diagrams out of images

In [None]:
cubical_complex = CubicalPersistence(n_jobs=1)
cubical_complex.fit(X_train)
X_train_cubical = cubical_complex.transform(X_train)

In [None]:
plot_diagrams(X_train_cubical)

## Computing the distance matrix between the diagrams

In [None]:
diagram_distance = DiagramDistance(metric='wasserstein', metric_params={'order': 2, 'delta': 0.1}, n_jobs=1)
diagram_distance.fit(X_train_cubical)
X_train_distance = diagram_distance.transform(X_train_cubical)

In [None]:
plot_matrix(X_train_distance)

## Putting everything in a pipeline

In [None]:
steps = [
    ('filtration', SignedDistanceFiltration(n_iterations=4, normalize=False)),
    ('persistence', CubicalPersistence()),
    ('distance', DiagramDistance(metric='wasserstein', metric_params={'order': 2, 'delta': 0.1}, n_jobs=1))
    ]

pipeline_signed_distance = Pipeline(steps)

In [None]:
pipeline_signed_distance.fit(X_train)
X_train_pipeline_distance = pipeline_signed_distance.transform(X_train)

In [None]:
plot_matrix(X_train_pipeline_distance)

## Applying several pipelines based on different filtrations

In [None]:
direction_list = [ [0, 1], [0, -1], [1, 0], [-1, 0] ]
filtration_list = [HeightFiltration(direction=direction, normalize=False) 
                    for direction in direction_list]
steps_list = [ [
    ('filtration', filtration),
    ('persistence', CubicalPersistence()),
    ('distance', DiagramDistance(metric='wasserstein', metric_params={'order': 2, 'delta': 0.1}))]
    for filtration in filtration_list ]

pipeline_list = [ (str(direction_list[i]), Pipeline(steps_list[i])) for i in range(len(steps_list))]
feature_union_filtrations = FeatureUnion(pipeline_list, n_jobs=1)

In [None]:
feature_union_filtrations.fit(X_train)
X_train_filtrations = feature_union_filtrations.transform(X_train)

In [None]:
plot_matrices(X_train_filtrations)