Skip to content

Commit

Permalink
Merge branch 'master' into add_cgcnn_megnet
Browse files Browse the repository at this point in the history
  • Loading branch information
corochann committed Nov 5, 2019
2 parents 53ed4f1 + b904fa0 commit fb22a79
Show file tree
Hide file tree
Showing 7 changed files with 314 additions and 30 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import os

import numpy
import scipy
import chainer

from chainer_chemistry.dataset.graph_dataset.base_graph_data import PaddingGraphData # NOQA


Expand All @@ -13,10 +16,10 @@ def get_reddit_coo_data(dirpath):
"""

print("Loading node feature and label")
reddit_data = numpy.load(dirpath + "reddit_data.npz")
reddit_data = numpy.load(os.path.join(dirpath, "reddit_data.npz"))

print("Loading edge data")
coo_adj = scipy.sparse.load_npz(dirpath + "reddit_graph.npz")
coo_adj = scipy.sparse.load_npz(os.path.join(dirpath, "reddit_graph.npz"))
row = coo_adj.row.astype(numpy.int32)
col = coo_adj.col.astype(numpy.int32)
data = coo_adj.data.astype(numpy.float32)
Expand Down
13 changes: 9 additions & 4 deletions chainer_chemistry/datasets/citation_network/citation.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import os

import numpy
import networkx as nx
from tqdm import tqdm


def citation_to_networkx(dirpath, name):
G = nx.Graph()
# node feature, node label
with open("{}{}.content".format(dirpath, name)) as f:
with open(os.path.join(dirpath, "{}.content".format(name))) as f:
lines = f.readlines()
compressor = {}
acc = 0
for line in f:
for line in tqdm(lines):
lis = line.split()
key, val = lis[0], lis[-1]
if val in compressor:
Expand All @@ -23,8 +27,9 @@ def citation_to_networkx(dirpath, name):
G.graph['label_num'] = acc

# edge
with open("{}{}.cites".format(dirpath, name)) as f:
for line in f:
with open(os.path.join(dirpath, "{}.cites".format(name))) as f:
lines = f.readlines()
for line in tqdm(lines):
u, v = line.split()
if u not in G.nodes.keys():
print("Warning: {} does not appear in {}{}.content".format(
Expand Down
91 changes: 91 additions & 0 deletions chainer_chemistry/datasets/citation_network/citeseer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import os
import tarfile
from logging import getLogger
from typing import List, Tuple

from chainer.dataset import download

download_url = 'https://linqs-data.soe.ucsc.edu/public/lbc/citeseer.tgz'
feat_file_name = 'citeseer.content'
edge_file_name = 'citeseer.cites'

_root = 'pfnet/chainer/citeseer'

_label_names = ['Agents', 'AI', 'DB', 'IR', 'ML', 'HCI']


def get_citeseer_label_names():
# type: () -> List[str]
"""Return label names of Cora dataset."""
return _label_names


def get_citeseer_dirpath(download_if_not_exist=True):
# type: (bool) -> str
"""Construct a dirpath which stores citeseer dataset.
This method check whether the file exist or not, and downloaded it
if necessary.
Args:
download_if_not_exist (bool): If ``True``, download dataset
if it is not downloaded yet.
Returns:
dirpath (str): directory path for citeseer dataset.
"""
feat_cache_path, edge_cache_path = get_citeseer_filepath(
download_if_not_exist=download_if_not_exist)
dirpath = os.path.dirname(feat_cache_path)
dirpath2 = os.path.dirname(edge_cache_path)
assert dirpath == dirpath2
return dirpath


def get_citeseer_filepath(download_if_not_exist=True):
# type: (bool) -> Tuple[str, str]
"""Construct a filepath which stores citeseer dataset.
This method check whether the file exist or not, and downloaded it
if necessary.
Args:
download_if_not_exist (bool): If ``True``, download dataset
if it is not downloaded yet.
Returns:
feat_cache_path (str): file path for citeseer dataset (features).
edge_cache_path (str): file path for citeseer dataset (edge index).
"""
feat_cache_path, edge_cache_path = _get_citeseer_filepath()
if not os.path.exists(feat_cache_path):
if download_if_not_exist:
is_successful = download_and_extract_citeseer(
save_dirpath=os.path.dirname(feat_cache_path))
if not is_successful:
logger = getLogger(__name__)
logger.warning('Download failed.')
return feat_cache_path, edge_cache_path


def _get_citeseer_filepath():
# type: () -> Tuple[str, str]
"""Construct a filepath which stores citeseer dataset.
This method does not check if the file is already downloaded or not.
Returns:
feat_cache_path (str): file path for citeseer dataset (features).
edge_cache_path (str): file path for citeseer dataset (edge index).
"""
cache_root = download.get_dataset_directory(_root)
feat_cache_path = os.path.join(cache_root, feat_file_name)
edge_cache_path = os.path.join(cache_root, edge_file_name)
return feat_cache_path, edge_cache_path


def download_and_extract_citeseer(save_dirpath):
# type: (str) -> bool
print('downloading citeseer dataset...')
download_file_path = download.cached_download(download_url)
print('extracting citeseer dataset...')
tf = tarfile.open(download_file_path, 'r')
tf.extractall(os.path.dirname(save_dirpath))
return True
93 changes: 93 additions & 0 deletions chainer_chemistry/datasets/citation_network/cora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import os
import tarfile
from logging import getLogger
from typing import List, Tuple

from chainer.dataset import download

download_url = 'https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz'
feat_file_name = 'cora.content'
edge_file_name = 'cora.cites'

_root = 'pfnet/chainer/cora'

_label_names = [
'Case_Based', 'Genetic_Algorithms', 'Neural_Networks',
'Probabilistic_Methods', 'Reinforcement_Learning', 'Rule_Learning',
'Theory'
]


def get_cora_label_names():
# type: () -> List[str]
"""Return label names of Cora dataset."""
return _label_names


def get_cora_dirpath(download_if_not_exist=True):
# type: (bool) -> str
"""Construct a dirpath which stores Cora dataset.
This method check whether the file exist or not, and downloaded it
if necessary.
Args:
download_if_not_exist (bool): If ``True``, download dataset
if it is not downloaded yet.
Returns:
dirpath (str): directory path for Cora dataset.
"""
feat_cache_path, edge_cache_path = get_cora_filepath(
download_if_not_exist=download_if_not_exist)
dirpath = os.path.dirname(feat_cache_path)
dirpath2 = os.path.dirname(edge_cache_path)
assert dirpath == dirpath2
return dirpath


def get_cora_filepath(download_if_not_exist=True):
# type: (bool) -> Tuple[str, str]
"""Construct a filepath which stores Cora dataset.
This method check whether the file exist or not, and downloaded it
if necessary.
Args:
download_if_not_exist (bool): If ``True``, download dataset
if it is not downloaded yet.
Returns:
feat_cache_path (str): file path for Cora dataset (features).
edge_cache_path (str): file path for Cora dataset (edge index).
"""
feat_cache_path, edge_cache_path = _get_cora_filepath()
if not os.path.exists(feat_cache_path):
if download_if_not_exist:
is_successful = download_and_extract_cora(
save_dirpath=os.path.dirname(feat_cache_path))
if not is_successful:
logger = getLogger(__name__)
logger.warning('Download failed.')
return feat_cache_path, edge_cache_path


def _get_cora_filepath():
# type: () -> Tuple[str, str]
"""Construct a filepath which stores Cora dataset.
This method does not check if the file is already downloaded or not.
Returns:
feat_cache_path (str): file path for Cora dataset (features).
edge_cache_path (str): file path for Cora dataset (edge index).
"""
cache_root = download.get_dataset_directory(_root)
feat_cache_path = os.path.join(cache_root, feat_file_name)
edge_cache_path = os.path.join(cache_root, edge_file_name)
return feat_cache_path, edge_cache_path


def download_and_extract_cora(save_dirpath):
# type: (str) -> bool
print('downloading cora dataset...')
download_file_path = download.cached_download(download_url)
print('extracting cora dataset...')
tf = tarfile.open(download_file_path, 'r')
tf.extractall(os.path.dirname(save_dirpath))
return True
86 changes: 84 additions & 2 deletions chainer_chemistry/datasets/reddit/reddit.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,102 @@
import os
from logging import getLogger
from zipfile import ZipFile

import numpy
import networkx as nx
import scipy
from chainer.dataset import download

download_url = 'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/reddit.zip'
feat_file_name = 'reddit_data.npz'
edge_file_name = 'reddit_graph.npz'

_root = 'pfnet/chainer/reddit'


def reddit_to_networkx(dirpath):
print("Loading graph data")
coo_adj = scipy.sparse.load_npz(dirpath + "reddit_graph.npz")
coo_adj = scipy.sparse.load_npz(os.path.join(dirpath, edge_file_name))
G = nx.from_scipy_sparse_matrix(coo_adj)

print("Loading node feature and label")
# node feature, edge label
reddit_data = numpy.load(dirpath + "reddit_data.npz")
reddit_data = numpy.load(os.path.join(dirpath, feat_file_name))
G.graph['x'] = reddit_data['feature'].astype(numpy.float32)
G.graph['y'] = reddit_data['label'].astype(numpy.int32)

G.graph['label_num'] = 41
# G = nx.convert_node_labels_to_integers(G)
print("Finish loading graph: {}".format(dirpath))
return G


def get_reddit_dirpath(download_if_not_exist=True):
# type: (bool) -> str
"""Construct a dirpath which stores reddit dataset.
This method check whether the file exist or not, and downloaded it
if necessary.
Args:
download_if_not_exist (bool): If ``True``, download dataset
if it is not downloaded yet.
Returns:
dirpath (str): directory path for reddit dataset.
"""
feat_cache_path, edge_cache_path = get_reddit_filepath(
download_if_not_exist=download_if_not_exist)
dirpath = os.path.dirname(feat_cache_path)
dirpath2 = os.path.dirname(edge_cache_path)
assert dirpath == dirpath2
return dirpath


def get_reddit_filepath(download_if_not_exist=True):
# type: (bool) -> Tuple[str, str]
"""Construct a filepath which stores reddit dataset.
This method check whether the file exist or not, and downloaded it
if necessary.
Args:
download_if_not_exist (bool): If ``True``, download dataset
if it is not downloaded yet.
Returns:
feat_cache_path (str): file path for reddit dataset (features).
edge_cache_path (str): file path for reddit dataset (edge index).
"""
feat_cache_path, edge_cache_path = _get_reddit_filepath()
if not os.path.exists(feat_cache_path):
if download_if_not_exist:
is_successful = download_and_extract_reddit(
save_dirpath=os.path.dirname(feat_cache_path))
if not is_successful:
logger = getLogger(__name__)
logger.warning('Download failed.')
return feat_cache_path, edge_cache_path


def _get_reddit_filepath():
# type: () -> Tuple[str, str]
"""Construct a filepath which stores reddit dataset.
This method does not check if the file is already downloaded or not.
Returns:
feat_cache_path (str): file path for reddit dataset (features).
edge_cache_path (str): file path for reddit dataset (edge index).
"""
cache_root = download.get_dataset_directory(_root)
feat_cache_path = os.path.join(cache_root, feat_file_name)
edge_cache_path = os.path.join(cache_root, edge_file_name)
return feat_cache_path, edge_cache_path


def download_and_extract_reddit(save_dirpath):
# type: (str) -> bool
print('downloading reddit dataset...')
download_file_path = download.cached_download(download_url)
print('extracting reddit dataset...')
zip = ZipFile(download_file_path, 'r')
zip.extractall(save_dirpath)
return True
24 changes: 17 additions & 7 deletions examples/network_graph/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,37 @@ Before running the example, the following packages also need to be installed:
- [`seaborn`](https://seaborn.pydata.org/)
- [`scikit-learn`](http://scikit-learn.org/stable/)

## How to run the code

### Dataset

Please download the dataset to use, unzip and place it under each directory.
## Supported dataset

- [Cora](https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz)
- [Citeseer](https://linqs-data.soe.ucsc.edu/public/lbc/citeseer.tgz)
- [Reddit](https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/reddit.zip)
- we use the dataset provided by [dmlc/dgl](https://github.com/dmlc/dgl/blob/master/python/dgl/data/reddit.py) repository.

Note that dataset is downloaded automatically.

## How to run the code

### Train a model

To train a model, run the following:

On the CPU:
```angular2html
PYTHONPATH=. python examples/network_graph/train_network_graph.py --dataset cora
python train_network_graph.py --dataset cora
```

On the GPU:
Train sparse model with GPU:
```angular2html
python train_network_graph.py --dataset cora --device 0 --method gin_sparse
```

### Train a model with reddit dataset

reddit dataset contains, it can run only with specific configuration.
Please turn on coo option to run training of reddit dataset.

```angular2html
PYTHONPATH=. python examples/network_graph/train_network_graph.py --dataset cora
python train_network_graph.py --dataset reddit --device 0 --method gin --coo true
```
Loading

0 comments on commit fb22a79

Please sign in to comment.