In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import json
import matplotlib.pyplot as plt
import numpy as np
import pickle
import random
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from transformers import BertModel, AutoTokenizer

from crf import CRF
from pytorchtools import EarlyStopping

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(torch.cuda.is_available())

import logging
def logger(content):
    logging.getLogger('matplotlib.font_manager').disabled = True
    log_format = '[%(asctime)s] %(message)s'
    date_format = '%Y%m%d %H:%M:%S'
    logging.basicConfig(level = logging.DEBUG, format = log_format, datefmt = date_format)
    logging.info(content)

True


In [2]:
args = {
    'path': '/data/pretrained/bert-base-chinese/',
    'bert_out_dim': 768,
    'n_class': 9, 
    'dropout': 0.2
}

In [3]:
dataset = []
words = set()
varietys = set()
categorys = set()
with open('./Dataset/comment_labeled.json', 'r') as f:
    for line in f:
        data = json.loads(line)
        comment_id = data['comment_id']
        comment_variety = data['comment_variety'] # 花生
        user_star = data['user_star'] # 2
        comment_text = data['comment_text']
        comment_units = data['comment_units'] # 多个四元组
        dataset.append({
            'id': comment_id, 
            'variety': comment_variety, 
            'user_star': user_star, 
            'text': comment_text, 
            'comment': comment_units, 
        })
        words |= set(comment_text)
        varietys.add(comment_variety)
        categorys |= {i['aspect'] for i in comment_units}

logger('Load dataset: {}'.format(len(dataset)))
logger('words: {}, variety: {}, category: {}'.format(len(words), len(varietys), len(categorys)))

word_list = list(words)
word_list.insert(0, '[PAD]')
word_list.insert(1, '[UNK]')
word_dict = {word_list[i]: i for i in range(len(word_list))}


variety_list = list(varietys)
variety_list.insert(0, '<PAD>')
variety_list.insert(1, '<OOV>')
variety_dict = {variety_list[i]: i for i in range(len(variety_list))}

[20240406 22:21:03] Load dataset: 8670
[20240406 22:21:03] words: 2250, variety: 26, category: 9


In [4]:
def load_glove(word_to_ix, dim = 100):
    if dim == 100:
        path = '/data/pretrained/Glove/glove.6B.100d.txt'
    elif dim == 300:
        path = '/data/pretrained/Glove/glove.840B.300d.txt'
    word_emb = []
    word_emb = torch.zeros((len(word_to_ix), dim), dtype = torch.float)
    with open(path, 'r') as f:
        for line in f:
            data = line.strip().split(' ') # [word emb1 emb2 ... emb n]
            word = data[0]
            if word in word_to_ix:
                word_emb[word_to_ix[word]] = torch.tensor([float(i) for i in data[1:]])
    return word_emb
word_emb = load_glove(word_dict, 300)

logger('Load Glove Word embedding: {}'.format(word_emb.shape))

[20240406 22:21:40] Load Glove Word embedding: torch.Size([2252, 300])


In [5]:
# tag BIO label
tag_list = ['O'] + [i + '-A' for i in ['B', 'I', 'E', 'S']] + [i + '-O' for i in ['B', 'I', 'E', 'S']]
tag_dict = {tag_list[i]: i for i in range(len(tag_list))}
print(tag_dict)

def tagging(text, aspects, opinions):
    tags = ['O'] * len(text)
    for t in sorted(aspects, key = lambda x: x['tail'] - x['head']):
        if t['tail'] - t['head'] == 1:
            tags[t['head']] = 'S-A'
        else:
            tags[t['head']] = 'B-A'
            for i in range(t['head'] + 1, t['tail'] - 1):
                tags[i] = 'I-A'
            tags[t['tail'] - 1] = 'E-A'
    for o in sorted(opinions, key = lambda x: x['tail'] - x['head']):
        if o['tail'] - o['head'] == 1:
            tags[o['head']] = 'S-O'
        else:
            tags[o['head']] = 'B-O'
            for i in range(o['head'] + 1, o['tail'] - 1):
                tags[i] = 'I-O'
            tags[o['tail'] - 1] = 'E-O'
    return tags

sen_len = 40
tokenizer = AutoTokenizer.from_pretrained('/data/pretrained/bert-base-chinese/')
tagged_dataset = []
for item in dataset:
    aspects = []
    opinions = []
    for comment in item['comment']: # 每个 comment 是一个四元组，描述一个方面
        aspects += comment['target']
        opinions += comment['opinion']
    text = item['text']
    tags = tagging(text, aspects, opinions) + ['O'] * (sen_len - len(text))
    word_ids = [word_dict[word] if word in word_dict else word_dict['[UNK]'] for word in item['text']] + \
               [word_dict['[PAD]']] * (sen_len - len(text))
    word_masks = [1] * len(text) + [0] * (sen_len - len(text))
    tag_ids = [tag_dict[i] for i in tags]
    tagged_dataset.append({
        'text': item['text'],
        'word_ids': word_ids, 
        'word_masks': word_masks,
        'labels': tag_ids
    })
print(tagged_dataset[0])

{'O': 0, 'B-A': 1, 'I-A': 2, 'E-A': 3, 'S-A': 4, 'B-O': 5, 'I-O': 6, 'E-O': 7, 'S-O': 8}
{'text': '这次买的没有之前买的品质好，以前每一个颗粒饱满，这次的质量参差不齐。', 'word_ids': [1482, 473, 2186, 562, 1207, 1686, 979, 752, 2186, 562, 2091, 57, 2074, 468, 316, 752, 543, 18, 1453, 2124, 2104, 597, 516, 468, 1482, 473, 562, 57, 2187, 1145, 493, 1509, 1141, 1066, 0, 0, 0, 0, 0, 0], 'word_masks': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], 'labels': [5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 6, 6, 7, 0, 0, 0, 0, 0, 0, 0]}


In [6]:
def label_sentence_entity(text, tags, tag_list):
    tags = [tag_list[i] for i in tags]
    entity = []
    count = len(text)
    i = 0
    while i < count:
        if tags[i][0] == 'B':
            j = i + 1
            while j < count:
                if tags[j][0] == 'E':
                    break
                else:
                    j += 1
            entity.append({
                "text": ''.join(text[i: j]),
                "start_index": i,
                "end_index": j,
                "label": tags[i][2:]
            })
            i = j + 1
        elif tags[i][0] == 'S':
            entity.append({
                "text": text[i],
                "start_index": i,
                "end_index": i,
                "label": tags[i][2:]
            })
            i += 1
        else:
            i += 1
    return entity

# print(tokens[0], labels[0])
# label_sentence_entity()

In [7]:
def ner_metrics(pred_entities, gold_entities):
    correct_num = 0
    predict_num = 0
    gold_num = 0
    for i in range(len(pred_entities)):
        gold_entity = gold_entities[i]
        pred_entity = pred_entities[i]
        gold_num += len(gold_entity)
        predict_num += len(pred_entity)
        for entity in gold_entity:
            if entity in pred_entity:
                correct_num += 1
    precision = correct_num / (predict_num + 0.000000001)
    recall = correct_num / (gold_num + 0.000000001)
    f1 = 2 * precision * recall / (precision + recall + 0.000000001)
    return precision, recall, f1

# decode from BIOES
def decode_ner(pred, labels, texts):
    pred_entities = []
    gold_entities = []
    for j in range(pred.shape[0]):
        gold_entity = label_sentence_entity(texts[j], labels[j].tolist(), tag_list)
        pred_entity = label_sentence_entity(texts[j], pred[j], tag_list) # text, start_index, end_index, label
        pred_entities.append(pred_entity)
        gold_entities.append(gold_entity)
        # precision, recall, f1 = ner_metrics(pred_entities, gold_entities)
    return pred_entities, gold_entities

# Model

### Train method

In [8]:
import time
def save_model(model, path):
    ts = time.strftime('%m%d%H%M', time.localtime())
    torch.save(model.state_dict(), '{}.{}'.format(path, ts))
    print('Save model to', '{}.{}'.format(path, ts))

In [9]:
def evaluate(model, test_loader, tag_list):
    model.eval()
    correct_num, predict_num, gold_num = 0, 0, 0
    with torch.no_grad():
        for i, batch in enumerate(test_loader):
            tokens, masks, labels, texts = batch['tokens'], batch['masks'], batch['labels'], batch['texts']
            sen_len = max(len(text) for text in texts)
            labels = labels[:, :sen_len]
            pred = model.predict(**batch)[:, :sen_len]
            # print(pred)
            for j in range(labels.shape[0]):
                gold_entity = label_sentence_entity(texts[j], labels[j].tolist(), tag_list)
                pred_entity = label_sentence_entity(texts[j], pred[j], tag_list)
                gold_num += len(gold_entity)
                predict_num += len(pred_entity)
                for entity in gold_entity:
                    if entity in pred_entity:
                        correct_num += 1
                # print(gold_entity)
                # print(pred_entity)
                # return
    precision = correct_num / (predict_num + 0.000000001)
    recall = correct_num / (gold_num + 0.000000001)
    f1 = 2 * precision * recall / (precision + recall + 0.000000001)
    logger('[Test] Precision: {:.6f} Recall: {:.6f} F1: {:.6f}'.format(precision, recall, f1))
    return precision, recall, f1

# evaluate(model, test_loader, tag_list)

In [10]:
def train_lstm(model, train_loader, valid_loader, epochs = 100, lr = 1e-4, patience = 5):
    optimizer = optim.Adam(model.parameters(), lr = 3e-4)
    early_stopping = EarlyStopping(patience = patience, verbose = False)
    entrophy = nn.CrossEntropyLoss()
    avg_train_losses = []
    avg_valid_losses = []
    for epoch in range(epochs):
        train_correct, train_total, valid_correct, valid_total = 0, 0, 0, 0
        train_losses = []
        valid_losses = []
        model.train()
        for _, batch in enumerate(train_loader):
            tokens, masks, labels = batch
            sen_len = torch.max(torch.sum(masks, dim = 1, dtype = torch.int64)).item()
            tokens = tokens[:, :sen_len]
            masks = masks[:, :sen_len]
            labels = labels[:, :sen_len]
            optimizer.zero_grad()
            output = model(tokens, masks) # (n_batch, n_token, n_class)
            loss = entrophy(output.permute(0, 2, 1), labels)
            train_losses.append(loss.item())
            loss.backward()
            optimizer.step()
            predict = model.predict(tokens, masks, labels) # (n_batch, n_tokens)
            train_correct += torch.sum(predict[masks == 1] == labels[masks == 1]).item()
            train_total += torch.sum(masks == 1).item()
        avg_train_loss = np.average(train_losses)
        avg_train_losses.append(avg_train_loss)

        model.eval()
        with torch.no_grad():
            gold_num = 0
            predict_num = 0
            correct_num = 0
            for i, batch in enumerate(valid_loader):
                tokens, masks, labels = batch
                sen_len = torch.max(torch.sum(masks, dim = 1, dtype = torch.int64)).item()
                tokens = tokens[:, :sen_len]
                masks = masks[:, :sen_len]
                labels = labels[:, :sen_len]
                output = model(tokens, masks)
                loss = entrophy(output.permute(0, 2, 1), labels)
                valid_losses.append(loss.item())
                predict = torch.max(output, dim = 2).indices # (n_batch, n_tokens)
                valid_correct += torch.sum(predict[masks == 1] == labels[masks == 1]).item()
                valid_total += torch.sum(masks == 1).item()
                for j in range(labels.shape[0]):
                    gold_entity = label_sentence_entity(text[j], labels[j].tolist(), tag_list)
                    pred_entity = label_sentence_entity(text[j], predict[j], tag_list)
                    gold_num += len(gold_entity)
                    predict_num += len(pred_entity)
                    for entity in gold_entity:
                        if entity in pred_entity:
                            correct_num += 1
            avg_valid_loss = np.average(valid_losses)
            avg_valid_losses.append(avg_valid_loss)
        precision = correct_num / (predict_num + 0.000000001)
        recall = correct_num / (gold_num + 0.000000001)
        f1 = 2 * precision * recall / (precision + recall + 0.000000001)
        logger('[Test] Precision: {:.8f} Recall: {:.8f} F1: {:.8f}'.format(precision, recall, f1))
        
        logger('[epoch {:d}] TLoss: {:.3f} VLoss: {:.3f} TAcc: {:.3f} VAcc: {:.3f}'.format(
                epoch + 1, avg_train_loss, avg_valid_loss, train_correct / train_total, valid_correct / valid_total))
        early_stopping(avg_valid_loss, model)
        if early_stopping.early_stop:
            logger("Early stopping")
            break


In [11]:
def train(model, train_loader, valid_loader, loss_mode = 'linear', epochs = 100, lr = 1e-5):
    print('Train {} model with lr = {}'.format(model.__class__.__name__, lr))
    optimizer = optim.Adam(model.parameters(), lr)
    early_stopping = EarlyStopping(patience = 5, verbose = False)
    entrophy = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        valid_correct = 0
        valid_total = 0
        train_correct = 0
        train_total = 0
        train_losses = []
        valid_losses = []
        model.train()
        for _, batch in enumerate(train_loader):
            optimizer.zero_grad()
            if loss_mode == 'linear':
                output = model(**batch) # (n_batch, n_token, n_class)
                loss = entrophy(output.permute(0, 2, 1), batch['labels'])
            elif loss_mode == 'crf':
                loss = model(**batch) # (n_batch, n_token, n_class)
            else:
                raise ValueError('Invalid loss mode %s', loss_mode)
            train_losses.append(loss.item())
            loss.backward()
            optimizer.step()
            predict = model.predict(**batch) # (n_batch, n_tokens)
            train_correct += torch.sum(torch.logical_and(predict == batch['labels'][:, :predict.shape[1]], batch['masks'] == 1)).item()
            train_total += torch.sum(batch['masks'] == 1).item()
        avg_train_loss = np.average(train_losses)

        model.eval()
        with torch.no_grad():
            for i, batch in enumerate(valid_loader):
                if loss_mode == 'linear':
                    output = model(**batch) # (n_batch, n_token, n_class)
                    loss = entrophy(output.permute(0, 2, 1), batch['labels'])
                elif loss_mode == 'crf':
                    loss = model(**batch) # (n_batch, n_token, n_class)
                else:
                    raise ValueError('Invalid loss mode %s', loss_mode)
                valid_losses.append(loss.item())
                predict = model.predict(**batch)
                valid_correct += torch.sum(torch.logical_and(predict == batch['labels'][:, :predict.shape[1]], batch['masks'] == 1)).item()
                valid_total += torch.sum(batch['masks'] == 1).item()
            avg_valid_loss = np.average(valid_losses)
        precision, recall, f1 = evaluate(model, valid_loader, tag_list)
        
        # logger('[epoch {:d}] TLoss: {:.3f} VLoss: {:.3f} TAcc: {:.3f} VAcc: {:.3f}'.format(
        #     epoch + 1, avg_train_loss, avg_valid_loss, train_correct / train_total, valid_correct / valid_total))
        logger('[epoch {:d}] TLoss: {:.3f} VLoss: {:.3f} TAcc: {:.3f} VAcc: {:.3f}'.format(
            epoch + 1, avg_train_loss, avg_valid_loss, train_correct / train_total, valid_correct / valid_total))
        logger('Precision: {:.3f} Recall: {:.3f} F1: {:.3f}'.format(precision, recall, f1))
        early_stopping(-valid_correct / valid_total, model)
        if early_stopping.early_stop:
            logger("Early stopping")
            break
    save_model(model, './results/{}.{}'.format(model.__class__.__name__, lr))
    return model


## LSTM + Linear

In [9]:
def collate_lstm(batch):
    tokens = torch.tensor([item['word_ids'] for item in batch], dtype = torch.long, device = device)
    masks = torch.tensor([item['word_masks'] for item in batch], dtype = torch.bool, device = device)
    labels = torch.tensor([item['labels'] for item in batch], dtype = torch.long, device = device)
    sen_len = torch.max(torch.sum(masks, dim = 1, dtype = torch.int64)).item()
    tokens = tokens[:, :sen_len]
    masks = masks[:, :sen_len]
    labels = labels[:, :sen_len]
    return {
        'tokens': tokens, 
        'masks': masks, 
        'labels': labels,
        'texts': [item['text'] for item in batch],
    }

n_train, n_dev = int(0.6 * len(tagged_dataset)), int(0.2 * len(tagged_dataset))
batch_size = 8
train_loader = DataLoader(tagged_dataset[:n_train], batch_size = batch_size, collate_fn = collate_lstm)
valid_loader = DataLoader(tagged_dataset[n_train: n_train + n_dev], batch_size = batch_size, collate_fn = collate_lstm)
test_loader = DataLoader(tagged_dataset[n_train + n_dev:], batch_size = batch_size, collate_fn = collate_lstm)

In [8]:
class LSTMLinear(nn.Module):
    def __init__(self, word_emb, n_class = 9, dropout = 0.2, num_layers = 1, hidden_dim = 200, emb_dim = 300):
        super().__init__()
        self.word_embedding = nn.Embedding.from_pretrained(word_emb)
        self.dropout1 = nn.Dropout(p = dropout)
        self.dropout2 = nn.Dropout(p = dropout)
        self.lstm = nn.LSTM(emb_dim, hidden_dim // 2,
                            num_layers = num_layers, bidirectional=True)
        self.hidden2tag = nn.Linear(hidden_dim, n_class)
        
    def forward(self, tokens, masks):
        embeds = self.word_embedding(tokens)
        embeds = self.dropout1(embeds) # (batch_size, sen_len, 256)
        sen_len = torch.sum(masks, dim = 1, dtype = torch.int64).to('cpu') # (batch_size)
        pack_seq = pack_padded_sequence(embeds, sen_len, batch_first = True, enforce_sorted = False)
        lstm_out, _ = self.lstm(pack_seq)
        lstm_out, _ = pad_packed_sequence(lstm_out, batch_first = True) # (batch_size, seq_len, hidden_size)
        lstm_feats = self.hidden2tag(lstm_out) # （batch_size, seq_len, tagset_size)
        lstm_feats = self.dropout2(lstm_feats)

        return lstm_feats
    
    def predict(self, tokens, masks):
        lstm_feats = self.forward(tokens, masks)
        predict = torch.argmax(lstm_feats, dim = 2) # (n_batch, n_tokens)
        return predict

## Bert Model

In [12]:
# load dataset
args = {
    'path': '/data/pretrained/bert-base-chinese/',
    'bert_out_dim': 768,
    'n_class': 9, 
    'dropout': 0.2
}
tokenizer = AutoTokenizer.from_pretrained('/data/pretrained/bert-base-chinese/')
def collate(batch):
    tokens = tokenizer([item['text'] for item in batch], 
                       padding = 'max_length', truncation = True, max_length = 40, return_tensors = 'pt')
    labels = torch.tensor([item['labels'] for item in batch], dtype = torch.long, device = device)
    # entity_embeds = torch.tensor([item['entity_embeds'] for item in batch], dtype = torch.float, device = device)
    # return tokens['input_ids'].to(device), tokens['attention_mask'].to(device), labels, [item['text'] for item in batch]
    return {
        'tokens': tokens['input_ids'].to(device),
        'masks': tokens['attention_mask'].to(device),
        'labels': labels,
        'texts': [item['text'] for item in batch],
        # 'entity_embeds': entity_embeds,
    }

n_train, n_dev = int(0.6 * len(tagged_dataset)), int(0.2 * len(tagged_dataset))
batch_size = 8
train_loader = DataLoader(tagged_dataset[:n_train], batch_size = batch_size, collate_fn = collate)
valid_loader = DataLoader(tagged_dataset[n_train: n_train + n_dev], batch_size = batch_size, collate_fn = collate)
test_loader = DataLoader(tagged_dataset[n_train + n_dev:], batch_size = batch_size, collate_fn = collate)

### Bert Linear

In [13]:
class BertLinear(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.bert = BertModel.from_pretrained('/data/pretrained/bert-base-chinese/')
        self.dropout = nn.Dropout()
        self.cls = nn.Linear(args['bert_out_dim'], args['n_class'])
        
    def forward(self, tokens, masks, labels = None, texts = None):
        bert_out = self.bert(tokens, masks)['last_hidden_state'] # (n_batch, n_tokens, n_emb)
        bert_out = self.dropout(bert_out)
        cls_out = self.cls(bert_out)
        return cls_out

    def predict(self, **batch):
        cls_out = self.forward(batch['tokens'], batch['masks'])
        pred = torch.max(cls_out, dim = 2).indices
        return pred

# model = BertLinear(args, tokenizer, tag_dict).to(device)


In [14]:
# BertLinear
model = BertLinear(args).to(device)
train(model, train_loader, valid_loader, epochs = 100, lr = 1e-6)
evaluate(model, test_loader, tag_list)

Some weights of the model checkpoint at /data/pretrained/bert-base-chinese/ were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Train BertLinear model with lr = 1e-06


[20240406 22:23:44] [Test] Precision: 0.214992 Recall: 0.107028 F1: 0.142911
[20240406 22:23:44] [epoch 1] TLoss: 0.941 VLoss: 0.587 TAcc: 0.527 VAcc: 0.645
[20240406 22:23:44] Precision: 0.215 Recall: 0.107 F1: 0.143
[20240406 22:24:57] [Test] Precision: 0.399838 Recall: 0.467702 F1: 0.431116
[20240406 22:24:57] [epoch 2] TLoss: 0.551 VLoss: 0.507 TAcc: 0.658 VAcc: 0.684
[20240406 22:24:57] Precision: 0.400 Recall: 0.468 F1: 0.431
[20240406 22:26:12] [Test] Precision: 0.454973 Recall: 0.550294 F1: 0.498114
[20240406 22:26:12] [epoch 3] TLoss: 0.473 VLoss: 0.488 TAcc: 0.698 VAcc: 0.696
[20240406 22:26:12] Precision: 0.455 Recall: 0.550 F1: 0.498
[20240406 22:27:28] [Test] Precision: 0.492835 Recall: 0.599356 F1: 0.540901
[20240406 22:27:28] [epoch 4] TLoss: 0.438 VLoss: 0.478 TAcc: 0.718 VAcc: 0.703
[20240406 22:27:28] Precision: 0.493 Recall: 0.599 F1: 0.541
[20240406 22:28:43] [Test] Precision: 0.508405 Recall: 0.612995 F1: 0.555823
[20240406 22:28:43] [epoch 5] TLoss: 0.416 VLoss: 0

Save model to ./results/BertLinear.1e-06.04062234


[20240406 22:34:58] [Test] Precision: 0.540802 Recall: 0.523449 F1: 0.531984


(0.5408024481467968, 0.5234490702648472, 0.5319842791220859)

## BERT ABSA

In [16]:
# 直接使用MLP对于进行分类
class BertLinear2(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.bert = BertModel.from_pretrained('/data/pretrained/bert-base-chinese/')
        self.dropout = nn.Dropout()
        self.cls = nn.Linear(args['bert_out_dim'], args['n_class'])
        
    def forward(self, **batch):
        bert_out = self.bert(batch['tokens'], batch['masks'])['last_hidden_state'] # (n_batch, n_tokens, n_emb)
        bert_out = self.dropout(bert_out)
        cls_out = self.cls(bert_out)
        return cls_out

    def predict_tagging(self, **batch):
        cls_out = self.forward(batch['tokens'], batch['masks'])
        pred = torch.max(cls_out, dim = 2).indices
        return pred

    # NER
    def predict(self, **batch):
        # sequence tagging
        cls_out = self.forward(**batch)
        pred = torch.max(cls_out, dim = 2).indices
        return pred
    
    # decode from BIOES
    def decode(self, pred, labels, texts):
        # NER: decode from BIOES
        pred_entities, gold_entities = decode_ner(pred, labels, texts)
        # pred_entities, gold_entities = decode_ner(pred, batch['labels'], batch['texts'])
        # metrics
        # precision, recall, f1 = ner_metrics(pred_entities, gold_entities)
        return pred_entities, gold_entities
        pass
    
    
# model = BertLinear(args, tokenizer, tag_dict).to(device)


In [21]:
model = BertLinearNER(args).to(device)
train(model, train_loader, valid_loader, epochs = 100, lr = 1e-6)
evaluate(model, test_loader, tag_list)

Some weights of the model checkpoint at /data/pretrained/bert-base-chinese/ were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Train BertLinearNER model with lr = 1e-06


AttributeError: 'tuple' object has no attribute 'shape'