Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

请问评估指标是如何确定的呢? #9

Closed
hanyc0914 opened this issue Dec 18, 2020 · 3 comments
Closed

请问评估指标是如何确定的呢? #9

hanyc0914 opened this issue Dec 18, 2020 · 3 comments

Comments

@hanyc0914
Copy link

表格中 recall, F1 score 等指标是如何计算得到的呢?

@Ethan-yt
Copy link
Owner

比赛结束后测试集的标签发放给我们了,然后我们写了一段代码计算。具体的代码可以参考:

import re


def get_entities(lines):
    result = []
    cur = 0
    for line in lines:
        entities = []
        last_end = 0
        for m in re.finditer(r"{{(.*?)::?(.*?)}}", line):
            label = m.group(1).upper()
            word = m.group(2)
            cur += m.start() - last_end
            last_end = m.end()
            if word:
                entities.append((label, cur, cur + len(word)))
                cur += len(word)
        cur += len(line) - last_end
        result.extend(entities)
    return result


def getlines(path):
    with open(path) as f:
        lines = f.read().split("\n")
        lines = list(filter(lambda line: line, lines))
        return lines


def main():
    ground_truth_path = 'zs100w_0921_wyq_up.txt'
    pred_path = 'result.txt'

    ground_truth_lines = getlines(ground_truth_path)
    pred_lines = getlines(pred_path)

    for i, (gtl, pl) in enumerate(zip(ground_truth_lines, pred_lines)):
        gtl_no_label = re.sub(r"{{(.*?)::?(.*?)}}", r'\2', gtl)
        pl_no_label = re.sub(r"{{(.*?)::?(.*?)}}", r'\2', pl)
        assert gtl_no_label == pl_no_label, f"Different data in row {i}: \n{gtl} \n{pl}"

    true_entities = set(get_entities(ground_truth_lines))
    pred_entities = set(get_entities(pred_lines))

    nb_correct = len(true_entities & pred_entities)
    nb_pred = len(pred_entities)
    nb_true = len(true_entities)
    p = nb_correct / nb_pred if nb_pred > 0 else 0
    r = nb_correct / nb_true if nb_true > 0 else 0
    score = 2 * p * r / (p + r) if p + r > 0 else 0
    print('P', p)
    print('R', r)
    print('F1', score)

    true_entities_dict = {}
    for t, start, end in true_entities:
        if t not in true_entities_dict:
            true_entities_dict[t] = set()
        true_entities_dict[t].add((start, end))

    pred_entities_dict = {}
    for t, start, end in pred_entities:
        if t not in pred_entities_dict:
            pred_entities_dict[t] = set()
        pred_entities_dict[t].add((start, end))

    nb_correct_dict = {k: len(true_entities_dict[k] & pred_entities_dict[k]) for k in true_entities_dict}
    nb_pred_dict = {k: len(pred_entities_dict[k]) for k in true_entities_dict}
    nb_true_dict = {k: len(true_entities_dict[k]) for k in true_entities_dict}

    p_dict = {k: nb_correct_dict[k] / nb_pred_dict[k] if nb_pred_dict[k] > 0 else 0 for k in true_entities_dict}
    r_dict = {k: nb_correct_dict[k] / nb_true_dict[k] if nb_true_dict[k] > 0 else 0 for k in true_entities_dict}
    score_dict = {k: 2 * p_dict[k] * r_dict[k] / (p_dict[k] + r_dict[k]) if p_dict[k] + r_dict[k] > 0 else 0 for k in
                  true_entities_dict}
    print('P', p_dict)
    print('R', r_dict)
    print('F1', score_dict)


if __name__ == '__main__':
    main()

@hanyc0914
Copy link
Author

好的,太感谢您了!还有一个问题,词表大小是 23292,但是网络最后一层输出的维度是 768,这样计算交叉熵损失函数会报错说label 范围超出实际维度,例如 logits.shape = [32,204,768], labels.shape = [32,204],那么 labels 中实际元素值肯定会比 768 大的,请问这个问题怎么解决呢?网络最后一层输出为什么没有设置成 23292 呢?
image

@Ethan-yt
Copy link
Owner

我上传的模型是transformers.RobertaForMaskedLM。如果使用transformers.RobertaModel将会抛弃lm_head层,所以直接输出hidden size。你可以使用transformers.RobertaForMaskedLM,最后会映射到词表的维度。

具体参考huggingface的文档

https://huggingface.co/transformers/model_doc/roberta.html#transformers.RobertaForMaskedLM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants