<a href="https://colab.research.google.com/github/li199959/one/blob/main/%E5%88%86%E7%B1%BB.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
from pickletools import optimize
from transformers import BertTokenizer, BertModel
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data

In [None]:
!pip install transformers

In [3]:
data_path = 'toutiao_cat_data.txt'

In [5]:
class ToutiaoDataset(data.Dataset):
    def __init__(self, data_path) -> None:
        super(ToutiaoDataset, self).__init__()
        self.build(data_path)

    def build(self, data_path):
        with open(data_path, 'r', encoding='utf-8') as f:
            texts = []
            labels = []
            for line in f:
                    _, category, _, text, key_word = line.strip().split('_!_')

                    labels.append(int(category)-100)
                    texts.append(text)
                    # print(texts[-1], labels[-1]);input()
            self.texts = texts
            self.labels = labels


    def __len__(self):
        return len(self.texts)

    def __getitem__(self, i):
        text_i = self.texts[i]
        label_i = self.labels[i]
        return text_i, label_i

toutiao_dataset = ToutiaoDataset(data_path)
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
def collate_fn(text_label):
    #text_label =[
    # ('我是大肥猪', 0), 
    # ('我是小肥猪', 1),
    # ('兔子好可爱', 2)
    # ]
    texts = [text for text, _ in text_label]
    labels = [label for _, label in text_label]
    max_len = min(max(len(t) for t in texts) + 2, 512)

    # texts = [
    #     '我是大肥猪',
    #     '我是小肥猪',
    #     '兔子好可爱'
    # ]
    data = bert_tokenizer.batch_encode_plus(
        batch_text_or_text_pairs=texts,
        add_special_tokens=True,

        truncation=True,
        padding='max_length',
        max_length=max_len,
        return_tensors='pt',#tf,pt,np

        return_token_type_ids=True,
        return_attention_mask=True,
        return_special_tokens_mask=True,
    )
    # for key, value in data.items():
    #     print(key, ':', value);input()

    input_ids = data['input_ids']
    attention_mask = data['attention_mask']
    token_type_ids = data['token_type_ids']
    labels = torch.LongTensor(labels)
    return input_ids, attention_mask, token_type_ids, labels


#(batch_size, seq_len, hidden_size)
pretrained = BertModel.from_pretrained('bert-base-chinese')
# pretrained = BertModel()
#pretrained.load_state_dict(xx)
class ToutiaoClassification(nn.Module):
    def __init__(self, hidden_size, num_classes) -> None:
        super(ToutiaoClassification, self).__init__()
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, input_ids, attention_mask, token_type_ids):
        with torch.no_grad():
            out = pretrained(input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids)
        #fc(.shape = (batch_size, hidden_size))
        prediction = self.fc(out.last_hidden_state[:, 0, :])
        return prediction

lr = 5e-4
batch_size = 32
shuffle = True
epochs = 4

dataloader = data.DataLoader(toutiao_dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)
model = ToutiaoClassification(hidden_size=768, num_classes=17)
optimizer = optim.AdamW(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

model.train()
for epoch in range(epochs):
    counter = 0
    for input_ids, attention_mask, token_type_ids, label in dataloader:
        prediciton = model(input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids)
        
        loss = criterion(prediciton, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        prediciton = prediciton.argmax(dim=1)
        accuracy = (prediciton == label).sum().item() / len(label)
        counter += 1

        print(counter, loss.item(), accuracy)

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


1 3.0365872383117676 0.0
2 2.798135280609131 0.03125
3 2.816507577896118 0.0625
4 2.7371623516082764 0.15625
5 2.7119100093841553 0.125
6 2.532694101333618 0.21875
7 2.6986520290374756 0.09375
8 2.648597240447998 0.125
9 2.7281436920166016 0.125
10 2.474766731262207 0.25
11 2.5555312633514404 0.1875
12 2.406691789627075 0.375
13 2.500128746032715 0.25
14 2.592433452606201 0.1875
15 2.3990066051483154 0.21875
16 2.40392804145813 0.3125
17 2.2900640964508057 0.3125
18 2.4136197566986084 0.21875
19 2.309659957885742 0.3125
20 2.2608439922332764 0.3125
21 2.231853723526001 0.28125
22 2.2815401554107666 0.28125
23 2.226658821105957 0.34375
24 2.1324214935302734 0.3125
25 2.2456893920898438 0.3125
26 2.373019218444824 0.3125
27 2.044534683227539 0.53125
28 2.086271286010742 0.3125
29 2.1630616188049316 0.375
30 2.0102007389068604 0.53125
31 2.0341320037841797 0.5625
32 2.194951057434082 0.5
33 2.0136704444885254 0.5
34 2.1275453567504883 0.4375
35 2.1159725189208984 0.5
36 2.086566925048828 

KeyboardInterrupt: ignored