-
Notifications
You must be signed in to change notification settings - Fork 129
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'master' into add_cgcnn_megnet
- Loading branch information
Showing
7 changed files
with
314 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import os | ||
import tarfile | ||
from logging import getLogger | ||
from typing import List, Tuple | ||
|
||
from chainer.dataset import download | ||
|
||
download_url = 'https://linqs-data.soe.ucsc.edu/public/lbc/citeseer.tgz' | ||
feat_file_name = 'citeseer.content' | ||
edge_file_name = 'citeseer.cites' | ||
|
||
_root = 'pfnet/chainer/citeseer' | ||
|
||
_label_names = ['Agents', 'AI', 'DB', 'IR', 'ML', 'HCI'] | ||
|
||
|
||
def get_citeseer_label_names(): | ||
# type: () -> List[str] | ||
"""Return label names of Cora dataset.""" | ||
return _label_names | ||
|
||
|
||
def get_citeseer_dirpath(download_if_not_exist=True): | ||
# type: (bool) -> str | ||
"""Construct a dirpath which stores citeseer dataset. | ||
This method check whether the file exist or not, and downloaded it | ||
if necessary. | ||
Args: | ||
download_if_not_exist (bool): If ``True``, download dataset | ||
if it is not downloaded yet. | ||
Returns: | ||
dirpath (str): directory path for citeseer dataset. | ||
""" | ||
feat_cache_path, edge_cache_path = get_citeseer_filepath( | ||
download_if_not_exist=download_if_not_exist) | ||
dirpath = os.path.dirname(feat_cache_path) | ||
dirpath2 = os.path.dirname(edge_cache_path) | ||
assert dirpath == dirpath2 | ||
return dirpath | ||
|
||
|
||
def get_citeseer_filepath(download_if_not_exist=True): | ||
# type: (bool) -> Tuple[str, str] | ||
"""Construct a filepath which stores citeseer dataset. | ||
This method check whether the file exist or not, and downloaded it | ||
if necessary. | ||
Args: | ||
download_if_not_exist (bool): If ``True``, download dataset | ||
if it is not downloaded yet. | ||
Returns: | ||
feat_cache_path (str): file path for citeseer dataset (features). | ||
edge_cache_path (str): file path for citeseer dataset (edge index). | ||
""" | ||
feat_cache_path, edge_cache_path = _get_citeseer_filepath() | ||
if not os.path.exists(feat_cache_path): | ||
if download_if_not_exist: | ||
is_successful = download_and_extract_citeseer( | ||
save_dirpath=os.path.dirname(feat_cache_path)) | ||
if not is_successful: | ||
logger = getLogger(__name__) | ||
logger.warning('Download failed.') | ||
return feat_cache_path, edge_cache_path | ||
|
||
|
||
def _get_citeseer_filepath(): | ||
# type: () -> Tuple[str, str] | ||
"""Construct a filepath which stores citeseer dataset. | ||
This method does not check if the file is already downloaded or not. | ||
Returns: | ||
feat_cache_path (str): file path for citeseer dataset (features). | ||
edge_cache_path (str): file path for citeseer dataset (edge index). | ||
""" | ||
cache_root = download.get_dataset_directory(_root) | ||
feat_cache_path = os.path.join(cache_root, feat_file_name) | ||
edge_cache_path = os.path.join(cache_root, edge_file_name) | ||
return feat_cache_path, edge_cache_path | ||
|
||
|
||
def download_and_extract_citeseer(save_dirpath): | ||
# type: (str) -> bool | ||
print('downloading citeseer dataset...') | ||
download_file_path = download.cached_download(download_url) | ||
print('extracting citeseer dataset...') | ||
tf = tarfile.open(download_file_path, 'r') | ||
tf.extractall(os.path.dirname(save_dirpath)) | ||
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import os | ||
import tarfile | ||
from logging import getLogger | ||
from typing import List, Tuple | ||
|
||
from chainer.dataset import download | ||
|
||
download_url = 'https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz' | ||
feat_file_name = 'cora.content' | ||
edge_file_name = 'cora.cites' | ||
|
||
_root = 'pfnet/chainer/cora' | ||
|
||
_label_names = [ | ||
'Case_Based', 'Genetic_Algorithms', 'Neural_Networks', | ||
'Probabilistic_Methods', 'Reinforcement_Learning', 'Rule_Learning', | ||
'Theory' | ||
] | ||
|
||
|
||
def get_cora_label_names(): | ||
# type: () -> List[str] | ||
"""Return label names of Cora dataset.""" | ||
return _label_names | ||
|
||
|
||
def get_cora_dirpath(download_if_not_exist=True): | ||
# type: (bool) -> str | ||
"""Construct a dirpath which stores Cora dataset. | ||
This method check whether the file exist or not, and downloaded it | ||
if necessary. | ||
Args: | ||
download_if_not_exist (bool): If ``True``, download dataset | ||
if it is not downloaded yet. | ||
Returns: | ||
dirpath (str): directory path for Cora dataset. | ||
""" | ||
feat_cache_path, edge_cache_path = get_cora_filepath( | ||
download_if_not_exist=download_if_not_exist) | ||
dirpath = os.path.dirname(feat_cache_path) | ||
dirpath2 = os.path.dirname(edge_cache_path) | ||
assert dirpath == dirpath2 | ||
return dirpath | ||
|
||
|
||
def get_cora_filepath(download_if_not_exist=True): | ||
# type: (bool) -> Tuple[str, str] | ||
"""Construct a filepath which stores Cora dataset. | ||
This method check whether the file exist or not, and downloaded it | ||
if necessary. | ||
Args: | ||
download_if_not_exist (bool): If ``True``, download dataset | ||
if it is not downloaded yet. | ||
Returns: | ||
feat_cache_path (str): file path for Cora dataset (features). | ||
edge_cache_path (str): file path for Cora dataset (edge index). | ||
""" | ||
feat_cache_path, edge_cache_path = _get_cora_filepath() | ||
if not os.path.exists(feat_cache_path): | ||
if download_if_not_exist: | ||
is_successful = download_and_extract_cora( | ||
save_dirpath=os.path.dirname(feat_cache_path)) | ||
if not is_successful: | ||
logger = getLogger(__name__) | ||
logger.warning('Download failed.') | ||
return feat_cache_path, edge_cache_path | ||
|
||
|
||
def _get_cora_filepath(): | ||
# type: () -> Tuple[str, str] | ||
"""Construct a filepath which stores Cora dataset. | ||
This method does not check if the file is already downloaded or not. | ||
Returns: | ||
feat_cache_path (str): file path for Cora dataset (features). | ||
edge_cache_path (str): file path for Cora dataset (edge index). | ||
""" | ||
cache_root = download.get_dataset_directory(_root) | ||
feat_cache_path = os.path.join(cache_root, feat_file_name) | ||
edge_cache_path = os.path.join(cache_root, edge_file_name) | ||
return feat_cache_path, edge_cache_path | ||
|
||
|
||
def download_and_extract_cora(save_dirpath): | ||
# type: (str) -> bool | ||
print('downloading cora dataset...') | ||
download_file_path = download.cached_download(download_url) | ||
print('extracting cora dataset...') | ||
tf = tarfile.open(download_file_path, 'r') | ||
tf.extractall(os.path.dirname(save_dirpath)) | ||
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,102 @@ | ||
import os | ||
from logging import getLogger | ||
from zipfile import ZipFile | ||
|
||
import numpy | ||
import networkx as nx | ||
import scipy | ||
from chainer.dataset import download | ||
|
||
download_url = 'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/reddit.zip' | ||
feat_file_name = 'reddit_data.npz' | ||
edge_file_name = 'reddit_graph.npz' | ||
|
||
_root = 'pfnet/chainer/reddit' | ||
|
||
|
||
def reddit_to_networkx(dirpath): | ||
print("Loading graph data") | ||
coo_adj = scipy.sparse.load_npz(dirpath + "reddit_graph.npz") | ||
coo_adj = scipy.sparse.load_npz(os.path.join(dirpath, edge_file_name)) | ||
G = nx.from_scipy_sparse_matrix(coo_adj) | ||
|
||
print("Loading node feature and label") | ||
# node feature, edge label | ||
reddit_data = numpy.load(dirpath + "reddit_data.npz") | ||
reddit_data = numpy.load(os.path.join(dirpath, feat_file_name)) | ||
G.graph['x'] = reddit_data['feature'].astype(numpy.float32) | ||
G.graph['y'] = reddit_data['label'].astype(numpy.int32) | ||
|
||
G.graph['label_num'] = 41 | ||
# G = nx.convert_node_labels_to_integers(G) | ||
print("Finish loading graph: {}".format(dirpath)) | ||
return G | ||
|
||
|
||
def get_reddit_dirpath(download_if_not_exist=True): | ||
# type: (bool) -> str | ||
"""Construct a dirpath which stores reddit dataset. | ||
This method check whether the file exist or not, and downloaded it | ||
if necessary. | ||
Args: | ||
download_if_not_exist (bool): If ``True``, download dataset | ||
if it is not downloaded yet. | ||
Returns: | ||
dirpath (str): directory path for reddit dataset. | ||
""" | ||
feat_cache_path, edge_cache_path = get_reddit_filepath( | ||
download_if_not_exist=download_if_not_exist) | ||
dirpath = os.path.dirname(feat_cache_path) | ||
dirpath2 = os.path.dirname(edge_cache_path) | ||
assert dirpath == dirpath2 | ||
return dirpath | ||
|
||
|
||
def get_reddit_filepath(download_if_not_exist=True): | ||
# type: (bool) -> Tuple[str, str] | ||
"""Construct a filepath which stores reddit dataset. | ||
This method check whether the file exist or not, and downloaded it | ||
if necessary. | ||
Args: | ||
download_if_not_exist (bool): If ``True``, download dataset | ||
if it is not downloaded yet. | ||
Returns: | ||
feat_cache_path (str): file path for reddit dataset (features). | ||
edge_cache_path (str): file path for reddit dataset (edge index). | ||
""" | ||
feat_cache_path, edge_cache_path = _get_reddit_filepath() | ||
if not os.path.exists(feat_cache_path): | ||
if download_if_not_exist: | ||
is_successful = download_and_extract_reddit( | ||
save_dirpath=os.path.dirname(feat_cache_path)) | ||
if not is_successful: | ||
logger = getLogger(__name__) | ||
logger.warning('Download failed.') | ||
return feat_cache_path, edge_cache_path | ||
|
||
|
||
def _get_reddit_filepath(): | ||
# type: () -> Tuple[str, str] | ||
"""Construct a filepath which stores reddit dataset. | ||
This method does not check if the file is already downloaded or not. | ||
Returns: | ||
feat_cache_path (str): file path for reddit dataset (features). | ||
edge_cache_path (str): file path for reddit dataset (edge index). | ||
""" | ||
cache_root = download.get_dataset_directory(_root) | ||
feat_cache_path = os.path.join(cache_root, feat_file_name) | ||
edge_cache_path = os.path.join(cache_root, edge_file_name) | ||
return feat_cache_path, edge_cache_path | ||
|
||
|
||
def download_and_extract_reddit(save_dirpath): | ||
# type: (str) -> bool | ||
print('downloading reddit dataset...') | ||
download_file_path = download.cached_download(download_url) | ||
print('extracting reddit dataset...') | ||
zip = ZipFile(download_file_path, 'r') | ||
zip.extractall(save_dirpath) | ||
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.