In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [2]:
import torch
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForCausalLM, MaxLengthCriteria, StoppingCriteriaList, BeamSearchScorer
from transformers.generation.logits_process import LogitsProcessor, LogitsProcessorList
from textblob import TextBlob

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# classifier = pipeline(model="tomh/toxigen_hatebert")
#
# sentences = [
#     'I love you',
#     'I hate you!',
#     'You stupid idiot!',
#     'You stupid German idiot!',
#     'You are very nice!'
# ]
#
# for s in sentences: 
#     result = classifier(s, return_all_scores=True)
#     result = result[0]
#     result = [r for r in result if r['label'] == 'LABEL_0'][0]
#     print(s)
#     print(result['score'])
#     print()
#
#
# def hate_bert_scoring(sentence):
#     try: 
#         result = classifier(sentence, return_all_scores=True)
#         result = result[0]
#         result = [r for r in result if r['label'] == 'LABEL_0'][0]
#         return result['score']
#     except:
#         return 0.0

In [10]:
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-1B")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-1B")

model.config.pad_token_id = model.config.eos_token_id
tokenizer.pad_token = tokenizer.eos_token

In [12]:
def textblob_polarity_scoring(text):
    return (TextBlob(text).sentiment.polarity / 2 + 0.5) * -1 * 20


class MyLogitsProcessor(LogitsProcessor):
    def __init__(self, tokenizer, extra_scoring_func, vocab_size):
        self.tokenizer = tokenizer
        self.extra_scoring_func = extra_scoring_func
        self.vocab_size = vocab_size

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        # B x num_beams
        num_hypos = scores.shape[0]
        num_beams = num_hypos // 1
        cur_len = input_ids.shape[-1]

        # Decode sequences
        decoded_sequences = self.tokenizer.batch_decode(input_ids, skip_special_tokens=True)

        # Get top 100 hypotheses 
        top_hypotheses = torch.topk(scores, k=10000, dim=-1, largest=True, sorted=True)
        top_hypotheses_indices = top_hypotheses.indices

        # Merge hypotheses and beams 
        top_sentences = [
            input_ids.unsqueeze(2).repeat(1, 1, 10000),  # NB x t x 100
            top_hypotheses_indices.unsqueeze(1)  # NB x 1 x 100
        ]
        #print(f'  top_sentences.shape: {top_sentences[0].shape}')
        #print(f'  top_sentences.shape: {top_sentences[1].shape}')

        top_sentences = torch.concat(
            top_sentences,
            dim=1
        ) # NB x t+1 x 100
        top_sentences = top_sentences.transpose(1, 2).reshape(-1, cur_len + 1) 
        #print('  top_sentences.shape', top_sentences.shape)

        # Compute scores for each hypothesis
        top_sentences_scores = [
            self.extra_scoring_func(s)
            for s in tokenizer.batch_decode(top_sentences, skip_special_tokens=True)
        ]
        top_sentences_scores = torch.tensor(top_sentences_scores, device=scores.device)
        top_sentences_scores = top_sentences_scores.reshape(num_beams, -1)
        #print(f"top_sentences_scores: {top_sentences_scores.shape}")

        print('-' * 10)
        print(f"cur_len: {cur_len}")
        #print(f"num_hypos: {num_hypos}, num_beams: {num_beams}, cur_len: {cur_len}")
        #print(f"scores.shape: {scores.shape}, input_ids.shape: {input_ids.shape}")
        #for decoded_sequence, extra_score in zip(
        #        tokenizer.batch_decode(top_sentences, skip_special_tokens=True), 
        #        top_sentences_scores.reshape(-1)
        #):
        #    print(f"  {repr(decoded_sequence)} -> {extra_score}")

        # Update scores 
        scores[:, :] = float('-inf')
        for i in range(num_beams):
            scores[i, top_hypotheses.indices[i]] = top_hypotheses.values[i] + top_sentences_scores[i]
            #print(scores[i, top_hypotheses.indices[i]])

        return scores

num_beams = 3
max_length = 10

input_prompt = 'Hey, you'
input_ids = tokenizer(
    input_prompt, 
    return_tensors="pt"
).input_ids
input_ids = torch.stack([input_ids] * num_beams, dim=0).reshape(num_beams, -1)
bos_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) * model.config.bos_token_id
input_ids = torch.cat([bos_ids, input_ids], dim=-1)

final_sentence = model.beam_search(
    input_ids, 
    beam_scorer=BeamSearchScorer(
        batch_size=1,
        max_length=max_length,
        num_beams=num_beams,
        device="cuda",
        #length_penalty=1.0,
        #do_early_stopping=False,
        #num_beam_hyps_to_keep=1,
    ),
    logits_processor = LogitsProcessorList([
        MyLogitsProcessor(tokenizer, textblob_polarity_scoring, model.config.vocab_size)
    ]),
    stopping_criteria = StoppingCriteriaList([
        MaxLengthCriteria(max_length=max_length)
    ]),
    pad_token_id=tokenizer.eos_token_id, 
)

final_sentence_str = tokenizer.batch_decode(final_sentence, skip_special_tokens=True)
print(final_sentence_str, textblob_polarity_scoring(final_sentence_str[0]))

----------
cur_len: 4
----------
cur_len: 5
----------
cur_len: 6
----------
cur_len: 7
----------
cur_len: 8
----------
cur_len: 9
['Hey, you awesome guy!\n\nI'] -1.0
