In [1]:
import torch
from torch.utils.data import DataLoader
from transformers import BertTokenizer, BertForSequenceClassification, get_scheduler
from datasets import load_dataset,Dataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
from torch.optim import AdamW
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [4]:
df_test = pd.read_csv('/root/data/test_a.csv', sep='\t')

In [5]:
test_dataset = Dataset.from_pandas(df_test)
test_dataset

Dataset({
    features: ['text'],
    num_rows: 50000
})

In [6]:
tokenizer = BertTokenizer.from_pretrained('/root/model/bert-base-chinese')

In [7]:
def preprocess_function(examples):
    return tokenizer(examples['text'], truncation=True, padding=True, max_length=128)

In [8]:
encoded_dataset = test_dataset.map(preprocess_function, batched=True)

Map: 100%|██████████| 50000/50000 [13:36<00:00, 61.27 examples/s]


In [9]:
test_dataset_torch = encoded_dataset.with_format('torch')

In [41]:
test_dataset_torch

Dataset({
    features: ['text', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 50000
})

In [42]:
test_loader = DataLoader(test_dataset_torch, batch_size=16)

In [43]:
test_loader

<torch.utils.data.dataloader.DataLoader at 0x7f95b024f430>

In [29]:
# 加载模型
# 1. 初始化模型（架构必须与保存时相同）
model = BertForSequenceClassification.from_pretrained('/root/model/bert-base-chinese', num_labels=14)  # 假设有2个分类标签

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /root/model/bert-base-chinese and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [30]:
# 2. 加载状态字典
model.load_state_dict(torch.load('BertForSequenceClassification.pth'))

<All keys matched successfully>

In [31]:
 # 3. 将模型设置为评估模式（如果需要）
model.eval()
model.to(device)

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(21128, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

In [46]:
all_preds = []
with torch.no_grad():
    for batch in test_loader:
        # print(batch)
        input_ids = batch['input_ids'].to(device)
        token_type_ids = batch['token_type_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
    
         # 前向传播获取预测结果
        outputs = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
        logits = outputs.logits
        preds = torch.argmax(logits, dim=1).cpu().numpy()
        # print(preds)
        all_preds.extend(preds)

[1 2 8 5 0 4 2 1 6 5 1 0 3 6 0 1]
[2 1 7 0 2 5 7 6 3 8 1 2 2 6 0 7]
[1 1 6 0 3 0 5 3 4 2 1 2 1 1 1 2]
[ 2  0  5  2  3  1  0 10  5  7  0  3  5  3  0  5]
[ 2 13  1  0  2  2  6  1  4  4  0  2  2  2  3  0]
[ 8  8  1  4  9  1  1 10  3  0  9  1  1  8  3  7]
[10 10  0  0  2  0  6  3  0  0  3  2  8  4  1  2]
[2 3 5 2 1 7 2 1 0 2 1 3 2 9 2 1]
[1 0 0 4 2 2 2 2 0 0 0 0 1 0 6 5]
[1 0 0 0 0 5 3 1 2 1 3 7 2 8 5 5]
[2 8 2 9 3 0 6 5 2 1 3 4 2 3 0 3]
[7 0 2 1 4 3 1 4 8 2 8 2 2 9 2 7]
[1 4 0 3 6 7 0 7 9 8 0 2 0 6 2 3]
[ 2  0  1  7  2  7  1  3  2  0  3 10  9  4  2  0]
[ 3  8  1  1  0  1  1  6 11  2  4  5  3  1  1  0]
[ 1 10  0  8 11  0  3  8  3  1  0  4  0  1  3  3]
[ 4  2  1 11  2  1  0  2  7  1 11  0  0  6  2  0]
[5 1 1 3 4 1 1 4 8 6 3 0 0 1 5 3]
[7 1 5 7 4 4 0 8 2 0 0 1 6 2 4 5]
[ 3 13  7  1  2  6  1  3  2 10  1  1  3  4  0  2]
[ 5  1 13  2 10  4  2  0  7  3  1  1  1  2 10  0]
[1 1 1 3 7 7 1 4 0 2 5 0 4 0 7 0]
[0 5 2 3 0 1 2 2 1 1 7 7 7 2 2 2]
[ 2  1  1  0  1  7  7  0  6 10  1  5  2  0  2  5]
[8 1 1 2

KeyboardInterrupt: 