In [1]:
%load_ext autoreload
%autoreload 2

import os, sys
import numpy as np
import torch
import torch.nn as nn
import time
import math

from utils import seed_all_randomness, load_corpus, str2bool
import utils_testing

import torch.nn.functional as F
from gpt2_model.tokenization_gpt2 import GPT2Tokenizer
from gpt2_model.modeling_gpt2_multi import GPT2MultiLMHeadModel, GPT2MoSLMHeadModel, GPT2LMHeadModel, GPT2Model
from gpt2_model.configuration_gpt2 import GPT2Config
import argparse



In [2]:
parser = argparse.ArgumentParser(description='PyTorch Train Future Topic Prediction')

def add_model_options(parser, suffix):
    parser.add_argument('--model_path'+suffix, type=str,  default='./models/',
                        help='path to load the model')
    parser.add_argument('--load_file_name'+suffix, type=str,  default='LM_weights.pt',
                    help='file name of saved model')
    parser.add_argument('--n_facet'+suffix, type=int, default=1,
                        help='number of facets')
    parser.add_argument('--n_facet_hidden'+suffix, type=int, default=1,
                        help='number of hidden states')
    parser.add_argument('--n_facet_MLP'+suffix, type=int, default=0,
                        help='size of compression layer')
    parser.add_argument('--n_facet_window'+suffix, type=int, default=0,
                        help='size of windows we look at')
    parser.add_argument('--n_facet_effective'+suffix, type=int, default=1,
                        help='number of facet heads')
    
    parser.add_argument('--use_avg'+suffix, type=str2bool, nargs='?', default=False,
                        help='Whether we want to add an average embedding term to stablize the training')
    parser.add_argument('--use_MoS'+suffix, type=str2bool, nargs='?', default=True,
                        help='Whether we want to do the normalization for each facet (i.e., use mixture of softmax)')
    parser.add_argument('--weight_mode'+suffix, type=str,  default='dynamic',
                        help='could be empty, dynamic, and statis')
    parser.add_argument('--use_proj_bias'+suffix, type=str2bool, nargs='?', default=True,
                        help='Whether we want to add an bias term in the linear projection layer')
    parser.add_argument('--efficient_mode'+suffix, type=str,  default='None',
                        help='how to save computational time')
    parser.add_argument('--masking_ratio'+suffix, type=float, default=-1,
                        help='dynamically use single facets. Use -1 to turn off this efficient mode')
    parser.add_argument('--last_num'+suffix, type=int, default=0,
                        help='number of facet that does not have multiple partitions')
    
    

# parser.add_argument('--data', type=str, default='./data/processed/wiki2016_gpt2/',
#                     help='location of the data corpus')
# parser.add_argument('--tensor_folder', type=str, default='tensors_all_min100',
#                     help='location of the data corpus')
# parser.add_argument('--model_path_multi', type=str,  default='./models/',
#                     help='path to load the multi-facet model')
# parser.add_argument('--model_path_single', type=str,  default='./models/',
#                     help='path to load the single-facet model')
parser.add_argument('--outf', type=str, default='gen_log/generated.txt',
                    help='output file for generated text')

parser.add_argument('--batch_size', type=int, default=4, metavar='N',
                    help='batch size')
parser.add_argument('--bptt', type=int, default=256,
                    help='sequence length')
parser.add_argument('--max_batch_num', type=int, default=100,
                    help='number of batches for evaluation')
parser.add_argument('--num_sent_gen', type=int, default=3, metavar='N',
                    help='In each prompt, generate how many sentences')
parser.add_argument('--gen_sent_len', type=int, default=50, metavar='N',
                    help='In each prompt, generate sentences with length gen_sent_len')

parser.add_argument('--run_eval', type=str2bool, nargs='?', default=True,
                    help='If false, we only print the results')

parser.add_argument('--cuda', type=str2bool, nargs='?', default=True,
                    help='use CUDA')
parser.add_argument('--seed', type=int, default=1111,
                    help='random seed')

add_model_options(parser, "_multi")
add_model_options(parser, "_single")


def load_model(model_path, load_file_name, gpt2_config, n_facet, n_facet_window, n_facet_hidden, n_facet_MLP, n_facet_effective, weight_mode, use_avg, use_MoS, use_proj_bias, efficient_mode, last_num, device, cuda):
    LM_state_dict = torch.load(os.path.join(model_path, load_file_name), map_location=device)
    #GPT2_LM = GPT2MultiLMHeadModel.from_pretrained(model_name, state_dict = LM_state_dict)
    GPT2_encoder = GPT2Model(gpt2_config)
    if use_MoS:
        #weight_mode = 'dynamic'
        #weight_mode = 'static'
        #weight_mode = ''
        GPT2_LM = GPT2MoSLMHeadModel(gpt2_config, GPT2_encoder, n_facet, n_facet_hidden, weight_mode, use_proj_bias, 
                                     n_facet_window = n_facet_window, n_facet_MLP = n_facet_MLP, efficient_mode=efficient_mode, 
                                     device=device, n_facet_effective_in=n_facet_effective, last_num=last_num)
    else:
        GPT2_LM = GPT2MultiLMHeadModel(gpt2_config, GPT2_encoder, n_facet, n_facet_hidden, use_avg)
    GPT2_LM.load_state_dict(LM_state_dict)
    if args.cuda:
        GPT2_LM = GPT2_LM.cuda()
    return GPT2_LM




In [7]:
#args_str_list = ("--model_path_multi ./models/gpt2_wiki_n3_init-20210320-145419").split()
#args_str_list = ("--model_path_multi ../models/gpt2_wiki_n3_init-20210218-133219 --load_file_name_multi LM_weights_8.pt --n_facet_multi 3 --n_facet_effective_multi 3 --n_facet_window_multi -2 --n_facet_hidden_multi 3 --n_facet_MLP_multi -1" +

args_str_list = ("--model_path_multi ../models/gpt2_wiki_n3_init-20210320-145419 --load_file_name_multi LM_weights_8.pt --n_facet_multi 6 --n_facet_effective_multi 3 --n_facet_window_multi -2 --n_facet_hidden_multi 3 --n_facet_MLP_multi -1 --efficient_mode_multi even_last_2" +
                         " --model_path_single ../models/gpt2_wiki_n1_init-20210316-011244 --load_file_name_single LM_weights_8.pt --n_facet_single 1 --n_facet_window_single -2 --n_facet_hidden_single 3 --n_facet_MLP_single -1" +
                          " --seed 11 --cuda True").split()

#args_str_list = ("--model_path_multi ../models/gpt2_wiki_n3_init-20210224-004334 --load_file_name_multi LM_weights_8.pt --n_facet_multi 3 --n_facet_effective_multi 3 --n_facet_window_multi -2 --n_facet_hidden_multi 3 --n_facet_MLP_multi -1" +
#                         " --model_path_single ../models/gpt2_wiki_n1_init-20210224-004343 --load_file_name_single LM_weights_8.pt --n_facet_single 1 --n_facet_window_single -2 --n_facet_hidden_single 3 --n_facet_MLP_single -1" +
#                          " --seed 11 --cuda True").split()

#model_name = 'gpt2-medium'
model_name = 'gpt2'


#print(args_str_list)
args = parser.parse_args(args_str_list)


# Set the random seed manually for reproducibility.
seed_all_randomness(args.seed,args.cuda)

print('Args: {}'.format(args))

device = torch.device("cuda" if args.cuda else "cpu")


gpt2_config = GPT2Config.from_pretrained(model_name)
gpt2_config.output_hidden_states = True


model_multi = load_model(args.model_path_multi, args.load_file_name_multi, gpt2_config, args.n_facet_multi, args.n_facet_window_multi, 
                         args.n_facet_hidden_multi, args.n_facet_MLP_multi, args.n_facet_effective_multi, 
                         args.weight_mode_multi, args.use_avg_multi, args.use_MoS_multi, args.use_proj_bias_multi, 
                         args.efficient_mode_multi, args.last_num_multi, device, args.cuda)

model_single = load_model(args.model_path_single, args.load_file_name_single, gpt2_config, args.n_facet_single, args.n_facet_window_single, 
                         args.n_facet_hidden_single, args.n_facet_MLP_single, args.n_facet_effective_single, 
                         args.weight_mode_single, args.use_avg_single, args.use_MoS_single, args.use_proj_bias_single, 
                         args.efficient_mode_single, args.last_num_single, device, args.cuda)

tokenizer_GPT2 = GPT2Tokenizer.from_pretrained('gpt2')
vocab_map = dict(tokenizer_GPT2.encoder, **tokenizer_GPT2.added_tokens_encoder)
#print(vocab_map)
#print(tokenizer_GPT2.encoder)
vocab_size = len(vocab_map)
idxl2token= ['']*vocab_size
for w in vocab_map:
    idx=vocab_map[w]
    idxl2token[idx]=w
#print(idxl2token)

#idxl2token = tokenizer_GPT2.decode(torch.tensor(range()))

Args: Namespace(batch_size=4, bptt=256, cuda=True, efficient_mode_multi='even_last_2', efficient_mode_single='None', gen_sent_len=50, last_num_multi=0, last_num_single=0, load_file_name_multi='LM_weights_8.pt', load_file_name_single='LM_weights_8.pt', masking_ratio_multi=-1, masking_ratio_single=-1, max_batch_num=100, model_path_multi='../models/gpt2_wiki_n3_init-20210320-145419', model_path_single='../models/gpt2_wiki_n1_init-20210316-011244', n_facet_MLP_multi=-1, n_facet_MLP_single=-1, n_facet_effective_multi=3, n_facet_effective_single=1, n_facet_hidden_multi=3, n_facet_hidden_single=3, n_facet_multi=6, n_facet_single=1, n_facet_window_multi=-2, n_facet_window_single=-2, num_sent_gen=3, outf='gen_log/generated.txt', run_eval=True, seed=11, use_MoS_multi=True, use_MoS_single=True, use_avg_multi=False, use_avg_single=False, use_proj_bias_multi=True, use_proj_bias_single=True, weight_mode_multi='dynamic', weight_mode_single='dynamic')


In [10]:
def generate_next_word(model, input_token):
    top_k = 10
    #temperature = 1
    outputs, emb_div, count_best_arr, weight = model(input_token)#, labels=input_token)
    #print(outputs)
    probs = outputs[0]
    probs = probs[:, -1, :] #/ temperature
    #probs = F.softmax(logits, dim=-1)
    probs_top, index_top = torch.topk(probs, k=top_k)
    return probs_top, index_top

def visualize_top_k_words(probs_top, index_top, idxl2token):
    probs_top_list = probs_top.squeeze().tolist()
    index_top_list = index_top.squeeze().tolist()
    #print(index_top_list)
    topk = len(probs_top_list)
    for i in range(topk):
        print("{}  {}  {}".format(i,idxl2token[index_top_list[i]],probs_top_list[i]))
    print("")

In [17]:
#tokenizer_GPT2.encode

model_multi.eval()
model_single.eval()

#For medium
#prompt = "There are a queen and a man in front of me, and I talk to the"
#prompt = "There are an uncle and a woman in front of me, and I talk to the"
#prompt = "There are an aunt and a man in front of me, and I talk to the"
#prompt = "There are a niece and a woman in front of me, and I talk to the"
#prompt = "John, Kathryn and Mary are in their home, and John is watching TV with "
#prompt = "John, Kathryn and Mary are in their home, and Mary is watching TV with "

#For small
prompt = "There are an uncle and a woman in front of me, and I talk to the"
#prompt = "John, Jenny and Mary are in their home, and John is watching TV with "
#prompt = "There are plates and balloons in front of me, and I pick up the"

#prompt = "Obama plans to visit Beijing and Russian, and his flight first arrives at"
#prompt = "There are a few children and a dog in front of me, and I choose to play with the"
#prompt = "There are a few cats and a dog in front of me, and I choose to play with the"
#prompt = "There are plates and a balloon in front of me, and I pick up the"
#prompt = "There are balls and a balloon in front of me, and I pick up the"
#prompt = "There are a king and a woman in front of me, and I talk to the"
#prompt = "There are a policeman and a bride in front of me, and I talk to the"
#prompt = "There are a prince and a woman in front of me, and I talk to the"
#prompt = "There are a prince and women in front of me, and I talk to the"
#prompt = "There are a man and a princess in front of me, and I talk to the"
#prompt = "There are a prince and a bride in front of me, and I talk to the"
#prompt = "There are a son and a woman in front of me, and I talk to the"
#prompt = "There are a sister and a man in front of me, and I talk to the"
#prompt = "There are the poor and the quiet in front of me, and I talk to the"
#prompt = "There are the rich and the loud in front of me, and I talk to the"


#prompt = "John and Mary are in their home, and a bullet hits"
#prompt = "A king and a woman are in their home, and the"
#prompt = "A king and a woman go to a park, and a dog attacks one of them, who is"
#prompt = "Obama plans to visit Paris and England, and his flight first arrives at"
#prompt = "I went to Paris and England before, and I love one of the places more, which is"
#prompt = "I went to France and London before, and I love one of the places more, which is"
#prompt = "I went to Boston and California before, and I love one of the places more, which is"
#prompt = "I went to Paris and China before, and I love one of the places more, which is"
#prompt = "Obama plans to visit Greece and Baghdad, and his flight first arrives at"
#prompt = "I went to Greece and Baghdad before, and I love one of the places more, which is"

tokenized_text = tokenizer_GPT2.tokenize(prompt, add_prefix_space=True)
indexed_tokens = tokenizer_GPT2.convert_tokens_to_ids(tokenized_text)
indexed_tokens_tensor= torch.tensor(indexed_tokens, device=device, dtype=torch.long)
indexed_tokens_tensor = indexed_tokens_tensor.unsqueeze(0)

print(prompt)
print(indexed_tokens_tensor)

probs_top, index_top = generate_next_word(model_multi, indexed_tokens_tensor)
print("Multi-facet")
visualize_top_k_words(probs_top, index_top, idxl2token)

probs_top, index_top = generate_next_word(model_single, indexed_tokens_tensor)
print("Single-facet")
visualize_top_k_words(probs_top, index_top, idxl2token)

There are an uncle and a woman in front of me, and I talk to the
tensor([[1318,  389,  281, 7711,  290,  257, 2415,  287, 2166,  286,  502,   11,
          290,  314, 1561,  284,  262]], device='cuda:0')
Multi-facet
0  Ġuncle  0.14971046149730682
1  Ġwoman  0.11480355262756348
2  Ġman  0.08414516597986221
3  Ġtwo  0.027981584891676903
4  Ġother  0.027950972318649292
5  Ġgirl  0.02295503579080105
6  Ġmother  0.018764926120638847
7  Ġboy  0.01686326414346695
8  Ġlady  0.014466330409049988
9  Ġaunt  0.014079896733164787

Single-facet
0  Ġwoman  0.14697295427322388
1  Ġman  0.10963626205921173
2  Ġother  0.032980743795633316
3  Ġtwo  0.02860880084335804
4  Ġgirl  0.02445436269044876
5  Ġuncle  0.02117065340280533
6  Ġmother  0.018148234114050865
7  Ġlady  0.016263790428638458
8  Ġboy  0.013309802860021591
9  Ġpriest  0.011639346368610859

