# 📚 Adding a Custom Dataset Tutorial

## 🎯 Tutorial Overview

This comprehensive guide walks you through the process of integrating your custom dataset into our library. The process is divided into three main steps:

1. **Dataset Creation** 🔨
   - Implement data loading mechanisms
   - Define preprocessing steps
   - Structure data in the required format

2. **Integrate with Dataset APIs** 🔄
   - Add dataset to the library framework
   - Ensure compatibility with existing systems
   - Set up proper inheritance structure

3. **Configuration Setup** ⚙️
   - Define dataset parameters
   - Specify data paths and formats
   - Configure preprocessing options

## 📋 Tutorial Structure

This tutorial follows a unique structure to provide the clearest possible learning experience:

> 💡 **Main Notebook (Current File)**
> - High-level concepts and explanations
> - Step-by-step workflow description
> - References to implementation files

> 📁 **Supporting Files**
> - Detailed code implementations
> - Specific examples and use cases
> - Technical documentation

### 🛠️ Technical Framework

This tutorial demonstrates custom dataset integration using:
- `torch_geometric.data.InMemoryDataset` as the base class
- <TBX_name> library's dataset management system

### 🎓 Important Notes

- To make the learning process concrete, we'll work with a practical toy "language" dataset example:
- While we use the "language" dataset as an example, all file references use the generic `<dataset_name>` format for better generalization




# Step 1: Create a Dataset 🛠️

## Overview

Adding your custom dataset to <TBX_name> requires implementing specific loading and preprocessing functionality. We utilize the `torch_geometric.data.InMemoryDataset` interface to make this process straightforward.

## Required Methods

To implement your dataset, you need to override two key methods from the `torch_geometric.data.InMemoryDataset` class:

- `download()`: Handles dataset acquisition
- `process()`: Manages data preprocessing

> 💡 **Reference Implementation**: For a complete example, check `topobenchmarkx/data/datasets/language_dataset.py`

### Deep Dive: The Download Method

The `download()` method is responsible for acquiring dataset files from external resources. Let's examine its implementation using our language dataset example, where we store data in a GoogleDrive-hosted zip file.

#### Implementation Steps

1. **Download Data** 📥
  - Fetch data from the specified source URL
  - Save to the raw directory

2. **Extract Content** 📦
  - Unzip the downloaded file
  - Place contents in appropriate directory

3. **Organize Files** 📂
  - Move extracted files to named folders
  - Clean up temporary files and directories

#### Code Implementation

```python
def download(self) -> None:
    # Step 1: download data from source
    self.url = self.URLS[self.name]
    self.file_format = self.FILE_FORMAT[self.name]
    download_file_from_drive(file_link=self.url, path_to_save=self.raw_dir, 
                           dataset_name=self.name, file_format=self.file_format)

    # Step 2: extract zip file
    folder = self.raw_dir
    filename = f"{self.name}.{self.file_format}"
    path = osp.join(folder, filename)
    extract_zip(path, folder)
    os.unlink(path)  # Delete zip
    
    # Step 3: organize files
    for file in os.listdir(osp.join(folder, self.name)):
        if file.endswith('ipynb'): continue
        shutil.move(osp.join(folder, self.name, file), osp.join(folder, file))
    shutil.rmtree(osp.join(folder, self.name))  # Cleanup






### Deep Dive: The Process Method

The `process()` method handles data preprocessing and organization. Here's the method's structure:

```python
def process(self) -> None:
   r"""Handle the data for the dataset.
   
   This method loads the Language dataset, applies preprocessing 
   transformations, and saves processed data."""

   # Step 1: extract the data
   ...  # Convert raw data to list of torch_geometric.data.Data objects

   # Step 2: collate the graphs
   self.data, self.slices = self.collate(graph_sentences)

   # Step 3: save processed data
   fs.torch_save(
       (self._data.to_dict(), self.slices, {}, self._data.__class__),
       self.processed_paths[0],
   )
   self.graph = graph_sentences


```self.collate``` - Quote: Collates a list of Data or HeteroData objects to the internal storage format; meaning that it transforms a list of torch.data.Data objectis into one torch.data.BaseData allowing for .. 



# Step 2: Integrate with Dataset APIs 🔄

Now that we have created a dataset class, we need to integrate it with the benchmark library. In this section we describe where to add the dataset files and how to make it available through data loaders.


Here's how to structure your files, the files highlighted with ** are going to be updated: 
```
topobenchmarkx/
├── data/
│   ├── datasets/
│   │   ├── **init.py**
│   │   ├── base.py
│   │   ├── <dataset_name>.py   # Your dataset file
│   │   └── ...
|   ├── loaders
│   │   ├── init.py
│   │   ├── **loaders.py**
```

To make your dataset available to library:

The file ```<dataset_name>.py```  has been created during the previous steps (`language_dataset.py` in our case) and should be placed in the `topobenchmarkx/data/datasets/` directory. 


Now it is important Update Registry `topobenchmarkx/data/datasets/__init__.py` to include custom dataset into the library:

```python
from .<dataset_name> import <dataset_name_class>

__all__ = [
    # Other datasets...
    '<dataset_name_class>',
]
```

Next it is required to update the data loader system. Modify the loader file (`topobenchmarkx/data/loaders/loaders.py`:) to include your dataset:

For the the toy example dataset we add the following into the ```load``` method of ```GraphLoader``` class: 

```python
elif self.parameters.data_name in ["LanguageDataset"]:
   dataset = LanguageDataset(
       root=root_data_dir,
       name=self.parameters["data_name"],
       parameters=self.parameters,
   )
   
   data_dir = dataset.processed_root
```


**Notes:**
- In `topobenchmarkx/data/loaders/loaders.py` we additionally provide a template for adding new dataset. 
- The  ```load``` of class ```GraphLoader``` has to return ```dataset: torch_geometric.data.Dataset``` and ```data_dir: str``` 

# Step 3: Define Configuration 🔧

Now that we've integrated our dataset, we need to define its configuration parameters. In this section, we'll explain how to create and structure the configuration file for your dataset.

## Configuration File Structure
Create a new YAML file for your dataset in `configs/dataset/<dataset_name>.yaml` with the following structure:

```yaml
# Dataset loader config
loader:
 _target_: topobenchmarkx.data.loaders.GraphLoader
 parameters: 
   data_domain: graph
   data_type: NLP
   data_name: LanguageDataset
   data_dir: ${paths.data_dir}/${dataset.loader.parameters.data_domain}/${dataset.loader.parameters.data_type}

# Dataset parameters
parameters:
 num_features: 0
 num_classes: 0
 task: regression
 loss_type: mse
 monitor_metric: mae
 task_level: node

# Splits configuration
split_params:
 learning_setting: transductive
 data_split_dir: ${paths.data_dir}/data_splits/${dataset.loader.parameters.data_name}
 data_seed: 0
 split_type: random  # or 'k-fold'
 k: 10               # for k-fold Cross-Validation
 train_prop: 0.5     # for random strategy
 standardize: True

# Dataloader parameters
dataloader_params:
 batch_size: 1       # Fixed
 num_workers: 0
 pin_memory: False

# Step 4: Custom Data Transformations ⚙️

While most datasets can be used directly after integration, some require specific preprocessing transformations. These transformations might vary depending on the task, model, or other conditions.

## Example Case: Language Dataset

Let's look at our language dataset's structure:
- Each graph represents an English sentence
- Nodes are tokens i.e. each node is a string and it doesn't have corresponding high dimentional feature embedding
- Edges come from transformer attention, hence forming a fully connected graph

For this dataset, two default transformations are logical:
1. **Graph Sparsification**: Reduce edge density
2. **Node Feature Generation**: Create numerical features from tokens


Below we provide an quick tutorial on how to create a data transformations and create a sequence of default transformations that will be executed whener you use the defined dataset config file.

### Creating a Transform

In general any transfom in the library inherits `torch_geometric.transforms.BaseTransform` class, which allow to apply a sequency of transforms to the data. Our inderface requires to implement the `forward` method. The important part of all transforms is that it takes `torch_geometric.data.Data` object and returns updated `torch_geometric.data.Data` object.



For language dataset,  we have generated the `attention2graph` transfroms that is a data_manipulation transform hence we place it into `topobenchmarkx/transforms/data_manipulation/` folder. 
Below you can see the `forward` method of `Attention2Graph` class: 


```python
   def forward(self, data: torch_geometric.data.Data):
       # Reshape attention scores
       attention_shape = data.attention_shape
       attention_scores = data.attention_scores.reshape(attention_shape)
       
       # Apply threshold
       mask = attention_scores > self.parameters["threshold"]
       edge_index = torch.stack(torch.where(mask==1))
       
       # Process edges
       edge_index = torch_geometric.utils.remove_self_loops(edge_index)[0]
       edge_index = torch_geometric.utils.to_undirected(edge_index)
       data.edge_index = edge_index
       
       return data
```

Please see the `topobenchmarkx/transforms/data_manipulation/attention2graph.py` file for the precise implementation. 

### Register the Transform

Similarly as adding dataset we have to registed the transform we have created, to do so please follow the procedure below:

Update `topobenchmarkx/transforms/data_manipulations/__init__.py`:

``` python
# Step 1: Import your transform
from .attention2graph import Attention2Graph

# Step 2: Add to DATA_MANIPULATIONS dictionary
DATA_MANIPULATIONS = {
    "Identity": IdentityTransform,
    # ... other transforms ...
    "Attention2Graph": Attention2Graph,  # Add your transform
}

# Step 3: Add to __all__
__all__ = [
    "DATA_MANIPULATIONS",
    # ... other transforms ...
    "Attention2Graph"  # Add your transform
]
```

### Create a configuration file 
Now as we have registered the transform we can finally create the configuration file and use it in the framework: 

``` yaml
_target_: topobenchmarkx.transforms.data_transform.DataTransform
transform_name: "Attention2Graph"
transform_type: "data manipulation"
threshold: 0.1
``` 
Please refer to `configs/transforms/dataset_defaults/attention2graph.yaml` for the example. 

**Notes:**

- You might notice an interesting key `_target_` in the configuration file. In general for any new transform you the `_target_` is always `topobenchmarkx.transforms.data_transform.DataTransform`.  [For more information please refer to hydra documentation "Instantiating objects with Hydra" section.](https://hydra.cc/docs/advanced/instantiate_objects/overview/). 

### Default transforms

Now when we have crated the transfor we can define a list of default transforms that will be executed always whenwever the dataset under default configuration is used.


To configure the deafult transform navigate to `configs/transforms/dataset_defaults` create `<def_transforms.yaml>` and the follwoing `.yaml` file: 

```yaml
defaults:
  - transform_1: transform_1
  - transform_2: transform_2
  - transform_3: transform_3
```


**Important**
There are different types of transforms, including `data_manipulation`, `liftings`, and `feature_liftings`. In case you want to use multiple transforms from the same categoty, let's say from `data_manipulation`, then it is required to stick to a special syntaxis. [See hydra configuration for more information]() or the example below: 

```yaml
defaults:
  - data_manipulation@first_usage: transform_1
  - data_manipulation@second_usage: transform_2
```


In the case of the language dataset we have the following `language.yaml` file:

```yaml
defaults:
  - data_manipulations@equal_gaus_features: equal_gaus_features
  - data_manipulations@attention2graph: attention2graph
  - liftings@_here_: ${get_required_lifting:graph,${model}}

equal_gaus_features:
  num_features: 10
```


In our example we have a bunch of interesting aspects: 
- There are a two transforms from the same catgory `data_manipulations`, hence we use operator `@` to assign new names `equal_gaus_features` and `attention2graph`
-  In the case of `equal_gaus_features` we have to override the initial number of features as the `equal_gaus_features.yaml` uses a special "configuration register" to infer the feature dimension. In the case of our language dataset, we don't have node's hidden features, hence we need to define the number of features of our own. 
- We recommend to add `liftings@_here_: ${get_required_lifting:graph,${model}}` so that a default lifting is applied to run any domain-specific topological model.

```python 

In [3]:
# Step 1
# Add language_dataset.py file in topobenchmarkx/data/datasets folder

# Step 2
# Update __init__.py file in topobenchmarkx/data/datasets
# Update loaders.py in topobenchmarkx/data/loaders

# Step 3
# Create confin file TopoBenchmark/configs/dataset/graph


In [4]:



# Preprocessing: 
# Optional: in case you need to have some default transform, create them following 1. and generate corresponding .yaml configuration files and add to configs/transforms/dataset_defaults folder 


# To create a transform see steps below
# Step 1: Crate transform requitred for the dataset add it to appropriate topobenchmarkx/transforms folder
# Step 1.1:
# Step 1.2:
# Step 1.3: 
# Step 1.4: Add configuration file associated with the transform into transforms/data_manipulations (see file attention2graph.yaml)

In [1]:
import rootutils

rootutils.setup_root("./", indicator=".project-root", pythonpath=True)

%load_ext autoreload
%autoreload 2

import hydra
import torch
import torch_geometric
from hydra import compose, initialize
from omegaconf import OmegaConf

from topobenchmarkx.data.preprocessor import PreProcessor
from topobenchmarkx.dataloader.dataloader import TBXDataloader
from topobenchmarkx.data.loaders import GraphLoader

from topobenchmarkx.utils.config_resolvers import (
    get_default_transform,
    get_monitor_metric,
    get_monitor_mode,
    infer_in_channels,
)


initialize(config_path="../configs", job_name="job")

cfg = compose(config_name="run.yaml", return_hydra_config=True)
graph_loader = GraphLoader(cfg.dataset.loader.parameters)
dataset, dataset_dir = graph_loader.load()

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  initialize(config_path="../configs", job_name="job")


Download complete.


Extracting /home/lev/projects/TopoBenchmark/datasets/graph/NLP/LanguageDataset/raw/LanguageDataset.zip
Processing...
Done!


In [6]:
cfg['transforms']['equal_gaus_features']['num_features']

10

In [7]:
preprocessed_dataset = PreProcessor(dataset, dataset_dir, cfg['transforms'])

Processing...
Done!
