/
chain.py
151 lines (130 loc) · 4.96 KB
/
chain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import random
import operator
import bisect
import json
# Python3 compatibility
try:
basestring
except NameError:
basestring = str
BEGIN = "___BEGIN__"
END = "___END__"
def accumulate(iterable, func=operator.add):
"""
Cumulative calculations. (Summation, by default.)
Via: https://docs.python.org/3/library/itertools.html#itertools.accumulate
"""
it = iter(iterable)
total = next(it)
yield total
for element in it:
total = func(total, element)
yield total
class Chain(object):
"""
A Markov chain representing processes that have both beginnings and ends.
For example: Sentences.
"""
def __init__(self, corpus, state_size, model=None):
"""
`corpus`: A list of lists, where each outer list is a "run"
of the process (e.g., a single sentence), and each inner list
contains the steps (e.g., words) in the run. If you want to simulate
an infinite process, you can come very close by passing just one, very
long run.
`state_size`: An integer indicating the number of items the model
uses to represent its state. For text generation, 2 or 3 are typical.
"""
self.state_size = state_size
self.model = model or self.build(corpus, self.state_size)
self.precompute_begin_state()
def build(self, corpus, state_size):
"""
Build a Python representation of the Markov model. Returns a dict
of dicts where the keys of the outer dict represent all possible states,
and point to the inner dicts. The inner dicts represent all possibilities
for the "next" item in the chain, along with the count of times it
appears.
"""
if (type(corpus) != list) or (type(corpus[0]) != list):
raise Exception("`corpus` must be list of lists")
# Using a DefaultDict here would be a lot more convenient, however the memory
# usage is far higher.
model = {}
for run in corpus:
items = ([ BEGIN ] * state_size) + run + [ END ]
for i in range(len(run) + 1):
state = tuple(items[i:i+state_size])
follow = items[i+state_size]
if state not in model:
model[state] = {}
if follow not in model[state]:
model[state][follow] = 0
model[state][follow] += 1
return model
def precompute_begin_state(self):
"""
Caches the summation calculation and available choices for BEGIN * state_size.
Significantly speeds up chain generation on large corpuses. Thanks, @schollz!
"""
begin_state = tuple([ BEGIN ] * self.state_size)
choices, weights = zip(*self.model[begin_state].items())
cumdist = list(accumulate(weights))
self.begin_cumdist = cumdist
self.begin_choices = choices
def move(self, state):
"""
Given a state, choose the next item at random.
"""
if state == tuple([ BEGIN ] * self.state_size):
choices = self.begin_choices
cumdist = self.begin_cumdist
else:
choices, weights = zip(*self.model[state].items())
cumdist = list(accumulate(weights))
r = random.random() * cumdist[-1]
selection = choices[bisect.bisect(cumdist, r)]
return selection
def gen(self, init_state=None):
"""
Starting either with a naive BEGIN state, or the provided `init_state`
(as a tuple), return a generator that will yield successive items
until the chain reaches the END state.
"""
state = init_state or (BEGIN,) * self.state_size
while True:
next_word = self.move(state)
if next_word == END: break
yield next_word
state = tuple(state[1:]) + (next_word,)
def walk(self, init_state=None):
"""
Return a list representing a single run of the Markov model, either
starting with a naive BEGIN state, or the provided `init_state`
(as a tuple).
"""
return list(self.gen(init_state))
def to_json(self):
"""
Dump the model as a JSON object, for loading later.
"""
return json.dumps(list(self.model.items()))
@classmethod
def from_json(cls, json_thing):
"""
Given a JSON object or JSON string that was created by `self.to_json`,
return the corresponding markovify.Chain.
"""
if isinstance(json_thing, basestring):
obj = json.loads(json_thing)
else:
obj = json_thing
if isinstance(obj, list):
rehydrated = dict((tuple(item[0]), item[1]) for item in obj)
elif isinstance(obj, dict):
rehydrated = obj
else:
raise ValueError("Object should be dict or list")
state_size = len(list(rehydrated.keys())[0])
inst = cls(None, state_size, rehydrated)
return inst