In [14]:
%matplotlib widget

In [15]:
from matplotlib import pyplot as plot

In [16]:
import numpy

In [17]:
import torch
from torch import nn
from torch import distributions
import torch.optim

In [18]:
'''
create a fake data assuming HMM with two latent states
'''
seqlen = 500
pi_true = numpy.array([[0.95, 0.05], [0.1, 0.9]])
mu_true = numpy.ones(2) * [0.3, 3.0]
z0_true = numpy.int(numpy.random.rand() > 0.5)
z = [z0_true]
x = []
for t in range(seqlen):
    x.append(numpy.random.poisson(mu_true[z[-1]]))
    if t < seqlen - 1:
        z.append(numpy.random.multinomial(1, pi_true[z[-1]]).argmax())
obs_sequence = x
    
plot.figure()
plot.plot(numpy.arange(len(x))+1, x, 'r-')
plot.plot(numpy.arange(len(z)), z, 'b-')
plot.grid(True)
plot.legend(['obs', 'latent'])
plot.show()

FigureCanvasNbAgg()

In [19]:
'''
a single sequence Poison-HMM: it supports a single sequence.

it is a very inefficient implementation of forward-backward algorithm and viterbi algorithm.

TODO:

1. indexes must be double- and triple-checked.
2. optimize
'''
class PoissonHMM(nn.Module):
    def __init__(self, n_states=2):
        super(PoissonHMM, self).__init__()
        
        self.n_states = n_states
        
        self.upi = nn.Parameter(torch.from_numpy(numpy.random.randn(n_states, n_states)).float())
        self.urate = nn.Parameter(torch.from_numpy(numpy.random.randn(n_states)).float())
        self.upi0 = nn.Parameter(torch.from_numpy(numpy.ones(n_states)).float())
        

    def _prep(self, reverse=False):
        upi_ = torch.softmax(self.upi, dim=1)
        
        transitions = []
        emissions = []
        for si in range(self.n_states):
            if reverse:
                transitions.append(distributions.Categorical(upi_[:,si]))
            else:
                transitions.append(distributions.Categorical(upi_[si]))
            emissions.append(distributions.Poisson(torch.log(1+torch.exp(self.urate[si]))))
        
        upi0_ = torch.exp(self.upi0 - torch.logsumexp(self.upi0, 0))
        initial = distributions.Categorical(upi0_)

        return transitions, emissions, initial
        
    def generate(self, n_steps=100):
        zs = []
        xs = []
        
        transitions, emissions, initial = self._prep()
        
        for t in range(n_steps):
            if t == 0:
                zs.append(initial.sample())
            else:
                zs.append(transitions[zs[-1].item()].sample())
            xs.append(emissions[zs[-1].item()].sample())
            
        return zs, xs
    
    def _viterbi(self, obs, agg='max', prev=None):
        seqlen = len(obs)
        
        scores = []
        if agg == 'max':
            pointers = []
        
        transitions, emissions, initial = self._prep(reverse=prev is not None)
        
        if prev is None:
            prev = initial.probs
            
        for t in range(seqlen):
            # observation probability
            logp_obs = []
            for st in range(self.n_states): # old state
                logp_obs.append(emissions[st].log_prob(obs[t])) 
            logp_obs = torch.stack(logp_obs)
            logp_tra = []
            for st0 in range(self.n_states): # old state
                logp_tra.append([])
                for st1 in range(self.n_states): # new state
                    logp_tra[-1].append(torch.log(1e-8+transitions[st0].probs[st1]))
                logp_tra[-1] = torch.stack(logp_tra[-1])
            logp_tra = torch.stack(logp_tra) # old state (rows) x new state (columns)
            # combine them
            if agg == 'max':
                prev, prev_id = torch.max(logp_tra + logp_obs.view(-1,1) + prev.view(-1, 1), dim=0) # find the best old state
                scores.append(prev)
                pointers.append(prev_id)
            elif agg == 'sum':
                prev = torch.logsumexp(logp_tra + logp_obs.view(-1,1) + prev.view(-1, 1), dim=0) # find the best old state
                scores.append(prev)
            else:
                raise Exception('NOT SUPPORTED')
        
        if agg == 'max':
            # max score
            logp, idx = scores[-1].max(), scores[-1].argmax()
            backtracked = [pointers[-1][idx]]
            for t in range(seqlen-1):
                backtracked.append(pointers[-t][backtracked[-1]])

            return logp, backtracked[::-1]
        elif agg == 'sum':
            return scores
        else: 
            raise Exception('NOT SUPPORTED')
       
    def viterbi(self, obs):
        return self._viterbi(obs, agg='max')
    
 
    def forward_backward(self, obs):
        seqlen = len(obs)
        
        scores = []
        pointers = []
        
        alphas = torch.stack(self._viterbi(obs, agg='sum'))
        betas = torch.stack(self._viterbi(torch.flip(obs, [0]), agg='sum', prev=alphas[-1]))
        
        return (alphas + betas).sum(dim=1).mean()

In [20]:
phmm = PoissonHMM(n_states=2)

In [21]:
'''
it sometimes fails with nan. debug it.
'''
n_steps = 10
disp_inter = 1

optimizer = torch.optim.LBFGS(phmm.parameters())

print('Step\tLoss')
# counter = 0
for ni in range(n_steps):    
    def closure():
        optimizer.zero_grad()
        logp = phmm.forward_backward(torch.from_numpy(numpy.array(obs_sequence)).float())
        loss = -logp
        loss.backward()
        return loss
    if numpy.mod(ni, disp_inter) == 0:
        logp = phmm.forward_backward(torch.from_numpy(numpy.array(obs_sequence)).float())
        print('{} {}'.format(ni+1, -logp))
    
    optimizer.step(closure)

Step	Loss
1 1748806.125
2 1290449.75
3 1290447.75
4 1290447.75
5 1290447.0
6 1290447.75
7 1290447.5
8 1290446.875
9 1290446.875
10 1290446.875


In [22]:
logp, inferred = phmm.viterbi(torch.from_numpy(numpy.array(obs_sequence)).float())

# let's plot
plot.figure()

plot.plot(numpy.arange(len(obs_sequence)), obs_sequence, 'k-')
plot.plot(numpy.arange(len(z)), numpy.array(z)-1, 'r--')
plot.plot(numpy.arange(len(inferred)), [iz.item()-1 for iz in inferred], 'b--')
plot.legend(['observed', 'true latent', 'inferred latent'])
plot.grid(True)
plot.show()

FigureCanvasNbAgg()

In [23]:
print('Transition matrix')
print('\tTrue')
print(pi_true)
print('\tEstimated')
print(torch.softmax(phmm.upi, 1).data.numpy())

print('Poison rate')
print('\tTrue')
print(mu_true)
print('\tEstimated')
print(torch.exp(phmm.urate).data.numpy())

Transition matrix
	True
[[0.95 0.05]
 [0.1  0.9 ]]
	Estimated
[[0.94299453 0.05700541]
 [0.08822832 0.9117717 ]]
Poison rate
	True
[0.3 3. ]
	Estimated
[ 0.36258495 18.782112  ]
