# Analysis of the topology of signals

#### Author: Matteo Caorsi

In this notebook we will give an empirical proof that in some cases the topology, extracted from time-series via Takens embeddings, may contain enough information to discriminate signal from noise.

## The dataset

The dataset consists of very weak signals in a very noisy background.

## The procedure

The procedure does not aim at building a classifier, but rather make sure that there is enough information in the topology to distinguish between signals and noise.
Ideally, if we were able to overfit the data, we would know that there is enough information in the topology to make the classification. 
There is only one reasonable tool that can try to overfit our data: `Persformer` (see [here](https://arxiv.org/abs/2112.15210) for more details).

### The task

We will train `Persformer` and try to overfit the data. The task is a binary classification task: noise VS signal. The signal is preprocessed with takens embedding techniques and `giotto-tda` is used to compute the persistece diagrams. These diagrams are then labelled with either `0` or `1`, depending on whether they contain the signal or not.

### Use of saliency

Afterwards, thanks to the use of **saliency maps**, we would also be able to understand what features in the persistent diagram are relevant for the classification, And consequently build a simple topological classfier that selects the discovered tpological features.

# Load dependencies

Here we import the `giotto-deep` dependencies and a few standard packages

In [None]:

from typing import Tuple, List

from gtda.homology import VietorisRipsPersistence
import numpy as np
import plotly.express as px
import torch
import torch.nn as nn
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import Dataset

from gdeep.data import PreprocessingPipeline
from gdeep.utility import PoolerType
from gdeep.data.datasets import PersistenceDiagramFromFiles
from gdeep.data.datasets.base_dataloaders import (DataLoaderBuilder,
                                                  DataLoaderParamsTuples)
from gdeep.data.persistence_diagrams.one_hot_persistence_diagram import (
    OneHotEncodedPersistenceDiagram, collate_fn_persistence_diagrams)
from gdeep.search.hpo import GiottoSummaryWriter
from gdeep.topology_layers import Persformer, PersformerConfig, PersformerWrapper
from gdeep.trainer.trainer import Trainer
from gdeep.utility.utils import autoreload_if_notebook
from gdeep.analysis.interpretability import Interpreter
from gdeep.visualization import Visualiser


autoreload_if_notebook()


## Generate data

The data will be generated in the next cell: a sinusoid (the signal) will be interspersed in time by white noise, and then white noise will be added to the whole signal. `SNR = 3` in this example.

In [None]:
# generate the sinusoidl signal
pure_signal = np.sin(np.linspace(0, 100, num=500))

# intersperse with white noise the sinusoids
signal_no_noise = np.hstack((pure_signal, 1-2*np.random.rand(600,), pure_signal, 1-2*np.random.rand(400,)))

# build the ground truth
label = np.hstack((np.ones((500,)), np.zeros((600,)), np.ones((500,)), np.zeros((400,))))

# add noise all over the signal
noise = 1-2*np.random.rand(signal_no_noise.shape[0],)
snr = 3
signal = noise + snr*signal_no_noise

# plot
px.scatter(signal, title = "Our Signal")

## Preprocessing

There are two steps in the preprocessing:

### Takens embedding

We propose the lagged embedding to build, out of the time series, a point cloud. It is the topology of this point cloud that we would be interested in. More formally, we extract a sequence of vectors in $\mathbb{R}^{d}$ of the form,
$$
TD_{d,\tau} s : \mathbb{R} \to \mathbb{R}^{d}\,, \qquad t \to 
\begin{bmatrix}
s(t) \\
s(t + \tau) \\
s(t + 2\tau) \\
\vdots \\
s(t + (d-1)\tau)
\end{bmatrix},
$$
where $d$ is the embedding dimension and $\tau$ is the time delay. The quantity $(d-1)\tau$ is known as the \"window size\" and the difference between $t_{i+1}$ and $t_i$ is called the stride.


### Vietoris-Rips persistence

The overall point cloud is split into a sequence of point clouds: each one will be transformed into a persistence diagram using the Vietoris-Rips filtration. The output, consisting of persistence diagrams, will then form our dataset!

In [None]:
def takens_embedding(signal, dimension):
    """this function runs the takens embedding
    on a subsampled signal, meaning that there is no stride
    at this level.

    Args:
        signal (np.array):
            complex signal
        dimension (int):
            takens embedding dimension
    """
    length = signal.shape[0]
    lista = [signal[i:length - dimension+i] for i in range(dimension)]
    return np.vstack(lista).T

te_signal = takens_embedding(signal, 6)
pts_per_cloud = 25
batches = te_signal.shape[0] // pts_per_cloud

# split the point cloud (stride = pts_per_cloud)
point_clouds = np.split(te_signal[:batches*pts_per_cloud],
                        batches,
                        axis=0)

In [None]:

# initialise the class to compute the topology of the point clouds
vr = VietorisRipsPersistence(metric="euclidean",
                             homology_dimensions=(0, 1),
                             collapse_edges=False,
                             coeff=2,
                             max_edge_length=np.inf,
                             infinity_values=None,
                             reduced_homology=True,
                             n_jobs=-1)

# compute the point clouds persistence
dgms = vr.fit_transform(point_clouds)
label = np.split(label[:batches * pts_per_cloud],
                 batches,
                 axis=0)

# unit test:
assert dgms.shape[0] == batches

In [None]:
# Labels for each diagram
actual_labels = np.round(np.mean(label, axis = 1)).astype(int)


In [None]:
from gdeep.data.persistence_diagrams import get_one_hot_encoded_persistence_diagram_from_gtda

# Build list of persistence diagrams so that you can feed them to the Persformer
list_of_dgms = []
for dgm in dgms:
    list_of_dgms.append(get_one_hot_encoded_persistence_diagram_from_gtda(dgm))
    


### Build new Dataset

Using giotto-deep and torch API it is easy to build a new fully giotto-deep compatible dataset out of the list of peersistence diagrams!

In [None]:

class PersistenceDiagramFromList(Dataset[Tuple[OneHotEncodedPersistenceDiagram, int]]):
    """
    This data type is acceptsble for Persformer. It gets the data from
    a list of diagrams.
    
    Args:
        x: 
            The input list
        y:
            The label list
    """
    def __init__(self,
                 x: List[OneHotEncodedPersistenceDiagram],
                 y: List[int]
                 ):
        self.x = x
        self.y = y

    def __len__(self) -> int:
        """
        Return the length of the dataset.
        
        Returns:
            The length of the dataset.
        """
        return len(self.x)


    def __getitem__(self, index: int) -> Tuple[OneHotEncodedPersistenceDiagram, int]:
        """
        Return the item at the specified index.
        
        Args:
            index: 
                The index of the item.
            
        Returns:
            The item at the specified index.
        """
        diagram = self.x[index]
        label = self.y[index]

        return diagram, label

# the dataset
dataset = PersistenceDiagramFromList(list_of_dgms, actual_labels)

### Build DataLoader

This is the final step for the preparation of the data: build DataLoaders out of datasets, as this is the object needed for training.

In [None]:
from gdeep.data.datasets import DataLoaderBuilder
from gdeep.data.persistence_diagrams import collate_fn_persistence_diagrams

# do not forget the collate function!!
db = DataLoaderBuilder([dataset])
dl_train, _, _ = db.build([{"batch_size": 12, "collate_fn": collate_fn_persistence_diagrams}])

## Build model

We build the *Persformer*: we try to make it reasonably large, even if the dataset is small, as the goal is to overfit rather than making a generalisable classifier.

In [None]:
# Define the model by using a Wrapper for the Persformer model

wrapped_model = PersformerWrapper(
    num_attention_layers=2,
    num_attention_heads=8,
    input_size= 2 + 2,
    output_size=2,
    pooler_type=PoolerType.ATTENTION,
    hidden_size=16,
    intermediate_size=16,
)
writer = GiottoSummaryWriter()

loss_function =  nn.CrossEntropyLoss()

trainer = Trainer(wrapped_model, [dl_train, dl_train, dl_train], loss_function, writer)


## Let's train

We are now training the model: again, the goal would be to overfit the model, as in this way we would know that there is enough information in the topological representation to perform the classification.

In [None]:
# train the model for one epoch
trainer.train(SGD, 30,
              lr_scheduler=ExponentialLR,
              scheduler_params={"gamma": 0.9},)

## Interpretation of the results

Thanks to `giotto-deep`, we can actually build Saliency maps and display the importance score on top of the diagrams points. Hence, the color of each point indicates how relevant a topological feature is in the classification of the whole task.

These saliencies will guide us on where the relevant topological information is stored and how to build a simple classifier.

In [None]:
# get a datum and its corresponding class
batch = next(iter(dl_train))
datum = batch[0][0][0].reshape(1, *(batch[0][0][0].shape))
class_ = batch[1][0].item()


# we now use the Saliency maps to interpret the results
inter = Interpreter(trainer.model, method="Saliency")

# interpret the diagram
x, attr = inter.interpret(x=datum, y=class_)

# visualise the results
vs = Visualiser(trainer)
out = vs.plot_attributions_persistence_diagrams(inter)
out


In [None]:
# what would be the suggestion for the wrong class?
class_ = (batch[1][0].item()+1) % 2
print(class_)

# interpret the diagram
x2, attr2 = inter.interpret(x=datum, y=class_)

# visualise the results
vs = Visualiser(trainer)
out2 = vs.plot_attributions_persistence_diagrams(inter)

out2


## Performance

In this simple classification task, we can check the confusion matrix to make sure that there are no major issues. This can be done in one line.

In [None]:
# safety checks on the training performance
trainer.evaluate_classification(2)