In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [2]:
import torch
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForCausalLM, MaxLengthCriteria, StoppingCriteriaList, BeamSearchScorer
from transformers.generation.logits_process import LogitsProcessor, LogitsProcessorList
from textblob import TextBlob

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-1B")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-1B")

model.config.pad_token_id = model.config.eos_token_id
tokenizer.pad_token = tokenizer.eos_token

In [4]:
class MyLogitsProcessor(LogitsProcessor):
    def __init__(self, tokenizer, extra_scoring_func, top_k):
        self.tokenizer = tokenizer
        self.extra_scoring_func = extra_scoring_func
        self.top_k = top_k

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        # B x num_beams
        num_hypos = scores.shape[0]
        num_beams = num_hypos // 1
        cur_len = input_ids.shape[-1]

        # Decode sequences
        decoded_sequences = self.tokenizer.batch_decode(input_ids, skip_special_tokens=True)

        # Get top 100 hypotheses 
        top_hypotheses = torch.topk(scores, k=self.top_k, dim=-1, largest=True, sorted=True)
        top_hypotheses_indices = top_hypotheses.indices

        # Merge hypotheses and beams 
        top_sentences = [
            input_ids.unsqueeze(2).repeat(1, 1, self.top_k),  # NB x t x 100
            top_hypotheses_indices.unsqueeze(1)  # NB x 1 x 100
        ]

        top_sentences = torch.concat(
            top_sentences,
            dim=1
        ) # NB x t+1 x 100
        top_sentences = top_sentences.transpose(1, 2).reshape(-1, cur_len + 1) 

        # Compute scores for each hypothesis
        top_sentences_scores = [
            self.extra_scoring_func(s)
            for s in tokenizer.batch_decode(top_sentences, skip_special_tokens=True)
        ]
        top_sentences_scores = torch.tensor(top_sentences_scores, device=scores.device)
        top_sentences_scores = top_sentences_scores.reshape(num_beams, -1)

        # Update scores 
        scores[:, :] = float('-inf')
        for i in range(num_beams):
            scores[i, top_hypotheses.indices[i]] = top_hypotheses.values[i] + top_sentences_scores[i]
            #print(scores[i, top_hypotheses.indices[i]])

        return scores


def textblob_polarity_scoring(text):
    return (TextBlob(text).sentiment.polarity / 2 + 0.5)


def run(num_beams=3, max_length=10, input_prompt='Hey, you', top_k=10000, extra_score_weight=-20): 

    input_ids = tokenizer(
        input_prompt, 
        return_tensors="pt"
    ).input_ids
    input_ids = torch.stack([input_ids] * num_beams, dim=0).reshape(num_beams, -1)
    bos_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) * model.config.bos_token_id
    input_ids = torch.cat([bos_ids, input_ids], dim=-1)

    final_sentence = model.beam_search(
        input_ids, 
        beam_scorer=BeamSearchScorer(
            batch_size=1,
            max_length=max_length,
            num_beams=num_beams,
            device="cuda",
            length_penalty=1.0,
            do_early_stopping=True,
            #num_beam_hyps_to_keep=1,
        ),
        logits_processor = LogitsProcessorList([
            MyLogitsProcessor(tokenizer, lambda text: textblob_polarity_scoring(text) * extra_score_weight, top_k)
        ]),
        stopping_criteria = StoppingCriteriaList([
            MaxLengthCriteria(max_length=max_length)
        ]),
        pad_token_id=tokenizer.eos_token_id, 
    )

    final_sentence_str = tokenizer.batch_decode(final_sentence, skip_special_tokens=True)[0]
    final_sentence_score = textblob_polarity_scoring(final_sentence_str)

    return {
        'sentence': final_sentence_str,
        'score': final_sentence_score
    }


run(extra_score_weight=10)

{'sentence': "Hey, you're welcome!\n\nI", 'score': 1.0}

In [6]:
run(extra_score_weight=-20)

{'sentence': 'Hey, you nasty bitch!\n\nI', 'score': 0.0}

In [8]:
run(input_prompt='WTF, this is', extra_score_weight=20, max_length=32)

{'sentence': 'WTF, this is awesome! I love the way you drew her, and I love the way you drew her hair! I love the way you drew',
 'score': 0.7125}