Skip to content

Commit

Permalink
Merge ddf06cb into eac3c80
Browse files Browse the repository at this point in the history
  • Loading branch information
sethaxen committed Jan 10, 2021
2 parents eac3c80 + ddf06cb commit cc05417
Show file tree
Hide file tree
Showing 7 changed files with 108 additions and 190 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@ jobs:
- name: Install build dependencies
run: |
python -m pip install --upgrade pip
python -m pip install setuptools wheel numpy cython
python -m pip install setuptools wheel
- name: Build wheel
run: |
python setup.py build_ext --inplace
python setup.py sdist bdist_wheel
- name: Test wheel
run: |
Expand Down
1 change: 0 additions & 1 deletion doc/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,5 @@ dependencies:
- rdkit>=2016.03.4
- numpy>=1.11.3
- scipy>=0.18.0
- cython>=0.25.2
- sdaxen_python_utilities>=0.1.4
- sphinxcontrib-programoutput
171 changes: 0 additions & 171 deletions e3fp/fingerprint/metrics/_fast.pyx

This file was deleted.

91 changes: 88 additions & 3 deletions e3fp/fingerprint/metrics/array_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np
import scipy
from scipy.sparse import csr_matrix, issparse, vstack
from ._fast import soergel as fast_soergel
from e3fp.util import maybe_jit


def tanimoto(X, Y=None):
Expand Down Expand Up @@ -58,15 +58,22 @@ def soergel(X, Y=None):
-------
soergel : array of shape (`n_fprints_X`, `n_fprints_Y`)
Notes
--------
If Numba is available, this function is jit-compiled and much more efficient.
See Also
--------
tanimoto: A fast version of this function for binary data.
pearson: Pearson correlation, also appropriate for non-binary data.
cosine, dice
"""
X, Y = _check_array_pair(X, Y)
return fast_soergel(X, Y, sparse=issparse(X))

S = np.empty((X.shape[0], Y.shape[0]), dtype=np.float64)
if issparse(X):
return _sparse_soergel(X.data, X.indices, X.indptr,
Y.data, Y.indices, Y.indptr, S)
return _dense_soergel(X, Y, S)

def dice(X, Y=None):
"""Compute the Dice coefficients between `X` and `Y`.
Expand Down Expand Up @@ -212,3 +219,81 @@ def _sparse_cosine(X, Y):
XY = (X * Y.T).toarray()
with np.errstate(divide="ignore"): # handle 0 in denominator
return np.nan_to_num(XY / (Xnorm * Ynorm.T))

@maybe_jit(nopython=True, nogil=True, cache=True)
def _dense_soergel(X, Y, S):
for ix in range(S.shape[0]):
for iy in range(S.shape[1]):
sum_abs_diff = 0
sum_max = 0
for j in range(X.shape[1]):
diff = X[ix, j] - Y[iy, j]
if diff > 0:
sum_abs_diff += diff
sum_max += X[ix, j]
else:
sum_abs_diff -= diff
sum_max += Y[iy, j]

if sum_max == 0:
S[ix, iy] = 0
continue
S[ix, iy] = 1 - sum_abs_diff / sum_max
return S

@maybe_jit(nopython=True, nogil=True, cache=True)
def _sparse_soergel(Xdata, Xindices, Xindptr, Ydata, Yindices, Yindptr, S):
for ix in range(S.shape[0]):
if Xindptr[ix] == Xindptr[ix + 1]:
for iy in range(S.shape[1]): # no X values in row
S[ix, iy] = 0
continue
jxindmax = Xindptr[ix + 1] - 1
for iy in range(S.shape[1]):
if Yindptr[iy] == Yindptr[iy + 1]: # no Y values in row
S[ix, iy] = 0
continue

sum_abs_diff = 0
sum_max = 0
# Implementation of the final step of merge sort
jyindmax = Yindptr[iy + 1] - 1
jx = Xindptr[ix]
jy = Yindptr[iy]
while jx <= jxindmax and jy <= jyindmax:
jxind = Xindices[jx]
jyind = Yindices[jy]
if jxind < jyind:
sum_max += Xdata[jx]
sum_abs_diff += Xdata[jx]
jx += 1
elif jyind < jxind:
sum_max += Ydata[jy]
sum_abs_diff += Ydata[jy]
jy += 1
else:
diff = Xdata[jx] - Ydata[jy]
if diff > 0:
sum_abs_diff += diff
sum_max += Xdata[jx]
else:
sum_abs_diff -= diff
sum_max += Ydata[jy]
jx += 1
jy += 1

while jx <= jxindmax:
sum_max += Xdata[jx]
sum_abs_diff += Xdata[jx]
jx += 1

while jy <= jyindmax:
sum_max += Ydata[jy]
sum_abs_diff += Ydata[jy]
jy += 1

if sum_max == 0:
S[ix, iy] = 0
continue
S[ix, iy] = 1 - sum_abs_diff / sum_max
return S
19 changes: 19 additions & 0 deletions e3fp/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,25 @@ class E3FPEfficiencyWarning(E3FPWarning, RuntimeWarning):
"""A warning class for a potentially inefficient process."""


def maybe_jit(*args, **kwargs):
"""Decorator to jit a function using Numba if available.
Usage is identical to `numba.jit`.
"""
def wrapper(func):
try:
import numba
has_numba = True
except ImportError:
has_numba = False

if has_numba:
return numba.jit(*args, **kwargs)(func)
else:
return func
return wrapper


class deprecated(object):
"""Decorator to mark a function as deprecated.
Expand Down
1 change: 0 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ dependencies:
- rdkit>=2016.03.4
- numpy>=1.11.3
- scipy>=0.18.0
- cython>=0.25.2
- setuptools
- sdaxen_python_utilities>=0.1.4
# optional
Expand Down
12 changes: 0 additions & 12 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

from setuptools import setup
from setuptools.command.build_ext import build_ext
from setuptools.extension import Extension
from Cython.Build import cythonize
import numpy as np
from e3fp import version

Expand All @@ -14,7 +12,6 @@
"scipy>=0.18.0",
"numpy>=1.11.3",
"mmh3>=2.3.1",
"cython>=0.25.2",
"sdaxen_python_utilities>=0.1.4",
]
if ON_RTD: # ReadTheDocs can't handle C libraries
Expand All @@ -28,7 +25,6 @@
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Cython",
"License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3)",
"Operating System :: OS Independent",
"Development Status :: 4 - Beta",
Expand All @@ -43,13 +39,6 @@ def get_readme():
with open("README.rst") as f:
return f.read()

ext_modules = [
Extension(
"e3fp.fingerprint.metrics._fast",
sources=["e3fp/fingerprint/metrics/_fast.pyx"],
include_dirs=[np.get_include()],
)
]

setup(
name="e3fp",
Expand All @@ -74,5 +63,4 @@ def get_readme():
include_package_data=True,
test_suite="nose.collector",
tests_require=test_requirements,
ext_modules=cythonize(ext_modules),
)

0 comments on commit cc05417

Please sign in to comment.