In [None]:
import numpy as np
import matplotlib.pyplot as plt

class HMM(object):
    # Construct an HMM from an HMM_simulator object and the state and observation data generated by the simulator
    def __init__(self, simulator, data):
        self.simulator = simulator
        self.numStates = simulator.state_size # number of possibilities for X_i
        self.numOutputs = simulator.obs_size # number of possibilities for Z_i

        # list of sequences generated from the simulator
        self.trainStates = data.X
        self.trainOutputs = data.Z
        
        # Set up priors for initial state, transition, and observation matrices
        self.pi = np.full(self.numStates, 1/(self.numStates - len(simulator.obstacles))) # uniform prior over initial states
        self.pi[[simulator.obstacles]] = 0
        self.trans = self.transition_prior(p_stay = 0.2)
        
        #self.trans = self.trans.T
        self.emit = self.observation_prior(p_truesignal = 0.7)

    def transition_prior(self, p_stay = 0.2):
        # Set up prior for transition matrix
        # Modified deterministic version of code in HMM_Simulator that sets up the transition matrix
        ## Second iterate through all the states and fill the transition matrix
        ### Determine corner case
        trans = np.zeros((self.numStates, self.numStates))
        map_size = np.sqrt(self.numStates)
        for s in self.simulator.corner_states:
            prow = np.zeros(self.numStates)
            prow[s] = p_stay + (1.0 - p_stay) / 2
            if s == 0:
                prow[1] = (1.0 - p_stay) / 4.0
                prow[map_size] = (1.0 - p_stay) / 4.0
            elif s == (self.map_size-1):
                prow[map_size - 2] = (1.0 - p_stay) / 4.0
                prow[s + map_size] = (1.0 - p_stay) / 4.0
            elif s == (self.numStates - map_size - 1):
                prow[self.numStates - 2*map_size] = (1.0 - p_stay) / 4.0
                prow[self.numStates - 2*map_size + 1] = (1.0 - p_stay) / 4.0
            else:
                prow[self.numStates - 2] = (1.0 - p_stay) / 4.0
                prow[self.numStates - map_size - 1] = (1.0 - p_stay) / 4.0
                
            trans[s,:] = prow
            
            
        for s in range(self.numStates):
            if s in self.simulator.corner_states:
                continue;
            prow = np.zeros(self.numStates)
            if s in self.simulator.north_edge_states:
                prow[s] = p_stay + (1.0 - p_stay) / 4
                prow[s-1] = (1.0 - p_stay) / 4
                prow[s+1] = (1.0 - p_stay) / 4
                prow[s+self.map_size] = (1.0 - self.p_stay) / 4
            elif s in self.simulator.east_edge_states:
                prow[s] = p_stay + (1.0 - p_stay) / 4
                prow[s-map_size] = (1.0 - p_stay) / 4
                prow[s-1] = (1.0 - p_stay) / 4
                prow[s+map_size] = (1.0 - p_stay) / 4
            elif s in self.west_edge_states:
                prow[s] = p_stay + (1.0 - p_stay) / 4
                prow[s-map_size] = (1.0 - p_stay) / 4
                prow[s+1] = (1.0 - p_stay) / 4
                prow[s+map_size] = (1.0 - p_stay) / 4
            elif s in self.simulator.south_edge_states:
                prow[s] = p_stay + (1.0 - p_stay) / 4
                prow[s-map_size] = (1.0 - p_stay) / 4
                prow[s+1] = (1.0 - p_stay) / 4
                prow[s-1] = (1.0 - p_stay) / 4
            else:
                prow[s] = p_stay
                prow[s-map_size] = (1.0 - p_stay) / 4
                prow[s+1] = (1.0 - p_stay) / 4
                prow[s-1] = (1.0 - p_stay) / 4
                prow[s+map_size] = (1.0 - p_stay) / 4
            
            trans[s,:] = prow
        
        return(trans)
    
    # Modified deterministic version of code in HMM_Simulator that sets up observation matrix
    def observation_prior(self,p_truesignal = 0.7):
        M = np.zeros((self.numStates, self.numOutputs))
        for i in range(self.numStates):
            m_row = np.array([0.7 if k == self.simulator.S_type[i] else 
                              0.3/(self.numOutputs-1) for k in range(self.numOutputs)])
            M[i,:] = m_row
        return(M)
    
    # Estimate the transition and observation likelihoods and the
    # prior over the initial state based upon training data
    # Implement the Baum-Welch algorithm
    def train(self, eps = 0.05):

        pi = self.pi
        trans = self.trans
        emit = self.emit
        run = 0
        self.loglik = [] # list to track log likelihood
        self.loglik.append(-float('inf'))
        eps = eps # tolerance for convergence

        # Running the Baum-Welch algorithm iteratively over the training set
        # until log-likelihood converges

        while True:
            alphaAll = []
            gammaAll = []
            expect_fromAll = []
            expect_from_toAll = []
            obsMatAll = []
            for s in range(0,len(self.trainOutputs)):
                obsMat = np.zeros((len(self.trainOutputs[s]),self.numOutputs))
                obsMat[np.arange(len(self.trainOutputs[s])),self.trainOutputs[s]] = 1
                obsMatAll.append(obsMat)

                alpha = np.zeros((len(self.trainOutputs[s]),self.numStates))
                alpha[0,:] = emit[:,self.trainOutputs[s][0]]*pi
                for i in range(1,np.shape(alpha)[0]):
                    alpha[i,:] = emit[:,self.trainOutputs[s][i]]*np.matmul(trans,alpha[i-1,:])
                alphaAll.append(alpha)

                beta = np.zeros((len(self.trainOutputs[s]),self.numStates))
                beta[-1,:] = 1
                for j in range(2,np.shape(beta)[0]+1):
                    beta[-j,:] = np.matmul(emit[:,self.trainOutputs[s][-(j-1)]]*trans.T,beta[-(j-1),:])

                gamma = (alpha*beta)/np.sum(alpha*beta, axis = 1, keepdims = True)
                gammaAll.append(gamma)

                # expected number of transitions from state i
                expect_fromi = np.sum(gamma[:-1,:], axis = 0)
                expect_fromAll.append(expect_fromi)

                xi = np.zeros((len(self.trainOutputs[s])-1,self.numStates,self.numStates))
                for t in range(0,(np.shape(xi)[0])):
                    xi[t] = (alpha[t][np.newaxis,:]*trans*(emit[:,self.trainOutputs[s][t+1]][:,np.newaxis])*beta[t+1][:,np.newaxis]).T
                    xi[t] = xi[t]/np.sum(xi[t])
                expect_fromitoj = np.sum(xi, axis = 0)
                expect_from_toAll.append(expect_fromitoj)

            #updating HMM model

            pi = np.mean([g[0] for g in gammaAll], axis = 0)
            trans = np.nan_to_num((np.sum(expect_from_toAll,axis=0)/(np.sum(expect_fromAll,axis=0)[:,np.newaxis])).T)
            emit = np.nan_to_num(np.sum([np.matmul(gammaAll[i].T,obsMatAll[i]) for i in range(0,len(self.trainOutputs))], axis = 0) / \
                np.sum([np.sum(g,axis = 0) for g in gammaAll],axis = 0)[:,np.newaxis])

            # get log likelihood for this iteration
            lik = [np.sum(a[-1]) for a in alphaAll]
            self.loglik.append(np.sum(np.log(lik)))
            run += 1

            #check for convergence
            if (self.loglik[run] - self.loglik[run-1]) < eps: break
        
        self.pi = pi
        self.trans = trans
        self.emit = emit