In [None]:
import sys
import os

# Add the scripts folder to the Python path
sys.path.append(os.path.abspath("../scripts"))  # Adjust the path accordingly

In [None]:
import pandas as pd
import numpy as np
import pickle
from load_data import *
from preprocessing import *

# Load Data from File

In [None]:
train, test = load_brown_data("../data/brown-universal.txt", split=0.8)
tags = load_tags("../data/tags-universal.txt")

In [None]:
print("There are {} sentences in the training set.".format(len(train)))
print("There are {} sentences in the testing set.".format(len(test)))

In [None]:
# partition train so only a few of the samples are used for the initial probabilities
train_sample = train[-1000:]

In [None]:
words = []
train_sentences = []
for sentence in train_sample:
    train_sentence = []
    for token in sentence:
        word = token.get_word()
        if word == '``' or word == "''":
            word = '"'
        words.append(word)
        train_sentence.append(word)
        
            
    train_sentences.append(train_sentence)

# Baum-Welch Algorithm

Note that the code for train_em and baum_welch were written using NLTK's open-source code as a reference, which can be found here https://www.nltk.org/api/nltk.tag.hmm.html

In [None]:
class BaumWelch:
    def __init__(self, tags, words):
        self.states = tags
        self.vocab = words
        self.vocab_lookup = {word: i for i, word in enumerate(words)}
        self.num_tags = len(self.states)  
        self.vocab_size = len(words)  

        # initialize uniform, can be initialized using initialize_probabilities() w/ other values
        transition_probs = np.random.rand(self.num_tags, self.num_tags)
        emission_probs = np.random.rand(self.num_tags, self.vocab_size)
        initial_probs = np.random.rand(self.num_tags)

        # normalize
        transition_probs /= transition_probs.sum(axis=1, keepdims=True)  
        emission_probs /= emission_probs.sum(axis=1, keepdims=True) 
        initial_probs /= initial_probs.sum() 

        self.transition_probs = np.log(transition_probs)
        self.emission_probs = np.log(emission_probs)
        self.initial_probs = np.log(initial_probs)
    
    def save_hmm(hmm, filename="../results/hmm_tagger-BW.pkl"):
        """Save HMM to a file"""
        with open(filename, "wb") as f:
            pickle.dump({
                "transition_probs": np.exp2(hmm.transition_probs),
                "emission_probs": np.exp2(hmm.emission_probs),
                "initial_probs": np.exp2(hmm.initial_probs),
                "vocab": hmm.vocab,
                "states": hmm.states
            }, f)

    def initialize_probabilities(self, transition, emission, initial, log=False):
        """initialize transition, emission, and initial probabilities.
        log: boolean - if the probabilities fed in are in log space or not
        """
        if not log:
            self.transition_probs = np.log(transition)
            self.emission_probs = np.log(emission)
            self.initial_probs = np.log(initial)
        else: 
            self.transition_probs = transition
            self.emission_probs = emission
            self.initial_probs = initial

    def logsumexp2(self, arr):
        max_ = max(arr)
        return np.log2(np.sum(2 ** (arr - max_))) + max_


    def train_em(self, sequences, max_iterations=100):
        """Train the HMM using the Expectation-Maximization algorithm."""
        # track convergence by looking at log probability
        converged = False
        prev_log_likelihood = None
        iteration = 0
        epsilon = 0.001

        # compute expectations over all training sequences
        # one loop = one forward-backward pass
        while not converged and iteration < max_iterations:
            # initialize probability tables to modify
            acc_transitions_num = np.full((self.num_tags, self.num_tags), -np.inf) 
            acc_emissions_num = np.full((self.num_tags, self.vocab_size), -np.inf)
            acc_initial_num = np.full(self.num_tags, -np.inf)

            # track denominators for normalization
            acc_transition_denom = np.full(self.num_tags, -np.inf)
            acc_emission_denom = np.full(self.num_tags, -np.inf)
            acc_initial_denom = -np.inf

            log_likelihood = 0
            for sequence in sequences:
                if len(sequence) <= 1:
                    log_prob_sequence = 1 # random value so that it doesn't falsely converge
                    continue
                # values for this sequence
                (log_prob_sequence, 
                 seq_acc_transitions_num, 
                 seq_acc_emissions_num, 
                 seq_acc_transition_denom, 
                 seq_acc_emission_denom,
                 seq_acc_initial_num, 
                 seq_acc_initial_denom) = self.baum_welch(sequence)
                

                for i in range(self.num_tags):
                    # print(f"{seq_acc_emissions_num[i]} - {log_prob_sequence}")
                    acc_transitions_num[i] = np.logaddexp2(acc_transitions_num[i], seq_acc_transitions_num[i] - log_prob_sequence)
                    acc_emissions_num[i] = np.logaddexp2(acc_emissions_num[i], seq_acc_emissions_num[i] - log_prob_sequence)
                
                acc_transition_denom = np.logaddexp2(acc_transition_denom, seq_acc_transition_denom - log_prob_sequence)
                acc_emission_denom = np.logaddexp2(acc_emission_denom, seq_acc_emission_denom - log_prob_sequence)

                acc_initial_num = np.logaddexp2(acc_initial_num, seq_acc_initial_num - log_prob_sequence)
                acc_initial_denom = np.logaddexp2(acc_initial_denom, seq_acc_initial_denom - log_prob_sequence)

                log_likelihood += log_prob_sequence
                
            # update the transition and output probability values
            for i in range(self.num_tags):
                # print(f"transition: {acc_transitions_num[i]} - {acc_transition_denom[i]}")
                # print(f"emission: {acc_emissions_num[i]} - {acc_emission_denom[i]}")
                logprob_trans_i = acc_transitions_num[i] - acc_transition_denom[i]
                logprob_ems_i = acc_emissions_num[i] - acc_emission_denom[i]
                # replace any -inf with a value for stability
                logprob_ems_i[np.isinf(logprob_ems_i) & (logprob_ems_i < 0)] = np.log(1e-10)
                logprob_trans_i[np.isinf(logprob_trans_i) & (logprob_trans_i < 0)] = np.log(1e-10)

                logprob_trans_i -= self.logsumexp2(logprob_trans_i)
                logprob_ems_i -= self.logsumexp2(logprob_ems_i)


                # transition probabilities
                for j in range(self.num_tags):
                    self.transition_probs[i, j] = logprob_trans_i[j]
                # emission probabilities
                for k in range(self.vocab_size):
                    self.emission_probs[i, k] = logprob_ems_i[k]
            # initial probabilities
            logprob_initial = acc_initial_num - acc_initial_denom
            # replace -inf for numerical stability
            logprob_initial[np.isinf(logprob_initial) & (logprob_initial < 0)] = np.log(1e-10)
            # normalize
            logprob_initial -= self.logsumexp2(logprob_initial)

            
            # test for convergence
            if iteration > 0 and abs(log_likelihood - prev_log_likelihood) < epsilon:
                converged = True

            print("iteration", iteration, "logprob", log_likelihood)
            iteration += 1
            prev_log_likelihood = log_likelihood
        return self

    def baum_welch(self, sequence):  
        """One forward-backward pass"""    
        # forward and backward probabilities
        alpha = self.forward(sequence)
        beta = self.backward(sequence)

        log_prob_sequence = self.logsumexp2(alpha[-1][:])

        # initialize probability tables to modify
        acc_transition_num = np.full((self.num_tags, self.num_tags), -np.inf) 
        acc_emission_num = np.full((self.num_tags, self.vocab_size), -np.inf)
        acc_inital_num = np.full(self.num_tags, -np.inf)

        # track denominators for normalization
        acc_transition_denom = np.full(self.num_tags, -np.inf)
        acc_emission_denom = np.full(self.num_tags, -np.inf)
        acc_initial_denom = -np.inf

        for t in range(len(sequence)):
            word = sequence[t]
            next_word = None
            if t < len(sequence) - 1:
                next_word = sequence[t+1]
                next_word_idx = self.vocab_lookup.get(next_word)
                next_prob = self.emission_probs[:, next_word_idx:next_word_idx+1]
            xi = self.vocab_lookup.get(word)

            gamma = alpha[t] + beta[t]    

            if t == 0:
                acc_inital_num = gamma
                acc_initial_denom = self.logsumexp2(gamma)   

            if t < len(sequence) - 1:
                numer_sum = self.transition_probs + next_prob + beta[t+1] + alpha[t].reshape(self.num_tags, 1)
                acc_transition_num = np.logaddexp2(acc_transition_num, numer_sum)
                acc_transition_denom = np.logaddexp2(acc_transition_denom, gamma)
            else:
                acc_emission_denom = np.logaddexp2(acc_transition_denom, gamma)
            
            acc_emission_num[:, xi] = np.logaddexp2(acc_emission_num[:, xi], gamma)
        
        return (log_prob_sequence,
                acc_transition_num,
                acc_emission_num, 
                acc_transition_denom,
                acc_emission_denom,
                acc_inital_num,
                acc_initial_denom)
    
    def forward(self, sequence):
        """Compute forward probabilities (alpha)"""
        sent_length = len(sequence)
        alpha = np.full((sent_length, self.num_tags), -np.inf)

        # initialization step
        first_word = sequence[0]
        word_idx = self.vocab_lookup.get(first_word) if first_word in self.vocab else -1

        if word_idx >= 0: # if word is in vocabulary
            alpha[0, :] = self.initial_probs + self.emission_probs[:, word_idx]
        else:
            alpha[0, :] = self.initial_probs + np.log2(1e-6)

        # Recursion step
        for t in range(1, sent_length):
            word = sequence[t]
            word_idx = self.vocab_lookup.get(word) if word in self.vocab else -1
            for j in range(self.num_tags):
                output_prob = 0
                alpha_sum = self.logsumexp2(alpha[t-1] + self.transition_probs[:, j])
                if word_idx >= 0: # if word is in vocabulary
                    output_prob = self.emission_probs[j, word_idx]
                else:
                    output_prob = np.log2(1e-6)

                alpha[t, j] = alpha_sum + output_prob
        # replace any -inf for numerical stability
        alpha[np.isinf(alpha) & (alpha < 0)] = np.log(1e-10)
        return alpha
    
    def backward(self, sequence):
        """Compute backward probabilities (beta)"""
        sent_length = len(sequence)
        beta = np.full((sent_length, self.num_tags), -np.inf)

        # initialize
        last_word = sequence[-1]
        word_idx = self.vocab_lookup.get(last_word) if last_word in self.vocab else -1

        beta[-1, :] = self.initial_probs + np.log2(1e-6)

        # recursion
        for t in range(sent_length - 2, -1, -1):
            word = sequence[t+1]  # emission (word) at t
            word_idx =  self.vocab_lookup.get(word) if word in self.vocab else -1
            for j in range(self.num_tags):
                # P(transition from j -> any) + beta(t + 1 -> any)
                # note: multiplication is addition in log space
                beta_sum = self.transition_probs[j, :] + beta[t + 1] 
                # add P(next evidence | any)
                if word_idx >= 0: # if word is in vocabulary
                   beta_sum += self.emission_probs[j, word_idx]
                else:
                    beta_sum += np.log2(1e-6)
                    
                beta[t, j] = self.logsumexp2(beta_sum)

        beta[np.isinf(beta) & (beta < 0)] = np.log(1e-10)
        return beta

In [None]:
bw_tagger = BaumWelch(tags, words)

In [None]:
num_samples = len(train_sentences)
batch_size = 50
batch_start = 0
batch_end = batch_size
done = False
while not done:
    bw_tagger.train_em(train_sentences[batch_start:batch_end], max_iterations=10)
    batch_start = batch_end
    if batch_start >= num_samples:
        done = True
    else:
        batch_end = min(batch_start + batch_size, num_samples)
    if batch_end == num_samples:
        done = True

In [None]:
bw_tagger.save_hmm()