In [1]:
import numpy as np
from util import read_passages
from operator import itemgetter, attrgetter
from sklearn.metrics import f1_score, confusion_matrix
import matplotlib.pyplot as plt
import itertools

In [2]:
def sortByCount(array):
    """
    Given a list, count the number of each element and sort them according to the most frequent to the least frequent.
    Returns a list of tuples, (item, count).
    """
    array_dict = {}
    for item in list(set(array)):
        array_dict[item] = 0
    for item in array:
        array_dict[item] += 1
    sorted_list = [(item, count) for item, count in array_dict.items()]
    sorted_list = sorted(sorted_list, key=itemgetter(1), reverse=True)
    return sorted_list

In [7]:
pred_file = "predictions/lucky_testatt=True_cont=word_lstm=False_bi=True_crf=True.out"
test_data = "lucky_test.txt"


In [8]:
str_seqs, label_seqs = read_passages(test_data,True)

pred_labels, _ = read_passages(pred_file, False)    

def linearize(labels):
    linearized = []
    for paper in labels:
        for label in paper:
            linearized.append(label)
    return linearized

true_label = linearize(label_seqs)
pred_label = linearize(pred_labels)

f1 = f1_score(true_label,pred_label,average="weighted")
print("F1 score:",f1)

F1 score: 0.7964487570659828


In [None]:
def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        #print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    #print(cm)
    plt.figure(figsize=(6,4.5))
    plt.imshow(cm, interpolation='nearest', aspect='auto', cmap=cmap)
    plt.title(title)
    #plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)
    plt.ylim([-0.5, len(classes)-0.5])
    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    #plt.show()
    filename = "confusion_matrix.pdf"
    plt.savefig(filename,quality=100,bbox_inches='tight')

In [None]:
plt.rcParams.update({'font.size': 14})

In [None]:
cnf_matrix = confusion_matrix(true_label,pred_label)
plot_confusion_matrix(cnf_matrix, classes=np.unique(true_label), normalize=True,
                      title='')