# Basic tutorial: image data
#### Author: Matteo Caorsi

This short tutorial provides you with the basic functioning of *giotto-deep* API.

The main steps of the tutorial are the following:
 1. creation of a dataset
 2. creation of a model
 3. define metrics and losses
 4. run benchmarks
 5. visualise results interactively

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import numpy as np

import torch
from torch import nn

from gdeep.models import FFNet

from gdeep.visualisation import  persistence_diagrams_of_activations

from torch.utils.tensorboard import SummaryWriter
from gdeep.data.datasets import BuildDatasets, BuildDataLoaders


from gtda.diagrams import BettiCurve

from gtda.plotting import plot_betti_surfaces

# Initialize the tensorboard writer

In order to analyse the reuslts of your models, you need to start tensorboard.
On the terminal, move inside the `/example` folder. There run the following command:

```
tensorboard --logdir=runs
```

Then go [here](http://localhost:6006/) after the training to see all the visualisation results.

In [None]:
writer = SummaryWriter()

# Create your dataset

In [None]:
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision.transforms import ToTensor  # for preprocessing

db = BuildDatasets(name="CIFAR10")
ds_tr, ds_val, ds_ts = db.build_datasets()


In [None]:
# Preprocessing

from gdeep.data.preprocessors import ToTensorImage

transformation = ToTensorImage((32,32))
transformation.fit_to_dataset(ds_tr)  # this is useless for this transformation

transformed_ds_tr = transformation.attach_transform_to_dataset(ds_tr)
transformed_ds_val = transformation.attach_transform_to_dataset(ds_val)
transformed_ds_ts = transformation.attach_transform_to_dataset(ds_ts)

# use only 320 images from cifar10
train_indices = list(range(32*10))
dl_tr, dl_val, dl_ts = BuildDataLoaders((transformed_ds_tr, transformed_ds_val, transformed_ds_ts)).build_dataloaders(
    batch_size=32, sampler=SubsetRandomSampler(train_indices))

## Define and train your model

In [None]:
import torchvision.models as models
from gdeep.pipeline import Pipeline

# wrap a sequential model in a torch nn.Module
class model3(nn.Module):
    def __init__(self):
        super(model3, self).__init__()
        self.seqmodel = nn.Sequential(models.resnet18(pretrained=True), nn.Linear(1000,10))
    def forward(self, X):
        return self.seqmodel(X)

model = model3()

In [None]:
from torch.optim import SGD

# print(model)
loss_fn = nn.CrossEntropyLoss()

pipe = Pipeline(model, (dl_tr, dl_ts), loss_fn, writer)

# train the model
pipe.train(SGD, 3, False, {"lr":0.01}, {"batch_size":32, "sampler":SubsetRandomSampler(train_indices)})



# Simply use interpretability tools

In [None]:
from gdeep.analysis.interpretability import Interpreter
from gdeep.visualisation import Visualiser

inter = Interpreter(pipe.model, method="GuidedGradCam")
output = inter.interpret_image(next(iter(dl_tr))[0][0].reshape(1,3,32,32), 
                      1, pipe.model.seqmodel[0].layer2[0].conv1);

# visualise the interpreter
vs = Visualiser(pipe)
try:
    vs.plot_interpreter_image(inter);
except AssertionError:
    print("The heatmap is made of all zeros...")

# Extract inner data from your models

In [None]:
from gdeep.models import ModelExtractor

me = ModelExtractor(pipe.model, loss_fn)

lista = me.get_layers_param()

for k, item in lista.items():
    print(k,item.shape)


In [None]:
x = next(iter(dl_tr))[0][0]
if x.dtype is not torch.int64:
    res = me.get_decision_boundary(x, n_epochs=1)
    res.shape

In [None]:
x = next(iter(dl_tr))[0]
list_activations = me.get_activations(x)
len(list_activations)


In [None]:
x, target = next(iter(dl_tr))
if x.dtype is torch.float:
    for gradient in me.get_gradients(x, target=target)[1]:
        print(gradient.shape)

# Visualise activations and other topological aspects of your model

In [None]:

# vs.plot_data_model()
# vs.plot_activations(x)
vs.plot_persistence_diagrams(x)


# Evaluate model

In the next section we compute the confusion matrix on the entire training dataloader.

In [None]:
pipe.evaluate_classification(10)