Skip to content

Commit

Permalink
Merge pull request #81 from thallada/combine-no-retain
Browse files Browse the repository at this point in the history
Fix combining models with mixed retain_original values
  • Loading branch information
jsvine committed Oct 7, 2017
2 parents 8db3b9c + 74230aa commit 0b33928
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 7 deletions.
12 changes: 8 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,19 +174,23 @@ with open("path/to/my/huge/corpus.txt") as f:
print(text_model.make_sentence())
```

And `(b)` read in the corpus line-by-line or file-by-file and combine it into one model at the end:
And `(b)` read in the corpus line-by-line or file-by-file and combine them into one model at each step:

```python
models = []
combined_model = None
for (dirpath, _, filenames) in os.walk("path/to/my/huge/corpus"):
for filename in filenames:
with open(os.path.join(dirpath, filename)) as f:
models.append(markovify.Text(f, retain_original=False))
model = markovify.Text(file, retain_original=False)
if combined_model:
combined_model = markovify.combine(models=[combined_model, model])
else:
combined_model = model

combined_model = markovify.combine(models)
print(combined_model.make_sentence())
```


## Markovify In The Wild

- BuzzFeed's [Tom Friedman Sentence Generator](http://www.buzzfeed.com/jsvine/the-tom-friedman-sentence-generator) / [@mot_namdeirf](https://twitter.com/mot_namdeirf).
Expand Down
2 changes: 1 addition & 1 deletion markovify/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def from_chain(cls, chain_json, corpus=None, parsed_sentences=None):
If corpus is None, overlap checking won't work.
"""
chain = Chain.from_json(chain_json)
return cls(corpus or '', parsed_sentences=parsed_sentences, state_size=chain.state_size, chain=chain)
return cls(corpus or None, parsed_sentences=parsed_sentences, state_size=chain.state_size, chain=chain)


class NewlineText(Text):
Expand Down
5 changes: 3 additions & 2 deletions markovify/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,11 @@ def combine(models, weights=None):
if isinstance(ret_inst, Chain):
return Chain.from_json(c)
if isinstance(ret_inst, Text):
if ret_inst.retain_original:
if any(m.retain_original for m in models):
combined_sentences = []
for m in models:
combined_sentences += m.parsed_sentences
if m.retain_original:
combined_sentences += m.parsed_sentences
return ret_inst.from_chain(c, parsed_sentences=combined_sentences)
else:
return ret_inst.from_chain(c)
Expand Down
20 changes: 20 additions & 0 deletions test/test_combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ def get_sorted(chain_json):
with open(os.path.join(os.path.dirname(__file__), "texts/sherlock.txt")) as f:
sherlock = f.read()
sherlock_model = markovify.Text(sherlock)
sherlock_model_no_retain = markovify.Text(sherlock, retain_original=False)

class MarkovifyTest(unittest.TestCase):

Expand Down Expand Up @@ -55,6 +56,25 @@ def test_mismatched_model_types(self):
text_model_b = markovify.NewlineText(sherlock)
combo = markovify.combine([ text_model_a, text_model_b ])

def test_combine_no_retain(self):
text_model = sherlock_model_no_retain
combo = markovify.combine([ text_model, text_model ])
assert(not combo.retain_original)

def test_combine_retain_on_no_retain(self):
text_model_a = sherlock_model_no_retain
text_model_b = sherlock_model
combo = markovify.combine([ text_model_a, text_model_b ])
assert(combo.retain_original)
assert(combo.parsed_sentences == text_model_b.parsed_sentences)

def test_combine_no_retain_on_retain(self):
text_model_a = sherlock_model_no_retain
text_model_b = sherlock_model
combo = markovify.combine([ text_model_b, text_model_a ])
assert(combo.retain_original)
assert(combo.parsed_sentences == text_model_b.parsed_sentences)

if __name__ == '__main__':
unittest.main()

0 comments on commit 0b33928

Please sign in to comment.