In [2]:
import os 
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [3]:
import regex
from lark import UnexpectedInput, Lark, UnexpectedCharacters, UnexpectedToken, UnexpectedEOF, UnexpectedInput
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

import numpy as np
from transformers import LogitsProcessor, AutoModelForCausalLM, AutoTokenizer, BeamSearchScorer, LogitsProcessorList, MaxLengthCriteria, StoppingCriteriaList
import torch
from dataclasses import dataclass
from typing import List, Optional, Union

  from .autonotebook import tqdm as notebook_tqdm


In [13]:



with open("cfg_json.lark", "r") as f:
    cfg_json = f.read()

json_parser = Lark(
    cfg_json, 
    parser='lalr',
    # Using the basic lexer isn't required, and isn't usually recommended.
    # But, it's good enough for JSON, and it's slightly faster.
    lexer='basic',
    # Disabling propagate_positions and placeholders slightly improves speed
    propagate_positions=False,
    maybe_placeholders=False,
    regex=True
)

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-13b-hf") 
vocab = tokenizer.get_vocab()

state = ParsingStepper(json_parser, vocab, tokenizer.eos_token)
str(state.get_parsing_state('[null'))

"(1, {'RSQB', 'COMMA'})"

In [14]:
s = '{"a": ["1", "b": ["1", "2", "3"]]}'
for i in range(len(s)+1):
    cfg_state = state.get_parsing_state(s[:i])
    print(f"'{s[:i]}' -> {cfg_state}")

'' -> (0, {'ESCAPED_STRING', 'TRUE', 'LSQB', 'FALSE', 'LBRACE', 'NULL'})
'{' -> (0, {'RBRACE', 'ESCAPED_STRING'})
Second catch
'{"' -> (1, {'RBRACE', 'ESCAPED_STRING'})
Second catch
'{"a' -> (1, {'RBRACE', 'ESCAPED_STRING'})
'{"a"' -> (1, {'COLON'})
'{"a":' -> (4, {'ESCAPED_STRING', 'TRUE', 'LSQB', 'FALSE', 'LBRACE', 'NULL'})
'{"a": ' -> (4, {'ESCAPED_STRING', 'TRUE', 'LSQB', 'FALSE', 'LBRACE', 'NULL'})
'{"a": [' -> (6, {'RSQB', 'TRUE', 'LSQB', 'FALSE', 'LBRACE', 'NULL', 'ESCAPED_STRING'})
Second catch
'{"a": ["' -> (7, {'RSQB', 'TRUE', 'LSQB', 'FALSE', 'LBRACE', 'NULL', 'ESCAPED_STRING'})
Second catch
'{"a": ["1' -> (7, {'RSQB', 'TRUE', 'LSQB', 'FALSE', 'LBRACE', 'NULL', 'ESCAPED_STRING'})
'{"a": ["1"' -> (7, {'RSQB', 'COMMA'})
'{"a": ["1",' -> (10, {'ESCAPED_STRING', 'TRUE', 'LSQB', 'FALSE', 'LBRACE', 'NULL'})
'{"a": ["1", ' -> (10, {'ESCAPED_STRING', 'TRUE', 'LSQB', 'FALSE', 'LBRACE', 'NULL'})
Second catch
'{"a": ["1", "' -> (12, {'ESCAPED_STRING', 'TRUE', 'LSQB', 'FALSE', 'LBRACE',

In [16]:
s = '{"num_values": "4", "values": ["1", "2",">",">"],'
print(state.get_parsing_state(s))
print(s[:15] + "_" + s[15:])

(48, {'ESCAPED_STRING'})
{"num_values": _"4", "values": ["1", "2",">",">"],


In [6]:
model_name = "meta-llama/Llama-2-7b-hf"

model = AutoModelForCausalLM.from_pretrained(model_name, load_in_4bit=True, device_map="cuda:0")
tokenizer = AutoTokenizer.from_pretrained(model_name)

model.config.pad_token_id = model.config.eos_token_id
tokenizer.pad_token = tokenizer.eos_token

2023-10-01 15:28:11.464215: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Loading checkpoint shards: 100%|██████████| 2/2 [00:58<00:00, 29.22s/it]


In [17]:
num_beams = 2
input_prompt = '{"num_values": "4", "values": ["1", "2",'
max_length = 35

input_ids = tokenizer(
    input_prompt, 
    return_tensors="pt"
).input_ids
input_ids = torch.stack([input_ids] * num_beams, dim=0).reshape(num_beams, -1).to(model.device)
bos_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) * model.config.bos_token_id
input_ids = torch.cat([bos_ids, input_ids], dim=-1)

final_sentence = model.beam_search(
    input_ids, 
    beam_scorer=BeamSearchScorer(
        batch_size=1,
        max_length=max_length,
        num_beams=num_beams,
        device="cuda",
        length_penalty=1.0,
        do_early_stopping=True,
    ),
    logits_processor = LogitsProcessorList([
        LogitsProcessor(tokenizer)
    ]),
    stopping_criteria = StoppingCriteriaList([
        MaxLengthCriteria(max_length=max_length)
    ]),
    pad_token_id=tokenizer.eos_token_id, 
)

final_sentence_str = tokenizer.batch_decode(final_sentence, skip_special_tokens=True)[0]
print(final_sentence_str)

Decoded sequences: ['{"num_values": "4", "values": ["1", "2",', '{"num_values": "4", "values": ["1", "2",']
Parsing states: [(39, {'ESCAPED_STRING', 'TRUE', 'LSQB', 'FALSE', 'LBRACE', 'NULL'}), (39, {'ESCAPED_STRING', 'TRUE', 'LSQB', 'FALSE', 'LBRACE', 'NULL'})]
Valid tokens: [['tr', '",', '">', '":', '")', '");', '".', '";', '").', 'true', '"]', '"><', '","', 'nu', 'null', 'false', '"/>', '":"', '"),', '"></', 'fa', '"))', '"`', '"?', '"));', '"}', '">\r', '"];', '"},', '"],', '")]', '",\r', '""', '"].', '"+', 'fal', '");\r', '"?>', '"\r', '"=>', '"])', '")`', '".$', '"/', '";\r', '"\\', '":{"', 't', 'n', 'f', '"', '{', '['], ['tr', '",', '">', '":', '")', '");', '".', '";', '").', 'true', '"]', '"><', '","', 'nu', 'null', 'false', '"/>', '":"', '"),', '"></', 'fa', '"))', '"`', '"?', '"));', '"}', '">\r', '"];', '"},', '"],', '")]', '",\r', '""', '"].', '"+', 'fal', '");\r', '"?>', '"\r', '"=>', '"])', '")`', '".$', '"/', '";\r', '"\\', '":{"', 't', 'n', 'f', '"', '{', '[']]
Argmax: 