In [20]:
import os
import time
import sklearn
import torch
import torch.nn.functional as F
from torch import nn
from tqdm import tqdm
from sklearn import metrics


from transformers import AdamW
from transformers import get_linear_schedule_with_warmup

device = torch.device("cuda")
max_train_epochs = 5
warmup_proportion = 0.05
gradient_accumulation_steps = 1
train_batch_size = 32
valid_batch_size = train_batch_size
test_batch_size = train_batch_size
data_workers= 2


learning_rate=2e-5
weight_decay=0.01
max_grad_norm=1.0

    
cur_time = time.strftime("%Y-%m-%d_%H:%M:%S")

base_path = '/home/zhy/anaconda3/envs/trans/neural-chinese-address-parsing/data/'

from transformers import BertConfig, BertTokenizer, BertModel, BertForTokenClassification
cls_token='[CLS]'
eos_token='[SEP]'
unk_token='[UNK]'
pad_token='[PAD]'
mask_token='[MASK]'
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
config = BertConfig.from_pretrained('bert-base-chinese')
TheModel = BertModel
ModelForTokenClassification = BertForTokenClassification

eos_id = tokenizer.convert_tokens_to_ids([eos_token])[0]
unk_id = tokenizer.convert_tokens_to_ids([unk_token])[0]
period_id = tokenizer.convert_tokens_to_ids(['.'])[0]
print(eos_id, unk_id, period_id)

102 100 119


In [21]:
labels = ['B-assist', 'I-assist', 'B-cellno', 'I-cellno', 'B-city', 'I-city', 'B-community', 'I-community', 'B-country', 'I-country', 'B-devZone', 'I-devZone', 'B-district', 'I-district', 'B-floorno', 'I-floorno', 'B-houseno', 'I-houseno', 'B-otherinfo', 'I-otherinfo', 'B-person', 'I-person', 'B-poi', 'I-poi', 'B-prov', 'I-prov', 'B-redundant', 'I-redundant', 'B-road', 'I-road', 'B-roadno', 'I-roadno', 'B-roomno', 'I-roomno', 'B-subRoad', 'I-subRoad', 'B-subRoadno', 'I-subRoadno', 'B-subpoi', 'I-subpoi', 'B-subroad', 'I-subroad', 'B-subroadno', 'I-subroadno', 'B-town', 'I-town']
label2id = {}
for i, l in enumerate(labels):
    label2id[l] = i
num_labels = len(labels)

In [22]:
def get_data_list(f):
    data_list = []
    origin_token, token, label = [], [], []
    for l in f:
        l = l.strip().split()
        if not l:
            data_list.append([token, label, origin_token])
            origin_token, token, label = [], [], []
            continue
        for i, tok in enumerate(l[0]):
            token.append(tok)
            label.append(label2id[l[1]])
        origin_token.append(l[0])
    assert len(token) == 0
    return data_list

f_train = open(base_path + 'train.txt')
f_test = open(base_path + 'test.txt')
f_dev = open(base_path + 'dev.txt')

train_list = get_data_list(f_train)
test_list = get_data_list(f_test)
dev_list = get_data_list(f_dev)
print(len(train_list), len(test_list), len(dev_list))
max_token_len = 0
for ls in [train_list, test_list, dev_list]:
    for l in ls:
        max_token_len = max(max_token_len, len(l[0]))
print('max_token_len', max_token_len)

8957 2985 2985
max_token_len 76


In [23]:
class MyDataSet(torch.utils.data.Dataset):
    def __init__(self, examples):
        self.examples = examples

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, index):
        example = self.examples[index]
        sentence = example[0]
        #vaild_id = example[1]
        label = example[1]
        
        pad_len = max_token_len - len(sentence)
        total_len = len(sentence)+2
        
        input_token = [cls_token] + sentence + [eos_token] + [pad_token] * pad_len
        input_ids = tokenizer.convert_tokens_to_ids(input_token)
        attention_mask = [1] + [1] * len(sentence) + [1] + [0] * pad_len

        label = [-100] + label + [-100] + [-100] * pad_len
        assert max_token_len + 2 == len(input_ids) == len(attention_mask) == len(input_token)
        
        return input_ids, attention_mask, total_len, label, index

def the_collate_fn(batch):
    total_lens = [b[2] for b in batch]
    total_len = max(total_lens)
    input_ids = torch.LongTensor([b[0] for b in batch])
    attention_mask = torch.LongTensor([b[1] for b in batch])
    label = torch.LongTensor([b[3] for b in batch])
    input_ids = input_ids[:,:total_len]
    attention_mask = attention_mask[:,:total_len]
    label = label[:,:total_len]

    indexs = [b[4] for b in batch]

    return input_ids, attention_mask, label, indexs

train_dataset = MyDataSet(train_list)
train_data_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=train_batch_size,
    shuffle = True,
    num_workers=data_workers,
    collate_fn=the_collate_fn,
)

test_dataset = MyDataSet(test_list)
test_data_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=train_batch_size,
    shuffle = False,
    num_workers=data_workers,
    collate_fn=the_collate_fn,
)

In [24]:
def eval():
    result = []
    for step, batch in enumerate(tqdm(test_data_loader)):
        input_ids, attention_mask, label = (b.to(device) for b in batch[:-1])
        with torch.no_grad():
            logits = model(input_ids, attention_mask)
            logits = F.softmax(logits, dim=-1)
        logits = logits.data.cpu()
        logit_list = []
        sum_len = 0
        for m in attention_mask:
            l = m.sum().cpu().item()
            logit_list.append(logits[sum_len:sum_len+l])
            sum_len += l
        assert sum_len == len(logits)
        for i, l in enumerate(logit_list):
            rr = torch.argmax(l, dim=1)
            for j, w in enumerate(test_list[batch[-1][i]][0]):
                result.append([w, labels[label[i][j+1].cpu().item()],labels[rr[j+1]]])
            result.append([])
    print(result[:20])
    return result

def log(msg):

    print(msg)

In [25]:
class BertForSeqTagging(ModelForTokenClassification):
    def __init__(self):
        super().__init__(config)
        self.num_labels = num_labels
        self.bert = TheModel.from_pretrained('bert-base-chinese')
        self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
        self.classifier = torch.nn.Linear(config.hidden_size, num_labels)
        self.init_weights()
            
    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs[0]
        batch_size, max_len, feature_dim = sequence_output.shape
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)
        active_loss = attention_mask.view(-1) == 1
        active_logits = logits.view(-1, self.num_labels)[active_loss]

        if labels is not None:
            loss_fct = torch.nn.CrossEntropyLoss()
            active_labels = labels.view(-1)[active_loss]
            loss = loss_fct(active_logits, active_labels)
            return loss
        else:
            return active_logits
        
model = BertForSeqTagging()
model.to(device)
t_total = len(train_data_loader) // gradient_accumulation_steps * max_train_epochs + 1

num_warmup_steps = int(warmup_proportion * t_total)
log('warmup steps : %d' % num_warmup_steps)

no_decay = ['bias', 'LayerNorm.weight'] 
param_optimizer = list(model.named_parameters())
optimizer_grouped_parameters = [
    {'params':[p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],'weight_decay': weight_decay},
    {'params':[p for n, p in param_optimizer if any(nd in n for nd in no_decay)],'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, correct_bias=False)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=t_total)

warmup steps : 70




In [26]:
for epoch in range(max_train_epochs):
    # train
    epoch_loss = None
    epoch_step = 0
    start_time = time.time()
    model.train()
    for step, batch in enumerate(tqdm(train_data_loader)):
        input_ids, attention_mask, label = (b.to(device) for b in batch[:-1])
        loss = model(input_ids, attention_mask, label)
        loss.backward()

        if (step + 1) % gradient_accumulation_steps == 0:
            optimizer.step()
            scheduler.step() 
            optimizer.zero_grad()
            
        if epoch_loss is None:
            epoch_loss = loss.item()
        else:
            epoch_loss = 0.98*epoch_loss + 0.02*loss.item()
        epoch_step += 1
    
    used_time = (time.time() - start_time)/60
    log('Epoch = %d Epoch Mean Loss %.4f Time %.2f min' % (epoch, epoch_loss, used_time))
    result = eval()
    with open('result.txt', 'w') as f:
        for r in result:
            f.write('\t'.join(r) + '\n')
    y_true = []
    y_pred = []
    for r in result:
        if not r: continue
        y_true.append(label2id[r[1]])
        y_pred.append(label2id[r[2]])
    print(sklearn.metrics.f1_score(y_true, y_pred, average='micro'))


100%|██████████| 280/280 [00:42<00:00,  6.61it/s]


Epoch = 0 Epoch Mean Loss 0.3638 Time 0.71 min


100%|██████████| 94/94 [00:06<00:00, 13.51it/s]


[['龙', 'B-town', 'B-town'], ['港', 'I-town', 'I-town'], ['镇', 'I-town', 'I-town'], ['泰', 'B-poi', 'B-poi'], ['和', 'I-poi', 'I-poi'], ['小', 'I-poi', 'I-poi'], ['区', 'I-poi', 'I-poi'], ['B', 'B-houseno', 'B-houseno'], ['懂', 'I-houseno', 'I-houseno'], ['1', 'B-roomno', 'B-roomno'], ['0', 'B-roomno', 'B-roomno'], ['9', 'B-roomno', 'B-roomno'], ['7', 'B-roomno', 'B-roomno'], [], ['浙', 'B-prov', 'B-prov'], ['江', 'I-prov', 'I-prov'], ['省', 'I-prov', 'I-prov'], ['嘉', 'B-city', 'B-city'], ['兴', 'I-city', 'I-city'], ['市', 'I-city', 'I-city']]
0.9014303038949921


100%|██████████| 280/280 [00:43<00:00,  6.43it/s]


Epoch = 1 Epoch Mean Loss 0.2595 Time 0.73 min


100%|██████████| 94/94 [00:06<00:00, 13.98it/s]


[['龙', 'B-town', 'B-town'], ['港', 'I-town', 'I-town'], ['镇', 'I-town', 'I-town'], ['泰', 'B-poi', 'B-poi'], ['和', 'I-poi', 'I-poi'], ['小', 'I-poi', 'I-poi'], ['区', 'I-poi', 'I-poi'], ['B', 'B-houseno', 'B-houseno'], ['懂', 'I-houseno', 'I-houseno'], ['1', 'B-roomno', 'B-roomno'], ['0', 'B-roomno', 'B-roomno'], ['9', 'B-roomno', 'B-roomno'], ['7', 'B-roomno', 'B-roomno'], [], ['浙', 'B-prov', 'B-prov'], ['江', 'I-prov', 'I-prov'], ['省', 'I-prov', 'I-prov'], ['嘉', 'B-city', 'B-city'], ['兴', 'I-city', 'I-city'], ['市', 'I-city', 'I-city']]
0.908653160222571


100%|██████████| 280/280 [00:43<00:00,  6.41it/s]


Epoch = 2 Epoch Mean Loss 0.2036 Time 0.73 min


100%|██████████| 94/94 [00:07<00:00, 12.29it/s]


[['龙', 'B-town', 'B-town'], ['港', 'I-town', 'I-town'], ['镇', 'I-town', 'I-town'], ['泰', 'B-poi', 'B-poi'], ['和', 'I-poi', 'I-poi'], ['小', 'I-poi', 'I-poi'], ['区', 'I-poi', 'I-poi'], ['B', 'B-houseno', 'B-houseno'], ['懂', 'I-houseno', 'I-houseno'], ['1', 'B-roomno', 'B-roomno'], ['0', 'B-roomno', 'B-roomno'], ['9', 'B-roomno', 'B-roomno'], ['7', 'B-roomno', 'B-roomno'], [], ['浙', 'B-prov', 'B-prov'], ['江', 'I-prov', 'I-prov'], ['省', 'I-prov', 'I-prov'], ['嘉', 'B-city', 'B-city'], ['兴', 'I-city', 'I-city'], ['市', 'I-city', 'I-city']]
0.9110251105721215


100%|██████████| 280/280 [00:45<00:00,  6.15it/s]


Epoch = 3 Epoch Mean Loss 0.1719 Time 0.76 min


100%|██████████| 94/94 [00:07<00:00, 11.96it/s]


[['龙', 'B-town', 'B-town'], ['港', 'I-town', 'I-town'], ['镇', 'I-town', 'I-town'], ['泰', 'B-poi', 'B-poi'], ['和', 'I-poi', 'I-poi'], ['小', 'I-poi', 'I-poi'], ['区', 'I-poi', 'I-poi'], ['B', 'B-houseno', 'B-houseno'], ['懂', 'I-houseno', 'I-houseno'], ['1', 'B-roomno', 'B-roomno'], ['0', 'B-roomno', 'B-roomno'], ['9', 'B-roomno', 'B-roomno'], ['7', 'B-roomno', 'B-roomno'], [], ['浙', 'B-prov', 'B-prov'], ['江', 'I-prov', 'I-prov'], ['省', 'I-prov', 'I-prov'], ['嘉', 'B-city', 'B-city'], ['兴', 'I-city', 'I-city'], ['市', 'I-city', 'I-city']]
0.9120951633613925


100%|██████████| 280/280 [00:44<00:00,  6.26it/s]


Epoch = 4 Epoch Mean Loss 0.1385 Time 0.75 min


100%|██████████| 94/94 [00:06<00:00, 13.74it/s]


[['龙', 'B-town', 'B-town'], ['港', 'I-town', 'I-town'], ['镇', 'I-town', 'I-town'], ['泰', 'B-poi', 'B-poi'], ['和', 'I-poi', 'I-poi'], ['小', 'I-poi', 'I-poi'], ['区', 'I-poi', 'I-poi'], ['B', 'B-houseno', 'B-houseno'], ['懂', 'I-houseno', 'I-houseno'], ['1', 'B-roomno', 'B-roomno'], ['0', 'B-roomno', 'B-roomno'], ['9', 'B-roomno', 'B-roomno'], ['7', 'B-roomno', 'B-roomno'], [], ['浙', 'B-prov', 'B-prov'], ['江', 'I-prov', 'I-prov'], ['省', 'I-prov', 'I-prov'], ['嘉', 'B-city', 'B-city'], ['兴', 'I-city', 'I-city'], ['市', 'I-city', 'I-city']]
0.9142531031530888


In [27]:
torch.save(model.state_dict(), 'Neural_Chinese_Address_Parsing_BERT_state_dict.pkl')