In [1]:
import os 
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

In [2]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

import torch
import numpy as np
from transformers import LogitsProcessor, AutoModelForCausalLM, AutoTokenizer, BeamSearchScorer, LogitsProcessorList, MaxLengthCriteria, StoppingCriteriaList

import cfg_decoding.parsing as p
import cfg_decoding.logits_processor as lp

import importlib
importlib.reload(p)
importlib.reload(lp)

  from .autonotebook import tqdm as notebook_tqdm


<module 'cfg_decoding.logits_processor' from '/workspaces/funcqa_experiments/cfg_decoding/logits_processor.py'>

In [3]:
MODEL_NAME = "meta-llama/Llama-2-13b-chat-hf"

In [4]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

In [5]:
importlib.reload(p)

with open("funcqa.lark", "r") as f:
    cfg_def = f.read()

stepper = p.create_parsing_stepper(cfg_def, tokenizer)

print(stepper.get_parsing_state("add(1"))

# s = 'add(10., 2.)'
# for i in range(len(s)+1):
#     cfg_state = stepper.get_parsing_state(s[:i])
#     print(f"'{s[:i]}' -> {cfg_state}")

State(start_idx=0, terminals={'__ANON_0'})


In [6]:
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, load_in_4bit=True, device_map="cuda:1")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

model.config.pad_token_id = model.config.eos_token_id
tokenizer.pad_token = tokenizer.eos_token

Loading checkpoint shards: 100%|██████████| 3/3 [00:04<00:00,  1.45s/it]


In [7]:
importlib.reload(lp)

num_beams = 2
input_prompt = '''Use functions add, mul, div and sub to solve the following math problem.

E.g. multiply(1, 20) or add(1, mul(2, 3)) or divide(5, 3) or subtract(15, 3) or add(10, 2)

Question: 1 + 20

Calculation: '''

input_ids = tokenizer(
    input_prompt, 
    return_tensors="pt"
).input_ids
input_ids = torch.stack([input_ids] * num_beams, dim=0).reshape(num_beams, -1).to(model.device)
bos_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) * model.config.bos_token_id
input_ids = torch.cat([bos_ids, input_ids], dim=-1)

prompt_end_index = input_ids.shape[1]
max_length = prompt_end_index + 20

final_sentence = model.beam_search(
    input_ids, 
    beam_scorer=BeamSearchScorer(
        batch_size=1,
        max_length=max_length,
        num_beams=num_beams,
        device="cuda",
        length_penalty=1.0,
        do_early_stopping=True,
    ),
    logits_processor = LogitsProcessorList([
        lp.GrammarConstrainedLogitsProcessor(tokenizer, stepper, prompt_end_index=prompt_end_index)
    ]),
    stopping_criteria = StoppingCriteriaList([
        MaxLengthCriteria(max_length=max_length)
    ]),
    pad_token_id=tokenizer.eos_token_id, 
)

final_sentence_str = tokenizer.batch_decode(final_sentence, skip_special_tokens=True)[0]
print(final_sentence_str)

2023-11-24 12:19:19.130368: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-11-24 12:19:19.130397: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-11-24 12:19:19.131241: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-11-24 12:19:19.135342: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Decoded sequences:  ['', '']
Parsing states:  [State(start_idx=0, terminals={'__ANON_0'}), State(start_idx=0, terminals={'__ANON_0'})]
Existing tokens for terminal:  ['', '']
Valid tokens:  [['<'], ['<']]
--------------------
Decoded sequences:  ['<', '<']
Parsing states:  [State(start_idx=0, terminals={'__ANON_0'}), State(start_idx=0, terminals={'__ANON_0'})]
Existing tokens for terminal:  ['<', '<']
Valid tokens:  [['T'], ['T']]
--------------------
Decoded sequences:  ['<T', '<T']
Parsing states:  [State(start_idx=0, terminals={'__ANON_0'}), State(start_idx=0, terminals={'__ANON_0'})]
Existing tokens for terminal:  ['<T', '<T']
Valid tokens:  [['>'], ['>']]
--------------------
Decoded sequences:  ['<T>', '<T>']
Parsing states:  [State(start_idx=3, terminals={'SUBTRACT', 'DIVIDE', 'ADD', 'MULTIPLY'}), State(start_idx=3, terminals={'SUBTRACT', 'DIVIDE', 'ADD', 'MULTIPLY'})]
Existing tokens for terminal:  ['', '']
Valid tokens:  [['ad', 'add', 'sub', 'su', 'mu', 'div', 'mult', 'di', '