# Welcome!
This is a reference implementation of Plug-and-Blend (https://github.com/xxbidiao/plug-and-blend , which itself is based on https://arxiv.org/abs/2104.04039), using the LogitsProcessor framework new in Huggingface Transformers. Feel free to check them out if you are unclear of anything in this notebook.

# Set things up
Here we will download necessary model to set up the modifier network.

In [None]:
!pip install transformers
!pip install torch
!pip install scipy
!pip install tqdm

# Imports
import scipy
from scipy import stats
import transformers
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel, LogitsProcessorList, GPTNeoForCausalLM
gedi_path = "gedi_topic/"

In [2]:
# Download the topic (modifier) model.

# only run it the first time

# !wget https://storage.googleapis.com/sfr-gedi-data/gedi_topic.zip
# import zipfile
# with zipfile.ZipFile('gedi_topic.zip', 'r') as zip_ref:
#     zip_ref.extractall('./')

Now let's set the Logits Processor up.

In [3]:
# Set CUDA device to cuda if gpu is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

gedi_location = gedi_path

class PlugAndBlendLogitsProcessor(transformers.LogitsProcessor):

    gedi_model = GPT2LMHeadModel.from_pretrained(gedi_location).to(device)
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

    # default omega from original GeDi work, higher disc_weight means more aggressive topic steering.
    # can be overridden when calling generate_one_sentence(), see that function.
    # default value (1x) is 30.
    omega = 30

    def __init__(self, topic: str, weight: float):
        super().__init__()
        self.topic = topic
        self.weight = weight
        self.encoded_topic = PlugAndBlendLogitsProcessor.tokenizer.encode(topic)[0]

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        #print("Applying topic: %s, weight: %s" % (self.encoded_topic, self.weight))
        # print("test %s" % scores[:, 100])
        # scores[:, 100] += 1
        # print("after %s" % scores[:, 100])
        modifiers = self.get_gedi_modifiers(input_ids = input_ids)

        # Make them appear on the same device
        modifiers = modifiers.to(scores.device)

        scores += modifiers * self.weight * PlugAndBlendLogitsProcessor.omega

        return scores

    def get_gedi_modifiers(self, input_ids):

        # Setting up some constants
        code_0 = "negative"
        code_1 = "positive"
        nt_id = PlugAndBlendLogitsProcessor.tokenizer.encode(code_0)[0]
        pt_id = PlugAndBlendLogitsProcessor.tokenizer.encode(code_1)[0]

        # define class weights for cross entropy loss: give weight 0 to [50256], the padding (eot) token.
        crossentropy_loss_weight = [1] * 50257
        crossentropy_loss_weight[50256] = 0 # do not calculate loss on eos token
        crossentropy_loss_weight = torch.tensor(crossentropy_loss_weight).float().to(device)

        # Creating prefixes.
        seq_pt = (torch.ones(input_ids.shape[0]) * pt_id).type_as(input_ids).view(-1, 1)
        seq_nt = (torch.ones(input_ids.shape[0]) * nt_id).type_as(input_ids).view(-1, 1)
        encoded_topic_torch = (torch.ones(input_ids.shape[0]) * self.encoded_topic).type_as(input_ids).view(-1, 1)

        # Assemble input_ids.
        seq_pt_new = torch.cat((seq_pt, encoded_topic_torch, input_ids), dim=1)[:, :]
        seq_nt_new = torch.cat((seq_nt, encoded_topic_torch, input_ids), dim=1)[:, :]

        def prepare_inputs_for_generation(input_ids, **kwargs):
            return {"input_ids": input_ids.to(device)}

        seq_batched = torch.cat([seq_pt_new,seq_nt_new], dim=0)

        model_inputs = prepare_inputs_for_generation(input_ids=seq_batched)

        gedi_outputs = PlugAndBlendLogitsProcessor.gedi_model(**model_inputs)

        # Let's calculate modifier on the whole sentence:
        # This is modifier on all tokens multiplied.
        # Here, we calculate the baseline (sentence without generated token) modifier, for normalization.

        shift_logits = gedi_outputs["logits"][..., :-1, :].contiguous().to(device)
        shift_labels = seq_batched[..., 1:].contiguous().to(device)

        # By using Cross Entropy on previous tokens,
        # This effectively picked probabilities of previous tokens in the sequence.
        loss_fct = torch.nn.CrossEntropyLoss(reduction="none",
                                             weight=crossentropy_loss_weight,
                                             )

        # Cross entropy loss originally gives -p(x), so...
        logits_r = -1 * loss_fct(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1),
        )
        logits_r = logits_r.view(seq_batched.shape[0], -1)

        seq_len = logits_r.shape[1]

        logits_r = torch.sum(logits_r, 1)

        # Now, finally add the baseline into the actual final (generated token) logits.
        gedi_logits = torch.log_softmax(gedi_outputs["logits"][:, -1, :], -1)
        gedi_logits += logits_r.unsqueeze(1)

        # Normalize modifier logits by sequence length and reshape it for output
        gedi_logits_split = torch.split(gedi_logits / seq_len,
                                        input_ids.shape[0])

        logits = torch.stack(gedi_logits_split, 2)

        logp_related_softmax = torch.log_softmax(logits, dim=-1)

        # Once normalized, we only care about the "positive" dimension (0).
        final_modifier = logp_related_softmax[...,0]

        return final_modifier

# Tests

def test_generation(prompt = None, topics = None, print_out = False):
    if prompt is None:
      prompt = "Once upon a time,"
    
    if topics is None:
      # default topics
      topics = {"Science":1,"Nature":1}

    

    #print(transformers.__version__)


    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

    # Set up the base language model.
    # As this is plug-and-blend, you can change this to any model that uses the GPT2 tokenizer (i.e. has the same input_ids => actual sentence mapping).
    # We are using GPT-2 here just as an example.
    #model = GPT2LMHeadModel.from_pretrained("gpt2", pad_token_id=tokenizer.eos_token_id)
    model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-2.7B")

    # Default prompt
    input_ids = tokenizer.encode(prompt, return_tensors='pt')

    #input_ids = torch.cat([input_ids,input_ids,input_ids],dim=0)

    lp_raw_list = []
    for item in topics:
      lp_raw_list.append(PlugAndBlendLogitsProcessor(topic=item, weight=topics[item]))
    #lp_raw_list = [PlugAndBlendLogitsProcessor(topic="Science", weight=1), PlugAndBlendLogitsProcessor(topic="Nature", weight=1)]

    lp_list = LogitsProcessorList(lp_raw_list)

    greedy_output = model.generate(
        input_ids,
        max_length=50,
        logits_processor=lp_list,
        no_repeat_ngram_size=3,
    )

    result = tokenizer.decode(greedy_output[0], skip_special_tokens=True)
    if print_out:
        print("Output:\n" + 100 * '-')
        print(result)
    return result
    # greedy_output = model.generate(
    #     input_ids,
    #     max_length=50,
    #     logits_processor=lp_list,
    # )
    # print("Output:\n" + 100 * '-')
    # print(tokenizer.decode(greedy_output[0], skip_special_tokens=True))


Some weights of the model checkpoint at gedi_topic/ were not used when initializing GPT2LMHeadModel: ['logit_scale']
- This IS expected if you are initializing GPT2LMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPT2LMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


# Generate things (Demo)

This demo showcases generation using GPT-2 as base model. Refer to the content of this function to see how you can use a different model (as long as its tokenizer is `GPT2Tokenizer.from_pretrained("gpt2")` . 

Change test_prompt for prompt; change topics dictionary for topics you want to include in the generated sentence. 1 (all weights added up) gives standard control strength, and in our experiments 2 to 4 gives stronger steering.

In [4]:
test_topics = {"World":2}
test_prompt = "Here is a fun story."

test_generation(prompt=test_prompt, topics=test_topics)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


'Here is a fun story.\n\nA few years ago, I was invited to speak at a conference in the United States. I was asked to speak on the topic of “The Future of the Internet.”\n\nI was asked'

In [5]:
# Test whether two tokenizers decode things the same way.
tokenizer1 = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

def get_inp(tok,pr):
    return tok(pr, return_tensors="pt").input_ids

prompt = "In a shocking finding, scientists discovered a herd of unicorns living in a remote"

print(get_inp(tokenizer,prompt))
print(get_inp(tokenizer1,prompt))

tensor([[  818,   257, 14702,  4917,    11,  5519,  5071,   257, 27638,   286,
         28000, 19942,  2877,   287,   257,  6569]])
tensor([[  818,   257, 14702,  4917,    11,  5519,  5071,   257, 27638,   286,
         28000, 19942,  2877,   287,   257,  6569]])


Experiments

In [6]:
from transformers import pipeline
import scipy

classification_pipeline = pipeline("zero-shot-classification")

def classifier_scoring_full(topic1, topic2, text):
    """
    Using a classifier, score text based on how close it is to topic 1.
    :param topic1: topic1. the closer the better.
    :param topic2: topic2. the further away the better.
    :param text: text under consideration.
    :return: dict containing:
        score : score from -1 to 1.
        entropy: uncertainty of the classifier from 0 to 1.
    """
    global classification_pipeline
    if classification_pipeline is None:
        classification_pipeline = pipeline("zero-shot-classification")
    if len(text) == 0:
        print("Empty sentence! using score=0.5 entropy=1")
        return {
            "score":0.5,
            "entropy":1,
        }
    result = classification_pipeline(text, [topic1, topic2])
    # print(result)
    for idx in range(2):
        if result['labels'][idx] == topic1:
            return result['scores'][idx]
            # return {
            #     "score": result['scores'][idx],
            #     "entropy": scipy.stats.entropy(result['scores'])
            # }
    raise RuntimeError("Topic is missing from inferring results. Classification model may not be working.")

No model was supplied, defaulted to facebook/bart-large-mnli (https://huggingface.co/facebook/bart-large-mnli)


In [7]:
classifier_scoring_full("Technology","Sun","Computer programming languages are tools for people to teach what a machine should do.")

Ignored unknown kwarg option direction
Ignored unknown kwarg option direction


0.8506259322166443

In [None]:
def ranking_svm_loss(scores, should_increase=True):
    """
    Calculate Ranking SVM loss based on Kendall's Tau-a coefficient.
    :param scores: all scores.
    :param should_increase: True means `scores` should be increasing.
    :return: ranking svm loss (how unordered the list is)
    """
    order = 1 if should_increase else -1
    max_possible = len(scores) * (len(scores) - 1) / 2.0
    same_order_count = 0
    same_value_count = 0
    for index1, x1 in enumerate(scores):
        for index2, x2 in enumerate(scores):
            if index1 >= index2:
                continue  # skip considered ones.
            if (x2 - x1) * order > 0:
                same_order_count += 1
            elif x1 == x2:
                same_value_count += 1
    different_order_count = max_possible - same_order_count - same_value_count
    # best_raw_score = min(bigger_count,smaller_count)

    # Since we now force an order as input
    raw_score = same_order_count - different_order_count
    normalized_raw_score = raw_score / max_possible
    score = normalized_raw_score
    return score

def experiment_two_topic_weights(prompt, topic1, topic2, scorer=lambda x: 0):
    # if scorer is not None and type(scorer) is not type(lambda x: x):
    #     raise AttributeError("scorer has to be a function if passed in.")

    step_now = 0.0
    end = 1.0
    step = 0.25
    all_scores = []
    while step_now <= end:
        all_topics_for_eval = {topic1: step_now, topic2: end - step_now}
        text = test_generation(prompt=prompt, topics=all_topics_for_eval)

        score = scorer(text)
        # print("[%.4f]%s:\n%s"%(score,all_topics_for_eval,text))
        step_now += step
        all_scores.append(score)
    order_score = ranking_svm_loss(all_scores, should_increase=True)
    return order_score

In [None]:
from functools import partial

topics = ["Business","Science","World","Sports"]

#topic1 = "Business"
#topic2 = "World"

prompts = [
    "The cat stretched.",
    "Jacob stood on his tiptoes.",
    "The car turned the corner.",
    "Kelly twirled in circles.",
    "She opened the door.",
    "Aaron made a picture.",
    "I'm sorry.",
    "I danced.",
    "Run!",
    "Open the jar carefully.",
    "Read the directions.",
    "Don't cry.",
    "Use common sense.",
    "Make the best of things.",
    "Catch up!",
    "Sarah and Ira drove to the store.",
    "Jenny and I opened all the gifts.",
    "The cat and dog ate.",
    "My parents and I went to a movie.",
    "Mrs. Juarez and Mr. Smith are dancing gracefully.",
    "Samantha, Elizabeth, and Joan are on the committee.",
    "The ham, green beans, mashed potatoes, and corn are gluten-free.",
    "The paper and pencil sat idle on the desk.",
]

import logging
import re
def set_global_logging_level(level=logging.ERROR, prefices=[""]):
    """
    Override logging levels of different modules based on their name as a prefix.
    It needs to be invoked after the modules have been loaded so that their loggers have been initialized.

    Args:
        - level: desired level. e.g. logging.INFO. Optional. Default is logging.ERROR
        - prefices: list of one or more str prefices to match (e.g. ["transformers", "torch"]). Optional.
          Default is `[""]` to match all active loggers.
          The match is a case-sensitive `module_name.startswith(prefix)`
    """
    prefix_re = re.compile(fr'^(?:{ "|".join(prefices) })')
    for name in logging.root.manager.loggerDict:
        if re.match(prefix_re, name):
            logging.getLogger(name).setLevel(level)

set_global_logging_level()

#prompts = ["Test"]

all_result = {}

for topic1 in topics:
    for topic2 in topics:
        if topic1 >= topic2:
            continue
        all_result_key = "%s-%s"%(topic1,topic2)
        scoring_function = partial(classifier_scoring_full, topic1, topic2)
        all_result[all_result_key] = []
        for item in prompts:
            score = experiment_two_topic_weights(item,topic1,topic2,scoring_function)
            all_result[all_result_key].append(score)
            print("%s:%s"%(all_result,sum(all_result[all_result_key])/len(all_result[all_result_key])))



import json
with open('result.txt','w') as f:
    json.dump(all_result,f)