In [1]:
# # # #
#   - A pretrained BERT with CRF model.
# # # #

# %%
import sys
import os
import time
import importlib
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.autograd as autograd
import torch.optim as optim

from torch.utils.data.distributed import DistributedSampler
from torch.utils import data

from tqdm import tqdm, trange
import collections

from transformers import BertModel, BertForTokenClassification
from transformers.modeling_bert import BertLayerNorm
import pickle
#from transformers import BertAdam, warmup_linear
from transformers import BertTokenizer, AdamW, get_linear_schedule_with_warmup
from transformers import BertTokenizer
from seqeval.metrics import accuracy_score, f1_score, classification_report

def set_work_dir(local_path="bert-crf-company", server_path="bert-crf-company"):
    if (os.path.exists(os.path.abspath('..')+'/'+local_path)):
        os.chdir(os.path.abspath('..')+'/'+local_path)
    elif (os.path.exists(os.path.abspath('..')+'/'+server_path)):
        os.chdir(os.path.abspath('..')+'/'+server_path)
    else:
        raise Exception('Set work path error!')


def get_data_dir(local_path="bert-crf-company/data", server_path="bert-crf-company/data"):
    if (os.path.exists(os.path.abspath('..')+'/'+local_path)):
        return os.path.abspath('..')+'/'+local_path
    elif (os.path.exists(os.path.abspath('..')+'/'+server_path)):
        return os.path.abspath('..')+'/'+server_path
    else:
        raise Exception('get data path error!')


print('Python version ', sys.version)
print('PyTorch version ', torch.__version__)

set_work_dir()
print('Current dir:', os.getcwd())

### 是否使用 GPU
cuda_yes = torch.cuda.is_available()
# cuda_yes = True

print('Cuda is available?', cuda_yes)
device = torch.device("cuda:1" if cuda_yes else "cpu")
print('Device:', device)

data_dir = os.path.join(get_data_dir(), '')
# "Whether to run training."
do_train = True
# "Whether to run eval on the dev set."
do_eval = True
# "Whether to run the model in inference mode on the test set."
do_predict = True
# Whether load checkpoint file before train model
load_checkpoint = True
# "The vocabulary file that the BERT model was trained on."
max_seq_length = 180 #256
batch_size = 32 #32
# "The initial learning rate for Adam."
learning_rate0 = 5e-5
lr0_crf_fc = 8e-5
weight_decay_finetune = 1e-5 #0.01
weight_decay_crf_fc = 5e-6 #0.005
total_train_epochs = 15

gradient_accumulation_steps = 1

warmup_proportion = 0.1

### 模型输出目录 ###
output_dir = './output/'
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

bert_model_scale = './bert-base-chinese'
do_lower_case = False


Python version  3.8.6 | packaged by conda-forge | (default, Dec 26 2020, 05:05:16) 
[GCC 9.3.0]
PyTorch version  1.7.1
Current dir: /root/mao/249/bert-crf-company
Cuda is available? True
Device: cuda:1


In [2]:
# %%
'''
Functions and Classes for read and organize data set
'''

class InputExample(object):
    """A single training/test example for NER."""

    def __init__(self, guid, words, labels):
        """Constructs a InputExample.
        Args:
          guid: Unique id for the example(a sentence or a pair of sentences).
          words: list of words of sentence
          labels_a/labels_b: (Optional) string. The label seqence of the text_a/text_b. This should be
            specified for train and dev examples, but not for test examples.
        """
        self.guid = guid
        self.words = words
        self.labels = labels


class InputFeatures(object):
    """A single set of features of data.
    result of convert_examples_to_features(InputExample)
    """
    def __init__(self, input_ids, input_mask, segment_ids,  predict_mask, label_ids):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.predict_mask = predict_mask
        self.label_ids = label_ids


class DataProcessor(object):
    """Base class for data converters for sequence classification data sets."""
    def get_train_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the train set."""
        raise NotImplementedError()

    def get_dev_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the dev set."""
        raise NotImplementedError()

    def get_labels(self):
        """Gets the list of labels for this data set."""
        raise NotImplementedError()

    @classmethod
    def _read_data(cls, input_file):
        """
        Reads a BIO data.
        """
        with open(input_file) as f:
            # out_lines = []
            out_lists = []
            entries = f.read().strip().split("\n")
            for line in entries:
                entries_ = line.split(' ')
                words = []
                ner_labels = []                
                for entry in entries_:
                    ls_entry=entry.split('/')
                    assert len(ls_entry)==2
                    words.append(ls_entry[0])
                    ner_labels.append(ls_entry[1])

                out_lists.append([words,ner_labels])
        return out_lists
    

    @classmethod
    def _read_data_demo(cls, input_file):
        """
        Reads line
        """
        with open(input_file) as f:
            # out_lines = []
            out_lists = []
            entries = f.read().strip().splitlines()
            
            for line in entries:
                words = []
                ner_labels = []
                for ch in line:
                    words.append(ch)
                    ner_labels.append('O')
                out_lists.append([words,ner_labels])
        return out_lists


In [3]:
class CompanyDataProcessor(DataProcessor):
    def __init__(self):
        self._label_types = [ 'X', '[CLS]', '[SEP]', 'O', 'I-LOC', 'B-LOC', 'I-NAME', 'I-INDU', 'I-ORG', 'B-NAME', 'B-INDU', 'B-ORG']
        self._num_labels = len(self._label_types)
        self._label_map = {label: i for i,
                           label in enumerate(self._label_types)}

    def get_train_examples(self, data_dir):
        return self._create_examples(
            self._read_data(os.path.join(data_dir, "train.1000.txt")))

    def get_dev_examples(self, data_dir):
        return self._create_examples(
            self._read_data(os.path.join(data_dir, "dev.200.txt")))

    def get_test_examples(self, data_dir):
        return self._create_examples(
            self._read_data(os.path.join(data_dir, "test.200.txt"))) 

    def get_labels(self):
        return self._label_types

    def get_num_labels(self):
        return self.get_num_labels

    def get_label_map(self):
        return self._label_map
    
    def get_start_label_id(self):
        return self._label_map['[CLS]']

    def get_stop_label_id(self):
        return self._label_map['[SEP]']

    def _create_examples(self, all_lists):
        examples = []
        for (i, one_lists) in enumerate(all_lists):
            guid = i
            words = one_lists[0]
            labels = one_lists[-1]
            examples.append(InputExample(
                guid=guid, words=words, labels=labels))
        return examples

    def _create_examples2(self, lines):
        examples = []
        for (i, line) in enumerate(lines):
            guid = i
            text = line[0]
            ner_label = line[-1]
            examples.append(InputExample(
                guid=guid, text_a=text, labels_a=ner_label))
        return examples


In [4]:
class NerDataset(data.Dataset):
    def __init__(self, examples, tokenizer, label_map, max_seq_length):
        self.examples=examples
        self.tokenizer=tokenizer
        self.label_map=label_map
        self.max_seq_length=max_seq_length

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        feat=example2feature(self.examples[idx], self.tokenizer, self.label_map, max_seq_length)
        return feat.input_ids, feat.input_mask, feat.segment_ids, feat.predict_mask, feat.label_ids

    @classmethod
    def pad(cls, batch):

        seqlen_list = [len(sample[0]) for sample in batch]
        maxlen = np.array(seqlen_list).max()

        f = lambda x, seqlen: [sample[x] + [0] * (seqlen - len(sample[x])) for sample in batch] # 0: X for padding
        input_ids_list = torch.LongTensor(f(0, maxlen))
        input_mask_list = torch.LongTensor(f(1, maxlen))
        segment_ids_list = torch.LongTensor(f(2, maxlen))
        predict_mask_list = torch.ByteTensor(f(3, maxlen))
        label_ids_list = torch.LongTensor(f(4, maxlen))

        return input_ids_list, input_mask_list, segment_ids_list, predict_mask_list, label_ids_list

In [5]:
def example2feature(example, tokenizer, label_map, max_seq_length):

    add_label = 'X'
    # tokenize_count = []
    tokens = ['[CLS]']
    predict_mask = [0]
    label_ids = [label_map['[CLS]']]
    for i, w in enumerate(example.words):
        # use bertTokenizer to split words
        sub_words = tokenizer.tokenize(w)
        if not sub_words:
            sub_words = ['[UNK]']
        # tokenize_count.append(len(sub_words))
        tokens.extend(sub_words)
        for j in range(len(sub_words)):
            if j == 0:
                predict_mask.append(1)
                label_ids.append(label_map[example.labels[i]])
            else:
                # '##xxx' -> 'X' (see bert paper)
                predict_mask.append(0)
                label_ids.append(label_map[add_label])

    # truncate
    if len(tokens) > max_seq_length - 1:
        print('Example No.{} is too long, length is {}, truncated to {}!'.format(example.guid, len(tokens), max_seq_length))
        tokens = tokens[0:(max_seq_length - 1)]
        predict_mask = predict_mask[0:(max_seq_length - 1)]
        label_ids = label_ids[0:(max_seq_length - 1)]
    tokens.append('[SEP]')
    predict_mask.append(0)
    label_ids.append(label_map['[SEP]'])

    input_ids = tokenizer.convert_tokens_to_ids(tokens)
    segment_ids = [0] * len(input_ids)
    input_mask = [1] * len(input_ids)

    feat=InputFeatures(
                # guid=example.guid,
                # tokens=tokens,
                input_ids=input_ids,
                input_mask=input_mask,
                segment_ids=segment_ids,
                predict_mask=predict_mask,
                label_ids=label_ids)

    return feat


#%%
'''
Prepare data set
'''
# random.seed(44)
np.random.seed(44)
torch.manual_seed(44)
if cuda_yes:
    torch.cuda.manual_seed_all(44)

# Load pre-trained model tokenizer (vocabulary)
companyProcessor = CompanyDataProcessor()
label_list = companyProcessor.get_labels()
label_map = companyProcessor.get_label_map()
train_examples = companyProcessor.get_train_examples(data_dir)
dev_examples = companyProcessor.get_dev_examples(data_dir)
test_examples = companyProcessor.get_test_examples(data_dir)

total_train_steps = int(len(train_examples) / batch_size / gradient_accumulation_steps * total_train_epochs)

print("***** Running training *****")
print("  Num examples = %d"% len(train_examples))
print("  Batch size = %d"% batch_size)
print("  Num steps = %d"% total_train_steps)

tokenizer = BertTokenizer.from_pretrained(bert_model_scale, do_lower_case=do_lower_case)

train_dataset = NerDataset(train_examples,tokenizer,label_map,max_seq_length)
dev_dataset = NerDataset(dev_examples,tokenizer,label_map,max_seq_length)
test_dataset = NerDataset(test_examples,tokenizer,label_map,max_seq_length)

train_dataloader = data.DataLoader(dataset=train_dataset,
                                batch_size=batch_size,
                                shuffle=True,
                                num_workers=4,
                                collate_fn=NerDataset.pad)

dev_dataloader = data.DataLoader(dataset=dev_dataset,
                                batch_size=batch_size,
                                shuffle=False,
                                num_workers=4,
                                collate_fn=NerDataset.pad)

test_dataloader = data.DataLoader(dataset=test_dataset,
                                batch_size=batch_size,
                                shuffle=False,
                                num_workers=4,
                                collate_fn=NerDataset.pad)


def warmup_linear(x, warmup=0.002):
    if x < warmup:
        return x/warmup
    return 1.0 - x


***** Running training *****
  Num examples = 1000
  Batch size = 32
  Num steps = 468


In [6]:
#%%
'''
#####  Use BertModel + CRF  #####
##### 直接用 bert，不用 BertForTokenClassification 了。
CRF is for transition and the maximum likelyhood estimate(MLE).
Bert is for latent label -> Emission of word embedding.
'''

print('*** Use BertModel + CRF ***')


def log_sum_exp_1vec(vec):  # shape(1,m)
    max_score = vec[0, np.argmax(vec)]
    max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1])
    return max_score + torch.log(torch.sum(torch.exp(vec - max_score_broadcast)))

def log_sum_exp_mat(log_M, axis=-1):  # shape(n,m)
    return torch.max(log_M, axis)[0]+torch.log(torch.exp(log_M-torch.max(log_M, axis)[0][:, None]).sum(axis))

def log_sum_exp_batch(log_Tensor, axis=-1): # shape (batch_size,n,m)
    return torch.max(log_Tensor, axis)[0]+torch.log(torch.exp(log_Tensor-torch.max(log_Tensor, axis)[0].view(log_Tensor.shape[0],-1,1)).sum(axis))


class BERT_CRF_NER(nn.Module):
    def __init__(self, bert_model, start_label_id, stop_label_id, num_labels, max_seq_length, batch_size, device):
        super(BERT_CRF_NER, self).__init__()
        self.hidden_size = 768
        self.start_label_id = start_label_id
        self.stop_label_id = stop_label_id
        self.num_labels = num_labels
        # self.max_seq_length = max_seq_length
        self.batch_size = batch_size
        self.device=device

        # use pretrainded BertModel 
        self.bert = bert_model
        self.dropout = torch.nn.Dropout(0.2)
        # Maps the output of the bert into label space.
        self.hidden2label = nn.Linear(self.hidden_size, self.num_labels)

        # Matrix of transition parameters.  Entry i,j is the score of transitioning *to* i *from* j.
        self.transitions = nn.Parameter(
            torch.randn(self.num_labels, self.num_labels))

        # These two statements enforce the constraint that we never transfer *to* the start tag(or label),
        # and we never transfer *from* the stop label (the model would probably learn this anyway,
        # so this enforcement is likely unimportant)
        self.transitions.data[start_label_id, :] = -10000
        self.transitions.data[:, stop_label_id] = -10000

        nn.init.xavier_uniform_(self.hidden2label.weight)
        nn.init.constant_(self.hidden2label.bias, 0.0)
        # self.apply(self.init_bert_weights)

    def init_bert_weights(self, module):
        """ Initialize the weights.
        """
        if isinstance(module, (nn.Linear, nn.Embedding)): 
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        elif isinstance(module, BertLayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()
    
    def _forward_alg(self, feats):
        '''
        recursion: 递归
        this also called alpha-recursion or forward recursion, to calculate log_prob of all barX 
        '''
        
        # T = self.max_seq_length
        T = feats.shape[1]
        batch_size = feats.shape[0]
        
        # alpha_recursion,forward, alpha(zt)=p(zt,bar_x_1:t)
        log_alpha = torch.Tensor(batch_size, 1, self.num_labels).fill_(-10000.).to(self.device)
        # normal_alpha_0 : alpha[0]=Ot[0]*self.PIs
        # self.start_label has all of the score. it is log,0 is p=1
        log_alpha[:, 0, self.start_label_id] = 0
        
        # feats: sentances -> word embedding -> lstm -> MLP -> feats
        # feats is the probability of emission, feat.shape=(1,tag_size)
        for t in range(1, T):
            log_alpha = (log_sum_exp_batch(self.transitions + log_alpha, axis=-1) + feats[:, t]).unsqueeze(1)

        # log_prob of all barX
        log_prob_all_barX = log_sum_exp_batch(log_alpha)
        return log_prob_all_barX

    def _get_bert_features(self, input_ids, segment_ids, input_mask):
        '''
        sentances -> word embedding -> lstm -> MLP -> feats
        '''
        bert_seq_out, _ = self.bert(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, output_hidden_states=False)        
        
        bert_seq_out = self.dropout(bert_seq_out)
        bert_feats = self.hidden2label(bert_seq_out)
        return bert_feats

    def _score_sentence(self, feats, label_ids):
        ''' 
        Gives the score of a provided label sequence
        p(X=w1:t,Zt=tag1:t)=...p(Zt=tag_t|Zt-1=tag_t-1)p(xt|Zt=tag_t)...
        '''
        
        # T = self.max_seq_length
        T = feats.shape[1]
        batch_size = feats.shape[0]

        batch_transitions = self.transitions.expand(batch_size,self.num_labels,self.num_labels)
        batch_transitions = batch_transitions.flatten(1)

        score = torch.zeros((feats.shape[0],1)).to(device)
        # the 0th node is start_label->start_word,the probability of them=1. so t begin with 1.
        for t in range(1, T):
            score = score + \
                batch_transitions.gather(-1, (label_ids[:, t]*self.num_labels+label_ids[:, t-1]).view(-1,1)) \
                    + feats[:, t].gather(-1, label_ids[:, t].view(-1,1)).view(-1,1)
        return score

    def _viterbi_decode(self, feats):
        '''
        Max-Product Algorithm or viterbi algorithm, argmax(p(z_0:t|x_0:t))
        '''
        
        # T = self.max_seq_length
        T = feats.shape[1]
        batch_size = feats.shape[0]

        # batch_transitions=self.transitions.expand(batch_size,self.num_labels,self.num_labels)

        log_delta = torch.Tensor(batch_size, 1, self.num_labels).fill_(-10000.).to(self.device)
        log_delta[:, 0, self.start_label_id] = 0
        
        # psi is for the vaule of the last latent that make P(this_latent) maximum.
        psi = torch.zeros((batch_size, T, self.num_labels), dtype=torch.long).to(self.device)  # psi[0]=0000 useless
        for t in range(1, T):
            # delta[t][k]=max_z1:t-1( p(x1,x2,...,xt,z1,z2,...,zt-1,zt=k|theta) )
            # delta[t] is the max prob of the path from  z_t-1 to z_t[k]
            log_delta, psi[:, t] = torch.max(self.transitions + log_delta, -1)
            # psi[t][k]=argmax_z1:t-1( p(x1,x2,...,xt,z1,z2,...,zt-1,zt=k|theta) )
            # psi[t][k] is the path choosed from z_t-1 to z_t[k],the value is the z_state(is k) index of z_t-1
            log_delta = (log_delta + feats[:, t]).unsqueeze(1)

        # trace back
        path = torch.zeros((batch_size, T), dtype=torch.long).to(self.device)

        # max p(z1:t,all_x|theta)
        max_logLL_allz_allx, path[:, -1] = torch.max(log_delta.squeeze(), -1)

        for t in range(T-2, -1, -1):
            # choose the state of z_t according the state choosed of z_t+1.
            path[:, t] = psi[:, t+1].gather(-1,path[:, t+1].view(-1,1)).squeeze()

        return max_logLL_allz_allx, path


    def neg_log_likelihood(self, input_ids, segment_ids, input_mask, label_ids):
        bert_feats = self._get_bert_features(input_ids, segment_ids, input_mask)
        forward_score = self._forward_alg(bert_feats)
        # p(X=w1:t,Zt=tag1:t)=...p(Zt=tag_t|Zt-1=tag_t-1)p(xt|Zt=tag_t)...
        gold_score = self._score_sentence(bert_feats, label_ids)
        # - log[ p(X=w1:t,Zt=tag1:t)/p(X=w1:t) ] = - log[ p(Zt=tag1:t|X=w1:t) ]
        return torch.mean(forward_score - gold_score)

    # this forward is just for predict, not for train
    # dont confuse this with _forward_alg above.
    def forward(self, input_ids, segment_ids, input_mask):
        # Get the emission scores from the BiLSTM
        bert_feats = self._get_bert_features(input_ids, segment_ids, input_mask)

        # Find the best path, given the features.
        score, label_seq_ids = self._viterbi_decode(bert_feats)
        return score, label_seq_ids


*** Use BertModel + CRF ***


In [7]:
#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# defind NER model 

start_label_id = companyProcessor.get_start_label_id()
stop_label_id = companyProcessor.get_stop_label_id()

bert_model = BertModel.from_pretrained(bert_model_scale)
# from transformers import BertConfig 
# model_config = BertConfig.from_pretrained(bert_model_scale, output_hidden_states=True)
# bert_model = BertModel.from_pretrained(bert_model_scale, config=model_config)


model = BERT_CRF_NER(bert_model, start_label_id, stop_label_id, len(label_list), max_seq_length, batch_size, device)


if load_checkpoint and os.path.exists(output_dir+'/ner_bert_crf_checkpoint.pt'):
    checkpoint = torch.load(output_dir+'/ner_bert_crf_checkpoint.pt', map_location='cpu')
    start_epoch = checkpoint['epoch']+1
    valid_acc_prev = checkpoint['valid_acc']
    valid_f1_prev = checkpoint['valid_f1']
    pretrained_dict=checkpoint['model_state']
    net_state_dict = model.state_dict()
    pretrained_dict_selected = {k: v for k, v in pretrained_dict.items() if k in net_state_dict}
    net_state_dict.update(pretrained_dict_selected)
    model.load_state_dict(net_state_dict)
    print('Loaded the pretrain NER_BERT_CRF model, epoch:',checkpoint['epoch'],'valid acc:', 
            checkpoint['valid_acc'], 'valid f1:', checkpoint['valid_f1'])
else:
    start_epoch = 0
    valid_acc_prev = 0
    valid_f1_prev = 0

model.to(device)

# Prepare optimizer
param_optimizer = list(model.named_parameters())


no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
new_param = ['transitions', 'hidden2label.weight', 'hidden2label.bias']
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay) \
        and not any(nd in n for nd in new_param)], 'weight_decay': weight_decay_finetune},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay) \
        and not any(nd in n for nd in new_param)], 'weight_decay': 0.0},
    {'params': [p for n, p in param_optimizer if n in ('transitions','hidden2label.weight')] \
        , 'lr':lr0_crf_fc, 'weight_decay': weight_decay_crf_fc},
    {'params': [p for n, p in param_optimizer if n == 'hidden2label.bias'] \
        , 'lr':lr0_crf_fc, 'weight_decay': 0.0}
]

# optimizer = BertAdam(optimizer_grouped_parameters, lr=learning_rate0, warmup=warmup_proportion, t_total=total_train_steps)
# optimizer = optim.Adam(model.parameters(), lr=learning_rate0)

optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate0)

In [8]:
def evaluate(model, predict_dataloader, batch_size, epoch_th, dataset_name):
    # print("***** Running prediction *****")
    model.eval()
    all_preds = []
    all_labels = []
    total=0
    correct=0
    start = time.time()
    with torch.no_grad():
        for batch in predict_dataloader:
            batch = tuple(t.to(device) for t in batch)
            input_ids, input_mask, segment_ids, predict_mask, label_ids = batch
            
            _, predicted_label_seq_ids = model(input_ids, segment_ids, input_mask)
            
            
            # _, predicted = torch.max(out_scores, -1)
            valid_predicted = torch.masked_select(predicted_label_seq_ids, predict_mask)
            valid_label_ids = torch.masked_select(label_ids, predict_mask)
            
            all_preds.extend(valid_predicted.tolist())
            all_labels.extend(valid_label_ids.tolist())
            # print(len(valid_label_ids),len(valid_predicted),len(valid_label_ids)==len(valid_predicted))
            total += len(valid_label_ids)
            correct += valid_predicted.eq(valid_label_ids).sum().item()

    all_labels_tag=[label_list[i] for i in all_labels]
    all_preds_tag=[label_list[i] for i in all_preds]

    test_acc = correct/total

    f1=f1_score([all_labels_tag], [all_preds_tag])
    print("F1-Score: {}".format(f1))
    print("Classification report: -- ")
    ### seqeval.metrics 中的函数：classification_report
    print(classification_report([all_labels_tag], [all_preds_tag]))
    
    end = time.time()
    print('Epoch:%d, Acc:%.2f, on %s, Spend:%.3f minutes for evaluation' \
        % (epoch_th, 100.*test_acc, dataset_name,(end-start)/60.0))
    print('--------------------------------------------------------------')
    return test_acc, f1


In [9]:
#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# train procedure
#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

global_step_th = int(len(train_examples) / batch_size / gradient_accumulation_steps * start_epoch)

# train_start=time.time()
for epoch in range(start_epoch, total_train_epochs):
    tr_loss = 0
    train_start = time.time()
    model.train()
    optimizer.zero_grad()
    # for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
    for step, batch in enumerate(train_dataloader):
        batch = tuple(t.to(device) for t in batch)
        input_ids, input_mask, segment_ids, predict_mask, label_ids = batch

        neg_log_likelihood = model.neg_log_likelihood(input_ids, segment_ids, input_mask, label_ids)

        if gradient_accumulation_steps > 1:
            neg_log_likelihood = neg_log_likelihood / gradient_accumulation_steps

        neg_log_likelihood.backward()

        tr_loss += neg_log_likelihood.item()

        if (step + 1) % gradient_accumulation_steps == 0:
            # modify learning rate with special warm up BERT uses
            lr_this_step = learning_rate0 * warmup_linear(global_step_th/total_train_steps, warmup_proportion)
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr_this_step
            optimizer.step()
            optimizer.zero_grad()
            global_step_th += 1
            
        print("Epoch:{}-{}/{}, Negative loglikelihood: {} ".format(epoch, step, len(train_dataloader), neg_log_likelihood.item()))
    
    print('--------------------------------------------------------------')
    print("Epoch:{} completed, Total training's Loss: {}, Spend: {}m".format(epoch, tr_loss, (time.time() - train_start)/60.0))
    
    ### 当前 epoch 的 evaluate
    valid_acc, valid_f1 = evaluate(model, dev_dataloader, batch_size, epoch, 'Valid_set')

    # Save a checkpoint
    if valid_f1 > valid_f1_prev:
        torch.save({'epoch': epoch, 'model_state': model.state_dict(), 'valid_acc': valid_acc,
            'valid_f1': valid_f1, 'max_seq_length': max_seq_length, 'lower_case': do_lower_case},
                    os.path.join(output_dir, 'ner_bert_crf_checkpoint.pt'))
        valid_f1_prev = valid_f1



evaluate(model, test_dataloader, batch_size, total_train_epochs-1, 'Test_set')


#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
'''
Test_set prediction using the best epoch of NER_BERT_CRF model
'''
checkpoint = torch.load(output_dir+'/ner_bert_crf_checkpoint.pt', map_location='cpu')
epoch = checkpoint['epoch']
valid_acc_prev = checkpoint['valid_acc']
valid_f1_prev = checkpoint['valid_f1']
pretrained_dict=checkpoint['model_state']
net_state_dict = model.state_dict()
pretrained_dict_selected = {k: v for k, v in pretrained_dict.items() if k in net_state_dict}
net_state_dict.update(pretrained_dict_selected)
model.load_state_dict(net_state_dict)
print('Loaded the pretrain  NER_BERT_CRF  model, epoch:',checkpoint['epoch'],'valid acc:', 
      checkpoint['valid_acc'], 'valid f1:', checkpoint['valid_f1'])

model.to(device)
#evaluate(model, train_dataloader, batch_size, total_train_epochs-1, 'Train_set')
evaluate(model, test_dataloader, batch_size, epoch, 'Test_set')
# print('Total spend:',(time.time()-train_start)/60.0)


Epoch:0-0/32, Negative loglikelihood: 9743.21484375 
Epoch:0-1/32, Negative loglikelihood: 9765.427734375 
Epoch:0-2/32, Negative loglikelihood: 9763.287109375 
Epoch:0-3/32, Negative loglikelihood: 9740.1259765625 
Epoch:0-4/32, Negative loglikelihood: 9742.478515625 
Epoch:0-5/32, Negative loglikelihood: 9425.0498046875 
Epoch:0-6/32, Negative loglikelihood: 9730.8359375 
Epoch:0-7/32, Negative loglikelihood: 9736.9765625 
Epoch:0-8/32, Negative loglikelihood: 9731.392578125 
Epoch:0-9/32, Negative loglikelihood: 9721.7763671875 
Epoch:0-10/32, Negative loglikelihood: 9719.5595703125 
Epoch:0-11/32, Negative loglikelihood: 9408.3671875 
Epoch:0-12/32, Negative loglikelihood: 9714.478515625 
Epoch:0-13/32, Negative loglikelihood: 9712.365234375 
Epoch:0-14/32, Negative loglikelihood: 9708.087890625 
Epoch:0-15/32, Negative loglikelihood: 9705.453125 
Epoch:0-16/32, Negative loglikelihood: 9703.2001953125 
Epoch:0-17/32, Negative loglikelihood: 9699.1552734375 
Epoch:0-18/32, Negative 

  valid_predicted = torch.masked_select(predicted_label_seq_ids, predict_mask)
  valid_label_ids = torch.masked_select(label_ids, predict_mask)


F1-Score: 0.9469240048250904
Classification report: -- 
              precision    recall  f1-score   support

        INDU       0.91      0.98      0.94       200
         LOC       0.96      0.98      0.97       203
        NAME       0.93      0.95      0.94       208
         ORG       0.92      0.96      0.94       203

   micro avg       0.93      0.96      0.95       814
   macro avg       0.93      0.96      0.95       814
weighted avg       0.93      0.96      0.95       814

Epoch:0, Acc:98.02, on Valid_set, Spend:0.007 minutes for evaluation
--------------------------------------------------------------
Epoch:1-0/32, Negative loglikelihood: 9673.21484375 
Epoch:1-1/32, Negative loglikelihood: 9671.703125 
Epoch:1-2/32, Negative loglikelihood: 9669.615234375 
Epoch:1-3/32, Negative loglikelihood: 9668.697265625 
Epoch:1-4/32, Negative loglikelihood: 9668.37109375 
Epoch:1-5/32, Negative loglikelihood: 9667.7138671875 
Epoch:1-6/32, Negative loglikelihood: 9666.830078125 
Epo

Epoch:4-2/32, Negative loglikelihood: 9656.587890625 
Epoch:4-3/32, Negative loglikelihood: 9345.1455078125 
Epoch:4-4/32, Negative loglikelihood: 9656.73046875 
Epoch:4-5/32, Negative loglikelihood: 9656.875 
Epoch:4-6/32, Negative loglikelihood: 9656.365234375 
Epoch:4-7/32, Negative loglikelihood: 9656.4931640625 
Epoch:4-8/32, Negative loglikelihood: 9656.2041015625 
Epoch:4-9/32, Negative loglikelihood: 9344.8623046875 
Epoch:4-10/32, Negative loglikelihood: 9656.63671875 
Epoch:4-11/32, Negative loglikelihood: 9656.4765625 
Epoch:4-12/32, Negative loglikelihood: 9657.43359375 
Epoch:4-13/32, Negative loglikelihood: 9656.671875 
Epoch:4-14/32, Negative loglikelihood: 9656.4111328125 
Epoch:4-15/32, Negative loglikelihood: 9344.6982421875 
Epoch:4-16/32, Negative loglikelihood: 9344.794921875 
Epoch:4-17/32, Negative loglikelihood: 9656.1865234375 
Epoch:4-18/32, Negative loglikelihood: 9345.4853515625 
Epoch:4-19/32, Negative loglikelihood: 9344.75 
Epoch:4-20/32, Negative loglike

Epoch:7-16/32, Negative loglikelihood: 9030.7314453125 
Epoch:7-17/32, Negative loglikelihood: 9653.564453125 
Epoch:7-18/32, Negative loglikelihood: 9653.28515625 
Epoch:7-19/32, Negative loglikelihood: 9653.4423828125 
Epoch:7-20/32, Negative loglikelihood: 9654.3466796875 
Epoch:7-21/32, Negative loglikelihood: 9653.625 
Epoch:7-22/32, Negative loglikelihood: 9653.310546875 
Epoch:7-23/32, Negative loglikelihood: 9653.431640625 
Epoch:7-24/32, Negative loglikelihood: 9653.2294921875 
Epoch:7-25/32, Negative loglikelihood: 9342.37109375 
Epoch:7-26/32, Negative loglikelihood: 9653.5400390625 
Epoch:7-27/32, Negative loglikelihood: 9653.29296875 
Epoch:7-28/32, Negative loglikelihood: 9652.9736328125 
Epoch:7-29/32, Negative loglikelihood: 9652.9609375 
Epoch:7-30/32, Negative loglikelihood: 9653.2333984375 
Epoch:7-31/32, Negative loglikelihood: 7473.2333984375 
--------------------------------------------------------------
Epoch:7 completed, Total training's Loss: 304554.48828125, S

Epoch:10-28/32, Negative loglikelihood: 9651.275390625 
Epoch:10-29/32, Negative loglikelihood: 9651.6337890625 
Epoch:10-30/32, Negative loglikelihood: 9651.2509765625 
Epoch:10-31/32, Negative loglikelihood: 8717.4765625 
--------------------------------------------------------------
Epoch:10 completed, Total training's Loss: 305428.443359375, Spend: 0.0631765325864156m
F1-Score: 0.9932473910374464
Classification report: -- 
              precision    recall  f1-score   support

        INDU       0.99      0.99      0.99       200
         LOC       1.00      0.99      0.99       203
        NAME       0.99      1.00      0.99       208
         ORG       1.00      1.00      1.00       203

   micro avg       0.99      0.99      0.99       814
   macro avg       0.99      0.99      0.99       814
weighted avg       0.99      0.99      0.99       814

Epoch:10, Acc:99.68, on Valid_set, Spend:0.008 minutes for evaluation
--------------------------------------------------------------
E

Epoch:14-0/32, Negative loglikelihood: 9339.796875 
Epoch:14-1/32, Negative loglikelihood: 9650.81640625 
Epoch:14-2/32, Negative loglikelihood: 9650.6171875 
Epoch:14-3/32, Negative loglikelihood: 9650.94921875 
Epoch:14-4/32, Negative loglikelihood: 9651.0615234375 
Epoch:14-5/32, Negative loglikelihood: 9650.91796875 
Epoch:14-6/32, Negative loglikelihood: 9650.80078125 
Epoch:14-7/32, Negative loglikelihood: 9650.455078125 
Epoch:14-8/32, Negative loglikelihood: 8716.8134765625 
Epoch:14-9/32, Negative loglikelihood: 9650.9462890625 
Epoch:14-10/32, Negative loglikelihood: 9650.3876953125 
Epoch:14-11/32, Negative loglikelihood: 9650.380859375 
Epoch:14-12/32, Negative loglikelihood: 9650.46484375 
Epoch:14-13/32, Negative loglikelihood: 9651.048828125 
Epoch:14-14/32, Negative loglikelihood: 9650.591796875 
Epoch:14-15/32, Negative loglikelihood: 9650.833984375 
Epoch:14-16/32, Negative loglikelihood: 9650.6962890625 
Epoch:14-17/32, Negative loglikelihood: 9650.8896484375 
Epoch:

(0.9962962962962963, 0.9894606323620583)