Skip to content

Commit

Permalink
fixed bug in token_collocation_matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
internaut committed Apr 18, 2023
1 parent de2e2db commit cd68265
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 9 deletions.
8 changes: 6 additions & 2 deletions tests/test_tokenseq.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,11 +433,15 @@ def test_token_collocation_matrix_hypothesis(sentences, min_count, pass_embed_to
res = tokenseq.token_collocation_matrix(**args)
vocab1 = vocab2 = bigrams_w_indices = None

if return_vocab:
if return_vocab and return_bigrams_with_indices:
assert isinstance(res, tuple)
assert len(res) == 4
mat, vocab1, vocab2, bigrams_w_indices = res
elif return_vocab and not return_bigrams_with_indices:
assert isinstance(res, tuple)
assert len(res) == 3
mat, vocab1, vocab2 = res
elif return_bigrams_with_indices:
elif not return_vocab and return_bigrams_with_indices:
assert isinstance(res, tuple)
assert len(res) == 2
mat, bigrams_w_indices = res
Expand Down
21 changes: 14 additions & 7 deletions tmtoolkit/tokenseq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ def token_collocation_matrix(sentences: List[List[StrOrInt]], min_count: int = 1
return_vocab: bool = False, return_bigrams_with_indices: bool = False) \
-> Union[sparse.csr_matrix,
Tuple[sparse.csr_matrix, np.ndarray, np.ndarray],
Tuple[sparse.csr_matrix, List[Tuple, Tuple[int, int]]]]:
Tuple[sparse.csr_matrix, List[Tuple, Tuple[int, int]]],
Tuple[sparse.csr_matrix, np.ndarray, np.ndarray, List[Tuple, Tuple[int, int]]]]:
"""
Generate a sparse token collocation matrix from bigrams in `sentences`.
Expand All @@ -181,10 +182,14 @@ def token_collocation_matrix(sentences: List[List[StrOrInt]], min_count: int = 1

vocab_dtype = 'uint64' if tokens_as_hashes else 'str'
empty_mat = sparse.csr_matrix([], dtype='uint32', shape=(1, 1))

if return_vocab:
empty_res = (empty_mat, np.array([], dtype=vocab_dtype), np.array([], dtype=vocab_dtype))
elif return_bigrams_with_indices:
empty_vocab1 = np.array([], dtype=vocab_dtype)
empty_vocab2 = empty_vocab1.copy()

if return_vocab and return_bigrams_with_indices:
empty_res = (empty_mat, empty_vocab1, empty_vocab2, [])
elif return_vocab and not return_bigrams_with_indices:
empty_res = (empty_mat, empty_vocab1, empty_vocab2)
elif not return_vocab and return_bigrams_with_indices:
empty_res = (empty_mat, [])
else:
empty_res = empty_mat
Expand Down Expand Up @@ -220,9 +225,11 @@ def token_collocation_matrix(sentences: List[List[StrOrInt]], min_count: int = 1
col_ind = indices_of_matches(bg_second, bg_vocab_second, b_is_sorted=True, check_a_in_b=True)
mat = sparse.coo_matrix((tuple(bigrams.values()), (row_ind, col_ind)), dtype='uint32').tocsr()

if return_vocab:
if return_vocab and return_bigrams_with_indices:
return mat, bg_vocab_first, bg_vocab_second, list(zip(bigrams.keys(), zip(row_ind, col_ind)))
elif return_vocab and not return_bigrams_with_indices:
return mat, bg_vocab_first, bg_vocab_second
elif return_bigrams_with_indices:
elif not return_vocab and return_bigrams_with_indices:
return mat, list(zip(bigrams.keys(), zip(row_ind, col_ind)))
else:
return mat
Expand Down

0 comments on commit cd68265

Please sign in to comment.