In [1]:
from configs import PipelineConfigs
import torch
from utils import text_process, seg_char

In [2]:
configs = PipelineConfigs()
configs.model.load_state_dict(torch.load("./model.pt"))
model = configs.model
tokenizer = configs.tokenizer

In [3]:
def bert_pred(model, tokenizer, sentence):
    sentence = text_process(sentence, is_tradition=True)
    char_list = seg_char(sentence)
    inputs = tokenizer(
        char_list,
        return_tensors="pt",
        is_split_into_words=True
    )
    print(inputs)
    output = model(inputs["input_ids"], inputs["attention_mask"])
    print(output.size())    # [batch_size, src_len, 4]
    output = torch.argmax(output, dim=-1)[:, 1: -1].tolist()
    print(output)
    char_list = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])[1: -1]
    return char_list, output

In [6]:
sentence = "迈向充满希望的新世纪"
char_list, preds = bert_pred(model, tokenizer, sentence)



{'input_ids': tensor([[ 101, 6815, 1403, 1041, 4007, 2361, 3307, 4638, 3173,  686, 5279,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
torch.Size([1, 12, 4])
[[3, 3, 3, 3, 3, 3, 3, 3, 3, 3]]


In [5]:
print(char_list)
print(preds)

['今', '天', '是', '个', '好', '日', '子']
[[3, 3, 3, 3, 3, 3, 3]]


In [7]:
print(model)

BertCWS(
  (bert_model): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(21128, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True