In [1]:
from sklearn import metrics
import matplotlib.pyplot as plt
import numpy as np
import itertools
from collections import Counter
import dill as pickle
import os

In [2]:
STORE_RESULTS = True
pkl_name = "21.pkl"
pkl_path = "../output_objects/"
gt_file = "../test_gold.txt"
pred_file = "../output_data/temp.txt"

In [3]:
class Token:
    def __init__(self, features=None,label=None):
        if(features == None):
            self.features = []
        else:
            self.features = features
        self.label = label
    
    def add_feature(self,feature):
        self.features.append(feature)
    
    def get_string(self,label=True):
        s = " ".join(self.features)
        if(label):
            s = s + " " + self.label
        s = s + "\n"
        return s

class Sentence:
    def __init__(self,tokens=None):
        if(tokens == None):
            self.tokens = []
        else:
            self.tokens = tokens

    def add_token(self,token):
        #print("Before appending: ",self.get_in_format())
        self.tokens.append(token)

    def get_num_tokens(self):
        return len(self.tokens)
    
    def get_labels(self):
        labels = []
        for token in self.tokens:
            labels.append(token.label)
        return labels
    
    def get_words_str(self):
        words = []
        for token in self.tokens:
            words.append(token.features[0])
        return " ".join(words)
    
    
    def get_in_format(self,label=True,new_line=True):
        sent_list = []
        for token in self.tokens:
            token_string = token.get_string(label=label)
            sent_list.append(token_string)
        if(new_line):
            sent_list.append('\n')
        return sent_list

In [4]:
def read_gt_file(path,label=True):
    data = []
    count = 0
    lines = []
    sentence = Sentence()
    with open(path,'r',encoding='latin1') as f:
        for line in f:
            if(line == '\n'):
                count += 1
                #if(sentence.get_num_tokens() > 0):
                data.append(sentence)
                sentence = Sentence()
                #print("new sentence created: ",sentence)
                #print(sentence.get_in_format())
            else:
                token = []
                contents = line.split()
                #print(contents)
                if(label):
                    token = Token(contents[:-1],contents[-1])
                else:
                    token = Token(contents)

                #print(token.features)
                #print(token.label)
                #print(token.get_string(True))
                sentence.add_token(token)
                #print("sentence contents: ",sentence.get_in_format())
            
    print("Number of sentences: ",count)
    return data

In [5]:
def pickler(path,pkl_name,obj):
    with open(os.path.join(path, pkl_name), 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)

def unpickler(path,pkl_name):
    with open(os.path.join(path, pkl_name) ,'rb') as f:
        obj = pickle.load(f)
    return obj

In [None]:
def read_result_file(path,per_sent=False):
    predictions = []
    if(not per_sent):
        with open(path,'r',encoding='utf8') as f:
            for line in f:
                if(line != '\n'):
                    predictions.append(line.split()[0])
    else:
        sent_preds = []
        with open(path,'r',encoding='utf8') as f:
            for line in f:
                if(line != '\n'):
                    sent_preds.append(line.split()[0])
                else:
                    predictions.append(sent_preds)
                    sent_preds = []
    return predictions

In [None]:
def get_true_labels(gold,per_sent=False):
    true_labels = []
    if(not per_sent):
        for item in gold:
            true_labels += item.get_labels()
    else:
        for item in gold:
            true_labels.append(item.get_labels())
    return true_labels

In [None]:
def compute_accuracies(true_labels,predictions):
    #target_names = ['O', 'D', 'T']
    report = metrics.classification_report(true_labels,predictions)#,target_names=target_names)
    return report

In [None]:
def get_sentences(gold):
    sentences = []
    for sent in gold:
        sentences.append(sent.get_in_format(label=False,new_line=False))
    return sentences

In [None]:
def get_correct_and_incorrect(sentences,gold_labels,predictions):
    incorrect = []
    correct = []
    for i in range(len(gold_labels)):
        #print(gold_labels[i],predictions[i])
        if(gold_labels[i] != predictions[i]):
            incorrect.append([i,sentences[i], gold_labels[i], predictions[i]])
        else:
            correct.append([i,sentences[i], gold_labels[i], predictions[i]])
    return correct,incorrect

In [None]:
def store_results(pkl_path,pkl_name,sentences,gold_labels,predictions,correct,incorrect):
    d = {
        'sentences':sentences,
        'gold_labels':gold_labels,
        'predictions':predictions,
        'correct':correct,
        'incorrect':incorrect
        }
    pickler(pkl_path,pkl_name,d)

In [None]:
gold = read_gt_file(gt_file)

In [None]:
predictions = read_result_file(pred_file)

In [None]:
true_labels = get_true_labels(gold)

In [None]:
if STORE_RESULTS:
    sentences = get_sentences(gold)
    gold_labels = get_true_labels(gold,per_sent=True)
    sent_preds = read_result_file(pred_file,per_sent=True)
    correct,incorrect = get_correct_and_incorrect(sentences,gold_labels,sent_preds)
    store_results(pkl_path,pkl_name,sentences,gold_labels,sent_preds,correct,incorrect)

In [None]:
print("classification report:")
print(compute_accuracies(true_labels,predictions))

In [None]:
print("Accuracy: ",metrics.accuracy_score(true_labels,predictions))

In [None]:
cnf_matrix = metrics.confusion_matrix(true_labels,predictions)

In [None]:
def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

plt.figure()
class_names = ['D','O','T']
plot_confusion_matrix(cnf_matrix, classes=class_names,
                      title='Confusion matrix, without normalization')

# Plot normalized confusion matrix
plt.figure()
plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,
                      title='Normalized confusion matrix')

plt.show()

In [None]:
print("Counts of predicted labels: ",Counter(predictions))