In [1]:
# # # #
#   - A pretrained BERT with CRF model.
# # # #

# %%
import sys
import os
import importlib
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn

from torch.utils import data

from transformers import BertModel
from transformers.modeling_bert import BertLayerNorm
import pickle
#from transformers import BertAdam, warmup_linear
from transformers import BertTokenizer

def set_work_dir(local_path="bert-crf-company", server_path="bert-crf-company"):
    if (os.path.exists(os.path.abspath('..')+'/'+local_path)):
        os.chdir(os.path.abspath('..')+'/'+local_path)
    elif (os.path.exists(os.path.abspath('..')+'/'+server_path)):
        os.chdir(os.path.abspath('..')+'/'+server_path)
    else:
        raise Exception('Set work path error!')


def get_data_dir(local_path="bert-crf-company/data", server_path="bert-crf-company/data"):
    if (os.path.exists(os.path.abspath('..')+'/'+local_path)):
        return os.path.abspath('..')+'/'+local_path
    elif (os.path.exists(os.path.abspath('..')+'/'+server_path)):
        return os.path.abspath('..')+'/'+server_path
    else:
        raise Exception('get data path error!')


print('Python version ', sys.version)
print('PyTorch version ', torch.__version__)

set_work_dir()
print('Current dir:', os.getcwd())

data_dir = os.path.join(get_data_dir(), '')

max_seq_length = 180 #256
batch_size = 32 #32

cuda_yes = torch.cuda.is_available()
device = torch.device("cuda:0" if cuda_yes else "cpu")

#%%%%%%%%%%%%%%% if you want to use your own trained model, please change this dir %%%%%%%%%%%%
output_dir = './output/'   ###    output

if os.path.exists(output_dir+'/ner_bert_crf_checkpoint.pt')==False:
    print('【No trained model!】')

bert_model_scale = './bert-base-chinese'
do_lower_case = False


Python version  3.8.6 | packaged by conda-forge | (default, Dec 26 2020, 05:05:16) 
[GCC 9.3.0]
PyTorch version  1.7.1
Current dir: /root/mao/249/bert-crf-company


In [2]:
# %%
'''
Functions and Classes for read and organize data set
'''

class InputExample(object):
    """A single training/test example for NER."""

    def __init__(self, guid, words, labels):
        """Constructs a InputExample.
        Args:
          guid: Unique id for the example(a sentence or a pair of sentences).
          words: list of words of sentence
          labels_a/labels_b: (Optional) string. The label seqence of the text_a/text_b. This should be
            specified for train and dev examples, but not for test examples.
        """
        self.guid = guid
        self.words = words
        self.labels = labels


class InputFeatures(object):
    """A single set of features of data.
    result of convert_examples_to_features(InputExample)
    """
    def __init__(self, input_ids, input_mask, segment_ids,  predict_mask, label_ids):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.predict_mask = predict_mask
        self.label_ids = label_ids


class DataProcessor(object):
    """Base class for data converters for sequence classification data sets."""
    def get_train_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the train set."""
        raise NotImplementedError()

    def get_dev_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the dev set."""
        raise NotImplementedError()

    def get_labels(self):
        """Gets the list of labels for this data set."""
        raise NotImplementedError()
    
    @classmethod
    def _read_data_demo(cls, input_file):
        """
        Reads line
        """
        with open(input_file) as f:
            # out_lines = []
            out_lists = []
            entries = f.read().strip().splitlines()
            
            for line in entries:
                words = []
                ner_labels = []
                for ch in line:
                    words.append(ch)
                    ner_labels.append('O')
                out_lists.append([words,ner_labels])
        return out_lists


In [3]:
class CompanyDataProcessor(DataProcessor):
    def __init__(self):
        self._label_types = [ 'X', '[CLS]', '[SEP]', 'O', 'I-LOC', 'B-LOC', 'I-NAME', 'I-INDU', 'I-ORG', 'B-NAME', 'B-INDU', 'B-ORG']
        self._num_labels = len(self._label_types)
        self._label_map = {label: i for i,
                           label in enumerate(self._label_types)}

    def get_demo_examples(self, data_dir):
        return self._create_examples(
            self._read_data_demo(os.path.join(data_dir, "demo.txt"))) ##########

    def get_labels(self):
        return self._label_types

    def get_num_labels(self):
        return self.get_num_labels

    def get_label_map(self):
        return self._label_map
    
    def get_start_label_id(self):
        return self._label_map['[CLS]']

    def get_stop_label_id(self):
        return self._label_map['[SEP]']

    def _create_examples(self, all_lists):
        examples = []
        for (i, one_lists) in enumerate(all_lists):
            guid = i
            words = one_lists[0]
            labels = one_lists[-1]
            examples.append(InputExample(
                guid=guid, words=words, labels=labels))
        return examples

    def _create_examples2(self, lines):
        examples = []
        for (i, line) in enumerate(lines):
            guid = i
            text = line[0]
            ner_label = line[-1]
            examples.append(InputExample(
                guid=guid, text_a=text, labels_a=ner_label))
        return examples


In [4]:
class NerDataset(data.Dataset):
    def __init__(self, examples, tokenizer, label_map, max_seq_length):
        self.examples=examples
        self.tokenizer=tokenizer
        self.label_map=label_map
        self.max_seq_length=max_seq_length

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        feat=example2feature(self.examples[idx], self.tokenizer, self.label_map, max_seq_length)
        return feat.input_ids, feat.input_mask, feat.segment_ids, feat.predict_mask, feat.label_ids

    @classmethod
    def pad(cls, batch):

        seqlen_list = [len(sample[0]) for sample in batch]
        maxlen = np.array(seqlen_list).max()

        f = lambda x, seqlen: [sample[x] + [0] * (seqlen - len(sample[x])) for sample in batch] # 0: X for padding
        input_ids_list = torch.LongTensor(f(0, maxlen))
        input_mask_list = torch.LongTensor(f(1, maxlen))
        segment_ids_list = torch.LongTensor(f(2, maxlen))
        predict_mask_list = torch.ByteTensor(f(3, maxlen))
        label_ids_list = torch.LongTensor(f(4, maxlen))

        return input_ids_list, input_mask_list, segment_ids_list, predict_mask_list, label_ids_list

In [5]:
def example2feature(example, tokenizer, label_map, max_seq_length):

    add_label = 'X'
    # tokenize_count = []
    tokens = ['[CLS]']
    predict_mask = [0]
    label_ids = [label_map['[CLS]']]
    for i, w in enumerate(example.words):
        # use bertTokenizer to split words
        sub_words = tokenizer.tokenize(w)
        if not sub_words:
            sub_words = ['[UNK]']
        # tokenize_count.append(len(sub_words))
        tokens.extend(sub_words)
        for j in range(len(sub_words)):
            if j == 0:
                predict_mask.append(1)
                label_ids.append(label_map[example.labels[i]])
            else:
                # '##xxx' -> 'X' (see bert paper)
                predict_mask.append(0)
                label_ids.append(label_map[add_label])

    # truncate
    if len(tokens) > max_seq_length - 1:
        print('Example No.{} is too long, length is {}, truncated to {}!'.format(example.guid, len(tokens), max_seq_length))
        tokens = tokens[0:(max_seq_length - 1)]
        predict_mask = predict_mask[0:(max_seq_length - 1)]
        label_ids = label_ids[0:(max_seq_length - 1)]
    tokens.append('[SEP]')
    predict_mask.append(0)
    label_ids.append(label_map['[SEP]'])

    input_ids = tokenizer.convert_tokens_to_ids(tokens)
    segment_ids = [0] * len(input_ids)
    input_mask = [1] * len(input_ids)

    feat=InputFeatures(
                # guid=example.guid,
                # tokens=tokens,
                input_ids=input_ids,
                input_mask=input_mask,
                segment_ids=segment_ids,
                predict_mask=predict_mask,
                label_ids=label_ids)

    return feat


In [6]:
#%%
'''
#####  Use BertModel + CRF  #####
##### 直接用 bert，不用 BertForTokenClassification 了。
CRF is for transition and the maximum likelyhood estimate(MLE).
Bert is for latent label -> Emission of word embedding.
'''
print('*** Use BertModel + CRF ***')

def log_sum_exp_1vec(vec):  # shape(1,m)
    max_score = vec[0, np.argmax(vec)]
    max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1])
    return max_score + torch.log(torch.sum(torch.exp(vec - max_score_broadcast)))

def log_sum_exp_mat(log_M, axis=-1):  # shape(n,m)
    return torch.max(log_M, axis)[0]+torch.log(torch.exp(log_M-torch.max(log_M, axis)[0][:, None]).sum(axis))

def log_sum_exp_batch(log_Tensor, axis=-1): # shape (batch_size,n,m)
    return torch.max(log_Tensor, axis)[0]+torch.log(torch.exp(log_Tensor-torch.max(log_Tensor, axis)[0].view(log_Tensor.shape[0],-1,1)).sum(axis))


class BERT_CRF_NER(nn.Module):

    def __init__(self, bert_model, start_label_id, stop_label_id, num_labels, max_seq_length, batch_size, device):
        super(BERT_CRF_NER, self).__init__()
        self.hidden_size = 768
        self.start_label_id = start_label_id
        self.stop_label_id = stop_label_id
        self.num_labels = num_labels
        # self.max_seq_length = max_seq_length
        self.batch_size = batch_size
        self.device=device

        # use pretrainded BertModel 
        self.bert = bert_model
        self.dropout = torch.nn.Dropout(0.2)
        # Maps the output of the bert into label space.
        self.hidden2label = nn.Linear(self.hidden_size, self.num_labels)

        # Matrix of transition parameters.  Entry i,j is the score of transitioning *to* i *from* j.
        self.transitions = nn.Parameter(
            torch.randn(self.num_labels, self.num_labels))

        # These two statements enforce the constraint that we never transfer *to* the start tag(or label),
        # and we never transfer *from* the stop label (the model would probably learn this anyway,
        # so this enforcement is likely unimportant)
        self.transitions.data[start_label_id, :] = -10000
        self.transitions.data[:, stop_label_id] = -10000

        nn.init.xavier_uniform_(self.hidden2label.weight)
        nn.init.constant_(self.hidden2label.bias, 0.0)
        # self.apply(self.init_bert_weights)

    def init_bert_weights(self, module):
        """ Initialize the weights.
        """
        if isinstance(module, (nn.Linear, nn.Embedding)): 
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        elif isinstance(module, BertLayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

    def _forward_alg(self, feats):
        '''
        recursion: 递归
        this also called alpha-recursion or forward recursion, to calculate log_prob of all barX 
        '''
        
        # T = self.max_seq_length
        T = feats.shape[1]
        batch_size = feats.shape[0]
        
        # alpha_recursion,forward, alpha(zt)=p(zt,bar_x_1:t)
        log_alpha = torch.Tensor(batch_size, 1, self.num_labels).fill_(-10000.).to(self.device)
        # normal_alpha_0 : alpha[0]=Ot[0]*self.PIs
        # self.start_label has all of the score. it is log,0 is p=1
        log_alpha[:, 0, self.start_label_id] = 0
        
        # feats: sentances -> word embedding -> lstm -> MLP -> feats
        # feats is the probability of emission, feat.shape=(1,tag_size)
        for t in range(1, T):
            log_alpha = (log_sum_exp_batch(self.transitions + log_alpha, axis=-1) + feats[:, t]).unsqueeze(1)

        # log_prob of all barX
        log_prob_all_barX = log_sum_exp_batch(log_alpha)
        return log_prob_all_barX

    def _get_bert_features(self, input_ids, segment_ids, input_mask):
        '''
        sentances -> word embedding -> lstm -> MLP -> feats
        '''
        bert_seq_out, _ = self.bert(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, output_hidden_states=False)        
        
        bert_seq_out = self.dropout(bert_seq_out)
        bert_feats = self.hidden2label(bert_seq_out)
        return bert_feats

    def _score_sentence(self, feats, label_ids):
        ''' 
        Gives the score of a provided label sequence
        p(X=w1:t,Zt=tag1:t)=...p(Zt=tag_t|Zt-1=tag_t-1)p(xt|Zt=tag_t)...
        '''
        
        # T = self.max_seq_length
        T = feats.shape[1]
        batch_size = feats.shape[0]

        batch_transitions = self.transitions.expand(batch_size,self.num_labels,self.num_labels)
        batch_transitions = batch_transitions.flatten(1)

        score = torch.zeros((feats.shape[0],1)).to(device)
        # the 0th node is start_label->start_word,the probability of them=1. so t begin with 1.
        for t in range(1, T):
            score = score + \
                batch_transitions.gather(-1, (label_ids[:, t]*self.num_labels+label_ids[:, t-1]).view(-1,1)) \
                    + feats[:, t].gather(-1, label_ids[:, t].view(-1,1)).view(-1,1)
        return score

    def _viterbi_decode(self, feats):
        '''
        Max-Product Algorithm or viterbi algorithm, argmax(p(z_0:t|x_0:t))
        '''
        
        # T = self.max_seq_length
        T = feats.shape[1]
        batch_size = feats.shape[0]

        # batch_transitions=self.transitions.expand(batch_size,self.num_labels,self.num_labels)

        log_delta = torch.Tensor(batch_size, 1, self.num_labels).fill_(-10000.).to(self.device)
        log_delta[:, 0, self.start_label_id] = 0
        
        # psi is for the vaule of the last latent that make P(this_latent) maximum.
        psi = torch.zeros((batch_size, T, self.num_labels), dtype=torch.long).to(self.device)  # psi[0]=0000 useless
        for t in range(1, T):
            # delta[t][k]=max_z1:t-1( p(x1,x2,...,xt,z1,z2,...,zt-1,zt=k|theta) )
            # delta[t] is the max prob of the path from  z_t-1 to z_t[k]
            log_delta, psi[:, t] = torch.max(self.transitions + log_delta, -1)
            # psi[t][k]=argmax_z1:t-1( p(x1,x2,...,xt,z1,z2,...,zt-1,zt=k|theta) )
            # psi[t][k] is the path choosed from z_t-1 to z_t[k],the value is the z_state(is k) index of z_t-1
            log_delta = (log_delta + feats[:, t]).unsqueeze(1)

        # trace back
        path = torch.zeros((batch_size, T), dtype=torch.long).to(self.device)

        # max p(z1:t,all_x|theta)
        max_logLL_allz_allx, path[:, -1] = torch.max(log_delta.squeeze(), -1)

        for t in range(T-2, -1, -1):
            # choose the state of z_t according the state choosed of z_t+1.
            path[:, t] = psi[:, t+1].gather(-1,path[:, t+1].view(-1,1)).squeeze()

        return max_logLL_allz_allx, path

    ### 注意，train 的时候调用的是这个函数；这个函数里再调用 _forward_alg
    def neg_log_likelihood(self, input_ids, segment_ids, input_mask, label_ids):
        bert_feats = self._get_bert_features(input_ids, segment_ids, input_mask)
        forward_score = self._forward_alg(bert_feats)
        # p(X=w1:t,Zt=tag1:t)=...p(Zt=tag_t|Zt-1=tag_t-1)p(xt|Zt=tag_t)...
        gold_score = self._score_sentence(bert_feats, label_ids)
        # - log[ p(X=w1:t,Zt=tag1:t)/p(X=w1:t) ] = - log[ p(Zt=tag1:t|X=w1:t) ]
        return torch.mean(forward_score - gold_score)

    def forward(self, input_ids, segment_ids, input_mask):
        # Get the emission scores from the BiLSTM
        bert_feats = self._get_bert_features(input_ids, segment_ids, input_mask)

        # Find the best path, given the features.
        score, label_seq_ids = self._viterbi_decode(bert_feats)
        return score, label_seq_ids


*** Use BertModel + CRF ***


In [7]:
#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# load NER model 

tokenizer = BertTokenizer.from_pretrained(bert_model_scale, do_lower_case=do_lower_case)

companyProcessor = CompanyDataProcessor()

# Load pre-trained model tokenizer (vocabulary)
label_list = companyProcessor.get_labels()
label_map = companyProcessor.get_label_map()

demo_examples = companyProcessor.get_demo_examples(data_dir)
demo_dataset = NerDataset(demo_examples,tokenizer,label_map,max_seq_length)
demo_dataloader = data.DataLoader(dataset=demo_dataset,
                                batch_size=batch_size,
                                shuffle=False,
                                num_workers=4,
                                collate_fn=NerDataset.pad)

start_label_id = companyProcessor.get_start_label_id()
stop_label_id = companyProcessor.get_stop_label_id()

bert_model = BertModel.from_pretrained(bert_model_scale)

model = BERT_CRF_NER(bert_model, start_label_id, stop_label_id, len(label_list), max_seq_length, batch_size, device)

checkpoint = torch.load(output_dir+'/ner_bert_crf_checkpoint.pt', map_location='cpu')
start_epoch = checkpoint['epoch']+1
valid_acc_prev = checkpoint['valid_acc']
valid_f1_prev = checkpoint['valid_f1']
pretrained_dict=checkpoint['model_state']
net_state_dict = model.state_dict()
pretrained_dict_selected = {k: v for k, v in pretrained_dict.items() if k in net_state_dict}
net_state_dict.update(pretrained_dict_selected)
model.load_state_dict(net_state_dict)
print('Loaded the pretrain NER_BERT_CRF model, epoch:',checkpoint['epoch'],'valid acc:', 
        checkpoint['valid_acc'], 'valid f1:', checkpoint['valid_f1'])

model.to(device)


Loaded the pretrain NER_BERT_CRF model, epoch: 10 valid acc: 0.9967663702506063 valid f1: 0.9932473910374464


BERT_CRF_NER(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(21128, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)

In [8]:

#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

i_batch=0
model.eval()
with torch.no_grad():
    demon_dataloader = data.DataLoader(dataset=demo_dataset,
                                batch_size=10, #######
                                shuffle=False,
                                num_workers=4,
                                collate_fn=NerDataset.pad)
    for batch in demon_dataloader:
        i_batch+=1
        batch = tuple(t.to(device) for t in batch)
        input_ids, input_mask, segment_ids, predict_mask, label_ids = batch
        _, predicted_label_seq_ids = model(input_ids, segment_ids, input_mask)
        # _, predicted = torch.max(out_scores, -1)
        valid_predicted = torch.masked_select(predicted_label_seq_ids, predict_mask)
        # valid_label_ids = torch.masked_select(label_ids, predict_mask)
        for i in range(len(input_ids)):  #### 10

            new_ids=predicted_label_seq_ids[i].cpu().numpy()[predict_mask[i].cpu().numpy()==1]

            predicted_label=list(map(lambda i: label_list[i], new_ids))
            #print(predicted_label)

            ls_words=demo_examples[(i_batch-1)*10+i].words
            str_words=''
            bStart_LOC=False
            bStart_NAME=False            
            bStart_INDU=False
            bStart_ORG=False
            str_words=''
            for word, label in zip(ls_words, predicted_label):
                ### 起始符号
                if label.find('B-')>=0 and len(str_words)==0:
                    if label.find('LOC')>0:
                        str_words+='['
                        bStart_LOC=True                       
                    elif label.find('NAME')>0:
                        str_words+='【'
                        bStart_NAME=True                                                
                    elif label.find('INDU')>0:
                        str_words+='{'    
                        bStart_INDU=True
                    elif label.find('ORG')>0:
                        str_words+='<'       
                        bStart_ORG=True   
                ### 中间符号
                elif label.find('B-')>=0:
                    str_end=''
                    str_start=''
                    #结束上一个状态
                    if bStart_LOC:
                        bStart_LOC=False
                        str_end=']'
                    if bStart_NAME:
                        bStart_NAME=False
                        str_end='】'
                    if bStart_INDU:
                        bStart_INDU=False
                        str_end='}'
                    if bStart_ORG:
                        bStart_ORG=False
                        str_end='>'  
                     
                    #开始新的状态
                    if label.find('LOC')>0:
                        bStart_LOC=True
                        str_start='['
                    if label.find('NAME')>0:
                        bStart_NAME=True
                        str_start='【'
                    if label.find('INDU')>0:
                        bStart_INDU=True
                        str_start='{'
                    if label.find('ORG')>0:
                        bStart_ORG=True
                        str_start='<'    
                    
                    str_words+=str_end+str_start
                    
                elif label=='O':
                    str_end=''
                    #结束上一个状态
                    if bStart_LOC:
                        bStart_LOC=False
                        str_end=']'
                    if bStart_NAME:
                        bStart_NAME=False
                        str_end='】'
                    if bStart_INDU:
                        bStart_INDU=False
                        str_end='}'
                    if bStart_ORG:
                        bStart_ORG=False
                        str_end='>'       
                        
                    str_words+=str_end
                        
                str_words+=word
                
            ### 结尾符号    
            if bStart_LOC:
                str_words+=']'
            if bStart_NAME:
                str_words+='】'
            if bStart_INDU:
                str_words+='}'
            if bStart_ORG:
                str_words+='>'  
                
            print(str_words)
                

[武汉]【翔海】{水处理}<有限公司>
[漯河市]【天利】{白瓜籽加工}<有限责任公司>
[大连]【泰和道】{农业发展}<有限公司>
[天津]【大致】【和堂】{医药}<有限公司>【延年】{中药}<店>
[上海]【佳吉】{快运}<有限公司>[余姚市]【低塘】<营业部>
[上饶市]【百业】【众赢】{网络科技}<有限公司>
[本溪市]【枫汤居】{宾馆}<有限公司>
[四川省]【鸿利达】{涂装科技}<有限责任公司>
[武汉]【爱尔】{眼科医院}
[通海县]{电影院}
【邹易儒】([北京]){影视文化}<工作室>


  valid_predicted = torch.masked_select(predicted_label_seq_ids, predict_mask)
