-
Notifications
You must be signed in to change notification settings - Fork 107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
请问怎么调用训练好的模型进行预测? #16
Comments
同问 |
参考函数
参考代码model.eval()
true_tags = []
pred_tags = []
sent_data = []
for idx, batch_samples in enumerate(dev_loader):
sentences, labels, masks, lens = batch_samples
sent_data.extend([[vocab.id2word.get(idx.item()) for i, idx in enumerate(indices) if mask[i] > 0]
for (mask, indices) in zip(masks, sentences)])
sentences = sentences.to(device)
labels = labels.to(device)
masks = masks.to(device)
y_pred = model.forward(sentences)
labels_pred = model.crf.decode(y_pred, mask=masks)
targets = [itag[:ilen] for itag, ilen in zip(labels.cpu().numpy(), lens)]
true_tags.extend([[vocab.id2label.get(idx) for idx in indices] for indices in targets])
pred_tags.extend([[vocab.id2label.get(idx) for idx in indices] for indices in labels_pred])
print(f'预测标签:{pred_tags }') |
感谢大佬!!!立刻尝试💪在 2024年5月10日,11:10,ZhouYaFei ***@***.***> 写道:
参考函数
./model.py/dev(data_loader, vocab, model, device, mode='dev')
参考代码
model.eval()
true_tags = []
pred_tags = []
sent_data = []
for idx, batch_samples in enumerate(dev_loader):
sentences, labels, masks, lens = batch_samples
sent_data.extend([[vocab.id2word.get(idx.item()) for i, idx in enumerate(indices) if mask[i] > 0]
for (mask, indices) in zip(masks, sentences)])
sentences = sentences.to(device)
labels = labels.to(device)
masks = masks.to(device)
y_pred = model.forward(sentences)
labels_pred = model.crf.decode(y_pred, mask=masks)
targets = [itag[:ilen] for itag, ilen in zip(labels.cpu().numpy(), lens)]
true_tags.extend([[vocab.id2label.get(idx) for idx in indices] for indices in targets])
pred_tags.extend([[vocab.id2label.get(idx) for idx in indices] for indices in labels_pred])
print(f'预测标签:{pred_tags }')
—Reply to this email directly, view it on GitHub, or unsubscribe.You are receiving this because you commented.Message ID: ***@***.***>
|
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
No description provided.
The text was updated successfully, but these errors were encountered: