Skip to content

Commit

Permalink
allow passing graph and label kwargs to dataset construction
Browse files Browse the repository at this point in the history
  • Loading branch information
RemyLau committed Dec 7, 2023
1 parent a6429fc commit 183fc8a
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion src/obnb/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from obnb.label import filters
from obnb.typing import Any, Callable, Dict, List, LogLevel, Optional
from obnb.util.converter import GenePropertyConverter
from obnb.util.misc import default
from obnb.util.version import parse_data_version


Expand All @@ -15,6 +16,8 @@ class OpenBiomedNetBench(Dataset):
root: Directory where the data will be saved.
graph_name: Name of the biological network to use.
label_name: Name of the label sets to use.
graph_kwargs: Keyword arguments for the corresponding graph data obj.
label_kwargs: Keyword arguments for the corresponding label data obj.
version: Archive data version to use. "current" uses the most recent
processed archive data. "latest" download the latest data from
source direction and process it from scratch.
Expand Down Expand Up @@ -48,6 +51,8 @@ def __init__(
graph_name: str,
label_name: str,
*,
graph_kwargs: Optional[Dict[str, Any]] = None,
label_kwargs: Optional[Dict[str, Any]] = None,
version: str = "current",
auto_generate_feature: Optional[str] = "OneHotLogDeg",
graph_as_feature: bool = False,
Expand All @@ -63,11 +68,18 @@ def __init__(
log_level: LogLevel = "INFO",
):
"""Initialize OpenBiomedNetBench object."""
self.graph_kwargs = default(graph_kwargs, {})
self.label_kwargs = default(label_kwargs, {})
self.version = parse_data_version(version)

# Download network data
graph_cls = getattr(obnb.data, graph_name)
graph = graph_cls(root, version=self.version, log_level=log_level)
graph = graph_cls(
root,
version=self.version,
log_level=log_level,
**self.graph_kwargs,
)

# Set up study-bias holdout data splitter
train_ratio = round(1 - val_ratio - test_ratio, 4)
Expand Down Expand Up @@ -116,6 +128,7 @@ def __init__(
log_level=log_level,
),
log_level=log_level,
**self.label_kwargs,
)

# Perform necessary data conversion
Expand Down

0 comments on commit 183fc8a

Please sign in to comment.