In [105]:
import torch
from torch.utils import data
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
from transformers import BertTokenizer
from transformers import BertModel
from tqdm import tqdm

In [36]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [25]:
Data = pd.read_csv('./AG_NEWS/train.csv',header=None)

In [26]:
tokenizer = BertTokenizer.from_pretrained('./Bert_Model/bert-base-uncased')

In [78]:
class Dataset(data.Dataset):
    def __init__(self,type='train'):
        super().__init__()
        if type == 'train':
            self.data = pd.read_csv('./AG_NEWS/train.csv',header=None)
        elif type == 'test':
            self.data = pd.read_csv('./AG_NEWS/test.csv',header=None)
        else:
            raise Exception('type只能为train或者test')
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') # 加载bert的基础分词器
        self.TEXT_LEN = max(self.data.iloc[:][1].apply(len)) # 记录训练集当中的最长句子
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self,index):
        label,text = self.data.iloc[index][0],self.data.iloc[index][1] # 按照索引读取类别和正文
        tokened = self.tokenizer(text)
        input_ids = tokened['input_ids']
        mask = tokened['attention_mask']
        BERT_PAD_ID = self.tokenizer.pad_token_id
        if len(input_ids) < self.TEXT_LEN:
            pad_len = (self.TEXT_LEN - len(input_ids))
            input_ids += [BERT_PAD_ID] * pad_len
            mask += [0] * pad_len
        target = int(label - 1) #这里需要注意-1
        return torch.tensor(input_ids[:self.TEXT_LEN]), torch.tensor(mask[:self.TEXT_LEN]), torch.tensor(target)

In [79]:
def get_label():
    text = open('./AG_NEWS/classes.txt').read()
    id2label = text.split()
    return id2label, {v: k for k, v in enumerate(id2label)}

In [87]:
EMBEDDING_DIM = 768
NUM_FILTERS = 256
NUM_CLASSES = 4
FILTER_SIZES = [2, 3, 4]
class TextCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.bert = BertModel.from_pretrained('./Bert_Model/bert-base-uncased')
        for name ,param in self.bert.named_parameters():
            param.requires_grad = False
        self.convs = nn.ModuleList([nn.Conv2d(1, NUM_FILTERS, (i, EMBEDDING_DIM)) for i in FILTER_SIZES])
        self.linear = nn.Linear(NUM_FILTERS * 3, NUM_CLASSES)

    def conv_and_pool(self, conv, input):
        out = conv(input)
        out = F.relu(out)
        return F.max_pool2d(out, (out.shape[2], out.shape[3])).squeeze()

    def forward(self, input, mask):
        out = self.bert(input, mask)[0].unsqueeze(1)
        out = torch.cat([self.conv_and_pool(conv, out) for conv in self.convs], dim=1)
        return self.linear(out)

In [88]:
# 下面开始进行模型的训练过程

In [94]:
from torch.utils.data.dataset import random_split 

In [96]:
id2label, _ = get_label()
train_dataset = Dataset('train')
spilt_train,split_valid =  random_split(train_dataset,[int(len(train_dataset)*0.8),len(train_dataset)-int(len(train_dataset)*0.8)])
train_loader = data.DataLoader(spilt_train, batch_size=10, shuffle=True)
dev_loader = data.DataLoader(split_valid, batch_size=10, shuffle=True)

In [107]:
EPOCH = 3
LR = 1e-3
model = TextCNN().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
loss_fn = nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,1.0,gamma=0.1)

Some weights of the model checkpoint at ./Bert_Model/bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [108]:
from sklearn.metrics import classification_report
def evaluate(pred, true, target_names=None, output_dict=False):
    return classification_report(
        true,
        pred,
        target_names=target_names,
        output_dict=output_dict,
        zero_division=0,
    )

In [109]:
MODEL_DIR = './OUTPUT/'

In [110]:
for e in tqdm(range(EPOCH)):
    for b, (input, mask, target) in enumerate(train_loader):
        input = input.to(DEVICE)
        mask = mask.to(DEVICE)
        target = target.to(DEVICE)

        pred = model(input, mask)
        loss = loss_fn(pred, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if b % 500 != 0:
            continue

        y_pred = torch.argmax(pred, dim=1)
        report = evaluate(y_pred.cpu().data.numpy(), target.cpu().data.numpy(), output_dict=True)
        with torch.no_grad():
            for dev_input, dev_mask, dev_target in dev_loader:
                dev_input = dev_input.to(DEVICE)
                dev_mask = dev_mask.to(DEVICE)
                dev_target = dev_target.to(DEVICE)
                dev_pred = model(dev_input, dev_mask)
                dev_pred_ = torch.argmax(dev_pred, dim=1)
                dev_report = evaluate(dev_pred_.cpu().data.numpy(), dev_target.cpu().data.numpy(), output_dict=True)
                break
        print(
            '>> epoch:', e,
            'batch:', b,
            'loss:', round(loss.item(), 5),
            'train_acc:', report['accuracy'],
            'dev_acc:', dev_report['accuracy']
        )
    if e%50 ==0:
        torch.save(model, MODEL_DIR + f'{e}.pth')

  0%|          | 0/3 [00:00<?, ?it/s]

>> epoch: 0 batch: 0 loss: 1.42933 train_acc: 0.2 dev_acc: 0.4
>> epoch: 0 batch: 500 loss: 0.94256 train_acc: 0.8 dev_acc: 0.7
>> epoch: 0 batch: 1000 loss: 0.44073 train_acc: 0.7 dev_acc: 0.9
>> epoch: 0 batch: 1500 loss: 0.57633 train_acc: 0.7 dev_acc: 0.8
>> epoch: 0 batch: 2000 loss: 0.66611 train_acc: 0.6 dev_acc: 0.7
>> epoch: 0 batch: 2500 loss: 0.35223 train_acc: 0.9 dev_acc: 0.9
>> epoch: 0 batch: 3000 loss: 0.29196 train_acc: 0.8 dev_acc: 0.8
>> epoch: 0 batch: 3500 loss: 0.2984 train_acc: 0.9 dev_acc: 0.7
>> epoch: 0 batch: 4000 loss: 0.65905 train_acc: 0.9 dev_acc: 0.8
>> epoch: 0 batch: 4500 loss: 0.32991 train_acc: 0.8 dev_acc: 1.0
>> epoch: 0 batch: 5000 loss: 1.00685 train_acc: 0.7 dev_acc: 0.9
>> epoch: 0 batch: 5500 loss: 0.13897 train_acc: 1.0 dev_acc: 0.9
>> epoch: 0 batch: 6000 loss: 0.14716 train_acc: 1.0 dev_acc: 0.9
>> epoch: 0 batch: 6500 loss: 0.44635 train_acc: 0.8 dev_acc: 0.9
>> epoch: 0 batch: 7000 loss: 0.18137 train_acc: 1.0 dev_acc: 1.0
>> epoch: 0 bat

 33%|███▎      | 1/3 [05:35<11:11, 335.95s/it]

>> epoch: 1 batch: 0 loss: 0.31532 train_acc: 0.9 dev_acc: 0.9
>> epoch: 1 batch: 500 loss: 0.04972 train_acc: 1.0 dev_acc: 0.8
>> epoch: 1 batch: 1000 loss: 1.03451 train_acc: 0.7 dev_acc: 0.6
>> epoch: 1 batch: 1500 loss: 0.6972 train_acc: 0.7 dev_acc: 0.8
>> epoch: 1 batch: 2000 loss: 0.08108 train_acc: 1.0 dev_acc: 0.8
>> epoch: 1 batch: 2500 loss: 0.26872 train_acc: 0.9 dev_acc: 0.7
>> epoch: 1 batch: 3000 loss: 0.40997 train_acc: 0.8 dev_acc: 0.9
>> epoch: 1 batch: 3500 loss: 0.18261 train_acc: 0.9 dev_acc: 0.9
>> epoch: 1 batch: 4000 loss: 0.22237 train_acc: 0.9 dev_acc: 0.8
>> epoch: 1 batch: 4500 loss: 0.67195 train_acc: 0.7 dev_acc: 0.7
>> epoch: 1 batch: 5000 loss: 0.65606 train_acc: 0.8 dev_acc: 1.0
>> epoch: 1 batch: 5500 loss: 0.10515 train_acc: 1.0 dev_acc: 0.9
>> epoch: 1 batch: 6000 loss: 0.2915 train_acc: 0.8 dev_acc: 0.8
>> epoch: 1 batch: 6500 loss: 1.08978 train_acc: 0.7 dev_acc: 1.0
>> epoch: 1 batch: 7000 loss: 0.05594 train_acc: 1.0 dev_acc: 0.9
>> epoch: 1 batc

 67%|██████▋   | 2/3 [11:06<05:32, 332.65s/it]

>> epoch: 2 batch: 0 loss: 0.09919 train_acc: 1.0 dev_acc: 0.9
>> epoch: 2 batch: 500 loss: 0.06005 train_acc: 1.0 dev_acc: 1.0
>> epoch: 2 batch: 1000 loss: 0.06538 train_acc: 1.0 dev_acc: 0.9
>> epoch: 2 batch: 1500 loss: 0.2412 train_acc: 0.9 dev_acc: 0.9
>> epoch: 2 batch: 2000 loss: 0.55536 train_acc: 0.7 dev_acc: 0.9
>> epoch: 2 batch: 2500 loss: 0.23445 train_acc: 0.9 dev_acc: 0.8
>> epoch: 2 batch: 3000 loss: 0.08402 train_acc: 1.0 dev_acc: 1.0
>> epoch: 2 batch: 3500 loss: 0.70136 train_acc: 0.8 dev_acc: 0.9
>> epoch: 2 batch: 4000 loss: 0.35669 train_acc: 0.8 dev_acc: 1.0
>> epoch: 2 batch: 4500 loss: 0.15589 train_acc: 0.9 dev_acc: 0.7
>> epoch: 2 batch: 5000 loss: 0.10645 train_acc: 1.0 dev_acc: 0.8
>> epoch: 2 batch: 5500 loss: 0.66848 train_acc: 0.8 dev_acc: 1.0
>> epoch: 2 batch: 6000 loss: 0.48002 train_acc: 0.8 dev_acc: 0.7
>> epoch: 2 batch: 6500 loss: 0.56361 train_acc: 0.7 dev_acc: 0.8
>> epoch: 2 batch: 7000 loss: 0.11765 train_acc: 0.9 dev_acc: 0.8
>> epoch: 2 bat

100%|██████████| 3/3 [16:36<00:00, 332.23s/it]
