In [1]:
import random

import torch
from nlpcda import Similarword
from torch.utils.data import Dataset
import pandas as pd

Simbert不能正常使用，除非你安装：bert4keras、tensorflow ，为了安装快捷，没有默认安装.... No module named 'bert4keras'


In [2]:
from torch import nn, optim

class MyDataset(Dataset):
    def __init__(self, csv_path, type, is_augement):
        self.df = pd.read_csv(csv_path)
        self.df = self.df.dropna()
        # if type == "train":
            # self.df = self.df.sample(100).reset_index(drop=True)
        # self.max_seq_length = max_seq_length
        # self.tokenizer = tokenizer
        self.type = type
        self.is_augement = is_augement
        if self.is_augement:
            self.smw = Similarword(create_num=2, change_rate=0.2)

    def __len__(self):

        return len(self.df)

    def __getitem__(self, index):
        content = self.df.loc[index, "sentence"]
        if self.type == "train" and self.is_augement:
            if random.random() > 0.5:
                content = self.smw.replace(content)[-1]
        label = self.df.loc[index, "label"]
        # d_encode = self.tokenizer.encode_plus(content)
        # padding
        # d_encode = self.tokenizer.encode_plus(content,
        #                                       padding="max_length",
        #                                       max_length=self.max_seq_length,
        #                                       truncation=True)

        # sample = {"input_ids": d_encode['input_ids'],
        #     "token_type_ids": d_encode['token_type_ids'],
        #     "attention_mask": d_encode['attention_mask'],
        #     "length" : sum(d_encode['attention_mask']),
        #     "label": label}
        return content, label

def sup_collate_fn(batch):
    texts, labels = [], []
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    for item in batch:
        texts.append(item[0])
        labels.append(item[1])
    inputs = tokenizer(texts, max_length=40, padding="max_length", truncation=True, return_tensors="pt")
    inputs["labels"] = torch.tensor(labels)
    return inputs



In [3]:
from transformers import BertTokenizer, BertModel

class CLS_model(nn.Module):
    def __init__(self, embedding_dim, target_size):
        super(CLS_model, self).__init__()
        self.bert = BertModel.from_pretrained("bert-base-uncased")

        # 根据连接后的平均和最大池化输出的维度来定义fc1,增加的全连接层和Dropout
        self.fc1 = nn.Linear(768 * 2, embedding_dim)
        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(0.1)  
        self.fc2 = nn.Linear(embedding_dim, embedding_dim // 2)  
        self.fc3 = nn.Linear(embedding_dim // 2, target_size)  

    def forward(self, input_ids, token_type_ids=None, attention_mask=None):
        output = self.bert(input_ids=input_ids,
                           token_type_ids=token_type_ids,
                           attention_mask=attention_mask)
        
        last_hidden_state = output.last_hidden_state
        pooled_output = output.pooler_output

        # 计算序列的平均和最大值
        seq_avg = torch.mean(last_hidden_state, dim=1)
        seq_max = torch.max(last_hidden_state, dim=1)[0]
        concat_out = torch.cat((seq_avg, seq_max), dim=1)

        # 通过全连接层处理
        fc1_out = self.dropout(self.activation(self.fc1(concat_out)))
        fc2_out = self.dropout(self.activation(self.fc2(fc1_out)))
        fc3_out = self.fc3(fc2_out)  

        return fc3_out

class FGM():
    def __init__(self, model,emb_name,epsilon=1.0):
        self.model = model
        self.epsilon = epsilon
        self.emb_name = emb_name
        self.backup = {}

    def attack(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad and self.emb_name in name:
                self.backup[name] = param.data.clone()
                norm = torch.norm(param.grad)
                if norm!=0 and not torch.isnan(norm):
                    r_at = self.epsilon * param.grad / norm
                    param.data.add_(r_at)

    def restore(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad and self.emb_name in name:
                assert name in self.backup
                param.data = self.backup[name]
        self.backup = {}


In [4]:

def train_loop(dataloader, model, loss_fn, optimizer,fgm):
    size = len(dataloader.dataset)
    model.train()  # 设置模型为训练模式
    total_loss, total_accuracy = 0, 0

    for batch, data in enumerate(dataloader):
        inputs = {k: v.to(device) for k, v in data.items() if k != 'labels'}
        labels = data['labels'].to(device)

        # 正常的前向传播
        pred = model(**inputs)
        loss = loss_fn(pred, labels)

        # 正常的反向传播和优化
        optimizer.zero_grad()
        loss.backward()

        # 对抗训练
        fgm.attack()
        pred_adv = model(**inputs)
        loss_adv = loss_fn(pred_adv, labels)
        loss_adv.backward()
        fgm.restore()

        optimizer.step()

        total_loss += loss.item()
        _, predictions = torch.max(pred, 1)
        total_accuracy += (predictions == labels).type(torch.float).sum().item()

        if batch % 100 == 0:
            current_loss = total_loss / (batch + 1)
            current_accuracy = total_accuracy / ((batch + 1) * dataloader.batch_size)
            print(f"loss: {current_loss:>7f}  [accuracy: {100*current_accuracy:>0.2f}%]")

    avg_loss = total_loss / len(dataloader)
    avg_accuracy = total_accuracy / size
    print(f"Average loss: {avg_loss:>7f}  [Average accuracy: {100*avg_accuracy:>0.2f}%]")

    

def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    model.eval()  # 设置模型为评估模式
    with torch.no_grad():
        for batch, data in enumerate(dataloader):
            inputs = {k: v.to(device) for k, v in data.items() if k != 'labels'}
            labels = data['labels']

            pred = model(**inputs)
            test_loss += loss_fn(pred, labels).item()
            _, predictions = torch.max(pred, 1)
            correct += (predictions == labels).type(torch.float).sum().item()

    test_loss /= num_batches
    accuracy = correct / size
    print(f"Test Accuracy: {(100*accuracy):>0.1f}%, Avg loss: {test_loss:>8f}")

In [5]:
from torch.utils.data import DataLoader

train_dataset = MyDataset(csv_path = 'dataset/train.csv', type="train", is_augement=False)
trainloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=sup_collate_fn)

In [6]:
next(enumerate(trainloader))[1]

{'input_ids': tensor([[ 101,  100,  100,  ...,    0,    0,    0],
        [ 101,  100, 1964,  ...,    0,    0,    0],
        [ 101,  100, 1820,  ...,    0,    0,    0],
        ...,
        [ 101,  100,  100,  ...,    0,    0,    0],
        [ 101,  100,  100,  ...,    0,    0,    0],
        [ 101,  100,  100,  ...,    0,    0,    0]]), 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]]), 'labels': tensor([ 7,  3,  5,  7, 10,  2,  8,  7,  8, 10,  5,  3,  7,  8,  5,  0,  7,  5,
         7,  6, 12,  4,  2, 12,  3, 13,  7,  1,  5,  3,  5,  3])}

In [7]:
model = CLS_model(embedding_dim=200, target_size=15)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)

In [8]:
from torch.optim import Adam

loss_fn = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=2e-5)
fgm = FGM(model, epsilon=1, emb_name='bert.embeddings.word_embeddings.weight')

In [9]:
epochs = 20
for epoch in range(epochs):
    print(f"Epoch {epoch+1}\n-------------------------------")
    train_loop(trainloader, model, loss_fn, optimizer,fgm)
    # test_loop(test_loader, model, loss_fn)
print("Done!")

Epoch 1
-------------------------------
loss: 2.687459  [accuracy: 9.38%]
loss: 2.631022  [accuracy: 9.81%]


KeyboardInterrupt: 