Skip to content

Commit

Permalink
Support for sparse matrices
Browse files Browse the repository at this point in the history
  • Loading branch information
asistradition committed Apr 10, 2023
1 parent dae432d commit 9a34346
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 20 deletions.
78 changes: 60 additions & 18 deletions inferelator_velocity/tests/test_velocity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import numpy.testing as npt
from scipy.sparse import csr_matrix

from inferelator_velocity.velocity import calc_velocity, _calc_local_velocity

Expand All @@ -25,16 +26,29 @@

class TestVelocity(unittest.TestCase):

def setUp(self) -> None:
self.expr = T_EXPRESSION.copy()
self.ones_graph = np.ones((N, N))
self.knn = KNN.copy()

def test_calc_velocity(self):

correct_velo = np.tile(T_SLOPES[:, None], N).T
velo = calc_velocity(T_EXPRESSION, TIME, np.ones((N, N)),
wrap_time=None)
velo = calc_velocity(
self.expr,
TIME,
self.ones_graph,
wrap_time=None
)

npt.assert_array_almost_equal(correct_velo, velo)

velo_wrap = calc_velocity(T_EXPRESSION, TIME, np.ones((N, N)),
wrap_time=0)
velo_wrap = calc_velocity(
self.expr,
TIME,
self.ones_graph,
wrap_time=0
)

npt.assert_array_almost_equal(correct_velo, velo_wrap)

Expand All @@ -46,13 +60,21 @@ def test_calc_velocity_nan(self):
t = TIME.copy().astype(float)
t[0] = np.nan

velo = calc_velocity(T_EXPRESSION, t, np.ones((N, N)),
wrap_time=None)
velo = calc_velocity(
self.expr,
t,
self.ones_graph,
wrap_time=None
)

npt.assert_array_almost_equal(correct_velo, velo)

velo_wrap = calc_velocity(T_EXPRESSION, t, np.ones((N, N)),
wrap_time=0)
velo_wrap = calc_velocity(
self.expr,
t,
self.ones_graph,
wrap_time=0
)

npt.assert_array_almost_equal(correct_velo, velo_wrap)

Expand All @@ -62,20 +84,32 @@ def test_calc_velocity_wraps(self):

npt.assert_array_almost_equal(
correct_velo,
calc_velocity(T_EXPRESSION, TIME, KNN,
wrap_time=None)
calc_velocity(
self.expr,
TIME,
self.knn,
wrap_time=None
)
)

npt.assert_array_almost_equal(
correct_velo,
calc_velocity(T_EXPRESSION, TIME, KNN,
wrap_time=200)
calc_velocity(
self.expr,
TIME,
self.knn,
wrap_time=200
)
)

npt.assert_array_almost_equal(
correct_velo,
calc_velocity(T_EXPRESSION, TIME, KNN,
wrap_time=0)
calc_velocity(
self.expr,
TIME,
self.knn,
wrap_time=0
)
)

# Correct the first and last velocities
Expand All @@ -86,14 +120,14 @@ def test_calc_velocity_wraps(self):

npt.assert_array_almost_equal(
wrap_edge_correct,
calc_velocity(T_EXPRESSION, TIME, KNN,
calc_velocity(self.expr, TIME, self.knn,
wrap_time=10)
)

def test_single_velocity(self):

velo_0 = _calc_local_velocity(
T_EXPRESSION[0:5],
self.expr[0:5],
TIME[0:5],
2
)
Expand All @@ -103,7 +137,7 @@ def test_single_velocity(self):
def test_single_velocity_wrap(self):

velo_0 = _calc_local_velocity(
T_EXPRESSION[0:6],
self.expr[0:6],
np.hstack((TIME[7:], TIME[0:3])),
2,
wrap_time=N
Expand All @@ -112,10 +146,18 @@ def test_single_velocity_wrap(self):
npt.assert_array_almost_equal(velo_0.ravel(), T_SLOPES)

velo_1 = _calc_local_velocity(
T_EXPRESSION[0:6],
self.expr[0:6],
np.hstack((TIME[7:], TIME[0:3])),
3,
wrap_time=N
)

npt.assert_array_almost_equal(velo_1.ravel(), T_SLOPES)


class TestVelocitySparse(TestVelocity):

def setUp(self) -> None:
self.expr = csr_matrix(T_EXPRESSION)
self.ones_graph = csr_matrix(np.ones((N, N)))
self.knn = csr_matrix(KNN)
10 changes: 8 additions & 2 deletions inferelator_velocity/velocity.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def calc_velocity(
return _np.vstack(
[
_calc_local_velocity(
expr[n_idx, :].copy(),
time_axis[n_idx].copy(),
expr[n_idx, :],
time_axis[n_idx],
(n_idx == i).nonzero()[0][0],
wrap_time=wrap_time
)
Expand Down Expand Up @@ -57,6 +57,12 @@ def _calc_local_velocity(
if wrap_time is not None:
time_axis = _wrap_time(time_axis, wrap_time)

# Densify a sparse matrix
try:
expr = expr.A
except AttributeError:
pass

# Calculate change in expression and time relative to the centerpoint
y_diff = _np.subtract(expr, expr[center_index, :])

Expand Down

0 comments on commit 9a34346

Please sign in to comment.