P-Tuning 是 Prompt-based Tuning 的简称，是一种基于提示的微调技术。它的主要思想是在预训练语言模型（如 BERT、GPT 等）输入序列中插入一些可训练的提示（prompts），从而使模型能够更好地适应下游任务。  

#### P-Tuning 的含义
- **P**：代表 Prompt，即提示。提示是一组特殊的 token，其嵌入向量是可训练的。
- **Tuning**：指微调，即对模型进行进一步的训练，使其适应特定任务。  

#### 核心思想
传统的微调方法是直接在预训练语言模型的基础上添加特定任务的头部（如分类头）并进行训练。而 P-Tuning 则是在输入序列中插入一些额外的提示 token，这些提示 token 的嵌入向量是可训练的，通过这些提示 token，模型能够在下游任务中更好地利用预训练知识。

#### P-Tuning 和 Prefix-Tuning 的主要区别
- **P-Tuning**：适用于各种下游任务，尤其是分类和回归任务。提示 token 可以插入到输入序列的任何位置，包括前面、中间或后面。
- **Prefix-Tuning**：主要用于生成任务，如文本生成和机器翻译。提示 token 通常添加在输入序列的前面，作为前缀。

In [2]:
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from torch.utils.data import DataLoader, Dataset

In [3]:
class CustomDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length) -> None:
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        inputs = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        return {
            'input_ids': inputs['input_ids'].flatten(),
            'attention_mask': inputs['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

In [4]:
class PTuningBERT(nn.Module):
    def __init__(self, model_name, num_labels, prompt_length):
        super().__init__()
        self.bert = BertForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
        self.prompt_embeddings = nn.Embedding(prompt_length, self.bert.config.hidden_size)
        self.prompt_length = prompt_length

        # 冻结 bert 模型的参数
        for param in self.bert.parameters():
            param.requires_grad = False
    
    def forward(self, input_ids, attention_mask, labels=None):
        batch_size = input_ids.size(0)
        device = input_ids.device

        prompt_tokens = torch.arange(0, self.prompt_length).unsqueeze(0).repeat(batch_size, 1).to(device)
        prompt_embeddings = self.prompt_embeddings(prompt_tokens)

        inputs_embeds = self.bert.bert.embeddings(input_ids)
        split_point = inputs_embeds.size(1) // 2
        inputs_embeds = torch.cat((
            inputs_embeds[:, :split_point], 
            prompt_embeddings, 
            inputs_embeds[:, split_point:]
            ), dim=1)
        attention_mask = torch.cat((
            attention_mask[:, :split_point],
            torch.ones(batch_size, self.prompt_length).to(device), 
            attention_mask[:, split_point:]
            ), dim=1)

        outputs = self.bert(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            labels=labels
        )
        return outputs

In [5]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = PTuningBERT('bert-base-uncased', num_labels=2, prompt_length=10)
texts = ["This is a positive example.", "This is a negative example."]
labels = [1, 0]
max_length = 128

dataset = CustomDataset(texts, labels, tokenizer, max_length)
dataloader = DataLoader(dataset, batch_size=2)

optimizer = AdamW(model.parameters(), lr=1e-5)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

for epoch in range(3):
    model.train()
    for batch in dataloader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(input_ids, attention_mask, labels=labels)
        loss = outputs.loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(f'Epoch {epoch+1}, Loss: {loss.item()}')
        

Epoch 1, Loss: 0.6911073923110962
Epoch 2, Loss: 0.6256528496742249
Epoch 3, Loss: 0.8360525369644165


In [7]:
model.eval()
with torch.no_grad():
    for batch in dataloader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(input_ids, attention_mask, labels=labels)
        logits = outputs.logits
        predictions = torch.argmax(logits, dim=1)
        print(f'Predictions: {predictions}, Labels: {labels}')

Predictions: tensor([1, 1], device='cuda:0'), Labels: tensor([1, 0], device='cuda:0')
