In [1]:
import re 
import pandas as pd 
from collections import defaultdict

In [2]:
def preprocess_treebank(treebank_str):
    treebank_str = treebank_str.replace("\n", "")
    treebank_str = treebank_str.replace("=", "")
    treebank_str = treebank_str.replace("[", "")
    treebank_str = treebank_str.replace("]", "")
    treebank_str = treebank_str.strip()
    return treebank_str

def parse_treebank(treebank_str): 
    pattern = re.compile(r"(\S+)/(\S+)")
    tree_sents = treebank_str.split("\n\n")
    parsed_sents = [] 
    

    for sent in tree_sents: 
        sent = preprocess_treebank(sent)
        parsed_sent = [] 
        parsed_sent.append(("<s>", "<s>"))
        tokens = sent.split(" ")

        for token in tokens: 
            match = pattern.match(token)

            if (match): 
                delim_idx = token.find("/")
                word = token[:delim_idx]
                tag = token[delim_idx+1:]
                parsed_sent.append((word, tag))
        parsed_sent.append(("<e>", "<e>"))

        if len(parsed_sent) > 2:
            parsed_sents.append(parsed_sent)

    return parsed_sents

def get_tags_count(documents):
    tags = defaultdict(int)
    pair_tags = defaultdict(int)

    for doc in documents: 
        for parsed_word in doc: 
            tag = parsed_word[1]
            tags[tag] += 1
        
        for i in range(len(doc)-1): 
            tag1 = doc[i][1]
            tag2 = doc[i+1][1]
            pair_tags[(tag1, tag2)] += 1
    return tags, pair_tags

def get_emit_count(documents): 
    emit_count = defaultdict(int)

    for doc in documents: 
        for (word, tag) in doc: 
            emit_count[(word, tag)] += 1
    
    return emit_count

def get_vocabulary(documents):
    vocab = set()
    for doc in documents: 
        for parsed_word in doc: 
            word = parsed_word[0]
            vocab.add(word)
    return list(vocab)

def get_transition_matrix(documents):
    tags, pair_tags = get_tags_count(documents)

    trans = pd.DataFrame(0, index=tags, columns=tags)
    for doc in documents: 
        for i in range(len(doc)-1): 
            left_tag = doc[i][1]
            right_tag = doc[i+1][1]

            count_left_to_right = pair_tags[(left_tag, right_tag)]
            count_left = tags[left_tag]

            prob_left_to_right = count_left_to_right / count_left
            trans.loc[left_tag, right_tag]= prob_left_to_right
    
    return trans

def get_emission_matrix(documents): 
    tags, _ = get_tags_count(documents)
    emit_count = get_emit_count(documents)
    vocab = get_vocabulary(documents)

    emit = pd.DataFrame(0, index=tags, columns=vocab)
    for doc in documents: 
        for (word, tag) in doc: 
            count_emit = emit_count[(word, tag)]
            count_tag = tags[tag]

            emit_prob = count_emit / count_tag
            emit.loc[tag, word] = emit_prob
    
    return emit

def get_initial_state_matrix(documents):
    tags, _ = get_tags_count(documents)
    init = pd.Series(0, index=tags)
    for doc in documents: 
        tag = doc[0][1]
        init[tag] += 1
    init = init / init.sum()
    return init


def parse_treebank_file(filepath): 
    with open(filepath, "r") as f: 
        treebank_str = f.read() 
        return parse_treebank(treebank_str)

def get_file_id(i): 
    return f"{i:04d}"


def read_all_treebank_files(n_files, file_dir="./treebank/tagged"): 
    all_documents = []
    for i in range(1, n_files+1): 
        file_id = get_file_id(i)
        filepath = f"{file_dir}/wsj_{file_id}.pos"
        documents = parse_treebank_file(filepath)
        all_documents.extend(documents)
    return all_documents

In [75]:
class HiddenMarkovModel: 
    def __init__(self, treebank_dir, n_files): 
        self.all_docs = read_all_treebank_files(n_files, treebank_dir)

        self.trans = get_transition_matrix(self.all_docs)
        self.emit = get_emission_matrix(self.all_docs)
        self.pi = get_initial_state_matrix(self.all_docs)
        
        self.states = list(self.emit.index)
        self.vocab = list(self.emit.columns)

        self.n_states = len(self.states)
        self.n_vocab = len(self.vocab)
        
    def preprocess_obs(self, obs): 
        obs = "<s> " + obs + " <e>"
        return obs 
    
    def viterbi(self, obs): 
        proc_obs = self.preprocess_obs(obs)
        obs_list = proc_obs.split(" ")
        n_obs = len(obs_list)
        V = pd.DataFrame(0, index=self.states, columns=obs_list)
        prev = pd.DataFrame(0, index=self.states, columns=obs_list)
        
        for tag in self.states:
            first_ob = obs_list[0]
            V.loc[tag, first_ob] = self.pi[tag] * self.emit.loc[tag, first_ob]

        print(V)

        for t in range(1, n_obs): 
            ob = obs_list[t]
            prev_ob = obs_list[t-1]
            for cur_state in self.states: 
                for prev_state in self.states: 
                    new_prob = V.loc[prev_state, prev_ob] * self.trans.loc[prev_state, cur_state] * self.emit.loc[cur_state, ob]
                    if new_prob > V.loc[cur_state, ob]:
                        V.loc[cur_state, ob] = new_prob
                        prev.loc[cur_state, ob] = prev_state

        # reconstruct the most likely sequence and return it 
        path = [] 
        path_prob = -0.1
        final_state = None

        print(V)

        for state in self.states: 
            if V.loc[state, obs_list[-1]] > path_prob: 
                path_prob = V.loc[state, obs_list[-1]]
                print(path_prob)
                final_state = state
                print(final_state)

        path.append(final_state)

        for t in range(n_obs-1, 0, -1):
            if final_state is None: 
                print("Error: previous state is None")
                print(path)
                break
            final_state = prev.loc[final_state, obs_list[t]]
            path.insert(0, final_state)
        
        return path


In [79]:
# test HMM 
treebank_dir = "./treebank/tagged"
n_files = 1
hmm = HiddenMarkovModel(treebank_dir, n_files)

In [None]:
obs = "the chairman publishing group"
hmm.viterbi(obs)