-
Notifications
You must be signed in to change notification settings - Fork 349
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
236 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import random | ||
import bisect | ||
import re | ||
from .chain import BEGIN, END, accumulate | ||
from .text import Text, ParamError, DEFAULT_TRIES | ||
|
||
word_split_pattern = re.compile(r"\s+") | ||
def word_split(sentence): | ||
return re.split(word_split_pattern, sentence) | ||
|
||
|
||
class CompiledText(object): | ||
""" | ||
A "compiled" version of a Text model, where state transition functions have been | ||
pre-computed into a faster and more compact form. | ||
""" | ||
def __init__(self, model): | ||
""" | ||
model: a Text model to compile | ||
""" | ||
if not isinstance(model, Text): | ||
raise ParamError("unrecognized markofivy model type %s" % (type(model))) | ||
def compile_next(next_dict): | ||
words = list(next_dict.keys()) | ||
cff = list(accumulate(next_dict.values())) | ||
return (words, cff) | ||
# the compiled state transition function: | ||
self.sxf = { state: compile_next(next_dict) for (state, next_dict) in model.chain.model.items() } | ||
self.state_size = model.state_size | ||
|
||
def _gen_(self, init_state, max_words): | ||
state = init_state or (BEGIN,) * self.state_size | ||
prefix = list(state) | ||
for word in prefix: | ||
if word != BEGIN: break | ||
prefix = prefix[1:] | ||
seq = prefix | ||
while True: | ||
words, cff = self.sxf[state] | ||
r = random.random() * cff[-1] | ||
word = words[bisect.bisect(cff, r)] | ||
if word == END: break | ||
seq.append(word) | ||
if max_words != None and len(seq) > max_words: break | ||
state = state[1:] + (word,) | ||
return seq | ||
|
||
def make_sentence(self, init_state = None, tries = DEFAULT_TRIES, max_words = None): | ||
""" | ||
Attempts `tries` (default: 10) times to generate a valid sentence, | ||
based on the model. | ||
If successful, returns the sentence as a string. If not, returns None. | ||
If `init_state` (a tuple of `self.state_size` words) is not specified, | ||
this method chooses a sentence-start at random, in accordance with | ||
the model. | ||
If `max_words` is specified, it will attempt generation until a sentence | ||
of at most max_words long is created (or number of tries expires) | ||
""" | ||
for _ in range(tries): | ||
seq = self._gen_(init_state, None if (max_words == None) else max_words + 1) | ||
if max_words != None and len(seq) > max_words: continue | ||
return " ".join(seq) | ||
return None | ||
|
||
def make_short_sentence(self, max_chars, min_chars=0, **kwargs): | ||
""" | ||
Tries making a sentence of no more than `max_chars` characters and optionally | ||
no less than `min_chars` characters, passing **kwargs to `self.make_sentence`. | ||
""" | ||
tries = kwargs.get('tries', DEFAULT_TRIES) | ||
|
||
for _ in range(tries): | ||
sentence = self.make_sentence(**kwargs) | ||
if sentence and len(sentence) <= max_chars and len(sentence) >= min_chars: | ||
return sentence | ||
return None | ||
|
||
def make_sentence_with_start(self, beginning, strict=True, **kwargs): | ||
""" | ||
Tries making a sentence that begins with `beginning` string, | ||
which should be a string of one to `self.state_size` words known | ||
to exist in the corpus. | ||
If strict == True, then markovify will draw its initial inspiration | ||
only from sentences that start with the specified word/phrase. | ||
If strict == False, then markovify will draw its initial inspiration | ||
from any sentence containing the specified word/phrase. | ||
**kwargs are passed to `self.make_sentence` | ||
""" | ||
split = tuple(word_split(beginning)) | ||
word_count = len(split) | ||
|
||
if word_count == self.state_size: | ||
init_states = [ split ] | ||
elif word_count > 0 and word_count < self.state_size: | ||
if strict: | ||
init_states = [ (BEGIN,) * (self.state_size - word_count) + split ] | ||
else: | ||
init_states = [ key for key in self.sxf.keys() | ||
# check for starting with begin as well ordered lists | ||
if tuple(filter(lambda x: x != BEGIN, key))[:word_count] == split ] | ||
random.shuffle(init_states) | ||
else: | ||
err_msg = "`make_sentence_with_start` for this model requires a string containing 1 to {0} words. Yours has {1}: {2}".format(self.state_size, word_count, str(split)) | ||
raise ParamError(err_msg) | ||
|
||
for init_state in init_states: | ||
output = self.make_sentence(init_state, **kwargs) | ||
if output is not None: | ||
return output | ||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
import unittest | ||
import markovify | ||
import sys, os | ||
import operator | ||
|
||
with open(os.path.join(os.path.dirname(__file__), "texts/sherlock.txt")) as f: | ||
sherlock = f.read() | ||
sherlock_model = markovify.Text(sherlock) | ||
sherlock_compiled = markovify.CompiledText(sherlock_model) | ||
|
||
class MarkovifyCompiledTest(unittest.TestCase): | ||
def test_constructor_typechecking(self): | ||
with self.assertRaises(markovify.text.ParamError) as context: | ||
model = markovify.CompiledText("this string is not a Text model") | ||
|
||
def test_sherlock(self): | ||
text_model = sherlock_compiled | ||
sent = text_model.make_sentence() | ||
assert(len(sent) != 0) | ||
|
||
def test_make_sentence_with_start(self): | ||
text_model = sherlock_compiled | ||
start_str = "Sherlock Holmes" | ||
sent = text_model.make_sentence_with_start(start_str) | ||
assert(sent != None) | ||
assert(start_str == sent[:len(start_str)]) | ||
|
||
def test_make_sentence_with_start_one_word(self): | ||
text_model = sherlock_compiled | ||
start_str = "Sherlock" | ||
sent = text_model.make_sentence_with_start(start_str) | ||
assert(sent != None) | ||
assert(start_str == sent[:len(start_str)]) | ||
|
||
def test_make_sentence_with_start_one_word_that_doesnt_begin_a_sentence(self): | ||
text_model = sherlock_compiled | ||
start_str = "dog" | ||
with self.assertRaises(KeyError) as context: | ||
sent = text_model.make_sentence_with_start(start_str) | ||
|
||
def test_make_sentence_with_word_not_at_start_of_sentence(self): | ||
text_model = sherlock_compiled | ||
start_str = "dog" | ||
sent = text_model.make_sentence_with_start(start_str, strict=False) | ||
assert(sent != None) | ||
assert(start_str == sent[:len(start_str)]) | ||
|
||
def test_make_sentence_with_words_not_at_start_of_sentence(self): | ||
text_model = markovify.Text(sherlock, state_size=3) | ||
text_model = markovify.CompiledText(text_model) | ||
# " I was " has 128 matches in sherlock.txt | ||
# " was I " has 2 matches in sherlock.txt | ||
start_str = "was I" | ||
sent = text_model.make_sentence_with_start(start_str, strict=False, tries=50) | ||
assert(sent != None) | ||
assert(start_str == sent[:len(start_str)]) | ||
|
||
def test_make_sentence_with_words_not_at_start_of_sentence_miss(self): | ||
text_model = markovify.Text(sherlock, state_size=3) | ||
text_model = markovify.CompiledText(text_model) | ||
start_str = "was werewolf" | ||
sent = text_model.make_sentence_with_start(start_str, strict=False, tries=50) | ||
assert(sent == None) | ||
|
||
def test_make_sentence_with_words_not_at_start_of_sentence_of_state_size(self): | ||
text_model = markovify.Text(sherlock, state_size=2) | ||
text_model = markovify.CompiledText(text_model) | ||
start_str = "was I" | ||
sent = text_model.make_sentence_with_start(start_str, strict=False, tries=50) | ||
assert(sent != None) | ||
assert(start_str == sent[:len(start_str)]) | ||
|
||
def test_make_sentence_with_words_to_many(self): | ||
text_model = sherlock_compiled | ||
start_str = "dog is good" | ||
with self.assertRaises(markovify.text.ParamError) as context: | ||
sent = text_model.make_sentence_with_start(start_str, strict=False) | ||
|
||
def test_make_sentence_with_start_three_words(self): | ||
start_str = "Sherlock Holmes was" | ||
text_model = sherlock_compiled | ||
try: | ||
text_model.make_sentence_with_start(start_str) | ||
assert(False) | ||
except markovify.text.ParamError: | ||
assert(True) | ||
text_model = markovify.Text(sherlock, state_size=3) | ||
text_model = markovify.CompiledText(text_model) | ||
text_model.make_sentence_with_start(start_str) | ||
sent = text_model.make_sentence_with_start("Sherlock") | ||
assert(markovify.chain.BEGIN not in sent) | ||
|
||
def test_short_sentence_fail(self): | ||
text_model = sherlock_compiled | ||
sent = text_model.make_short_sentence(1, tries=1) | ||
assert(sent == None) | ||
|
||
def test_short_sentence(self): | ||
text_model = sherlock_compiled | ||
sent = None | ||
while sent is None: | ||
sent = text_model.make_short_sentence(45) | ||
assert len(sent) <= 45 | ||
|
||
def test_short_sentence_min_chars(self): | ||
text_model = sherlock_compiled | ||
sent = None | ||
while sent is None: | ||
sent = text_model.make_short_sentence(100, min_chars=50) | ||
assert len(sent) <= 100 | ||
assert len(sent) >= 50 | ||
|
||
def test_max_words(self): | ||
text_model = sherlock_compiled | ||
sent = text_model.make_sentence(max_words=0) | ||
assert sent is None | ||
|
||
if __name__ == '__main__': | ||
unittest.main() |