In [1]:
import torch
from torch.utils.data import DataLoader
from transformers import BertTokenizer, BertForSequenceClassification, get_scheduler
from datasets import load_dataset,Dataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
from torch.optim import AdamW
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [3]:
# 1. 记载数据
df_train = pd.read_csv('/root/data/train_set.csv', sep='\t')

In [4]:
# 2. 数据预处理
## 2.1 数据通过bert模型进行向量化
train_dataset = Dataset.from_pandas(df_train)

In [5]:
train_dataset['label'][0]

2

In [6]:
tokenizer = BertTokenizer.from_pretrained('/root/model/bert-base-chinese')

In [7]:
def preprocess_function(examples):
    return tokenizer(examples['text'], truncation=True, padding=True, max_length=128)

In [8]:
encoded_dataset = train_dataset.map(preprocess_function, batched=True)

Map: 100%|██████████| 200000/200000 [37:22<00:00, 89.20 examples/s]


In [41]:
train_dataset_torch = encoded_dataset.with_format('torch')

In [42]:
train_dataset_torch['input_ids'][0]

tensor([  101, 11545,  8161,  8369,  9265, 11863,  9960,  9560,  8159, 12779,
         8148, 10706,  8160, 12613,  8144, 12815,  8160,  9564,  8160,  9860,
         8156, 13092,  8144, 11003,  8157, 10896,  8129,  9247,  8160,  8347,
         9488,  8273,  9355, 10212, 10325,  8158,  8214,  8452,  8990,  8144,
         9083,  8160, 11960,  8157,  8222,  8393, 10595,  8156,  8183,  9039,
        12937,  8159,  8284,  9332, 12115,  8129,  9291,  8156, 11003,  8157,
         8203,  9292,  9564,  8160,  8728,  8160, 10194,  8152, 11256,  8129,
        12779,  8148, 10706,  8160, 12613,  8144,  8360,  8805, 10842,  8159,
        12129,  8129, 12408,  8160, 12129,  8129, 12459,  8272,  8740, 10706,
         8160, 12346,  8157, 11256,  8129,  8267,  9039, 11003,  8158,  8360,
         8805,  8369,  9265, 11948,  8152, 12937,  8129,  9801,  8158, 10602,
         8152, 11210,  8158,  8460,  8158, 13068,  8158,  8284,  8161,  8567,
        13262,  8203,  9292, 11256,  8129, 11545,  8161,   102])

In [43]:
# 2.2 转换为dotaloader
train_loader = DataLoader(train_dataset_torch, batch_size=16, shuffle=True)

In [12]:
train_loader

<torch.utils.data.dataloader.DataLoader at 0x7fb0e8a16fb0>

In [13]:
num_labels = len(set(df_train['label']))  # 假设标签是从 0 开始的连续整数
num_labels

14

In [70]:
# 3. 加载模型
model = BertForSequenceClassification.from_pretrained('/root/model/bert-base-chinese', num_labels=num_labels)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /root/model/bert-base-chinese and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [71]:
# 优化器
optimizer = AdamW(model.parameters(), lr=3e-5)

In [72]:
num_training_steps = len(train_loader) * 3  # 假设训练 3 个 epoch
lr_scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=num_training_steps * 0.1, num_training_steps=num_training_steps)

In [73]:
# 设备配置
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(21128, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

In [74]:
# 4. 训练模型
model.train()
for epoch in range(3):  # 训练 3 个 epoch
    for i , batch in enumerate(train_loader):
        # print(batch)  'target':batch['label'].to(device), 
        outputs = model(**{'input_ids':batch['input_ids'].to(device),'token_type_ids':batch['token_type_ids'].to(device),'attention_mask':batch['attention_mask'].to(device)})
        # print(outputs)
        logits = outputs.logits

        # 如果需要计算损失（在训练时）
        loss_fn = torch.nn.CrossEntropyLoss()

        loss = loss_fn(logits, batch['label'].to(device))
         
        # print("Logits:", logits)
        if i %10 == 0: 
            print("Loss:", loss.item()) 
        
        loss.backward()
 
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
    print("Loss:", loss.item()) 
    print(f"Epoch {epoch + 1} completed.")

Loss: 2.6137664318084717
Loss: 2.549243450164795
Loss: 2.76652455329895
Loss: 2.7074153423309326
Loss: 2.539043664932251
Loss: 2.597288131713867
Loss: 2.468055486679077
Loss: 2.4175307750701904
Loss: 2.488964796066284
Loss: 2.604745388031006
Loss: 2.2907803058624268
Loss: 2.5860893726348877
Loss: 2.38511061668396
Loss: 2.351752281188965
Loss: 2.3658854961395264
Loss: 2.226323366165161
Loss: 2.2169675827026367
Loss: 2.105623245239258
Loss: 2.2054359912872314
Loss: 2.1738059520721436
Loss: 2.8371376991271973
Loss: 2.159653425216675
Loss: 2.49169659614563
Loss: 2.2930331230163574
Loss: 2.7077815532684326
Loss: 2.608449697494507
Loss: 2.307523012161255
Loss: 2.41243577003479
Loss: 2.1018056869506836
Loss: 1.9875187873840332
Loss: 2.4616456031799316
Loss: 2.1064398288726807
Loss: 2.2388083934783936
Loss: 2.178513288497925
Loss: 2.2860982418060303
Loss: 2.091620922088623
Loss: 2.313572883605957
Loss: 2.2473278045654297
Loss: 2.1105382442474365
Loss: 2.284374952316284
Loss: 2.0561962127685547

In [75]:
# 5. 模型存储
torch.save(model.state_dict(), 'BertForSequenceClassification.pth')