# 📚 Adding a Custom Dataset Tutorial

## 🎯 Tutorial Overview

This comprehensive guide walks you through the process of integrating your custom dataset into our library. The process is divided into three main steps:

1. **Dataset Creation** 🔨
   - Implement data loading mechanisms
   - Define preprocessing steps
   - Structure data in the required format

2. **Integrate with Dataset APIs** 🔄
   - Add dataset to the library framework
   - Ensure compatibility with existing systems
   - Set up proper inheritance structure

3. **Configuration Setup** ⚙️
   - Define dataset parameters
   - Specify data paths and formats
   - Configure preprocessing options

## 📋 Tutorial Structure

This tutorial follows a unique structure to provide the clearest possible learning experience:

> 💡 **Main Notebook (Current File)**
> - High-level concepts and explanations
> - Step-by-step workflow description
> - References to implementation files

> 📁 **Supporting Files**
> - Detailed code implementations
> - Specific examples and use cases
> - Technical documentation

### 🛠️ Technical Framework

This tutorial demonstrates custom dataset integration using:
- `torch_geometric.data.InMemoryDataset` as the base class
- <TB_name> library's dataset management system

### 🎓 Important Notes

- To make the learning process concrete, we'll work with a practical toy "language" dataset example:
- While we use the "language" dataset as an example, all file references use the generic `<dataset_name>` format for better generalization




# Step 1: Create a Dataset 🛠️

## Overview

Adding your custom dataset to <TB_name> requires implementing specific loading and preprocessing functionality. We utilize the `torch_geometric.data.InMemoryDataset` interface to make this process straightforward.

## Required Methods

To implement your dataset, you need to override two key methods from the `torch_geometric.data.InMemoryDataset` class:

- `download()`: Handles dataset acquisition
- `process()`: Manages data preprocessing

> 💡 **Reference Implementation**: For a complete examples, see directory `topobench/data/datasets/` files `us_county_demos_dataset.py`, `mantra_dataset.py`, etc.,

### Deep Dive: The Download Method

The `download()` method is responsible for acquiring dataset files from external resources. Let's examine its implementation using our language dataset example, where we store data in a GoogleDrive-hosted zip file.

#### Implementation Steps

1. **Download Data** 📥
  - Fetch data from the specified source URL
  - Save to the raw directory

2. **Extract Content** 📦
  - Unzip the downloaded file
  - Place contents in appropriate directory

3. **Organize Files** 📂
  - Move extracted files to named folders
  - Clean up temporary files and directories

#### Code Implementation

```python
def download(self) -> None:
    r"""Download the dataset from a URL and saves it to the raw directory.

    Raises:
        FileNotFoundError: If the dataset URL is not found.
    """
    # Step 1: Download data from the source
    self.url = self.URLS[self.name]
    self.file_format = self.FILE_FORMAT[self.name]
    download_file_from_drive(
        file_link=self.url,
        path_to_save=self.raw_dir,
        dataset_name=self.name,
        file_format=self.file_format,
    )
    
    # Step 2: extract zip file
    folder = self.raw_dir
    filename = f"{self.name}.{self.file_format}"
    path = osp.join(folder, filename)
    extract_zip(path, folder)
    # Delete zip file
    os.unlink(path)
    
    # Step 3: organize files
    # Move files from osp.join(folder, name_download) to folder
    for file in os.listdir(osp.join(folder, self.name)):
        shutil.move(osp.join(folder, self.name, file), folder)
    # Delete osp.join(folder, self.name) dir
    shutil.rmtree(osp.join(folder, self.name))






# Deep Dive: The Process Method

The `process()` method handles data preprocessing and organization. Here's the method's structure:

```python
def process(self) -> None:
   r"""Handle the data for the dataset.
   
   This method loads the Language dataset, applies preprocessing 
   transformations, and saves processed data."""

   # Step 1: extract the data
   ...  # Convert raw data to list of torch_geometric.data.Data objects

   # Step 2: collate the graphs
   self.data, self.slices = self.collate(graph_sentences)

   # Step 3: save processed data
   fs.torch_save(
       (self._data.to_dict(), self.slices, {}, self._data.__class__),
       self.processed_paths[0],
   )


```self.collate``` -- Collates a list of Data or HeteroData objects to the internal storage format; meaning that it transforms a list of torch.data.Data objectis into one torch.data.BaseData.



# Step 2: Integrate with Dataset APIs 🔄

Now that we have created a dataset class, we need to integrate it with the library. In this section we describe where to add the dataset files and how to make it available through data loaders.

Here's how to structure your files, the files highlighted with ** are going to be updated: 
```yaml
topobench/
├── data/
│   ├── datasets/
│   │   ├── **init.py**
│   │   ├── base.py
│   │   ├── <dataset_name>.py   # Your dataset file
│   │   └── ...
│   ├── loaders/
│   │   ├── init.py
│   │   ├── base.py
│   │   ├── graph/
│   │   │   ├── <loader_name>.py   # Your loader file
│   │   ├── hypergraph/
│   │   │   ├── <loader_name>.py   # Your loader file
│   │   ├── .../
```

To make your dataset available to library:

The file ```<dataset_name>.py```  has been created during the previous steps (`us_county_demos_dataset.py` in our case) and should be placed in the `topobench/data/datasets/` directory. 


The registry in `topobench/data/datasets/__init__.py` discovers the files in `topobench/data/datasets` and updates `__all__` variable of `topobench/data/datasets/__init__.py` automatically. Hence there is no need to update the `__init__.py` file manually to allow your dataset to be loaded by the library. Simply creare a file `<dataset_name>.py` and place it in the  `topobench/data/datasets/` directory.

------------------------------------------------------------------------------------------------

Next it is required to update the data loader system. Modify the loader file (`topobench/data/loaders/loaders.py`:) to include your custom dataset:

For the example dataset the loader file ```topobench/data/loaders/graph/us_county_demos_dataset_loader.py``` consist of the following:

```python
class USCountyDemosDatasetLoader(AbstractLoader):
    """Load US County Demos dataset with configurable year and task variable.

    Parameters
    ----------
    parameters : DictConfig
        Configuration parameters containing:
            - data_dir: Root directory for data
            - data_name: Name of the dataset
            - year: Year of the dataset (if applicable)
            - task_variable: Task variable for the dataset
    """

    def __init__(self, parameters: DictConfig) -> None:
        super().__init__(parameters)

    def load_dataset(self) -> USCountyDemosDataset:
        """Load the US County Demos dataset.

        Returns
        -------
        USCountyDemosDataset
            The loaded US County Demos dataset with the appropriate `data_dir`.

        Raises
        ------
        RuntimeError
            If dataset loading fails.
        """

        dataset = self._initialize_dataset()
        self.data_dir = self._redefine_data_dir(dataset)
        return dataset

    def _initialize_dataset(self) -> USCountyDemosDataset:
        """Initialize the US County Demos dataset.

        Returns
        -------
        USCountyDemosDataset
            The initialized dataset instance.
        """
        return USCountyDemosDataset(
            root=str(self.root_data_dir),
            name=self.parameters.data_name,
            parameters=self.parameters,
        )

    def _redefine_data_dir(self, dataset: USCountyDemosDataset) -> Path:
        """Redefine the data directory based on the chosen (year, task_variable) pair.

        Parameters
        ----------
        dataset : USCountyDemosDataset
            The dataset instance.

        Returns
        -------
        Path
            The redefined data directory path.
        """
        return dataset.processed_root
```
The loader class have to inherit from the `AbstractLoader` where the method ```load_dataset``` is required while other methods are optional used for convenience and structure.

## Notes:
- The  ```load_dataset``` of ```AbstractLoader``` class requires to return ```torch.utils.data.Dataset``` object. 
- ### **Important:** to allow the automatic registering of the loader, make sure to include "DatasetLoader" into name of loader class (Example: USCountyDemos**DatasetLoader**)

# Step 3: Define Configuration 🔧

Now that we've integrated our dataset, we need to define its configuration parameters. In this section, we'll explain how to create and structure the configuration file for your dataset.

## Configuration File Structure
Create a new YAML file for your dataset in `configs/dataset/<dataset_name>.yaml` with the following structure:


### While creating a configuration file, you will need to specify: 

1) Loader class (`topobench.data.loaders.USCountyDemosDatasetLoader`) for automatic instantialization inside the provided pipeline and the parameters for the loader.
```yaml
# Dataset loader config
loader:
  _target_: topobench.data.loaders.USCountyDemosDatasetLoader
  parameters: 
    data_domain: graph             # Primary data domain. Options: ['graph', 'hypergrpah', 'cell, 'simplicial']
    data_type: cornel              # Data type. String emphasizing from where dataset come from. 
    data_name: US-county-demos     # Name of the dataset
    year: 2012                     # In the case of US-county-demos there are multiple version of this dataset. Options:[2012, 2016]
    task_variable: 'Election'      # Different target variable used as target. Options: ['Election', 'MedianIncome', 'MigraRate', 'BirthRate', 'DeathRate', 'BachelorRate', 'UnemploymentRate']
    data_dir: ${paths.data_dir}/${dataset.loader.parameters.data_domain}/${dataset.loader.parameters.data_type}
``` 

2) The dataset parameters: 

```yaml
# Dataset parameters
parameters:
  num_features: 6         # Number of features in the dataset
  num_classes: 1          # Dimentuin of the target variable
  task: regression        # Dataset task. Options: [classification, regression]
  loss_type: mse          # Task-specific loss function
  monitor_metric: mae     # Metric to monitor during training
  task_level: node        # Task level. Options: [classification, regression]
```

3) The dataset split parameters: 
```yaml
#splits
split_params:
  learning_setting: transductive      # Type of learning. Options:['transductive', 'inductive']
  data_seed: 0                        # Seed for data splitting
  split_type: random                  # Type of splitting. Options: ['k-fold', 'random']
  k: 10                               # Number of folds in case of "k-fold" cross-validation
  train_prop: 0.5                     # Training proportion in case of 'random' splitting strategy
  standardize: True                   # Standardize the data or not. Options: [True, False]
  data_split_dir: ${paths.data_dir}/data_splits/${dataset.loader.parameters.data_name}
```

4) Finally the dataloader parameters:

```yaml
# Dataloader parameters
dataloader_params:
  batch_size: 1       # Number of graphs per batch. In sace of transductive always 1 as there is only one graph. 
  num_workers: 0      # Number of workers for data loading
  pin_memory: False   # Pin memory for data loading
```

### Notes:
- The `paths` section in the configuration file is automatically populated with the paths to the data directory and the data splits directory.
- Some of the dataset parameters are used to configure the model.yaml and other files. Hence we suggest always include the above parameters in the dataset configuration file.



Here's the markdown for easy copying:


## Preparing to Load the Custom Dataset: Understanding Configuration Imports

Before loading our dataset, it's crucial to understand the configuration imports, particularly those from the `topobench.utils.config_resolvers` module. These utility functions play a key role in dynamically configuring your machine learning pipeline.

### Key Imports for Dynamic Configuration

Let's import the essential configuration resolver functions:

```python
from topobench.utils.config_resolvers import (
    get_default_transform,
    get_monitor_metric,
    get_monitor_mode,
    infer_in_channels,
)
```

### Why These Imports Matter

In our previous step, we explored configuration variables that use dynamic lookups, such as:

```yaml
data_dir: ${paths.data_dir}/${dataset.loader.parameters.data_domain}/${dataset.loader.parameters.data_type}
```

However, some configurations require more advanced automation, which is where these imported functions become invaluable.

### Practical Example: Dynamic Transforms

Consider the configuration in `projects/TopoBench/configs/run.yaml`, where the `transforms` parameter uses the `get_default_transform` function:

```yaml
transforms: ${get_default_transform:${dataset},${model}}
```

This syntax allows for automatic transformation selection based on the dataset and model, demonstrating the power of these configuration resolver functions.

By importing and utilizing these functions, you gain:
- Flexible configuration management
- Automatic parameter inference
- Reduced manual configuration overhead

These facilitate seamless dataset loading and preprocessing for multiple topological domains and provide an easy and intuitive interface for incorporating novel functionality.
```





In [1]:
import hydra
from hydra import compose, initialize
from hydra.utils import instantiate



from topobench.utils.config_resolvers import (
    get_default_metrics,
    get_default_trainer,
    get_default_transform,
    get_flattened_channels,
    get_monitor_metric,
    get_monitor_mode,
    get_non_relational_out_channels,
    get_required_lifting,
    infer_in_channels,
    infer_num_cell_dimensions,
)


with initialize(config_path="../configs", job_name="job"):

    cfg = compose(
        config_name="run.yaml",
        overrides=[
            "model=hypergraph/unignn2",
            "dataset=graph/US-county-demos",
        ], 
        return_hydra_config=True
    )
loader = instantiate(cfg.dataset.loader)


dataset, dataset_dir = loader.load()

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  with initialize(config_path="../configs", job_name="job"):


In [2]:
print(dataset)

US-county-demos(self.root=/home/levtel/projects/TopoBench/datasets/graph/cornel, self.name=US-county-demos, self.parameters={'data_domain': 'graph', 'data_type': 'cornel', 'data_name': 'US-county-demos', 'year': 2012, 'task_variable': 'Election', 'data_dir': '/home/levtel/projects/TopoBench/datasets/graph/cornel'}, self.force_reload=False)


In [3]:
print(dataset[0])

Data(x=[3224, 6], edge_index=[2, 18966], y=[3224])


Take a look at the default transforms

In [4]:
print('Transform name:', cfg.transforms.keys())


Transform name: dict_keys(['graph2hypergraph_lifting'])


In [5]:
from topobench.data.preprocessor import PreProcessor
preprocessed_dataset = PreProcessor(dataset, dataset_dir, cfg['transforms'])

Transform parameters are the same, using existing data_dir: /home/levtel/projects/TopoBench/datasets/graph/cornel/US-county-demos/graph2hypergraph_lifting/3613529153


In [8]:
preprocessed_dataset[0]

Data(x=[3224, 6], edge_index=[2, 18966], y=[3224], incidence_hyperedges=[3224, 3224], num_hyperedges=3224, x_0=[3224, 6], x_hyperedges=[3224, 6])

The dataset is integrated, lets run some model: 

CLI command: `python -m topobench model=hypergraph/unignn2 dataset=graph/US-county-demos`



In [6]:

import os
from hydra.core.global_hydra import GlobalHydra
from omegaconf import DictConfig
from topobench.dataloader import TBDataloader
from topobench.utils import instantiate_callbacks

# It's good practice to clear Hydra's state in a notebook 
# in case you run the cell multiple times.
GlobalHydra.instance().clear()


def run(cfg: DictConfig) -> DictConfig:
    """Run pipeline with given configuration."""
    # Instantiate and load dataset
    dataset_loader = hydra.utils.instantiate(cfg.dataset.loader)
    dataset, dataset_dir = dataset_loader.load()

    # Preprocess dataset and load the splits
    transform_config = cfg.get("transforms", None)
    preprocessor = PreProcessor(dataset, dataset_dir, transform_config)
    dataset_train, dataset_val, dataset_test = (
        preprocessor.load_dataset_splits(cfg.dataset.split_params)
    )
    
    # Prepare datamodule
    if cfg.dataset.parameters.task_level in ["node", "graph"]:
        datamodule = TBDataloader(
            dataset_train=dataset_train,
            dataset_val=dataset_val,
            dataset_test=dataset_test,
            **cfg.dataset.get("dataloader_params", {}),
        )
    else:
        raise ValueError("Invalid task_level")

    # Model for us is Network + logic: inputs backbone, readout, losses
    model = hydra.utils.instantiate(
        cfg.model,
        evaluator=cfg.evaluator,
        optimizer=cfg.optimizer,
        loss=cfg.loss,
    )
    callbacks = instantiate_callbacks(cfg.get("callbacks"))

    # === FIX IS HERE ===
    # Manually set the output directory for the Trainer since we are in a notebook
    trainer = hydra.utils.instantiate(
        cfg.trainer,
        callbacks=callbacks,
        logger=False,
        num_sanity_val_steps=0,
        default_root_dir="."  # Explicitly set the output directory
    )
    
    trainer.fit(
        model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")
    )
    ckpt_path = trainer.checkpoint_callback.best_model_path
    trainer.test(
        model=model, datamodule=datamodule, ckpt_path=ckpt_path
    )


# --- Your calling code remains the same ---
with initialize(config_path="../configs", job_name="job"):
    cfg = compose(
        config_name="run.yaml",
        overrides=[
            "model=hypergraph/unignn2",
            "dataset=graph/US-county-demos",
            "callbacks=notebook"
        ], 
        return_hydra_config=True
    )
    cfg.trainer.max_epochs=5
    run(cfg)

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  with initialize(config_path="../configs", job_name="job"):
Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.rich_model_summary.RichModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A30') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
/home/levtel/miniconda3/envs/tb/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory ./checkpoints 

Transform parameters are the same, using existing data_dir: /home/levtel/projects/TopoBench/datasets/graph/cornel/US-county-demos/graph2hypergraph_lifting/3613529153


/home/levtel/miniconda3/envs/tb/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

/home/levtel/miniconda3/envs/tb/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/mae improved. New best score: 0.735
`Trainer.fit` stopped: `max_epochs=5` reached.
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  arrmean = um.true_divide(arrmean, div, out=arrmean,
  ret = ret.dtype.type(ret / rcount)
Restoring states from the checkpoint path at ./checkpoints/epoch=4-step=5-v1.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
Loaded model weights from the checkpoint at ./checkpoints/epoch=4-step=5-v1.ckpt
/home/levtel/miniconda3/envs/tb/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]


