-
Notifications
You must be signed in to change notification settings - Fork 0
/
hmm.py
86 lines (70 loc) · 3.75 KB
/
hmm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import json
from pomegranate import DiscreteDistribution, HiddenMarkovModel
from utils import *
from models import Model
import os
class GenerativeHMM(Model):
def __init__(self, args):
Model.__init__(self, args)
self.model = None
self.build_model()
self.train_loss = []
self.valid_loss_history = []
assert(self.seq_length * self.num_characters == self.input)
def build_model(self):
distributions = []
for _ in range(self.hidden_size):
emission_probs = np.random.random(self.num_characters)
emission_probs = emission_probs / emission_probs.sum()
distributions.append(DiscreteDistribution(dict(zip(self.all_characters, emission_probs))))
trans_mat = np.random.random((self.hidden_size, self.hidden_size))
trans_mat = trans_mat / trans_mat.sum(axis=1, keepdims=1)
starts = np.random.random(self.hidden_size)
starts = starts / starts.sum()
# testing initializations
np.testing.assert_almost_equal(starts.sum(), 1)
np.testing.assert_array_almost_equal(np.ones(self.hidden_size), trans_mat.sum(axis=1))
self.model = HiddenMarkovModel.from_matrix(trans_mat, distributions, starts)
self.model.bake()
def fit(self, train_dataloader, valid_dataloader, verbose=True, logger=None, save_model=True, weights=None, **kwargs):
start_time = time.time()
for epoch in range(1, self.epochs + 1):
_, hist = self.model.fit(train_dataloader, max_iterations=1, pseudocount=self.pseudo_count,
n_jobs=self.n_jobs, return_history=True)
train_loss = self.evaluate(train_dataloader)
self.train_loss_history.append(train_loss)
if valid_dataloader:
valid_loss = self.evaluate(valid_dataloader)
self.valid_loss_history.append(valid_loss)
if verbose:
print("epoch {0}, train neg log prob: {1:.4f}, test neg log probability {2:.4f}, time: {3:.2f} sec".format(
epoch, train_loss, valid_loss, time.time() - start_time), file=logger)
if epoch % self.save_epochs == 0 and save_model:
path = os.path.join(self.base_log, self.name, "{0}_checkpoint_{1}.json".format(self.model_type, epoch))
self.save_model(path)
if self.early_stopping:
super().early_stop_iteration(train_loss, valid_loss, epoch, logger)
if self.early_stopping.early_stop:
break
def evaluate(self, dataloader, verbose=False, logger=None, weights=None, **kwargs):
assert(len(np.array(dataloader).shape) == 2 or len(np.array(dataloader).shape) == 3)
loss = -sum([self.model.log_probability(seq) for seq in np.array(dataloader)])
if verbose:
print("Average neg log prob: {0:.4f}".format(loss / len(dataloader)), file=logger)
return loss / len(dataloader)
def sample(self, num_samples, length, to_string=True, **kwargs):
return ["".join(x) for x in self.model.sample(n=num_samples, length=length)]
def show_model(self, logger=None, **kwargs):
print(self.model, logger)
def plot_model(self, save_fig_dir, show=False, **kwargs):
# self.model.plot() does not plot legible graphs for hidden size > 10
pass
def save_model(self, path, **kwargs):
with open(path, 'w') as f:
json.dump(self.model.to_json(), f)
def load_model(self, path, **kwargs):
with open(path, 'r') as f:
json_model = json.load(f)
self.model = HiddenMarkovModel.from_json(json_model)
def plot_history(self, save_fig_dir, **kwargs):
super().plot_history(save_fig_dir=save_fig_dir, **kwargs)