# Combine with HMM

In [6]:
import os
import json
import time
import argparse

from tqdm import tqdm
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer
from transformers import NoRepeatNGramLogitsProcessor, RepetitionPenaltyLogitsProcessor
from transformers import BeamSearchScorer, LogitsProcessorList, LogitsProcessor, StoppingCriteriaList, StoppingCriteria, MaxLengthCriteria

from HMM.hmm_model import *
from HMM.DFA_model import *
from model_utils import (
    encode_with_messages_format,
    get_prefix_suffix_tokens_for_HMM,
    get_sequence_scores,
    ConstraintLogitsProcessor
)



In [7]:

def init():
    global device
    global CUDA_CORE
    global args

    arg_parser = argparse.ArgumentParser()
    arg_parser.add_argument('--device', default='cuda', type=str)
    arg_parser.add_argument('--cuda_core', default='1', type=str)
    arg_parser.add_argument('--hmm_batch_size', default=256, type=int)
    arg_parser.add_argument('--hmm_model_path', default=None, type=str)
    arg_parser.add_argument('--llama_model_path', default='gpt2', type=str)
    arg_parser.add_argument('--do_beam_search', action='store_true')
    arg_parser.add_argument('--debug', action='store_true')

    args = arg_parser.parse_args([
        # "--hmm_model_path", "/local1/hzhang19/matcha/models/hmm_llama-story-pretrain-finetune_4096_64/checkpoint-60.weight.th",
        "--hmm_model_path", "/local1/hzhang19/matcha/models/hmm_llama-story-pretrain-finetune_32768_64/checkpoint-50.weight.th",
        # "--hmm_model_path", "/local1/hzhang19/matcha/models/hmm_llama-story-finetune_32768_64/checkpoint-60.weight.th",
        "--llama_model_path", "/local1/ponienkung/CtrlGen/output/NewFinetunePretrained_Filtered_StoryPretrain-TULU-LLAMA2",
        # "--llama_model_path", "/local1/ponienkung/CtrlGen/output/NewFinetuneTULU_Filtered_StoryPretrain-TULU-LLAMA2",
        # Old ones
        # "--hmm_model_path", "/local1/hzhang19/matcha/models/hmm_llama-story-cont-para_32768_64/checkpoint-90.weight.th",
        # "--llama_model_path", "/local1/ponienkung/CtrlGen/output/NewFinetune_cont_para_2K_8K_2K_StoryPretrain-TULU-LLAMA2",
        "--hmm_batch_size",  "2",
        "--cuda_core", "7",
    ])
    device = args.device
    # os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_core
    torch.cuda.set_device(int(args.cuda_core))

def load_models():
    global tokenizer
    global llama_model
    global hmm_model
    global keyphrase_builder
    global end_sentence_builder
    global word_count_builder
    global trivial_builder
    global eos_builder
    try:
        print(f'loading llama2 from {args.llama_model_path} ...')
        llama_model = LlamaForCausalLM.from_pretrained(args.llama_model_path).to(device)
        llama_model.half()
        tokenizer = LlamaTokenizer.from_pretrained(args.llama_model_path)    

        print(f'loading hmm from {args.hmm_model_path} ...')
        hmm_model = HMM(args.hmm_model_path)
        hmm_model.to(device)

        print(f'constructing DFA builders ...')
        keyphrase_builder = KeyphraseBuilder(tokenizer, 32000)
        end_sentence_builder = EndSentenceBuilder(tokenizer, 32000)
        word_count_builder = WordCountBuilder(tokenizer, 32000)
        trivial_builder = TrivialBuilder(tokenizer, 32000)
        eos_builder = EOSBuilder(tokenizer, 32000)
    except Exception as e:
        print(f"Cannot Load args.model {args.model_name_or_path} because of the following exception:\n {e}")
        print("Exit the process...")
        exit(0)


In [8]:
init()
load_models()

loading llama2 from /local1/ponienkung/CtrlGen/output/NewFinetunePretrained_Filtered_StoryPretrain-TULU-LLAMA2 ...


Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00,  2.39s/it]


loading hmm from /local1/hzhang19/matcha/models/hmm_llama-story-pretrain-finetune_32768_64/checkpoint-50.weight.th ...
constructing DFA builders ...


In [14]:
# input_json = {
#     'Prefix': 'Once upon a time there was an old mother pig who had one hundred little pigs and not enough food to feed them. So when they were old enough, ', 
#     'Suffix': 'a story about the 92nd little pig.', 
#     'Instruct': '', 
#     'Prior': '', 
#     'Operation': 'Continuation', 
#     'temperature': 0.95, 
#     'num_return_sequences': 10, 
#     'num_beams': 10, 
#     'no_repeat_ngram_size': 2, 
#     'top_p': 1.0
#     'max_tokens': 50, --> token_contraints [min, max]
# }
# Need to add constraints here in the input_json
# 'word_contraints':
# 'keyword_constraints':

def prompt(input_json):
    # Get the text and operation
    Prefix, Prior, Suffix, Instruct, Operation = input_json['Prefix'], input_json['Prior'], input_json['Suffix'], input_json['Instruct'], input_json['Operation']
    Prefix = Prefix.rstrip(" ")
    # Get the constraints
    token_constraint, word_contraint, keyword_constraint = input_json["token_constraint"], input_json["word_contraint"], input_json["keyword_constraint"]
    max_tokens = token_constraint[1]
    # Get generation config
    temperature = input_json['temperature']
    num_return_sequences = input_json['num_return_sequences']
    num_beams = input_json['num_beams']
    no_repeat_ngram_size = input_json['no_repeat_ngram_size']
    top_p = input_json['top_p']
    
    # TODO
    if word_contraint != []:
        max_tokens = max(max_tokens, int(1.5 * word_contraint[1]))
    
        # max_tokens = max_tokens
    # Get prefix, suffix tokens for HMM
    prefix_tokens, suffix_tokens = get_prefix_suffix_tokens_for_HMM(Prefix, Suffix, tokenizer)

    # Construct DFA graph
    # dfa_graphs = [eos_builder.build()]
    dfa_graphs = []
    has_constraints = False
    if keyword_constraint != []:
        print("Build Keyword")
        dfa_graphs.append(keyphrase_builder.build(keyword_constraint))
        has_constraints = True
    if word_contraint != []:
        print("Build Word Length")
        dfa_graphs.append(word_count_builder.build(word_contraint[0], word_contraint[1]))
        has_constraints = True
    # MODIFIED: Comment out this
    if (Suffix == '') and has_constraints:
        dfa_graphs.append(end_sentence_builder.build())
    if dfa_graphs != []:
        dfa_model = DFAModel(DFA_prod(dfa_graphs, mode='intersection'))
    else:
        dfa_model = DFAModel(trivial_builder.build())
    if (not has_constraints) and Suffix == '': # Freeform continuation
        USE_HMM = False
    else:
        USE_HMM = True


    # Get input_ids
    prompt_tokens = encode_with_messages_format(
        Prefix = Prefix,
        SoftControl = Instruct, 
        Prior = Prior,
        Suffix = Suffix,
        tokenizer = tokenizer, 
        operation = Operation
        )
    
    input_ids = torch.tensor([prompt_tokens] * num_beams, device=device)
    max_length = len(prompt_tokens) + max_tokens
    print(f"max_length: {max_length}, max_tokens: {max_tokens}")
    
    with torch.no_grad():
        past_key_values = llama_model(input_ids[:, :-1], return_dict=True).past_key_values

        model_kwargs = {
            'past_key_values': past_key_values,
        }

        hmm_model.initialize_cache(prefix_tokens, suffix_tokens,
            [(1, max_tokens)], dfa_model)

        hmm_config = {
            'hmm_prompt_len': len(prompt_tokens),
            'hmm_prefix': prefix_tokens,
            'hmm_suffix': suffix_tokens,
            'hmm_generation_offset': len(prefix_tokens),
            'hmm_min_tokens': 1,
            'hmm_max_tokens': max_tokens,
            'hmm_batch_size': args.hmm_batch_size,
        }

        # Init logits processors
        stopping_criteria = StoppingCriteriaList([
            MaxLengthCriteria(max_length=max_length)])
        logits_processor = LogitsProcessorList([
            ConstraintLogitsProcessor(hmm_model, hmm_config, temperature=temperature)] if USE_HMM else [])
        if no_repeat_ngram_size > 0:
            logits_processor.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size))
        print("Logits Processor: ", logits_processor)
        # If use beamsearch
        if args.do_beam_search:
            print("Do beamsearch")
            beam_scorer = BeamSearchScorer(
                batch_size=1,
                num_beams=num_beams,
                num_beam_hyps_to_keep=num_beams,
                device=llama_model.device,
            )
            outputs= llama_model.beam_search(
                input_ids,
                beam_scorer,
                logits_processor=logits_processor,
                stopping_criteria=stopping_criteria,
                **model_kwargs
            )
        else:
            outputs= llama_model.sample(
                input_ids,
                logits_processor=logits_processor,
                stopping_criteria=stopping_criteria,
                **model_kwargs
            )
            

        # clean up output
        sequences = outputs.tolist()
        output_ids = []
        sequence_ids = []
        logits_mask = []
        for seq in sequences:                    
            seq = seq[len(prompt_tokens):]                    
            while seq[-1] == 0:
                seq = seq[:-1]
            while seq[-1] == 2:
                seq = seq[:-1]
            # for i in range(min(len(suffix_tokens), len(seq)), 0, -1):
            #     if seq[-i:] == list(suffix_tokens[:i]):
            #         seq = seq[:-i]
            #         break
            output_ids.append(seq)
            sequence_ids.append(list(prompt_tokens) + list(seq) + list(suffix_tokens))
            # after_prefix = len(seq) + len(suffix_tokens)
            # check_logits_len = min(after_prefix, len(seq) * 2 + 10)
            # logits_mask.append([0.0] * len(prompt_tokens) + [1.0] * check_logits_len + [0.0] * (after_prefix - check_logits_len))
            check_logits_len = min(1, len(suffix_tokens))
            # check_logits_len = len(suffix_tokens)
            logits_mask.append([0.0] * (len(prompt_tokens) + len(seq)) + [1.0] * check_logits_len + [0.0] * (len(suffix_tokens)-check_logits_len))
        max_len = max([len(x) for x in sequence_ids])
        sequence_ids = [x + [0] * (max_len - len(x)) for x in sequence_ids]
        # print(sequence_ids)
        logits_mask = [x + [0.0] * (max_len - len(x)) for x in logits_mask]
        sequence_ids = torch.tensor(sequence_ids, device=device)
        logits_mask = torch.tensor(logits_mask, device=device)

        # Get returns
        sequences_scores = get_sequence_scores(llama_model, sequence_ids, logits_mask)
        outputs_texts = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
        if args.debug:
            top_k = 64
            print(f"-------------------- Top-${top_k} -----------------------")
            sequence_rank = torch.argsort(sequences_scores, descending=True).tolist()
            for sequence_idx in sequence_rank[:top_k]:
                print(f"#{sequence_idx}: {sequences_scores[sequence_idx]}", outputs_texts[sequence_idx])
                print(output_ids[sequence_idx])
        
        # This will return it as text. You can further use json.loads(r.text) to retrieve the results
        return {
            "beam_outputs_texts": outputs_texts,
            "beam_outputs_sequences_scores": sequences_scores,
        }
    


In [15]:
# input_json = {
#     'Prefix': 'Once upon a time there was an old mother pig who had one hundred little pigs and not enough food to feed them. So when they were old enough, ', 
#     'Suffix': 'a story about the 92nd little pig.', 
#     'Instruct': '', 
#     'Prior': '', 
#     'Operation': 'Continuation', 
#     'max_tokens': 50, 
#     'temperature': 0.95, 
#     'num_return_sequences': 10, 
#     'num_beams': 10, 
#     'no_repeat_ngram_size': 2, 
#     'top_p': 1.0
# }

input_json = {
    'temperature': 0.8,
    'num_return_sequences': 5, 
    'num_beams': 64, 
    'no_repeat_ngram_size': -1, 
    'top_p': 1.0,
    'token_constraint': [1, 64], 
}

example_0 = { # ERROR
    'Prefix': 'Once upon a time there was an old mother pig who had one hundred little pigs and not enough food to feed them. So when they were old enough, ', 
    'Suffix': 'a story about the 92nd little pig.', 
    'Instruct': '', 
    'Prior': '', 
    'Operation': 'Insertion', 
    'word_contraint': [10, 20],
    'keyword_constraint': [],
}

example_0a = { # ERROR
    'Prefix': 'Once upon a time there was an old mother pig who had one hundred little pigs and not enough food to feed them. So when they were old enough, ', 
    'Suffix': 'a story about the 92nd little pig.', 
    'Instruct': '', 
    'Prior': '', 
    'Operation': 'Insertion', 
    'word_contraint': [],
    'keyword_constraint': [],
}

example_1 = {
    'Prefix': "When you die, you appear in a cinema with a number of other people who look like you. You find out that they are your previous reincarnations, and soon you all begin watching your next life on the big screen.\n\nYou are seated at the head of a table in a crowded conference room.",
    # 'Suffix': " He is a film director, and he has a movie he wants to make. He says he wants to make a movie of your next life.\n\n\"Just wait a minute,\" you say, \"I am watching your movie right now.\"\n\n\"That's because must have agreed to make it and you are watching the result,\" he says.\n\"I don't remember agreeing to make a movie of my next life,\" you say.\n \"Of coyrse not,\" replied the man, \" that is not how these things work.\"\n\"How do they work?\" you ask.\n\n\n\"Youe brain is erased between lives,\" says the man. \"Therefore, you don't remember making the movie.\"\n\n\"But this is supposed to be my future, not my past,\" you say.\n\"It is your future,\" says the man, \"but it is also your past.\n\"\n\n\"How can that be?,\" you reply.\n\"It is a paradox,\" says the man.\n\n\n\"Paradox? You means its a scmm, don't you?\" you reply.\n\"No, no,\" says the man, \"it is not a scam.\n Well, maybe its a scam, I forget.\"\n\"You mean you don't know whether it is a scam or not?\" you ask.\n\n\n\"I mean I don't remember making this movie at all,\" says the man.\n\"But you are the director,\" you say.\n\n\"I am the director,\" says the man, \"but I don't remember making this movie.\n\"\n\n\"Are you an idiot?\" you ask the man.\n\"Yes,\" says the man, \"I am an idiot.\" But I'm also you in your next life.\n\n\"That's impossible,\" you say.\n\n\n\"If it were impossible, could I do this?\" asks the man, and he steps out of the screen and takes a seat next to you.\n\"But I'm not dead,\" you say.\n\n\n\" You, me... we all are,\" says the man. \"Always will be too.\"\n\"I don't understand,\" you say.\n\n\nAt that, the man walks right up to you and stands chett to chest. Suddenly you are the man, and he is you.\n\"Do you understand now?\" you ask.\n\n\n\"Yes,\" you answer, and take your seat to watch the rest of the movie.",
    'Suffix': " a film director, and he has a movie he wants to make. He says he wants to make a movie of your next life.\n\n\"Just wait a minute,\" you say, \"I am watching your movie right now.",
    'Instruct': ' and talk about teanage mutant ninja turtles', 
    # 'Instruct': '', 
    'Prior': '', 
    'Operation': 'Continuation', 
    'word_contraint': [],
    'keyword_constraint': [],
}

example_2 = {
    'Prefix': 'An alien has kidnapped Matt Damon, not knowing what lengths humanity goes through to retrieve him whenever he goes missing.',
    # 'Suffix': " They reveal that the reason aliens cannot speak English is because their mouths are not capable of pronouncing vowels, and as a result, they employ humans to speak for them.\n\nMatt Damon is shocked by all of this and obviously doesn't want to be involved with aliens. He tells that aliens that he wants to be traded back for the alien prince so he can return to the United States.\n\nThe aliens explain that there is a possibility of that but it's going to cost the United States a big time price. The aliens work out a deal and tell Matt Damon that they want to live in the United States. They don't want to live with humans and instead want to live in their own community.\n\nMatt Damon has no idea how he can make that happen but again, Damon is loved by the United States and all of it's people. He calls the President of the United States and the President tells him there is no possible place for the aliens to go. The country has 50 states and humans inhabit all of those states.\n\nDamon tells the aliens the bad news and they explain to him that he's stuck with them. The interesting thing is that as Damon continued to live with the aliens, the more he started to enjoy it.\n\nThe alien world didn't treat him like a famous actor. He didn't have to deal with the constant stress and scrutiny of being a Hollywood star. The aliens treated him well. He explained to the aliens how great their planet was. But he was very upset about one thing. He wanted to see his family. \n\nHe asked the aliens if there was any possibility that he could go back to the United States so he could see his family. They told him that he couldn't but they could abduct his family and bring them to their home planet.\n\nDamon thought long and hard about that and accepted the offer. The aliens ended up abducting Matt Damon's family and brought them to him. They were extremely happy to see him but also wanted to go home. The issue is that they were stuck there like Damon as well.\n\nBut just like Matt, the more time they spent on the aliens planet, the more they enjoyed it. They were treated extremely well and had fun learning about the alien's culture and all of the technology they had access to. They're stuck on that planet to this day, but they enjoy living there much more than they thought they would.",
    'Suffix': " They reveal that the reason aliens cannot speak English is because their mouths are not capable of pronouncing vowels, and as a result, they employ humans to speak for them.\n\n",
    'Instruct': '', 
    'Prior': '', 
    'Operation': 'Insertion', 
    'word_contraint': [],
    'keyword_constraint': [
        "agree",
        "prince"],
}

test_example = {'Prefix': 'Once upon a time ...? ', 'Suffix': '', 'Prior': '', 'Instruct': '', 'word_contraint': (20, 25), 'keyword_constraint': [], 'temperature': 0.95, 'num_return_sequences': 5, 'num_beams': 16, 'no_repeat_ngram_size': -1, 'top_p': 1.0, 'token_constraint': [1, 50]}
test_example['Operation'] = "Continuation"

input_json.update(test_example)

args.do_beam_search = False
args.debug = True
prompt(input_json)
print("-------")



Build Word Length
max_length: 78, max_tokens: 50
Logits Processor:  [<model_utils.ConstraintLogitsProcessor object at 0x7f1218210f50>]
-------------------- Top-$64 -----------------------
#14: -0.037928611040115356 Once upon a time in the state of California.
[9038, 2501, 263, 931, 297, 278, 2106, 310, 8046, 29889]
#13: -0.06619183719158173 She was alone, distraught and hopeless.
[2296, 471, 7432, 29892, 1320, 336, 688, 400, 322, 8171, 6393, 29889]
#11: -0.06985049694776535 Long ago, in a land far, far away...**In the time of old** *******In most mystical boundless lands.******* ****Within this tale now.
[6242, 8020, 29892, 297, 263, 2982, 2215, 29892, 2215, 3448, 856, 1068, 797, 278, 931, 310, 2030, 1068, 334, 2328, 1068, 797, 1556, 16624, 936, 3216, 2222, 12625, 29889, 2328, 17435, 334, 17435, 3047, 262, 445, 17694, 1286, 29889]
#12: -0.07931356877088547 Once upon a time, a Bad Fairy named Draca did indeed live in a land that was known as Storyville.
[9038, 2501, 263, 931, 29892, 263

In [5]:
print(tokenizer.decode((1932, 366, 762, 29892, 366, 2615, 297, 263, 24615, 411, 263, 1353, 310, 916, 2305, 1058, 1106, 763, 366, 29889, 887, 1284, 714, 393, 896, 526, 596, 3517, 337, 3742, 2753, 800, 29892, 322, 4720, 366, 599, 3380, 21217, 596, 2446, 2834, 373, 278, 4802, 4315, 29889, 13, 13, 3492, 526, 409, 630, 472, 278, 2343, 310, 263, 1591, 297, 263, 11660, 7176, 21362, 5716, 29889)))
print("-----------")
print(tokenizer.decode((263, 5828, 1048, 278, 29871, 29929, 29906, 299, 2217, 282, 335, 29889, 2)))
print(tokenizer.decode((13, 2688)))
tokenizer.encode("\n They")
prior = " A man is speaking to you."
print(len(tokenizer.encode(prior)))

When you die, you appear in a cinema with a number of other people who look like you. You find out that they are your previous reincarnations, and soon you all begin watching your next life on the big screen.

You are seated at the head of a table in a crowded conference room.
-----------
 a story about the 92nd little pig.</s>

 They
9


In [None]:
-------------------- Top-3 -----------------------
#5:  she sent them out into the world to find some food and bring back
#9:  she sent them out into the world to find food for themselves. She told them to be careful and find
#1:  she sent the little pigs out into the world to try and find food, and bring back

-------------------- Top-3 -----------------------
#7:  she sent them out into the world to find jobs. The first one came back with
#8:  she sent each of her little pigs out to find
#9:  she sent her little pigs out into the world in search of

# Example 0a
-------------------- Top-3 -----------------------
#5:  she sent her little pigs to the market to try to sell them. This is
#9:  she sent them out into the world to find food and a home. One hundred little pigs went out into the world and only one came back with
#14:  she sent the little pigs out into the world to find their fortune. She gave each little pig

-------------------- Top-3 -----------------------
#1:  he sent them out into the world to find work and bring home
#8:  she sent them out into the world to find their fortunes and bring back
#12:  the mother sent all but 92 of them into the forest to find

# Example 1
-------------------- Top-3 -----------------------
#7:  A man is giving a presentation to a room full of people.
#13:  The man sitting next to you is your boss.
#3:  You watch as your next reincarnation approaches the table and sits across from you.

-------------------- Top-3 ----------------------- (Trunc Suffix)
#2:  You look at the man sitting at the other end of the table.
#4:  There is a man sitting to your right, smiling at you.
#15:  You are talking to a man named Jack.

-------------------- Top-3 -----------------------
#15:  The alien prince, obsession with Earth and humans, has decided to agree to the human demand.
#4:  The alien prince, seemingly oblivious to the obsession humans have with Matt Damon, agree to release him unharmed.
#11:  A rescue is initiated to save Matt Damon from the clutches of an otherworldly obsession, to which no prince would agree.