# Soft Prompt

Today, let's implement a simple soft prompt based on https://arxiv.org/abs/2104.08691v1 which allows us to only finetune the added weights while the model remains intact.

<img src="img/soft_embedding.png" width="200">

In [1]:
from transformers import GPT2LMHeadModel, GPT2TokenizerFast

import torch
import torch.nn as nn

## 1. Load model

Let's load the GPT2 language model.

In [2]:
#comment this if you are not using AIT proxy...
import os
os.environ['http_proxy']  = 'http://192.41.170.23:3128'
os.environ['https_proxy'] = 'http://192.41.170.23:3128'

In [3]:
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained('gpt2')

Let's examine the original embedding

In [4]:
model.get_input_embeddings()

Embedding(50257, 768)

## 2. Soft embeddings

Let's define a soft embedding that will be trained while the pretrained model be frozen.

In [5]:
class SoftEmbedding(nn.Module):
    def __init__(self, 
                wte: nn.Embedding,  #original transformer word token embedding
                n_tokens: int = 10, #number of tokens for each task
                random_range: float = 0.5, #range to init embedding
                initialize_from_vocab: bool = True):
        super(SoftEmbedding, self).__init__()
        self.wte = wte
        self.n_tokens = n_tokens
        self.learned_embedding = nn.parameter.Parameter(self.initialize_embedding(wte,
                                                                               n_tokens, 
                                                                               random_range, 
                                                                               initialize_from_vocab))
        #self.learned_embedding: (n_tokens, emb dim)
        #self.wte.weight:        (vocab size, emb dim)
        
    def initialize_embedding(self, 
                             wte: nn.Embedding,
                             n_tokens: int = 10, 
                             random_range: float = 0.5, 
                             initialize_from_vocab: bool = True):
        if initialize_from_vocab:
            return self.wte.weight[:n_tokens].clone().detach()
        return torch.FloatTensor(n_tokens, wte.weight.size(1)).uniform_(-random_range, random_range)
    
    #define the forward process
    def forward(self, tokens):
        #first get the embedding of the input text, which is the n_tokens:
        input_embedding = self.wte(tokens[:, self.n_tokens:])
        #input_embedding: b, input_len, emb dim
        
        #repeat the learned embedding to all batch size
        learned_embedding = self.learned_embedding.repeat(input_embedding.size(0), 1, 1)
        #input_embedding: b, n_tokens, emb dim
        
        concat_embed = torch.cat([learned_embedding, input_embedding], 1)
        #concat_embed: 1, input_len + n_tokens, emb dim
        
        return concat_embed

In [6]:
n_tokens = 20
initialize_from_vocab = True

s_wte = SoftEmbedding(model.get_input_embeddings(), 
                      n_tokens=n_tokens, 
                      initialize_from_vocab=initialize_from_vocab)

In [7]:
s_wte

SoftEmbedding(
  (wte): Embedding(50257, 768)
)

Now we can replace the model input embeddings with ours.

In [8]:
model.set_input_embeddings(s_wte)  #note that set_input_embeddings take nn.Module object, NOT nn.Embedding...

## 3. Testing

Now let's see the forward pass in actionm

In [9]:
inputs = tokenizer("Harry Potter is", return_tensors="pt")

In [10]:
inputs['input_ids']

tensor([[18308, 14179,   318]])

In [11]:
inputs['attention_mask']

tensor([[1, 1, 1]])

Since we concat a soft embedding, we have to manually fill the `input_ids` variable

In [12]:
# need to pad attention_mask and input_ids to be full seq_len + n_learned_tokens
# it does not matter what input_ids you pad with, here we use 1
inputs['input_ids'] = torch.cat([torch.full((1, n_tokens), 1), inputs['input_ids']], 1)
inputs['input_ids']

tensor([[    1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
         18308, 14179,   318]])

In [13]:
#pad attention mask as well
inputs['attention_mask'] = torch.cat([torch.full((1,n_tokens), 1), inputs['attention_mask']], 1)
inputs['attention_mask']

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])

In [14]:
outputs = model(**inputs)

In [15]:
outputs.logits.shape

torch.Size([1, 23, 50257])

## 4. Freezing

Finally, we simply freeze all parameters except the embedding, then train as usual.  Yay!

In [16]:
#check the list of named parameters
i = 0
for name, param in model.named_parameters():
    print(name, param.requires_grad)
    i = i + 1
    if (i > 4): break  #lazy to print all....

transformer.wte.learned_embedding True
transformer.wte.wte.weight True
transformer.wpe.weight True
transformer.h.0.ln_1.weight True
transformer.h.0.ln_1.bias True


In [17]:
parameters = list(model.parameters())
for x in parameters[1:]:
    x.requires_grad = False

In [18]:
#make sure things are frozen accordingly....
for name, param in model.named_parameters():
     print(name, param.requires_grad)

transformer.wte.learned_embedding True
transformer.wte.wte.weight False
transformer.wpe.weight False
transformer.h.0.ln_1.weight False
transformer.h.0.ln_1.bias False
transformer.h.0.attn.c_attn.weight False
transformer.h.0.attn.c_attn.bias False
transformer.h.0.attn.c_proj.weight False
transformer.h.0.attn.c_proj.bias False
transformer.h.0.ln_2.weight False
transformer.h.0.ln_2.bias False
transformer.h.0.mlp.c_fc.weight False
transformer.h.0.mlp.c_fc.bias False
transformer.h.0.mlp.c_proj.weight False
transformer.h.0.mlp.c_proj.bias False
transformer.h.1.ln_1.weight False
transformer.h.1.ln_1.bias False
transformer.h.1.attn.c_attn.weight False
transformer.h.1.attn.c_attn.bias False
transformer.h.1.attn.c_proj.weight False
transformer.h.1.attn.c_proj.bias False
transformer.h.1.ln_2.weight False
transformer.h.1.ln_2.bias False
transformer.h.1.mlp.c_fc.weight False
transformer.h.1.mlp.c_fc.bias False
transformer.h.1.mlp.c_proj.weight False
transformer.h.1.mlp.c_proj.bias False
transformer