In [18]:
import os
import time
import sklearn
import torch
import torch.nn.functional as F
from torch import nn
from tqdm import tqdm
from sklearn import metrics


from transformers import AdamW
from transformers import get_linear_schedule_with_warmup

device = torch.device("cpu")
max_train_epochs = 5
warmup_proportion = 0.05
gradient_accumulation_steps = 1
train_batch_size = 32
valid_batch_size = train_batch_size
test_batch_size = train_batch_size
data_workers= 2



learning_rate=2e-5
weight_decay=0.01
max_grad_norm=1.0

    
cur_time = time.strftime("%Y-%m-%d_%H:%M:%S")

model_path = '/home/zhy/anaconda3/envs/trans/trans-p/Neural_Chinese_Address_Parsing_BERT_state_dict.pkl'

from transformers import BertConfig, BertTokenizer, BertModel, BertForTokenClassification
cls_token='[CLS]'
eos_token='[SEP]'
unk_token='[UNK]'
pad_token='[PAD]'
mask_token='[MASK]'
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
config = BertConfig.from_pretrained('bert-base-chinese')
TheModel = BertModel
ModelForTokenClassification = BertForTokenClassification


In [19]:
labels = ['B-assist', 'I-assist', 'B-cellno', 'I-cellno', 'B-city', 'I-city', 'B-community', 'I-community', 'B-country', 'I-country', 'B-devZone', 'I-devZone', 'B-district', 'I-district', 'B-floorno', 'I-floorno', 'B-houseno', 'I-houseno', 'B-otherinfo', 'I-otherinfo', 'B-person', 'I-person', 'B-poi', 'I-poi', 'B-prov', 'I-prov', 'B-redundant', 'I-redundant', 'B-road', 'I-road', 'B-roadno', 'I-roadno', 'B-roomno', 'I-roomno', 'B-subRoad', 'I-subRoad', 'B-subRoadno', 'I-subRoadno', 'B-subpoi', 'I-subpoi', 'B-subroad', 'I-subroad', 'B-subroadno', 'I-subroadno', 'B-town', 'I-town']
label2id = {}
for i, l in enumerate(labels):
    label2id[l] = i
num_labels = len(labels)

In [20]:
class BertForSeqTagging(ModelForTokenClassification):
    def __init__(self):
        super().__init__(config)
        self.num_labels = num_labels
        self.bert = TheModel.from_pretrained('bert-base-chinese')
        self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
        self.classifier = torch.nn.Linear(config.hidden_size, num_labels)
        self.init_weights()
            
    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs[0]
        batch_size, max_len, feature_dim = sequence_output.shape
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)
        active_loss = attention_mask.view(-1) == 1
        active_logits = logits.view(-1, self.num_labels)[active_loss]

        if labels is not None:
            loss_fct = torch.nn.CrossEntropyLoss()
            active_labels = labels.view(-1)[active_loss]
            loss = loss_fct(active_logits, active_labels)
            return loss
        else:
            return active_logits
        


In [21]:
def eval():
    result = []
    for step, batch in enumerate(tqdm(test_data_loader)):
        input_ids, attention_mask, label = (b.to(device) for b in batch[:-1])
        with torch.no_grad():
            logits = model(input_ids, attention_mask)
            logits = F.softmax(logits, dim=-1)
        logits = logits.data.cpu()
        logit_list = []
        sum_len = 0
        for m in attention_mask:
            l = m.sum().cpu().item()
            logit_list.append(logits[sum_len:sum_len+l])
            sum_len += l
        assert sum_len == len(logits)
        for i, l in enumerate(logit_list):
            rr = torch.argmax(l, dim=1)
            for j, w in enumerate(test_list[batch[-1][i]][0]):
                result.append([w, labels[label[i][j+1].cpu().item()],labels[rr[j+1]]])
            result.append([])
    print(result[:20])
    return result

In [22]:
class MyDataSet(torch.utils.data.Dataset):
    def __init__(self, examples):
        self.examples = examples

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, index):
        example = self.examples[index]
        sentence = example[0]
        #vaild_id = example[1]
        label = example[1]
        
        pad_len = max_token_len - len(sentence)
        total_len = len(sentence)+2
        
        input_token = [cls_token] + sentence + [eos_token] + [pad_token] * pad_len
        input_ids = tokenizer.convert_tokens_to_ids(input_token)
        attention_mask = [1] + [1] * len(sentence) + [1] + [0] * pad_len
        # vaild_mask = [0] + vaild_id + [0] + [0] * pad_len
        # active_mask = [1] * len(label) + [0] * (max_token_len+2-len(label))
        label = [-100] + label + [-100] + [-100] * pad_len
        assert max_token_len + 2 == len(input_ids) == len(attention_mask) == len(input_token)# == len(vaild_mask)
        
        return input_ids, attention_mask, total_len, label, index
    
def the_collate_fn(batch):
    total_lens = [b[2] for b in batch]
    total_len = max(total_lens)
    input_ids = torch.LongTensor([b[0] for b in batch])
    attention_mask = torch.LongTensor([b[1] for b in batch])
    label = torch.LongTensor([b[3] for b in batch])
    input_ids = input_ids[:,:total_len]
    attention_mask = attention_mask[:,:total_len]
    label = label[:,:total_len]

    indexs = [b[4] for b in batch]

    return input_ids, attention_mask, label, indexs

In [23]:
model = BertForSeqTagging()
model.load_state_dict(torch.load(model_path))

<All keys matched successfully>

In [24]:
#cent="中华人民共和国河北省廊坊市市河西区展览路街道美丽小区1"
cent = input("Enter any value: ")
inputs = tokenizer(cent, add_special_tokens=False)
i2 =[]
for aa in cent:
    i2.append(aa)

a1=inputs['attention_mask']

a2 =[]
for aa in a1:
    a2.append(aa)
test_list=[]
test_list.append([i2,a2])
test_dataset = MyDataSet(test_list)
test_data_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=train_batch_size,
    shuffle = False,
    num_workers=data_workers,
    collate_fn=the_collate_fn,
)
print(test_list)



[[['中', '华', '人', '民', '共', '和', '国', '河', '北', '省', '廊', '坊', '市', '市', '河', '西', '区', '展', '览', '路', '街', '道', '美', '丽', '小', '区'], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]]


In [25]:
max_token_len=0
for ls in test_list:
    max_token_len = max(max_token_len, len(ls[0]))
print('max_token_len', max_token_len)
result = eval()


max_token_len 26


100%|██████████| 1/1 [00:00<00:00,  6.22it/s]

[['中', 'I-assist', 'B-country'], ['华', 'I-assist', 'I-country'], ['人', 'I-assist', 'I-prov'], ['民', 'I-assist', 'I-prov'], ['共', 'I-assist', 'I-prov'], ['和', 'I-assist', 'I-prov'], ['国', 'I-assist', 'I-prov'], ['河', 'I-assist', 'B-prov'], ['北', 'I-assist', 'I-prov'], ['省', 'I-assist', 'I-prov'], ['廊', 'I-assist', 'B-city'], ['坊', 'I-assist', 'I-city'], ['市', 'I-assist', 'I-city'], ['市', 'I-assist', 'I-city'], ['河', 'I-assist', 'B-district'], ['西', 'I-assist', 'I-district'], ['区', 'I-assist', 'I-district'], ['展', 'I-assist', 'B-town'], ['览', 'I-assist', 'I-town'], ['路', 'I-assist', 'I-town']]





In [26]:
ii = -1
aas = []
tt = []
for aa in result[:-1]:
        a1,a2,a3=aa
        if a3[0] == "B":
                print('')
                print(a3[2:],':')
        print(a1,end='')



country :
中华人民共和国
prov :
河北省
city :
廊坊市市
district :
河西区
town :
展览路街道
poi :
美丽小区