In [1]:
import os
import numpy as np
from collections import defaultdict
import math
from conlleval2 import evaluate

In [2]:
def get_emission_dict(file_path):
    
    e = {}
    y_count = {}
    emission_dict = {}
    
    train_data = open(file_path, 'r', encoding="utf-8")
    lines = train_data.readlines()
    
    all_x = set()
    all_y = set()

    for line in lines:
        # "x pos y\n"
        line = line.strip()

        if len(line) > 0:
            x, pos, y = line.split(" ")
            all_x.add(x)
            all_y.add(y)
            y_count[y] = y_count.get(y,0) + 1
            e[(x, y)]  = e.get((x,y),0) + 1
            e[(pos, y)] = e.get((pos,y),0) + 1

    for x, y in e.keys():
        # x here can be x or pos tag
        key = "emission:" + y + "+" + x
        emission_dict[key] = math.log(e[(x, y)] / y_count[y])
    
    return emission_dict, list(all_y)

train_path = "full/train"
emission_dict, states = get_emission_dict(train_path)

In [3]:
emission_dict

{'emission:O+The': -4.333515843240958,
 'emission:O+DT': -2.2045841023613524,
 'emission:O+official': -7.166729187297174,
 'emission:O+JJ': -2.617071711239342,
 'emission:O+cause': -8.776167099731275,
 'emission:O+NN': -1.8272698774179623,
 'emission:O+of': -3.518671727703493,
 'emission:O+IN': -1.9894501491261936,
 'emission:O+death': -7.166729187297174,
 'emission:O+has': -4.649032714686183,
 'emission:O+VBZ': -3.5639524322366496,
 'emission:O+not': -5.97280671882474,
 'emission:O+RB': -3.8525431826246486,
 'emission:O+been': -5.755742213586912,
 'emission:O+VBN': -3.3294297280649645,
 'emission:O+officially': -8.370701991623111,
 'emission:O+determined': -8.370701991623111,
 'emission:O+,': -2.6427690567346263,
 'emission:O+but': -6.211217742269738,
 'emission:O+CC': -3.7199212943829667,
 'emission:O+investigators': -9.46931428029122,
 'emission:O+NNS': -2.5233002891919925,
 'emission:O+believe': -9.46931428029122,
 'emission:O+VBP': -4.122206749573751,
 'emission:O+the': -2.8532490

In [4]:
def get_resulting_dict(file_path, emission_dict):

    t = {}
    y_count = {}
    
    train_data = open(file_path, 'r', encoding="utf-8")
    lines = train_data.readlines()
    start = "start"
    all_y = set(["start", "stop"])

    for line in lines:
        # "x pos y\n"
        line = line.strip()
        if len(line) <= 0:
            t[(start, "stop")] = t.get((start,"stop"),0) + 1
            start = "start"
            y_count[start] = y_count.get(start,0) + 1
        else:
            x, pos, y = line.split(" ")
            t[(start, y)] = t.get((start,y),0) + 1
            y_count[y] = y_count.get(y,0) + 1
            start = y
            all_y.add(y)

    for start, end in t.keys():
        key = "transition:" + start + "+" + end
        emission_dict[key] = math.log(t[(start, end)] / y_count[start])

    return emission_dict

resulting_dict = get_resulting_dict(train_path, emission_dict)
# print(f)

In [5]:
resulting_dict

{'emission:O+The': -4.333515843240958,
 'emission:O+DT': -2.2045841023613524,
 'emission:O+official': -7.166729187297174,
 'emission:O+JJ': -2.617071711239342,
 'emission:O+cause': -8.776167099731275,
 'emission:O+NN': -1.8272698774179623,
 'emission:O+of': -3.518671727703493,
 'emission:O+IN': -1.9894501491261936,
 'emission:O+death': -7.166729187297174,
 'emission:O+has': -4.649032714686183,
 'emission:O+VBZ': -3.5639524322366496,
 'emission:O+not': -5.97280671882474,
 'emission:O+RB': -3.8525431826246486,
 'emission:O+been': -5.755742213586912,
 'emission:O+VBN': -3.3294297280649645,
 'emission:O+officially': -8.370701991623111,
 'emission:O+determined': -8.370701991623111,
 'emission:O+,': -2.6427690567346263,
 'emission:O+but': -6.211217742269738,
 'emission:O+CC': -3.7199212943829667,
 'emission:O+investigators': -9.46931428029122,
 'emission:O+NNS': -2.5233002891919925,
 'emission:O+believe': -9.46931428029122,
 'emission:O+VBP': -4.122206749573751,
 'emission:O+the': -2.8532490

In [6]:
def viterbi_algo(x, states, f):
    scores = np.full((len(x), len(states)), -np.inf)
    parents = np.full((len(x), len(states)), 0, dtype=int)
    
    for i in range(len(states)):
        emission_key1 = "emission:" + states[i] + "+" + x[0].split()[0]
        emission_key2 = "emission:" + states[i] + "+" + x[0].split()[1]
        transmission_key = "transition:" + "start" + "+" + states[i]
        scores[0, i] = f.get(emission_key1, -10e8) + f.get(emission_key2, -10e8) + f.get(transmission_key, -10e8)
    
    for i in range(1, len(x)):
        for j in range(len(states)):
            for k in range(len(states)):
                emission_key1 = "emission:" + states[k] + "+" + x[i].split()[0]
                emission_key2 = "emission:" + states[k] + "+" + x[i].split()[1]
                transmission_key = "transition:" + states[j] + "+" + states[k]
                overall_score = scores[i-1, j] + f.get(emission_key1, -10e8) + f.get(emission_key2, -10e8) + f.get(transmission_key, -10e8)

                if overall_score > scores[i, k]:
                    scores[i, k] = overall_score
                    parents[i,k] = j
    
    best_score = -np.inf
    best_parent = None
    
    for i in range(len(states)):
        t_feature = "transition:" + states[i] + "+" + "stop"
        total = scores[len(x)-1, i] + f.get(t_feature, -10**8)
        
        if total > best_score:
            best_score = total
            best_parent = i
    
    best_state = [states[best_parent]]
    prev_parent = best_parent
    for i in range(len(x)-1, 0, -1):
        prev_parent = parents[i, prev_parent]
        output = states[prev_parent]
        best_state = [output] + best_state
    return best_state



In [27]:
def get_prediction(file_dir,resulting_dict):
    train_path = os.path.join(file_dir, "train")
    input_path = os.path.join(file_dir, "dev.in")
    
    test_set = open(input_path, 'r', encoding="utf-8")
    lines = test_set.readlines()
    
    sequences = [] #ls of sequences
    sequence = []
    for line in lines:
        if line == '\n':
            sequences.append(sequence)
            sequence = []
            continue

        line = line.replace('\n', '')
        sequence.append(line)

    out_path = os.path.join(file_dir, "dev.p5.CRF.f3.out")
    out_file = open(out_path, "w", encoding="utf-8")
    
    for x in sequences:
        print(x)
        predicted = viterbi_algo(x, states, resulting_dict)
        for i in range(len(x)):
            out_file.write(x[i] + ' ' + predicted[i] + '\n')
        out_file.write('\n')
    out_file.close()
        
    
    print("Complete prediction for dataset")

file_dir = "full"
get_prediction(file_dir,resulting_dict)

['The DT', 'Saudi NNP', 'Interior NNP', 'Ministry NNP', 'says VBZ', 'the DT', 'three CD', 'were VBD', 'not RB', 'on IN', 'a DT', 'most RBS', 'wanted JJ', 'list NN', 'of IN', 'suspected JJ', 'al-Qaida NNP', 'sympathizers NNS', 'issued VBN', 'by IN', 'Saudi JJ', 'authorities NNS', 'last JJ', 'year NN', '. .']
['British JJ', 'forces NNS', ', ,', 'based VBN', 'in IN', 'the DT', 'mainly RB', "Shi'ite JJ", 'south NN', ', ,', 'have VBP', 'suffered VBN', 'far RB', 'fewer JJR', 'losses NNS', 'than IN', 'the DT', 'much JJ', 'larger JJR', 'U.S. NNP', 'force NN', 'fighting VBG', 'Sunni NNP', 'Arab NNP', 'insurgents NNS', 'and CC', 'foreign JJ', 'fighters NNS', 'in IN', 'the DT', 'rest NN', 'of IN', 'Iraq NNP', '. .']
['The DT', 'massive JJ', 'exodus NN', 'of IN', 'hundreds NNS', 'of IN', 'thousands NNS', 'of IN', 'Yemenis NNS', 'from IN', 'the DT', 'south NN', 'to TO', 'the DT', 'north NN', 'contributed VBD', 'to TO', 'two CD', 'decades NNS', 'of IN', 'hostility NN', 'between IN', 'the DT', 'state

['Iran NNP', 'is VBZ', 'said VBN', 'to TO', 'have VB', 'influence NN', 'with IN', "Shi'ite NNP", 'factions NNS', 'and CC', 'militias NNS', 'in IN', 'Iraq NNP', '. .']
['In IN', 'April NNP', '2009 CD', ', ,', 'Croatia NNP', 'joined VBD', 'NATO NNP', '; ;', 'it PRP', 'is VBZ', 'a DT', 'candidate NN', 'for IN', 'eventual JJ', 'EU NNP', 'accession NN', '. .']
['Police NNS', 'say VBP', 'they PRP', 'have VBP', 'arrested VBN', 'four CD', 'suspects NNS', 'in IN', 'connection NN', 'with IN', 'the DT', 'incident NN', 'and CC', 'are VBP', 'now RB', 'giving VBG', 'Obama NNP', '24-hour JJ', 'protection NN', '. .']
['He PRP', 'announced VBD', 'plans NNS', 'to TO', 'close VB', 'it PRP', 'after IN', 'Russia NNP', 'pledged VBD', 'to TO', 'give VB', 'Kyrgyzstan NNP', 'about IN', '$ $', '2 CD', 'billion CD', 'in IN', 'loans NNS', 'and CC', 'aid NN', '. .']
['The DT', 'speech NN', 'will MD', 'be VB', 'Mr. NNP', 'Bush NNP', "'s POS", 'second JJ', 'on IN', 'energy NN', 'issues NNS', 'in IN', 'a DT', 'week N

['He PRP', 'stressed VBD', 'Japan NNP', "'s POS", 'record NN', 'of IN', 'pacifism NN', 'since IN', 'the DT', 'end NN', 'of IN', 'World NNP', 'War NNP', 'II NNP', '. .']
['But CC', 'relations NNS', 'have VBP', 'improved VBN', 'since IN', 'the DT', 'two CD', 'countries NNS', 'launched VBD', 'a DT', 'slow JJ', 'moving NN', 'peace NN', 'process NN', 'in IN', '2004 CD', 'to TO', 'resolve VB', 'their PRP$', 'disputes NNS', ', ,', 'including VBG', 'the DT', 'conflict NN', 'over IN', 'Kashmir NNP', '. .']
['Other JJ', 'national JJ', 'private JJ', 'networks NNS', 'also RB', 'opposed VBD', 'Mr. NNP', 'Chavez NNP', ', ,', 'but CC', 'their PRP$', 'criticism NN', 'of IN', 'the DT', 'government NN', 'is VBZ', 'now RB', 'softer JJR', 'and CC', 'they PRP', 'have VBP', 'kept VBN', 'their PRP$', 'licenses NNS', '. .']
['China NNP', "'s POS", 'state-run JJ', 'news NN', 'agency NN', 'says VBZ', 'a DT', 'coal NN', 'mine NN', 'accident NN', 'in IN', 'central JJ', 'Henan NNP', 'province NN', 'has VBZ', 'kill

In [8]:
def eval(pred,gold):
    f_pred = open(pred,encoding = 'utf-8')
    f_gold = open(gold,encoding = 'utf-8')
    data_pred = f_pred.readlines()
    data_gold = f_gold.readlines()
    gold_tags = list()
    pred_tags = list()
    for sentence in range(len(data_pred)):
        words_pred = data_pred[sentence].strip().split(' ')
        words_gold = data_gold[sentence].strip().split(' ')  
        if len(words_gold)==1:
            continue
        gold_tags.append(words_gold[2])
        pred_tags.append(words_pred[2])
    return gold_tags,pred_tags


true_path = "full/dev.out"
pred_path = "full/dev.p5.CRF.f3.out"

g_tags, p_tags = eval(true_path, pred_path)
print(evaluate(g_tags,p_tags,verbose=True))


processed 2097 tokens with 235 phrases; found: 236 phrases; correct: 163.
accuracy:  72.21%; (non-O)
accuracy:  93.75%; precision:  69.07%; recall:  69.36%; FB1:  69.21
              art: precision:   0.00%; recall:   0.00%; FB1:   0.00  3
              eve: precision:   0.00%; recall:   0.00%; FB1:   0.00  1
              geo: precision:  85.88%; recall:  75.26%; FB1:  80.22  85
              gpe: precision:  68.00%; recall:  80.95%; FB1:  73.91  25
              nat: precision:   0.00%; recall:   0.00%; FB1:   0.00  2
              org: precision:  45.71%; recall:  48.48%; FB1:  47.06  35
              per: precision:  71.88%; recall:  76.67%; FB1:  74.19  32
              tim: precision:  64.15%; recall:  65.38%; FB1:  64.76  53
(69.0677966101695, 69.36170212765957, 69.2144373673036)


In [62]:
resulting_dict = get_resulting_dict(train_path, emission_dict)
def part3(file_dir,resulting_dict,epoch = 5, lr= 0.01):
    train_path = os.path.join(file_dir, "train")
    input_path = os.path.join(file_dir, "dev.in")
    
    test_set = open(train_path, 'r', encoding="utf-8")
    lines = test_set.readlines()
    for i in range(epoch):
        print("starting epoch")
        sequences = [] #ls of sequences
        word_sequence = []
        correct_state = []
        for line in lines:
            temp = []
    #         print(line)
            if line == '\n':
                sequences.append([word_sequence,correct_state])
                word_sequence = []
                correct_state = []
                continue
    #         print(line)
            line = line.strip().split(" ")
    #         print(l)
            word_sequence.append(line[0]+" "+line[1])
            correct_state.append(line[2])
    #     print(sequences[0])
        
    #     print(sequences[0])
        for x in sequences:
            predicted = viterbi_algo(x[0], states, resulting_dict)
            x.append(predicted)
        #print(sequences[0])
        #sequences is a list of lists
        #sequence[0][0] is word + pos
        #sequence[0][1] 
        for i in range(len(sequences)):
            sentence = sequences[i][0]
            word_only = []
            pos_only = []
            for word_pos in sentence:
                #print(i)
                word,pos = word_pos.split(" ")
                word_only.append(word)
                pos_only.append(pos)
            correct_states = sequences[i][1]
            predicted_states = sequences[i][2]
            
                #print(len(sequences))
            #for each prediction, check if its correct
            for i in range(1,len(word_only)):
                if correct_states[i] != predicted_states[i]:
                    #'emission:O+enter': -9.46931428029122,
                    resulting_dict["emission:"+ predicted_states[i] +"+"+ word_only[i]] -= 1* lr
                    resulting_dict["emission:"+ predicted_states[i] +"+"+ pos_only[i]] -= 1*lr
                    resulting_dict["transition:"+ predicted_states[i-1] +"+" + predicted_states[i]] -= 1*lr
                    resulting_dict["emission:"+ correct_states[i] +"+"+ word_only[i]] += 1*lr
                    resulting_dict["emission:"+ correct_states[i] +"+"+ pos_only[i]] += 1*lr
                    resulting_dict["transition:"+ correct_states[i-1] +"+"+ correct_states[i]] += 1*lr

    return resulting_dict
        
resulting_dict = part3(file_dir,resulting_dict)

#         for i in range(len(x)):
#             out_file.write(x[i][0] + ' ' + predicted[i] + '\n')
#         out_file.write('\n')
#     out_file.close()


starting epoch
starting epoch
starting epoch
starting epoch
starting epoch


In [64]:
def get_prediction(file_dir,resulting_dict):
    train_path = os.path.join(file_dir, "train")
    input_path = os.path.join(file_dir, "dev.in")
    
    test_set = open(input_path, 'r', encoding="utf-8")
    lines = test_set.readlines()
    
    sequences = [] #ls of sequences
    sequence = []
    for line in lines:
        if line == '\n':
            sequences.append(sequence)
            sequence = []
            continue

        line = line.replace('\n', '')
        sequence.append(line)

    out_path = os.path.join(file_dir, "dev.p5.CRF.f3_test2.out")
    out_file = open(out_path, "w", encoding="utf-8")
    
    for x in sequences:
        print(x)
        predicted = viterbi_algo(x, states, resulting_dict)
        for i in range(len(x)):
            out_file.write(x[i] + ' ' + predicted[i] + '\n')
        out_file.write('\n')
    out_file.close()
        
    
    print("Complete prediction for dataset")

file_dir = "full"
get_prediction(file_dir,resulting_dict)

['The DT', 'Saudi NNP', 'Interior NNP', 'Ministry NNP', 'says VBZ', 'the DT', 'three CD', 'were VBD', 'not RB', 'on IN', 'a DT', 'most RBS', 'wanted JJ', 'list NN', 'of IN', 'suspected JJ', 'al-Qaida NNP', 'sympathizers NNS', 'issued VBN', 'by IN', 'Saudi JJ', 'authorities NNS', 'last JJ', 'year NN', '. .']
['British JJ', 'forces NNS', ', ,', 'based VBN', 'in IN', 'the DT', 'mainly RB', "Shi'ite JJ", 'south NN', ', ,', 'have VBP', 'suffered VBN', 'far RB', 'fewer JJR', 'losses NNS', 'than IN', 'the DT', 'much JJ', 'larger JJR', 'U.S. NNP', 'force NN', 'fighting VBG', 'Sunni NNP', 'Arab NNP', 'insurgents NNS', 'and CC', 'foreign JJ', 'fighters NNS', 'in IN', 'the DT', 'rest NN', 'of IN', 'Iraq NNP', '. .']
['The DT', 'massive JJ', 'exodus NN', 'of IN', 'hundreds NNS', 'of IN', 'thousands NNS', 'of IN', 'Yemenis NNS', 'from IN', 'the DT', 'south NN', 'to TO', 'the DT', 'north NN', 'contributed VBD', 'to TO', 'two CD', 'decades NNS', 'of IN', 'hostility NN', 'between IN', 'the DT', 'state

['Back RB', 'in IN', 'May NNP', ', ,', 'a DT', 'Nepalese JJ', 'Sherpa NN', 'climbed VBD', 'Everest NNP', 'for IN', 'a DT', 'record NN', '19th JJ', 'time NN', '. .']
['But CC', ', ,', 'the DT', 'president NN', 'predicted VBD', 'that IN', 'a DT', 'free JJ', 'country NN', 'will MD', 'emerge VB', 'in IN', 'Iraq NNP', ', ,', 'proving VBG', 'the DT', 'merit NN', 'of IN', 'his PRP$', 'policies NNS', 'there RB', '. .']
['Kuwait NNP', 'has VBZ', 'also RB', 'found VBN', 'the DT', 'deadly JJ', 'H5N1 NNP', 'variety NN', 'of IN', 'avian JJ', 'flu NN', 'in IN', 'a DT', 'bird NN', 'culled VBN', 'by IN', 'authorities NNS', '. .']
['They PRP', 'also RB', 'discussed VBD', 'creating VBG', 'three CD', 'non-permanent JJ', 'seats NNS', 'for IN', 'African JJ', 'countries NNS', '. .']
['Officials NNS', 'alleged VBD', 'that IN', 'Khan NNP', 'acted VBD', 'as IN', 'a DT', 'link NN', 'between IN', 'top JJ', 'al-Qaida NNP', 'leaders NNS', 'and CC', 'the DT', 'organizations NNS', "' POS", 'operational JJ', 'cells N

['Finance NNP', 'Minister NNP', 'Rodrigo NNP', 'Cabezas NNP', 'said VBD', 'Wednesday NNP', 'that IN', 'his PRP$', 'country NN', "'s POS", 'inflation NN', 'rate NN', 'of IN', '22.5 CD', 'percent NN', 'is VBZ', '" ``', 'unsatisfactory JJ', '. .', '" ``']
['The DT', 'U.S. NNP', 'military NN', 'has VBZ', 'not RB', 'commented VBN', 'on IN', 'the DT', 'report NN', '. .']
['Bird NNP', 'flu NNP', 'has VBZ', 'killed VBN', 'more JJR', 'people NNS', 'in IN', 'Indonesia NNP', 'than IN', 'any DT', 'other JJ', 'country NN', 'since IN', 'it PRP', 'began VBD', 'spreading VBG', 'in IN', 'Southeast NNP', 'Asia NNP', 'in IN', 'late JJ', '2003 CD', '. .']
['He PRP', 'stressed VBD', 'Japan NNP', "'s POS", 'record NN', 'of IN', 'pacifism NN', 'since IN', 'the DT', 'end NN', 'of IN', 'World NNP', 'War NNP', 'II NNP', '. .']
['But CC', 'relations NNS', 'have VBP', 'improved VBN', 'since IN', 'the DT', 'two CD', 'countries NNS', 'launched VBD', 'a DT', 'slow JJ', 'moving NN', 'peace NN', 'process NN', 'in IN',

In [63]:
resulting_dict

{'emission:O+The': -4.333515843240958,
 'emission:O+DT': -2.3625841023613474,
 'emission:O+official': -7.166729187297174,
 'emission:O+JJ': -2.7940717112393356,
 'emission:O+cause': -8.776167099731275,
 'emission:O+NN': -2.0932698774179563,
 'emission:O+of': -3.459671727703495,
 'emission:O+IN': -2.1174501491261895,
 'emission:O+death': -7.166729187297174,
 'emission:O+has': -4.649032714686183,
 'emission:O+VBZ': -3.5639524322366496,
 'emission:O+not': -5.97280671882474,
 'emission:O+RB': -4.089543182624641,
 'emission:O+been': -5.755742213586912,
 'emission:O+VBN': -3.3294297280649645,
 'emission:O+officially': -8.370701991623111,
 'emission:O+determined': -8.370701991623111,
 'emission:O+,': -2.760769056734622,
 'emission:O+but': -6.211217742269738,
 'emission:O+CC': -3.7199212943829667,
 'emission:O+investigators': -9.46931428029122,
 'emission:O+NNS': -2.5733002891919914,
 'emission:O+believe': -9.46931428029122,
 'emission:O+VBP': -4.122206749573751,
 'emission:O+the': -2.97124909

In [65]:
def eval(pred,gold):
    f_pred = open(pred,encoding = 'utf-8')
    f_gold = open(gold,encoding = 'utf-8')
    data_pred = f_pred.readlines()
    data_gold = f_gold.readlines()
    gold_tags = list()
    pred_tags = list()
    for sentence in range(len(data_pred)):
        words_pred = data_pred[sentence].strip().split(' ')
        words_gold = data_gold[sentence].strip().split(' ')  
        if len(words_gold)==1:
            continue
        gold_tags.append(words_gold[2])
        pred_tags.append(words_pred[2])
    return gold_tags,pred_tags


true_path = "full/dev.out"
pred_path = "full/dev.p5.CRF.f3_test2.out"

g_tags, p_tags = eval(true_path, pred_path)
print(evaluate(g_tags,p_tags,verbose=True))


processed 2097 tokens with 193 phrases; found: 236 phrases; correct: 143.
accuracy:  76.73%; (non-O)
accuracy:  92.94%; precision:  60.59%; recall:  74.09%; FB1:  66.67
              art: precision:   0.00%; recall:   0.00%; FB1:   0.00  3
              eve: precision:   0.00%; recall:   0.00%; FB1:   0.00  1
              geo: precision:  65.88%; recall:  83.58%; FB1:  73.68  85
              gpe: precision:  68.00%; recall:  70.83%; FB1:  69.39  25
              nat: precision:   0.00%; recall:   0.00%; FB1:   0.00  2
              org: precision:  42.86%; recall:  57.69%; FB1:  49.18  35
              per: precision:  62.50%; recall:  80.00%; FB1:  70.18  32
              tim: precision:  66.04%; recall:  71.43%; FB1:  68.63  53
(60.59322033898306, 74.09326424870466, 66.66666666666667)
