## Emissions
This is the fixed emission code from part 2.

In [1]:
def train_emission(filename):
    """
    Returns - a dictionary containing emission parameters
    """
    with open(filename, encoding="utf8") as f:
        lines = f.readlines()
    
    # for each state y, keep track of each observation count i.e. count (y -> x)
    # before eg: {state1: {obs1: 1, obs2: 5}, state2: {obs1: 4}}
    emission_dict = {}
    
    # update emission_dict for state with count(y -> x) = 0
    # after eg: {state1: {obs1: 1, obs2: 5}, state2: {obs1: 4, obs2: 0}}
    observations = set()
    
    for line in lines:
        split_line = line.split()
        
        # process only valid lines
        if len(split_line) == 2:
            obs, state = split_line[0], split_line[1]
            
            observations.add(obs)
            
            if state not in emission_dict:
                emission_dict[state] = {}
                
            if obs not in emission_dict[state]:
                emission_dict[state][obs] = 1
            else:
                emission_dict[state][obs] += 1

    for k, v in emission_dict.items():
        for obs in observations:
            if obs not in v:
                emission_dict[k][obs] = 0
    
    return emission_dict

In [2]:
def get_emission_params_fixed(emission_dict, state, obs, k=0.5):
    
    if state not in emission_dict:
        raise Exception("State not in emission dict")
    
    state_data = emission_dict[state]
    count_y = sum(state_data.values()) # count(y)
    
    if obs == "#UNK#":
        count_y_to_x = k
    else:
        count_y_to_x = state_data[obs] # count(y -> x)
    
    return count_y_to_x / (count_y + k)

## Transitions 

In [3]:
def train_transition(filename):
    """
    Returns - a dictionary containing transition parameters
    """
    with open(filename, encoding="utf8") as f:
        lines = f.readlines()
    
    # for each state u, keep track of each state count i.e. count (u,v)
    # before eg: {START: {y1: 1, y2: 5}, y1: {y1: 3, y2: 4, STOP: 1}, y2: {y1: 1, STOP: 3}}
    transition_dict = {}
    
    # after eg: {START: {y1: 1, y2: 5, STOP: 0}, y1: {y1: 3, y2: 4, STOP: 1}, y2: {y1: 1, y2: 0, STOP: 3}}
    states = set()
    states.add('STOP')
    
    prev_state = 'START'
        
    for line in lines:
        split_line = line.split()
        
        if prev_state not in transition_dict:
            transition_dict[prev_state] = {}
                
        # can only be START or STOP
        if len(split_line) < 2:
            if 'STOP' not in transition_dict[prev_state]:
                transition_dict[prev_state]['STOP'] = 0
            
            transition_dict[prev_state]['STOP'] += 1
            prev_state = 'START'
        
        # processing the sentence
        elif len(split_line) == 2:
            curr_state = split_line[1]
            states.add(curr_state)
           
            if curr_state not in transition_dict[prev_state]:
                transition_dict[prev_state][curr_state] = 0
            
            transition_dict[prev_state][curr_state] += 1
            prev_state = curr_state
    
    for k, v in transition_dict.items():
        for state in states:
            if state not in v:
                transition_dict[k][state] = 0
    
    return transition_dict

In [4]:
def get_transition_params(transition_dict, u, v):
    
    if u not in transition_dict:
        raise Exception("State u not in transition dict")
        
    if v not in transition_dict[u]:
        raise Exception("State v not in transition dict")
    
    state_data = transition_dict[u]
    
    count_u_to_v = state_data[v] # count(u,v)
    count_u = sum(state_data.values()) # count(u)
            
    return count_u_to_v / count_u

### Testing with example

In [5]:
transition_dict = train_transition('../dataset/EN/train')

#transition_dict
get_transition_params(transition_dict, 'START', 'B-PP')

0.1087041628604985

## Training

In [6]:
def train(filename):
    """
    Returns - emission and transition parameters
    """
    return train_emission(filename), train_transition(filename)