Skip to content

Commit

Permalink
Merge 32960f5 into 15496b2
Browse files Browse the repository at this point in the history
  • Loading branch information
erikerlandson committed Nov 29, 2019
2 parents 15496b2 + 32960f5 commit 8851783
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 27 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,17 @@ model_combo = markovify.combine([ model_a, model_b ], [ 1.5, 1 ])

This code snippet would combine `model_a` and `model_b`, but, it would also place 50% more weight on the connections from `model_a`.

### Compiling a model

Once a model has been generated, it may also be compiled for improved text generation speed and reduced size.
```python
text_model = markovify.Text(text, state_size=3)
text_model.compile()
```

Once a model is compiled, it may not currently be combined with other models using `markovify.combine(...)`.
If you wish to combine models, do that first and then compile the result.

### Working with messy texts

Starting with `v0.7.2`, `markovify.Text` accepts two additional parameters: `well_formed` and `reject_reg`.
Expand Down
24 changes: 20 additions & 4 deletions markovify/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ def accumulate(iterable, func=operator.add):
total = func(total, element)
yield total

def compile_next(next_dict):
words = list(next_dict.keys())
cff = list(accumulate(next_dict.values()))
return [words, cff]

class Chain(object):
"""
A Markov chain representing processes that have both beginnings and ends.
Expand All @@ -42,7 +47,17 @@ def __init__(self, corpus, state_size, model=None):
"""
self.state_size = state_size
self.model = model or self.build(corpus, self.state_size)
self.precompute_begin_state()
self.compiled = (len(self.model) > 0) and (type(self.model[tuple([BEGIN]*state_size)]) == list)
if not self.compiled:
self.precompute_begin_state()

def compile(self):
if not self.compiled:
del self.begin_cumdist
del self.begin_choices
sxf = { state: compile_next(next_dict) for (state, next_dict) in self.model.items() }
self.model = sxf
self.compiled = True

def build(self, corpus, state_size):
"""
Expand Down Expand Up @@ -77,16 +92,17 @@ def precompute_begin_state(self):
Significantly speeds up chain generation on large corpora. Thanks, @schollz!
"""
begin_state = tuple([ BEGIN ] * self.state_size)
choices, weights = zip(*self.model[begin_state].items())
cumdist = list(accumulate(weights))
choices, cumdist = compile_next(self.model[begin_state])
self.begin_cumdist = cumdist
self.begin_choices = choices

def move(self, state):
"""
Given a state, choose the next item at random.
"""
if state == tuple([ BEGIN ] * self.state_size):
if self.compiled:
choices, cumdist = self.model[state]
elif state == tuple([ BEGIN ] * self.state_size):
choices = self.begin_choices
cumdist = self.begin_cumdist
else:
Expand Down
3 changes: 3 additions & 0 deletions markovify/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ def __init__(self, input_text, state_size=2, chain=None, parsed_sentences=None,
parsed = parsed_sentences or self.generate_corpus(input_text)
self.chain = chain or Chain(parsed, state_size)

def compile(self):
self.chain.compile()

def to_dict(self):
"""
Returns the underlying data as a Python dict.
Expand Down
4 changes: 4 additions & 0 deletions markovify/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@

def get_model_dict(thing):
if isinstance(thing, Chain):
if thing.compiled:
raise ValueError("Not implemented for compiled markovify.Chain")
return thing.model
if isinstance(thing, Text):
if thing.chain.compiled:
raise ValueError("Not implemented for compiled markovify.Chain")
return thing.chain.model
if isinstance(thing, list):
return dict(thing)
Expand Down
64 changes: 41 additions & 23 deletions test/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,28 @@
def get_sorted(chain_json):
return sorted(chain_json, key=operator.itemgetter(0))

with open(os.path.join(os.path.dirname(__file__), "texts/sherlock.txt")) as f:
sherlock = f.read()
sherlock_model = markovify.Text(sherlock)

class MarkovifyTest(unittest.TestCase):
class MarkovifyTestBase(unittest.TestCase):
__test__ = False

def test_text_too_small(self):
text = u"Example phrase. This is another example sentence."
text_model = markovify.Text(text)
assert(text_model.make_sentence() == None)

def test_sherlock(self):
text_model = sherlock_model
text_model = self.sherlock_model
sent = text_model.make_sentence()
assert(len(sent) != 0)

def test_json(self):
text_model = sherlock_model
text_model = self.sherlock_model
json_model = text_model.to_json()
new_text_model = markovify.Text.from_json(json_model)
sent = text_model.make_sentence()
sent = new_text_model.make_sentence()
assert(len(sent) != 0)

def test_chain(self):
text_model = sherlock_model
text_model = self.sherlock_model
chain_json = text_model.chain.to_json()

stored_chain = markovify.Chain.from_json(chain_json)
Expand All @@ -43,34 +40,34 @@ def test_chain(self):
assert(len(sent) != 0)

def test_make_sentence_with_start(self):
text_model = sherlock_model
text_model = self.sherlock_model
start_str = "Sherlock Holmes"
sent = text_model.make_sentence_with_start(start_str)
assert(sent != None)
assert(start_str == sent[:len(start_str)])

def test_make_sentence_with_start_one_word(self):
text_model = sherlock_model
text_model = self.sherlock_model
start_str = "Sherlock"
sent = text_model.make_sentence_with_start(start_str)
assert(sent != None)
assert(start_str == sent[:len(start_str)])

def test_make_sentence_with_start_one_word_that_doesnt_begin_a_sentence(self):
text_model = sherlock_model
text_model = self.sherlock_model
start_str = "dog"
with self.assertRaises(KeyError) as context:
sent = text_model.make_sentence_with_start(start_str)

def test_make_sentence_with_word_not_at_start_of_sentence(self):
text_model = sherlock_model
text_model = self.sherlock_model
start_str = "dog"
sent = text_model.make_sentence_with_start(start_str, strict=False)
assert(sent != None)
assert(start_str == sent[:len(start_str)])

def test_make_sentence_with_words_not_at_start_of_sentence(self):
text_model = markovify.Text(sherlock, state_size=3)
text_model = self.sherlock_model_ss3
# " I was " has 128 matches in sherlock.txt
# " was I " has 2 matches in sherlock.txt
start_str = "was I"
Expand All @@ -79,39 +76,39 @@ def test_make_sentence_with_words_not_at_start_of_sentence(self):
assert(start_str == sent[:len(start_str)])

def test_make_sentence_with_words_not_at_start_of_sentence_miss(self):
text_model = markovify.Text(sherlock, state_size=3)
text_model = self.sherlock_model_ss3
start_str = "was werewolf"
sent = text_model.make_sentence_with_start(start_str, strict=False, tries=50)
assert(sent == None)

def test_make_sentence_with_words_not_at_start_of_sentence_of_state_size(self):
text_model = markovify.Text(sherlock, state_size=2)
text_model = self.sherlock_model_ss2
start_str = "was I"
sent = text_model.make_sentence_with_start(start_str, strict=False, tries=50)
assert(sent != None)
assert(start_str == sent[:len(start_str)])

def test_make_sentence_with_words_to_many(self):
text_model = sherlock_model
text_model = self.sherlock_model
start_str = "dog is good"
with self.assertRaises(markovify.text.ParamError) as context:
sent = text_model.make_sentence_with_start(start_str, strict=False)

def test_make_sentence_with_start_three_words(self):
start_str = "Sherlock Holmes was"
text_model = sherlock_model
text_model = self.sherlock_model
try:
text_model.make_sentence_with_start(start_str)
assert(False)
except markovify.text.ParamError:
assert(True)
text_model = markovify.Text(sherlock, state_size=3)
text_model = self.sherlock_model_ss3
text_model.make_sentence_with_start(start_str)
sent = text_model.make_sentence_with_start("Sherlock")
assert(markovify.chain.BEGIN not in sent)

def test_short_sentence(self):
text_model = sherlock_model
text_model = self.sherlock_model
sent = None
while sent is None:
sent = text_model.make_short_sentence(45)
Expand All @@ -120,17 +117,17 @@ def test_short_sentence(self):
def test_short_sentence_min_chars(self):
sent = None
while sent is None:
sent = sherlock_model.make_short_sentence(100, min_chars=50)
sent = self.sherlock_model.make_short_sentence(100, min_chars=50)
assert len(sent) <= 100
assert len(sent) >= 50

def test_dont_test_output(self):
text_model = sherlock_model
text_model = self.sherlock_model
sent = text_model.make_sentence(test_output=False)
assert sent is not None

def test_max_words(self):
text_model = sherlock_model
text_model = self.sherlock_model
sent = text_model.make_sentence(max_words=0)
assert sent is None

Expand All @@ -156,5 +153,26 @@ def test_custom_regex(self):

model = markovify.NewlineText('This sentence (would normall fail', well_formed = False)

class MarkovifyTest(MarkovifyTestBase):
__test__ = True

with open(os.path.join(os.path.dirname(__file__), "texts/sherlock.txt")) as f:
sherlock_text = f.read()
sherlock_model = markovify.Text(sherlock_text)
sherlock_model_ss2 = markovify.Text(sherlock_text, state_size = 2)
sherlock_model_ss3 = markovify.Text(sherlock_text, state_size = 3)

class MarkovifyTestCompiled(MarkovifyTestBase):
__test__ = True

with open(os.path.join(os.path.dirname(__file__), "texts/sherlock.txt")) as f:
sherlock_text = f.read()
sherlock_model = markovify.Text(sherlock_text)
sherlock_model_ss2 = markovify.Text(sherlock_text, state_size = 2)
sherlock_model_ss3 = markovify.Text(sherlock_text, state_size = 3)
sherlock_model.compile()
sherlock_model_ss2.compile()
sherlock_model_ss3.compile()

if __name__ == '__main__':
unittest.main()
14 changes: 14 additions & 0 deletions test/test_combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ def get_sorted(chain_json):
sherlock = f.read()
sherlock_model = markovify.Text(sherlock)
sherlock_model_no_retain = markovify.Text(sherlock, retain_original=False)
sherlock_model_compiled = markovify.Text(sherlock)
sherlock_model_compiled.compile()

class MarkovifyTest(unittest.TestCase):

Expand Down Expand Up @@ -56,6 +58,18 @@ def test_mismatched_model_types(self):
text_model_b = markovify.NewlineText(sherlock)
combo = markovify.combine([ text_model_a, text_model_b ])

def test_compiled_model_fail(self):
with self.assertRaises(Exception) as context:
model_a = sherlock_model
model_b = sherlock_model_compiled
combo = markovify.combine([ text_model_a, text_model_b ])

def test_compiled_chain_fail(self):
with self.assertRaises(Exception) as context:
model_a = sherlock_model.chain
model_b = sherlock_model_compiled.chain
combo = markovify.combine([ text_model_a, text_model_b ])

def test_combine_no_retain(self):
text_model = sherlock_model_no_retain
combo = markovify.combine([ text_model, text_model ])
Expand Down

0 comments on commit 8851783

Please sign in to comment.