# Implementing your own model

In this tutorial we show how to implement your own model and test it on a dataset. 

This particular example uses the MUTAG dataset, uses an hypergraph lifting to create hypergraphs, and defines a model to work on them. 

We train the model using the appropriate training and validation datasets, and finally test it on the test dataset.

### <font color='289C4E'>Table of contents<font><a class='anchor' id='top'></a>
&emsp;[1. Imports](##sec1)

&emsp;[2. Configurations and utilities](##sec2)

&emsp;[3. Loading the data](##sec3)

&emsp;[4. Backbone definition](##sec4)

&emsp;[5. Model initialization](##sec5)

&emsp;[6. Training](##sec6)

&emsp;[7. Testing the model](##sec7)

## 1. Imports <a class="anchor" id="sec1"></a>

In [1]:
import lightning as pl
import torch
from omegaconf import OmegaConf

from topobenchmarkx.data.loaders import GraphLoader
from topobenchmarkx.data.preprocessor import PreProcessor
from topobenchmarkx.dataloader import TBXDataloader
from topobenchmarkx.evaluator import TBXEvaluator
from topobenchmarkx.loss import TBXLoss
from topobenchmarkx.model import TBXModel
from topobenchmarkx.nn.backbones.simplicial.sann import SANN
from topobenchmarkx.nn.wrappers.simplicial.sann_wrapper import SANNWrapper
from topobenchmarkx.nn.encoders import SANNFeatureEncoder
from topobenchmarkx.nn.readouts.sann import SANNReadout
from topobenchmarkx.optimizer import TBXOptimizer
from topobenchmarkx.data.datasets.sann_data import SANNData

%load_ext autoreload
%autoreload 2

## 2. Configurations and utilities <a class="anchor" id="sec2"></a>

Configurations can be specified using yaml files or directly specified in your code like in this example.

In [82]:
in_channels = 3
hidden_channels = 2*32
hidden_channels_readout = 2*32
out_channels = 2

loader_config = {
    "data_domain": "graph",
    "data_type": "TUDataset",
    "data_name": "PROTEINS",
    "data_dir": "./data/MUTAG/PROTEINS",
}

transform_config = {
"clique_lifting": {
        "transform_type": "lifting",
        "transform_name": "SimplicialCliqueLifting",
        "feature_lifting": "Duplicate",
        "all_ones": True,
        "complex_dim": 3
    },
"sann_encoding":{
        "transform_type": "data manipulation",
        "transform_name": "PrecomputeKHopFeatures",
        "max_hop": 2,
        "complex_dim": 3
    }
}

wrapper_config = {
    "out_channels": out_channels-1,
    "num_cell_dimensions": hidden_channels,
}

split_config = {
    "learning_setting": "inductive",
    "split_type": "k-fold",
    "data_seed": 0,
    "data_split_dir": "./data/PROTEINS/splits/",
    "k": 10,
}


readout_config = {
    "readout_name": "SANNReadout",
    "num_cell_dimensions": in_channels,
    "hidden_dim": hidden_channels,
    "out_channels": out_channels,
    "task_level": "graph",
    "pooling_type": "sum",
}

loss_config = {"task": "classification", "loss_type": "cross_entropy"}

evaluator_config = {"task": "classification",
                    "num_classes": out_channels, 
                    "metrics": ["accuracy", "precision", "recall"]}

optimizer_config = {"optimizer_id": "Adam",
                    "parameters":
                        {"lr": 1e-3,"weight_decay": 1e-4}
                    }

loader_config = OmegaConf.create(loader_config)
transform_config = OmegaConf.create(transform_config)
split_config = OmegaConf.create(split_config)
readout_config = OmegaConf.create(readout_config)
loss_config = OmegaConf.create(loss_config)
evaluator_config = OmegaConf.create(evaluator_config)
optimizer_config = OmegaConf.create(optimizer_config)

In [76]:
def wrapper(**factory_kwargs):
    def factory(backbone):
        return SANNWrapper(backbone, **factory_kwargs)   
    return factory

## 3. Loading the data <a class="anchor" id="sec3"></a>

In this example we use the MUTAG dataset. It is a graph dataset and we use the k-hop lifting to transform the graphs into hypergraphs. 

We invite you to check out the README of the [repository](https://github.com/pyt-team/TopoBenchmarkX) to learn more about the various liftings offered.

In [5]:
graph_loader = GraphLoader(loader_config)

dataset, dataset_dir = graph_loader.load()

preprocessor = PreProcessor(dataset, dataset_dir, transform_config, force_reload=True)
dataset_train, dataset_val, dataset_test = preprocessor.load_dataset_splits(split_config)

datamodule = TBXDataloader(dataset_train, dataset_val, dataset_test, batch_size=64)

Processing...
  self._set_arrayXarray(i, j, x)
  UP = [torch.mm(B[i], B[i].T) for i in range(K + 1)]
Done!


Transform parameters are the same, using existing data_dir: ./data/MUTAG/PROTEINS/PROTEINS/clique_lifting_sann_encoding/2503638606


## 4. Backbone definition <a class="anchor" id="sec4"></a>

To implement a new model we only need to define the forward method.

With a hypergraph with $n$ nodes and $m$ hyperedges this model simply calculates the hyperedge features as $X_1 = B_1 \cdot X_0$ where $B_1 \in \mathbb{R}^{n \times m}$ is the incidence matrix, where $B_{ij}=1$ if node $i$ belongs to hyperedge $j$ and is 0 otherwise.

Then the outputs are computed as $X^{'}_0=\text{ReLU}(W_0 \cdot X_0 + B_0)$ and $X^{'}_1=\text{ReLU}(W_1 \cdot X_1 + B_1)$, by simply using two linear layers with ReLU activation.

## 5. Model initialization <a class="anchor" id="sec5"></a>

Now that the model is defined we can create the TBXModel, which takes care of implementing everything else that is needed to train the model. 

First we need to implement a few classes to specify the behaviour of the model.

In [84]:
backbone = SANN(in_channels=hidden_channels, hidden_channels=hidden_channels)
backbone_wrapper = wrapper(**wrapper_config)

readout = SANNReadout(**readout_config)
loss = TBXLoss(**loss_config)

in_channels_hops = [
    [in_channels, 6, 18],
    [in_channels, 12, 39],
    [in_channels, 9, 30]
]
feature_encoder = SANNFeatureEncoder(in_channels=in_channels_hops, out_channels=hidden_channels, selected_hops=range(3))

evaluator = TBXEvaluator(**evaluator_config)
optimizer = TBXOptimizer(**optimizer_config)

ValueError: Invalid task classification

In [11]:
data_loader = datamodule.train_dataloader()
d0 = next(iter(data_loader))
# do
# print(d0.batch_1.shape)
# print(d0.x1_1.shape)
d0.x1_1

tensor([[0.0000, 0.0000, 0.0000,  ..., 0.8138, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.4226, 0.4049, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.4226, 0.4049, 0.0000],
        ...,
        [0.8338, 0.8338, 0.8338,  ..., 0.8276, 0.0000, 0.0000],
        [0.8338, 0.8338, 0.8338,  ..., 0.8276, 0.0000, 0.0000],
        [0.8571, 0.8571, 0.8571,  ..., 0.8453, 0.0000, 0.0000]])

Now we can instantiate the TBXModel.

In [78]:
model = TBXModel(backbone=backbone,
                 backbone_wrapper=backbone_wrapper,
                 readout=readout,
                 loss=loss,
                 feature_encoder=feature_encoder,
                 evaluator=evaluator,
                 optimizer=optimizer,
                 compile=False)

## 6. Training <a class="anchor" id="sec6"></a>

Now we can use the `lightning` trainer to train the model.

In [79]:
# Increase the number of epochs to get better results
trainer = pl.Trainer(max_epochs=200, accelerator="cpu", enable_progress_bar=True, log_every_n_steps=1, enable_checkpointing=False)

trainer.fit(model, datamodule)
train_metrics = trainer.callback_metrics

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name            | Type               | Params
-------------------------------------------------------
0 | feature_encoder | SANNFeatureEncoder | 10.2 K
1 | backbone        | SANNWrapper        | 75.0 K
2 | readout         | SANNReadout        | 45.3 K
3 | val_acc_best    | MeanMetric         | 0     
-------------------------------------------------------
130 K     Trainable params
0         Non-trainable params
130 K     Total params
0.522     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [20]:
print('      Training metrics\n', '-'*26)
for key in train_metrics:
    print('{:<21s} {:>5.4f}'.format(key+':', train_metrics[key].item()))

      Training metrics
 --------------------------
train/accuracy:       0.8402
train/precision:      0.8490
train/recall:         0.7358
val/loss:             0.6590
val/accuracy:         0.8036
val/precision:        0.7805
val/recall:           0.7111
train/loss:           0.3559


## 7. Testing the model <a class="anchor" id="sec7"></a>

Finally, we can test the model and obtain the results.

In [21]:
trainer.test(model, datamodule=datamodule)
test_metrics = trainer.callback_metrics

  rank_zero_warn(


Testing: 0it [00:00, ?it/s]




In [56]:
print('      Testing metrics\n', '-'*25)
for key in test_metrics:
    print('{:<20s} {:>5.4f}'.format(key+':', test_metrics[key].item()))

      Testing metrics
 -------------------------
test/loss:           0.5737
test/accuracy:       0.7679
test/precision:      0.7317
test/recall:         0.6667
