In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from datasets import load_from_disk, load_dataset
import evaluate
from transformers import AutoTokenizer, AutoModel, AdamW
from transformers.data.data_collator import DataCollatorWithPadding
import pandas as pd

In [2]:
# 定义模型
class Model(nn.Module):
    def __init__(self, num_classes, path) -> None:
        super().__init__()
        self.pretrained = AutoModel.from_pretrained(path)
        self.dropout = nn.Dropout(0.3)
        self.fc = nn.Linear(768, num_classes)

    def forward(self, input_ids, attention_mask, token_type_ids):
        out = self.pretrained(input_ids, attention_mask, token_type_ids)
        out = out.last_hidden_state[:, 0]
        out = self.dropout(out)
        logits = self.fc(out)

        return logits

In [None]:
path = "huggingface/models/bert_base_uncased"
tokenizer = AutoTokenizer.from_pretrained(path)
dataset = load_from_disk("huggingface/datasets/glue/sst2")

# 数据预处理函数
def preprocess_function(data):
    data = tokenizer.batch_encode_plus(data['sentence'], truncation=True)
    return data
# 处理数据集
dataset = dataset.map(function=preprocess_function, batched=True, batch_size=1000,
                num_proc=4, remove_columns=['sentence', 'idx'])

# 获取训练集,验证集和测试集
train_dataloader = DataLoader(dataset['train'], batch_size=128, 
        collate_fn=DataCollatorWithPadding(tokenizer),shuffle=True, drop_last=True)

valid_dataloader = DataLoader(dataset['validation'], batch_size=128, 
        collate_fn=DataCollatorWithPadding(tokenizer),shuffle=True, drop_last=True)

test_dataloader = DataLoader(dataset['test'], batch_size=64, shuffle=False, drop_last=True,
                                collate_fn=DataCollatorWithPadding(tokenizer))

model = Model(num_classes=2, path=path)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# 对预训练的Bert和分类头使用不同的学习率
optimizer = AdamW([
    {"params": model.pretrained.parameters(), 'lr': 2e-5},
    {"params": model.fc.parameters(), 'lr': 5e-4}
])
criterion = nn.CrossEntropyLoss()
epochs = 5

In [4]:
# 在验证集上的精度
def test():
    model.eval()
    model.to(device)
    correct, total = 0, 0
    for idx, data in enumerate(valid_dataloader):
        input_ids = data['input_ids'].to(device)
        attention_mask = data['attention_mask'].to(device)
        token_type_ids = data['token_type_ids'].to(device)
        labels = data['labels'].to(device)
        with torch.no_grad():
            logits = model(input_ids, attention_mask, token_type_ids)
        logits = logits.argmax(dim=-1)
        correct += (logits == labels).sum()
        total += len(labels)
    return correct / total

In [5]:
# 训练函数，使用验证集上的精度来选择模型
def train():
    model.train()
    model.to(device)
    correct, total = 0, 0
    total_loss, updates = 0, 0
    best_acc = 0
    for epoch in range(epochs):
        for idx, data in enumerate(train_dataloader):
            input_ids = data['input_ids'].to(device)
            attention_mask = data['attention_mask'].to(device)
            token_type_ids = data['token_type_ids'].to(device)
            labels = data['labels'].to(device)

            logits = model(input_ids, attention_mask, token_type_ids)
            loss = criterion(logits, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            logits = logits.argmax(dim=-1)
            correct += (logits == labels).sum()
            total += len(labels)

            total_loss += loss.item()
            updates += 1

            if idx % 50 == 0:
                print('epoch:{}, idx:{}, loss:{}, acc:{}'.format(epoch + 1, idx, 
                                                        total_loss / updates, correct / total))
        acc = test()
        print("epoch:{}, valid_dataset acc:{}".format(epoch + 1, acc))
        if acc > best_acc:
            best_acc = acc
            torch.save(model.state_dict(), 'code/huggingface实战/分类/classify.pt')


In [6]:
# 开始训练
train()

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


epoch:1, idx:0, loss:0.7110611200332642, acc:0.484375
epoch:1, idx:50, loss:0.4006512132929821, acc:0.8144914507865906
epoch:1, idx:100, loss:0.3311491499442865, acc:0.8524907231330872
epoch:1, idx:150, loss:0.30035069584846497, acc:0.8683257699012756
epoch:1, idx:200, loss:0.2788397831777435, acc:0.8802083134651184
epoch:1, idx:250, loss:0.2653611544297036, acc:0.8882283568382263
epoch:1, idx:300, loss:0.25490892091858824, acc:0.8935838937759399
epoch:1, idx:350, loss:0.24564787488632392, acc:0.8981258869171143
epoch:1, idx:400, loss:0.23697058759871267, acc:0.9024314284324646
epoch:1, idx:450, loss:0.2284041431519779, acc:0.9069775342941284
epoch:1, idx:500, loss:0.22317740579922044, acc:0.9096494317054749
epoch:1, valid_dataset acc:0.9231771230697632
epoch:2, idx:0, loss:0.21962794293941085, acc:0.9113348126411438
epoch:2, idx:50, loss:0.20998327766562871, acc:0.9159174561500549
epoch:2, idx:100, loss:0.2008533557101585, acc:0.9200558066368103
epoch:2, idx:150, loss:0.19309104467292

In [7]:
# 推理函数
def inference():
    model.eval()
    model.to(device)
    sentences = []
    labels = []
    for data in test_dataloader:
        input_ids = data['input_ids'].to(device)
        attention_mask = data['attention_mask'].to(device)
        token_type_ids = data['token_type_ids'].to(device)
        with torch.no_grad():
            logits = model(input_ids, attention_mask, token_type_ids)
        logits = logits.argmax(dim=-1)
        for input_id in input_ids:
            sentence = tokenizer.decode(input_id)
            sentence = sentence.replace('[CLS]', '').replace('[SEP]', '').replace('[PAD]', '').strip()
            sentences.append(sentence)
        labels.extend(logits.detach().cpu().numpy())
    res = pd.DataFrame()
    res['sentences'] = sentences
    res['labels'] = labels

    res.to_csv("code/huggingface实战/分类/inference.csv", sep=',')

In [None]:
model.load_state_dict(torch.load("code/huggingface实战/分类/classify.pt"))
inference()