In [1]:
import os 
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

In [2]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

import torch
import numpy as np
from transformers import LogitsProcessor, AutoModelForCausalLM, AutoTokenizer, BeamSearchScorer, LogitsProcessorList, MaxLengthCriteria, StoppingCriteriaList

import cfg_decoding.parsing as p
import cfg_decoding.logits_processor as lp

import importlib
importlib.reload(p)
importlib.reload(lp)

  from .autonotebook import tqdm as notebook_tqdm


<module 'cfg_decoding.logits_processor' from '/workspaces/funcqa_experiments/cfg_decoding/logits_processor.py'>

In [3]:
MODEL_NAME = "meta-llama/Llama-2-13b-chat-hf"

In [4]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

In [5]:
importlib.reload(p)

with open("funcqa.lark", "r") as f:
    cfg_def = f.read()

stepper = p.create_parsing_stepper(cfg_def, tokenizer)

print(stepper.get_parsing_state("add(1"))

# s = 'add(10., 2.)'
# for i in range(len(s)+1):
#     cfg_state = stepper.get_parsing_state(s[:i])
#     print(f"'{s[:i]}' -> {cfg_state}")

State(start_idx=0, terminals={'__ANON_0'})


In [6]:
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, load_in_4bit=True, device_map="cuda:0")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

model.config.pad_token_id = model.config.eos_token_id
tokenizer.pad_token = tokenizer.eos_token

Loading checkpoint shards: 100%|██████████| 3/3 [01:54<00:00, 38.27s/it]


In [8]:
importlib.reload(lp)

num_beams = 10
input_prompt = '''Use functions add, mul, div and sub to solve the following math problem.

E.g. multiply(1, 20) or add(1, mul(2, 3)) or divide(5, 3) or subtract(15, 3) or add(10, 2)

Question: 1 + 20

Calculation: '''

input_ids = tokenizer(
    input_prompt, 
    return_tensors="pt"
).input_ids
input_ids = torch.stack([input_ids] * num_beams, dim=0).reshape(num_beams, -1).to(model.device)
bos_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) * model.config.bos_token_id
input_ids = torch.cat([bos_ids, input_ids], dim=-1)

prompt_end_index = input_ids.shape[1]
max_length = prompt_end_index + 20

final_sentence = model.beam_search(
    input_ids, 
    beam_scorer=BeamSearchScorer(
        batch_size=1,
        max_length=max_length,
        num_beams=num_beams,
        device="cuda",
        length_penalty=1.0,
        do_early_stopping=True,
    ),
    logits_processor = LogitsProcessorList([
        lp.GrammarConstrainedLogitsProcessor(tokenizer, stepper, prompt_end_index=prompt_end_index)
    ]),
    stopping_criteria = StoppingCriteriaList([
        MaxLengthCriteria(max_length=max_length)
    ]),
    pad_token_id=tokenizer.eos_token_id, 
)

final_sentence_str = tokenizer.batch_decode(final_sentence, skip_special_tokens=True)[0]
print(final_sentence_str)

Use functions add, mul, div and sub to solve the following math problem.

E.g. multiply(1, 20) or add(1, mul(2, 3)) or divide(5, 3) or subtract(15, 3) or add(10, 2)

Question: 1 + 20

Calculation: <T>add(1,20)


In [23]:
text = model.generate(
    input_ids[0:1], 
    logits_processor = LogitsProcessorList([
        lp.GrammarConstrainedLogitsProcessor(tokenizer, stepper, prompt_end_index=prompt_end_index)
    ]),
    max_new_tokens=40,
    # stopping_criteria = StoppingCriteriaList([
    #     MaxLengthCriteria(max_length=max_length)
    # ]),
    pad_token_id=tokenizer.eos_token_id, 
    do_sample=False,
    temperature=0.0,
)



In [24]:
text[0]

tensor([    1,     1,  4803,  3168,   788, 29892, 15065, 29892,  1933,   322,
         1014,   304,  4505,   278,  1494,  5844,  1108, 29889,    13,    13,
        29923, 29889, 29887, 29889, 22932, 29898, 29896, 29892, 29871, 29906,
        29900, 29897,   470,   788, 29898, 29896, 29892, 15065, 29898, 29906,
        29892, 29871, 29941,   876,   470, 16429, 29898, 29945, 29892, 29871,
        29941, 29897,   470, 23197, 29898, 29896, 29945, 29892, 29871, 29941,
        29897,   470,   788, 29898, 29896, 29900, 29892, 29871, 29906, 29897,
           13,    13, 16492, 29901, 29871, 29896,   718, 29871, 29906, 29900,
           13,    13, 27065,   362, 29901, 29871, 29966, 29911, 29958,  1202,
        29898, 29896, 29892, 29906, 29900, 29897,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0], device='cuda:

In [26]:
tokenizer.batch_decode(text[:, prompt_end_index:], skip_special_tokens=True)[0]

'<T>add(1,20)'