In [1]:
import torch
import sys
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq
from datasets import load_dataset
from torch.utils.data import DataLoader

model_name_or_path = "../output/Story-TULU-LLAMA2-7B-Cont/"
inference_file = "../data/processed/writing_prompts/wp_allcont_debug.jsonl"
# Load Model and Tokenizer
# Add pdding left to be able to do batched generation
model = AutoModelForCausalLM.from_pretrained(model_name_or_path)

  from .autonotebook import tqdm as notebook_tqdm


[2024-02-05 12:43:44,340] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)


Loading checkpoint shards: 100%|██████████| 2/2 [00:08<00:00,  4.22s/it]


In [11]:
from functools import partial

def encode_with_messages_format(example, tokenizer, max_seq_length):
    '''
    Here we assume each example has a 'messages' field Each message is a dict with 'role' and 'content' fields.
    We concatenate all messages with the roles as delimiters and tokenize them together.
    '''
    messages = example['messages']
    if len(messages) == 0:
        raise ValueError('messages field is empty.')
    
    def _concat_messages(messages):
        message_text = ""
        for message in messages:
            if message["role"] == "system":
                message_text += "<|system|>\n" + message["content"].strip() + "\n"
            elif message["role"] == "user":
                message_text += "<|user|>\n" + message["content"].strip() + "\n"
            elif message["role"] == "assistant":
                # For assistant, we don't want to give the content for inference
                message_text += "<|assistant|>\n" # + message["content"].strip() + tokenizer.eos_token + "\n"
                # Add a break to make sure there's only one turn
                break
            else:
                raise ValueError("Invalid role: {}".format(message["role"]))
        return message_text
        
    example_text = _concat_messages(messages).strip()
    tokenized_example = tokenizer(example_text, return_tensors='pt', max_length=max_seq_length, truncation=True)
    input_ids = tokenized_example.input_ids
    # Do not get labels here, no need to
    attention_mask = torch.ones_like(input_ids)
    return {
        'input_ids': input_ids.flatten(),
        'attention_mask': attention_mask.flatten(),
    }

In [93]:
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side = "left")
inference_batch_size = 2

# Load the dataset from inference file
dataset_args = {}
raw_inference_dataset = load_dataset(
    "json",
    data_files={"inference": inference_file},
    **dataset_args,
)
print(raw_inference_dataset['inference'][0])
# Preprocessing the datasets.
encode_function = partial(
    encode_with_messages_format,
    tokenizer=tokenizer,
    max_seq_length=2048,
)
inference_dataset = raw_inference_dataset.map(
    encode_function,
    batched=False,
    num_proc=8,
    remove_columns=[name for name in raw_inference_dataset["inference"].column_names if name not in ["input_ids", "labels", "attention_mask"]],
    desc="Tokenizing and reformatting instruction data",
)
inference_dataset.set_format(type="pt")
inference_dataset = inference_dataset["inference"]
# inference_dataset = inference_dataset.filter(lambda example: (example['labels'] != -100).any())["inference"]
inference_dataloader = DataLoader(
    inference_dataset, 
    shuffle=False, 
    collate_fn=DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding="longest"),
    batch_size=inference_batch_size
)
# # Load Testing Dataset

model.half().cuda()

# Start doing inference
for batch in inference_dataloader:
    generation_output = model.generate(
        batch.input_ids.cuda(),
        max_new_tokens=65, 
        # return_dict_in_generate=True,
        # output_scores=True,
        do_sample=True, # If want to use greedy, set to False
        attention_mask = batch.attention_mask.cuda()
    )
    # You can also get the scores(logits) by changing above arfs
    output_ids = generation_output
    # All input have same length with paddings
    input_len = batch.input_ids[0].shape[-1]
    output_str_list = tokenizer.batch_decode(output_ids[:,input_len:])
    print(tokenizer.batch_decode(output_ids)[0])
    print(tokenizer.batch_decode(output_ids)[1])
    print("---------------Done")
    print(output_str_list)
    for output_str in output_str_list:
        print(output_str.strip())



You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'dataset': 'writing_prompts', 'id': 'writing_prompts_0', 'messages': [{'role': 'user', 'content': "Continue the given text in enchanting style:\nI watched as the seconds chipped the last of my business hours away. The expected tinkle signalling the entrance of my oddest regular interrupted the deafening ticks of the little hand I had been tracking. So familiar was our weekly routine that I didn't even notice I had begun until the tip of my tattoo gun dipped into John's skin to cross out the last of the names on his back. I didn't need to read it to know Gregory was being offed. For what, I never asked. A few months ago, John had sauntered into my humble tattoo studio right at closing time. Not a second before. I had attempted to turn John away, especially having just refused the business of three, piss-drunk dude-bros who didn't want to hear that matching knuckle-pieces spelling out B.R.O.S. on their right hands was not the best, most awesome idea ever. Instead of arguing or leaving, 

In [95]:
model.eval()
generation_output = model.generate(
    batch.input_ids.cuda(),
    max_new_tokens=65, 
    # return_dict_in_generate=True,
    # output_scores=True,
    do_sample=False,
    attention_mask = batch.attention_mask.cuda()
)
# You can also get the scores(logits) by changing above arfs
output_ids = generation_output
# All input have same length with paddings
input_len = batch.input_ids[0].shape[-1]
output_str_list = tokenizer.batch_decode(output_ids[:,input_len:])
print(tokenizer.batch_decode(output_ids)[0])
print(tokenizer.batch_decode(output_ids, skip_special_tokens = True)[1])
print("---------------Done")
print(output_str_list)
for output_str in output_str_list:
    print(output_str.strip())

<s> <|user|>
Continue the given text in enchanting style:
I watched as the seconds chipped the last of my business hours away. The expected tinkle signalling the entrance of my oddest regular interrupted the deafening ticks of the little hand I had been tracking. So familiar was our weekly routine that I didn't even notice I had begun until the tip of my tattoo gun dipped into John's skin to cross out the last of the names on his back. I didn't need to read it to know Gregory was being offed. For what, I never asked. A few months ago, John had sauntered into my humble tattoo studio right at closing time. Not a second before. I had attempted to turn John away, especially having just refused the business of three, piss-drunk dude-bros who didn't want to hear that matching knuckle-pieces spelling out B.R.O.S. on their right hands was not the best, most awesome idea ever. Instead of arguing or leaving, John had stood stock-still. Silent. Staring straight into my heavy-lidded eyes. After an

In [7]:
import torch
import sys
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq
from datasets import load_dataset
from torch.utils.data import DataLoader

model_name_or_path = "../output/Story-TULU-LLAMA2-7B-Cont/"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side = "left")

print(f"----{tokenizer.decode([29873, 471, 297, 393, 3256, 29892, 408, 306, 12642, 287, 964, 902, 5076, 29892, 393, 306, 1754, 590, 10608, 29889, 450, 4799, 2820, 502, 6140, 304, 12003, 264, 29892, 278, 25005, 20139, 411, 385, 12646, 260, 2673, 29889, 306, 1033, 4459, 278, 7688, 310, 278, 3838, 1183, 750, 19182, 29892, 278, 2411, 1414, 310, 902, 9177, 29889, 739, 471, 263, 18766, 29892, 263, 23222])}-------")



In [3]:
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
prefix = "I really love my dog, who"
model.half().cuda()
# Test beam generation

In [12]:
from parsing import (
    parse_prompt, parse_suggestion, parse_probability,
    filter_suggestions
)

with torch.inference_mode():
    input_ids = encode_with_messages_format(prefix, tokenizer).cuda()
    # dict_keys(['sequences', 'sequences_scores', 'scores', 'beam_indices', 'attentions', 'hidden_states'])
    beam_outputs = model.generate(
        input_ids,
        max_new_tokens=100,
        temperature=0.95,
        num_return_sequences=5,
        num_beams=5,
        no_repeat_ngram_size=2,
        top_p=1,
        early_stopping=True,
        output_scores = True, # Get sequences_scores
        return_dict_in_generate=True,
    )
suggestions = []
for choice in beam_outputs.sequences:
    choice_text = tokenizer.batch_decode(choice[input_ids.shape[-1]:], skip_special_tokens=True)
    suggestion = parse_suggestion(
        choice_text,
        results['after_prompt'],
        stop_rules
    )
    probability = parse_probability(beam_outputs.sequences_scores.cpu().detach().tolist())
    suggestions.append((suggestion, probability, engine))

In [22]:
print(beam_outputs.__dict__.keys())

print(tokenizer.batch_decode(beam_outputs.sequences[:, input_ids.shape[-1]:], skip_special_tokens=True))
print(beam_outputs.sequences_scores.cpu().detach().tolist())

dict_keys(['sequences', 'sequences_scores', 'scores', 'beam_indices', 'attentions', 'hidden_states'])
["is a rescue dog. I have had him since he was a puppy and he has brought so much joy to my life. He is my baby, my best friend, and my confidant. We do everything together and I couldn't imagine life without him.", "is my best friend. I would do anything for him. He's always there for me, no matter what. When I'm sad, he lays his head on my lap and looks up at me with those big, loving eyes, and I know everything is going to be okay.\n\nBut sometimes, I can't help but wonder if he feels the same way about me. Does he know how much I love him? Do I ever cross his mind when he'", 'is a rescue dog. I have had him since he was 8 weeks old. He is now 12 years old and still going strong. We have been through a lot together and I couldn\'t imagine my life without him.\n\nHe has been my best friend, my confidant, and my protector. When I am sad, he lays his head on my lap and looks at me with

# Combine with HMM

In [12]:
from tqdm import tqdm
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer
from transformers import BeamSearchScorer, LogitsProcessorList, StoppingCriteriaList, MaxLengthCriteria
import argparse
from hmm_model_K import *
from model_utils import (
    encode_with_messages_format,
    ConstraintLogitsProcessor
)

def init():
    global device
    global CUDA_CORE
    global args
    
    arg_parser = argparse.ArgumentParser()
    arg_parser.add_argument('--device', default='cuda', type=str)
    arg_parser.add_argument('--cuda_core', default='1', type=str)
    arg_parser.add_argument('--hmm_batch_size', default=256, type=int)

    arg_parser.add_argument('--min_k', default=5, type=int)
    arg_parser.add_argument('--max_k', default=32, type=int)
    arg_parser.add_argument('--suffix_cap', default=10000, type=int)
    arg_parser.add_argument('--length_penalty', default=0.2, type=float)

    arg_parser.add_argument('--hmm_model_path', default=None, type=str)
    arg_parser.add_argument('--llama_model_path', default='gpt2', type=str)

    # Hard code the model path
    args = arg_parser.parse_args([
        "--hmm_model_path", "/local1/hzhang19/matcha/models/hmm_llama-story_32768_64/checkpoint-90.weight.th",
        "--llama_model_path", "/local1/ponienkung/CtrlGen/output/NewFinetune2K8K_StoryPretrain-TULU-LLAMA2",
        "--cuda_core", "4",
    ])
    
    device = args.device
    # os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_core
    torch.cuda.set_device(int(args.cuda_core))

def load_models():
    global tokenizer
    global llama_model
    global hmm_model
    try:
        print(f'loading llama2 as half prec. from {args.llama_model_path} ...')
        llama_model = LlamaForCausalLM.from_pretrained(args.llama_model_path).half().to(device)
        tokenizer = LlamaTokenizer.from_pretrained(args.llama_model_path)

        print(f'loading hmm from {args.hmm_model_path} ...')
        hmm_model = HMM(args.hmm_model_path)
        hmm_model.to(device)
    except Exception as e:
        print(f"Cannot Load args.model {args.model_name_or_path} because of the following exception:\n {e}")
        print("Exit the process...")
        exit(0)

In [13]:
# Init Models

init()
load_models()


loading llama2 as half prec. from /local1/ponienkung/CtrlGen/output/NewFinetune2K8K_StoryPretrain-TULU-LLAMA2 ...


Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00,  2.42s/it]


loading hmm from /local1/hzhang19/matcha/models/hmm_llama-story_32768_64/checkpoint-90.weight.th ...


In [8]:
from hmm_model_K import *
from transformers import LogitsProcessor
# HMM utils
class ConstraintLogitsProcessor(LogitsProcessor):
    def __init__(self, hmm_model, hmm_config, device):
        self.hmm_model = hmm_model
        self.hmm_config = hmm_config
        self.device = device


    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:        
        hmm_prompt_len = self.hmm_config['hmm_prompt_len']
        hmm_prefix = self.hmm_config['hmm_prefix']
        hmm_suffix = self.hmm_config['hmm_suffix']
        hmm_batch_size = self.hmm_config['hmm_batch_size']
        hmm_min_suffix_offset = self.hmm_config['hmm_min_suffix_offset']
        hmm_max_suffix_offset = self.hmm_config['hmm_max_suffix_offset']

        prefixes = [tuple(hmm_prefix) + tuple(prefix)
            for prefix in input_ids[:,hmm_prompt_len:].tolist()]

        hmm_logits, hmm_logits_ = self.hmm_model.compute_logits(
            prefixes, hmm_suffix,
            hmm_min_suffix_offset, hmm_max_suffix_offset,
            batch_size=hmm_batch_size)

        hmm_logits -= hmm_logits_
        # print("Before: ", hmm_logits.size())
        # TODO: Question: There will be value lower than -1e10. Also why do we take log_softmax twice here
        hmm_logits = torch.cat((hmm_logits, -1e30 * torch.ones((hmm_logits.shape[0], 1), device=self.device)), dim=1)
        # print("After: ", hmm_logits.size())
        
        logits = torch.log_softmax(scores, dim=-1)
        # print(torch.argmax(logits, dim=-1))
        logits = torch.log_softmax(hmm_logits + logits, dim=-1)
        # print((hmm_logits + logits)[:, -2:])
        # print(torch.argmax(logits, dim=-1))
        # print("------")
        
        
        return logits


In [20]:
# suffix 
def prompt(input_data):
    prefix_tokens = tuple(tokenizer.encode(input_data["Prefix"])[1:])
    # Question: Why remove the first token again?
    # suffix_tokens = tuple(tokenizer.encode(input_data["Suffix"][1:])[1:args.suffix_cap+1]) # TEMPORARY BUG FIX
    # To remove space in suffix
    suffix_tokens = tuple(tokenizer.encode("\n" + input_data["Suffix"])[3:args.suffix_cap+1] + [2]) # TEMPORARY BUG FIX
    print(prefix_tokens, suffix_tokens)
    with torch.no_grad():
        torch.cuda.empty_cache()
        input_ids = encode_with_messages_format(
            Prefix = input_data["Prefix"],
            SoftControl = input_data["Instruct"], 
            Prior = input_data["Prior"],
            tokenizer = tokenizer, 
            operation = input_data["Operation"]
            ).cuda()
        
        # TODO: Input args for k_ranges
        k_ranges = []
        k_ranges.extend([(10, 60)])

        for k_range in tqdm(k_ranges):
            min_k, max_k = k_range
            # For HMM
            hmm_model.initialize_cache(
                prefix_tokens, 
                suffix_tokens,
                min_k,
                max_k
            )
            hmm_config = {
                'hmm_prompt_len': len(input_ids[0]),
                'hmm_prefix': prefix_tokens,
                'hmm_suffix': suffix_tokens,
                'hmm_batch_size': args.hmm_batch_size,
                'hmm_min_suffix_offset': len(prefix_tokens) + min_k,
                'hmm_max_suffix_offset': len(prefix_tokens) + max_k,
            }

            logits_processor = LogitsProcessorList([ConstraintLogitsProcessor(hmm_model, hmm_config, device)])
            print(tokenizer.batch_decode(input_ids))
            # model.generate with more configs
            beam_outputs = llama_model.generate(
                input_ids,
                num_beams=input_data["num_beams"],
                num_return_sequences=input_data["num_return_sequences"],
                no_repeat_ngram_size=input_data["no_repeat_ngram_size"],
                max_new_tokens=max_k,
                logits_processor=logits_processor,
                early_stopping=True,
                output_scores = True, # Get sequences_scores
                return_dict_in_generate=True,
            )
    
    beam_outputs_sequences = beam_outputs.sequences.cpu().detach().tolist()
    beam_outputs_sequences_scores = beam_outputs.sequences_scores.cpu().detach().tolist()
    beam_outputs_texts = [tokenizer.decode(choice[input_ids.shape[-1]:], skip_special_tokens=False) for choice in beam_outputs_sequences]
    # This will return it as text. You can further use json.loads(r.text) to retrieve the results
    return {
        "beam_outputs_texts": beam_outputs_texts,
        "beam_outputs_sequences_scores": beam_outputs_sequences_scores,
    }

In [18]:
input_json = {
    'Prefix': 'Once upon a time there was an old mother pig who had one hundred little pigs and not enough food to feed them. So when they were old enough, ', 
    'Suffix': 'a story about the 92nd little pig.', 
    'Instruct': '', 
    'Prior': '', 
    'Operation': 'Continuation', 
    'max_tokens': 50, 
    'temperature': 0.95, 
    'num_return_sequences': 10, 
    'num_beams': 10, 
    'no_repeat_ngram_size': 2, 
    'top_p': 1.0
}




In [14]:
# tokenizer.decode([    1,   529, 29989,  1792, 29989, 29958,    13,  1323, 14150,   278,
#           2183,  1426, 29901,    13, 29875,   505,    13, 29966, 29989,   465,
#          22137, 29989, 29958,    13,   263, 12561,   393,   697,  2462,   445,
#           5233,   674, 14451,   701,   322,  5735,   714,   278,  1565,  6593,
#            310, 12091,  6369,  6079,  9968, 27819, 11940, 25820, 25449, 18866,
#           3266, 18712,  3508, 16973,  8422, 24261, 17038, 19218,  8007, 14638,
#          29991, 20941,  5735, 23987, 27414, 25979, 13458,  3558, 29613,  6268,
#          27076, 30441,  5236, 16765])
# tokenizer.encode(["</s>"])
# tokenizer.decode((3303, 3900, 674, 367, 2107, 1449))
# tokenizer.decode((869, 29871))
# suffix_tokens = tuple(tokenizer.encode(input_data["Suffix"])[1:args.suffix_cap+1]) # TEMPORARY BUG FIX
tokenizer.encode('\n. ')[3:]
tokenizer.decode([306, 1304, 304, 505, 29871])

# tokenizer.encode(["<pad>"])

'I used to have '

In [21]:
prompt(input_json)

(9038, 2501, 263, 931, 727, 471, 385, 2030, 5637, 282, 335, 1058, 750, 697, 6893, 2217, 282, 23379, 322, 451, 3307, 9687, 304, 8343, 963, 29889, 1105, 746, 896, 892, 2030, 3307, 29892, 29871) (29874, 5828, 1048, 278, 29871, 29929, 29906, 299, 2217, 282, 335, 29889, 2)


  0%|          | 0/1 [00:00<?, ?it/s]

['<s><|user|>\nContinue the given text:\nOnce upon a time there was an old mother pig who had one hundred little pigs and not enough food to feed them. So when they were old enough,\n<|assistant|>\n']


100%|██████████| 1/1 [00:33<00:00, 33.81s/it]


{'beam_outputs_texts': ['she sent them out into the world to find food.\n\nThe first little Pig found a job at a bakery. He worked hard every day, but he received a very goodом repres chiam environmental conseilimp métcenteringbatch anymore wij Corps Engel JrProps этойharm mittalle ellos Mag',
  'she sent them out into the world to find food.\n\nThe first little Pig found a job at a bakery. He worked hard every day, but he always took a cookie quello electric trig spiritualXTèn机宗cfg NYSym Try ur♭ зай)+\\ castle applic SU feelsjection',
  'she sent them out into the world to find food.\n\nThe first little Pig found a job at a bakery. He worked hard every day, baking sweet, golden, Kurchanged valuablenikotta quasiċloped найrableIRerte Buddulleänner svě MenHint shell Cambridge',
  "she sent them out into the world to find food.\n\nThe first little Pig found a job at a bakery. He worked hard every day, but he was always very grateful some дерев byl plata arrestờ діяightsстер situatedearsле