-
Notifications
You must be signed in to change notification settings - Fork 0
/
hmm.py
138 lines (122 loc) · 3.78 KB
/
hmm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import numpy as np
# using special tags
# - <s> for start of sentence
# - </s> for end of sentence
# - <oov> for out of vocabulary words
def wrap_s_tag(y):
return ['<s>'] + y + ['</s>']
def estimate_transition(ys):
bigram_count = {}
n_bigram = 0
unigram_count = {}
n_unigram = 0
for y in ys:
context = None
for t in wrap_s_tag(y):
if t not in unigram_count:
unigram_count[t] = 0
unigram_count[t] += 1
n_unigram += 1
if context != None:
if context not in bigram_count:
bigram_count[context] = {}
if t not in bigram_count[context]:
bigram_count[context][t] = 0
bigram_count[context][t] += 1
n_bigram += 1
context = t
p_transition = {}
for w1 in bigram_count:
p_transition[w1] = {}
for w2 in bigram_count[w1]:
p_transition[w1][w2] = bigram_count[w1][w2] * 1. / unigram_count[w1]
return p_transition
def estimate_emission(ys, xs, smoothing=None):
p_emission = {}
tag_count = {}
for y, x in zip(ys, xs):
for tag, token in zip(wrap_s_tag(y), wrap_s_tag(x)):
if tag not in p_emission:
p_emission[tag] = {}
if token not in p_emission[tag]:
p_emission[tag][token] = 0
p_emission[tag][token] += 1
if tag not in tag_count:
tag_count[tag] = 0
tag_count[tag] += 1
if smoothing == None:
for tag in p_emission:
p_emission[tag]['<oov>'] = 0.
for token in p_emission[tag]:
p_emission[tag][token] *= 1. / tag_count[tag]
else:
smoothing(p_emission)
return p_emission
def sample_discrete(distribution):
p = np.cumsum(distribution.values())
x = np.random.rand()
for k, pp in zip(distribution.keys(), p):
if x < pp:
return k
def sample(p_transition, p_emission):
y_prev = '<s>'
x_prev = '<s>'
y = [y_prev]
x = [x_prev]
while y_prev != '</s>':
y_prev = sample_discrete(p_transition[y_prev])
y.append(y_prev)
x_prev = sample_discrete(p_emission[y_prev])
x.append(x_prev)
return y, x
def V(m, y, p_transition, p_emission, x, c=None):
if c == None:
c = {}
if (m, y) in c:
return c[m, y]
if m == 0:
c[m, y] = 1. if y == '<s>' else 0., [y]
return c[m, y]
v_max = 0.
yhat_max = []
for yp in p_transition:
if y in p_transition[yp]:
if x[m] in p_emission[y]:
p_xm = p_emission[y][x[m]]
else:
p_xm = p_emission[y]['<oov>']
p, yhat = V(m-1, yp, p_transition, p_emission, x, c)
v = p * p_transition[yp][y] * p_xm
if v_max < v:
v_max = v
yhat_max = yhat
c[m, y] = v_max, yhat_max + [y]
return v_max, yhat_max + [y]
def inference(x, p_transition, p_emission):
tagged_x = wrap_s_tag(x)
m = len(wrap_s_tag(x)) - 1
return V(m, '</s>', p_transition, p_emission, tagged_x)
def accuracy(y, y_hat):
n = 0
for ty, th in zip(y, y_hat[1:-1]):
if ty == th:
n += 1
return n * 1. / len(y_hat)
def evaluate(xs, p_transition, p_emission, ys):
n = 0
total_acc = 0.
for i, (x, y) in enumerate(zip(xs, ys)):
p, yhat = inference(x, p_transition, p_emission)
n += 1
total_acc += accuracy(y, yhat)
return total_acc / n
def add_k_smooth(k, counts):
for t in counts:
n = 0
counts[t]['<oov>'] = 0
for w in counts[t]:
counts[t][w] += k
n += counts[t][w]
for w in counts[t]:
counts[t][w] *= 1. / n
return counts