Skip to content

Commit

Permalink
Fix up sequential sampling + test!
Browse files Browse the repository at this point in the history
  • Loading branch information
Leif Johnson committed Jun 23, 2015
1 parent 2fcf753 commit f64fa1e
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 10 deletions.
9 changes: 9 additions & 0 deletions test/recurrent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,21 @@ def test_feed_forward(self):

def test_predict_sequence(self):
net = self._build(13)

count = 0
for cs in net.predict_sequence([0, 0, 1, 2], 4, streams=3):
assert isinstance(cs, list)
assert len(cs) == 3
count += 1
assert count == 4

count = 0
for cs in net.predict_sequence([0, 0, 1, 2], 4):
print(cs, type(cs))
assert isinstance(cs, int)
count += 1
assert count == 4


class TestWeightedClassifier(TestClassifier):
def _build(self, *hiddens):
Expand Down
21 changes: 11 additions & 10 deletions theanets/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,38 +370,39 @@ def error(self, outputs):
return (weights * nlp).sum() / weights.sum()
return nlp.mean()

def predict_sequence(self, seed, n, streams=1):
'''Draw a sample of n characters sequentially from this classifier.
def predict_sequence(self, seed, steps, streams=1):
'''Draw a sequential sample of classes from this network.
Parameters
----------
seed : list of int
A list of integer class labels to "seed" the classifier.
n : int
steps : int
The number of time steps to sample.
streams : int, optional
Number of parallel streams to sample from the model. Defaults to 1.
Yields
------
label(s) : int or list of ints
label(s) : int or list of int
Yields at each time step an integer class label sampled sequentially
from the model. If the number of requested streams is greater than
1, this will be an array of the corresponding number of integers.
1, this will be a list containing the corresponding number of class
labels.
'''
start = len(seed)
batch = max(2, streams)
inputs = np.zeros((start + n, batch, self.layers[0].size), 'f')
inputs = np.zeros((start + steps, batch, self.layers[0].size), 'f')
inputs[np.arange(start), :, seed] = 1
for i in range(start, start + n):
for i in range(start, start + steps):
chars = []
for pdf in self.predict_proba(inputs[:i])[-1]:
try:
c = np.random.multinomial(1, pdf, size=batch).argmax(axis=-1)
c = np.random.multinomial(1, pdf).argmax(axis=-1)
except ValueError:
# sometimes the pdf triggers a normalization error. just
# choose greedily in this case.
c = pdf.argmax(axis=-1)
chars.append(c)
chars.append(int(c))
inputs[i, np.arange(batch), chars] = 1
yield chars if streams >= 2 else chars[0]
yield chars[0] if streams == 1 else chars

0 comments on commit f64fa1e

Please sign in to comment.