# What's Prompt Tuning?
Prompt tuning optimizes a set of 20-100 special tokens at the start of the context in order to replicate the effect of a finetuned language model. These tokens are continuous vector embeddings that can't be decoded into words, but can still force the model to behave in a certain way.

Unfortunately, the transformers training and generation utilities don't yet support embeddings as input, so I've set up rudimentary training and generation loops for a simple description task.

You can read more here:

https://arxiv.org/abs/2104.08691

# What's this sheet do?
This sheet finds a set of special tokens that forces the model to output a given description for a character or object.

If you've played around with prompt engineering in AI Dungeon, it's like getting a computer to write a compressed, non-human-readable World Info entry based on your prose description.

In [None]:
#@title Setup dependencies
!pip install transformers
!git clone https://github.com/corolla-johnson/soft-prompt-tuning.git soft_prompt_tuning
!nvidia-smi

model_setup_for_prompt_tuning = False

# Setup word wrapping
from IPython.display import HTML, display

def set_css():
  display(HTML('''
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  '''))
get_ipython().events.register('pre_run_cell', set_css)

In [2]:
#@title Grab model
from transformers import GPT2LMHeadModel, GPTNeoForCausalLM, GPT2TokenizerFast
from transformers.optimization import Adafactor
from soft_prompt_tuning.soft_embedding import SoftEmbedding
import transformers
import torch
import torch.nn as nn
import os
import tarfile

model = GPT2LMHeadModel.from_pretrained('gpt2').to("cuda")
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")

In [3]:
#@title Input target description
#@markdown Double-click this cell to open it and edit the 'prefix_len', 'prompt' and 'target' variables.

# Length of the tuned prompt in tokens.
# The paper doesn't recommend going over 100.
prefix_len = 20

# Fixed part of the prompt (how you ask the model for a description)
prompt = "Detailed description of Emma Violence:\n"

# Desired description of the character or object
target = ("Emma Violence is a British cybernetic assassin who is known for her "
          "elegant style and high-profile targets. She is described to be a perfectionist "
          "in her work, often going above and beyond the call of duty. "
          "Despite her cold, ruthless, and calculating nature, she has a warm and "
          "motherly side that she only shows to a select few people. She has two guns "
          "implanted to her forearms and can utilize them both with deadly accuracy.") 




# Tokenize strings
prompt_tokens = tokenizer(prompt, return_tensors="pt")
prompt_len = prompt_tokens.input_ids.shape[1]
target_tokens = tokenizer(target, return_tensors="pt")
target_len = target_tokens.input_ids.shape[1]
target_start = prefix_len + prompt_len

print(f"Prefix Length: {prefix_len}")
print(f"Prompt Length: {prompt_len}")
print(f"Target Length: {target_len}")

inputs = tokenizer(prompt)
# need to pad attention_mask and input_ids to be full seq_len + n_learned_tokens
# even though it does not matter what you pad input_ids with, it's just to make HF happy
inputs['input_ids'] = torch.cat([torch.full((1,prefix_len), 50256), prompt_tokens['input_ids'], target_tokens['input_ids']],1).cuda()
inputs['attention_mask'] = torch.cat([torch.full((1,prefix_len), 1), prompt_tokens['attention_mask'], target_tokens['attention_mask']],1).cuda()
labels = torch.cat([torch.full((1,target_start), -100), target_tokens['input_ids']], 1).cuda()

Prefix Length: 20
Prompt Length: 7
Target Length: 90


In [4]:
#@title Configure model for prompt tuning
#@markdown Runs only once
if not model_setup_for_prompt_tuning:
  model.train()

  # Freeze model
  for param in model.parameters():
      param.requires_grad = False

  old_wte = model.get_input_embeddings()

  # Add softembedding module
  s_wte = SoftEmbedding(old_wte,
                        n_tokens=prefix_len,
                        initialize_from_vocab=True).to("cuda")
  model.set_input_embeddings(s_wte)

  # Set up optimizer
  params = [model.transformer.wte.learned_embedding]
  optimizer = Adafactor(params=params)
  model_setup_for_prompt_tuning = True

  # Test model output
  output = model(**inputs, labels=labels)

In [5]:
#@title (OPTIONAL) Load existing prefix
#@markdown This will override the existing prefix_len with the size of the loaded one.

path = "learned_prefix.pt"#@param{type:"string"}
s_wte.learned_embedding = torch.load(path)

prefix_len = s_wte.learned_embedding.shape[0]
s_wte.n_tokens = prefix_len

print(f"Loaded prefix of length {prefix_len}")
output = model(**inputs, labels=labels)
loss = output.loss
loss.backward()
print(f"Loss: {loss}")

Loaded prefix of length 20
Loss: 0.7311124205589294


In [None]:
#@title Training
#@markdown 4000+ for "gpt2"
#@markdown
#@markdown 200+ for "GPT-Neo-2.7B"

iterations = 4000#@param{type:"number"}

for i in range(iterations):
  optimizer.zero_grad()
  output = model(**inputs, labels=labels)
  loss = output.loss
  loss.backward()
  optimizer.step()
  if i%10 == 0:
    print(f"{i}: Loss: {loss}")

In [6]:
#@title (OPTIONAL) Save tuned prefix
path = "learned_prefix.pt"#@param{type:"string"}
torch.save(model.transformer.wte.learned_embedding, path)

print(f"Saved prefix of length {prefix_len}")
output = model(**inputs, labels=labels)
loss = output.loss
loss.backward()
print(f"Loss: {loss}")

Saved prefix of length 20
Loss: 0.6657297611236572


In [25]:
#@title (OPTIONAL) Reinitialize prefix
#@markdown Warning: This will reset any training. Make sure to save the tuned prefix first.
s_wte.__init__(old_wte, s_wte.n_tokens)

In [9]:
#@title Generation parameters 
#@markdown Make sure to run this cell after making changes to the parameters.

use_prefix = True #@param{type:"boolean"}

custom_prompt = "Emma propped the sniper rifle on a balustrade overlooking the Rue de la Paix and carefully aligned her crosshairs with the target. He was a man in a" #@param{type:"string"}
use_custom_prompt = False #@param{type:"boolean"}

temperature = 0.7 #@param{type:"number"}
top_k = 0.7 #@param{type:"number"}

output_length = 120 #@param{type:"number"}

In [17]:
#@title Generate!
#@markdown We're using a fairly barebones top-k sampling scheme so the output might be quite repetitive.
model.eval()

if use_prefix:
  model.transformer.wte = s_wte

  if use_custom_prompt:
    test_inputs = tokenizer(custom_prompt, return_tensors="pt")
    test_inputs.input_ids = torch.cat([torch.full((1,prefix_len), 50256), test_inputs['input_ids']],1).cuda()
    test_inputs.attention_mask = torch.cat([torch.full((1,prefix_len), 1), test_inputs['attention_mask']],1).cuda()

  else:
    test_inputs = tokenizer(prompt, return_tensors="pt")
    test_inputs.input_ids = torch.cat([torch.full((1,prefix_len), 50256), prompt_tokens['input_ids']],1).cuda()
    test_inputs.attention_mask = torch.cat([torch.full((1,prefix_len), 1), prompt_tokens['attention_mask']],1).cuda()

else:
  model.transformer.wte = old_wte
  if use_custom_prompt:
    test_inputs = tokenizer(custom_prompt, return_tensors="pt")
    test_inputs.input_ids = test_inputs.input_ids.cuda()
    test_inputs.attention_mask = test_inputs.attention_mask.cuda()
  else:
    test_inputs = tokenizer(prompt, return_tensors="pt")
    test_inputs.input_ids = test_inputs.input_ids.cuda()
    test_inputs.attention_mask = test_inputs.attention_mask.cuda()

def top_k_logits(logits, k):
    v, ix = torch.topk(logits, k)
    out = logits.clone()
    out[out < v[:, [-1]]] = -float('Inf')
    return out

with torch.no_grad():
  for k in range(output_length):
      import random
      import numpy as np
      import torch
      import torch.nn as nn
      from torch.nn import functional as F

      logits = model(input_ids=test_inputs.input_ids.cuda(), attention_mask=test_inputs.attention_mask.cuda()).logits
      
      # pluck the logits at the final step and scale by temperature
      logits = logits[:, -1, :] / temperature

      logits = top_k_logits(logits, 10)
      
      # apply softmax to convert to probabilities
      probs = F.softmax(logits, dim=-1)

      # sample from the distribution
      next_token = torch.multinomial(probs, num_samples=1)

      # append to the sequence and continue
      test_inputs.input_ids = torch.cat((test_inputs.input_ids, next_token), dim=1).cuda()
      test_inputs.attention_mask = torch.cat([test_inputs.attention_mask, torch.full((1,1), 1).cuda()],1).cuda()

print(f"Target:\n<{prefix_len} token tuned prompt>\n{prompt + target}\n")

if use_prefix:
  print(f"Actual:\n<{prefix_len} token tuned prompt>\n{tokenizer.decode(test_inputs.input_ids[0][prefix_len:])}")
else:
  print(f"Actual:\n{tokenizer.decode(test_inputs.input_ids[0])}")


Target:
<20 token tuned prompt>
Detailed description of Emma Violence:
Emma Violence is a British cybernetic assassin who is known for her elegant style and high-profile targets. She is described to be a perfectionist in her work, often going above and beyond the call of duty. Despite her cold, ruthless, and calculating nature, she has a warm and motherly side that she only shows to a select few people. She has two guns implanted to her forearms and can utilize them both with deadly accuracy.

Actual:
<20 token tuned prompt>
Detailed description of Emma Violence:
Emma Violence is a British cybernetic assassin who is known for her elegant style and high-profile targets. She is described to be a perfectionist in her work and is known for her elegant style. Despite her cold, calculated, and ruthless approach to a target, she has two guns implanted to her forearms and can utilize them both with deadly accuracy. Despite her cold, calculated, and ruthless approach to a target, she has a warm