Skip to content

Commit

Permalink
encapsulate dict to csr_matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
dkaslovsky committed Oct 13, 2018
1 parent 00d152d commit f9c5815
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 11 deletions.
10 changes: 4 additions & 6 deletions coupled_biased_random_walks/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

from coupled_biased_random_walks.count import (ObservationCounter,
get_feature_name, get_mode)
from coupled_biased_random_walks.matrix import (random_walk,
from coupled_biased_random_walks.matrix import (dict_to_csr_matrix,
random_walk,
row_normalize_csr_matrix)


Expand Down Expand Up @@ -138,11 +139,8 @@ def _compute_biased_transition_matrix(self):
raise CBRWFitError('all biased joint probabilities are zero')

# construct sparse matrix
# csr_matrix cannot accept iterators so cast to list for python 3
data = list(prob_idx.values())
idx = zip(*list(prob_idx.keys()))
shape = len(self._counter.index)
trans_matrix = csr_matrix((data, idx), shape=(shape, shape))
n_features = len(self._counter.index)
trans_matrix = dict_to_csr_matrix(prob_idx, shape=n_features)
return row_normalize_csr_matrix(trans_matrix)

def _compute_biases(self):
Expand Down
17 changes: 17 additions & 0 deletions coupled_biased_random_walks/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,23 @@ def random_walk(transition_matrix, alpha, err_tol, max_iter):
return pi


def dict_to_csr_matrix(data_dict, shape):
"""
Converts dict of index -> value to csr_matrix
:param data_dict: dict mapping matrix index tuple to corresponding matrix value
:param shape: (row, col) tuple for shape of csr_matrix (also accepts int when row = col)
"""
if not data_dict:
raise ValueError('dict must not be empty')

if isinstance(shape, int):
shape = (shape, shape)
# csr_matrix cannot accept iterators so cast to lists for python 3
data = list(data_dict.values())
idx = zip(*list(data_dict.keys()))
return csr_matrix((data, idx), shape=shape)


def row_normalize_csr_matrix(matrix):
"""
Row normalize a csr matrix without mutating the input
Expand Down
52 changes: 47 additions & 5 deletions tests/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,18 @@
from six import iteritems
from six.moves import zip

from coupled_biased_random_walks.matrix import (random_walk,
from coupled_biased_random_walks.matrix import (dict_to_csr_matrix,
random_walk,
row_normalize_csr_matrix)

np.random.seed(0)


def construct_2x2_matrix(data):
def construct_2x2_csr_matrix(data):
"""
Construct a 2x2 csr_matrix
:param data: list of length 4 of data for csr matrix corresponding to idx position
"""
idx = [(0, 0), (0, 1), (1, 0), (1, 1)]
matrix_data = []
matrix_idx = []
Expand All @@ -24,6 +29,16 @@ def construct_2x2_matrix(data):
return csr_matrix(([], ([], [])), shape=(2 ,2))


def csr_matrix_equality(c1, c2):
"""
Test 2 csr matrices for equality
"""
if c1.shape != c2.shape:
return False
# more efficient to test elements for inequality
return (c1 != c2).nnz == 0


class TestRandomWalk(unittest.TestCase):
"""
Unit tests for random_walk
Expand All @@ -36,21 +51,48 @@ class TestRandomWalk(unittest.TestCase):
def test_random_walk(self):
# prob 0.5, 0.5
data = [0, 1, 1, 0]
matrix = construct_2x2_matrix(data)
matrix = construct_2x2_csr_matrix(data)
pi = random_walk(matrix, alpha=self.alpha, err_tol=self.err_tol, max_iter=self.max_iter)
self.assertEqual(len(pi), 2)
self.assertAlmostEqual(pi[0], 0.5, 3)
self.assertAlmostEqual(pi[1], 0.5, 3)

# prob 1, 0 (alpha = 1)
data = [1, 0, 1, 0]
matrix = construct_2x2_matrix(data)
matrix = construct_2x2_csr_matrix(data)
pi = random_walk(matrix, alpha=1, err_tol=self.err_tol, max_iter=self.max_iter)
self.assertEqual(len(pi), 2)
self.assertAlmostEqual(pi[0], 1, 3)
self.assertAlmostEqual(pi[1], 0, 3)


class TestDictToCSRMatrix(unittest.TestCase):
"""
Unit tests for dict_to_csr_matrix
"""

def test_dict_to_csr_matrix(self):
table = {
'test 1': {
'data_dict': {(0, 1): 25, (1, 0): 16},
'shape': 2,
'expected': construct_2x2_csr_matrix([0, 25, 16, 0])
},
'test 2': {
'data_dict': {(0, 0): 1, (1, 1): 1},
'shape': 2,
'expected': construct_2x2_csr_matrix([1, 0, 0, 1])
}
}

for test_name, params in iteritems(table):
data_dict = params['data_dict']
shape = params['shape']
expected = params['expected']
result = dict_to_csr_matrix(data_dict, shape)
self.assertTrue(csr_matrix_equality(result, expected), test_name)


class TestRowNormalizeCSRMatrix(unittest.TestCase):
"""
Unit tests for row_normalize_csr_matrix
Expand Down Expand Up @@ -87,7 +129,7 @@ def test_valid_row_normalize(self):
}

for test_name, test in iteritems(valid_table):
matrix = construct_2x2_matrix(test['data'])
matrix = construct_2x2_csr_matrix(test['data'])
normalized = row_normalize_csr_matrix(matrix)
row_sums = normalized.sum(axis=1)
self.assertAlmostEqual(row_sums[0], test['expected_row_0'], 3, test_name)
Expand Down

0 comments on commit f9c5815

Please sign in to comment.