In [1]:
import os
import numpy as np
from collections import defaultdict
import math

from conlleval2 import evaluate

In [2]:
def estimate_emission_parameter(file_path):
         
    emissions = {}
    y_count = {}
    resulting_dict = {}
    
    train_data = open(file_path, 'r', encoding="utf-8")
    lines = train_data.readlines()
    
    all_x = set()
    all_y = set()

    for line in lines:
        line = line.strip()

        if len(line) > 0:
            x, pos, y = line.split(" ")
            all_x.add(x)
            all_y.add(y)
            y_count[y] = y_count.get(y,0) + 1
            emissions[(x, y)]  = emissions.get((x,y),0) + 1
            emissions[(pos, y)] = emissions.get((pos,y),0) + 1

    for x, y in emissions.keys():
        key = "emission:" + y + "+" + x
        resulting_dict[key] = math.log(emissions[(x, y)] / y_count[y])
        emissions[(x, y)] = math.log(emissions[(x, y)] / y_count[y])
    
    return resulting_dict, list(all_y), emissions

train_path = "full/train"
resulting_dict ,states, emissions = estimate_emission_parameter(train_path)

In [3]:
emissions

{('The', 'O'): -4.333515843240958,
 ('DT', 'O'): -2.2045841023613524,
 ('official', 'O'): -7.166729187297174,
 ('JJ', 'O'): -2.617071711239342,
 ('cause', 'O'): -8.776167099731275,
 ('NN', 'O'): -1.8272698774179623,
 ('of', 'O'): -3.518671727703493,
 ('IN', 'O'): -1.9894501491261936,
 ('death', 'O'): -7.166729187297174,
 ('has', 'O'): -4.649032714686183,
 ('VBZ', 'O'): -3.5639524322366496,
 ('not', 'O'): -5.97280671882474,
 ('RB', 'O'): -3.8525431826246486,
 ('been', 'O'): -5.755742213586912,
 ('VBN', 'O'): -3.3294297280649645,
 ('officially', 'O'): -8.370701991623111,
 ('determined', 'O'): -8.370701991623111,
 (',', 'O'): -2.6427690567346263,
 ('but', 'O'): -6.211217742269738,
 ('CC', 'O'): -3.7199212943829667,
 ('investigators', 'O'): -9.46931428029122,
 ('NNS', 'O'): -2.5233002891919925,
 ('believe', 'O'): -9.46931428029122,
 ('VBP', 'O'): -4.122206749573751,
 ('the', 'O'): -2.8532490951584024,
 ('36-year-old', 'O'): -9.46931428029122,
 ('writer', 'O'): -9.46931428029122,
 ('died', 

In [4]:
def get_resulting_and_transition_dict(file_path, emission_dict):
    transitions = {}
    y_count = {}
    
    train_data = open(file_path, 'r', encoding="utf-8")
    lines = train_data.readlines()
    start = "start"
    all_y = set(["start", "stop"])

    for line in lines:
        line = line.strip()
        if len(line) <= 0:
            transitions[(start, "stop")] = transitions.get((start, "stop"),0) + 1
            start = "start"
            y_count[start] = y_count.get(start,0) + 1
        else:
            x, pos, y = line.split(" ")
            transitions[(start, y)] =  transitions.get((start,y),0) + 1
            y_count[y] = y_count.get(y,0) + 1
            start = y
            all_y.add(y)

    for start, end in transitions.keys():
        key = "transition:" + start + "+" + end
        emission_dict[key] = math.log(transitions[(start, end)] / y_count[start])
        transitions[(start, end)] = math.log(transitions[(start, end)] / y_count[start])

    return emission_dict, transitions

resulting_dict, transitions = get_resulting_and_transition_dict(train_path, resulting_dict)

In [5]:
resulting_dict

{'emission:O+The': -4.333515843240958,
 'emission:O+DT': -2.2045841023613524,
 'emission:O+official': -7.166729187297174,
 'emission:O+JJ': -2.617071711239342,
 'emission:O+cause': -8.776167099731275,
 'emission:O+NN': -1.8272698774179623,
 'emission:O+of': -3.518671727703493,
 'emission:O+IN': -1.9894501491261936,
 'emission:O+death': -7.166729187297174,
 'emission:O+has': -4.649032714686183,
 'emission:O+VBZ': -3.5639524322366496,
 'emission:O+not': -5.97280671882474,
 'emission:O+RB': -3.8525431826246486,
 'emission:O+been': -5.755742213586912,
 'emission:O+VBN': -3.3294297280649645,
 'emission:O+officially': -8.370701991623111,
 'emission:O+determined': -8.370701991623111,
 'emission:O+,': -2.6427690567346263,
 'emission:O+but': -6.211217742269738,
 'emission:O+CC': -3.7199212943829667,
 'emission:O+investigators': -9.46931428029122,
 'emission:O+NNS': -2.5233002891919925,
 'emission:O+believe': -9.46931428029122,
 'emission:O+VBP': -4.122206749573751,
 'emission:O+the': -2.8532490

In [6]:
def get_combined_dict(resulting_dict, emissions, transitions):
    for word, state in emissions.keys():
        for start, end in transitions.keys():
            key = "combine:" + start + "+" + end + "+" + word
            resulting_dict[key] = transitions.get((start, end), -9999999) + emissions.get((word, end), -9999999)
    return resulting_dict

resulting_dict = get_combined_dict(resulting_dict, emissions, transitions)

In [7]:
def viterbi_algo(x, states, resulting_dict):
    scores = np.full((len(x), len(states)), -np.inf)
    parents = np.full((len(x), len(states)), 0, dtype=int)
    threshold = -9999999
    for i in range(len(states)):
        combined_key1 = "combine:" + "start" + "+" + states[i] + "+" + x[0].split()[0]
        combined_key2 = "combine:" + "start" + "+" + states[i] + "+" + x[0].split()[1]
        emission_key1 = "emission:" + states[i] + "+" + x[0].split()[0]
        emission_key2 = "emission:" + states[i] + "+" + x[0].split()[1]
        transition_key = "transition:" + "start" + "+" + states[i]

        scores[0, i] =+ resulting_dict.get(combined_key1, threshold) + \
                        resulting_dict.get(combined_key2, threshold) + \
                        resulting_dict.get(emission_key1, threshold) + \
                        resulting_dict.get(emission_key2,threshold) + \
                            resulting_dict.get(transition_key, threshold) 
    
    for i in range(1, len(x)):
        for j in range(len(states)): 
            for k in range(len(states)):
                combined_key1 = "combine:" + states[j] + "+" + states[k] + "+" + x[i].split()[0]
                combined_key2 = "combine:" + states[j] + "+" + states[k] + "+" + x[i].split()[1]
                emission_key1 = "emission:" + states[k] + "+" + x[i].split()[0]
                emission_key2 = "emission:" + states[k] + "+" + x[i].split()[1]
                transition_key = "transition:" + states[j] + "+" + states[k]

                overall_score = scores[i-1, j] + resulting_dict.get(combined_key1, threshold) + \
                                                    resulting_dict.get(combined_key2, threshold) + \
                                                    resulting_dict.get(emission_key1,threshold) + \
                                                    resulting_dict.get(emission_key2, threshold) + \
                                                    resulting_dict.get(transition_key, threshold)
                   
                if overall_score > scores[i, k]:
                    scores[i, k] = overall_score
                    parents[i,k] = j
    
    best_score = -np.inf
    best_parent = None
    
    for i in range(len(states)):
        transition_key = "transition:" + states[i] + "+" + "stop"
        total = scores[len(x)-1, i] + resulting_dict.get(transition_key, -10**8)    
        if total > best_score:
            best_score = total
            best_parent = i
    
    best_state = [states[best_parent]]
    prev_parent = best_parent
    
    for i in range(len(x)-1, 0, -1):
        prev_parent = parents[i, prev_parent]
        output = states[prev_parent]
        best_state = [output] + best_state
    
    return best_state

In [8]:
def get_prediction(file_dir,resulting_dict):

    train_path = os.path.join(file_dir, "train")
    input_path = os.path.join(file_dir, "dev.in")
    
    test_set = open(input_path, 'r', encoding="utf-8")
    lines = test_set.readlines()
    
    sequences = [] 
    sequence = []
    for line in lines:
        if line == '\n':
            sequences.append(sequence)
            sequence = []
            continue

        line = line.replace('\n', '')
        sequence.append(line)
        
    out_path = os.path.join(file_dir, "dev.p5.CRF.f4.out")
    out_file = open(out_path, "w", encoding="utf-8")
    
    for x in sequences:
        predicted = viterbi_algo(x, states, resulting_dict)
        for i in range(len(x)):
            out_file.write(x[i] + ' ' + predicted[i] + '\n')
        out_file.write('\n')
    out_file.close()
    
    print("Complete prediction for dataset")

file_dir = "full"
get_prediction(file_dir,resulting_dict)

Complete prediction for dataset


In [9]:
def eval(pred,gold):
    f_pred = open(pred,encoding = 'utf-8')
    f_gold = open(gold,encoding = 'utf-8')
    data_pred = f_pred.readlines()
    data_gold = f_gold.readlines()
    gold_tags = list()
    pred_tags = list()
    for sentence in range(len(data_pred)):
        words_pred = data_pred[sentence].strip().split(' ')
        words_gold = data_gold[sentence].strip().split(' ')  
        if len(words_gold)==1:
            continue
        gold_tags.append(words_gold[2])
        pred_tags.append(words_pred[2])
    return gold_tags,pred_tags


true_path = "full/dev.out"
pred_path = "full/dev.p5.CRF.f4.out"

g_tags, p_tags = eval(true_path, pred_path)
print(evaluate(g_tags,p_tags,verbose=True))

processed 2097 tokens with 226 phrases; found: 236 phrases; correct: 161.
accuracy:  72.98%; (non-O)
accuracy:  93.90%; precision:  68.22%; recall:  71.24%; FB1:  69.70
              art: precision:   0.00%; recall:   0.00%; FB1:   0.00  3
              eve: precision:   0.00%; recall:   0.00%; FB1:   0.00  1
              geo: precision:  85.88%; recall:  73.00%; FB1:  78.92  85
              gpe: precision:  68.00%; recall:  80.95%; FB1:  73.91  25
              nat: precision:   0.00%; recall:   0.00%; FB1:   0.00  2
              org: precision:  42.86%; recall:  51.72%; FB1:  46.88  35
              per: precision:  68.75%; recall:  75.86%; FB1:  72.13  32
              tim: precision:  64.15%; recall:  75.56%; FB1:  69.39  53
((68.22033898305084, 71.23893805309734, 69.69696969696969), 0)
