In [None]:
%load_ext autoreload
%autoreload 2

# https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat/tree/main
from charset_normalizer.utils import is_punctuation, is_symbol, unicode_range, is_accentuated, is_latin, \
    remove_accent, is_separator, is_cjk, is_case_variable, is_hangul, is_katakana, is_hiragana, is_ascii, is_thai

from threading import Thread
from typing import Iterator, List, Set

import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

from regexparser import SequentialPattern


In [None]:
model_id = 'meta-llama/Llama-2-13b-chat-hf'
device = 'cuda'

if torch.cuda.is_available():
    config = AutoConfig.from_pretrained(model_id)
    config.pretraining_tp = 1
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        config=config,
        torch_dtype=torch.float16,
        load_in_8bit=True,
        device_map='auto'
    )
else:
    model = None
fast_tokenizer = False
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=fast_tokenizer)

In [None]:
def get_prompt(message: str, chat_history: list[tuple[str, str]],
               system_prompt: str) -> str:
    texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
    # The first user input is _not_ stripped
    do_strip = False
    for user_input, response in chat_history:
        user_input = user_input.strip() if do_strip else user_input
        do_strip = True
        texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
    message = message.strip() if do_strip else message
    texts.append(f'{message} [/INST]')
    return ''.join(texts)


def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int:
    prompt = get_prompt(message, chat_history, system_prompt)
    input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids']
    return input_ids.shape[-1]

In [None]:
from typing import Tuple
import pandas as pd
from lmformatenforcer import JsonSchemaParser, CharacterLevelParser, RegexParser, StringParser, generate_enforced


def run(message: str,
        chat_history: list[tuple[str, str]],
        system_prompt: str,
        max_new_tokens: int = 1024,
        temperature: float = 0.8,
        top_p: float = 0.95,
        top_k: int = 50,
        required_regex: str = None,
        required_str: str = None,
        required_json_schema: dict = None) -> str:
    prompt = get_prompt(message, chat_history, system_prompt)
    inputs = tokenizer([prompt], return_tensors='pt', add_special_tokens=False, return_token_type_ids=False).to(device)
    
    generate_kwargs = dict(
        inputs,
        # streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        output_scores=True,
        return_dict_in_generate=True
    )

    parser: CharacterLevelParser = None
    if required_regex:
        parser = RegexParser(required_regex)
    if required_str:
        parser = StringParser(required_str)
    if required_json_schema:
        parser = JsonSchemaParser(required_json_schema)

    if parser:
        output = generate_enforced(model, tokenizer, parser, **generate_kwargs)
    else:
        output = model.generate(**generate_kwargs)

    sequence = output.sequences[0]
    string_output = tokenizer.decode(sequence, skip_special_tokens=True, skip_prompt=True)
    enforced_scores = output.enforced_scores
    
    pd.set_option('display.width', 1000)
    pd.set_option('display.max_columns', 10)
    pd.set_option('display.max_rows', 200)
    print(enforced_scores)
    return string_output
    
    #t = Thread(target=model.generate, kwargs=generate_kwargs)
    #t.start()

    # outputs = []
    # for text in streamer:
    #     outputs.append(text)
    #   yield ''.join(outputs)


In [None]:
DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\
"""
MAX_MAX_NEW_TOKENS = 200
DEFAULT_MAX_NEW_TOKENS = 100
MAX_INPUT_TOKEN_LENGTH = 4000
floating_point_regex = 'Michael Jordan was Born in (\d)+(.\d+)?'
integer_regex = ' Michael Jordan was Born in (\d)+.'
question = 'In what year was Michael Jordan Born? Please answer using only a number, without any prefix or suffix.'

# print("Without enforcing")
# result = run(question, chat_history=[], system_prompt=DEFAULT_SYSTEM_PROMPT, max_new_tokens=DEFAULT_MAX_NEW_TOKENS)
# print(result)

print(f"With regex force. Regex: {integer_regex}")
result = run(question, chat_history=[], system_prompt=DEFAULT_SYSTEM_PROMPT, max_new_tokens=DEFAULT_MAX_NEW_TOKENS, required_regex=integer_regex)
print(result)

# print("With string force")
# result = run(question, chat_history=[], system_prompt=DEFAULT_SYSTEM_PROMPT, max_new_tokens=DEFAULT_MAX_NEW_TOKENS, required_str='The answer is 1963')
# print(result)


In [None]:
from pydantic import BaseModel


DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\
"""
MAX_MAX_NEW_TOKENS = 200
DEFAULT_MAX_NEW_TOKENS = 100
MAX_INPUT_TOKEN_LENGTH = 4000
floating_point_regex = 'Michael Jordan was Born in (\d)+(.\d+)?'
integer_regex = ' Michael Jordan was Born in (\d)+.'
question = 'Please give me information about Michael Jordan. You MUST answer using the following json schema: '

class AnswerFormat(BaseModel):
    first_name: str
    last_name: str
    year_of_birth: int
    num_seasons_in_nba: int

question_with_schema = f'{question}{AnswerFormat.schema_json()}'

print(f"With json schema force force.")
result = run(question_with_schema, chat_history=[], system_prompt=DEFAULT_SYSTEM_PROMPT, max_new_tokens=DEFAULT_MAX_NEW_TOKENS, required_json_schema=AnswerFormat.schema())
print(result)

# abcdef
# michael
# and
# ing
# gratify
# ing
# print(f"Without json schema force force.")
# result = run(question_with_schema, chat_history=[], system_prompt=DEFAULT_SYSTEM_PROMPT, max_new_tokens=DEFAULT_MAX_NEW_TOKENS)
# print(result)
# "first_name":"Michael", "last_name"

# { "first_name":"Michael", "last_name":".Jordan", "year_of_birth":1963, "num_seasons_in_nba":15}
