In [1]:
from llama.llama.tokenizer import Tokenizer as OriginalTokenizer
from transformers import LlamaTokenizer as HFTokenizer
from transformers.tokenization_utils import PreTrainedTokenizer
from typing import Literal, List, TypedDict, Optional
import json
import numpy as np
import os

MAIN_DIR = ".."
with open(os.path.join(MAIN_DIR, "auth", "api_keys.json"), "r") as f:
    api_keys = json.load(f)

  from .autonotebook import tqdm as notebook_tqdm


# Original Llama implementation

In [6]:
Role = Literal["system", "user", "assistant"]

class Message(TypedDict):
    role: Role
    content: str

class CompletionPrediction(TypedDict, total=False):
    generation: str
    tokens: List[str]  # not required
    logprobs: List[float]  # not required


class ChatPrediction(TypedDict, total=False):
    generation: Message
    tokens: List[str]  # not required
    logprobs: List[float]  # not required

Dialog = List[Message]

B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"

SPECIAL_TAGS = [B_INST, E_INST, "<<SYS>>", "<</SYS>>"]
UNSAFE_ERROR = "Error: special tags are not allowed as part of the prompt."

In [7]:
llama_tokenizer = OriginalTokenizer(
    os.path.join(MAIN_DIR, "model", "original-llama-tokenizer", "tokenizer.model")
    )

sentencepiece_vocab = {
    id : llama_tokenizer.sp_model.id_to_piece(id)
    for id in range(llama_tokenizer.sp_model.get_piece_size())
}

In [87]:
max_gen_len = 256

unsafe_requests = []

dialog = [
    Message(role= "system", content="This is a system message"),
    Message(role= "user", content="What's 1 + 1?"),
    Message(role= "assistant", content="2"),
    Message(role = "user", content="Are you sure?")
    ]

unsafe_requests.append(
    any([tag in msg["content"] for tag in SPECIAL_TAGS for msg in dialog])
)
if dialog[0]["role"] == "system":
    dialog = [
        {
            "role": dialog[1]["role"],
            "content": B_SYS
            + dialog[0]["content"]
            + E_SYS
            + dialog[1]["content"],
        }
    ] + dialog[2:]
assert all([msg["role"] == "user" for msg in dialog[::2]]) and all(
    [msg["role"] == "assistant" for msg in dialog[1::2]]
), (
    "model only supports 'system', 'user' and 'assistant' roles, "
    "starting with 'system', then 'user' and alternating (u/a/u/a/u...)"
)
dialog_tokens: List[int] = sum(
    [
        llama_tokenizer.encode(
            f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
            bos=True,
            eos=True,
        )
        for prompt, answer in zip(
            dialog[::2],
            dialog[1::2],
        )
    ],
    [],
)

assert (
    dialog[-1]["role"] == "user"
), f"Last message must be from user, got {dialog[-1]['role']}"

dialog_tokens += llama_tokenizer.encode(
    f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
    bos=True,
    eos=False,
)

AssertionError: Last message must be from user, got assistant

In [62]:
print("No of tokens: ", len(dialog_tokens))
print("".join([sentencepiece_vocab[id] for id in dialog_tokens]))

No of tokens:  51
<s>▁[INST]▁<<SYS>><0x0A>This▁is▁a▁system▁message<0x0A><</SYS>><0x0A><0x0A>What's▁1▁+▁1?▁[/INST]▁2▁</s><s>▁[INST]▁Are▁you▁sure?▁[/INST]


# HuggingFace Tokenizer

- **With System**:\<s> [INST] <\<SYS>>\nSystem message\n<\</SYS>>\n\n**User1** [/INST] **Assistant1** \</s>\<s> [INST] **User2** [/INST]``

- **Without System**: \<s> [INST] **User1** [/INST] **Assistant1** \</s>\<s> [INST] **User2** [/INST]

In [76]:
hf_tokenizer = HFTokenizer.from_pretrained(
    "meta-llama/Llama-2-7b-chat-hf", token=api_keys["HF_API_KEY"], cache_dir=os.path.join(MAIN_DIR, "model"),
    add_bos_token = False, add_eos_token = False
    )

In [82]:
def create_llama2_chat_prompt(
    messages: List[str], tokenizer: PreTrainedTokenizer, system_prompt: Optional[str] = None
) -> str:
    prompt_str = ""

    B_INST, E_INST = "[INST]", "[/INST]"
    B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
    BOS_TOKEN, EOS_TOKEN = tokenizer.bos_token, tokenizer.eos_token

    if system_prompt:
        messages[0] = B_SYS + system_prompt + E_SYS + messages[0]
        
    for user_message, assistant_message in zip(messages[::2], messages[1::2]):
        user_message = BOS_TOKEN + " " + B_INST + " " + user_message + " " + E_INST
        assistant_message = " " + assistant_message + " " + EOS_TOKEN
        
        prompt_str += user_message
        prompt_str += assistant_message

    if (len(messages) % 2) == 1:       
        user_query = BOS_TOKEN + " " + B_INST + " " + messages[-1] + " " + E_INST
        prompt_str += user_query
    
    return prompt_str.strip()

ALPACA_INSTRUCTION_COMPLETION_TEMPLATE = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{instruction}

### Input:
{input}

### Response:
"""

ALPACA_INSTRUCTION_TEMPLATE_NO_INPUT = """Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
{instruction}

### Response:
"""

def create_llama2_instruction_prompt(
    instruction: str,
    input: Optional[str] = None,
    output: Optional[str] = None,
    prompt_template: str = ALPACA_INSTRUCTION_COMPLETION_TEMPLATE
) -> str:
    if not input:
        prompt_template = ALPACA_INSTRUCTION_TEMPLATE_NO_INPUT
        prompt_str = prompt_template.format(instruction = instruction)
        
    else:
        prompt_str = prompt_template.format(instruction=instruction, input=input)
    
    if output:
        prompt_str += f"\n{output}"
            
    return prompt_str

In [88]:
system_prompt = "This is a system message"
messages = [
    "What's 1 + 1?",
    "2",
    "Are you sure?",
]

prompt_str = create_llama2_chat_prompt(messages, hf_tokenizer, system_prompt)
hf_tokens = hf_tokenizer.encode(prompt_str)

In [85]:
# Check
np.array_equal(np.array(hf_tokens), np.array(dialog_tokens))

False