In [1]:
import numpy as np
import pandas as pd
from collections import defaultdict
from tqdm import tqdm_notebook as tqdm, tnrange as trange

## Emissions (with smoothing)

In [2]:
def learn_emissions(train_filename):
    ''' Learns emissions parameters from data and returns them as a nested dictionary '''
    with open(train_filename, "r") as f:
        lines = f.readlines()

#     # Keep set of all unique states and observations
#     states = set()
    observations = set()

    # Track emission counts
    emissions = {} # Where key is y, and value is a dictionary of emissions x from y with their frequency

    # Learn from data
    for line in tqdm(lines, desc='Emissions'):
        data_split = line.strip().split(" ")

        # Only process valid lines
        if len(data_split) == 2:
            obs, state = data_split[0], data_split[1]

#             states.add(state)
            observations.add(obs)

            # Track this emission
            current_emissions = {}
            if state in emissions:
                current_emissions = emissions[state]

            # If it exists, increment it, if not set it to 1
            if obs in current_emissions:
                current_emissions[obs] += 1
            else:
                current_emissions[obs] = 1

            emissions[state] = current_emissions # Update
    
    emission_counts = {k: sum(emissions[k].values()) for k in emissions}
    
    return emissions, emission_counts, observations


def get_emission_parameters(emissions, emission_counts, x, y, k=1):
    ''' Returns the MLE of the emission parameters based on the emissions dictionary '''
    if y not in emissions:  # edge case: no records of emission from this state
        return 0

    state_data = emissions[y]
    count_y = emission_counts[y] #sum(state_data.values()) # Denominator
    
    # If x == "#UNK#", it will return the following
    count_y_x = k
    
    # If x exists in training, return its MLE instead
    if x != "#UNK#":
        if x not in state_data:  # edge case: no record of this emission from this state
            return 0
        count_y_x = state_data[x] # Numerator
    
    e = count_y_x / (count_y + k)
    return e

## Transitions

In [3]:
def learn_transitions(train_filename):
    """
    Returns a dictionary containing (key, value) where
        key: (u, v)
        value: Count(u, v)
    """
    with open(train_filename, 'r') as f:
        lines = f.readlines()
        
    transitions = defaultdict(int)
    prev_state = 'START'
    # avoid excessive indentations
    for line in tqdm(lines, desc='Transitions'):
        data_split = line.strip().rsplit(' ', 1)
        
        # line breaks -> new sequence
        if len(data_split) < 2:
            transitions[(prev_state, 'STOP')] += 1
            prev_state = 'START'
            continue

        obs, curr_state = data_split
        transitions[(prev_state, curr_state)] += 1
        prev_state = curr_state
        
    # count number of 'from' states
    transition_counts = defaultdict(int)
    for (u, v), counts in transitions.items():
        transition_counts[u] += counts

    # get all unique states
    u, v = zip(*transitions)
    states = set(u) | set(v)
    return transitions, transition_counts, states

def get_transition_parameters(transitions, transition_counts, u, v):
    if transition_counts[u] == 0:  # edge case: no records of transitions starting from u
        return 0
    return transitions[(u, v)] / transition_counts[u]

## Training

In [4]:
datasets = ['SG', 'CN', 'EN', 'FR']
dataset = datasets[0]
train_filename = f'data/{dataset}/train'
validation_filename = f'data/{dataset}/dev.in'

# Train
emissions, emission_counts, observations = learn_emissions(train_filename)
transitions, transition_counts, states = learn_transitions(train_filename)

HBox(children=(IntProgress(value=0, description='Emissions', max=311777), HTML(value='')))




HBox(children=(IntProgress(value=0, description='Transitions', max=311777), HTML(value='')))




## Viterbi

In [5]:
def viterbi(transitions, transition_counts, states, emissions, emission_counts, obs_seq):
    a = lambda prev, curr: get_transition_parameters(transitions, transition_counts, prev, curr)
    b = lambda state, out: get_emission_parameters(emissions, emission_counts, x=out, y=state)

    # create empty tables
    n = len(obs_seq) + 2  # START + |obs_seq| + STOP
    P = pd.DataFrame(index=states, columns=range(n)).fillna(0)  # probability table
    B = pd.DataFrame(index=states, columns=range(n))  # backpointer table
    
    # initialization
    P.loc['START', 0] = 1
    
    # recursion
    for j in range(1, n-1):
        x = obs_seq[j-1]  # obs_seq starts from 0, j starts from 1
        for v in states:  # curr state
            for u in states:  # prev state
                p = P.loc[u, j-1] * a(u, v) * b(v, x)
                if p > P.loc[v, j]:
                    P.loc[v, j] = p  # update probability
                    B.loc[v, j] = u  # update backpointer
                    
    # termination
    j = n - 1
    for u in states:
        p = P.loc[u, j-1] * a(u, 'STOP')
        if p > P.loc[v, j]:
            P.loc[v, j] = p  # probability
            B.loc[v, j] = u  # backpointer
            
    # backtrace
    state_seq = ['STOP']
    for i in range(n-1, 0, -1):
        curr_state = state_seq[-1]
        prev_state = B.loc[curr_state, i]
        if pd.isnull(prev_state):  # edge case: no possible transition to STOP
            state_seq = []
            break
        state_seq.append(prev_state)
    state_seq = state_seq[::-1]
    return P, B, state_seq

P, B, state_seq = viterbi(transitions, transition_counts, states, emissions, emission_counts, ['Omg', "I'm"])
P

Unnamed: 0,0,1,2,3
B-neutral,0,0.0,6.591126e-09,0.0
START,1,0.0,0.0,0.0
B-positive,0,0.0,4.508593e-10,0.0
O,0,6.4e-05,6.297116e-07,0.0
B-negative,0,0.0,4.495536e-10,0.0
STOP,0,0.0,0.0,0.0
I-neutral,0,0.0,0.0,0.0
I-negative,0,0.0,0.0,0.0
I-positive,0,0.0,0.0,4.64654e-08


In [6]:
B

Unnamed: 0,0,1,2,3
B-neutral,,,O,
START,,,,
B-positive,,,O,
O,,START,O,
B-negative,,,O,
STOP,,,,
I-neutral,,,,
I-negative,,,,
I-positive,,,,O


In [7]:
state_seq

[]