In [12]:
import sys
# Add manually root '/home/lev/projects/TopoBenchmarkX'
root_path = '/home/lev/projects/TopoBenchmarkX'
if root_path not in sys.path:
    sys.path.append(root_path)

import os.path as osp
from typing import Callable, List, Optional

from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.io import fs, read_tu_data

from topobenchmarkx.io.load.download_utils import download_file_from_drive

class CornelDataset(InMemoryDataset):
    r"""
    """

    URLS = {
        'contact-high-school': 'https://drive.google.com/open?id=1VA2P62awVYgluOIh1W4NZQQgkQCBk-Eu',
        'US-county-demos': 'https://drive.google.com/file/d/1FNF_LbByhYNICPNdT6tMaJI9FxuSvvLK/view?usp=sharing',
    }

    FILE_FORMAT = {
        'contact-high-school': 'tar.gz',
        'US-county-demos': 'zip',
    }

    def __init__(
        self,
        root: str,
        name: str,
        transform: Optional[Callable] = None,
        pre_transform: Optional[Callable] = None,
        pre_filter: Optional[Callable] = None,
        force_reload: bool = False,
        use_node_attr: bool = False,
        use_edge_attr: bool = False,
        cleaned: bool = False,
    ) -> None:
        self.name = name.replace('_', '-')
        self.cleaned = cleaned
        super().__init__(root, transform, pre_transform, pre_filter,
                         force_reload=force_reload)

        self.data, _, _ = fs.torch_load(self.processed_paths[0])
        
        # if not isinstance(out, tuple) or len(out) < 3:
        #     raise RuntimeError(
        #         "The 'data' object was created by an older version of PyG. "
        #         "If this error occurred while loading an already existing "
        #         "dataset, remove the 'processed/' directory in the dataset's "
        #         "root folder and try again.")
        # assert len(out) == 3 or len(out) == 4

        # if len(out) == 3:  # Backward compatibility.
        #     data, self.slices, self.sizes = out
        #     data_cls = Data
        # else:
        #     data, self.slices, self.sizes, data_cls = out

        # if not isinstance(data, dict):  # Backward compatibility.
        #     self.data = data
        # else:
        #     self.data = data_cls.from_dict(data)

        # assert isinstance(self._data, Data)
        # if self._data.x is not None and not use_node_attr:
        #     num_node_attributes = self.num_node_attributes
        #     self._data.x = self._data.x[:, num_node_attributes:]
        # if self._data.edge_attr is not None and not use_edge_attr:
        #     num_edge_attrs = self.num_edge_attributes
        #     self._data.edge_attr = self._data.edge_attr[:, num_edge_attrs:]

    # @property
    # def raw_dir(self) -> str:
    #     name = f'raw{"_cleaned" if self.cleaned else ""}'
    #     return osp.join(self.root, self.name, name)

    # @property
    # def processed_dir(self) -> str:
    #     name = f'processed{"_cleaned" if self.cleaned else ""}'
    #     return osp.join(self.root, self.name, name)

    # @property
    # def num_node_labels(self) -> int:
    #     return self.sizes['num_node_labels']

    # @property
    # def num_node_attributes(self) -> int:
    #     return self.sizes['num_node_attributes']

    # @property
    # def num_edge_labels(self) -> int:
    #     return self.sizes['num_edge_labels']

    # @property
    # def num_edge_attributes(self) -> int:
    #     return self.sizes['num_edge_attributes']

    # @property
    # def raw_file_names(self) -> List[str]:
    #     names = ['A', 'graph_indicator']
    #     return [f'{self.name}_{name}.txt' for name in names]

    # @property
    # def processed_file_names(self) -> str:
    #     return 'data.pt'

    def download(self) -> None:
        # Download data
        self.url = self.URLS[self.name] 
        self.file_format = self.FILE_FORMAT[self.name]
        
        download_file_from_drive(
            file_link=self.url, 
            path_to_save=self.raw_dir, 
            dataset_name=self.name,
            file_format=self.file_format
        )

        fs.cp(f'{self.raw_dir}/{self.name}.{self.file_format}', self.raw_dir, extract=True)

        # Move into raw/
        for filename in fs.ls(osp.join(self.raw_dir, self.name)):
            fs.mv(filename, osp.join(self.raw_dir, osp.basename(filename)))
        fs.rm(osp.join(self.raw_dir, self.name))

        # Delete also f'{self.raw_dir}/{self.name}.{self.file_format}'
        fs.rm(f'{self.raw_dir}/{self.name}.{self.file_format}')

    def process(self) -> None:
        data = load_us_county_demos(self.raw_dir, self.name)

        data = data if self.pre_transform is None else self.pre_transform(data)
        self.save([data], self.processed_paths[0])

    def __repr__(self) -> str:
        return f'{self.name}({len(self)})'

In [13]:
import numpy as np
import pandas as pd
import torch
import torch_geometric

def load_us_county_demos(path, dataset_name, year=2012):

    edges_df = pd.read_csv(f'{path}/county_graph.csv')
    stat = pd.read_csv(f'{path}/county_stats_{year}.csv', encoding='ISO-8859-1')
    
    keep_cols = ['FIPS', 'DEM', 'GOP', 'MedianIncome', 'MigraRate', 'BirthRate', 'DeathRate', 'BachelorRate', 'UnemploymentRate']
    # Drop rows with missing values
    stat = stat[keep_cols].dropna()

    # Delete edges that are not present in stat df
    unique_fips = stat['FIPS'].unique()

    src_ = edges_df['SRC'].apply(lambda x: x in unique_fips) 
    dst_ = edges_df['DST'].apply(lambda x: x in unique_fips)

    edges_df = edges_df[src_ & dst_]

    # Remove rows from stat df where edges_df['SRC'] or edges_df['DST'] are not present
    stat = stat[stat['FIPS'].isin(edges_df['SRC']) & stat['FIPS'].isin(edges_df['DST'])]
    stat = stat.reset_index(drop=True)

    # Remove rows where SRC == DST
    edges_df = edges_df[edges_df['SRC'] != edges_df['DST']]

    # Get torch_geometric edge_index format
    edge_index = torch.tensor(np.stack([edges_df['SRC'].to_numpy(), edges_df['DST'].to_numpy()]))

    # Make edge_index undirected
    edge_index = torch_geometric.utils.to_undirected(edge_index)

    # Convert edge_index back to pandas DataFrame
    edges_df = pd.DataFrame(edge_index.numpy().T, columns=['SRC', 'DST'])

    del edge_index

    # Map stat['FIPS'].unique() to [0, ..., num_nodes]
    fips_map = {fips: i for i, fips in enumerate(stat['FIPS'].unique())}
    stat['FIPS'] = stat['FIPS'].map(fips_map)

    # Map edges_df['SRC'] and edges_df['DST'] to [0, ..., num_nodes]
    edges_df['SRC'] = edges_df['SRC'].map(fips_map)
    edges_df['DST'] = edges_df['DST'].map(fips_map)

    # Get torch_geometric edge_index format
    edge_index = torch.tensor(np.stack([edges_df['SRC'].to_numpy(), edges_df['DST'].to_numpy()]))

    # Remove isolated nodes (Note: this function maps the nodes to [0, ..., num_nodes] automatically)
    edge_index, _, mask = torch_geometric.utils.remove_isolated_nodes(edge_index)

    # Conver mask to index
    index = np.arange(mask.size(0))[mask]
    stat = stat.iloc[index]
    stat = stat.reset_index(drop=True)

    # Get new values for FIPS from current index
    # To understand why please print stat.iloc[[516, 517, 518, 519, 520]] for 2012 year
    # Basically the FIPS values has been shifted
    stat['FIPS'] = stat.reset_index()['index']

    # Create Election variable
    stat['Election'] = (stat['DEM'] - stat['GOP']) / (stat['DEM'] + stat['GOP'])

    # Drop DEM and GOP columns and FIPS
    stat = stat.drop(columns=['DEM', 'GOP', 'FIPS'])

    # Prediction col
    y_col = 'Election' # TODO: Define through config file
    x_col = list(set(stat.columns).difference(set([y_col])))

    stat['MedianIncome'] = stat['MedianIncome'].apply(lambda x: x.replace(',', '')).to_numpy().astype(float)

    x = stat[x_col].to_numpy()
    y = stat[y_col].to_numpy()


    data = torch_geometric.data.Data(x=x, y=y, edge_index=edge_index)
    
    return data


In [14]:
a = CornelDataset(root='/home/lev/projects/TopoBenchmarkX/datasets/graph', name='US-county-demos')

Download complete.


In [16]:
a[0]

({'x': array([[ 6.9000e+00,  2.1900e+01,  1.1100e+01,  1.0200e+01,  5.1441e+04,
          -6.1000e+00],
         [ 7.5000e+00,  2.8600e+01,  1.1100e+01,  1.0000e+01,  4.8867e+04,
           1.7600e+01],
         [ 1.1500e+01,  1.3600e+01,  1.1000e+01,  1.0700e+01,  3.0287e+04,
          -6.8000e+00],
         ...,
         [ 5.6000e+00,  1.8700e+01,  1.4800e+01,  5.5000e+00,  6.1057e+04,
          -4.5000e+00],
         [ 5.2000e+00,  2.1200e+01,  1.0700e+01,  1.2500e+01,  4.9533e+04,
          -3.0000e+00],
         [ 4.1000e+00,  1.6800e+01,  1.0400e+01,  9.4000e+00,  5.3665e+04,
          -1.0400e+01]]),
  'edge_index': tensor([[   0,    0,    0,  ..., 3106, 3106, 3106],
          [  10,   23,   25,  ..., 3088, 3089, 3097]]),
  'y': array([-0.46424958, -0.56411933,  0.02926744, ..., -0.60490232,
         -0.58287365, -0.73974715])},
 None,
 torch_geometric.data.data.Data)