In [None]:
import numpy as np
from collections import Counter
from scipy.optimize import linear_sum_assignment

def extract_clusters(gold_clusters):
    gold_clusters = [tuple(tuple(m) for m in gc) for gc in gold_clusters]
    gold_clusters = [cluster for cluster in gold_clusters if len(cluster) > 0]
    return gold_clusters


def extract_mentions_to_predicted_clusters_from_clusters(gold_clusters):
    mention_to_gold = {}
    for gc in gold_clusters:
        for mention in gc:
            mention_to_gold[tuple(mention)] = gc
    return mention_to_gold


def f1(p_num, p_den, r_num, r_den, beta=1):
    p = 0 if p_den == 0 else p_num / float(p_den)
    r = 0 if r_den == 0 else r_num / float(r_den)
    return 0 if p + r == 0 else (1 + beta * beta) * p * r / (beta * beta * p + r)


class CorefEvaluator(object):
    def __init__(self):
        self.evaluators = [Evaluator(m) for m in (muc, b_cubed, ceafe)]

    def update(self, predicted, gold, mention_to_predicted, mention_to_gold):
        for e in self.evaluators:
            e.update(predicted, gold, mention_to_predicted, mention_to_gold)

    def get_f1(self):
        return [e.get_f1() for e in self.evaluators]
                    
    def get_recall(self):
        return [e.get_recall() for e in self.evaluators]

    def get_precision(self):
        return [e.get_precision() for e in self.evaluators]

    def get_prf(self):
        return self.get_precision(), self.get_recall(), self.get_f1()


class Evaluator(object):
    def __init__(self, metric, beta=1):
        self.p_num = 0
        self.p_den = 0
        self.r_num = 0
        self.r_den = 0
        self.metric = metric
        self.beta = beta

    def update(self, predicted, gold, mention_to_predicted, mention_to_gold):
        if self.metric == ceafe:
            pn, pd, rn, rd = self.metric(predicted, gold)
        else:
            pn, pd = self.metric(predicted, mention_to_gold)
            rn, rd = self.metric(gold, mention_to_predicted)
        self.p_num += pn
        self.p_den += pd
        self.r_num += rn
        self.r_den += rd

    def get_f1(self):
        return f1(self.p_num, self.p_den, self.r_num, self.r_den, beta=self.beta)

    def get_recall(self):
        return 0 if self.r_num == 0 else self.r_num / float(self.r_den)

    def get_precision(self):
        return 0 if self.p_num == 0 else self.p_num / float(self.p_den)

    def get_prf(self):
        return self.get_precision(), self.get_recall(), self.get_f1()

    def get_counts(self):
        return self.p_num, self.p_den, self.r_num, self.r_den


def b_cubed(clusters, mention_to_gold):
    num, dem = 0, 0

    for c in clusters:
        if len(c) == 1:
            continue

        gold_counts = Counter()
        correct = 0
        for m in c:
            if m in mention_to_gold:
                gold_counts[tuple(mention_to_gold[m])] += 1
        for c2, count in gold_counts.items():
            if len(c2) != 1:
                correct += count * count

        num += correct / float(len(c))
        dem += len(c)

    return num, dem


def muc(clusters, mention_to_gold):
    tp, p = 0, 0
    for c in clusters:
        p += len(c) - 1
        tp += len(c)
        linked = set()
        for m in c:
            if m in mention_to_gold:
                linked.add(mention_to_gold[m])
            else:
                tp -= 1
        tp -= len(linked)
    return tp, p


def phi4(c1, c2):
    return 2 * len([m for m in c1 if m in c2]) / float(len(c1) + len(c2))


def ceafe(clusters, gold_clusters):
    clusters = [c for c in clusters if len(c) != 1]
    scores = np.zeros((len(gold_clusters), len(clusters)))
    for i in range(len(gold_clusters)):
        for j in range(len(clusters)):
            scores[i, j] = phi4(gold_clusters[i], clusters[j])
    row_ind, col_ind = linear_sum_assignment(-scores)
    similarity = sum(scores[row_ind, col_ind])
    return similarity, len(clusters), similarity, len(gold_clusters)


def get_metrics(gold_clusters, pred_clusters):
    pred_clusters = extract_clusters(pred_clusters)
    gold_clusters = extract_clusters(gold_clusters)
    mention_to_predicted = extract_mentions_to_predicted_clusters_from_clusters(pred_clusters)
    mention_to_gold = extract_mentions_to_predicted_clusters_from_clusters(gold_clusters)
    ev = CorefEvaluator()
    ev.update(pred_clusters, gold_clusters, mention_to_predicted, mention_to_gold)
    return ev.get_precision(), ev.get_recall(), ev.get_f1()  


def calculate_and_print_metrics_of_all_clusters(all_pred_clusters, all_gold_clusters, to_print = False):
    muc_precision = 0
    muc_recall = 0
    muc_f1 = 0
    b3_precision = 0
    b3_recall = 0
    b3_f1 = 0
    ceaf_precision = 0 
    ceaf_recall = 0
    ceaf_f1 = 0
    number_of_texts = 0
    i = 0
    while i < len(all_pred_clusters):
        number_of_texts += 1
        p, r, f = get_metrics(all_gold_clusters[i], all_pred_clusters[i]) 
        muc_precision += p[0]
        b3_precision += p[1]
        ceaf_precision += p[2]
        muc_recall += r[0]
        b3_recall += r[1]
        ceaf_recall += r[2]
        muc_f1 += f[0]
        b3_f1 += f[1]
        ceaf_f1 += f[2]
        i += 1
    muc_p = muc_precision/number_of_texts
    muc_r = muc_recall/number_of_texts
    muc_f = muc_f1/number_of_texts
    b3_p = b3_precision/number_of_texts
    b3_r = b3_recall/number_of_texts
    b3_f = b3_f1/number_of_texts
    ceaf_p = ceaf_precision/number_of_texts
    ceaf_r = ceaf_recall/number_of_texts
    ceaf_f = ceaf_f1/number_of_texts
    avg_p = (muc_p + b3_p + ceaf_p) / 3
    avg_r = (muc_r + b3_r + ceaf_r) / 3
    avg_f = (muc_f + b3_f + ceaf_f) / 3
    if to_print:
        print("MUC precision " + str(muc_p))
        print("MUC recall " + str(muc_r))
        print("MUC f1 " + str(muc_f))
        print("B3 precision " + str(b3_p))
        print("B3 recall " + str(b3_r))
        print("B3 f1 " + str(b3_f))
        print("CEAF precision " + str(ceaf_p))
        print("CEAF recall " + str(ceaf_r))
        print("CEAF f1 " + str(ceaf_f))
        print("Average precision " + str(avg_p))
        print("Average recall " + str(avg_r))
        print("Average f1 " + str(avg_f))
    return muc_p, muc_r, muc_f, b3_p, b3_r, b3_f, ceaf_p, ceaf_r,ceaf_f