In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List
import transformers
from tokenizers import AddedToken
import torch
from torch import nn

In [9]:
class LinearWrapper(nn.Module):
    def __init__(self, layer: nn.Linear, num_embeddings: int, freeze_old=True):
        super().__init__()
        self.layer = layer
        self.num_embeddings = num_embeddings
        self.n_new_tokens = num_embeddings - layer.out_features
        self.new_embeddings = nn.Linear(layer.in_features, self.n_new_tokens, bias=False)
        self.new_embeddings.to(layer.weight.device).to(layer.weight.dtype)
        if freeze_old:
            for param in self.layer.parameters():
                param.requires_grad = False
    
    def forward(self, x):
        z1 = self.layer(x)
        z2 = self.new_embeddings(x)
        return torch.cat([z1, z2], dim=-1)

class EmbeddingWrapper(nn.Module):
    def __init__(self, embedding: nn.Embedding, num_embeddings: int, freeze_old=True):
        super().__init__()
        self.embedding_dim = embedding.embedding_dim
        self.num_embeddings = num_embeddings
        self.n_new_tokens = num_embeddings - embedding.num_embeddings
        # inspired from here 
        # https://github.com/huggingface/transformers/blob/185463784e0a0b4cd7974ce5bded7a52ae170f6d/src/transformers/modeling_utils.py#L2026
        self.old_embeddings = nn.Embedding(self.num_embeddings, self.embedding_dim)
        self.old_embeddings.weight.data = torch.ones_like(self.old_embeddings.weight.data)*0#1e-7
        self.old_embeddings.weight.data[:embedding.num_embeddings] = embedding.weight.data
        self.old_embeddings.to(embedding.weight.device).to(embedding.weight.dtype)
        self.new_embeddings = nn.Embedding(self.num_embeddings, self.embedding_dim)
        self.new_embeddings.weight.data[:embedding.num_embeddings] = torch.ones_like(embedding.weight.data)*0#1e-7
        self.new_embeddings.to(embedding.weight.device).to(embedding.weight.dtype)
        if freeze_old:
            for param in self.old_embeddings.parameters():
                param.requires_grad = False

    
    def forward(self, x):
        return self.old_embeddings(x) + self.new_embeddings(x)


class Llama2EmbeddingSurgeon():
    def __init__(self, llama, extended_tokenizer):
        self.llama = llama 
        self.extended_tokenizer = extended_tokenizer
        self.extended_embedding = EmbeddingWrapper(llama.model.embed_tokens, len(extended_tokenizer))
        self.extended_unembedding = LinearWrapper(llama.lm_head, len(extended_tokenizer))
        
    def get_surgeried_model(self):
        self.backup_embed_tokens = self.llama.model.embed_tokens
        self.backup_lm_head = self.llama.lm_head
        self.llama.model.embed_tokens = self.extended_embedding
        self.llama.lm_head = self.extended_unembedding
        self.llama.config.vocab_size = len(self.extended_tokenizer)
        return self.llama
    
    def save(self, llama, path):
        # check if llama is surgeried
        assert llama.model.embed_tokens == self.extended_embedding
        assert llama.lm_head == self.extended_unembedding
        self.llama.model.embed_tokens = self.backup_embed_tokens
        self.llama.lm_head = self.backup_lm_head
        self.llama.save_pretrained(path)
        self.extended_tokenizer.save_pretrained(path)
        torch.save(self.extended_embedding.state_dict(), f"{path}/extended_embedding.pt")
        torch.save(self.extended_unembedding.state_dict(), f"{path}/extended_unembedding.pt") 
    
    @classmethod
    def load(cls, path):
        extended_embedding_dict = torch.load(f"{path}/extended_embedding.pt")
        extended_unembedding_dict = torch.load(f"{path}/extended_unembedding.pt")
        llama = AutoModelForCausalLM.from_pretrained(path)
        tokenizer = AutoTokenizer.from_pretrained(path)
        surgeon = cls(llama, tokenizer)
        surgeon.extended_embedding.load_state_dict(extended_embedding_dict)
        surgeon.extended_unembedding.load_state_dict(extended_unembedding_dict)
        return surgeon

class PeftModelEmbeddingSurgeon():
    def __init__(self, peft_model, extended_tokenizer):
        try:
            self.llama = peft_model.base_model.model
        except AttributeError:
            self.llama = peft_model
        self.peft_model = peft_model
        self.extended_tokenizer = extended_tokenizer
        self.extended_embedding = nn.Embedding(len(extended_tokenizer), self.llama.model.embed_tokens.embedding_dim)
        self.extended_embedding.to(self.llama.model.embed_tokens.weight.device).to(self.llama.model.embed_tokens.weight.dtype)
        self.extended_unembedding = nn.Linear(self.llama.model.embed_tokens.embedding_dim, len(extended_tokenizer), bias=False)
        self.extended_unembedding.to(self.llama.lm_head.weight.device).to(self.llama.lm_head.weight.dtype)
       # self.extended_embedding = EmbeddingWrapper(self.llama.model.embed_tokens, len(extended_tokenizer))
       # self.extended_unembedding = LinearWrapper(self.llama.lm_head, len(extended_tokenizer))
        
    def get_surgeried_model(self):
        self.backup_embed_tokens = self.llama.model.embed_tokens
        self.backup_lm_head = self.llama.lm_head
        self.llama.model.embed_tokens = self.extended_embedding
        self.llama.lm_head = self.extended_unembedding
        self.llama.config.vocab_size = len(self.extended_tokenizer)
        return self.peft_model

    def save(self, peft_model, path):        
        self.llama.model.embed_tokens = self.backup_embed_tokens
        self.llama.lm_head = self.backup_lm_head
        self.peft_model.save_pretrained(path)
        self.extended_tokenizer.save_pretrained(path)
        torch.save(self.extended_embedding.state_dict(), f"{path}/extended_embedding.pt")
        torch.save(self.extended_unembedding.state_dict(), f"{path}/extended_unembedding.pt") 
    
    @classmethod
    def load(cls, path, **kwargs):
        extended_embedding_dict = torch.load(f"{path}/extended_embedding.pt")
        extended_unembedding_dict = torch.load(f"{path}/extended_unembedding.pt")
        peft_model = AutoModelForCausalLM.from_pretrained(path, **kwargs)
        tokenizer = AutoTokenizer.from_pretrained(path)
        surgeon = cls(peft_model, tokenizer)
        surgeon.extended_embedding.load_state_dict(extended_embedding_dict)
        surgeon.extended_unembedding.load_state_dict(extended_unembedding_dict)
        return surgeon

In [3]:
model_name_or_path = '/dlabscratch1/public/llm_weights/llama2_hf/Llama-2-7b-hf/'
# load model in torch.bfloat16
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map='auto', torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r=64, 
    lora_alpha=16, 
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                    "fc1", "fc2", 
                    "gate_proj", "up_proj", "down_proj",], 
    lora_dropout=0.00, 
    bias="none", 
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, config)
model.print_trainable_parameters()

trainable params: 159,907,840 || all params: 6,898,323,456 || trainable%: 2.3180681656919973


In [5]:
model

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): LlamaForCausalLM(
      (model): LlamaModel(
        (embed_tokens): Embedding(32000, 4096)
        (layers): ModuleList(
          (0-31): 32 x LlamaDecoderLayer(
            (self_attn): LlamaSdpaAttention(
              (q_proj): lora.Linear(
                (base_layer): Linear(in_features=4096, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Identity()
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=64, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=64, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (k_proj): lora.Linear(
                (base_layer): Linear(in_features=4096, out_features=4

In [6]:
pause_token = AddedToken("<|pause|>", 
                         single_word=False, 
                         lstrip=True, 
                         rstrip=True)
                         #special=True, 
                         #normalized=False)

In [7]:
tokenizer.add_tokens([pause_token], special_tokens=True)
print(tokenizer)
# get idx of pause otken
pause_token_id = tokenizer.convert_tokens_to_ids("<|pause|>")
print(pause_token_id)

LlamaTokenizerFast(name_or_path='/dlabscratch1/public/llm_weights/llama2_hf/Llama-2-7b-hf/', vocab_size=32000, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>'}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
	0: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32000: AddedToken("<|pause|>", rstrip=True, lstrip=True, single_word=False, normalized=False, special=True),
}
32000


In [10]:
## conventionally you'd do this like this:
#model.resize_token_embeddings(len(tokenizer))
        
## ours
surgeon = PeftModelEmbeddingSurgeon(model, tokenizer)
model = surgeon.get_surgeried_model()
print(model)

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): LlamaForCausalLM(
      (model): LlamaModel(
        (embed_tokens): Embedding(32001, 4096)
        (layers): ModuleList(
          (0-31): 32 x LlamaDecoderLayer(
            (self_attn): LlamaSdpaAttention(
              (q_proj): lora.Linear(
                (base_layer): Linear(in_features=4096, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Identity()
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=64, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=64, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (k_proj): lora.Linear(
                (base_layer): Linear(in_features=4096, out_features=4

In [11]:
for k, p in model.named_parameters():
    if p.requires_grad:
        print(k, p.shape)

base_model.model.model.embed_tokens.weight torch.Size([32001, 4096])
base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight torch.Size([64, 4096])
base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight torch.Size([4096, 64])
base_model.model.model.layers.0.self_attn.k_proj.lora_A.default.weight torch.Size([64, 4096])
base_model.model.model.layers.0.self_attn.k_proj.lora_B.default.weight torch.Size([4096, 64])
base_model.model.model.layers.0.self_attn.v_proj.lora_A.default.weight torch.Size([64, 4096])
base_model.model.model.layers.0.self_attn.v_proj.lora_B.default.weight torch.Size([4096, 64])
base_model.model.model.layers.0.self_attn.o_proj.lora_A.default.weight torch.Size([64, 4096])
base_model.model.model.layers.0.self_attn.o_proj.lora_B.default.weight torch.Size([4096, 64])
base_model.model.model.layers.0.mlp.gate_proj.lora_A.default.weight torch.Size([64, 4096])
base_model.model.model.layers.0.mlp.gate_proj.lora_B.default.weight torch.Size([11008, 6

In [12]:
for name, param in model.named_parameters():
    print(name, param.device)

base_model.model.model.embed_tokens.weight cuda:0
base_model.model.model.layers.0.self_attn.q_proj.base_layer.weight cuda:0
base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight cuda:0
base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight cuda:0
base_model.model.model.layers.0.self_attn.k_proj.base_layer.weight cuda:0
base_model.model.model.layers.0.self_attn.k_proj.lora_A.default.weight cuda:0
base_model.model.model.layers.0.self_attn.k_proj.lora_B.default.weight cuda:0
base_model.model.model.layers.0.self_attn.v_proj.base_layer.weight cuda:0
base_model.model.model.layers.0.self_attn.v_proj.lora_A.default.weight cuda:0
base_model.model.model.layers.0.self_attn.v_proj.lora_B.default.weight cuda:0
base_model.model.model.layers.0.self_attn.o_proj.base_layer.weight cuda:0
base_model.model.model.layers.0.self_attn.o_proj.lora_A.default.weight cuda:0
base_model.model.model.layers.0.self_attn.o_proj.lora_B.default.weight cuda:0
base_model.model.model.layers.

In [13]:
# test
toks1 = tokenizer.encode('The<|pause|> quick<|pause|> brown <|pause|> fox jumps over the lazy dog', return_tensors='pt')
toks2 = tokenizer.encode('The quick brown fox jumps over the lazy dog', return_tensors='pt')
idx2 = 0
for idx1 in range(len(toks1[0])):
    if toks1[0, idx1].item() != pause_token_id:
        print('w/o', toks2[0, idx2].item(), tokenizer.decode([toks2[0, idx2]]), 'w', toks1[0, idx1].item(), tokenizer.decode([toks1[0, idx1]]))
        assert toks2[0, idx2] == toks1[0, idx1]
        idx2 += 1
    else:
        print('skipping pause token...')

w/o 1 <s> w 1 <s>
w/o 450 The w 450 The
skipping pause token...
w/o 4996 quick w 4996 quick
skipping pause token...
w/o 17354 brown w 17354 brown
skipping pause token...
w/o 1701 fo w 1701 fo
w/o 29916 x w 29916 x
w/o 432 j w 432 j
w/o 17204 umps w 17204 umps
w/o 975 over w 975 over
w/o 278 the w 278 the
w/o 17366 lazy w 17366 lazy
w/o 11203 dog w 11203 dog


In [14]:
out = model.generate(tokenizer.encode('The<|pause|> quick<|pause|> brown <|pause|> fox jumps over the lazy dog', return_tensors='pt'), max_length=50, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id)



In [15]:
tokenizer.decode(out[0], skip_special_tokens=True)

'The quick brown fox jumps over the lazy dog Kknow Email Weltkrie Щenschapp CreateDelcurrмериIsGEN Emailходить InterGEN cím pert có presenceestigoriousмери CreateParameters presenceIsromagnet K giant dispos javafxactor OUTbigg'

In [16]:
tokenizer.decode(out[0], skip_special_tokens=False)

'<s> The<|pause|> quick<|pause|> brown<|pause|> fox jumps over the lazy dog Kknow Email Weltkrie Щenschapp CreateDelcurrмериIsGEN Emailходить InterGEN cím pert có presenceestigoriousмери CreateParameters presenceIsromagnet K giant dispos javafxactor OUTbigg'

# Test training

In [17]:
from datasets import load_dataset
dataset = load_dataset("gsm8k", "main", split = "train")
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{}

### Input:
{}

### Response:
{}"""

EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN
def formatting_prompts_func(examples):
    instructions = len(examples["question"])*["Solve the math problem using a eval tool. The command eval[[expr]] allows you to evaluate an expression."]
    inputs       = examples["question"]
    outputs      = examples["answer"]
    texts = []
    for instruction, input, output in zip(instructions, inputs, outputs):
        # Must add EOS_TOKEN, otherwise your generation will go on forever!
        text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
        texts.append(text)
    #print(texts)
    return { "text" : texts, }
pass

dataset = dataset.map(formatting_prompts_func, batched = True)

In [18]:
from trl import SFTTrainer
from transformers import TrainingArguments
import os

tokenizer.pad_token = tokenizer.eos_token

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    eval_dataset = None,
    dataset_text_field = "text",
    max_seq_length = 1024,
    dataset_num_proc = 2,
    packing = False, # Can make training 5x faster for short sequences.
    args = TrainingArguments(
        gradient_checkpointing=False,
        per_device_train_batch_size = 1,
        gradient_accumulation_steps = 1,
        warmup_steps = 0,
        max_steps = 1,
        #num_train_epochs = 1,
        learning_rate = 2e-4,
        fp16 = not torch.cuda.is_bf16_supported(),
        bf16 = torch.cuda.is_bf16_supported(),
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "outputs",
    ),
)

with torch.cuda.amp.autocast():
    trainer_stats = trainer.train()

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mchrisxx[0m. Use [1m`wandb login --relogin`[0m to force relogin


ValueError: Attempting to unscale FP16 gradients.

# Test saving and loading

In [None]:
surgeon.save(model, '/dlabscratch1/tmp/peft_test')



In [None]:
model

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): LlamaForCausalLM(
      (model): LlamaModel(
        (embed_tokens): Embedding(32000, 4096)
        (layers): ModuleList(
          (0-31): 32 x LlamaDecoderLayer(
            (self_attn): LlamaSdpaAttention(
              (q_proj): lora.Linear(
                (base_layer): Linear(in_features=4096, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Identity()
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=64, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=64, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (k_proj): lora.Linear(
                (base_layer): Linear(in_features=4096, out_features=4

In [None]:
import gc
model.cpu()
gc.collect()
for i in range(torch.cuda.device_count()):
    torch.cuda.set_device(i) 
    torch.cuda.empty_cache() 

392

In [None]:
surgeon2 = PeftModelEmbeddingSurgeon.load('/dlabscratch1/tmp/peft_test', device_map='auto', torch_dtype=torch.float16)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
model2 = surgeon2.get_surgeried_model()

In [None]:
for name, param in model2.named_parameters():
    print(name, param.device)

model.embed_tokens.old_embeddings.weight cpu
model.embed_tokens.new_embeddings.weight cpu
model.layers.0.self_attn.q_proj.base_layer.weight cpu
model.layers.0.self_attn.q_proj.lora_A.default.weight cpu
model.layers.0.self_attn.q_proj.lora_B.default.weight cpu
model.layers.0.self_attn.k_proj.base_layer.weight cpu
model.layers.0.self_attn.k_proj.lora_A.default.weight cpu
model.layers.0.self_attn.k_proj.lora_B.default.weight cpu
model.layers.0.self_attn.v_proj.base_layer.weight cpu
model.layers.0.self_attn.v_proj.lora_A.default.weight cpu
model.layers.0.self_attn.v_proj.lora_B.default.weight cpu
model.layers.0.self_attn.o_proj.base_layer.weight cpu
model.layers.0.self_attn.o_proj.lora_A.default.weight cpu
model.layers.0.self_attn.o_proj.lora_B.default.weight cpu
model.layers.0.mlp.gate_proj.base_layer.weight cpu
model.layers.0.mlp.gate_proj.lora_A.default.weight cpu
model.layers.0.mlp.gate_proj.lora_B.default.weight cpu
model.layers.0.mlp.up_proj.base_layer.weight cpu
model.layers.0.mlp.u

In [None]:
import gc
model2.cpu()
surgeon2.extended_embedding.cpu()
surgeon2.extended_unembedding.cpu()
surgeon2.backup_embed_tokens.cpu()
surgeon2.backup_lm_head.cpu()
surgeon.extended_embedding.cpu()
surgeon.extended_unembedding.cpu()
surgeon.backup_embed_tokens.cpu()
surgeon.backup_lm_head.cpu()
gc.collect()
for i in range(torch.cuda.device_count()):
    torch.cuda.set_device(i) 
    torch.cuda.empty_cache() 