In [39]:
%matplotlib widget

In [40]:
from matplotlib import pyplot as plot

In [41]:
import numpy

In [42]:
import torch
from torch import nn
from torch import distributions
import torch.optim

In [43]:
'''
create a fake data assuming HMM with two latent states
'''
seqlen = 1000
pi_true = numpy.array([[0.95, 0.05], [0.1, 0.9]])
mu_true = numpy.ones(2) * [0.3, 3.0]
z0_true = numpy.int(numpy.random.rand() > 0.5)
z = [z0_true]
x = []
for t in range(seqlen):
    x.append(numpy.random.poisson(mu_true[z[-1]]))
    if t < seqlen - 1:
        z.append(numpy.random.multinomial(1, pi_true[z[-1]]).argmax())
obs_sequence = x
    
plot.figure()
plot.plot(numpy.arange(len(x))+1, x, 'r-')
plot.plot(numpy.arange(len(z)), z, 'b-')
plot.grid(True)
plot.legend(['obs', 'latent'])
plot.show()

FigureCanvasNbAgg()

In [183]:
'''
a single sequence Poison-HMM: it supports a single sequence.

it is a very inefficient implementation of forward-backward algorithm and viterbi algorithm.

TODO:

1. indexes must be double- and triple-checked.
2. optimize
'''
class PoissonHMM(nn.Module):
    def __init__(self, n_states=2):
        super(PoissonHMM, self).__init__()
        
        self.n_states = n_states
        
        self.upi = nn.Parameter(torch.from_numpy(numpy.random.randn(n_states, n_states)).float())
        self.upi0 = nn.Parameter(torch.from_numpy(numpy.ones(n_states)).float())
        self.urate = nn.Parameter(torch.from_numpy(numpy.random.rand(n_states)).float())
#         self.urate_scale = nn.Parameter(torch.from_numpy(numpy.random.rand(n_states)).float())
        

    def _prep(self, reverse=False):
        upi_ = torch.softmax(self.upi, dim=1)
        
        transitions = []
        emissions = []
        for si in range(self.n_states):
            if reverse:
                transitions.append(distributions.Categorical(upi_[:,si]))
            else:
                transitions.append(distributions.Categorical(upi_[si]))
#             emissions.append(distributions.Poisson(torch.log(1+torch.exp(self.urate[si]))))
            emissions.append(distributions.Normal(self.urate[si],1.))
        
        upi0_ = torch.exp(self.upi0 - torch.logsumexp(self.upi0, 0))
        initial = distributions.Categorical(upi0_)

        return transitions, emissions, initial
        
    def generate(self, n_steps=100):
        zs = []
        xs = []
        
        transitions, emissions, initial = self._prep()
        
        for t in range(n_steps):
            if t == 0:
                zs.append(initial.sample())
            else:
                zs.append(transitions[zs[-1].item()].sample())
            xs.append(emissions[zs[-1].item()].sample())
            
        return zs, xs
    
    def _viterbi(self, obs, agg='max', prev=None):
        seqlen = len(obs)
        
        scores = []
        if agg == 'max':
            pointers = []
        
        transitions, emissions, initial = self._prep(reverse=prev is not None)
        
        if prev is None:
            prev = initial.probs
            
        for t in range(seqlen):
            # observation probability
            logp_obs = []
            for st in range(self.n_states): # old state
                logp_obs.append(emissions[st].log_prob(obs[t])) 
            logp_obs = torch.stack(logp_obs)
            logp_tra = []
            for st0 in range(self.n_states): # old state
                logp_tra.append([])
                for st1 in range(self.n_states): # new state
                    logp_tra[-1].append(torch.log(1e-8+transitions[st0].probs[st1]))
                logp_tra[-1] = torch.stack(logp_tra[-1])
            logp_tra = torch.stack(logp_tra) # old state (rows) x new state (columns)
            # combine them
            if agg == 'max':
                prev, prev_id = torch.max(logp_tra + logp_obs.view(-1,1) + prev.view(-1, 1), dim=0) # find the best old state
                scores.append(prev)
                pointers.append(prev_id)
            elif agg == 'sum':
                prev = torch.logsumexp(logp_tra + logp_obs.view(-1,1) + prev.view(-1, 1), dim=0) # find the best old state
                scores.append(prev)
            else:
                raise Exception('NOT SUPPORTED')
        
        if agg == 'max':
            # max score
            logp, idx = scores[-1].max(), scores[-1].argmax()
            backtracked = [pointers[-1][idx]]
            for t in range(seqlen-1):
                backtracked.append(pointers[-t][backtracked[-1]])

            return logp, backtracked[::-1]
        elif agg == 'sum':
            return scores
        else: 
            raise Exception('NOT SUPPORTED')
       
    def viterbi(self, obs):
        return self._viterbi(obs, agg='max')
    
 
    def forward_backward(self, obs):
        seqlen = len(obs)
        
        scores = []
        pointers = []
        
        alphas = torch.stack(self._viterbi(obs, agg='sum'))
        betas = torch.stack(self._viterbi(torch.flip(obs, [0]), agg='sum', prev=alphas[-1]))
        
        return torch.logsumexp(alphas + betas, dim=1).mean()

In [189]:
phmm = PoissonHMM(n_states=2)

In [190]:
obs_sequence = numpy.array([13,  7,  6, 10,  9,  5,  5,  7,  6,  5,  3,  6,  3, 10,  8, 10, 14,
      14, 11,  7,  8,  6,  7,  7,  6,  5,  5,  6,  7, 10,  6,  4,  7,  8,
       5,  8, 11,  8,  7,  7,  6,  7,  2,  4,  5,  5,  5,  3,  6,  5,  9,
       3,  5,  7,  8,  8,  8,  7, 10,  8, 10,  9, 10,  7, 11,  9, 10, 10,
      11,  7,  8,  9,  9,  7,  7,  5,  8, 11, 10,  7,  5,  9,  6,  6,  9,
       4,  8, 11,  5, 11,  2,  4,  9,  7,  6,  6,  7,  7,  8,  5,  8,  8,
      10,  9,  4,  3,  3,  7,  2,  2,  5,  4,  6,  7,  3,  5,  4,  4,  4,
       4,  4,  8,  6,  8, 10, 10,  6,  8,  5,  8,  6,  9,  7,  8,  5,  4,
       6,  7,  3,  4,  5,  6,  5,  4])

In [191]:
'''
it sometimes fails with nan. debug it.
'''
n_steps = 10
disp_inter = 1
l2coeff = .5

optimizer = torch.optim.LBFGS(phmm.parameters())

print('Step\tLoss')
# counter = 0
for ni in range(n_steps):    
    def closure():
        optimizer.zero_grad()
        logp = phmm.forward_backward(torch.from_numpy(numpy.array(obs_sequence)).float())
        loss = -logp
        loss = (1.-l2coeff) * loss + l2coeff * sum([(param ** 2).sum() for param in phmm.parameters()])
        loss.backward()
        return loss
    if numpy.mod(ni, disp_inter) == 0:
        logp = phmm.forward_backward(torch.from_numpy(numpy.array(obs_sequence)).float())
        print('{} {}'.format(ni+1, -logp))
    
    optimizer.step(closure)

Step	Loss
1 6503.5947265625
2 744.3103637695312
3 744.3214111328125
4 744.3217163085938
5 744.3218383789062
6 744.322021484375
7 744.3215942382812
8 744.3214721679688
9 744.3215942382812
10 744.3214111328125


In [192]:
sorted_idx = numpy.argsort(phmm.urate.data.numpy())

In [193]:
logp, inferred = phmm.viterbi(torch.from_numpy(numpy.array(obs_sequence)).float())

# let's plot
plot.figure()

plot.plot(numpy.arange(len(obs_sequence)), obs_sequence, 'k-')
# plot.plot(numpy.arange(len(z)), numpy.array(z), 'r--')
plot.plot(numpy.arange(len(inferred)), [sorted_idx[iz.item()] for iz in inferred], 'b--')
plot.legend(['observed', 
#              'true latent', 
             'inferred latent'])
plot.grid(True)
plot.show()



FigureCanvasNbAgg()

In [194]:
print('Transition matrix')
# print('\tTrue')
# print(pi_true)
print('\tEstimated')
print(torch.softmax(phmm.upi, 1).data.numpy())

# print('Poison rate')
# # print('\tTrue')
# # print(mu_true)
# print('\tEstimated')
# print(torch.log(1+torch.exp(phmm.urate)).data.numpy())

print('Gaussian mean')
# print('\tTrue')
# print(mu_true)
print('\tEstimated')
print(phmm.urate.data.numpy())

Transition matrix
	Estimated
[[0.66147995 0.33852005]
 [0.31218347 0.68781656]]
Gaussian mean
	Estimated
[4.7399683 8.589926 ]
