In [None]:
import rootutils

rootutils.setup_root("./", indicator=".project-root", pythonpath=True)

%load_ext autoreload
%autoreload 2

# Define download functions

In [None]:
from urllib.parse import parse_qs, urlparse

import requests


# Function to extract file ID from Google Drive URL
def get_file_id_from_url(url):
    """
    Extracts the file ID from a Google Drive file URL.

    Args:
        url (str): The Google Drive file URL.

    Returns:
        str: The file ID extracted from the URL.

    Raises:
        ValueError: If the provided URL is not a valid Google Drive file URL.
    """
    parsed_url = urlparse(url)
    query_params = parse_qs(parsed_url.query)
    if "id" in query_params:  # Case 1: URL format contains '?id='
        file_id = query_params["id"][0]
    elif "file/d/" in parsed_url.path:  # Case 2: URL format contains '/file/d/'
        file_id = parsed_url.path.split("/")[3]
    else:
        raise ValueError("The provided URL is not a valid Google Drive file URL.")
    return file_id


# Function to download file from Google Drive
def download_file_from_drive(
    file_link, path_to_save, dataset_name, file_format="tar.gz"
):
    """
    Downloads a file from a Google Drive link and saves it to the specified path.

    Args:
        file_link (str): The Google Drive link of the file to download.
        path_to_save (str): The path where the downloaded file will be saved.
        dataset_name (str): The name of the dataset.
        file_format (str, optional): The format of the downloaded file. Defaults to "tar.gz".

    Returns:
        None

    Raises:
        None
    """
    file_id = get_file_id_from_url(file_link)

    download_link = f"https://drive.google.com/uc?id={file_id}"
    response = requests.get(download_link)

    output_path = f"{path_to_save}/{dataset_name}.{file_format}"
    if response.status_code == 200:
        with open(output_path, "wb") as f:
            f.write(response.content)
        print("Download complete.")
    else:
        print("Failed to download the file.")

# Define the data load function 


In [None]:
import numpy as np
import pandas as pd
import torch
import torch_geometric


def load_us_county_demos(path, year=2012):
    edges_df = pd.read_csv(f"{path}/county_graph.csv")
    stat = pd.read_csv(f"{path}/county_stats_{year}.csv", encoding="ISO-8859-1")

    keep_cols = [
        "FIPS",
        "DEM",
        "GOP",
        "MedianIncome",
        "MigraRate",
        "BirthRate",
        "DeathRate",
        "BachelorRate",
        "UnemploymentRate",
    ]
    # Drop rows with missing values
    stat = stat[keep_cols].dropna()

    # Delete edges that are not present in stat df
    unique_fips = stat["FIPS"].unique()

    src_ = edges_df["SRC"].apply(lambda x: x in unique_fips)
    dst_ = edges_df["DST"].apply(lambda x: x in unique_fips)

    edges_df = edges_df[src_ & dst_]

    # Remove rows from stat df where edges_df['SRC'] or edges_df['DST'] are not present
    stat = stat[stat["FIPS"].isin(edges_df["SRC"]) & stat["FIPS"].isin(edges_df["DST"])]
    stat = stat.reset_index(drop=True)

    # Remove rows where SRC == DST
    edges_df = edges_df[edges_df["SRC"] != edges_df["DST"]]

    # Get torch_geometric edge_index format
    edge_index = torch.tensor(
        np.stack([edges_df["SRC"].to_numpy(), edges_df["DST"].to_numpy()])
    )

    # Make edge_index undirected
    edge_index = torch_geometric.utils.to_undirected(edge_index)

    # Convert edge_index back to pandas DataFrame
    edges_df = pd.DataFrame(edge_index.numpy().T, columns=["SRC", "DST"])

    del edge_index

    # Map stat['FIPS'].unique() to [0, ..., num_nodes]
    fips_map = {fips: i for i, fips in enumerate(stat["FIPS"].unique())}
    stat["FIPS"] = stat["FIPS"].map(fips_map)

    # Map edges_df['SRC'] and edges_df['DST'] to [0, ..., num_nodes]
    edges_df["SRC"] = edges_df["SRC"].map(fips_map)
    edges_df["DST"] = edges_df["DST"].map(fips_map)

    # Get torch_geometric edge_index format
    edge_index = torch.tensor(
        np.stack([edges_df["SRC"].to_numpy(), edges_df["DST"].to_numpy()])
    )

    # Remove isolated nodes (Note: this function maps the nodes to [0, ..., num_nodes] automatically)
    edge_index, _, mask = torch_geometric.utils.remove_isolated_nodes(edge_index)

    # Conver mask to index
    index = np.arange(mask.size(0))[mask]
    stat = stat.iloc[index]
    stat = stat.reset_index(drop=True)

    # Get new values for FIPS from current index
    # To understand why please print stat.iloc[[516, 517, 518, 519, 520]] for 2012 year
    # Basically the FIPS values has been shifted
    stat["FIPS"] = stat.reset_index()["index"]

    # Create Election variable
    stat["Election"] = (stat["DEM"] - stat["GOP"]) / (stat["DEM"] + stat["GOP"])

    # Drop DEM and GOP columns and FIPS
    stat = stat.drop(columns=["DEM", "GOP", "FIPS"])

    # Prediction col
    y_col = "Election"  # TODO: Define through config file
    x_col = list(set(stat.columns).difference(set([y_col])))

    stat["MedianIncome"] = (
        stat["MedianIncome"]
        .apply(lambda x: x.replace(",", ""))
        .to_numpy()
        .astype(float)
    )

    x = stat[x_col].to_numpy()
    y = stat[y_col].to_numpy()

    data = torch_geometric.data.Data(x=x, y=y, edge_index=edge_index)

    return data

# Define the dataset class

The dataset class inherits InMemoryDataset (torch_geometric).

Next it is esential to overwrite three methods: __init__, download, process

As well as a number of properties: raw_dir, processed_dir, raw_file_names, processed_file_names


In [None]:
import os.path as osp
from collections.abc import Callable

from omegaconf import DictConfig
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.io import fs

from topobenchmarkx.io.load.cornel_dataset import load_us_county_demos
from topobenchmarkx.io.load.download_utils import download_file_from_drive
from topobenchmarkx.io.load.split_utils import random_splitting


class USCountyDemosDataset(InMemoryDataset):
    r"""
    Dataset class for US County Demographics dataset.

    Args:
        root (str): Root directory where the dataset will be saved.
        name (str): Name of the dataset.
        parameters (DictConfig): Configuration parameters for the dataset.
        transform (Optional[Callable]): A function/transform that takes in an
            `torch_geometric.data.Data` object and returns a transformed version.
            The transform function is applied to the loaded data before saving it.
        pre_transform (Optional[Callable]): A function/transform that takes in an
            `torch_geometric.data.Data` object and returns a transformed version.
            The pre_transform function is applied to the data before the transform
            function is applied.
        pre_filter (Optional[Callable]): A function that takes in an
            `torch_geometric.data.Data` object and returns a boolean value
            indicating whether the data object should be included in the dataset.
        force_reload (bool): If set to True, the dataset will be re-downloaded
            even if it already exists on disk. (default: True)
        use_node_attr (bool): If set to True, the node attributes will be included
            in the dataset. (default: False)
        use_edge_attr (bool): If set to True, the edge attributes will be included
            in the dataset. (default: False)

    Attributes:
        URLS (dict): Dictionary containing the URLs for downloading the dataset.
        FILE_FORMAT (dict): Dictionary containing the file formats for the dataset.
        RAW_FILE_NAMES (dict): Dictionary containing the raw file names for the dataset.

    """

    URLS = {
        "US-county-demos": "https://drive.google.com/file/d/1FNF_LbByhYNICPNdT6tMaJI9FxuSvvLK/view?usp=sharing",
    }

    FILE_FORMAT = {
        "US-county-demos": "zip",
    }

    RAW_FILE_NAMES = {}

    def __init__(
        self,
        root: str,
        name: str,
        parameters: DictConfig,
        transform: Callable | None = None,
        pre_transform: Callable | None = None,
        pre_filter: Callable | None = None,
        force_reload: bool = True,
    ) -> None:
        # Assign the class variables that would be needed for steps 1, 2, 4, and 3
        self.name = name.replace("_", "-")
        self.parameters = parameters

        # Static, do not modify
        # --------------------------------------------------------
        super().__init__(
            root, transform, pre_transform, pre_filter, force_reload=force_reload
        )

        # Logic that should be modified while adding new dataset:
        # --------------------------------------------------------
        # Step 3:Load the processed data
        # After the data has been downloaded from source
        # Then preprocessed to obtain x,y and saved into processed folder
        # We can now load the processed data from processed folder

        # Load the processed data
        data, _, _ = fs.torch_load(self.processed_paths[0])

        # Map the loaded data into
        data = Data.from_dict(data)

        # Step 4: Create the splits and upload desired fold
        splits = random_splitting(data.y, parameters=self.parameters)
        # Assign train val test masks to the graph
        data.train_mask = torch.from_numpy(splits["train"])
        data.val_mask = torch.from_numpy(splits["valid"])
        data.test_mask = torch.from_numpy(splits["test"])

        # Assign data object to self.data, to make it be prodessed by Dataset class
        self.data = data

    @property
    def raw_dir(self) -> str:
        return osp.join(self.root, self.name, "raw")

    @property
    def processed_dir(self) -> str:
        return osp.join(self.root, self.name, "processed")

    @property
    def raw_file_names(self) -> list[str]:
        names = ["", "_2012"]
        return [f"{self.name}_{name}.txt" for name in names]

    @property
    def processed_file_names(self) -> str:
        return "data.pt"

    def download(self) -> None:
        """
        Downloads the dataset from the specified URL and saves it to the raw directory.

        Raises:
            FileNotFoundError: If the dataset URL is not found.
        """

        # Step 1: Download data from the source
        self.url = self.URLS[self.name]
        self.file_format = self.FILE_FORMAT[self.name]

        download_file_from_drive(
            file_link=self.url,
            path_to_save=self.raw_dir,
            dataset_name=self.name,
            file_format=self.file_format,
        )

        # Extract the downloaded file if it is compressed
        fs.cp(
            f"{self.raw_dir}/{self.name}.{self.file_format}", self.raw_dir, extract=True
        )

        # Move the etracted files to the datasets/domain/dataset_name/raw/ directory
        for filename in fs.ls(osp.join(self.raw_dir, self.name)):
            fs.mv(filename, osp.join(self.raw_dir, osp.basename(filename)))
        fs.rm(osp.join(self.raw_dir, self.name))

        # Delete also f'{self.raw_dir}/{self.name}.{self.file_format}'
        fs.rm(f"{self.raw_dir}/{self.name}.{self.file_format}")

    def process(self) -> None:
        """
        Process the data for the dataset.

        This method loads the US county demographics data, applies any pre-processing transformations if specified,
        and saves the processed data to the appropriate location.

        Returns:
            None
        """
        data = load_us_county_demos(self.raw_dir, year=self.parameters.year)

        data = data if self.pre_transform is None else self.pre_transform(data)
        self.save([data], self.processed_paths[0])

    def __repr__(self) -> str:
        return f"{self.name}({len(self)})"

### Heterophilic datasets

In [None]:
import os
import urllib.request


def hetero_load(name, path="./data/hetero_data"):
    file_name = f"{name}.npz"

    data = np.load(os.path.join(path, file_name))

    x = torch.tensor(data["node_features"])
    y = torch.tensor(data["node_labels"])
    edge_index = torch.tensor(data["edges"]).T

    # Make edge_index undirected
    edge_index = torch_geometric.utils.to_undirected(edge_index)

    # Remove self-loops
    edge_index, _ = torch_geometric.utils.remove_self_loops(edge_index)

    data = torch_geometric.data.Data(x=x, y=y, edge_index=edge_index)
    return data


def download_hetero_datasets(name, path):
    url = "https://github.com/OpenGSL/HeterophilousDatasets/raw/main/data/"
    name = f"{name}.npz"
    try:
        print(f"Downloading {name}")
        path2save = os.path.join(path, name)
        urllib.request.urlretrieve(url + name, path2save)
        print("Done!")
    except:
        raise Exception(
            """Download failed! Make sure you have stable Internet connection and enter the right name"""
        )


from collections.abc import Callable

from omegaconf import DictConfig
from torch_geometric.data import InMemoryDataset

from topobenchmarkx.io.load.us_county_demos import load_us_county_demos


class HeteroDataset(InMemoryDataset):
    r"""
    Dataset class for US County Demographics dataset.

    Args:
        root (str): Root directory where the dataset will be saved.
        name (str): Name of the dataset.
        parameters (DictConfig): Configuration parameters for the dataset.
        transform (Optional[Callable]): A function/transform that takes in an
            `torch_geometric.data.Data` object and returns a transformed version.
            The transform function is applied to the loaded data before saving it.
        pre_transform (Optional[Callable]): A function/transform that takes in an
            `torch_geometric.data.Data` object and returns a transformed version.
            The pre_transform function is applied to the data before the transform
            function is applied.
        pre_filter (Optional[Callable]): A function that takes in an
            `torch_geometric.data.Data` object and returns a boolean value
            indicating whether the data object should be included in the dataset.
        force_reload (bool): If set to True, the dataset will be re-downloaded
            even if it already exists on disk. (default: True)
        use_node_attr (bool): If set to True, the node attributes will be included
            in the dataset. (default: False)
        use_edge_attr (bool): If set to True, the edge attributes will be included
            in the dataset. (default: False)

    Attributes:
        URLS (dict): Dictionary containing the URLs for downloading the dataset.
        FILE_FORMAT (dict): Dictionary containing the file formats for the dataset.
        RAW_FILE_NAMES (dict): Dictionary containing the raw file names for the dataset.

    """

    RAW_FILE_NAMES = {}

    def __init__(
        self,
        root: str,
        name: str,
        parameters: DictConfig,
        transform: Callable | None = None,
        pre_transform: Callable | None = None,
        pre_filter: Callable | None = None,
        force_reload: bool = True,
        use_node_attr: bool = False,
        use_edge_attr: bool = False,
    ) -> None:
        self.name = name  # .replace("_", "-")
        self.parameters = parameters
        super().__init__(
            root, transform, pre_transform, pre_filter, force_reload=force_reload
        )

        # Step 3:Load the processed data
        # After the data has been downloaded from source
        # Then preprocessed to obtain x,y and saved into processed folder
        # We can now load the processed data from processed folder

        # Load the processed data
        data, _, _ = fs.torch_load(self.processed_paths[0])

        # Map the loaded data into
        data = Data.from_dict(data)

        # Step 5: Create the splits and upload desired fold
        splits = random_splitting(data.y, parameters=self.parameters)
        # Assign train val test masks to the graph
        data.train_mask = torch.from_numpy(splits["train"])
        data.val_mask = torch.from_numpy(splits["valid"])
        data.test_mask = torch.from_numpy(splits["test"])

        # Assign data object to self.data, to make it be prodessed by Dataset class
        self.data, self.slices = self.collate([data])

    # Do not forget to take care of properties
    @property
    def raw_dir(self) -> str:
        return osp.join(self.root, self.name, "raw")

    @property
    def processed_dir(self) -> str:
        return osp.join(self.root, self.name, "processed")

    @property
    def processed_file_names(self) -> str:
        return "data.pt"

    @property
    def raw_file_names(self) -> list[str]:
        """Spefify the downloaded raw fine name"""
        return [f"{self.name}.npz"]

    def download(self) -> None:
        """
        Downloads the dataset from the specified URL and saves it to the raw directory.

        Raises:
            FileNotFoundError: If the dataset URL is not found.
        """

        # Step 1: Download data from the source
        download_hetero_datasets(name=self.name, path=self.raw_dir)

    def process(self) -> None:
        """
        Process the data for the dataset.

        This method loads the US county demographics data, applies any pre-processing transformations if specified,
        and saves the processed data to the appropriate location.

        Returns:
            None
        """

        data = hetero_load(name=self.name, path=self.raw_dir)
        data = data if self.pre_transform is None else self.pre_transform(data)
        self.save([data], self.processed_paths[0])

    def __repr__(self) -> str:
        return f"{self.name}()"


data_dir = "/home/lev/projects/TopoBenchmarkX/datasets"
data_domain = "graph"
data_type = "heterophilic"
data_name = "amazon_ratings"

data_dir = f"{data_dir}/{data_domain}/{data_type}"

parameters = {
    "split_type": "random",
    "k": 10,
    "train_prop": 0.5,
    "data_seed": 0,
    "data_split_dir": f"/home/lev/projects/TopoBenchmarkX/datasets/data_splits/{data_name}",
}

dataset = HeteroDataset(
    name=data_name, root=data_dir, parameters=parameters, force_reload=True
)