In [1]:
from transformers import AutoTokenizer

#加载分词工具
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')

tokenizer

PreTrainedTokenizerFast(name_or_path='bert-base-cased', vocab_size=28996, model_max_len=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})

In [2]:
from datasets import load_dataset
from datasets import load_from_disk

#加载数据集
#从网络加载
#datasets = load_dataset(path='glue', name='sst2')

#从本地磁盘加载数据
datasets = load_from_disk('./data/glue_sst2')


#分词
def f(data):
    return tokenizer(
        data['sentence'],
        padding='max_length',
        truncation=True,
        max_length=30,
    )


datasets = datasets.map(f, batched=True, batch_size=1000, num_proc=4)

#取数据子集，否则数据太多跑不动
dataset_train = datasets['train'].shuffle().select(range(1000))
dataset_test = datasets['validation'].shuffle().select(range(200))

del datasets

dataset_train

 

Loading cached processed dataset at data/glue_sst2/train/cache-7aa524f603792fd9.arrow


 

Loading cached processed dataset at data/glue_sst2/train/cache-0b32e00a0d7a3a56.arrow


 

Loading cached processed dataset at data/glue_sst2/train/cache-0ba0364f53072715.arrow


 

Loading cached processed dataset at data/glue_sst2/train/cache-827155e4d83ebe3c.arrow


 

Loading cached processed dataset at data/glue_sst2/validation/cache-c55ae83ee610d88f.arrow


 

Loading cached processed dataset at data/glue_sst2/validation/cache-6362a30151b5e959.arrow


 

Loading cached processed dataset at data/glue_sst2/validation/cache-2fb35a5f1cc316c4.arrow


 

Loading cached processed dataset at data/glue_sst2/validation/cache-c8408139568ba8e5.arrow


 

Loading cached processed dataset at data/glue_sst2/test/cache-87319ea4f4e6236d.arrow


 

Loading cached processed dataset at data/glue_sst2/test/cache-57288efb5fa03946.arrow


 

Loading cached processed dataset at data/glue_sst2/test/cache-0710d5088cfae21a.arrow


 

Loading cached processed dataset at data/glue_sst2/test/cache-764b7f7882a99b5e.arrow


Dataset({
    features: ['sentence', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 1000
})

In [3]:
from transformers import AutoModelForSequenceClassification

#加载模型
model = AutoModelForSequenceClassification.from_pretrained('bert-base-cased',
                                                           num_labels=2)

print(sum([i.nelement() for i in model.parameters()]) / 10000)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at b

10831.181


In [4]:
import numpy as np
from datasets import load_metric
from transformers.trainer_utils import EvalPrediction

#加载评价函数
metric = load_metric('accuracy')


#定义评价函数
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    logits = logits.argmax(axis=1)
    return metric.compute(predictions=logits, references=labels)


#模拟测试输出
eval_pred = EvalPrediction(
    predictions=np.array([[0, 1], [2, 3], [4, 5], [6, 7]]),
    label_ids=np.array([1, 1, 1, 1]),
)

compute_metrics(eval_pred)

{'accuracy': 1.0}

In [5]:
from transformers import TrainingArguments, Trainer

#初始化训练参数
args = TrainingArguments(output_dir='./output_dir', evaluation_strategy='epoch')
args.num_train_epochs = 1
args.learning_rate = 1e-4
args.weight_decay = 1e-2
args.per_device_eval_batch_size = 32
args.per_device_train_batch_size = 16

#初始化训练器
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=dataset_train,
    eval_dataset=dataset_test,
    compute_metrics=compute_metrics,
)

#评价模型
trainer.evaluate()

The following columns in the evaluation set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: sentence, idx. If sentence, idx are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 200
  Batch size = 32


{'eval_loss': 0.7820796370506287,
 'eval_accuracy': 0.49,
 'eval_runtime': 8.0978,
 'eval_samples_per_second': 24.698,
 'eval_steps_per_second': 0.864}

In [6]:
#训练
trainer.train()

The following columns in the training set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: sentence, idx. If sentence, idx are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 1000
  Num Epochs = 1
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 63


Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.416479,0.8


The following columns in the evaluation set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: sentence, idx. If sentence, idx are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 200
  Batch size = 32


Training completed. Do not forget to share your model on huggingface.co/models =)




TrainOutput(global_step=63, training_loss=0.5025159744989305, metrics={'train_runtime': 216.993, 'train_samples_per_second': 4.608, 'train_steps_per_second': 0.29, 'total_flos': 15416663400000.0, 'train_loss': 0.5025159744989305, 'epoch': 1.0})

In [7]:
#评价模型
trainer.evaluate()

The following columns in the evaluation set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: sentence, idx. If sentence, idx are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 200
  Batch size = 32


{'eval_loss': 0.4164787232875824,
 'eval_accuracy': 0.8,
 'eval_runtime': 11.1786,
 'eval_samples_per_second': 17.891,
 'eval_steps_per_second': 0.626,
 'epoch': 1.0}

In [8]:
#保存模型
trainer.save_model(output_dir='./output_dir')

Saving model checkpoint to ./output_dir
Configuration saved in ./output_dir/config.json
Model weights saved in ./output_dir/pytorch_model.bin


In [9]:
import torch


def collate_fn(data):
    label = [i['label'] for i in data]
    input_ids = [i['input_ids'] for i in data]
    token_type_ids = [i['token_type_ids'] for i in data]
    attention_mask = [i['attention_mask'] for i in data]

    label = torch.LongTensor(label)
    input_ids = torch.LongTensor(input_ids)
    token_type_ids = torch.LongTensor(token_type_ids)
    attention_mask = torch.LongTensor(attention_mask)

    return label, input_ids, token_type_ids, attention_mask


#数据加载器
loader_test = torch.utils.data.DataLoader(dataset=dataset_test,
                                          batch_size=4,
                                          collate_fn=collate_fn,
                                          shuffle=True,
                                          drop_last=True)

for i, (label, input_ids, token_type_ids,
        attention_mask) in enumerate(loader_test):
    break

label, input_ids, token_type_ids, attention_mask

(tensor([0, 1, 1, 1]),
 tensor([[  101,   119,   119,   119,  1103,  1273, 18907,  1121,   170,  2960,
           1104,  8594,   113,  1380,  1834,  1106,  5233,  1149,  1103,  4289,
            114,   119,   119,   119,   102,     0,     0,     0,     0,     0],
         [  101,  1139,  1992,  7930,   176,  8871,  1377,  4655,  2745, 26478,
           1107,   170, 13657,  2365, 13390,  1104,  4105,  9688,  1105,  9207,
           1193, 10478,  1174,  1149,  8594,   119,   102,     0,     0,     0],
         [  101,  1884,  3848, 16719,  3290,  1921, 19603,  1103,  4817,  1111,
            188,  1732,  7200, 10947, 12606,  2895,  3899,   119,   102,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0],
         [  101,   170,  4600,  8179,  1104,  1103,  2581,  2286, 14430,  5532,
            119,   102,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0]]),
 tensor([[0

In [10]:
import torch


#测试
def test():
    #加载参数
    model.load_state_dict(torch.load('./output_dir/pytorch_model.bin'))

    model.eval()

    #运算
    out = model(input_ids=input_ids,
                token_type_ids=token_type_ids,
                attention_mask=attention_mask)

    #[4, 2] -> [4]
    out = out['logits'].argmax(dim=1)

    correct = (out == label).sum().item()

    return correct / len(label)


test()

0.75