Skip to content

Commit

Permalink
Make pyarrow optional (#113)
Browse files Browse the repository at this point in the history
* make pyarrow optional

* add importorskip
  • Loading branch information
lilyminium committed Apr 11, 2024
1 parent 71a4af4 commit 72d52f7
Show file tree
Hide file tree
Showing 9 changed files with 66 additions and 214 deletions.
4 changes: 2 additions & 2 deletions devtools/conda-envs/test_env_dgl_false.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ dependencies:
- scipy
- ambertools

# database
- pyarrow
# # database
# - pyarrow

# gcn
- pytorch >=2.0
Expand Down
171 changes: 0 additions & 171 deletions openff/nagl/label/_label.py

This file was deleted.

23 changes: 14 additions & 9 deletions openff/nagl/label/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,16 @@
import tqdm
import typing

import numpy as np
import pyarrow as pa
import pyarrow.dataset as ds
import pyarrow.parquet as pq

from openff.units import unit

from openff.nagl.utils._parallelization import get_mapper_to_processes
from openff.nagl.label.labels import LabellerType
from openff.utilities import requires_package

if typing.TYPE_CHECKING:
import pyarrow

class LabelledDataset:

@requires_package("pyarrow")
def __init__(
self,
source,
Expand All @@ -28,6 +26,8 @@ def to_pandas(self, columns=None):
return self.dataset.to_table(columns=columns).to_pandas()

def _reload(self):
import pyarrow.dataset as ds

self.dataset = ds.dataset(self.source, format="parquet")

@classmethod
Expand All @@ -44,6 +44,8 @@ def from_smiles(
):
from openff.toolkit import Molecule

import pyarrow as pa
import pyarrow.dataset as ds

loader = functools.partial(
Molecule.from_smiles,
Expand Down Expand Up @@ -84,16 +86,19 @@ def from_smiles(

def append_columns(
self,
columns: typing.Dict[pa.Field, typing.Iterable[typing.Any]],
columns: typing.Dict["pyarrow.Field", typing.Iterable[typing.Any]],
exist_ok: bool = False,
):
self._append_columns(columns, exist_ok=exist_ok)

def _append_columns(
self,
columns: typing.Dict[pa.Field, typing.Iterable[typing.Any]],
columns: typing.Dict["pyarrow.Field", typing.Iterable[typing.Any]],
exist_ok: bool = False,
):
import pyarrow.dataset as ds
import pyarrow.parquet as pq

from .utils import _append_column_to_table

n_all_rows = self.dataset.count_rows()
Expand Down
40 changes: 25 additions & 15 deletions openff/nagl/label/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@
import typing

import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
import pyarrow.dataset as ds

from openff.units import unit
from openff.utilities import requires_package

from openff.nagl._base.base import ImmutableModel
from openff.utilities import requires_package

if typing.TYPE_CHECKING:
import pyarrow

ChargeMethodType = typing.Literal[
"am1bcc", "am1-mulliken", "gasteiger", "formal_charge",
Expand All @@ -28,12 +27,13 @@ class _BaseLabel(ImmutableModel, abc.ABC):
smiles_column: str = "mapped_smiles"
verbose: bool = False

@requires_package("pyarrow")
def _append_column(
self,
table: pa.Table,
key: typing.Union[pa.Field, str],
table: "pyarrow.Table",
key: typing.Union["pyarrow.Field", str],
values: typing.Iterable[typing.Any],
) -> pa.Table:
) -> "pyarrow.Table":
from .utils import _append_column_to_table
return _append_column_to_table(
table,
Expand All @@ -46,9 +46,9 @@ def _append_column(
@abc.abstractmethod
def apply(
self,
table: pa.Table,
table: "pyarrow.Table",
verbose: bool = False,
) -> pa.Table:
) -> "pyarrow.Table":
raise NotImplementedError()


Expand All @@ -62,10 +62,11 @@ class LabelConformers(_BaseLabel):

def apply(
self,
table: pa.Table,
table: "pyarrow.Table",
verbose: bool = False,
):
from openff.toolkit import Molecule
import pyarrow as pa

rms_cutoff = self.rms_cutoff
if not isinstance(rms_cutoff, unit.Quantity):
Expand Down Expand Up @@ -170,9 +171,11 @@ def _assign_charges(

def apply(
self,
table: pa.Table,
table: "pyarrow.Table",
verbose: bool = False,
):
import pyarrow as pa

rows = table.to_pylist()
if verbose:
rows = tqdm.tqdm(rows, desc="Assigning charges")
Expand Down Expand Up @@ -221,9 +224,11 @@ def _calculate_dipoles(

def apply(
self,
table: pa.Table,
table: "pyarrow.Table",
verbose: bool = False,
):
import pyarrow as pa

rows = table.to_pylist()
if verbose:
rows = tqdm.tqdm(rows, desc="Calculating dipoles")
Expand Down Expand Up @@ -322,9 +327,11 @@ def _calculate_esp(

def apply(
self,
table: pa.Table,
table: "pyarrow.Table",
verbose: bool = False,
):
import pyarrow as pa

rows = table.to_pylist()
if verbose:
rows = tqdm.tqdm(rows, desc="Calculating ESPs")
Expand Down Expand Up @@ -402,7 +409,7 @@ def apply(
]

def apply_labellers(
table: pa.Table,
table: "pyarrow.Table",
labellers: typing.Iterable[LabellerType],
verbose: bool = False,
):
Expand All @@ -417,6 +424,9 @@ def apply_labellers_to_batch_file(
labellers: typing.Iterable[LabellerType] = tuple(),
verbose: bool = False,
):
import pyarrow.dataset as ds
import pyarrow.parquet as pq

if not labellers:
return
source = pathlib.Path(source)
Expand Down
6 changes: 2 additions & 4 deletions openff/nagl/label/utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
import logging
import typing

import numpy as np

import pyarrow.parquet as pq
import pyarrow.dataset as ds
from openff.utilities import requires_package

if typing.TYPE_CHECKING:
import pyarrow as pa

logger = logging.getLogger(__name__)

@requires_package("pyarrow")
def _append_column_to_table(
table: "pa.Table",
key: typing.Union["pa.Field", str],
Expand Down

0 comments on commit 72d52f7

Please sign in to comment.