In [1]:
import os
import logging
import warnings

# os.environ['CUDA_VISIBLE_DEVICES'] = '3'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
warnings.filterwarnings("ignore")

import torch
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.optim as optim

from datasets import load_dataset
from sklearn import metrics
from sklearn.model_selection import train_test_split
from transformers import BertTokenizer, AutoTokenizer, BertModel, get_linear_schedule_with_warmup
from transformers.optimization import get_cosine_schedule_with_warmup, AdamW
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from torchcrf import CRF
from tqdm import tqdm

class config:
    root_dir = ''
    data_dir = root_dir + 'data/example_datasets_msra/'
    load_before = False
    # bert_model = 'microsoft/deberta-v3-large'
    # bert_model = 'microsoft/deberta-v2-xxlarge'
    # bert_model = 'microsoft/mdeberta-v3-base'
    # bert_model = 'WENGSYX/Deberta-Chinese-Large'
    bert_model = 'hfl/chinese-roberta-wwm-ext-large'
    # bert_model = 'bert-base-uncased'
    # bert_model = 'bert-base-chinese'
    model_dir = root_dir + 'checkpoints/' + bert_model.split('/')[-1] + '.pt'
    device = torch.device('cuda:5' if torch.cuda.is_available() else 'cpu')

    # train config
    output_dir = 'outputs/'
    overwrite_output_dir = True
    epoch = 11
    batch_size = 256 + 128
    fp16 = True
    val_split_size = 0.13
    test_split_size = 0.17
    learning_rate = 3e-5
    weight_decay = 0.01
    clip_grad = 5
    patience = 0.0002
    patience_num = 10
    max_sequence_length = 256
    warm_up_ratio = 0.1

    labels = ['<PAD>', '[CLS]', '[SEP]', 'O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC']
    label2id = {tag: idx for idx, tag in enumerate(labels)}
    id2label = {idx: tag for idx, tag in enumerate(labels)}
    num_labels = len(label2id)

    tokenizer = BertTokenizer.from_pretrained(bert_model)

In [2]:
class NERDataset(Dataset):
    def __init__(self, words, labels, config, word_pad_idx=0, label_pad_idx=-1):
        self.tokenizer = config.tokenizer
        self.label2id = config.label2id
        self.id2label = {_id: _label for _label, _id in list(config.label2id.items())}
        self.dataset = self.preprocess(words, labels)
        self.word_pad_idx = word_pad_idx
        self.label_pad_idx = label_pad_idx
        self.device = config.device

    def preprocess(self, origin_sentences, origin_labels):
        sentences = []
        labels = []
        for line, tag in zip(origin_sentences, origin_labels):
            words = ['[CLS]'] + line[0:config.max_sequence_length-2] + ['[SEP]']
            label = [1] + tag[0:config.max_sequence_length-2] + [2]
            sentences.append(words)
            labels.append(label)

            start = config.max_sequence_length-2
            while len(tag) > start:
                sub_words = ['[SEP]'] + line[start:start+config.max_sequence_length-2] + ['[SEP]']
                sub_label = [2] + tag[start:start+config.max_sequence_length-2] + [2]
                sentences.append(sub_words)
                labels.append(sub_label)

                start += config.max_sequence_length

        data = [(sentence, label) for sentence, label in zip(sentences, labels)]
        return data

    def __getitem__(self, idx):
        words, tags = self.dataset[idx]
        token_ids = self.tokenizer.convert_tokens_to_ids(words)
        return token_ids, tags

    def __len__(self):
        return len(self.dataset)

    def collate_fn(self, batch):
        token_tensors = torch.LongTensor([i[0] + [0] * (config.max_sequence_length - len(i[0])) for i in batch])
        label_tensors = torch.LongTensor([i[1] + [0] * (config.max_sequence_length - len(i[1])) for i in batch])
        mask = (token_tensors > 0)
        return token_tensors, label_tensors, mask

In [3]:
train_test_ds = load_dataset('msra_ner', split='train+test')
train_x, test_x, train_y, test_y = train_test_split(
    train_test_ds['tokens'], 
    train_test_ds['ner_tags'], 
    test_size=config.test_split_size, 
    # random_state=0,
    # shuffle=True, 
)

train_x, val_x, train_y, val_y = train_test_split(
    train_x, 
    train_y, 
    test_size=config.val_split_size, 
    # random_state=0,
    # shuffle=True, 
)

train_dataset = NERDataset(train_x, train_y, config)
val_dataset = NERDataset(val_x, val_y, config)
test_dataset = NERDataset(test_x, test_y, config)

train_loader = DataLoader(train_dataset, batch_size=config.batch_size, collate_fn=train_dataset.collate_fn, shuffle=False) # , num_workers=4
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, collate_fn=train_dataset.collate_fn, shuffle=False) # , num_workers=4
test_loader = DataLoader(test_dataset, batch_size=config.batch_size, collate_fn=test_dataset.collate_fn, shuffle=False) # , num_workers=4

Reusing dataset msra_ner (/home/jovyan/.cache/huggingface/datasets/msra_ner/msra_ner/1.0.0/5ce47bc7f8da59fd9d0ad08d185fa72f5576b614f136a56e82c7669d22ea5cfe)


In [4]:
class BertBiLSTMxCRF(nn.Module):
    def __init__(self, config):
        super(BertBiLSTMxCRF, self).__init__()
        self.hidden_dim = 512

        self.bert = BertModel.from_pretrained(config.bert_model)

        self.embedding_dim = self.bert.config.hidden_size
        self.bilstm = nn.LSTM(
            input_size=self.embedding_dim,
            hidden_size=self.hidden_dim // 2,
            batch_first=True,
            num_layers=2,
            dropout=0.5,
            bidirectional=True
        )
        self.dropout = nn.Dropout(0.1)
        self.linear = nn.Linear(self.hidden_dim, config.num_labels)
        self.crf = CRF(config.num_labels, batch_first=True)
    
    def _get_features(self, sentence):
        with torch.no_grad():
            embeds  = self.bert(sentence)
        enc, _ = self.bilstm(embeds[0])
        enc = self.dropout(enc)
        feats = self.linear(enc)
        return feats

    def forward(self, sentence, tags, mask, is_test=False):
        emissions = self._get_features(sentence)
        if not is_test:
            loss=-self.crf.forward(emissions, tags, mask, reduction='mean')
            return loss
        else:
            decode=self.crf.decode(emissions, mask)
            return decode

In [5]:
def train(e, model, iterator, optimizer, scheduler, device):
    model.train()
    losses = 0.0
    step = 0
    for batch in (pbar := tqdm(iterator)):
        step += 1
        sentence, tags, mask = (i.to(device) for i in batch)

        loss = model(sentence, tags, mask)
        losses += loss.item()

        loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

        pbar.set_description(f'epoch: {e}')
        pbar.set_postfix({'train loss': '{:.4f}'.format(loss.item())})

    # print({'avg train loss': '{:.4f}'.format(losses/step)})

def validate(e, model, iterator, device):
    model.eval()
    Y, Y_hat = [], []
    losses = 0
    step = 0
    with torch.no_grad():
        for batch in (pbar := tqdm(iterator)):
            step += 1

            sentence, tags, mask = (i.to(device) for i in batch)

            y_hat = model(sentence, tags, mask, is_test=True)

            loss = model(sentence, tags, mask)
            losses += loss.item()
            # Save prediction
            for j in y_hat:
              Y_hat.extend(j)
            # Save labels
            mask = (mask==1)
            y_orig = torch.masked_select(tags, mask)
            Y.append(y_orig.cpu())

            pbar.set_description(f'epoch: {e}')
            pbar.set_postfix({'val loss': '{:.4f}'.format(loss.item())})

    Y = torch.cat(Y, dim=0).numpy()
    Y_hat = np.array(Y_hat)
    acc = (Y_hat == Y).mean()*100

    print({'avg val loss': '{:.4f}'.format(losses/step), 'val acc': acc})
    return model, losses/step, acc

def test(model, iterator, device):
    model.eval()
    Y, Y_hat = [], []
    with torch.no_grad():
        for batch in (pbar := tqdm(iterator, desc='test')):
            sentence, tags, mask = (i.to(device) for i in batch)
            y_hat = model(sentence, tags, mask, is_test=True)
            # Save prediction
            for j in y_hat:
              Y_hat.extend(j)
            # Save labels
            mask = (mask==1)
            y_orig = torch.masked_select(tags, mask)
            Y.append(y_orig)

    Y = torch.cat(Y, dim=0).numpy()
    y_true = [config.id2label[i] for i in Y]
    y_pred = [config.id2label[i] for i in Y_hat]

    return y_true, y_pred


model = BertBiLSTMxCRF(config).to(config.device)
optimizer = AdamW(model.parameters(), lr=1e-3, eps=1e-6, weight_decay=config.weight_decay)

len_dataset = len(train_dataset) 
total_steps = (len_dataset // config.batch_size) * config.epoch if len_dataset % config.batch_size == 0 else (len_dataset // config.batch_size + 1) * config.epoch
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps = config.warm_up_ratio * total_steps, num_training_steps = total_steps)


best_val_loss = 1e18
best_val_acc = 1e-18
for epoch in range(1, config.epoch+1):
    train(epoch, model, train_loader, optimizer, scheduler, config.device)
    candidate_model, loss, acc = validate(epoch, model, val_loader, config.device)

    if loss < best_val_loss and acc > best_val_acc:
        best_model = candidate_model
        best_val_loss = loss
        best_val_acc = acc
        torch.save(model.state_dict(), config.model_dir)

    print()

y_test, y_pred = test(best_model, test_loader, config.device)
print(metrics.classification_report(y_test, y_pred, labels=config.label2id.keys(), digits=3))


Some weights of the model checkpoint at hfl/chinese-roberta-wwm-ext-large were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
epoch: 1: 100%|██████████| 92/92 [10:16<00:00,  6.70s/it, train loss=9.6145]  
epoch: 1: 100%|██████████| 19/19 [03:48<00:00, 12.01s/it, va

{'avg val loss': '8.4045', 'val acc': 94.45856297062394}
