Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

训练结束后predict.py脚本获取不到标签问题 #3

Open
bree2eC opened this issue May 16, 2021 · 7 comments
Open

训练结束后predict.py脚本获取不到标签问题 #3

bree2eC opened this issue May 16, 2021 · 7 comments

Comments

@bree2eC
Copy link

bree2eC commented May 16, 2021

训练结束后predict.py脚本获取不到标签问题
知乎上也有很多人说predict.py脚本获取到的标签为空,其实不是训练数据有问题或者轮次不够,作者的get_label 函数逻辑有一些小小的问题,我这里简单修改了一下,可以成功获取到标签,新的predict.py 的get_label 函数如下:

def get_label(sentence):
    """
    Prediction of the sentence's label.
    """
    feature = get_feature_test(sentence)
    fd = {MODEL.albert.input_ids: [feature[0]],
          MODEL.albert.input_masks: [feature[1]],
          MODEL.albert.segment_ids:[feature[2]],
          }
    prediction = MODEL.sess.run(MODEL.albert.predictions, feed_dict=fd)[0]
    print(prediction)
    r=[]
    for i in range(len(prediction)):
        if prediction[i]!=0.0:
            r.append(id2label(i))
    return r
    #return [id2label(l) for l in np.where(prediction==1)[0] if l!=0]
@hellonlp
Copy link
Owner

建议打印prediction看看

@nomel0921
Copy link

训练结束后predict.py脚本获取不到标签问题 知乎上也有很多人说predict.py脚本获取到的标签为空,其实不是训练数据有问题或者轮次不够,作者的get_label 函数逻辑有一些小小的问题,我这里简单修改了一下,可以成功获取到标签,新的predict.py 的get_label 函数如下:

def get_label(sentence):
    """
    Prediction of the sentence's label.
    """
    feature = get_feature_test(sentence)
    fd = {MODEL.albert.input_ids: [feature[0]],
          MODEL.albert.input_masks: [feature[1]],
          MODEL.albert.segment_ids:[feature[2]],
          }
    prediction = MODEL.sess.run(MODEL.albert.predictions, feed_dict=fd)[0]
    print(prediction)
    r=[]
    for i in range(len(prediction)):
        if prediction[i]!=0.0:
            r.append(id2label(i))
    return r
    #return [id2label(l) for l in np.where(prediction==1)[0] if l!=0]

您好,按照您的修改之后我的输出都是第一个标签了...我看了下任何句子的prediction几乎相等,这是什么原因呢...

@lonkecxd
Copy link

请问你这个问题解决了吗?我也遇到了,prediction结果都是[]。

@hellonlp
Copy link
Owner

我的第一个标签是空标签"|",所以会不一样

@xtuyaowu
Copy link

大家有用起来么?要不要加个QQ好友交流一下?qq:1840658279

@rio5050
Copy link

rio5050 commented May 6, 2023

prediction打印出来全是0......

@hellonlp
Copy link
Owner

是正常的,说明预测结果为空

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants