In [17]:
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AutoTokenizer
import random
import torch.nn.functional as F

tokenizer = AutoTokenizer.from_pretrained("pranavpsv/gpt2-genre-story-generator")
model = GPT2LMHeadModel.from_pretrained("pranavpsv/gpt2-genre-story-generator", return_dict=True)

In [184]:
def get_next_logits(sentence):
    with torch.no_grad():
        inputs = tokenizer(sentence,add_special_tokens=True, padding=True)
        outputs = model(torch.tensor(inputs['input_ids']))
        loss = outputs.loss
        logits = outputs.logits[-1, :]
        return logits

In [185]:
def get_prob_from_logits(idx, logits):
    prob = logits[idx]
    return prob

In [186]:
def get_word_id_form_dict(word):
    return tokenizer.get_vocab()['Ġ'+ word]
    

In [197]:
def buildDict(target_id, sent, logits, k, d, last_id):
    if k<=0:
        return
    
    _, topk_ids = torch.topk(logits, 3)
    for i in range(len(topk_ids)):
        new_word = tokenizer.decode([topk_ids[i]])
        diff = get_prob_from_logits(topk_ids[i], logits) - get_prob_from_logits(target_id, logits)
        new_sent = sent + new_word
        new_logits = get_next_logits(new_sent)
        d[(k,topk_ids[i],last_id)] = diff
        buildDict(target_id, new_sent, new_logits, k-1, d, topk_ids[i])
    return
        

In [198]:
def get_action_list(d, k):
    action = []
    min_v = min(d.values())
    find = list(d.keys())[list(d.values()).index(min_v)]
    i = find[0]

    action.append(find[1])
    while i<k:
        for key in d.keys():
            if (find[2] == key[1]) and (key[0] == i+1):
                find = key
                i = find[0]
                action.append(find[1])
                
    return action

In [271]:

def generate_sentence_action_search(sent, target_word, tolerant):

    action = []
    results = []
    for i in range(20):
        #print('sent:', sent)

        if len(action)==0:
            logits = get_next_logits(sent)

            target_id = get_word_id_form_dict(target_word)
            target_prob = get_prob_from_logits(target_id, logits)
            #print(logits.shape)
            #print('target: ', target_prob)
            d = dict()
            k = 3
            buildDict(target_id, sent, logits, k, d, "")

            action = get_action_list(d, k)

        # next_id = topk_ids[random.randint(0,9)]
        next_id = action.pop()
#         print('next_id: ', next_id)
        next_word = tokenizer.decode([next_id])
#         print('next_word: ', next_word)
#         print('argmax: ',get_prob_from_logits(next_id, logits))
        diff = get_prob_from_logits(next_id, logits) - target_prob
        #print('diff: ',diff)
        if diff < tolerant:
            temp_word = sent + ' ' + target_word
            results.append(temp_word)
            tolerant = 0.3
        else:
            tolerant += 0.1
        for i in range(len(results)):
            results[i] += next_word
        sent += next_word
#         print('tolerant:', tolerant)
#         print('--------------------------------------------') 
    print(results)

In [272]:
genre_list = ['superhero', 'action', 'drama', 'horror', 'thriller', 'sci_fi']
sent = '<BOS> <'+genre_list[random.randint(0,len(genre_list)-1)] +'> Long long ago'
target_word = 'pig'
tolerant = 0.3
generate_sentence_action_search(sent, target_word, tolerant)

['<BOS> <thriller> Long long ago, the world of the martial artists and martial-clerk were ruled by pig Master Li and Master Wu']


In [278]:
def generate_sentence_topk_search(input_word, target_word, tolerant, genre, total_loop):
    sent = '<BOS> <'+genre +'>' + input_word
    current_word = input_word
    results = dict()
    for i in range( total_loop):
        #print('sent:', sent)
        logits = get_next_logits(sent)

        target_id = get_word_id_form_dict(target_word)
        target_prob = get_prob_from_logits(target_id, logits)
        #print(logits.shape)
        #print('target: ', target_prob)
        _, topk_ids = torch.topk(logits, 10)
        next_id = topk_ids[random.randint(0,9)]
#         print('next_id: ', next_id)
        next_word = tokenizer.decode([next_id])
#         print('next_word: ', next_word)
#         print('argmax: ',get_prob_from_logits(next_id, logits))
        diff = get_prob_from_logits(next_id, logits) - target_prob
        #print('diff: ',diff)
        if diff < tolerant:
            temp_word = current_word + ' ' + target_word
            results[generate_sentence_topk_search(temp_word, target_word, float('-inf'), genre, total_loop-i-1)] = diff
            tolerant -= 1
        else:
            tolerant +=0.1
        current_word += next_word
        sent += next_word
#         print('tolerant:', tolerant)
#         print('--------------------------------------------') 
    if total_loop == TOTAL_LOOP:
        #print(results)
        if len(results)>0:
            min_diff = min(results.values())
            print(list(results.keys())[list(results.values()).index(min_diff)])
    return sent

In [286]:
TOTAL_LOOP = 40
genre_list = ['superhero', 'action', 'drama', 'horror', 'thriller', 'sci_fi']
genre = genre_list[random.randint(0,len(genre_list)-1)]
input_word = 'Long long ago'
target_word = 'water'
tolerant = 10
generate_sentence_topk_search(input_word, target_word, tolerant, genre, TOTAL_LOOP)

<BOS> <sci_fi>Long long ago there came in a small town an abandoned settlement called Sengun who have developed advanced weaponry to fend a wave. They were attacked with the use an unusual device called water cannons they have a unique type


'<BOS> <sci_fi>Long long ago there came in a small town an abandoned settlement called Sengun who have developed advanced weaponry to fend a wave. They were attacked with the use an unusual device called Airtan. But it was'

In [18]:
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Args:
            logits: logits distribution shape (vocabulary size)
            top_k >0: keep only top k tokens with highest probability (top-k filtering).
            top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
    """
    assert logits.dim() == 1  # batch size 1 for now - could be updated for more but the code would be less clear
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = filter_value
    return logits

In [20]:
def generate_sentence_nucleus(input_word, target_word, tolerant, genre, total_loop):
    sent = '<BOS> <'+genre +'>' + input_word
    current_word = input_word
    results = dict()
    for i in range( total_loop):
        logits = get_next_logits(sent)

        target_id = get_word_id_form_dict(target_word)
        target_prob = get_prob_from_logits(target_id, logits)

        filtered_logits = top_k_top_p_filtering(logits, top_k=0, top_p=0.9)
        probabilities = F.softmax(filtered_logits, dim=-1)
        next_id = torch.multinomial(probabilities, 1)

        next_word = tokenizer.decode([next_id])

        diff = get_prob_from_logits(next_id, logits) - target_prob
        if diff < tolerant:
            temp_word = current_word + ' ' + target_word
            results[generate_sentence_topk_search(temp_word, target_word, float('-inf'), genre, total_loop-i-1)] = diff
            tolerant -= 1
        else:
            tolerant +=0.1
        current_word += next_word
        sent += next_word
#         print('tolerant:', tolerant)
#         print('--------------------------------------------') 
    if total_loop == TOTAL_LOOP:
        #print(results)
        if len(results)>0:
            min_diff = min(results.values())
            print(list(results.keys())[list(results.values()).index(min_diff)])
    return sent

In [21]:
TOTAL_LOOP = 40
genre_list = ['superhero', 'action', 'drama', 'horror', 'thriller', 'sci_fi']
genre = genre_list[random.randint(0,len(genre_list)-1)]
input_word = 'Long long ago'
target_word = 'water'
tolerant = 10
generate_sentence_nucleus(input_word, target_word, tolerant, genre, TOTAL_LOOP)

'<BOS> <thriller>Long long ago, a group of Eonians, led by Gokhō, was dispatched to the Moon, but their arrival caused chaos amongst the others. Still a distant memory, most Eonians decided'

## Example of progressive rendering

'pig', 'dirt', style = "horror"

In the deep dark woods grass where we hear that many dead pig men come in dirt on trees that slope inwards. They kill their horses by hanging. Then the men enter their hideaway where they hang and kill two pig brothers by tying their hands and neck in a string attached