In [10]:
# data
# from datasets import load_dataset
# ds = load_dataset('csv',data_files="data/ChnSentiCorp.csv")

from torch.utils.data import Dataset
from datasets import load_from_disk

class BizDataset(Dataset):
    def __init__(self,split):
        self.dataset = load_from_disk(r"D:\Workspace\regression\huggingface\dataset")
        if split=="train":
            self.dataset=self.dataset["train"]
        elif split=="test":
            self.dataset=self.dataset["test"]
        elif split=="validation":
            self.dataset=self.dataset["validation"]
        else:
            print("split is wrong")
           
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self,item):
        text = self.dataset[item]["text"]
        label = self.dataset[item]["label"]
        return text,label

In [11]:
# net
from transformers import BertModel
import torch

DEVICE= torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model_name="bert-base-chinese"
model_name=r"D:\Workspace\regression\huggingface\model\bert-base-chinese\models--bert-base-chinese\snapshots\c30a6ed22ab4564dc1e3b2ecbf6e766b0611a33f"

pretrained = BertModel.from_pretrained(model_name).to(DEVICE)
print(pretrained)

# 如需进行定制化，需要调整对应的输入输出，保持词向量一致
print(pretrained.embeddings.word_embeddings)

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(21128, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False

```
import torch
from transformers import BertModel, BertTokenizer

# 加载预训练模型和分词器
model = BertModel.from_pretrained("bert-base-uncased")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# 输入文本
text = "Hello, how are you?"

# 分词并编码
inputs = tokenizer(text, return_tensors="pt")

# 获取模型输出
with torch.no_grad():
    outputs = model(**inputs)

# 提取 [CLS] token 的隐藏状态
cls_embedding = outputs.last_hidden_state[:, 0]

print("CLS embedding shape:", cls_embedding.shape)  # 输出: (1, 768)

```

In [12]:
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = torch.nn.Linear(768,2)
    # 前向推理运算
    def forward(self,input_ids,attention_mask,token_type_ids):
        with torch.no_grad():
            out = pretrained(input_ids=input_ids,attention_mask= attention_mask,token_type_ids=token_type_ids)
        
        out= self.fc(out.last_hidden_state[:,0])
        out = out.softmax(dim=1)
        return out

In [15]:
# trainer
from transformers import BertTokenizer,AdamW
from torch.utils.data import DataLoader

EPOCH=10
tokenizer = BertTokenizer.from_pretrained(model_name)

# 对数据进行编码处理
def collate_fn(data):
    sentes = [i[0] for i in data]
    labels = [i[1] for i in data]
    data = tokenizer.batch_encode_plus(
        batch_text_or_text_pairs=sentes,
        padding="max_length",  # 填充到最大长度
        truncation=True,       # 截断超过最大长度的序列
        max_length=30,         # 最大长度为 10
        return_length=True,
        return_tensors="pt",
    )
    input_ids = data["input_ids"]
    attention_mask = data["attention_mask"]
    token_type_ids = data["token_type_ids"]
    labels = torch.LongTensor(labels)
    
    return input_ids,attention_mask,token_type_ids,labels

# 1、创建数据集
train_ds= BizDataset("train")
# 2、创建数据加载器
train_loader = DataLoader(
    dataset=train_ds,
    batch_size=8,
    shuffle=True,
    drop_last=True,
    collate_fn=collate_fn
)

#训练
if __name__=='__main__':
    print(DEVICE)
    model=Model().to(DEVICE)
    optimizer=AdamW(model.parameters(),lr=1e-3)    
    loss_func =torch.nn.CrossEntropyLoss()
    
    model.train()
    for epoch in range(EPOCH):
        for i ,(input_ids,attention_mask,token_type_ids,labels) in enumerate(train_loader):
            input_ids,attention_mask,token_type_ids,labels = input_ids.to(DEVICE),attention_mask.to(DEVICE),token_type_ids.to(DEVICE),labels.to(DEVICE)
            # 执行前向计算
            out= model(input_ids,attention_mask,token_type_ids)
            loss = loss_func(out,labels)
            
            # 深度学习优化模型三步走
            #1、清空权重梯度,2、反向传播，3、更新梯度
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if i%5 == 0:
                out = out.argmax(dim=1)
                acc = (out == labels).sum().item()/len(labels)
                print(epoch,i,loss.item(),acc)
        
        # 保存模型参数
        torch.save(model.state_dict(),f"params/{epoch}bert.pt")
        print(epoch,"参数保存成功。")    

cpu
0 0 0.6501349210739136 0.625
0 5 0.6785931587219238 0.625
0 10 0.5828644037246704 0.875
0 15 0.6672710180282593 0.75
0 20 0.6269630789756775 0.875
0 25 0.5823624134063721 0.875
0 30 0.5705785155296326 0.75
0 35 0.6450259685516357 0.625
0 40 0.5675601959228516 0.875
0 45 0.6019430160522461 0.75
0 50 0.5194454193115234 0.875
0 55 0.6560109853744507 0.5
0 60 0.5182225108146667 0.875
0 65 0.5746229887008667 0.75
0 70 0.5453912615776062 0.875
0 75 0.5146540999412537 0.75
0 80 0.622308075428009 0.625
0 85 0.5554884672164917 0.75
0 90 0.4001471996307373 1.0
0 95 0.5963026285171509 0.625
0 100 0.41106510162353516 1.0
0 105 0.49354955554008484 0.875
0 110 0.6354066729545593 0.625
0 115 0.5349334478378296 0.75
0 120 0.46713849902153015 1.0
0 125 0.40546926856040955 1.0
0 130 0.5157116651535034 0.875
0 135 0.6548353433609009 0.625
0 140 0.4921993911266327 0.875
0 145 0.45338380336761475 0.875
0 150 0.57278972864151 0.75
0 155 0.5948582291603088 0.625
0 160 0.5654757618904114 0.75
0 165 0.4616

1 150 0.5602446794509888 0.75
1 155 0.3324597179889679 1.0
1 160 0.3995357155799866 0.875
1 165 0.333290159702301 1.0
1 170 0.5937380194664001 0.75
1 175 0.6031216382980347 0.75
1 180 0.35814058780670166 1.0
1 185 0.7196798324584961 0.5
1 190 0.47112563252449036 0.875
1 195 0.5544594526290894 0.75
1 200 0.44024595618247986 0.875
1 205 0.3181207776069641 1.0
1 210 0.3274097740650177 1.0
1 215 0.36965516209602356 1.0
1 220 0.5725760459899902 0.75
1 225 0.45976656675338745 0.875
1 230 0.32977980375289917 1.0
1 235 0.5687647461891174 0.75
1 240 0.5978078842163086 0.75
1 245 0.5180699229240417 0.875
1 250 0.39836031198501587 0.875
1 255 0.5301960110664368 0.75
1 260 0.6325533390045166 0.75
1 265 0.38303515315055847 1.0
1 270 0.3301967680454254 1.0
1 275 0.32930970191955566 1.0
1 280 0.5607427358627319 0.75
1 285 0.4778633117675781 0.875
1 290 0.552074670791626 0.75
1 295 0.5836372375488281 0.75
1 300 0.6109071373939514 0.625
1 305 0.32744449377059937 1.0
1 310 0.49642616510391235 0.75
1 315

2 295 0.6369665861129761 0.625
2 300 0.6557157039642334 0.625
2 305 0.5747004747390747 0.75
2 310 0.3277896046638489 1.0
2 315 0.3732028007507324 1.0
2 320 0.5769637227058411 0.75
2 325 0.42758625745773315 0.875
2 330 0.3167629837989807 1.0
2 335 0.44729095697402954 0.875
2 340 0.7451139092445374 0.5
2 345 0.5591530203819275 0.75
2 350 0.403085321187973 0.875
2 355 0.5934910178184509 0.625
2 360 0.45056262612342834 0.875
2 365 0.31672534346580505 1.0
2 370 0.31622231006622314 1.0
2 375 0.5475991368293762 0.75
2 380 0.3265668749809265 1.0
2 385 0.4365466237068176 0.875
2 390 0.3320382833480835 1.0
2 395 0.3158762454986572 1.0
2 400 0.42778947949409485 0.875
2 405 0.5590391755104065 0.75
2 410 0.5603075623512268 0.75
2 415 0.5021701455116272 0.75
2 420 0.5637964010238647 0.75
2 425 0.5012670755386353 0.75
2 430 0.4592035412788391 0.75
2 435 0.3222958743572235 1.0
2 440 0.5266078114509583 0.75
2 445 0.31861138343811035 1.0
2 450 0.554032564163208 0.625
2 455 0.3153887391090393 1.0
2 460 0

3 445 0.7103617787361145 0.625
3 450 0.6780800223350525 0.5
3 455 0.4601452052593231 0.75
3 460 0.44828835129737854 0.875
3 465 0.40021395683288574 0.875
3 470 0.3759569525718689 0.875
3 475 0.47413158416748047 0.875
3 480 0.6647790670394897 0.625
3 485 0.5549802184104919 0.75
3 490 0.32065549492836 1.0
3 495 0.5366483926773071 0.75
3 500 0.446577787399292 0.875
3 505 0.36841392517089844 0.875
3 510 0.45536044239997864 0.875
3 515 0.4627752900123596 0.875
3 520 0.4692034125328064 0.75
3 525 0.36862483620643616 1.0
3 530 0.4176875948905945 0.875
3 535 0.6634545922279358 0.625
3 540 0.4542435109615326 0.875
3 545 0.3270670473575592 1.0
3 550 0.5982930660247803 0.75
3 555 0.6666509509086609 0.5
3 560 0.34704938530921936 1.0
3 565 0.5633231401443481 0.75
3 570 0.7127703428268433 0.5
3 575 0.5615665912628174 0.75
3 580 0.6152815818786621 0.75
3 585 0.39018723368644714 0.875
3 590 0.5664861798286438 0.75
3 595 0.3157813847064972 1.0
3 600 0.3377455770969391 1.0
3 605 0.4556068480014801 0.875

KeyboardInterrupt: 