In [2]:
import os 
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [3]:
import regex
from lark import UnexpectedInput, Lark, UnexpectedCharacters, UnexpectedToken, UnexpectedEOF, UnexpectedInput
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

import numpy as np
from transformers import LogitsProcessor, AutoModelForCausalLM, AutoTokenizer, BeamSearchScorer, LogitsProcessorList, MaxLengthCriteria, StoppingCriteriaList
import torch
from dataclasses import dataclass
from typing import List, Optional, Union

  from .autonotebook import tqdm as notebook_tqdm


In [13]:
@dataclass
class IntermediateParsingState: 
    active_terminal_names: List[str]
    active_terminal_patterns: List[regex.Regex]
    current_terminal_start_index: int

    def __str__(self) -> str:
        return f"({self.current_terminal_start_index}, {self.active_terminal_names})"
    
    def __repr__(self) -> str:
        return str(self)

class ParsingStepper():
    def __init__(self, parser: Lark, vocab, eos_token):
        self.parser: Lark = parser
        self.partial_token = ""
        self.vocab = vocab
        self.eos_token = eos_token
        self.regex_map = self._create_terminal_regexes()

    def _create_terminal_regexes(self):
        """
        Create a map from terminal names to regexes that match the terminal
        """
        terminal_regexes = {}
        for terminal in self.parser.terminals:
            if terminal.pattern:
                terminal_regexes[terminal.name] = regex.compile(terminal.pattern.to_regexp())
        terminal_regexes['$END'] = regex.compile(self.eos_token)
        return terminal_regexes

    def get_parsing_state(self, current_generation: str): 

        # Get the next parser tokens that would be valid to add to the input string according to the CFG
        next_parser_tokens, token_start_index = self._get_next_parser_tokens(current_generation)
        # Get the regexes for the next parser tokens
        next_patterns = [self.regex_map[terminal] for terminal in next_parser_tokens]
        
        return IntermediateParsingState(next_parser_tokens, next_patterns, token_start_index)
    
    def _get_next_parser_tokens(self, input_str):
        """
        Get the next tokens that would be valid to add to the input string
        :return: A list of tokens that would be valid to add to the input string, and the position in the input string where the next token would start
        """
        try:
            # Try parsing until error or end of input
            self.parser.parse(input_str)
        except UnexpectedInput as e:
            interactive = self.parser.parse_interactive(input_str)
            try: 
                # Get the set of tokens that would be valid next
                interactive.exhaust_lexer()
            except UnexpectedInput as ee: 
                # Now, this exception means that we have characters that do not match any of the terminals (yet). 
                # This means that we have a partial token.
                # Return the set of tokens that would be valid before that partial token 
                print("Second catch")
                return interactive.accepts(), ee.pos_in_stream
            # Return the token 
            return interactive.accepts(), e.pos_in_stream
 
        # If we get here, the input is complete
        return [], len(input_str)


with open("cfg_json.lark", "r") as f:
    cfg_json = f.read()

json_parser = Lark(
    cfg_json, 
    parser='lalr',
    # Using the basic lexer isn't required, and isn't usually recommended.
    # But, it's good enough for JSON, and it's slightly faster.
    lexer='basic',
    # Disabling propagate_positions and placeholders slightly improves speed
    propagate_positions=False,
    maybe_placeholders=False,
    regex=True
)

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-13b-hf") 
vocab = tokenizer.get_vocab()

state = ParsingStepper(json_parser, vocab, tokenizer.eos_token)
str(state.get_parsing_state('[null'))

"(1, {'RSQB', 'COMMA'})"

In [14]:
s = '{"a": ["1", "b": ["1", "2", "3"]]}'
for i in range(len(s)+1):
    cfg_state = state.get_parsing_state(s[:i])
    print(f"'{s[:i]}' -> {cfg_state}")

'' -> (0, {'ESCAPED_STRING', 'TRUE', 'LSQB', 'FALSE', 'LBRACE', 'NULL'})
'{' -> (0, {'RBRACE', 'ESCAPED_STRING'})
Second catch
'{"' -> (1, {'RBRACE', 'ESCAPED_STRING'})
Second catch
'{"a' -> (1, {'RBRACE', 'ESCAPED_STRING'})
'{"a"' -> (1, {'COLON'})
'{"a":' -> (4, {'ESCAPED_STRING', 'TRUE', 'LSQB', 'FALSE', 'LBRACE', 'NULL'})
'{"a": ' -> (4, {'ESCAPED_STRING', 'TRUE', 'LSQB', 'FALSE', 'LBRACE', 'NULL'})
'{"a": [' -> (6, {'RSQB', 'TRUE', 'LSQB', 'FALSE', 'LBRACE', 'NULL', 'ESCAPED_STRING'})
Second catch
'{"a": ["' -> (7, {'RSQB', 'TRUE', 'LSQB', 'FALSE', 'LBRACE', 'NULL', 'ESCAPED_STRING'})
Second catch
'{"a": ["1' -> (7, {'RSQB', 'TRUE', 'LSQB', 'FALSE', 'LBRACE', 'NULL', 'ESCAPED_STRING'})
'{"a": ["1"' -> (7, {'RSQB', 'COMMA'})
'{"a": ["1",' -> (10, {'ESCAPED_STRING', 'TRUE', 'LSQB', 'FALSE', 'LBRACE', 'NULL'})
'{"a": ["1", ' -> (10, {'ESCAPED_STRING', 'TRUE', 'LSQB', 'FALSE', 'LBRACE', 'NULL'})
Second catch
'{"a": ["1", "' -> (12, {'ESCAPED_STRING', 'TRUE', 'LSQB', 'FALSE', 'LBRACE',

In [16]:
s = '{"num_values": "4", "values": ["1", "2",">",">"],'
print(state.get_parsing_state(s))
print(s[:15] + "_" + s[15:])

(48, {'ESCAPED_STRING'})
{"num_values": _"4", "values": ["1", "2",">",">"],


In [6]:
model_name = "meta-llama/Llama-2-7b-hf"

model = AutoModelForCausalLM.from_pretrained(model_name, load_in_4bit=True, device_map="cuda:0")
tokenizer = AutoTokenizer.from_pretrained(model_name)

model.config.pad_token_id = model.config.eos_token_id
tokenizer.pad_token = tokenizer.eos_token

2023-10-01 15:28:11.464215: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Loading checkpoint shards: 100%|██████████| 2/2 [00:58<00:00, 29.22s/it]


In [17]:
class LogitsProcessor(LogitsProcessor):
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        self.all_tokens = self.tokenizer.convert_ids_to_tokens(range(self.tokenizer.vocab_size))

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:

        # input_ids: B * num_beams x T
        # scores: B * num_beams x V

        # Decode sequences
        decoded_sequences = self.tokenizer.batch_decode(input_ids, skip_special_tokens=True)

        print(f"Decoded sequences: {decoded_sequences}")

        # Get parsing states per sequence 
        parsing_states = [state.get_parsing_state(seq) for seq in decoded_sequences]
        print(f"Parsing states: {parsing_states}")

        valid_tokens = [self._filter_tokens_by_regex(self.all_tokens, state.active_terminal_patterns) for state in parsing_states]
        valid_token_ids = [self.tokenizer.convert_tokens_to_ids(tokens) for tokens in valid_tokens]  # list of lists of token ids
        print(f"Valid tokens: {valid_tokens}")

        # Mask out scores 
        scores_mask = torch.ones_like(scores) * float('inf') * -1
        for sequence_index, valid_token_ids_for_sequence in enumerate(valid_token_ids):
            scores_mask[sequence_index, valid_token_ids_for_sequence] = 0

        scores = scores + scores_mask

        print(f"Argmax: {scores.argmax(dim=-1)}")

        print("-" * 8)
        return scores
        

    def _filter_tokens_by_regex(self, tokens, regexes):
        """
        Filter tokens by regexes
        """
        return [
            token 
            for token in tokens 
            if any(regex.fullmatch(token, partial=True) for regex in regexes)
        ]


num_beams = 2
input_prompt = '{"num_values": "4", "values": ["1", "2",'
max_length = 35

input_ids = tokenizer(
    input_prompt, 
    return_tensors="pt"
).input_ids
input_ids = torch.stack([input_ids] * num_beams, dim=0).reshape(num_beams, -1).to(model.device)
bos_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) * model.config.bos_token_id
input_ids = torch.cat([bos_ids, input_ids], dim=-1)

final_sentence = model.beam_search(
    input_ids, 
    beam_scorer=BeamSearchScorer(
        batch_size=1,
        max_length=max_length,
        num_beams=num_beams,
        device="cuda",
        length_penalty=1.0,
        do_early_stopping=True,
    ),
    logits_processor = LogitsProcessorList([
        LogitsProcessor(tokenizer)
    ]),
    stopping_criteria = StoppingCriteriaList([
        MaxLengthCriteria(max_length=max_length)
    ]),
    pad_token_id=tokenizer.eos_token_id, 
)

final_sentence_str = tokenizer.batch_decode(final_sentence, skip_special_tokens=True)[0]
print(final_sentence_str)

Decoded sequences: ['{"num_values": "4", "values": ["1", "2",', '{"num_values": "4", "values": ["1", "2",']
Parsing states: [(39, {'ESCAPED_STRING', 'TRUE', 'LSQB', 'FALSE', 'LBRACE', 'NULL'}), (39, {'ESCAPED_STRING', 'TRUE', 'LSQB', 'FALSE', 'LBRACE', 'NULL'})]
Valid tokens: [['tr', '",', '">', '":', '")', '");', '".', '";', '").', 'true', '"]', '"><', '","', 'nu', 'null', 'false', '"/>', '":"', '"),', '"></', 'fa', '"))', '"`', '"?', '"));', '"}', '">\r', '"];', '"},', '"],', '")]', '",\r', '""', '"].', '"+', 'fal', '");\r', '"?>', '"\r', '"=>', '"])', '")`', '".$', '"/', '";\r', '"\\', '":{"', 't', 'n', 'f', '"', '{', '['], ['tr', '",', '">', '":', '")', '");', '".', '";', '").', 'true', '"]', '"><', '","', 'nu', 'null', 'false', '"/>', '":"', '"),', '"></', 'fa', '"))', '"`', '"?', '"));', '"}', '">\r', '"];', '"},', '"],', '")]', '",\r', '""', '"].', '"+', 'fal', '");\r', '"?>', '"\r', '"=>', '"])', '")`', '".$', '"/', '";\r', '"\\', '":{"', 't', 'n', 'f', '"', '{', '[']]
Argmax: 