In [None]:
!pip install torch==1.4.0

In [None]:
import torch

In [None]:
import numpy as np

class HMM(torch.nn.Module):
  """
  Hidden Markov Model with discrete observations.
  """
  def __init__(self, M, N):
    super(HMM, self).__init__()
    self.M = M # number of possible observations
    self.N = N # number of states

    # A
    self.transition_model = TransitionModel(self.N)

    # b(x_t)
    self.emission_model = EmissionModel(self.N,self.M)

    # pi
    self.unnormalized_state_priors = torch.nn.Parameter(torch.randn(self.N))

    # use the GPU
    self.is_cuda = torch.cuda.is_available()
    if self.is_cuda: self.cuda()

class TransitionModel(torch.nn.Module):
  def __init__(self, N):
    super(TransitionModel, self).__init__()
    self.N = N
    self.unnormalized_transition_matrix = torch.nn.Parameter(torch.randn(N,N))

class EmissionModel(torch.nn.Module):
  def __init__(self, N, M):
    super(EmissionModel, self).__init__()
    self.N = N
    self.M = M
    self.unnormalized_emission_matrix = torch.nn.Parameter(torch.randn(N,M))

In [None]:
def sample(self, T=3):
  state_priors = torch.nn.functional.softmax(self.unnormalized_state_priors, dim=0)
  transition_matrix = torch.nn.functional.softmax(self.transition_model.unnormalized_transition_matrix, dim=0)
  emission_matrix = torch.nn.functional.softmax(self.emission_model.unnormalized_emission_matrix, dim=1)

  # sample initial state
  z_t = torch.distributions.categorical.Categorical(state_priors).sample().item()
  z = []; x = []
  z.append(z_t)
  for t in range(0,T):
    # sample emission
    x_t = torch.distributions.categorical.Categorical(emission_matrix[z_t]).sample().item()
    x.append(x_t)

    # sample transition
    z_t = torch.distributions.categorical.Categorical(transition_matrix[:,z_t]).sample().item()
    if t < T-1: z.append(z_t)

  return x, z

# Add the sampling method to our HMM class
HMM.sample = sample

In [None]:
import string
#alphabet = string.ascii_lowercase
alphabet = 'abc'

def encode(s):
  """
  Convert a string into a list of integers
  """
  x = [alphabet.index(ss) for ss in s]
  return x

def decode(x):
  """
  Convert list of ints to string
  """
  s = "".join([alphabet[xx] for xx in x])
  return s

# Initialize the model
model = HMM(M=len(alphabet), N=2) 

# Hard-wiring the parameters (the input is in exp(x) >> need to log p)!
# Let state 0 = consonant, state 1 = vowel
model.unnormalized_state_priors.data[0] = np.log(0.6)   
model.unnormalized_state_priors.data[1] = np.log(0.4)
print("State priors:", torch.nn.functional.softmax(model.unnormalized_state_priors, dim=0))

model.emission_model.unnormalized_emission_matrix.data[0, 0] = np.log(0.1)
model.emission_model.unnormalized_emission_matrix.data[0, 1] = np.log(0.4)
model.emission_model.unnormalized_emission_matrix.data[0, 2] = np.log(0.5)
model.emission_model.unnormalized_emission_matrix.data[1, 0] = np.log(0.7)
model.emission_model.unnormalized_emission_matrix.data[1, 1] = np.log(0.2)
model.emission_model.unnormalized_emission_matrix.data[1, 2] = np.log(0.1)

#Transposed !!!
model.transition_model.unnormalized_transition_matrix.data[0,0] = np.log(0.7)
model.transition_model.unnormalized_transition_matrix.data[0,1] = np.log(0.4)
model.transition_model.unnormalized_transition_matrix.data[1,0] = np.log(0.3)
model.transition_model.unnormalized_transition_matrix.data[1,1] = np.log(0.6)

# In state 0, only allow consonants; in state 1, only allow vowels
#vowel_indices = torch.tensor([alphabet.index(letter) for letter in "aeiou"])
#consonant_indices = torch.tensor([alphabet.index(letter) for letter in "bcdfghjklmnpqrstvwxyz"])
#model.emission_model.unnormalized_emission_matrix[0, vowel_indices] = -np.inf
#model.emission_model.unnormalized_emission_matrix[1, consonant_indices] = -np.inf 
print("Emission matrix:", torch.nn.functional.softmax(model.emission_model.unnormalized_emission_matrix, dim=1))

# Only allow vowel -> consonant and consonant -> vowel
#model.transition_model.unnormalized_transition_matrix[0,0] = -np.inf  # consonant -> consonant
#model.transition_model.unnormalized_transition_matrix[0,1] = 0.       # vowel -> consonant
#model.transition_model.unnormalized_transition_matrix[1,0] = 0.       # consonant -> vowel
#model.transition_model.unnormalized_transition_matrix[1,1] = -np.inf  # vowel -> vowel
print("Transition matrix:", torch.nn.functional.softmax(model.transition_model.unnormalized_transition_matrix, dim=0))


In [None]:
def HMM_forward(self, x, T):
  """
  x : IntTensor of shape (batch size, T_max)
  T : IntTensor of shape (batch size)

  Compute log p(x) for each example in the batch.
  T = length of each example
  """
  if self.is_cuda:
  	x = x.cuda()
  	T = T.cuda()

  batch_size = x.shape[0]; T_max = x.shape[1]
  log_state_priors = torch.nn.functional.log_softmax(self.unnormalized_state_priors, dim=0)
  log_alpha = torch.zeros(batch_size, T_max, self.N)
  if self.is_cuda: log_alpha = log_alpha.cuda()

  log_alpha[:, 0, :] = self.emission_model(x[:,0]) + log_state_priors
  for t in range(1, T_max):
    log_alpha[:, t, :] = self.emission_model(x[:,t]) + self.transition_model(log_alpha[:, t-1, :])

  # Select the sum for the final timestep (each x may have different length).
  log_sums = log_alpha.logsumexp(dim=2)
  log_probs = torch.gather(log_sums, 1, T.view(-1,1) - 1)
  return log_probs

def emission_model_forward(self, x_t):
  log_emission_matrix = torch.nn.functional.log_softmax(self.unnormalized_emission_matrix, dim=1)
  out = log_emission_matrix[:, x_t].transpose(0,1)
  return out

def transition_model_forward(self, log_alpha):
  """
  log_alpha : Tensor of shape (batch size, N)
  Multiply previous timestep's alphas by transition matrix (in log domain)
  """
  log_transition_matrix = torch.nn.functional.log_softmax(self.unnormalized_transition_matrix, dim=0)

  # Matrix multiplication in the log domain
  out = log_domain_matmul(log_transition_matrix, log_alpha.transpose(0,1)).transpose(0,1)
  return out

def log_domain_matmul(log_A, log_B):
	"""
	log_A : m x n
	log_B : n x p
	output : m x p matrix

	Normally, a matrix multiplication
	computes out_{i,j} = sum_k A_{i,k} x B_{k,j}

	A log domain matrix multiplication
	computes out_{i,j} = logsumexp_k log_A_{i,k} + log_B_{k,j}
	"""
	m = log_A.shape[0]
	n = log_A.shape[1]
	p = log_B.shape[1]

	log_A_expanded = torch.stack([log_A] * p, dim=2)
	log_B_expanded = torch.stack([log_B] * m, dim=0)

	elementwise_sum = log_A_expanded + log_B_expanded
	out = torch.logsumexp(elementwise_sum, dim=1)

	return out

TransitionModel.forward = transition_model_forward
EmissionModel.forward = emission_model_forward
HMM.forward = HMM_forward

In [None]:
import torch.utils.data
from collections import Counter
from sklearn.model_selection import train_test_split

class TextDataset(torch.utils.data.Dataset):
  def __init__(self, lines):
    self.lines = lines # list of strings
    collate = Collate() # function for generating a minibatch from strings
    
    self.loader = torch.utils.data.DataLoader(self, batch_size=1024, num_workers=1, shuffle=True, collate_fn=collate)
    #self.loader = torch.utils.data.DataLoader(self, batch_size=2048, num_workers=1, shuffle=True, collate_fn=collate)

  def __len__(self):
    return len(self.lines)

  def __getitem__(self, idx):
    line = self.lines[idx].lstrip(" ").rstrip("\n").rstrip(" ").rstrip("\n")
    return line

class Collate:
  def __init__(self):
    pass

  def __call__(self, batch):
    """
    Returns a minibatch of strings, padded to have the same length.
    """
    x = []
    batch_size = len(batch)
    for index in range(batch_size):
      x_ = batch[index]

      # convert letters to integers
      x.append(encode(x_))

    # pad all sequences with 0 to have same length
    x_lengths = [len(x_) for x_ in x]
    T = max(x_lengths)
    for index in range(batch_size):
      x[index] += [0] * (T - len(x[index]))
      x[index] = torch.tensor(x[index])

    # stack into single tensor
    x = torch.stack(x)
    x_lengths = torch.tensor(x_lengths)
    return (x,x_lengths)

In [None]:
from google.colab import files

uploaded = files.upload()

In [None]:
#!wget https://raw.githubusercontent.com/lorenlugosch/pytorch_HMM/master/data/train/training.txt
import re
    
#filename = "training_edited-2.txt"
filename = "text.txt"

with open(filename, "r") as f:
  lines = f.readlines() # each line of lines will have one word
lines2 = [x.lower() for x in lines]
lines3 = [re.sub("[^a-z]", " ", y) for y in lines2]

temp = []
for x in lines3:
    if not not x.split():
        temp.append(x.split())
        
temp2 = []
for i in range(0,len(temp)):
    for j in range(0,len(temp[i])):
        temp2.append(temp[i][j])
        
#temp3 = [x+'\n' for x in temp2]
#temp2.sort()

#alphabet = list(Counter(("".join(lines))).keys())
#alphabet = list(Counter(("".join(temp2))).keys())
temp = string.ascii_lowercase
alphabet = [x for x in temp]

#train_lines, valid_lines = train_test_split(lines, test_size=0.1, random_state=42)
train_lines, valid_lines = train_test_split(temp2, test_size=0.1, random_state=42)
train_dataset = TextDataset(train_lines)
valid_dataset = TextDataset(valid_lines)

M = len(alphabet)

In [None]:
from tqdm import tqdm # for displaying progress bar

class Trainer:
  def __init__(self, model, lr):
    self.model = model
    self.lr = lr
    self.optimizer = torch.optim.Adam(model.parameters(), lr=self.lr, weight_decay=0.00001)
  
  def train(self, dataset):
    train_loss = 0
    num_samples = 0
    self.model.train()
    print_interval = 50
    for idx, batch in enumerate(tqdm(dataset.loader)):
      x,T = batch
      batch_size = len(x)
      num_samples += batch_size
      log_probs = self.model(x,T)
      loss = -log_probs.mean()
      self.optimizer.zero_grad()
      loss.backward()
      self.optimizer.step()
      train_loss += loss.cpu().data.numpy().item() * batch_size
      if idx % print_interval == 0:
        print("loss:", loss.item())
        for _ in range(5):
          sampled_x, sampled_z = self.model.sample()
          print(decode(sampled_x))
          print(sampled_z)
    train_loss /= num_samples
    return train_loss

  def test(self, dataset):
    test_loss = 0
    num_samples = 0
    self.model.eval()
    print_interval = 50
    for idx, batch in enumerate(dataset.loader):
      x,T = batch
      batch_size = len(x)
      num_samples += batch_size
      log_probs = self.model(x,T)
      loss = -log_probs.mean()
      test_loss += loss.cpu().data.numpy().item() * batch_size
      if idx % print_interval == 0:
        print("loss:", loss.item())
        sampled_x, sampled_z = self.model.sample()
        print(decode(sampled_x))
        print(sampled_z)
    test_loss /= num_samples
    return test_loss

In [None]:
# Initialize model
model = HMM(N=2, M=M)

# Train the model
num_epochs = 10
trainer = Trainer(model, lr=0.01)
#trainer = Trainer(model, lr=0.1)

for epoch in range(num_epochs):
        print("========= Epoch %d of %d =========" % (epoch+1, num_epochs))
        train_loss = trainer.train(train_dataset)
        valid_loss = trainer.test(valid_dataset)
        
        #train_loss = trainer.train(train_lines)
        #valid_loss = trainer.test(valid_lines)

        print("========= Results: epoch %d of %d =========" % (epoch+1, num_epochs))
        print("train loss: %.2f| valid loss: %.2f\n" % (train_loss, valid_loss) )

## Emission Probability for State 0 ##

In [None]:
import matplotlib.pyplot as plt
emission = torch.nn.functional.softmax(model.emission_model.unnormalized_emission_matrix, dim=1)
temp0 = [emission[0][x].item() for x in range(0,26)] 
i = list(range(1,27))
plt.bar(i, temp0,tick_label=alphabet)
plt.show()

## Emission Probability for State 1 ##

In [None]:
import matplotlib.pyplot as plt
emission = torch.nn.functional.softmax(model.emission_model.unnormalized_emission_matrix, dim=1)
temp1 = [emission[1][x].item() for x in range(0,26)] 
plt.bar(i,temp1,tick_label = alphabet)
plt.show()