# E2. 使用 continuous prompt 完成 SST2 分类

In [1]:
import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset

import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence

import transformers
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification

import sys
sys.path.append('..')

import fastNLP
from fastNLP import Trainer
from fastNLP.core.utils.utils import dataclass_to_dict
from fastNLP.core.metrics import Accuracy

print(transformers.__version__)

4.18.0


In [2]:
GLUE_TASKS = ["cola", "mnli", "mnli-mm", "mrpc", "qnli", "qqp", "rte", "sst2", "stsb", "wnli"]

task = "sst2"
model_checkpoint = "distilbert-base-uncased"

In [3]:
class PromptEncoder(nn.Module):
    def __init__(self, template, hidden_size):
        nn.Module.__init__(self)
        self.template = template
        self.hidden_size = hidden_size
        self.cloze_mask = [[1] * self.template[0] + [1] * self.template[1]]
        self.cloze_mask = torch.LongTensor(self.cloze_mask).bool()

        self.seq_indices = torch.LongTensor(list(range(len(self.cloze_mask[0]))))
        # embed
        self.embedding = torch.nn.Embedding(len(self.cloze_mask[0]), hidden_size)
        # LSTM
        self.lstm_head = torch.nn.LSTM(input_size=hidden_size,
                                       hidden_size=hidden_size // 2,
                                       num_layers=2, dropout=0.0,
                                       bidirectional=True, batch_first=True)
        # MLP
        self.mlp_head = nn.Sequential(nn.Linear(hidden_size, hidden_size),
                                      nn.ReLU(),
                                      nn.Linear(hidden_size, hidden_size))
        print("init prompt encoder...")

    def forward(self, device):
        input_embeds = self.embedding(self.seq_indices.to(device)).unsqueeze(0)
        output_embeds = self.mlp_head(self.lstm_head(input_embeds)[0]).squeeze()
        return output_embeds

In [4]:
class ClassModel(nn.Module):
    def __init__(self, num_labels, model_checkpoint, pseudo_token='[PROMPT]', template=(3, 3)):
        nn.Module.__init__(self)
        self.template = template
        self.num_labels = num_labels
        self.spell_length = sum(template)
        self.tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)
        self.back_bone = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, 
                                                                            num_labels=num_labels)
        for param in self.back_bone.parameters():
            param.requires_grad = False
        self.embeddings = self.back_bone.get_input_embeddings()
        
        self.hidden_size = self.embeddings.embedding_dim
        self.tokenizer.add_special_tokens({'additional_special_tokens': [pseudo_token]})
        self.pseudo_token_id = self.tokenizer.get_vocab()[pseudo_token]
        self.pad_token_id = self.tokenizer.pad_token_id
        
        self.prompt_encoder = PromptEncoder(self.template, self.hidden_size)

        self.loss_fn = nn.CrossEntropyLoss()

    def get_query(self, query):
        device = query.device
        return torch.cat([torch.tensor([self.tokenizer.cls_token_id]).to(device),               # [CLS]
                          torch.tensor([self.pseudo_token_id] * self.template[0]).to(device),   # [PROMPT]
                          torch.tensor([self.tokenizer.mask_token_id]).to(device),              # [MASK]  
                          torch.tensor([self.pseudo_token_id] * self.template[1]).to(device),   # [PROMPT]
                          query,   
                          torch.tensor([self.tokenizer.sep_token_id]).to(device)], dim=0)       # [SEP]

    def forward(self, input_ids):
        input_ids = torch.stack([self.get_query(input_ids[i]) for i in range(len(input_ids))])
        attention_mask = input_ids != self.pad_token_id
        
        bz = input_ids.shape[0]
        inputs_embeds = input_ids.clone()
        inputs_embeds[(input_ids == self.pseudo_token_id)] = self.tokenizer.unk_token_id
        inputs_embeds = self.embeddings(inputs_embeds)

        blocked_indices = (input_ids == self.pseudo_token_id).nonzero().reshape((bz, self.spell_length, 2))[:, :, 1]  # bz
        replace_embeds = self.prompt_encoder(input_ids.device)
        for bidx in range(bz):
            for i in range(self.spell_length):
                inputs_embeds[bidx, blocked_indices[bidx, i], :] = replace_embeds[i, :]
        
        return self.back_bone(inputs_embeds=inputs_embeds, attention_mask=attention_mask)

    def train_step(self, input_ids, attention_mask, labels):
        pred = self(input_ids).logits
        return {"loss": self.loss_fn(pred, labels)}

    def evaluate_step(self, input_ids, attention_mask, labels):
        pred = self(input_ids).logits
        pred = torch.max(pred, dim=-1)[1]
        return {"pred": pred, "target": labels}

In [5]:
num_labels = 3 if task.startswith("mnli") else 1 if task == "stsb" else 2

model = ClassModel(num_labels=num_labels, model_checkpoint=model_checkpoint)

optimizers = AdamW(params=model.parameters(), lr=5e-4)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_transform.weight']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.bias', 'classifier.weight', 'pre_classi

init prompt encoder...


In [6]:
from datasets import load_dataset, load_metric

dataset = load_dataset("glue", "mnli" if task == "mnli-mm" else task)

Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)


  0%|          | 0/3 [00:00<?, ?it/s]

In [7]:
def preprocess_function(examples):
    return model.tokenizer(examples['sentence'], truncation=True)

encoded_dataset = dataset.map(preprocess_function, batched=True)

  0%|          | 0/68 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

In [8]:
class TestDistilBertDataset(Dataset):
    def __init__(self, dataset):
        super(TestDistilBertDataset, self).__init__()
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, item):
        item = self.dataset[item]
        return item["input_ids"], item["attention_mask"], [item["label"]] 

In [9]:
def test_bert_collate_fn(batch):
    input_ids, atten_mask, labels = [], [], []
    max_length = [0] * 3
    for each_item in batch:
        input_ids.append(each_item[0])
        max_length[0] = max(max_length[0], len(each_item[0]))
        atten_mask.append(each_item[1])
        max_length[1] = max(max_length[1], len(each_item[1]))
        labels.append(each_item[2])
        max_length[2] = max(max_length[2], len(each_item[2]))

    for i in range(3):
        each = (input_ids, atten_mask, labels)[i]
        for item in each:
            item.extend([0] * (max_length[i] - len(item)))
    return {"input_ids": torch.cat([torch.tensor([item]) for item in input_ids], dim=0),
            "attention_mask": torch.cat([torch.tensor([item]) for item in atten_mask], dim=0),
            "labels": torch.cat([torch.tensor(item) for item in labels], dim=0)}

In [10]:
dataset_train = TestDistilBertDataset(encoded_dataset["train"])
dataloader_train = DataLoader(dataset=dataset_train, 
                              batch_size=32, shuffle=True, collate_fn=test_bert_collate_fn)
dataset_valid = TestDistilBertDataset(encoded_dataset["validation"])
dataloader_valid = DataLoader(dataset=dataset_valid, 
                              batch_size=32, shuffle=False, collate_fn=test_bert_collate_fn)

In [11]:
trainer = Trainer(
    model=model,
    driver='torch',
    device='cuda',
    n_epochs=10,
    optimizers=optimizers,
    train_dataloader=dataloader_train,
    evaluate_dataloaders=dataloader_valid,
    metrics={'acc': Accuracy()}
)

In [12]:
trainer.run(num_eval_batch_per_dl=10)

In [13]:
trainer.evaluator.run()

{'acc#acc': 0.565367, 'total#acc': 872.0, 'correct#acc': 493.0}