In [20]:
import torch
import pickle

In [21]:
DEVICE = torch.device("cuda:6" if torch.cuda.is_available() else "cpu")

In [22]:
model = torch.load("NERModel.pth").to(DEVICE)

In [43]:
with open("vocab.pkl", "rb") as f:
    vocab = pickle.load(f)
word2idx = vocab.get_stoi()

In [48]:
tag2id = {
    "O": 0,
    "B-PER": 1,
    "I-PER": 2,
    "B-ORG": 3,
    "I-ORG": 4,
    "B-LOC": 5,
    "I-LOC": 6,
    "PAD": 7,
}
id2tag = {v: k for k, v in tag2id.items()}

In [50]:
def predict(text):
    model.eval()
    with torch.no_grad():
        input_ids = (
            torch.tensor([word2idx[word] for word in text], dtype=torch.long)
            .unsqueeze(0)
            .to(DEVICE)
        )
        valid_lens = torch.tensor([input_ids.size(1)], dtype=torch.long).to(DEVICE)
        tag_ids, _ = model(input_ids, valid_lens)
        tags = [id2tag[tag_id] for tag_id in tag_ids[0]]
        return tags

In [63]:
input = "我叫张三，我在国科大读研，在雁栖湖。"

In [64]:
predict(input)

['O',
 'O',
 'B-PER',
 'I-PER',
 'O',
 'O',
 'O',
 'B-ORG',
 'I-ORG',
 'I-ORG',
 'O',
 'O',
 'O',
 'O',
 'B-LOC',
 'I-LOC',
 'I-LOC',
 'O']