In [2]:
%load_ext autoreload
%autoreload 2
import os
import argparse
import torch
import numpy as np
import utils_testing
import utils
from utils import str2bool
import colorama
import re
import sys
from gpt2_model.tokenization_gpt2 import GPT2Tokenizer
from gpt2_model.modeling_gpt2_condition import GPT2LMHeadModel
from gpt2_model.configuration_gpt2 import GPT2Config

parser = argparse.ArgumentParser(description='PyTorch Interactive LM')
#path
parser.add_argument('--checkpoint_topics', type=str, default='../models/',
                    help='topical model checkpoint to use')
parser.add_argument('--checkpoint_conditional', type=str, default='../models/',
                    help='conditional LM model checkpoint to use')
parser.add_argument('--emb_file', type=str, default='target_emb.pt',
                    help='path to a word embedding file')
parser.add_argument('--word_dict', type=str, default='../data/processed/wiki2016_gpt2/tensors_all_min100/dict_idx_compact',
                    help='path to a dictionary file')
#parser.add_argument('--outf', type=str, default='../gen_log/generated.txt',
#                    help='output file for generated text')

#parser.add_argument('--batch_size', type=int, default=3, metavar='N',
#                    help='batch size')
parser.add_argument('--num_sent_gen', type=int, default=3, metavar='N',
                    help='In each prompt, generate how many sentences')
parser.add_argument('--gen_sent_len', type=int, default=50, metavar='N',
                    help='In each prompt, generate sentences with length gen_sent_len')
parser.add_argument('--bptt', type=int, default=512,
                    help='sequence length')
parser.add_argument('--bptt_conditional', type=int, default=256,
                    help='sequence length')
parser.add_argument('--top_k_nn', type=int, default=5,
                    help='Representing each topic using how many words')

parser.add_argument('--cuda_topics', type=str2bool, nargs='?', default=True,
                    help='use CUDA for topical model')
parser.add_argument('--cuda_conditional', type=str2bool, nargs='?', default=True,
                    help='use CUDA for conditional LM')
parser.add_argument('--single_gpu', default=True, action='store_true',
                    help='use single GPU')

utils_testing.add_model_arguments(parser)

args = parser.parse_args("""--checkpoint_topics ../models/future_topic_all-20200106-222318
                         --checkpoint_conditional ../models/conditional_all-20200106-235956
                         --word_dict ../data/processed/wiki2016_gpt2/tensors_all_min100/dict_idx_compact""".split())

#new average model
args = parser.parse_args("""--checkpoint_topics ../models/future_topic_all-20200106-222318
                         --checkpoint_conditional ../models/conditional_all-20200115-160129
                         --word_dict ../data/processed/wiki2016_gpt2/tensors_all_min100/dict_idx_compact""".split())


In [3]:
if args.emb_file == "target_emb.pt":
    args.emb_file =  os.path.join(args.checkpoint_topics,"target_emb.pt")
device_topics = torch.device("cuda:0" if args.cuda_topics else "cpu")
device_conditional = torch.device("cuda:1" if args.cuda_conditional else "cpu")
with open(args.word_dict) as f_in:
    idx2word_freq = utils.load_idx2word_freq(f_in)
word_d2_idx = {}
for i in range(len(idx2word_freq)):
    w, freq = idx2word_freq[i]
    word_d2_idx[w] = i

parallel_encoder, parallel_decoder, encoder, decoder, word_norm_emb = utils.loading_all_models(args, idx2word_freq, device_topics)
output_emb_size = word_norm_emb.size(1)
#print(next(encoder.parameters()).device)

model_name = 'gpt2'

encoder_state_dict = torch.load(os.path.join(args.checkpoint_conditional, 'encoder.pt'), map_location=device_conditional)
gpt2_config = GPT2Config.from_pretrained(model_name)
gpt2_config.word_emb_dim = output_emb_size
model_condition = GPT2LMHeadModel.from_pretrained(model_name, state_dict = encoder_state_dict, config = gpt2_config).cuda(device_conditional)
#print(next(model_condition.parameters()).device)

tokenizer_GPT2 = GPT2Tokenizer.from_pretrained('distilgpt2')

encoder.eval()
decoder.eval()
model_condition.eval()

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): Laye

In [6]:
def show_future_topics(prompt, encoder, decoder, word_norm_emb, n_basis, top_k, bptt, idx2word_freq, tokenizer_GPT2, device_topics):
    tokenized_text = tokenizer_GPT2.tokenize(prompt, add_prefix_space=True)
    print(tokenized_text)
    indexed_tokens = tokenizer_GPT2.convert_tokens_to_ids(tokenized_text)
    start_idx = len(indexed_tokens) - bptt
    if start_idx > 0:
        indexed_tokens = indexed_tokens[start_idx:]
    feature = torch.tensor(indexed_tokens, dtype=torch.long, device=device_topics).unsqueeze(0)
    output_emb, past = parallel_encoder(feature)
    output_emb_last = output_emb[:,-1,:]
    basis_pred = decoder(output_emb_last)
    basis_norm_pred = basis_pred / (0.000000000001 + basis_pred.norm(dim = 2, keepdim=True) )
    
    basis_norm_pred = basis_norm_pred.permute(0,2,1)
    sim_pairwise = torch.matmul(word_norm_emb.unsqueeze(dim = 0), basis_norm_pred)
    top_value, top_index = torch.topk(sim_pairwise, top_k, dim = 1, sorted=True)
    top_value = top_value / (0.000000000001 + top_value.sum(dim = 1, keepdim=True) )
    #out_str = ''
    for j in range(n_basis):
        out_str = str(j) + ', '
        for k in range(top_k):
        #for k in range(3):
            word_nn = idx2word_freq[top_index[0,k,j].item()][0]
            #out_str += word_nn+' {:5.3f} '.format(top_value[0,k,j].item()) 
            out_str += word_nn+', ' 
        print(out_str)
    print()
            
    return top_value, top_index, feature

def conditional_generation(selected_conditions, gen_sent_len, num_sent_gen, word_d2_idx, idx2word_freq, model_condition, word_norm_emb, top_index, top_value, feature, bptt_conditional, tokenizer_GPT2, device_conditional):
    word_norm_emb_top = word_norm_emb[top_index,:]
    word_norm_emb_w_sum = torch.sum( word_norm_emb_top * top_value.unsqueeze(-1), dim = 1) / top_value.unsqueeze(-1).sum(dim = 1)
    word_w_sum_norm = word_norm_emb_w_sum / (0.000000000001 + word_norm_emb_w_sum.norm(dim = -1, keepdim=True))
    word_w_sum_norm = word_w_sum_norm.to(device=device_conditional)
    selected_topic_idx = []
    selected_word_idx = []
    for x in selected_conditions:
        if isinstance(x, int):
            selected_topic_idx.append(x)
        else:
            if x not in word_d2_idx:
                print('Warning: Ignore the word '+x+' because it is too rare')
                continue
            selected_word_idx.append(word_d2_idx[x])
    selected_topic_idx = torch.tensor(np.sort(selected_topic_idx), dtype=torch.long, device = device_conditional)
    selected_word_idx = torch.tensor(selected_word_idx, dtype=torch.long, device = device_conditional)
    
    end_int = feature.size(1)
    max_prompt_len = bptt_conditional - gen_sent_len
    start_int = 0
    if end_int > max_prompt_len:
        start_int = end_int - max_prompt_len
    insert_loc_list = []
    insert_loc_list.append(end_int - 1)
    insert_loc_truncated = np.array(insert_loc_list) - start_int
    
    feature_expanded = feature[0,start_int:end_int].unsqueeze(0).expand(num_sent_gen,end_int - start_int).to(device = device_conditional)
    future_emb_chosen_topics = word_w_sum_norm[0, selected_topic_idx,:]
    future_emb_chosen_words = word_norm_emb[selected_word_idx,:]
    num_selection = future_emb_chosen_topics.size(0) + future_emb_chosen_words.size(0)
    future_emb_chosen = torch.cat([future_emb_chosen_topics, future_emb_chosen_words],dim=0).unsqueeze(0).expand(num_sent_gen,num_selection,word_norm_emb.size(-1))
    future_emb_chosen_arr = []
    future_emb_chosen_arr.append(future_emb_chosen)
    truncate_idx = 0
    output = utils_testing.sample_seq(model_condition, feature_expanded, insert_loc_truncated[truncate_idx:], future_emb_chosen_arr[truncate_idx:], gen_sent_len, device_conditional)
    output_org = utils_testing.sample_seq(model_condition, feature_expanded, None, None, gen_sent_len, device_conditional)
    for j in range(num_sent_gen):
        generated_sent = tokenizer_GPT2.convert_tokens_to_string( [tokenizer_GPT2._convert_id_to_token(x) for x in output[j, :].tolist()] )
        utils_testing.print_sampled_sent(selected_topic_idx.tolist(), generated_sent, top_index[0,:,:], idx2word_freq, sys.stdout, 'conditional '+ str(j), selected_word_idx.tolist())
    for j in range(num_sent_gen):
        generated_sent_org = tokenizer_GPT2.convert_tokens_to_string( [tokenizer_GPT2._convert_id_to_token(x) for x in output_org[j, :].tolist()] )
        utils_testing.print_sampled_sent(selected_topic_idx.tolist(), generated_sent_org, top_index[0,:,:], idx2word_freq, sys.stdout, 'original '+ str(j), selected_word_idx.tolist())

In [7]:
prompt = "Google announces a new product"
#prompt = "Barack Obama writes a new book"
#prompt = "Barack Obama writes a new book on spirituality and the role of religion in society"
#prompt = "Barack Obama became the first African"
#prompt = "The magician curses the zombie"
#with torch.no_grad():
#    top_value, top_index, feature = show_future_topics(prompt, parallel_encoder, parallel_decoder, word_norm_emb, args.n_basis, args.top_k_nn, args.bptt, idx2word_freq, tokenizer_GPT2, device_topics)
#prompt = "Trump, the leader of the Republican"
top_value, top_index, feature = show_future_topics(prompt, parallel_encoder, parallel_decoder, word_norm_emb, args.n_basis, args.top_k_nn, args.bptt, idx2word_freq, tokenizer_GPT2, device_topics)



['ĠGoogle', 'Ġannounces', 'Ġa', 'Ġnew', 'Ġproduct']
0, development, design, innovation, designing, developing, 
1, interface, device, controller, configuration, system, 
2, websites, web, website, site, links, 
3, released, version, versions, release, releases, 
4, retailer, retailers, selling, brands, retail, 
5, 2011, 2010, 2012, 2009, 2008, 
6, provide, specific, use, allow, utilize, 
7, implementations, APIs, interoperate, backend, browsers, 
8, investment, investments, profitability, equity, financial, 
9, Software, Desktop, Apps, Networking, Windows, 



In [7]:
#selected_conditions = [4,7,2,'happy'] #[4,8,2,'happy'] #['zombie'] #[2] #['zombie'] #[4,8,2,'happy']
#selected_conditions = [9, 8]
selected_conditions = [6, 'story']
gen_sent_len = args.gen_sent_len
#gen_sent_len = 100
num_sent_gen = args.num_sent_gen
conditional_generation(selected_conditions, gen_sent_len, num_sent_gen, word_d2_idx, idx2word_freq, model_condition, word_norm_emb, top_index, top_value, feature, args.bptt_conditional, tokenizer_GPT2, device_conditional)


conditional 0:  head coach to head a professional NFL franchise. Former Philadelphia Eagles head coach Mike Shanahan [31mresigned[0m, following an attempt to sign a book signing deal with a [31mformer[0m Colts team. A new [31mstory[0m has it that one of the youngest players ever signed was a young college
6 topic: {'former': 1, 'resigned': 1}
word: {'story': 1}

conditional 1:  U.S. Senator from Massachusetts to be reelected. The [31mstory[0m tells the [31mstory[0m of two women who are friends and confidants in the United States Senate. A [31mformer[0m friend, Senator Hillary Clinton, also [31mjoined[0m the Senate in 2001. In the [31mstory[0m
6 topic: {'former': 1, 'joined': 1}
word: {'story': 3}

conditional 2:  to win a general election election     and the first African-American to win election as a Democratic President   The [31mstory[0m featured several other  black  and  African American  figures in the Democratic Party, but mostly [31mretired[0m or turned ove