In [98]:
from collections import defaultdict
import numpy as np
import sys

In [141]:
input_lines = [
"xyxzzxyxyy",
"--------",
"x y z",
"--------",
"A B",
"--------",
"	A	B",
"A	0.641	0.359",
"B	0.729	0.271",
"--------",
"	x	y	z",
"A	0.117	0.691	0.192	",
"B	0.097	0.42	0.483",
]

In [142]:
def parse_input(lines):
    symbols = lines[2].strip().split(" ")
    states = lines[4].strip().split(" ")
    
    state_transitions = defaultdict(dict)
    for line in lines[7:7+len(states)]:
        from_state = line.strip().split("\t")[0]
        probs = [float(f) for f in line.strip().split("\t")[1:] ]
        for state, prob in zip(states, probs):
            state_transitions[from_state][state] = prob
    
    emission_matrix = defaultdict(dict)
    for line in lines[9+len(states):]:
        from_state = line.strip().split("\t")[0]
        probs = [float(f) for f in line.strip().split("\t")[1:] ]
        for symbol, prob in zip(symbols, probs):
            emission_matrix[from_state][symbol] = prob
    return lines[0], symbols, states, dict(state_transitions), dict(emission_matrix)

In [143]:
sequence, symbols, states, state_transitions, emission_matrix = parse_input(input_lines)

In [144]:
def viterbi(sequence, state_transitions, emission_matrix, state_probabilities, backtrack):
    #print(state_probabilities)
    #sorted_states = sorted(state_probabilities.items(), key=lambda tpl: tpl[1], reverse=True)
    #yield sorted_states[0][0]
    
    if len(sequence) > 0:
        symbol = sequence[0]
        new_state_probabilities = dict()
        backtrack.append({})
        for to_state in state_transitions:
            max_prob = -sys.float_info.max
            max_state = None
            for from_state, prob in state_probabilities.items():
                p = prob + np.log(state_transitions[from_state][to_state] *  emission_matrix[to_state][symbol])
                if p > max_prob:
                    max_prob = p
                    max_state = from_state
            new_state_probabilities[to_state] = max_prob
            backtrack[-1][to_state] = max_state
        return viterbi(sequence[1:], state_transitions, emission_matrix, new_state_probabilities, backtrack)
    else:
        sorted_states = sorted(state_probabilities.items(), key=lambda tpl: tpl[1], reverse=True)
        return sorted_states[0][0]
    
def backtrack_path(backtrack, start_state):
    states = [start_state]
    next_state = backtrack[-1][start_state]
    for s in backtrack[-2::-1]:
        states.append(next_state)
        next_state = s[next_state]
    states.append(next_state)
    return "".join(states[::-1])

In [145]:
init_prob = dict()
for state in emission_matrix.keys():
    init_prob[state] = np.log(emission_matrix[state][sequence[0]]/len(state_transitions))
    
backtrack = []
final_state = viterbi(sequence[1:], state_transitions, emission_matrix, init_prob, backtrack)
print(backtrack_path(backtrack, final_state))

AAABBAAAAA


In [149]:
with open("../data/dataset_26256_7.txt", "r") as fin:
    sequence, _, _, state_transitions, emission_matrix = parse_input([l.strip() for l in fin])
    
    init_prob = dict()
    for state in emission_matrix.keys():
        init_prob[state] = emission_matrix[state][sequence[0]]/len(state_transitions)
    
    backtrack = []
    final_state = viterbi(sequence[1:], state_transitions, emission_matrix, init_prob, backtrack)
    print(backtrack_path(backtrack, final_state))

BCADDDDADDBCCADADDADADABCADDABCABCADDDDADADABCADADDDDDADABCCCABCABCADDADDDDBCADDDABCABCABCABCABCCADD


In [147]:
emission_matrix = {
    "F": {
        "H": 0.5,
        "T": 0.5
    },
    "B": {
        "H": 0.75,
        "T": 0.25
    }
}

state_transitions = {
    "F": {
        "F": 0.9,
        "B": 0.1
    },
    "B": {
        "F": 0.1,
        "B": 0.9
    },
}



In [148]:
sequence = "HHTT"

init_prob = dict()
for state in emission_matrix.keys():
    init_prob[state] = emission_matrix[state][sequence[0]]/len(state_transitions)
    
backtrack = []
final_state = viterbi(sequence[1:], state_transitions, emission_matrix, init_prob, backtrack)
print(backtrack_path(backtrack, final_state))

FFFF
