In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List
import transformers
from tokenizers import AddedToken
import torch
from torch import nn

In [2]:
class LinearWrapper(nn.Module):
    def __init__(self, layer: nn.Linear, num_embeddings: int):
        super().__init__()
        self.layer = layer
        self.num_embeddings = num_embeddings
        self.n_new_tokens = num_embeddings - layer.out_features
        self.new_embeddings = nn.Linear(layer.in_features, self.n_new_tokens)
        self.new_embeddings.to(layer.weight.device).to(layer.weight.dtype)
    
    def forward(self, x):
        z1 = self.layer(x)
        z2 = self.new_embeddings(x)
        return torch.cat([z1, z2], dim=-1)

class EmbeddingWrapper(nn.Module):
    def __init__(self, embedding: nn.Embedding, num_embeddings: int):
        super().__init__()
        self.embedding_dim = embedding.embedding_dim
        self.num_embeddings = num_embeddings
        self.n_new_tokens = num_embeddings - embedding.num_embeddings
        # inspired from here 
        # https://github.com/huggingface/transformers/blob/185463784e0a0b4cd7974ce5bded7a52ae170f6d/src/transformers/modeling_utils.py#L2026
        self.old_embeddings = nn.Embedding(self.num_embeddings, self.embedding_dim)
        self.old_embeddings.weight.data = torch.zeros_like(self.old_embeddings.weight.data)
        self.old_embeddings.weight.data[:embedding.num_embeddings] = embedding.weight.data
        self.old_embeddings.to(embedding.weight.device).to(embedding.weight.dtype)
        self.new_embeddings = nn.Embedding(self.num_embeddings, self.embedding_dim)
        self.new_embeddings.weight.data[:embedding.num_embeddings] = torch.zeros_like(embedding.weight.data)
        self.new_embeddings.to(embedding.weight.device).to(embedding.weight.dtype)

    
    def forward(self, x):
        return self.old_embeddings(x) + self.new_embeddings(x)

In [3]:
model_name_or_path = '/dlabscratch1/public/llm_weights/llama2_hf/Llama-2-7b-hf/'
# load model in torch.bfloat16
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map='auto', torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
pause_token = AddedToken("<|pause|>", 
                         single_word=False, 
                         lstrip=True, 
                         rstrip=True)
                         #special=True, 
                         #normalized=False)

In [5]:
tokenizer.add_tokens([pause_token], special_tokens=True)
print(tokenizer)
# get idx of pause otken
pause_token_id = tokenizer.convert_tokens_to_ids("<|pause|>")
print(pause_token_id)

LlamaTokenizerFast(name_or_path='/dlabscratch1/public/llm_weights/llama2_hf/Llama-2-7b-hf/', vocab_size=32000, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>'}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
	0: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32000: AddedToken("<|pause|>", rstrip=True, lstrip=True, single_word=False, normalized=False, special=True),
}
32000


In [6]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head):

In [7]:
## conventionally you'd do this like this:
#model.resize_token_embeddings(len(tokenizer))

## ours
model.model.embed_tokens = EmbeddingWrapper(model.model.embed_tokens, len(tokenizer))
model.lm_head = LinearWrapper(model.lm_head, len(tokenizer))
print(model)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): EmbeddingWrapper(
      (old_embeddings): Embedding(32001, 4096)
      (new_embeddings): Embedding(32001, 4096)
    )
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post

In [8]:
print(pause_token_id)

32000


In [9]:
# test
toks1 = tokenizer.encode('The<|pause|> quick<|pause|> brown <|pause|> fox jumps over the lazy dog', return_tensors='pt')
toks2 = tokenizer.encode('The quick brown fox jumps over the lazy dog', return_tensors='pt')
idx2 = 0
for idx1 in range(len(toks1[0])):
    if toks1[0, idx1].item() != pause_token_id:
        print('w/o', toks2[0, idx2].item(), tokenizer.decode([toks2[0, idx2]]), 'w', toks1[0, idx1].item(), tokenizer.decode([toks1[0, idx1]]))
        assert toks2[0, idx2] == toks1[0, idx1]
        idx2 += 1
    else:
        print('skipping pause token...')

w/o 1 <s> w 1 <s>
w/o 450 The w 450 The
skipping pause token...
w/o 4996 quick w 4996 quick
skipping pause token...
w/o 17354 brown w 17354 brown
skipping pause token...
w/o 1701 fo w 1701 fo
w/o 29916 x w 29916 x
w/o 432 j w 432 j
w/o 17204 umps w 17204 umps
w/o 975 over w 975 over
w/o 278 the w 278 the
w/o 17366 lazy w 17366 lazy
w/o 11203 dog w 11203 dog


In [10]:
out = model.generate(tokenizer.encode('The<|pause|> quick<|pause|> brown <|pause|> fox jumps over the lazy dog', return_tensors='pt'), max_length=50, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id)



In [11]:
tokenizer.decode(out[0], skip_special_tokens=True)

'The quick brown fox jumps over the lazy dog to the other side of the river, to get over the other side of the sea.\nThe Lazy Dog, a local dog walking company, has been fined £'

In [12]:
tokenizer.decode(out[0], skip_special_tokens=False)

'<s> The<|pause|> quick<|pause|> brown<|pause|> fox jumps over the lazy dog to the other side of the river, to get over the other side of the sea.\nThe Lazy Dog, a local dog walking company, has been fined £'

In [15]:
model.save_pretrained('/dlabscratch1/tmp/llama2_test/')

In [16]:
model2 = AutoModelForCausalLM.from_pretrained('/dlabscratch1/tmp/llama2_test/', torch_dtype=torch.float16)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Some weights of the model checkpoint at /dlabscratch1/tmp/llama2_test/ were not used when initializing LlamaForCausalLM: ['lm_head.layer.weight', 'lm_head.new_embeddings.bias', 'lm_head.new_embeddings.weight', 'model.embed_tokens.new_embeddings.weight', 'model.embed_tokens.old_embeddings.weight']
- This IS expected if you are initializing LlamaForCausalLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LlamaForCausalLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of LlamaForCausalLM were not initialized from the model checkpoint at /dlabscratch1/tmp/llama2_test/ and are newly initialized: ['lm_head.weight', 'model.embed_tokens.weight']
You should probably TRAIN this model on a down