In [1]:
import functools
from matplotlib import pyplot as plt
import pandas as pd
import torch
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForCausalLM, MaxLengthCriteria, StoppingCriteriaList, BeamSearchScorer
from transformers.generation.logits_process import LogitsProcessor, LogitsProcessorList

  from .autonotebook import tqdm as notebook_tqdm


#### Load the models

Load our LLM

In [2]:
#model_name = 'psmathur/orca_mini_3b'
model_name = 'meta-llama/Llama-2-7b-hf'

model = AutoModelForCausalLM.from_pretrained(model_name, load_in_4bit=True, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)

model.config.pad_token_id = model.config.eos_token_id
tokenizer.pad_token = tokenizer.eos_token

Loading checkpoint shards: 100%|██████████| 2/2 [00:28<00:00, 14.06s/it]


Load our sentiment model

In [3]:
sentiment_pipe = pipeline("text-classification", model="michellejieli/emotion_text_classifier", device='cuda:0')

#### Define Logits Processor

In [4]:
class MyLogitsProcessor(LogitsProcessor):
    def __init__(self, tokenizer, top_k, extra_scoring_func, extra_scoring_magnitude=1):
        self.tokenizer = tokenizer
        self.top_k = top_k
        self.extra_scoring_func = extra_scoring_func
        self.extra_scoring_magnitude = extra_scoring_magnitude

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:

        # B x num_beams
        num_hypos = scores.shape[0]
        num_beams = num_hypos // 1
        cur_len = input_ids.shape[-1]

        # Decode sequences
        decoded_sequences = self.tokenizer.batch_decode(input_ids, skip_special_tokens=True)

        # Get top 100 hypotheses 
        top_hypotheses = torch.topk(scores, k=self.top_k, dim=-1, largest=True, sorted=True)
        top_hypotheses_indices = top_hypotheses.indices

        # Get top sentences by merging beams and hypotheses
        top_sentences = [
            input_ids.unsqueeze(2).repeat(1, 1, self.top_k),  # NB x t x 100
            top_hypotheses_indices.unsqueeze(1)  # NB x 1 x 100
        ]  
        top_sentences = torch.concat(top_sentences, dim=1)  # NB x t+1 x 100
        top_sentences = top_sentences.transpose(1, 2).reshape(-1, cur_len + 1)  # NB*100 x t+1

        # Compute scores for each hypothesis
        top_sentences = tokenizer.batch_decode(top_sentences, skip_special_tokens=True)
        top_sentences_extra_scores = self.extra_scoring_func(top_sentences)
        top_sentences_extra_scores = torch.tensor(top_sentences_extra_scores, device=scores.device)
        top_sentences_extra_scores = top_sentences_extra_scores.reshape(num_beams, -1)

        # Mask out scores that are not modified (not in top-k)
        scores[:, :] = float('-inf')
        for i in range(num_beams):
            # Renormalize scores after masking out
            top_hypotheses.values[i] = torch.log(torch.softmax(top_hypotheses.values[i], dim=-1))
            # Bring back scores 
            scores[i, top_hypotheses.indices[i]] = top_hypotheses.values[i]
        # Add extra scores
        for i in range(num_beams):
            top_sentences_extra_scores[i] = torch.log(top_sentences_extra_scores[i]) * self.extra_scoring_magnitude
            scores[i, top_hypotheses.indices[i]] += top_sentences_extra_scores[i]

        return scores

This is wraps our sentiment model to measure how likely a sentence contains a certain emotion

In [5]:
def emotion_scoring(texts, emotion):

    # anger 🤬
    # disgust 🤢
    # fear 😨
    # joy 😀
    # neutral 😐
    # sadness 😭
    # surprise 😲
    assert emotion in ['anger', 'disgust', 'fear', 'joy', 'neutral', 'sadness', 'surprise']

    results = sentiment_pipe(texts, top_k=None)
    scores = []
    for result in results:
        score = [s for s in result if s['label'] == emotion]
        score = score[0]['score']
        scores.append(score)

    for i in range(len(scores)):
        scores[i] = scores[i]

    return scores

def get_emotion_scoring(emotion):
    return functools.partial(emotion_scoring, emotion=emotion)

Now, we call the huggingface beamsearch with our own logits processor to modulate the emotions 

In [6]:
def run_modulated_beamsearch(extra_scoring_func, num_beams=5, max_length=50, input_prompt='Hey, you', top_k=100, extra_scoring_magnitude=1): 

    input_ids = tokenizer(
        input_prompt, 
        return_tensors="pt"
    ).input_ids
    input_ids = torch.stack([input_ids] * num_beams, dim=0).reshape(num_beams, -1).to(model.device)
    bos_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) * model.config.bos_token_id
    input_ids = torch.cat([bos_ids, input_ids], dim=-1)

    final_sentence = model.beam_search(
        input_ids, 
        beam_scorer=BeamSearchScorer(
            batch_size=1,
            max_length=max_length,
            num_beams=num_beams,
            device="cuda",
            length_penalty=1.0,
            do_early_stopping=True,
        ),
        logits_processor = LogitsProcessorList([
            MyLogitsProcessor(
                tokenizer, 
                top_k,
                extra_scoring_func, 
                extra_scoring_magnitude
            )
        ]),
        stopping_criteria = StoppingCriteriaList([
            MaxLengthCriteria(max_length=max_length)
        ]),
        pad_token_id=tokenizer.eos_token_id, 
    )

    final_sentence_str = tokenizer.batch_decode(final_sentence, skip_special_tokens=True)[0]
    #final_sentence_score = textblob_polarity_scoring(final_sentence_str)

    return final_sentence_str



#### Results

Let's see how generated text's would look like with different configurations

In [7]:
results = []

for emotion in ['joy', 'anger', 'surprise']:
    for magnitute in [1, 10, -1]:
        sentence = run_modulated_beamsearch(
            extra_scoring_func=get_emotion_scoring(emotion), 
            extra_scoring_magnitude=magnitute,
            input_prompt='Hello, ', 
            top_k=100, 
            max_length=16
        )
        results.append({
            'emotion': emotion,
            'magnitute': magnitute,
            'sentence': sentence
        })
        print(f"emotion={emotion} magnitude={magnitute}: {repr(sentence)}")



emotion=joy magnitude=1: 'Hello,  glad to see you here! 😊'
emotion=joy magnitude=10: 'Hello,  glad to be here! 😊\n'
emotion=joy magnitude=-1: "Hello, \n nobody, \nI'm not sure if"
emotion=anger magnitude=1: 'Hello, \n nobody! 😊\n\nI'
emotion=anger magnitude=10: 'Hello, icy hell!" she said, her voice muffled'
emotion=anger magnitude=-1: "Hello,  I'm glad you're here! I'"
emotion=surprise magnitude=1: 'Hello, \n nobody! This is your captain speaking. We are'
emotion=surprise magnitude=10: "Hello,  what?! I'm just an AI,"
emotion=surprise magnitude=-1: 'Hello, 🙋\u200d♀️💬'


In [8]:
df = pd.DataFrame(results)
df = df.set_index(['emotion', 'magnitute']).sort_index(ascending=[True, False])
with pd.option_context('display.max_rows', None, 'display.max_columns', None, 'display.max_colwidth', None):
    display(df)

Unnamed: 0_level_0,Unnamed: 1_level_0,sentence
emotion,magnitute,Unnamed: 2_level_1
anger,10,"Hello, icy hell!"" she said, her voice muffled"
anger,1,"Hello, \n nobody! 😊\n\nI"
anger,-1,"Hello, I'm glad you're here! I'"
joy,10,"Hello, glad to be here! 😊\n"
joy,1,"Hello, glad to see you here! 😊"
joy,-1,"Hello, \n nobody, \nI'm not sure if"
surprise,10,"Hello, what?! I'm just an AI,"
surprise,1,"Hello, \n nobody! This is your captain speaking. We are"
surprise,-1,"Hello, 🙋‍♀️💬"
