In [2]:
import os
from tqdm import tqdm
import numpy as np
import torch

In [3]:
root = "./data"
train_data_path = os.path.join(root, "train_data.txt")
test_data_path = os.path.join(root, "test_data.txt")

In [4]:
train_ratio = 0.7
batch_size = 128

In [5]:
def get_sentences(path):
    out = []
    with open(path, "r", encoding="utf-8-sig") as f:
        sentence = []
        labels = []
        for line in tqdm(f):
            s = line.strip().split("\t")
            if (len(s) == 2):
                word, label = s[0], s[1]
            else:
                word, label = "[PAD]", s[0]
            sentence.append(word)
            labels.append(label)
            if (word == '。'):
                out.append((sentence, labels))
                sentence = []
                labels = []
    return out

In [6]:
train_data = get_sentences(train_data_path)
test_data = get_sentences(test_data_path)

np.random.shuffle(train_data)
train_number = int(train_ratio * len(train_data))
valid_data = train_data[train_number : ]
train_data = train_data[ : train_number]
print(f"total sentences in train data: {len(train_data)}")
print(f"total sentences in valid data: {len(valid_data)}")
print(f"total sentences in test data: {len(test_data)}")


        


418362it [00:00, 2175048.04it/s]
132709it [00:00, 2225446.05it/s]

total sentences in train data: 5425
total sentences in valid data: 2325
total sentences in test data: 2480





In [7]:
from torch.utils.data import Dataset
from transformers import BertTokenizer

tot_labels = ['[PAD]', '[CLS]', '[SEP]', 'O', 'B-BODY','I-TEST', 'I-EXAMINATIONS',
            'I-TREATMENT', 'B-DRUG', 'B-TREATMENT', 'I-DISEASES', 'B-EXAMINATIONS',
            'I-BODY', 'B-TEST', 'B-DISEASES', 'I-DRUG']
label2idx = {label : idx for idx, label in enumerate(tot_labels)}
idx2label = {idx : label for idx, label in enumerate(tot_labels)}

class NERDataset(Dataset):
    def __init__(self, data, tokenizer, MAX_LEN=256-2):
        self.data = data
        self.tokenizer = tokenizer
        self.sentences = []
        self.labels = []
        for sentence, labels in self.data:
            if (len(sentence) > MAX_LEN):
                sentence = sentence[ : MAX_LEN]
                labels = labels[ : MAX_LEN]
            sentence = ["[CLS]"] + sentence + ["[SEP]"]
            labels = ["[CLS]"] + labels + ["[SEP]"]
            self.sentences.append(tokenizer.convert_tokens_to_ids(sentence))
            self.labels.append([label2idx[label] for label in labels])
             
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        return self.sentences[index], self.labels[index]
    
    def collate_fn(self, batch):
        max_len = max([len(data[0]) for data in batch])
        PAD_ID = self.tokenizer.pad_token_id
        sentence_tensors = torch.LongTensor([data[0] + [PAD_ID] * (max_len - len(data[0])) for data in batch])
        labels_tensors = torch.LongTensor([data[1] + [label2idx["[PAD]"]] * (max_len - len(data[1])) for data in batch])
        masks = (sentence_tensors != PAD_ID)
        return sentence_tensors, labels_tensors, masks
    


In [8]:
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
train_dataset = NERDataset(train_data, tokenizer)
valid_dataset = NERDataset(valid_data, tokenizer)
test_dataset = NERDataset(test_data, tokenizer)
    

In [9]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=train_dataset.collate_fn, num_workers=4)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, collate_fn=valid_dataset.collate_fn, num_workers=4)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=test_dataset.collate_fn, num_workers=4)

for data in train_dataloader:
    print(data[0])
    print(data[1])
    print(data[2])
    print(data[0].shape)
    print(data[1].shape)
    print(data[2].shape)
    break


tensor([[ 101,  123,  121,  ...,    0,    0,    0],
        [ 101, 8020, 7360,  ...,    0,    0,    0],
        [ 101, 4680, 1184,  ...,    0,    0,    0],
        ...,
        [ 101, 8024, 5357,  ...,    0,    0,    0],
        [ 101,  754,  123,  ...,    0,    0,    0],
        [ 101, 2642, 5442,  ...,    0,    0,    0]])
tensor([[1, 3, 3,  ..., 0, 0, 0],
        [1, 3, 4,  ..., 0, 0, 0],
        [1, 3, 3,  ..., 0, 0, 0],
        ...,
        [1, 3, 3,  ..., 0, 0, 0],
        [1, 3, 3,  ..., 0, 0, 0],
        [1, 3, 3,  ..., 0, 0, 0]])
tensor([[ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False]])
torch.Size([128, 256])
torch.Size([128, 256])
torch.Size([128, 256])


In [10]:
from transformers import BertModel
from torchcrf import CRF
import torch.nn as nn

class BertCRF(nn.Module):
    def __init__(self, target_size, hidden_dim=768):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.bert = BertModel.from_pretrained("bert-base-chinese")
        self.crf = CRF(target_size, batch_first=True)
        self.linear = nn.Linear(hidden_dim, target_size)
        
    def forward(self, sentences, labels, mask):
        with torch.no_grad():
            output = self.bert(input_ids=sentences, attention_mask=mask)
        last_hidden_states = output.last_hidden_state
        out = self.linear(last_hidden_states)
        loss = -self.crf(out, labels, mask, reduction="mean")
        return loss
    
    def decode(self, sentences, mask):
        with torch.no_grad():
            output = self.bert(input_ids=sentences, attention_mask=mask)
            last_hidden_states = output.last_hidden_state
            out = self.linear(last_hidden_states)
            decode = self.crf.decode(out, mask)
            return decode

        

In [11]:
target_size = len(label2idx)
model = BertCRF(target_size=target_size)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [12]:
from torch.optim import AdamW

optimizer = AdamW(model.parameters(), lr=0.001, weight_decay=1e-5, eps=1e-6)

In [13]:
import logging

logger = logging.getLogger('my_logger')
logger.setLevel(logging.INFO)

file_handler = logging.FileHandler('train.log')

formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler.setFormatter(formatter)

logger.addHandler(file_handler)

In [14]:
def evaluate(model, dataloader):
    model.eval()
    predicted = []
    truth = []
    for data in tqdm(dataloader):
        sentences, labels, masks = data
        sentences = sentences.to(device)
        labels = labels.to(device)
        masks = masks.to(device)
        out = model.decode(sentences, masks)
        for temp in out:
            predicted.extend(temp)
        y_origin = torch.masked_select(labels, masks)
        truth.extend(y_origin.cpu().tolist())
    predicted = np.array(predicted)
    truth = np.array(truth)
    return predicted, truth, (predicted == truth).mean()
        
        
        

In [95]:
Epoch = 30
eval_every = 1
early_stop = 5

best_acc = 0
best_epoch = -1
best_model_pt_path = "model.pt"

for epoch in range(Epoch):
    epoch_loss = 0
    for data in tqdm(train_dataloader):
        sentences, labels, masks = data
        sentences = sentences.to(device)
        labels = labels.to(device)
        masks = masks.to(device)
        loss = model(sentences, labels, masks)
        epoch_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    logger.info(f"Epoch:{epoch}, loss:{epoch_loss:.4f}")
    if (epoch % eval_every == 0):
        acc = evaluate(model, valid_dataloader)[-1]
        if (acc > best_acc):
            best_acc = acc
            best_epoch = epoch
            torch.save(model.state_dict(), best_model_pt_path)
        else:
            early_stop -= 1
        logger.info(f"Epoch:{epoch}, valid acc:{acc * 100:.2f}%, current best valid acc:{best_acc * 100:.2f}%, at epoch {best_epoch}")
        if (early_stop == 0):
            logger.info(f"Early stop!")
            break
        

        


100%|██████████| 43/43 [00:55<00:00,  1.28s/it]
100%|██████████| 19/19 [00:28<00:00,  1.50s/it]
100%|██████████| 43/43 [00:54<00:00,  1.27s/it]
100%|██████████| 19/19 [00:28<00:00,  1.50s/it]
100%|██████████| 43/43 [00:54<00:00,  1.27s/it]
100%|██████████| 19/19 [00:28<00:00,  1.50s/it]
100%|██████████| 43/43 [00:55<00:00,  1.29s/it]
100%|██████████| 19/19 [00:28<00:00,  1.49s/it]
100%|██████████| 43/43 [00:57<00:00,  1.33s/it]
100%|██████████| 19/19 [00:28<00:00,  1.51s/it]
100%|██████████| 43/43 [00:54<00:00,  1.28s/it]
100%|██████████| 19/19 [00:27<00:00,  1.47s/it]
100%|██████████| 43/43 [00:53<00:00,  1.25s/it]
100%|██████████| 19/19 [00:28<00:00,  1.48s/it]
100%|██████████| 43/43 [00:56<00:00,  1.32s/it]
100%|██████████| 19/19 [00:28<00:00,  1.48s/it]
100%|██████████| 43/43 [00:55<00:00,  1.30s/it]
100%|██████████| 19/19 [00:27<00:00,  1.47s/it]
100%|██████████| 43/43 [00:56<00:00,  1.30s/it]
100%|██████████| 19/19 [00:28<00:00,  1.48s/it]
100%|██████████| 43/43 [00:56<00:00,  1.

KeyboardInterrupt: 

In [16]:
from sklearn.metrics import classification_report

test_model = BertCRF(target_size=target_size)
test_model.load_state_dict(torch.load("model.pt"))
test_model = test_model.to(device)
predicted, truth, test_acc = evaluate(test_model, test_dataloader)
print(f"test accuracy: {test_acc * 100:.2f}%")
predicted = [idx2label[idx] for idx in predicted]
truth = [idx2label[idx] for idx in truth]
selected_labels = ['B-BODY','I-TEST', 'I-EXAMINATIONS',
            'I-TREATMENT', 'B-DRUG', 'B-TREATMENT', 'I-DISEASES', 'B-EXAMINATIONS',
            'I-BODY', 'B-TEST', 'B-DISEASES', 'I-DRUG']
report = classification_report(truth, predicted, labels=selected_labels)
print(report)



Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
100%|██████████| 20/20 [00:30<00:00,  1.54s/it]


test accuracy: 93.82%
                precision    recall  f1-score   support

        B-BODY       0.83      0.84      0.83      3078
        I-TEST       0.49      0.33      0.40      1604
I-EXAMINATIONS       0.85      0.76      0.80       833
   I-TREATMENT       0.90      0.84      0.87      1870
        B-DRUG       0.74      0.62      0.67       480
   B-TREATMENT       0.78      0.61      0.69       162
    I-DISEASES       0.83      0.76      0.79      6843
B-EXAMINATIONS       0.80      0.66      0.72       345
        I-BODY       0.69      0.75      0.72      2434
        B-TEST       0.46      0.27      0.34       567
    B-DISEASES       0.75      0.67      0.71      1310
        I-DRUG       0.85      0.73      0.79      1432

     micro avg       0.79      0.72      0.75     20958
     macro avg       0.75      0.65      0.69     20958
  weighted avg       0.78      0.72      0.74     20958

