In [1]:
from jinja2 import Template

CHAT_TEMPLATE = """
{%- set nl = '\\n' -%}
{%- set im_start = '<|im_start|>' -%}
{%- set im_end = '<|im_end|>' -%}
{{- im_start -}}system{{- nl -}}{{- system_message -}}{{- im_end -}}{{- nl -}}
{%- for message in messages -%}
    {%- if message.from == 'human' -%}
        {{- im_start -}}user{{- nl -}}
        {%- if message.value is not none -%}
            {%- set parts = message.value.split('<image>') -%}
            {%- for part in parts -%}
                {{- part -}}
                {%- if not loop.last -%}<image>{{- nl -}}{%- endif -%}
            {%- endfor -%}
        {%- endif -%}
        {{- im_end -}}{{- nl -}}
    {%- elif message.from == 'gpt' -%}
        {{- im_start -}}assistant{{- nl -}}
        {{- message.value if message.value is not none else '' -}}
        {{- im_end -}}{{- nl -}}
    {%- endif -%}
{%- endfor -%}
"""

conversation = [
        {"from": "human", "value": "<image>Hello!"},
        {"from": "gpt", "value": "Hi there!"}
]
system_message = "Test system message"
template = Template(CHAT_TEMPLATE)
rendered = template.render(
    system_message=system_message,
    messages=conversation
)
print(rendered)

<|im_start|>system
Test system message<|im_end|>
<|im_start|>user
<image>
Hello!<|im_end|>
<|im_start|>assistant
Hi there!<|im_end|>



In [19]:
from transformers import AutoTokenizer
import torch, re
from llava.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN

def preprocess_qwen(sources, tokenizer, has_image: bool = False, max_len=2048, system_message: str = "You are a helpful assistant."):
    roles = {"human": "<|im_start|>user", "gpt": "<|im_start|>assistant"}
    im_start, im_end = tokenizer.additional_special_tokens_ids
    nl_tokens = tokenizer("\n").input_ids
    _system = tokenizer("system").input_ids + nl_tokens
    input_ids = []
    source = sources
    if roles[source[0]["from"]] != roles["human"]: source = source[1:]
    input_id, target = [], []
    system = [im_start] + _system + tokenizer(system_message).input_ids + [im_end] + nl_tokens
    input_id += system
    target += [im_start] + [IGNORE_INDEX] * (len(system) - 3) + [im_end] + nl_tokens
    assert len(input_id) == len(target)
    for j, sentence in enumerate(source):
        role = roles[sentence["from"]]
        if has_image and sentence["value"] is not None and "<image>" in sentence["value"]:
            num_image = len(re.findall(DEFAULT_IMAGE_TOKEN, sentence["value"]))
            texts = sentence["value"].split('<image>')
            _input_id = tokenizer(role).input_ids + nl_tokens 
            for i,text in enumerate(texts):
                _input_id += tokenizer(text).input_ids 
                if i<len(texts)-1: _input_id += [IMAGE_TOKEN_INDEX] + nl_tokens
            _input_id += [im_end] + nl_tokens
            assert sum([i==IMAGE_TOKEN_INDEX for i in _input_id])==num_image
        else:
            if sentence["value"] is None: _input_id = tokenizer(role).input_ids + nl_tokens
            else: _input_id = tokenizer(role).input_ids + nl_tokens + tokenizer(sentence["value"]).input_ids + [im_end] + nl_tokens
        input_id += _input_id
    input_ids.append(input_id)
    return torch.tensor(input_ids, dtype=torch.long)

tokenizer = AutoTokenizer.from_pretrained('neulab/Pangea-7B')
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
tokenizer.add_tokens([DEFAULT_IMAGE_TOKEN], special_tokens=True)
input_ids = preprocess_qwen(conversation, tokenizer, has_image=True, system_message=system_message)

In [20]:
input_ids

tensor([[151644,   8948,    198,   2271,   1849,   1943, 151645,    198, 151644,
            872,    198,   -200,    198,   9707,      0, 151645,    198, 151644,
          77091,    198,  13048,   1052,      0, 151645,    198]])

In [21]:
len(tokenizer.vocab)

151650

In [25]:
new_input_ids = tokenizer(rendered)["input_ids"]
image_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_IMAGE_TOKEN)
new_input_ids[new_input_ids.index(image_token_id)] = IMAGE_TOKEN_INDEX
new_input_ids == input_ids[0].tolist()

True

In [13]:
len(tokenizer.vocab)

151649

In [16]:
print(tokenizer.decode(input_ids[0][:11]))

<|im_start|>system
Test system message<|im_end|>
<|im_start|>user

