In [12]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List
import transformers
from tokenizers import AddedToken
import torch

In [2]:
model_name_or_path = '/dlabscratch1/public/llm_weights/llama2_hf/Llama-2-7b-hf/'

In [3]:
# load model in torch.bfloat16
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map='auto', torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [9]:
tokenizer.add_tokens(["<test1>"], special_tokens=False)
tokenizer.add_tokens(["<test2>"], special_tokens=True)

0

In [88]:
pause_token = AddedToken("<|pause|>", 
                         single_word=False, 
                         lstrip=True, 
                         rstrip=True)
                         #special=True, 
                         #normalized=False)

In [94]:
tokenizer.add_tokens([pause_token], special_tokens=True)
print(tokenizer)
# get idx of pause otken
pause_token_id = tokenizer.convert_tokens_to_ids("<|pause|>")
print(pause_token_id)

LlamaTokenizerFast(name_or_path='/dlabscratch1/public/llm_weights/llama2_hf/Llama-2-7b-hf/', vocab_size=32000, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>'}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
	0: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32000: AddedToken("<test1>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
	32001: AddedToken("<test2>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32002: AddedToken("<pause>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True)

In [95]:
# update model
model.resize_token_embeddings(len(tokenizer))
print(model)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32004, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): Lin

In [96]:
print(pause_token_id)

32003


In [98]:
# test
toks1 = tokenizer.encode('The<|pause|> quick<|pause|> brown <|pause|> fox jumps over the lazy dog', return_tensors='pt')
toks2 = tokenizer.encode('The quick brown fox jumps over the lazy dog', return_tensors='pt')
idx2 = 0
for idx1 in range(len(toks1[0])):
    if toks1[0, idx1].item() != pause_token_id:
        print('w/o', toks2[0, idx2].item(), tokenizer.decode([toks2[0, idx2]]), 'w', toks1[0, idx1].item(), tokenizer.decode([toks1[0, idx1]]))
        assert toks2[0, idx2] == toks1[0, idx1]
        idx2 += 1
    else:
        print('skipping pause token...')

w/o 1 <s> w 1 <s>
w/o 450 The w 450 The
skipping pause token...
w/o 4996 quick w 4996 quick
skipping pause token...
w/o 17354 brown w 17354 brown
skipping pause token...
w/o 1701 fo w 1701 fo
w/o 29916 x w 29916 x
w/o 432 j w 432 j
w/o 17204 umps w 17204 umps
w/o 975 over w 975 over
w/o 278 the w 278 the
w/o 17366 lazy w 17366 lazy
w/o 11203 dog w 11203 dog


In [99]:
out = model.generate(tokenizer.encode('The<|pause|> quick<|pause|> brown <|pause|> fox jumps over the lazy dog', return_tensors='pt'), max_length=50, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id)



In [100]:
tokenizer.decode(out[0], skip_special_tokens=True)

'The quick brown fox jumps over the lazy dog.\nThe quick brown fox jumps over the lazy dog.\nThe quick brown fox jumps over the lazy dog. The quick brown fox jumps over'

In [101]:
tokenizer.decode(out[0], skip_special_tokens=False)

'<s> The<|pause|> quick<|pause|> brown<|pause|> fox jumps over the lazy dog.\nThe quick brown fox jumps over the lazy dog.\nThe quick brown fox jumps over the lazy dog. The quick brown fox jumps over'