# Basic tutorial: tabular data
#### Author: Matteo Caorsi

This short tutorial provides you with the basic functioning of *giotto-deep* API.

The main steps of the tutorial are the following:
 1. creation of a dataset
 2. creation of a model
 3. define metrics and losses
 4. run trainig
 5. visualise results interactively

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import numpy as np
import plotly.express as px
import torch
from torch import nn
import pandas as pd

from sklearn import datasets

from gdeep.models import FFNet
from gdeep.models import ModelExtractor
from gdeep.analysis.interpretability import Interpreter
from torch.optim.lr_scheduler import ExponentialLR
from torch.optim import SGD, Adam
from gdeep.utility.optimisation import SAMOptimizer
from torch.utils.data import SubsetRandomSampler

from gdeep.visualisation import  persistence_diagrams_of_activations
from gdeep.pipeline import Pipeline

from torch.utils.tensorboard import SummaryWriter
from gdeep.data import TorchDataLoader

from gtda.diagrams import BettiCurve

from gtda.plotting import plot_betti_surfaces

from sklearn.model_selection import StratifiedKFold

# Initialize the tensorboard writer

In order to analyse the reuslts of your models, you need to start tensorboard.
On the terminal, move inside the `/example` folder. There run the following command:

```
tensorboard --logdir=runs
```

Then go [here](http://localhost:6006/) after the training to see all the visualisation results.

In [None]:
writer = SummaryWriter()

# Create your dataset

In [None]:
dl = TorchDataLoader(name="DoubleTori")
#train_indices = list(range(160))
dl_tr, dl_val = dl.build_dataloaders(batch_size=23) #, sampler=SubsetRandomSampler(train_indices))
dl_ts = None

In [None]:
isinstance(next(iter(dl_tr))[0],torch.Tensor)

## Define and train your model

In [None]:

class model1(nn.Module):
    def __init__(self):
        super(model1, self).__init__()
        self.seqmodel = nn.Sequential(nn.Flatten(), FFNet(arch=[3, 5, 10, 5, 2]))
    def forward(self, x):
        return self.seqmodel(x)

model = model1()

In [None]:
# initlaise the loss function
loss_fn = nn.CrossEntropyLoss()
# initialise the pipelien class
pipe = Pipeline(model, (dl_tr, dl_val, _), loss_fn, writer, StratifiedKFold(3, shuffle=True))

# initialise the SAM optimiser
optim = SAMOptimizer(SGD)  # this is a class, not an instance!

# train the model with learning rate scheduler and cross-validation
pipe.train(optim, 5, True, optimizers_param={"lr": 0.01}, lr_scheduler=ExponentialLR, scheduler_params={"gamma": 0.9}, 
           profiling=False, store_grad_layer_hist=True, writer_tag="tori")


## Add pipeline hooks

It is possible to add a hook (a callable) that is called at the end of each training epoch.

The arguments of the hook are fix and have to respect the order!

In [None]:
def example_of_hook(epoch, optim, me, writer):
    print(f"Here we print the learning rate {optim.param_groups[0]['lr']} at epoch={epoch}")
    print(f"We can also get the value of gradients and parameters of the model "
          f"using the model extractor! {me.get_layers_param():}")
    

pipe.register_pipe_hook(example_of_hook)

Let's train the model with cross validation: we just have to set the parameter `cross_validation = True`.

The `keep_training = True` flag allow us to restart from the same scheduler, optimiser and trained model obtained at thhe end of the last training in the instance of the class `pipe`.

In [None]:
# train the model with CV
pipe.train(SGD, 3, cross_validation=True, keep_training=True, profiling=True, writer_tag="tori/kt")

# since we used the keep training flag, the optimiser has not been modified compared to the previous training.
print(pipe.optimizer)

# Simply use interpretability tools

In [None]:

inter = Interpreter(pipe.model)
inter.interpret_tabular(next(iter(dl_tr))[0], next(iter(dl_tr))[1]);


# Extract inner data from your models

In [None]:


me = ModelExtractor(pipe.model, loss_fn)

lista = me.get_layers_param()
for k, item in lista.items():
    print(k,item.shape)


In [None]:
x = next(iter(dl_tr))[0][0]
if x.dtype is not torch.int64:
    res = me.get_decision_boundary(x, n_epochs=100)
    res.shape

In [None]:
x = next(iter(dl_tr))[0]
list_activations = me.get_activations(x)
len(list_activations)


In [None]:
x, target = next(iter(dl_tr))
if x.dtype is torch.float:
    for gradient in me.get_gradients(x, target=target)[1]:
        print(gradient.shape)

# Visualise activations and other topological aspects of your model

In [None]:
from gdeep.visualisation import Visualiser

vs = Visualiser(pipe)

vs.plot_data_model()
vs.plot_activations(x)
vs.plot_persistence_diagrams(x)


In [None]:

vs.plot_decision_boundary();

In [None]:
vs.betti_plot_layers((0, 1), x)

In [None]:
plt = vs.plot_interpreter_tabular(inter)

In [None]:
# evaluation of the model performances for the classification task
pipe.evaluate_classification(2)