Skip to content

Commit

Permalink
Sparse support
Browse files Browse the repository at this point in the history
  • Loading branch information
asistradition committed Apr 10, 2023
1 parent 9a34346 commit 91293f5
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 16 deletions.
6 changes: 6 additions & 0 deletions inferelator_velocity/decay.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,12 @@ def _estimate_decay(
:rtype: float, float
"""

# Cast to dense if needed
try:
expression_data = expression_data.A.ravel()
except AttributeError:
pass

# If there is an estimate for alpha,
# modify velocity to remove the alpha component
if alpha_est is not None:
Expand Down
64 changes: 48 additions & 16 deletions inferelator_velocity/tests/test_decay.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

import numpy as np
import numpy.testing as npt
from scipy.sparse import (
csr_matrix,
vstack
)

from inferelator_velocity.decay import calc_decay, calc_decay_sliding_windows

Expand All @@ -21,11 +25,20 @@

class TestDecay(unittest.TestCase):

def setUp(self) -> None:
self.expr = V_EXPRESSION.copy()
self.vstack = np.vstack
self.expr_nan = np.full_like(V_EXPRESSION, np.nan)
self.expr_one = np.array([1, 0, 0, 0])

def test_calc_decay_no_alpha(self):

decays, decay_se, alpha_est = calc_decay(V_EXPRESSION, VELOCITY,
decay_quantiles=(0, 1),
include_alpha=False)
decays, decay_se, alpha_est = calc_decay(
self.expr,
VELOCITY,
decay_quantiles=(0, 1),
include_alpha=False
)

correct_ses = np.zeros_like(decay_se)

Expand All @@ -38,11 +51,14 @@ def test_calc_decay_no_alpha(self):
def test_calc_decay_nan(self):

v = np.vstack((VELOCITY, np.full_like(VELOCITY, np.nan)))
e = np.vstack((V_EXPRESSION, np.full_like(V_EXPRESSION, np.nan)))
e = self.vstack((self.expr, self.expr_nan))

decays, decay_se, alpha_est = calc_decay(e, v,
decay_quantiles=(0, 1),
include_alpha=False)
decays, decay_se, alpha_est = calc_decay(
e,
v,
decay_quantiles=(0, 1),
include_alpha=False
)

correct_ses = np.zeros_like(decay_se)

Expand All @@ -55,12 +71,15 @@ def test_calc_decay_nan(self):
def test_calc_decay_alpha(self):

velo = np.vstack((VELOCITY, np.array([1, 0, 0, 0])))
expr = np.vstack((V_EXPRESSION, np.array([1, 0, 0, 0])))
expr = self.vstack((self.expr, self.expr_one))

decays, decay_se, alpha_est = calc_decay(expr, velo,
decay_quantiles=(0, 1),
include_alpha=True,
alpha_quantile=1.0)
decays, decay_se, alpha_est = calc_decay(
expr,
velo,
decay_quantiles=(0, 1),
include_alpha=True,
alpha_quantile=1.0
)

correct_alpha = np.maximum(np.max(velo, axis=0), 0)

Expand All @@ -73,10 +92,14 @@ def test_calc_decay_alpha(self):

def test_calc_decay_window(self):

d, s, a, c = calc_decay_sliding_windows(V_EXPRESSION, VELOCITY, TIME,
decay_quantiles=(0, 1),
include_alpha=False,
n_windows=5)
d, s, a, c = calc_decay_sliding_windows(
self.expr,
VELOCITY,
TIME,
decay_quantiles=(0, 1),
include_alpha=False,
n_windows=5
)

self.assertEqual(len(d), 5)
self.assertEqual(len(d[0]), 4)
Expand All @@ -85,3 +108,12 @@ def test_calc_decay_window(self):

for d_win in d:
npt.assert_array_almost_equal(d_win, correct_decays)


class TestDecaySparse(TestDecay):

def setUp(self) -> None:
self.expr = csr_matrix(V_EXPRESSION)
self.vstack = vstack
self.expr_nan = csr_matrix(np.full_like(V_EXPRESSION, np.nan))
self.expr_one = csr_matrix(np.array([1, 0, 0, 0]))

0 comments on commit 91293f5

Please sign in to comment.