# 📚 Adding a Custom Dataset Tutorial

## 🎯 Tutorial Overview

This comprehensive guide walks you through the process of integrating your custom dataset into our library. The process is divided into three main steps:

1.  **Dataset Creation** 🔨
    * Implement a new dataset class (inheriting from `InMemoryDataset`).
    * Define the `download` logic to get your raw data files.
    * Define the `process` logic to convert raw data into `Data` objects.

2.  **Integrate with Dataset APIs** 🔄
    * Implement a `DatasetLoader` class to make your dataset accessible to the library.
    * Place your new Python files in the correct directories for auto-registration.

3.  **Configuration Setup** ⚙️
    * Create a `.yaml` config file for your dataset.
    * Define all dataset parameters, such as task type and feature dimensions.
    * Configure data paths, splitting strategies, and batching settings.

## 📋 How This Tutorial is Structured

This tutorial is presented in two complementary parts:

> 💡 **1. The Main Guide (This Notebook)**
> * Provides a high-level, step-by-step walkthrough of the entire **custom dataset integration** process.
> * Explains the core concepts and demonstrates how the integrated dataset can be used with the TopoBench library, complete with runnable code.

> 📁 **2. The Reference Templates (Library Files)**
> * The guide refers to the actual `.py` and `.yaml` implementation files within the library.
> * These files are intended to serve as starting points, which can be copied and modified for a new custom dataset.

## 🛠️ Technical Framework

Adding a new dataset involves "plugging it in" to TopoBench's existing architecture. This tutorial's process connects three key components:

* **`torch_geometric.data.InMemoryDataset`**: This is the standard **PyTorch Geometric (PyG)** base class your new dataset must inherit from. We use it because our datasets are expected to fit and be processed entirely in memory.

* **TopoBench's `AbstractLoader`**: This is the TopoBench-specific class that makes the framework aware of your dataset. You will implement a new loader that inherits from this, which tells TopoBench *how* to find and instantiate your dataset class.

* **Hydra Configuration**: This is how you define your dataset's parameters (like its name, task, and paths). TopoBench uses **Hydra** to manage all experiment configurations in `.yaml` files. This system makes it simple to compose experiments and override any setting from the command line.

## 🎓 Important Notes

* To make the learning process concrete, this guide uses the existing **`US-county-demos`** dataset as a reference.
* When following the instructions to create a new dataset, all template paths and class names use the placeholder **`<dataset_name>`**. This placeholder must be replaced with the actual name of the new dataset.

**For example:**

| If the instruction shows a template like: | And the new dataset is named "MyDataset": |
| :--- | :--- |
| `configs/dataset/<dataset_name>.yaml` | The new file should be: `configs/dataset/MyDataset.yaml` |
| `class <DatasetName>Dataset:` | The new class should be: `class MyDatasetDataset:` |



# Step 1: Create a Dataset Class 🛠️

## The Goal

The first step is to create a new **Python class** that defines your dataset. This class tells TopoBench how to get, process, and save your data.

We will inherit from **`torch_geometric.data.InMemoryDataset`**. This is a powerful base class from PyTorch Geometric that does most of the heavy lifting. For example, it automatically:
* Checks if the dataset is already downloaded or processed.
* Handles the logic of `self.raw_dir` and `self.processed_dir`.

Our only job is to provide the logic for the two methods it needs.

## Key Methods to Implement

To get our new class working, we only need to implement two core methods:

* **`download()`**
    * **What it does:** Gets your *raw* data files (like `.zip`, `.csv`, `.json`, etc.) from their source.
    * **Your Task:** Download from a URL or copy from a local directory.
    * **Output:** All raw files must be saved in the `self.raw_dir` (the `raw/` folder).

* **`process()`**
    * **What it does:** Converts the *raw* files into the final, ready-to-use PyTorch Geometric format.
    * **Your Task:** Load the raw files from `self.raw_dir`, perform all preprocessing, and build one or more `torch_geometric.data.Data` objects.
    * **Output:** The final dataset (a list of `Data` objects) must be saved as a single `.pt` file in the `self.processed_dir` (the `processed/` folder).

> 💡 **Reference Implementation**: For a complete example, see the directory `topobench/data/datasets/` and files like `us_county_demos_dataset.py`, `mantra_dataset.py`, etc.




## Deep Dive: The Download Method

> 💡 **Reference Implementation:**
> The code discussed in this section is adapted from:
> `topobench/data/datasets/us_county_demos_dataset.py`

The `download()` method's only job is to get the raw data files (like `.zip`, `.csv`, etc.) and save them in the `self.raw_dir` folder.

Let's examine the implementation from our `US-county-demos` example.

### 1. Defining URLs (Class Attributes)

First, it's a best practice to define URLs as class attributes, so they are easy to find and change. In the reference file, these are defined at the top of the class:

```python
class USCountyDemosDataset(InMemoryDataset):
    URLS = {
        'US-county-demos': '10-3W-P-1m-R_r-Z-L3S6_G1-hZk-m'
    }
    FILE_FORMAT = {
        'US-county-demos': 'zip'
    }

    def __init__(self, ...):
        ...

#### 3. Implementation Steps & Code

With that context, the `download` method is straightforward. The process is:

1.  **Get URL & Download** 📥
    * Look up the dataset's URL and file format from the class attributes.
    * Call `download_file_from_drive` to fetch the data and save it to `self.raw_dir`.

2.  **Extract Content** 📦
    * Call `extract_zip` to unzip the downloaded file (e.g., `US-county-demos.zip`) into the `self.raw_dir`.
    * Delete the original `.zip` file to save space.

3.  **Organize Files** 📂
    * Often, unzipping a file creates an extra sub-folder (e.g., the zip extracts to `raw/US-county-demos/`).
    * This step "flattens" the directory by moving all files from that sub-folder directly into `self.raw_dir`.
    * Finally, it removes the now-empty sub-folder.

#### Code Implementation

```python
def download(self) -> None:
    r"""Download the dataset from a URL and saves it to the raw directory.

    Raises:
        FileNotFoundError: If the dataset URL is not found.
    """
    # Step 1: Download data from the source
    self.url = self.URLS[self.name]
    self.file_format = self.FILE_FORMAT[self.name]
    download_file_from_drive(
        file_link=self.url,
        path_to_save=self.raw_dir,
        dataset_name=self.name,
        file_format=self.file_format,
    )
    
    # Step 2: Extract zip file
    folder = self.raw_dir
    filename = f"{self.name}.{self.file_format}"
    path = osp.join(folder, filename)
    extract_zip(path, folder)
    # Delete zip file
    os.unlink(path)
    
    # Step 3: Organize files
    # Move files from the unzipped sub-folder up one level
    source_folder = osp.join(folder, self.name)
    for file in os.listdir(source_folder):
        shutil.move(osp.join(source_folder, file), folder)
    # Delete the now-empty sub-folder
    shutil.rmtree(source_folder)

## Deep Dive: The Process Method



> 💡 **Reference Implementation:**
> The code discussed in this section is the *actual* implementation from:
> `topobench/data/datasets/us_county_demos_dataset.py`

Complete `process()` method:
```python
def process(self) -> None:
        r"""Handle the data for the dataset.

        This method loads the US county demographics data, applies any pre-
        processing transformations if specified, and saves the processed data
        to the appropriate location.
        """
        # Step 1: extract the data
        data = read_us_county_demos(
            self.raw_dir, self.year, self.task_variable
        )
        data_list = [data]

        # Step 2: collate the graphs
        self.data, self.slices = self.collate(data_list)
        self._data_list = None  # Reset cache.

        # Step 3: save processed data
        fs.torch_save(
            (self._data.to_dict(), self.slices, {}, self._data.__class__),
            self.processed_paths[0],
        )
```


The `process()` method is the heart of your dataset. Its job is to load the raw files you downloaded in `self.raw_dir` and convert them into the final `Data` objects that PyTorch Geometric can use.

The method has three responsibilities:

1.  **Load Raw Data:** Read your `.csv`, `.json`, `.txt`, or other files from `self.raw_dir`.
2.  **Build `Data` Objects:** Create one or more `torch_geometric.data.Data` objects holding your node features (`x`), connectivity (`edge_index`), and targets (`y`).
3.  **Collate & Save:** Combine all `Data` objects and save them to `self.processed_dir`.

In the `US-county-demos` implementation, the complex logic of Step 1 is hidden inside a helper function called `read_us_county_demos`. This is a very common and clean way to structure your code.

### 1. Step 1: Extract the Data (via Helper)

Instead of writing all the `pandas` and `torch` logic directly in the `process` method, the code calls a custom function `read_us_county_demos`.

* **What it does:** This helper function (which is in the same file) is responsible for loading the raw `.csv` and `.txt` files, creating the `x`, `edge_index`, and `y` tensors, and returning a single, fully-formed `Data` object.
* **Parameters:** It passes `self.raw_dir` (to find the files) and parameters like `self.year` and `self.task_variable` so it can build the *correct* version of the dataset.

```python
        # Step 1: extract the data
        data = read_us_county_demos(
            self.raw_dir, self.year, self.task_variable
        )
        data_list = [data]
```

### 2. Step 2: Collate the Graphs

This step is standard for `InMemoryDataset`.

* **`self.collate(data_list)`**: This helper method takes the list of `Data` objects (in this case, a list with just one item) and formats it into an efficient, collated `BaseData` object. This is the standard storage format PyG uses.

```python
        # Step 2: collate the graphs
        self.data, self.slices = self.collate(data_list)
```

### 3. Step 3: Save Processed Data 

This step saves the collated data to the `self.processed_dir`. The `US-county-demos` dataset uses a specific PyG file system utility (`fs.torch_save`) for this.

* **`fs.torch_save(...)`**: This is the TopoBench-specific way to save the processed data. It serializes the data object's dictionary representation (`self._data.to_dict()`) and the slices, saving them to the path specified in `self.processed_paths[0]`.

```python
        # Step 3: save processed data
        fs.torch_save(
            (self._data.to_dict(), self.slices, {}, self._data.__class__),
            self.processed_paths[0],
        )
```


## Step 2: Integrate with Dataset APIs 🔄



### The Goal

At this point, the `Dataset` class defined in Step 1 is just an "offline" Python file. TopoBench has no way of finding or using it.

The goal of Step 2 is to "plug in" our new dataset so the framework can **discover and load it**. This requires two actions:
1.  Placing the dataset file in the correct directory for auto-registration.
2.  Creating a **Loader Class** to act as the bridge between TopoBench's configuration system and our dataset class.

---

### Part 1: Place Your Dataset File

This is the easiest step.

1.  **Move your file:** Place the `<dataset_name>.py` file (e.g., `us_county_demos_dataset.py`) created in Step 1 into the `topobench/data/datasets/` directory.

    ```yaml
    topobench/
    ├── data/
    │   ├── datasets/
    │   │   ├── __init__.py         # <-- This file handles auto-registration
    │   │   ├── base.py
    │   │   ├── <dataset_name>.py   # <-- Your file from Step 1 goes here
    │   │   └── ...
    │   ├── loaders/
    │   │   └── ...
    ```

2.  **That's it!** The `topobench/data/datasets/__init__.py` file is designed to **automatically discover** and register any new dataset class in this directory. There is no need to edit `__init__.py` manually.

---

### Part 2: Create a Dataset Loader

Next, we must create a **Loader**. This is a simple class that inherits from `AbstractLoader` and tells TopoBench *how* to instantiate your dataset class and pass in the correct parameters (like `year` or `task_variable`) from the config file.

1.  **Create a new loader file:** Create a new file named `<dataset_name>_loader.py` (e.g., `us_county_demos_dataset_loader.py`) and place it in the appropriate `loaders` subdirectory. For standard graphs, this is `topobench/data/loaders/graph/`.

    ```yaml
    topobench/
    ├── data/
    │   ├── datasets/
    │   │   └── ...
    │   ├── loaders/
    │   │   ├── __init__.py       # <-- This file also handles auto-registration
    │   │   ├── base.py           # <-- Contains the AbstractLoader
    │   │   ├── graph/
    │   │   │   ├── <dataset_name>_loader.py  # <-- Your new loader file goes here
    │   │   │   └── ...
    │   │   ├── hypergraph/
    │   │   │   └── ...
    ```

2.  **Implement the Loader Class:** The loader is mostly boilerplate. Its main job is to implement the `load_dataset` method, which simply creates an instance of your dataset class from Step 1.

    Here is the complete template from `us_county_demos_dataset_loader.py` (including the necessary imports):

    ```python
    from pathlib import Path
    from omegaconf import DictConfig

    from topobench.data.datasets import USCountyDemosDataset  # <-- Import class from Step 1
    from topobench.data.loaders.base import AbstractLoader    # <-- Import base class


    class USCountyDemosDatasetLoader(AbstractLoader):
        """Load US County Demos dataset with configurable year and task variable.

        Parameters
        ----------
        parameters : DictConfig
            Configuration parameters containing:
                - data_dir: Root directory for data
                - data_name: Name of the dataset
                - year: Year of the dataset (if applicable)
                - task_variable: Task variable for the dataset
        """

        def __init__(self, parameters: DictConfig) -> None:
            super().__init__(parameters)

        def load_dataset(self) -> USCountyDemosDataset:
            """Load the US County Demos dataset.

            This is the main method called by TopoBench.
            It initializes the dataset and returns it.
            """
            dataset = self._initialize_dataset()
            self.data_dir = self._redefine_data_dir(dataset)
            return dataset

        def _initialize_dataset(self) -> USCountyDemosDataset:
            """Helper method to instantiate the dataset class."""
            
            # This is the key line: it creates an instance
            # of the class from Step 1, passing in parameters
            # from the config file (accessed via self.parameters).
            return USCountyDemosDataset(
                root=str(self.root_data_dir),
                name=self.parameters.data_name,
                parameters=self.parameters,
            )

        def _redefine_data_dir(self, dataset: USCountyDemosDataset) -> Path:
            """Helper method to get the final processed data path."""
            return dataset.processed_root
    ```

---

### ❗ Key Integration Rules (Auto-Registration)

For the TopoBench framework to find your new loader class automatically, you **must** follow these naming conventions:

* **Inheritance:** Your loader class (e.g., `USCountyDemosDatasetLoader`) must inherit from `AbstractLoader`.
* **Method:** Your loader class must implement the `load_dataset` method, which must return a `torch.utils.data.Dataset` object (our `InMemoryDataset` class qualifies).
* **Naming Convention:** This is the most important part. The class name **must** end with the suffix **`DatasetLoader`**.
    * ✅ `MyCoolDatasetLoader`
    * ✅ `USCountyDemosDatasetLoader`
    * ❌ `MyCoolDataset`
    * ❌ `MyCoolLoader`

## Step 3: Define Configuration ⚙️

### The Goal

The goal of this step is to create a `.yaml` configuration file. This file is the "control panel" for your dataset. It tells TopoBench (and the **Hydra** framework) exactly:
1.  **What to load:** Using the `_target_` key to point to your `Loader` class from Step 2.
2.  **How to load it:** By passing parameters (like `year` or `task_variable`) to your loader.
3.  **What its properties are:** By defining parameters like `num_features`, `task`, and `task_level`.

### File Location

First, create your new configuration file in the appropriate sub-directory. The path is critical for TopoBench to find it when you reference it (e.g., `dataset=graph/US-county-demos`).

**Create the file at:**
`configs/dataset/<data_domain>/<dataset_name>.yaml`

**For our example:**
`configs/dataset/graph/US-county-demos.yaml`

---

## Configuration File Structure

Your new `.yaml` file must define four main sections:

### 1. Loader Configuration

This section tells Hydra which `Loader` class to instantiate.

* `_target_`: This is the most important key. It's a special **Hydra** command that tells the framework to import and create an instance of this *exact* Python class (the one created in Step 2).
* `parameters`: This nested dictionary contains all the arguments that will be passed to your `Loader`'s `__init__` method.
* `data_dir: ${...}`: Note the `${...}` syntax. This is **Hydra's** way of *interpolating variables*—it re-uses values from other parts of the configuration (in this case, the global `paths.data_dir`).

```yaml
# Dataset loader config
loader:
  _target_: topobench.data.loaders.graph.us_county_demos_dataset_loader.USCountyDemosDatasetLoader
  parameters: 
    data_domain: graph             # Primary data domain. Options: ['graph', 'hypergraph', 'cell', 'simplicial']
    data_type: cornel              # Data type. String emphasizing from where dataset come from. 
    data_name: US-county-demos     # Name of the dataset
    year: 2012                     # In the case of US-county-demos there are multiple version of this dataset. Options:[2012, 2016]
    task_variable: 'Election'      # Different target variable used as target. Options: ['Election', 'MedianIncome', 'MigraRate', 'BirthRate', 'DeathRate', 'BachelorRate', 'UnemploymentRate']
    data_dir: ${paths.data_dir}/${dataset.loader.parameters.data_domain}/${dataset.loader.parameters.data_type}
```

### 2. Dataset Parameters

This section defines the objective properties of your dataset. These values are used by other parts of the framework (like models and trainers) to configure themselves correctly.

```yaml
# Dataset parameters
parameters:
  num_features: 6         # Number of features in the dataset
  num_classes: 1          # Dimension of the target variable
  task: regression        # Dataset task. Options: [classification, regression]
  loss_type: mse          # Task-specific loss function
  monitor_metric: mae     # Metric to monitor during training
  task_level: node        # Task level. Options: [node, edge, graph]
```

### 3. Split Parameters

This section controls all aspects of data splitting for training, validation, and testing.

```yaml
# Splits
split_params:
  learning_setting: transductive      # Type of learning. Options:['transductive', 'inductive']
  data_seed: 0                        # Seed for data splitting
  split_type: random                  # Type of splitting. Options: ['k-fold', 'random']
  k: 10                               # Number of folds in case of "k-fold" cross-validation
  train_prop: 0.5                     # Training proportion in case of 'random' splitting strategy
  standardize: True                   # Standardize the data or not. Options: [True, False]
  data_split_dir: ${paths.data_dir}/data_splits/${dataset.loader.parameters.data_name}
```

### 4. Dataloader Parameters

Finally, these parameters are passed directly to the PyTorch `DataLoader` wrapper that handles batching.

```yaml
# Dataloader parameters
dataloader_params:
  batch_size: 1       # Number of graphs per batch. In case of transductive always 1 as there is only one graph. 
  num_workers: 0      # Number of workers for data loading
  pin_memory: False   # Pin memory for data loading
```

---

### Notes
* The `paths` object (used in `${paths.data_dir}`) is automatically resolved by Hydra from a global configuration file.
* It is critical to include all these parameters, even if they seem redundant. Other configuration files (like for models) dynamically read these values (e.g., `model.in_channels=${dataset.parameters.num_features}`).

## 🎉 Congratulations! The Dataset is Integrated!

That's it! By completing these three steps, the new dataset is now fully integrated into the TopoBench framework.

### Summary of What Was Accomplished

1.  **Dataset Class:** A new `.py` file was created (e.g., `us_county_demos_dataset.py`) that inherits from `InMemoryDataset` and implements the `download` and `process` logic.
2.  **Loader Class:** A second `.py` file was created (e.g., `us_county_demos_dataset_loader.py`) that inherits from `AbstractLoader` to "plug" the dataset into the framework.
3.  **Config File:** A new `.yaml` file (e.g., `US-county-demos.yaml`) was created to define all the dataset's parameters and tell Hydra how to load it.

### Next Steps: Run It!

Thanks to this setup, the dataset is now available to the entire framework. It can be loaded, preprocessed, and used in any experiment simply by referencing its name.

The following part of this notebook will demonstrate how to load the newly integrated dataset and even run a full training and evaluation pipeline.


## Preparing to Load the Custom Dataset: Understanding Configuration Imports


## Step 4: Run It! (Understanding Config Resolvers)

We are now ready to load and use our new dataset.

Before we run the code, we must understand one advanced feature of Hydra that TopoBench configuration system relies on: **Resolvers**.

### What is a Resolver?

A Resolver is a Python function that can be called *from within a `.yaml` file*. This allows for powerful, dynamic configuration.

In Step 3, we saw simple **variable substitution**:
```yaml
# This just copies a value
data_dir: ${paths.data_dir}/${...}
```

But TopoBench also uses function resolvers for automated logic:

```yaml
# This calls a Python function
transforms: ${get_default_transform:${dataset},${model}}
```

When Hydra sees this line, it will **call the `get_default_transform` function**, passing it the `dataset` and `model` configs. This function then *dynamically* figures out the correct data transformations, which is a key part of TopoBench's automation.

### Why This Matters in a Notebook

For Hydra to find and use functions like `get_default_transform`, they must be **registered** first.

In a normal command-line run, TopoBench handles this automatically. But when running in a notebook, **we must register them manually**. We do this by simply **importing them** in a code cell *before* we initialize Hydra.

The next code cells will show the full list of imports from `topobench.utils.config_resolvers` needed for a complete run. This "registration" step unlocks the framework's full power:

* **Automatic Parameter Inference** (e.g., inferring model input channels from data)
* **Dynamic Configuration** (e.g., selecting the right transforms for a given model/dataset)
* **Reduced Manual Setup** (less boilerplate to write in config files)

Now, let's see the complete code in action.

In [4]:
import hydra
from hydra import compose, initialize
from hydra.utils import instantiate



from topobench.utils.config_resolvers import (
    get_default_metrics,
    get_default_trainer,
    get_default_transform,
    get_flattened_channels,
    get_monitor_metric,
    get_monitor_mode,
    get_non_relational_out_channels,
    get_required_lifting,
    infer_in_channels,
    infer_num_cell_dimensions,
)


with initialize(config_path="../configs", job_name="job"):

    cfg = compose(
        config_name="run.yaml",
        overrides=[
            "model=hypergraph/unignn2",
            "dataset=graph/US-county-demos",
        ], 
        return_hydra_config=True
    )
loader = instantiate(cfg.dataset.loader)


dataset, dataset_dir = loader.load()

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  with initialize(config_path="../configs", job_name="job"):


In [5]:
print(dataset)

US-county-demos(self.root=/Users/leone/Desktop/PhD/projects/projects/TopoBench/datasets/graph/cornel, self.name=US-county-demos, self.parameters={'data_domain': 'graph', 'data_type': 'cornel', 'data_name': 'US-county-demos', 'year': 2012, 'task_variable': 'Election', 'data_dir': '/Users/leone/Desktop/PhD/projects/projects/TopoBench/datasets/graph/cornel'}, self.force_reload=False)


In [6]:
print(dataset[0])

Data(x=[3224, 6], edge_index=[2, 18966], y=[3224])


Take a look at the default transforms

In [7]:
print('Transform name:', cfg.transforms.keys())


Transform name: dict_keys(['graph2hypergraph_lifting'])


In [8]:
from topobench.data.preprocessor import PreProcessor
preprocessed_dataset = PreProcessor(dataset, dataset_dir, cfg['transforms'])

Processing...
Done!


In [9]:
preprocessed_dataset[0]

Data(x=[3224, 6], edge_index=[2, 18966], y=[3224], incidence_hyperedges=[3224, 3224], num_hyperedges=3224, x_0=[3224, 6], x_hyperedges=[3224, 6])

The dataset is integrated, lets run some model: 

CLI command: `python -m topobench model=hypergraph/unignn2 dataset=graph/US-county-demos`



In [10]:

import os
from hydra.core.global_hydra import GlobalHydra
from omegaconf import DictConfig
from topobench.dataloader import TBDataloader
from topobench.utils import instantiate_callbacks

# It's good practice to clear Hydra's state in a notebook 
# in case you run the cell multiple times.
GlobalHydra.instance().clear()


def run(cfg: DictConfig) -> DictConfig:
    """Run pipeline with given configuration."""
    # Instantiate and load dataset
    dataset_loader = hydra.utils.instantiate(cfg.dataset.loader)
    dataset, dataset_dir = dataset_loader.load()

    # Preprocess dataset and load the splits
    transform_config = cfg.get("transforms", None)
    preprocessor = PreProcessor(dataset, dataset_dir, transform_config)
    dataset_train, dataset_val, dataset_test = (
        preprocessor.load_dataset_splits(cfg.dataset.split_params)
    )
    
    # Prepare datamodule
    if cfg.dataset.parameters.task_level in ["node", "graph"]:
        datamodule = TBDataloader(
            dataset_train=dataset_train,
            dataset_val=dataset_val,
            dataset_test=dataset_test,
            **cfg.dataset.get("dataloader_params", {}),
        )
    else:
        raise ValueError("Invalid task_level")

    # Model for us is Network + logic: inputs backbone, readout, losses
    model = hydra.utils.instantiate(
        cfg.model,
        evaluator=cfg.evaluator,
        optimizer=cfg.optimizer,
        loss=cfg.loss,
    )
    callbacks = instantiate_callbacks(cfg.get("callbacks"))

    # === FIX IS HERE ===
    # Manually set the output directory for the Trainer since we are in a notebook
    trainer = hydra.utils.instantiate(
        cfg.trainer,
        callbacks=callbacks,
        logger=False,
        num_sanity_val_steps=0,
        default_root_dir="."  # Explicitly set the output directory
    )
    
    trainer.fit(
        model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")
    )
    ckpt_path = trainer.checkpoint_callback.best_model_path
    trainer.test(
        model=model, datamodule=datamodule, ckpt_path=ckpt_path
    )


# --- Your calling code remains the same ---
with initialize(config_path="../configs", job_name="job"):
    cfg = compose(
        config_name="run.yaml",
        overrides=[
            "model=hypergraph/unignn2",
            "dataset=graph/US-county-demos",
            "callbacks=notebook"
        ], 
        return_hydra_config=True
    )
    cfg.trainer.max_epochs=5
    run(cfg)

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  with initialize(config_path="../configs", job_name="job"):
Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.rich_model_summary.RichModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Transform parameters are the same, using existing data_dir: /Users/leone/Desktop/PhD/projects/projects/TopoBench/datasets/graph/cornel/US-county-demos/graph2hypergraph_lifting/304036748


/Users/leone/miniconda3/envs/tb_challenge/lib/python3.11/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.


/Users/leone/miniconda3/envs/tb_challenge/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

/Users/leone/miniconda3/envs/tb_challenge/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/mae improved. New best score: 0.649
`Trainer.fit` stopped: `max_epochs=5` reached.
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  arrmean = um.true_divide(arrmean, div, out=arrmean,
  ret = ret.dtype.type(ret / rcount)
Restoring states from the checkpoint path at ./checkpoints/epoch=4-step=5.ckpt
Loaded model weights from the checkpoint at ./checkpoints/epoch=4-step=5.ckpt
/Users/leone/miniconda3/envs/tb_challenge/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]


