In [24]:
from lstm_model import dictionary_corpus
from lstm_model import utils
from transformers import *

In [25]:
import torch
import torch.nn.functional as F
import os
import warnings
warnings.filterwarnings("ignore", message="numpy.dtype size changed")
warnings.filterwarnings("ignore", message='source.*has changed')

In [26]:
from lm_utils import *

In [27]:


def update_hidden(lm, lm_vocabulary, prompt, word, hidden, alpha = 0.5):
    hidden_tmp = []
    for i in range(len(hidden)):
        hidden_tmp.append(Variable(hidden[i], requires_grad=True))
    hidden_tmp = tuple(hidden_tmp)
    output, hidden = lm(prompt, hidden_tmp)
    output_scores = output.view(-1, len(lm_vocabulary.word2idx))[-1]
    cross_entropy = torch.nn.CrossEntropyLoss()
    loss = cross_entropy(output_scores, torch.LongTensor(lm_vocabulary.word2idx[word]))
    loss.backward()
    hidden_new = []
    for i in range(len(hidden)):
        grad = hidden_tmp[i].grad.clone()
        updated_hidden = hidden_tmp - alpha * grad
        hidden_new.append(updated_hidden)
    hidden_new = tuple(hidden_new)
    return hidden_new


# def generation_with_update(prompt, lm, lm_vocabulary, cuda = False, model_type ='lstm_gulordava', sentence = False, n_sentences = 1, max_len = 20,
#               generation_type = 'greedy', filter_by = 'k', top_p = 0.5,top_k = 10, output_generations = 5, temperature = 1,
#               repetition_penalty = 1):
#     if sentence: text += '<eos>'
#     if generation_type == 'greedy': output_generations = 1
#     generated = []
#     original_prompt = lm_vocabulary.encode(prompt)
#     original_prompt = torch.LongTensor(original_prompt).cuda() if cuda else torch.LongTensor(original_prompt)
#     original_prompt = utils.batchify(original_prompt, 1, cuda)
#     while len(generated) < output_generations:
#         tokens_generated = 0
#         words_in_text = set([int(x) for x in original_prompt])
#         end_of_sentence = False
#         generated_text = ''
#         n_sentences_generated = 0
#         # TODO adapt to non-lstm
#         hidden = lm.init_hidden(1)
        
#         output, hidden = lm(original_prompt[:-1], hidden)
#         word = original_prompt[:-1]
        
#         hidden = update_hidden(lm, lm_vocabulary, prompt, word, hidden, alpha = 0.5)
#         prompt = word
        
#         while n_sentences_generated != n_sentences or tokens_generated == max_len:
#             output, hidden = lm(prompt, hidden)
#             output_scores = output.view(-1, len(lm_vocabulary.word2idx))[-1]
#             if temperature != 1:
#                 output_scores /= (temperature)
#             if repetition_penalty != 1:
#                 for _ in words_in_text:
#                     output_scores[_] /= (repetition_penalty)
#             output_scores[encode_lstm('<unk>')[0]] /= - 10 ** 10 
#             next_word_probability_distribution = F.log_softmax(output_scores, dim=-1)
#             sorted_probabilities, sorted_words = torch.sort(next_word_probability_distribution, descending = True)
#             if generation_type == 'greedy':
#                 generated_word_id = greedy(sorted_words)
#             elif generation_type == 'sampling':
#                 generated_word_id = sampling(sorted_words, sorted_probabilities, filter_by= filter_by, top_k = top_k, top_p = top_p)
#             generated_word_id = generated_word_id[0]
#             generated_word = decode_lstm([generated_word_id])
#             if generated_word == '<eos>':
#                 n_sentences_generated += 1
#             else:
#                 generated_text += generated_word + ' '
#                 words_in_text.add(generated_word_id)
#             prompt = encode_lstm(generated_word)
#             tokens_generated  += 1
#         generated.append(generated_text)
#     return generated

lm, lm_vocabulary = load_lm(model_type = 'distilGPT2')

In [115]:
cuda = False


def greedy(sorted_words, n_samples=1):
    return list(sorted_words[:n_samples])


def sampling(sorted_words, sorted_probabilities, filter_by='k', top_k=10, top_p=0.3, n_samples=1):
    stop_at = top_k
    if filter_by == 'p':
        cumulative_probs = torch.cumsum(F.softmax(sorted_probabilities, dim=-1), dim=-1)
        stop_at = 0
        while cumulative_probs[stop_at] < top_p:
            stop_at += 1
        stop_at += 1  # first token above the threshold
    sampled_token = torch.multinomial(F.softmax(torch.Tensor(sorted_probabilities[:stop_at]), dim=-1),
                                      num_samples=n_samples)
    return [sorted_words[int(i)] for i in sampled_token]


def deploy_lm(text, lm, lm_vocabulary,
              model_type='lstm_gulordava', sentence=False, cuda=False, hidden = None):
    if sentence: text += '<eos>'
    text = lm_vocabulary.encode(text)
    with torch.no_grad():
        text = torch.LongTensor(text).cuda() if cuda else torch.LongTensor(text)   
        if model_type == 'lstm_gulordava':
            text = text.unsqueeze(1)
            if hidden == None:
                hidden = lm.init_hidden(1)
            output, hidden = lm(text, hidden)
            output_scores = output.view(-1, len(lm_vocabulary.word2idx))
            return output_scores, hidden
        elif lm_vocabulary.transformer:
            text = text.unsqueeze(0)
            outputs, _ = lm(text)
            output_scores = outputs.squeeze(0)
            return output_scores
        
#     next_word_scores = language_model(context)[0][-1]
  
#   # Trasform scores into probabilities through softmax function
#   # probability distribution over the words in the vocabulary
#   next_word_probability_distribution = F.softmax(next_word_scores, dim = -1)
#   return next_word_probability_distribution
        
def generation(original_prompt, lm, lm_vocabulary, cuda=False, model_type='lstm_gulordava', sentence=False, n_sentences=1,
               max_len=20,
               generation_type='greedy', filter_by='k', top_p=0.5, top_k=10, output_generations=5, temperature=1,
               repetition_penalty=1):
    if sentence: text += '<eos>'
    if generation_type == 'greedy': output_generations = 1
    generated = []
    while len(generated) < output_generations:
        tokens_generated = 0
        words_in_text = set([x for x in original_prompt])
        end_of_sentence = False
        generated_text = ''
        n_sentences_generated = 0
        hidden = lm.init_hidden(1) if model_type == 'lstm_gulordava' else None
        prompt = original_prompt
        while n_sentences_generated != n_sentences or tokens_generated == max_len:
            output = deploy_lm(prompt, lm, lm_vocabulary, hidden = hidden, model_type =model_type)
            if model_type == 'lstm_gulordava': output, hidden = output 
            output_scores = output[-1]
            if temperature != 1:
                output_scores /= (temperature)
            if repetition_penalty != 1:
                for _ in words_in_text:
                    output_scores[_] /= (repetition_penalty)
            output_scores[lm_vocabulary.encode('<unk>')[0]] /= - 10 ** 10
            next_word_probability_distribution = F.log_softmax(output_scores, dim=-1)
            sorted_probabilities, sorted_words = torch.sort(next_word_probability_distribution, descending=True)
            print(sorted_probabilities[:10])
            if generation_type == 'greedy':
                generated_word_id = greedy(sorted_words)
            elif generation_type == 'sampling':
                generated_word_id = sampling(sorted_words, sorted_probabilities, filter_by=filter_by, top_k=top_k,
                                             top_p=top_p)
            generated_word_id = generated_word_id[0]
            generated_word = lm_vocabulary.decode([generated_word_id])
            if generated_word == '<eos>':
                n_sentences_generated += 1
            else:
                generated_text += generated_word + ' '
                words_in_text.add(generated_word_id)
            if model_type == 'lstm_gulordava': 
                prompt = generated_word
            else:
                prompt += generated_word
            tokens_generated += 1
        generated.append(generated_text)
    return generated


#x = lm_vocabulary.tokenizer.encode('My name is')

generation('Every day is', lm, lm_vocabulary, model_type = 'distilGPT2')

# context_indices = lm_vocabulary.encode('Every day is')
# print(context_indices)
# context = torch.tensor(context_indices)
# next_word_scores = language_model(context)[0][-1]
# print(next_word_scores[:10])


tensor(27)
tensor(27)
tensor(27)
tensor(27)
tensor(27)
tensor(27)
tensor(27)
tensor(27)
tensor(27)
tensor(27)
tensor(16791)
tensor(16791)
tensor(16791)
tensor(16791)
tensor(16791)
tensor(16791)
tensor(27)
tensor(27)
tensor(27)
tensor(27)
tensor(27)
tensor(27)
tensor(27)
tensor(27)
tensor(27)
tensor(27)
tensor(27)
tensor(27)
tensor(27)
tensor(27)
tensor(27)
tensor(27)
tensor(27)
tensor(27)
tensor(27)
tensor(27)
tensor(27)


KeyboardInterrupt: 

In [None]:
text = 'Everyone agrees that an old trash can smell'

print(generation(text, lm, lm_vocabulary, cuda = cuda, generation_type = 'sampling', filter_by = 'p', top_k = 0.5))

In [None]:
import pandas as pd


In [None]:
BOUNDARY_0 = 6
BOUNDARY_1 = 7

data = pd.read_csv('../data/RRCSimple Stimuli - RRC.csv', header = 0)
#print(data)
data.columns = ['id', 'sentence', 2,3 ]
data = list(data.sentence)

for i in range(len(data)):
    data[i] = data[i].replace('.', ' . ')
    prompt1 = ' '.join(data[i].split()[:BOUNDARY_0 + 1])
    prompt2 = ' '.join(data[i].split()[:BOUNDARY_1 + 1])
    print(prompt1, prompt2)
    
#     #print(beam_search(prompt1, lm, lm_vocabulary, model_type ='lstm_gulordava'))
#     #print(generation(prompt1, lm, lm_vocabulary, cuda = cuda, generation_type = 'sampling', filter_by = 'p', top_k = 0.3))
#     print(prompt2)
#     #print(beam_search(prompt2, lm, lm_vocabulary, model_type ='lstm_gulordava'))
#     print(generation(prompt2, lm, lm_vocabulary, cuda = cuda, generation_type = 'sampling', filter_by = 'p', top_k = 0.3))


In [None]:
data = pd.read_csv('../data/SynApt PreExp2 Stimuli - Sheet1.csv', header = 0)
data.columns = ['type', 0,'sentence', 2,3 ]
data = data[data.type == 'NP/Z']
data = list(data.sentence)

BOUNDARY_0 = 5
BOUNDARY_1 = 6

for i in range(len(data[:10])):
    data[i] = data[i].replace('.', ' . ')
    prompt1 = ' '.join(data[i].split()[:BOUNDARY_0 + 1])
    prompt2 = ' '.join(data[i].split()[:BOUNDARY_1 + 1])
    print(prompt1, prompt2)
    

#     print(beam_search(prompt1, lm, lm_vocabulary, model_type ='lstm_gulordava'))
#     #print(generation(prompt1, lm, lm_vocabulary, cuda = cuda, generation_type = 'sampling', filter_by = 'p', top_k = 0.3))
#     print(prompt2)
#     print(beam_search(prompt2, lm, lm_vocabulary, model_type ='lstm_gulordava'))
#     print(generation(prompt2, lm, lm_vocabulary, cuda = cuda, generation_type = 'sampling', filter_by = 'p', top_k = 0.3))
