# HMM Model

## Step 0: Description 
First working attempt at creating an HMM model in Pytorch for smFRET implementations. This will try GPU use. Please note, most of this code is drawn from the git file located at https://github.com/lorenlugosch/pytorch_HMM

## Step 1: Setting up the venv and imports.
Make sure you have all of the necessary packages imported. If not, create a conda venv that has torch downloaded.

In [3]:
import torch

## Step 2: Define the HMM Model (init)
First, we need to define the HMM model. We will need to initialize the three different dataframes that we need too: priors, transitions, and emissions. 
### Class Definition

In [4]:
class HMM(torch.nn.Module): # Torch documentation suggests inheritance from torch.nn.Module
  """
  Hidden Markov Model with discrete observations.
  """
  def __init__(self, transitions, emissions, priors):
    """
    Initializes a new HMM model.

    NOTE: The variables 'transitions', 'emissions', and 'priors' should be of type lists. 
    They will be normalized using torch.nn.functional softmax functions.

    Attributes:
      N (int): the number of states
      M (int): the number of observations
      transition_model (TransitionModel): the transition matrix for this HMM
      emission_model (EissionModel): the emission matrix for this HMM
      state_priors (torch.nn.Parameter): the prior distribution for this HMM
      is_cuda (bool): if a GPU is activated using cuda() for use
    """
    super().__init__()

    # First, save the number of observations and the number of states
    self.N = len(priors) # number of states
    self.M = len(emissions[0]) # number of observations

    # For the purposes of sampling and other algos, we will keep inputted probabilities unnormalized and pre-process data as needed.

    # Create A
    self.unnormalized_trans = TransitionMatrix(self.N, transitions)

    # b(x_t)
    self.unnormalized_emiss = EmissionMatrix(self.N, self.M, emissions)

    # pi
    self.unnormalized_sp = torch.nn.Parameter(torch.Tensor(priors))
 
    # use the GPU, for speed
    if torch.cuda.is_available(): 
      self.cuda()
      self.is_cuda = True

    else: self.is_cuda = False

class TransitionMatrix(torch.nn.Module):
  """
  The transition matrix for our HMM model.
  """
  def __init__(self, N, transitions):
    """
    Instantiates a new transition matrix for our HMM model.
    """
    ### Checks to make sure that the number of priors and transitions line up
    if len(transitions) != N:
      raise ValueError(f'Mismatch in the number of priors and rows/cols in "transitions". {N} != {len(transitions)}')

    super().__init__()
    self.N = N
    self.matrix = torch.nn.Parameter(torch.Tensor(transitions))

class EmissionMatrix(torch.nn.Module):
  def __init__(self, N, M, emissions):
    """
    Instantiates a new emission matrix for our HMM model.
    """
    ### Checks if the number of states and rows of the emissions matrix line up
    if len(emissions) != N:
      raise ValueError(f'Mismatch in the number of priors and rows in "emissions". {N} != {len(emissions)}')

    super().__init__()
    self.N = N
    self.M = M
    self.matrix = torch.nn.Parameter(torch.Tensor(emissions))

### Example Initialization

We will follow along with a model. This model is a simple FRET HMM with 2 states and 2 possible observations.

In [5]:
## According to the FRET API tutorial, there 2 states 0 and 1, both with the same chance of being x_0.
priors = [0.5,0.5]

## With 2 states there are 4 transitions. 
## Usually there are numbers close to 1 along diagonal (the prob of not transitioning is higher) and close to 0 else.
transitions = [[0.999999, 1e-6],
                [1e-6, 0.999999]]

## In this example, we have two states (A and B) and two different observations for emission (0 and 1). States are i's and Emissions are j's
# This model shows that if you're in state A, you have a higher chance of emitting 1 and if you're in state B, there will be random emissions.
observations = [[0.3, 0.7],
                [0.5, 0.5]]

# Thus, we have the model:
model = HMM(transitions, observations, priors)

## Step 3: Defining the sample() function
Next, we will write a sample(...) function that will allow us to simulate or sample the model for T time steps.

In [6]:
def sample(self, T=10):
  """
  This function samples the HMM model, returning the hidden states and what was observable.
  
  This function also locally normalizes the unnormalized_sp, unnormalized_trans, unnormalized_emiss 
  using the torch.nn.functional.softmax(...)
  """
  state_priors = torch.nn.functional.softmax(self.unnormalized_sp, dim=0)
  emission_matrix = torch.nn.functional.softmax(self.unnormalized_emiss.matrix, dim=1)
  transition_matrix = torch.nn.functional.softmax(self.unnormalized_trans.matrix, dim=0)

  # sample initial state
  z_t = torch.distributions.categorical.Categorical(state_priors).sample().item()
  z = []
  x = []
  z.append(z_t)
  
  for t in range(0,T):
    # sample emission
    x_t = torch.distributions.categorical.Categorical(emission_matrix[z_t]).sample().item()
    x.append(x_t)

    # sample transition
    z_t = torch.distributions.categorical.Categorical(transition_matrix[:,z_t]).sample().item()
    if t < T-1: z.append(z_t)

  return x, z

# Add the sampling method to our HMM class
HMM.sample = sample

### Sample the HMM:
The below code will run the HMM and report states and observations. Note, observation 0 does NOT imply that the object is in state 0. 0 and 1 were used for the sake of encoding the information in an easily indexable manner.

In [7]:
for _ in range(5):
  sampled_x, sampled_z = model.sample(T=5)
  print("x:", sampled_x)
  print("z:", sampled_z)
  print()

x: [1, 1, 0, 0, 0]
z: [0, 0, 0, 0, 1]

x: [1, 1, 1, 0, 1]
z: [0, 1, 0, 0, 0]

x: [0, 0, 1, 1, 1]
z: [0, 1, 0, 0, 0]

x: [1, 1, 1, 1, 0]
z: [1, 1, 0, 0, 0]

x: [1, 0, 0, 1, 0]
z: [0, 0, 0, 1, 1]



## Step 3: The Forward Algorithm
Now, let's implement the forward algorithm. Note, we will be using the log-domain iteration of the algorithm as it is computationally less expensive.

In [16]:
def forward(self, x, T):
  """
  x : IntTensor of shape (batch size, T_max)
  T : IntTensor of shape (batch size)

  Compute log p(x) for each example in the batch.
  T = length of each example 

  This function also locally normalizes the unnormalized_sp, unnormalized_trans, unnormalized_emiss 
  using the torch.nn.functional.log_softmax(...)

  Worth nothing, batch size is just the number of observation <<sequences>> passed to the forward algorith for probability calculation.
  """
  if self.is_cuda:
   x = x.cuda()
   T = T.cuda()

  batch_size = x.shape[0] # how many sequences we'll be calculating for
  T_max = x.shape[1] # the number of time observations

  log_state_priors = torch.nn.functional.log_softmax(self.unnormalized_sp, dim=0) # this log normalizes state priors
  log_alpha = torch.zeros(batch_size, T_max, self.N) # creates alpha prob matrix in R3. firs dim is batch or sequence number, second is time observation, and last is states
  
  if self.is_cuda: 
    log_alpha = log_alpha.cuda()

  # SPECIAL NOTE: self.unnormalized_emiss(x[:,0]) will invoke the function __call__(...) from EmissionMatrix that then implicitly calls emission_model_forward(self, x[:,0])
    
  log_alpha[:, 0, :] = self.unnormalized_emiss(x[:,0]) + log_state_priors
  for t in range(1, T_max):
    log_alpha[:, t, :] = self.unnormalized_emiss(x[:,t]) + self.unnormalized_trans(log_alpha[:, t-1, :])

  # Select the sum for the final timestep (each x may have different length).
  log_sums = log_alpha.logsumexp(dim=2)
  log_probs = torch.gather(log_sums, 1, T.view(-1,1) - 1)

  return log_probs.exp()

def emission_model_forward(self, x_t):
  log_emission_matrix = torch.nn.functional.log_softmax(self.matrix, dim=1)
  out = log_emission_matrix[:, x_t].transpose(0,1)
  return out

def transition_model_forward(self, log_alpha):
  """
  log_alpha : Tensor of shape (batch size, N)
  Multiply previous timestep's alphas by transition matrix (in log domain)
  """
  log_transition_matrix = torch.nn.functional.log_softmax(self.matrix, dim=0)

  # Matrix multiplication in the log domain
  out = log_domain_matmul(log_transition_matrix, log_alpha.transpose(0,1)).transpose(0,1)
  return out

def log_domain_matmul(log_A, log_B):
	"""
	log_A : m x n
	log_B : n x p
	output : m x p matrix

	Normally, a matrix multiplication
	computes out_{i,j} = sum_k A_{i,k} x B_{k,j}

	A log domain matrix multiplication
	computes out_{i,j} = logsumexp_k log_A_{i,k} + log_B_{k,j}
	"""
	m = log_A.shape[0]
	n = log_A.shape[1]
	p = log_B.shape[1]

	log_A_expanded = torch.reshape(log_A, (m,n,1))
	log_B_expanded = torch.reshape(log_B, (1,n,p))

	elementwise_sum = log_A_expanded + log_B_expanded
	out = torch.logsumexp(elementwise_sum, dim=1)

	return out

TransitionMatrix.forward = transition_model_forward
EmissionMatrix.forward = emission_model_forward
HMM.forward = forward

As an exercise, let's go ahead and calculate the probabilities of all length-3 observations sequences per our model.

If our model is working, the sum of these probabilities should be 1.

In [35]:
sequences = [[0,0,0], [1,0,0], [0,1,0], [0,0,1], [1,1,0], [0,1,1], [1,0,1], [1,1,1]]
x = torch.stack([torch.tensor(x) for x in sequences])
T = torch.tensor([3 for x in sequences])
p_sequences = model.forward(x, T)

print(p_sequences)
print(f'Length-3 Sequences Sum: {sum(p_sequences).item()}')

tensor([[0.0928],
        [0.1114],
        [0.1108],
        [0.1114],
        [0.1356],
        [0.1356],
        [0.1350],
        [0.1673]], grad_fn=<ExpBackward0>)
Length-3 Sequences Sum: 1.0


### STEP 4: Viterbi Analysis
Next, we will implement the viterbi analysis algorithm to calculate the most likely state sequence given the model.