Skip to content

Commit

Permalink
signed compression
Browse files Browse the repository at this point in the history
  • Loading branch information
JaimieMurdock committed Apr 26, 2016
1 parent ce39c73 commit 54d22a2
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 38 deletions.
4 changes: 2 additions & 2 deletions vsm/corpus/base.py
Expand Up @@ -468,9 +468,9 @@ def __init__(self,

# Integer encoding of a string-type corpus
if len(self.words) < 2 ** 16:
self.dtype = np.uint16
self.dtype = np.int16
else:
self.dtype = np.uint32
self.dtype = np.int32

self.corpus = np.asarray([self.words_int[unicode(word)]
for word in self.corpus],
Expand Down
52 changes: 26 additions & 26 deletions vsm/model/_cgs_update.c

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions vsm/model/_cgs_update.pyx
Expand Up @@ -4,7 +4,7 @@ import cython
@cython.cdivision(True)

def cgs_update(int itr,
unsigned int [:] corpus,
int [:] corpus,
double [:,:] word_top,
double [:] inv_top_sums,
double [:,:] top_doc,
Expand Down Expand Up @@ -87,7 +87,7 @@ def cgs_update(int itr,
mtrand_state[3], mtrand_state[4])

def cgs_update_short(int itr,
unsigned short [:] corpus,
short [:] corpus,
double [:,:] word_top,
double [:] inv_top_sums,
double [:,:] top_doc,
Expand Down
12 changes: 6 additions & 6 deletions vsm/model/ldacgsmulti.py
Expand Up @@ -194,10 +194,10 @@ def corpus(self, a):
if self._write_globals:
global _corpus
if not '_corpus' in globals():
if self.corpus.dtype == 'uint16':
dtype = 'H'
elif self.corpus.dtype == 'uint32':
dtype = 'I'
if self.corpus.dtype == 'int16':
dtype = 'h'
elif self.corpus.dtype == 'int32':
dtype = 'i'
else:
raise NotImplementedError

Expand Down Expand Up @@ -405,9 +405,9 @@ def update((docs, doc_indices, mtrand_state, dtype)):

indices = np.array([(j - start) for (i,j) in docs], dtype='i')

if dtype == np.uint16:
if dtype == np.int16:
update_fn = cgs_update_short
elif dtype == np.uint32:
elif dtype == np.int32:
update_fn = cgs_update
else:
raise NotImplementedError
Expand Down
4 changes: 2 additions & 2 deletions vsm/model/ldacgsseq.py
Expand Up @@ -125,9 +125,9 @@ def train(self, n_iterations=100, verbose=1, **kwargs):
:param kwargs: For compatability with calls to LdaCgsMulti.
:type kwargs: optional
"""
if self.corpus.dtype == np.uint16:
if self.corpus.dtype == np.int16:
update = cgs_update_short
elif self.corpus.dtype == np.uint32:
elif self.corpus.dtype == np.int32:
update = cgs_update
else:
raise NotImplementedError(
Expand Down

0 comments on commit 54d22a2

Please sign in to comment.