Skip to content

Commit

Permalink
Add a Text helper class for recurrent text processing.
Browse files Browse the repository at this point in the history
  • Loading branch information
Leif Johnson committed Jun 19, 2015
1 parent 89dec00 commit 7100f06
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 0 deletions.
35 changes: 35 additions & 0 deletions test/recurrent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,41 @@ def test_batches_unlabeled(self):
assert f()[0].shape == (STEPS, BATCH, INS)


class TestText:
TXT = 'hello world, how are you!'

def setUp(self):
self.txt = theanets.recurrent.Text(self.TXT, alpha='helo wrd,!', unknown='_')

def test_min_count(self):
txt = theanets.recurrent.Text(self.TXT, min_count=2, unknown='_')
assert txt.text == 'hello worl__ how _re _o__'
assert txt.alpha == ' ehlorw'

txt = theanets.recurrent.Text(self.TXT, min_count=3, unknown='_')
assert txt.text == '__llo _o_l__ _o_ ___ _o__'
assert txt.alpha == ' lo'

def test_alpha(self):
assert self.txt.text == 'hello world, how _re _o_!'
assert self.txt.alpha == 'helo wrd,!'

def test_encode(self):
assert self.txt.encode('hello!') == [1, 2, 3, 3, 4, 10]
assert self.txt.encode('you!') == [0, 4, 0, 10]

def test_decode(self):
assert self.txt.decode([1, 2, 3, 3, 4, 10]) == 'hello!'
assert self.txt.decode([0, 4, 0, 10]) == '_o_!'

def test_classifier_batches(self):
b = self.txt.classifier_batches(3, 2)
assert len(b()) == 2
assert b()[0].shape == (3, 2, 1 + len(self.txt.alpha))
assert b()[1].shape == (3, 2)
assert not np.allclose(b()[0], b()[0])


class Base:
def setUp(self):
np.random.seed(3)
Expand Down
132 changes: 132 additions & 0 deletions theanets/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

'''This module contains recurrent network structures.'''

import collections
import numpy as np
import re
import sys
import theano.tensor as TT

Expand Down Expand Up @@ -52,6 +54,104 @@ def labeled_sample():
return unlabeled_sample if labels is None else labeled_sample


class Text(object):
'''A class for handling sequential text data.
Parameters
----------
text : str
A blob of text.
alpha : str, optional
An alphabet to use for representing characters in the text. If not
provided, all characters from the text occurring at least ``min_count``
times will be used.
min_count : int, optional
If the alphabet is to be computed from the text, discard characters that
occur fewer than this number of times. Defaults to 2.
unknown : str, optional
A character to use to represent "out-of-alphabet" characters in the
text. This must not be in the alphabet. Defaults to '\0'.
Attributes
----------
text : str
A blob of text, with all non-alphabet characters replaced by the
"unknown" character.
alpha : str
A string containing each character in the alphabet.
'''

def __init__(self, text, alpha=None, min_count=2, unknown='\0'):
self.alpha = alpha
if self.alpha is None:
self.alpha = ''.join(sorted(set(
a for a, c in collections.Counter(text).items() if c >= min_count)))
self.text = re.sub(r'[^{}]'.format(re.escape(self.alpha)), unknown, text)
assert unknown not in self.alpha
self._rev_index = unknown + self.alpha
self._fwd_index = dict(zip(self._rev_index, range(1 + len(self.alpha))))

def encode(self, txt):
'''Encode a text string by replacing characters with alphabet index.
Parameters
----------
txt : str
A string to encode.
Returns
-------
classes : list of int
A sequence of alphabet index values corresponding to the given text.
'''
return list(self._fwd_index.get(c, 0) for c in txt)

def decode(self, enc):
'''Encode a text string by replacing characters with alphabet index.
Parameters
----------
classes : list of int
A sequence of alphabet index values to convert to text.
Returns
-------
txt : str
A string containing corresponding characters from the alphabet.
'''
return ''.join(self._rev_index[c] for c in enc)

def classifier_batches(self, time_steps, batch_size):
'''Create a callable that returns a batch of training data.
Parameters
----------
time_steps : int
Number of time steps in each batch.
batch_size : int
Number of training examples per batch.
Returns
-------
batch : callable
A callable that, when called, returns a batch of data that can be
used to train a classifier model.
'''
assert batch_size >= 2, 'batch_size must be at least 2!'

def batch():
inputs = np.zeros((time_steps, batch_size, 1 + len(self.alpha)), 'f')
outputs = np.zeros((time_steps, batch_size), 'i')
for b in range(batch_size):
offset = np.random.randint(len(self.text) - time_steps - 1)
enc = self.encode(self.text[offset:offset + time_steps + 1])
inputs[np.arange(time_steps), b, enc[:-1]] = 1
outputs[np.arange(time_steps), b] = enc[1:]
return [inputs, outputs]

return batch


_warned = False


Expand Down Expand Up @@ -267,3 +367,35 @@ def error(self, outputs):
if self.weighted:
return (weights * nlp).sum() / weights.sum()
return nlp.mean()

def sample(self, seed, n, streams=1):
'''Draw a sample of n characters from a sequential classifier model.
Parameters
----------
seed : list of int
A list of integer class labels to "seed" the classifier.
n : int
The number of time steps to sample.
streams : int, optional
Number of parallel streams to sample from the model. Defaults to 1.
Yields
------
label(s) : int or array of ints
Yields at each time step an integer class label sampled sequentially
from the model. If the number of requested streams is greater than
1, this will be an array of the corresponding number of integers.
'''
s = len(seed)
b = max(2, streams)
inputs = np.zeros((s + n, b, self.layers[-1].size), 'f')
inputs[np.arange(s), :, seed] = 1
for i in range(s, s + n):
pdf = self.predict_proba(inputs[:i])[-1]
try:
c = np.random.multinomial(1, pdf, size=b).argmax(axis=-1)
except:
c = pdf.argmax(axis=-1)
inputs[j, np.arange(b), c] = 1
yield c if streams >= 2 else c[0]

0 comments on commit 7100f06

Please sign in to comment.