Skip to content

Commit

Permalink
fixing random seed types on load. closes #112
Browse files Browse the repository at this point in the history
  • Loading branch information
JaimieMurdock committed May 30, 2015
1 parent 6906011 commit 28c7d68
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 11 deletions.
28 changes: 28 additions & 0 deletions unit_tests/tests_ldacgsmulti.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,28 @@ def test_LdaCgsMulti_IO(self):
finally:
os.remove(tmp.name)

def test_LdaCgsMulti_SeedTypes(self):
""" Test for issue #74 issues. """

from tempfile import NamedTemporaryFile
import os

c = random_corpus(1000, 50, 6, 100)
tmp = NamedTemporaryFile(delete=False, suffix='.npz')
try:
m0 = LdaCgsMulti(c, 'document', K=10)
m0.train(n_iterations=20)
m0.save(tmp.name)
m1 = LdaCgsMulti.load(tmp.name)

for s0, s1 in zip(m0.seeds, m1.seeds):
assert type(s0) == type(s1)
for s0, s1 in zip(m0._mtrand_states,m1._mtrand_states):
for i in range(5):
assert type(s0[i]) == type(s1[i])
finally:
os.remove(tmp.name)

def test_LdaCgsMulti_random_seeds(self):
from vsm.corpus.util.corpusbuilders import random_corpus

Expand Down Expand Up @@ -165,6 +187,12 @@ def test_LdaCgsMulti_IO(self):
p.start()
p.join()

def test_LdaCgsMulti_SeedTypes(self):
t = MPTester()
p = Process(target=t.test_LdaCgsMulti_SeedTypes, args=())
p.start()
p.join()

def test_LdaCgsMulti_random_seeds(self):
t = MPTester()
p = Process(target=t.test_LdaCgsMulti_random_seeds, args=())
Expand Down
23 changes: 23 additions & 0 deletions unit_tests/tests_ldacgsseq.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,29 @@ def test_LdaCgsSeq_IO(self):
self.assertTrue(not hasattr(m1, 'log_prob'))
finally:
os.remove(tmp.name)

def test_LdaCgsSeq_SeedTypes(self):
""" Test for issue #74 issues. """

from tempfile import NamedTemporaryFile
import os

c = random_corpus(1000, 50, 6, 100)
tmp = NamedTemporaryFile(delete=False, suffix='.npz')
try:
m0 = LdaCgsSeq(c, 'document', K=10)
m0.train(n_iterations=20)
m0.save(tmp.name)
m1 = LdaCgsSeq.load(tmp.name)

self.assertTrue(type(m0.seed) == type(m1.seed))
self.assertTrue(type(m0._mtrand_state[0]) == type(m1._mtrand_state[0]))
self.assertTrue(type(m0._mtrand_state[1]) == type(m1._mtrand_state[1]))
self.assertTrue(type(m0._mtrand_state[2]) == type(m1._mtrand_state[2]))
self.assertTrue(type(m0._mtrand_state[3]) == type(m1._mtrand_state[3]))
self.assertTrue(type(m0._mtrand_state[4]) == type(m1._mtrand_state[4]))
finally:
os.remove(tmp.name)


def test_LdaCgsQuerySampler_init(self):
Expand Down
22 changes: 11 additions & 11 deletions vsm/model/ldafunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,21 +109,21 @@ def load_lda(filename, ldaclass):
m.log_probs = arrays_in['log_probs'].tolist()

if 'seed' in arrays_in:
m.seed = arrays_in['seed']
m._mtrand_state = (arrays_in['mtrand_state0'],
m.seed = int(arrays_in['seed'])
m._mtrand_state = (str(arrays_in['mtrand_state0']),
arrays_in['mtrand_state1'],
arrays_in['mtrand_state2'],
arrays_in['mtrand_state3'],
arrays_in['mtrand_state4'])
int(arrays_in['mtrand_state2']),
int(arrays_in['mtrand_state3']),
float(arrays_in['mtrand_state4']))

if 'seeds' in arrays_in:
m.seeds = list(arrays_in['seeds'])
m._mtrand_states = zip(arrays_in['mtrand_states0'],
m.seeds = map(int, list(arrays_in['seeds']))
m._mtrand_states = zip(map(str, arrays_in['mtrand_states0']),
arrays_in['mtrand_states1'],
arrays_in['mtrand_states2'],
arrays_in['mtrand_states3'],
arrays_in['mtrand_states4'])

map(int, arrays_in['mtrand_states2']),
map(int, arrays_in['mtrand_states3']),
map(float, arrays_in['mtrand_states4']))
m.n_proc = len(m.seeds)

return m

Expand Down

0 comments on commit 28c7d68

Please sign in to comment.