In [6]:
from rouge import Rouge
from nltk.translate.bleu_score import corpus_bleu
from collections import Counter

In [7]:
#Для подсчета метрики METEOR использования необходимо скачать 
#.jar https://mvnrepository.com/artifact/edu.cmu/Meteor/1.5

# Аргументы вызова:
# records - Dataset с колонками "summary" - исходное краткое содержание и "text" - исходный текст
# predict_func(text, summary) - функция, генерирующая краткое содержание для исходного текста
def calc_method_score(records, predict_func, meteor_jar="meteor-1.5/meteor-1.5.jar"):
    references = []
    predictions = []
    for i, record in enumerate(records):
        references.append(record["summary"])
        predictions.append(predict_func(record["text"], record["summary"]))

    for i, (ref, hyp) in enumerate(zip(references, predictions)):
        references[i], predictions[i] = postprocess(ref, hyp, tokenize_after=True, lower=True)
    print_metrics(references, predictions, meteor_jar=meteor_jar)


def calc_bert_score(records, predict_func, nrows=None):
    references = []
    predictions = []
    for i, record in enumerate(records):
        if nrows is not None and i >= nrows:
            break
        references.append(record["summary"])
        predictions.append(predict_func(record["text"], record["summary"]))

    for i, (ref, hyp) in enumerate(zip(references, predictions)):
        references[i], predictions[i] = postprocess(ref, hyp, tokenize_after=False, lower=False)
    print_metrics(references, predictions, meteor_jar=None, metric="bert_score")

def print_metrics(refs, hyps,metric="all", meteor_jar=None):
    metrics = calc_metrics(refs, hyps, metric=metric, meteor_jar=meteor_jar)

    print("-------------METRICS-------------")
    print("Count:\t", metrics["count"])
    print("Ref:\t", metrics["ref_example"])
    print("Hyp:\t", metrics["hyp_example"])

    if "bleu" in metrics:
        print("BLEU:     \t{:3.1f}".format(metrics["bleu"] * 100.0))
    if "chrf" in metrics:
        print("chrF:     \t{:3.1f}".format(metrics["chrf"] * 100.0))
    if "rouge-1" in metrics:
        print("ROUGE-1-F:\t{:3.1f}".format(metrics["rouge-1"]['f'] * 100.0))
        print("ROUGE-2-F:\t{:3.1f}".format(metrics["rouge-2"]['f'] * 100.0))
        print("ROUGE-L-F:\t{:3.1f}".format(metrics["rouge-l"]['f'] * 100.0))
    if "meteor" in metrics:
        print("METEOR:   \t{:3.1f}".format(metrics["meteor"] * 100.0))
    if "length" in metrics:
        print("Avg length:\t{:3.1f}".format(metrics["length"]))
    for key, value in metrics.items():
        if "bert_score" not in key:
            continue
        print("{}:\t{:3.1f}".format(key, value["f"] * 100.0))

def calc_metrics(refs, hyps, metric="all",meteor_jar=None):
    metrics = dict()
    metrics["count"] = len(hyps)
    metrics["ref_example"] = refs[-1]
    metrics["hyp_example"] = hyps[-1]
    many_refs = [[r] if r is not list else r for r in refs]
    if metric in ("bleu", "all"):
        t_hyps = [hyp.split(" ") for hyp in hyps]
        t_refs = [[r.split(" ") for r in rs] for rs in many_refs]
        metrics["bleu"] = corpus_bleu(t_refs, t_hyps)
    if metric in ("rouge", "all"):
        rouge = Rouge()
        scores = rouge.get_scores(hyps, refs, avg=True)
        metrics.update(scores)
    if metric in ("meteor", "all") and meteor_jar is not None and os.path.exists(meteor_jar):
        meteor = Meteor(meteor_jar, language='ru')
        metrics["meteor"] = meteor.compute_score(hyps, many_refs)
    if metric in ("bert_score",) and torch.cuda.is_available():
        bert_scores, hash_code = calc_bert_score(hyps, refs)
        metrics["bert_score_{}".format(hash_code)] = bert_scores
    if metric in ("chrf", "all"):
        metrics["chrf"] = corpus_chrf(refs, hyps, beta=1.0)
    if metric in ("length", "all"):
        metrics["length"] = mean([len(h) for h in hyps])
    return metrics