In [1]:
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from transformers import BertTokenizer, BertModel, BertForMaskedLM

In [2]:
class bert_cls(nn.Module):
    def __init__(self, model_name, num_labels, freeze_bert=True):
        super(bert_cls, self).__init__()
        self.bert = BertModel.from_pretrained(model_name)
        print(self.bert)
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(768, num_labels)
        self.softmax = nn.Softmax(dim=-1)

        if freeze_bert:
            for param in self.bert.parameters():
                param.requires_grad = False

    def forward(self, input_ids, attention_mask):
        output = self.bert(input_ids, attention_mask)
        hidden_state = output['last_hidden_state'][:, 0, :]
        hidden_state = self.dropout(hidden_state)
        logits = self.classifier(hidden_state)
        return logits

model = bert_cls('bert-base-chinese', 2)

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(21128, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# checkpoint = torch.load('model/bert_cls_model.pth', map_location='cpu')
# model.load_state_dict(checkpoint)

In [4]:
# 实例测试流程(这时候还没load参数, 只是展示了整个pipeline的流程)
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
sentence = '我想要一个美女'
token = tokenizer(sentence, return_tensors='pt', padding='max_length', truncation=True, max_length=128)
input_ids = token['input_ids'].to(device)
attention_mask = token['attention_mask'].to(device)
print('input_ids:', input_ids, '\n', 'attention_mask:', attention_mask)

input_ids: tensor([[ 101, 2769, 2682, 6206,  671,  702, 5401, 1957,  102,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0]]) 
 attention_mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 

In [5]:
model.eval()
model(input_ids, attention_mask)

tensor([[-0.0604,  0.0856]], grad_fn=<AddmmBackward0>)

In [6]:
df = pd.read_csv('dataset/message80W.csv', encoding='utf-8', header=None, index_col=0)
df.columns = ['label', 'message']
df.head()

Unnamed: 0_level_0,label,message
0,Unnamed: 1_level_1,Unnamed: 2_level_1
1,0,商业秘密的秘密性那是维系其商业价值和垄断地位的前提条件之一
2,1,南口阿玛施新春第一批限量春装到店啦   春暖花开淑女裙、冰蓝色公主衫 ...
3,0,带给我们大常州一场壮观的视觉盛宴
4,0,有原因不明的泌尿系统结石等
5,0,23年从盐城拉回来的麻麻的嫁妆


In [7]:
from tqdm import tqdm

msg_list = df['message'][:20000].tolist()
msg_tokens = [tokenizer(msg, return_tensors='pt', padding='max_length', truncation=True, max_length=128) for msg in tqdm(msg_list)]

100%|██████████| 20000/20000 [00:17<00:00, 1147.76it/s]


In [8]:
def test():
    pred_list = []
    for i in tqdm(range(len(msg_tokens[:100]))):
        # print(msg_tokens[i]['input_ids'])
        # print(msg_tokens[i]['attention_mask'])
        logics = (model(msg_tokens[i]['input_ids'], msg_tokens[i]['attention_mask']))
        logics = logics.detach().cpu().numpy()
        pred = np.argmax(logics, axis=-1).tolist()
        pred_list += pred
    return pred_list


In [9]:
label_list = df['label'][:100].tolist()
# fine-tune 之前
pred_list = test()
from sklearn.metrics import classification_report
report = classification_report(label_list, pred_list)
print(report)

100%|██████████| 100/100 [01:45<00:00,  1.06s/it]


              precision    recall  f1-score   support

           0       0.85      0.63      0.72        87
           1       0.09      0.23      0.13        13

    accuracy                           0.58       100
   macro avg       0.47      0.43      0.42       100
weighted avg       0.75      0.58      0.65       100



In [11]:
checkpoint = torch.load('bert-base-chinese/pytorch_model.bin', map_location='cpu')
model.load_state_dict(checkpoint)

<All keys matched successfully>

In [12]:
# fine-tune 之后
pred_list = test()
report = classification_report(label_list, pred_list)
print(report)

100%|██████████| 100/100 [02:01<00:00,  1.22s/it]

              precision    recall  f1-score   support

           0       1.00      1.00      1.00        87
           1       1.00      1.00      1.00        13

    accuracy                           1.00       100
   macro avg       1.00      1.00      1.00       100
weighted avg       1.00      1.00      1.00       100




