In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteriaList, StoppingCriteria
# from nltk.tokenize import word_tokenize
from nltk.tokenize.treebank import TreebankWordDetokenizer

# Load finetuned model from huggingface
tokenizer = AutoTokenizer.from_pretrained("huangtuoyue/GPT2-large-GOTfinetuned_v5")
model = AutoModelForCausalLM.from_pretrained("huangtuoyue/GPT2-large-GOTfinetuned_v5")

Downloading:   0%|          | 0.00/3.13G [00:00<?, ?B/s]

In [2]:
# Use GPU if possible
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50259, 1280)
    (wpe): Embedding(1024, 1280)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): GPT2Block(
        (ln_1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): GPT2Block(
        (ln_1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout)

In [3]:
# for stop dialogue generation
class KeywordsStoppingCriteria(StoppingCriteria):
    def __init__(self, keywords_ids:list):
        self.keywords = keywords_ids

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        if input_ids[0][-1] in self.keywords:
            return True
        return False

# for context generation #
# 1. stop if we get [BOS] -- Jon talking
stop_context_words = ['[BOS]']
stop_context_ids = [tokenizer.encode(w)[0] for w in stop_context_words]
stop_context_criteria = KeywordsStoppingCriteria(stop_context_ids)
# 2. exclude EOS
bad_context_words = ['[EOS]', ' He', ' She', ' he']
bad_context_ids = tokenizer(bad_context_words, add_special_tokens=False).input_ids

# for dialogue generation #
# 1. stop if we get [EOS] -- Jon stops talking
stop_dialogue_words = ['[EOS]']
stop_dialogue_ids = [tokenizer.encode(w)[0] for w in stop_dialogue_words]
stop_dialogue_criteria = KeywordsStoppingCriteria(stop_dialogue_ids)
# 2. exclude BOS
bad_dialogue_words = ['[BOS]', ' "', ' Jon', ' said', ' he', ' Snow', ' He', ' She', ' him']
bad_dialogue_ids = tokenizer(bad_dialogue_words, add_special_tokens=False).input_ids

def generate(text, pred_dialogue):

    encoded_input = tokenizer(text, return_tensors='pt').input_ids
    
    # for context generation
    if not pred_dialogue: 

      outputs = model.generate(encoded_input.to(device), do_sample=True, min_length=30, max_new_tokens=80, pad_token_id = 50256,
                              temperature=0.95, top_p = 1, repetition_penalty = 1.1,
                              stopping_criteria=StoppingCriteriaList([stop_context_criteria]), bad_words_ids = bad_context_ids)
    
    # for dialogue generation
    else:
      outputs = model.generate(encoded_input.to(device), do_sample=True, min_length=3, max_new_tokens=30, pad_token_id = 50256, 
                              temperature=1.1, top_p = 1, repetition_penalty = 1, 
                              stopping_criteria=StoppingCriteriaList([stop_dialogue_criteria]), bad_words_ids = bad_dialogue_ids)
    
    # deceode outputs and keep speical_tokens
    res = tokenizer.batch_decode(outputs, skip_speical_tokens=False)
    return res[0]

In [None]:
f = open("./play.txt","w")
incorrect_generation = 0
total = 0
context = "Jon had just been told that he was going to take part in a special mission for the Lord Commander of the Night's Watch."
print(context)
f.write(context)
ITERATION = 5 # play 5 interation
BOS = "[BOS]"
EOS = "[EOS]"
i = 0
options = []
pred_dialogue = False

while i < ITERATION:
    prev_context_end_idx = len(context)
    context_list = context.split()

    # check context length #
    if len(context_list) < 50: # used last 50 words to generate new context
        run_context = TreebankWordDetokenizer().detokenize(context_list)
        # print(run_context)
        last_idx = len(run_context)

    else: 
        run_context = TreebankWordDetokenizer().detokenize(context_list[-50:])
        # print(run_context)
        last_idx = len(run_context)
    
    # dialogue generation #
    if pred_dialogue:
        # print(run_context)
        pred = generate(run_context, pred_dialogue)

    # context generation #
    else:
        # print(run_context)
        pred = generate(run_context, pred_dialogue)
    
    # count total generation time #
    total += 1

    # check if generation includes [BOS] & [EOS] #
    words = pred.split()

    eos_list = []
    bos_list = []
    for j, word in enumerate(words): 
        if word == BOS: 
            bos_list.append(j)
        elif word == EOS:
            eos_list.append(j)
    
    # Validity for context generation #
    if len(words) <= 55 and not pred_dialogue:
        # print("Invalid Context: context too short \n")
        incorrect_generation += 1
        pred_dialogue = False

    elif words[-1] != '[BOS]' and not pred_dialogue: # no [BOS]
        # print("Invalid Context: no BOS or too long \n")
        incorrect_generation += 1
        pred_dialogue = False

    elif words[-1] == '[BOS]' and not pred_dialogue: # correct context
        context = pred
        # print(context)
        print(context[last_idx:-5])
        print("\n")
        f.write(context[last_idx:-5])
        f.write("\n")

        pred_dialogue = True

    # Validity for dialogue generation #
    elif pred_dialogue:
        if words[-1] != '[EOS]':
            # print("Invalid Dialogue: no EOS or too long \n")
            incorrect_generation += 1

        elif words[-1] == '[EOS]': # correct dialogue
            # print(bos_list, eos_list)
            option = TreebankWordDetokenizer().detokenize(words[bos_list[-1]+1:eos_list[-1]])
            options.append(option)
            pred_dialogue = True
            if len(options) == 3: 
                print("\n")
                print("============= Select Dialogue ============")
                print("A: {} \nB: {} \nC: {}".format(options[0], options[1], options[2]))
                d = {"A": options[0], "B": options[1], "C": options[2]}
                valid = False

                user_input = input()
                # while not valid:
                #     user_input = input()
                #     if user_input in ["A", "B", "C"]: 
                #         valid = True
                print("Choice Selected: " + user_input)
                print("============= Context Continued ============")
                print("\n")

                f.write("\n")
                # f.write("============= Select Dialogue ============")
                f.write("\n")
                f.write("A:{} \nB:{} \nC:{}".format(options[0], options[1], options[2]))
                # f.write("\n\n")
                # f.write("Choice Selected: " + user_input + "\n")
                # f.write("============= Context Continued ============")
                f.write("\n\n")

                context += (" " + d[user_input] + " [EOS]")
                # print(context)
                options = []
                i += 1
                pred_dialogue = False

                # f.write(str(d))

f.close()