In [2]:
from transformers import AutoTokenizer
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
granite = AutoTokenizer.from_pretrained('ibm-granite/granite-3.1-8b-instruct')
llama = AutoTokenizer.from_pretrained('meta-llama/Llama-3.1-8B-Instruct')
phi = AutoTokenizer.from_pretrained('microsoft/phi-4')

In [59]:
import typing as t
# we want to unmask everything BUT the system
msgs = [
    {
        "content": "Hello world!",
        "role": "system"
    },
    {
        "content": "Holey shit",
        "role": "user"
    },
    {
        "content": "Hello from the other side",
        "role": "assistant"
    }
]

GLYPH = "𓀴"

def placeholder_msgs(msgs: t.List[t.Dict[str, str]]):
    return [{"role": m["role"], "content": GLYPH} for m in msgs]


input_ids = granite.apply_chat_template(msgs)

individual_msgs = [granite.encode(msg["content"]) for msg in msgs]
placeholder_ids = granite.apply_chat_template(placeholder_msgs(msgs))


# print(individual_msgs)
# print(input_ids)
# print(placeholder_ids)

i = 0
glyph_id = granite.encode(GLYPH)
print(glyph_id)

# just try something shitty first
ranges = []

# 

i = 0
while i < len(input_ids):
    # look to start substring matching
    if placeholder_ids[i] == glyph_id[0]:
        print(f'potentially found a glyph id match starting at {i=}')
        j = i
        k = 0
        matching = True
        while k < len(glyph_id) and j < len(placeholder_ids):
            # keep looking to see how far we can match against the glyphd ID
            if placeholder_ids[j] != glyph_id[k]: 
                print(f'but unfortunately, found that at {k=}, {j=}, {placeholder_ids[j]=} != {glyph_id[k]=}')
                matching = False
                break

            j += 1
            k += 1
        
        # we were able to loop through successfully
        if k == len(glyph_id) and matching:
            # we now know that between `starti` and `i` there exists a range which is part of a tokenizer
            ranges.append((i, j))
            # print(f"{k=}, {j=}, {placeholder_ids[j]=}")

            # now we can set `i` <-- j, and set `starti` <-- j + 1
            i = j
    i += 1

assert len(ranges) == len(msgs)


# next, lets collect the ranges and decode them to see what we get 
pieces = [granite.decode(placeholder_ids[i:j]) for i, j in ranges]
# print(pieces)
# print('=---')
# print(granite.decode(input_ids))


# basically the algorithm will look like this:
# 1. given some list of messages, create a template set of messages with the contents replaced with a glyph
# 2. tokenize the glyph messages and identify the portions in the message where the glyph exists
# 3. with the tokenized list, identify the ranges where the glyph exists. We will want to replace these ranges with tokenized copies of each message
# 4. with the knowledge of where the new message ranges are, we can now unmask according to our policy
#   1. create a copy of the input IDs and leave the portions masked (-100) except for where we expect them to be unmasked
#   2. when unmasking a particular message, if the tokenizer has an EOS token, assert that it is last token 



# the algorithm
final_input_ids = []
final_labels = []

j = 0


while j < len(placeholder_ids):
    # remove one range
    if not ranges:
        # just append everything else to the end
        final_input_ids.extend(placeholder_ids[j:])
        final_labels.extend([-100] * len(placeholder_ids[j:]))
        break
    
    start_idx, end_idx = ranges[0]
    if j < start_idx:
        # default case, just continue adding into input IDs and labels without doing anything
        final_input_ids.append(placeholder_ids[j])
        final_labels.append(-100)   # mask this out, we dont care about it
        j += 1
        continue
    else:
        # otherwise, we now must insert the tokenized user message. We select it via:
        msg_idx = len(individual_msgs) - len(ranges)  # this should always select the correct message
        msg = individual_msgs[msg_idx]

        # now we can append the correct message into the input IDs with the proper masking
        final_input_ids.extend(msg)
        final_labels.extend(msg)

        # continue only looking at the next set of ranges
        j = end_idx
        ranges = ranges[1:]

        # we want to also unmask the EOS token if it is present
        print(f"after extending, {j=} is set to {placeholder_ids[j]=}")

        if granite.eos_token_id is not None:
            print('detected eos token id, proceeding forward')
            suffix_start_j = j
            while j < len(placeholder_ids) and placeholder_ids[j] != granite.eos_token_id:
                j += 1
            
            if j >= len(placeholder_ids) or placeholder_ids[j] != granite.eos_token_id:
                raise RuntimeError('failed to find the trailing EOS token id')
            
            # by now we know that we are both within range and have found the trailing eos token id
            final_input_ids.extend(placeholder_ids[suffix_start_j:j+1])
            final_labels.extend(placeholder_ids[suffix_start_j:j+1])
            j += 1

        


print(final_input_ids)
print(final_labels)

print(granite.decode(final_input_ids))
assert granite.decode(final_input_ids) == granite.decode(input_ids)
assert len(final_labels) == len(final_input_ids)

print(granite.encode(granite.eos_token))

_final_input_ids = final_input_ids[:]
_final_labels = final_labels[:]

[189, 246, 227, 131]
potentially found a glyph id match starting at i=3
potentially found a glyph id match starting at i=12
potentially found a glyph id match starting at i=21
after extending, j=7 is set to placeholder_ids[j]=0
detected eos token id, proceeding forward
after extending, j=16 is set to placeholder_ids[j]=0
detected eos token id, proceeding forward
after extending, j=25 is set to placeholder_ids[j]=0
detected eos token id, proceeding forward
[49152, 2946, 49153, 8279, 5788, 19, 0, 203, 49152, 496, 49153, 42808, 107, 787, 283, 0, 203, 49152, 17594, 49153, 8279, 645, 322, 1604, 5209, 0, 203]
[-100, -100, -100, 8279, 5788, 19, 0, -100, -100, -100, -100, 42808, 107, 787, 283, 0, -100, -100, -100, -100, 8279, 645, 322, 1604, 5209, 0, -100]
<|start_of_role|>system<|end_of_role|>Hello world!<|end_of_text|>
<|start_of_role|>user<|end_of_role|>Holey shit<|end_of_text|>
<|start_of_role|>assistant<|end_of_role|>Hello from the other side<|end_of_text|>

[0]


In [52]:
granite.eos_token_id

0

In [49]:
test = [1, 2, 3]
test[0:2]

[1, 2]

In [89]:
from transformers import PreTrainedTokenizer

GLYPH = "𓀴"

def placeholder_msgs(msgs: t.List[t.Dict[str, str]]):
    return [{"role": m["role"], "content": GLYPH} for m in msgs]


input_ids = granite.apply_chat_template(msgs)

individual_msgs = [granite.encode(msg["content"]) for msg in msgs]
placeholder_ids = granite.apply_chat_template(placeholder_msgs(msgs))


# print(individual_msgs)
# print(input_ids)
# print(placeholder_ids)

i = 0
glyph_id = granite.encode(GLYPH)
print(glyph_id)

# just try something shitty first
ranges = []

# 

i = 0
while i < len(input_ids):
    # look to start substring matching
    if placeholder_ids[i] == glyph_id[0]:
        print(f'potentially found a glyph id match starting at {i=}')
        j = i
        k = 0
        matching = True
        while k < len(glyph_id) and j < len(placeholder_ids):
            # keep looking to see how far we can match against the glyphd ID
            if placeholder_ids[j] != glyph_id[k]: 
                print(f'but unfortunately, found that at {k=}, {j=}, {placeholder_ids[j]=} != {glyph_id[k]=}')
                matching = False
                break

            j += 1
            k += 1
        
        # we were able to loop through successfully
        if k == len(glyph_id) and matching:
            # we now know that between `starti` and `i` there exists a range which is part of a tokenizer
            ranges.append((i, j))
            # print(f"{k=}, {j=}, {placeholder_ids[j]=}")

            # now we can set `i` <-- j, and set `starti` <-- j + 1
            i = j
    i += 1

assert len(ranges) == len(msgs)


# next, lets collect the ranges and decode them to see what we get 
# print(pieces)
# print('=---')
# print(granite.decode(input_ids))


# basically the algorithm will look like this:
# 1. given some list of messages, create a template set of messages with the contents replaced with a glyph
# 2. tokenize the glyph messages and identify the portions in the message where the glyph exists
# 3. with the tokenized list, identify the ranges where the glyph exists. We will want to replace these ranges with tokenized copies of each message
# 4. with the knowledge of where the new message ranges are, we can now unmask according to our policy
#   1. create a copy of the input IDs and leave the portions masked (-100) except for where we expect them to be unmasked
#   2. when unmasking a particular message, if the tokenizer has an EOS token, assert that it is last token 


def get_placeholder_ranges(placeholder_ids: t.List[int], tokenizer: PreTrainedTokenizer):
    glyph_id = tokenizer.encode(GLYPH, add_special_tokens=False)  # we want to ignore special tokens since we're just extracting the token IDs here
    ranges = []
    i = 0
    while i < len(placeholder_ids):
        # look to start substring matching
        if placeholder_ids[i] == glyph_id[0]:
            print(f'potentially found a glyph id match starting at {i=}')
            j = i
            k = 0
            matching = True
            while k < len(glyph_id) and j < len(placeholder_ids):
                # keep looking to see how far we can match against the glyphd ID
                if placeholder_ids[j] != glyph_id[k]: 
                    print(f'but unfortunately, found that at {k=}, {j=}, {placeholder_ids[j]=} != {glyph_id[k]=}')
                    matching = False
                    break

                j += 1
                k += 1

            # we were able to loop through successfully
            if k == len(glyph_id) and matching:
                # we now know that between `starti` and `i` there exists a range which is part of a tokenizer
                ranges.append((i, j))

                # now we can set `i` <-- j, and set `starti` <-- j + 1
                i = j
        i += 1

    return ranges



def unmask_messages(msgs: t.List[t.Dict[str, str]], tokenizer: PreTrainedTokenizer) -> t.Dict[str, t.List[int]]:
    """
    Given a list of messages and an arbitrary tokenizer, returns a dictionary with
    `input_ids` and `labels` containing the correct masking.
    """

    # first we need to create the placeholder IDs
    placeholder_ids = tokenizer.apply_chat_template(placeholder_msgs(msgs))
    ranges = get_placeholder_ranges(placeholder_ids, tokenizer)
    individual_msgs = [tokenizer.encode(m["content"], add_special_tokens=False) for m in msgs]  # no special tokens here since we are looking to inject these into a broader template

    final_input_ids = []
    final_labels = []

    j = 0
    while j < len(placeholder_ids):
        # remove one range
        if not ranges:
            # just append everything else to the end
            final_input_ids.extend(placeholder_ids[j:])
            final_labels.extend([-100] * len(placeholder_ids[j:]))
            break
        
        start_idx, end_idx = ranges[0]
        if j < start_idx:
            # default case, just continue adding into input IDs and labels without doing anything
            final_input_ids.append(placeholder_ids[j])
            final_labels.append(-100)   # mask this out, we dont care about it
            j += 1
            continue
        else:
            # otherwise, we now must insert the tokenized user message. We select it via:
            msg_idx = len(individual_msgs) - len(ranges)  # this should always select the correct message
            msg = individual_msgs[msg_idx]

            # now we can append the correct message into the input IDs with the proper masking
            final_input_ids.extend(msg)
            final_labels.extend(msg)

            # continue only looking at the next set of ranges
            j = end_idx
            ranges = ranges[1:]

            # we want to also unmask the EOS token if it is present
            print(f"after extending, {j=} is set to {placeholder_ids[j]=}")

            if granite.eos_token_id is not None:
                print('detected eos token id, proceeding forward')
                suffix_start_j = j
                while j < len(placeholder_ids) and placeholder_ids[j] != tokenizer.eos_token_id:
                    j += 1

                if j >= len(placeholder_ids) or placeholder_ids[j] != tokenizer.eos_token_id:
                    raise RuntimeError('failed to find the trailing EOS token id')

                # by now we know that we are both within range and have found the trailing eos token id
                final_input_ids.extend(placeholder_ids[suffix_start_j:j+1])
                final_labels.extend(placeholder_ids[suffix_start_j:j+1])
                j += 1

    return {
        "input_ids": final_input_ids,
        "labels": final_labels
    }

msgs = [
    {
        "content": "Hello world!",
        "role": "system"
    },
    {
        "content": "Holey shit",
        "role": "user"
    },
    {
        "content": "Hello from the other side",
        "role": "assistant"
    }
]

GLYPH = "𓀴"
print(f"encoded glyph: {llama.encode(GLYPH)}")

results = unmask_messages(msgs, llama)
results

# assert results["input_ids"] == _final_input_ids
# assert results["labels"] == _final_labels
print(results["labels"])
# print(results["input_ids"])

# print(llama.apply_chat_template(placeholder_msgs(msgs)))


# lets print out all unmasked tokens
unmasked = [tok for tok in results["labels"] if tok != -100]
print(llama.decode(unmasked))




[189, 246, 227, 131]
potentially found a glyph id match starting at i=3
potentially found a glyph id match starting at i=12
potentially found a glyph id match starting at i=21
encoded glyph: [128000, 172, 241, 222, 112]
potentially found a glyph id match starting at i=25
potentially found a glyph id match starting at i=34
potentially found a glyph id match starting at i=43
after extending, j=29 is set to placeholder_ids[j]=128009
detected eos token id, proceeding forward
after extending, j=38 is set to placeholder_ids[j]=128009
detected eos token id, proceeding forward
after extending, j=47 is set to placeholder_ids[j]=128009
detected eos token id, proceeding forward
[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 9906, 1917, 0, 128009, -100, -100, -100, -100, 39, 50099, 17619, 128009, -100, -100, -100, -100, 9906, 505, 279, 1023, 3185, 128009]
Hello world!<|eot_id|>Holey shit<|eot_id

In [73]:
for tok in llama.apply_chat_template(placeholder_msgs(msgs)):
    print(f'{tok} --> {repr(llama.decode(tok))}')

128000 --> '<|begin_of_text|>'
128006 --> '<|start_header_id|>'
9125 --> 'system'
128007 --> '<|end_header_id|>'
271 --> '\n\n'
38766 --> 'Cut'
1303 --> 'ting'
33025 --> ' Knowledge'
2696 --> ' Date'
25 --> ':'
6790 --> ' December'
220 --> ' '
2366 --> '202'
18 --> '3'
198 --> '\n'
15724 --> 'Today'
2696 --> ' Date'
25 --> ':'
220 --> ' '
1627 --> '26'
10263 --> ' Jul'
220 --> ' '
2366 --> '202'
19 --> '4'
271 --> '\n\n'
172 --> '�'
241 --> '�'
222 --> '�'
112 --> '�'
128009 --> '<|eot_id|>'
128006 --> '<|start_header_id|>'
882 --> 'user'
128007 --> '<|end_header_id|>'
271 --> '\n\n'
172 --> '�'
241 --> '�'
222 --> '�'
112 --> '�'
128009 --> '<|eot_id|>'
128006 --> '<|start_header_id|>'
78191 --> 'assistant'
128007 --> '<|end_header_id|>'
271 --> '\n\n'
172 --> '�'
241 --> '�'
222 --> '�'
112 --> '�'
128009 --> '<|eot_id|>'


In [83]:
llama.decode([172, 241, 222, 112]), llama.encode(GLYPH, add_special_tokens=False)

('𓀴', [172, 241, 222, 112])

[128000, 13347, 11, 358, 2846, 7043]