# 零阶法训练lora中关键代码的编写测试

## 适合Trainer计算loss的数据集构造（参考toolbench的training代码）

### 探究LLaMA3 tokenizer所需offset

In [5]:
# import libraries
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
from threading import Thread
import os
os.environ['CUDA_VISIBLE_DEVICES']='0,1,2,3,4,5,6,7'
import json, pickle, copy
from peft import PeftModel
from tqdm import tqdm
import re
import concurrent.futures
from transformers.trainer_pt_utils import LabelSmoother

In [2]:
# Load LLaMA3 model
model_pth = '../Meta-Llama-3-8B-Instruct'

tokenizer = AutoTokenizer.from_pretrained(model_pth,torch_dtype=torch.float16,load_in_4bit=False,load_in_8bit=False)
model = AutoModelForCausalLM.from_pretrained(model_pth,torch_dtype=torch.float16,load_in_4bit=False,load_in_8bit=False)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00,  1.53it/s]


In [3]:
tokenizer.pad_token = tokenizer.bos_token

In [4]:
# example input
system_input = "You are GPT5."
user_input = "What is the meaning of life?"
message = [{'role': 'system', 'content': system_input}, {'role': 'user', 'content': user_input}]
prompt = tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True)
print(prompt)
length = len(tokenizer(prompt).input_ids)
print(f'token length: {length}')
print(f'token_ids: {tokenizer(prompt).input_ids}')
for input_id in tokenizer(prompt).input_ids:
    print(f'{tokenizer.decode(input_id)}({input_id})')

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are GPT5.<|eot_id|><|start_header_id|>user<|end_header_id|>

What is the meaning of life?<|eot_id|><|start_header_id|>assistant<|end_header_id|>


token length: 28
token_ids: [128000, 128006, 9125, 128007, 271, 2675, 527, 480, 2898, 20, 13, 128009, 128006, 882, 128007, 271, 3923, 374, 279, 7438, 315, 2324, 30, 128009, 128006, 78191, 128007, 271]
<|begin_of_text|>(128000)
<|start_header_id|>(128006)
system(9125)
<|end_header_id|>(128007)


(271)
You(2675)
 are(527)
 G(480)
PT(2898)
5(20)
.(13)
<|eot_id|>(128009)
<|start_header_id|>(128006)
user(882)
<|end_header_id|>(128007)


(271)
What(3923)
 is(374)
 the(279)
 meaning(7438)
 of(315)
 life(2324)
?(30)
<|eot_id|>(128009)
<|start_header_id|>(128006)
assistant(78191)
<|end_header_id|>(128007)


(271)


In [5]:
# example input
system_input = "You are GPT5."
user_input = "What is the meaning of life?"
assistant_output = "I don't know."
message = [{'role': 'system', 'content': system_input}, {'role': 'user', 'content': user_input}, {'role': 'assistant', 'content': assistant_output}]
prompt = tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=False)
print(prompt)
length = len(tokenizer(prompt).input_ids)
print(f'token length: {length}')
print(f'token_ids: {tokenizer(prompt).input_ids}')
for input_id in tokenizer(prompt).input_ids:
    print(f'{tokenizer.decode(input_id)}({input_id})')

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are GPT5.<|eot_id|><|start_header_id|>user<|end_header_id|>

What is the meaning of life?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

I don't know.<|eot_id|><|start_header_id|>assistant<|end_header_id|>


token length: 38
token_ids: [128000, 128006, 9125, 128007, 271, 2675, 527, 480, 2898, 20, 13, 128009, 128006, 882, 128007, 271, 3923, 374, 279, 7438, 315, 2324, 30, 128009, 128006, 78191, 128007, 271, 40, 1541, 956, 1440, 13, 128009, 128006, 78191, 128007, 271]
<|begin_of_text|>(128000)
<|start_header_id|>(128006)
system(9125)
<|end_header_id|>(128007)


(271)
You(2675)
 are(527)
 G(480)
PT(2898)
5(20)
.(13)
<|eot_id|>(128009)
<|start_header_id|>(128006)
user(882)
<|end_header_id|>(128007)


(271)
What(3923)
 is(374)
 the(279)
 meaning(7438)
 of(315)
 life(2324)
?(30)
<|eot_id|>(128009)
<|start_header_id|>(128006)
assistant(78191)
<|end_header_id|>(128007)


(271)
I(40)
 don(1541)
't(956)
 know(1440)
.(13)
<|

### 建立数据集对象

In [4]:
# load data
with open('../data/grade_school_math/data/train.jsonl', 'r') as f:
    question_data = [json.loads(line) for line in f]

with open('generated_data/actor_response_data_01.jsonl', 'r') as f:
    actor_data = [json.loads(line) for line in f]

with open('generated_data/critic_response_data_02.jsonl', 'r') as f:
    critic_data = [json.loads(line) for line in f]

with open('generated_data/summarizer_response_data_01.jsonl', 'r') as f:
    summarizer_data = [json.loads(line) for line in f]

In [7]:
actor_data_selected = []
for datum in actor_data:
    if datum['answer_correct'] == datum['answer_actor']:
        actor_data_selected.append(datum)

In [8]:
# generate prompts
system_prompt = 'You are an actor who is responsible for solving math problems. Given a math problem, you need to give a concise analysis followed by the correct answer in the format "\n#### [Answer with digits only]" in the very end of your response.'
messages = [[{'role': 'system', 'content': system_prompt}, 
             {'role': 'question', 'content': question_data[datum['question_id']]['question']},
             {'role': 'assistant', 'content': datum['actor_response']}] for datum in actor_data_selected]
input_prompts = [tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=False)[:-47] for message in messages]
pattern = '<|start_header_id|>assistant<|end_header_id|>\n\n'
separate_idx = [prompt.find(pattern) + 47 for prompt in input_prompts]
instruction_prompts = [prompt[:idx] for prompt, idx in zip(input_prompts, separate_idx)]
label_prompts = [prompt[idx:] for prompt, idx in zip(input_prompts, separate_idx)]

In [9]:
# generate instruction_ids
instruction_ids = [tokenizer(prompt).input_ids for prompt in instruction_prompts]
instruction_token_len = [len(ids) for ids in instruction_ids]

In [10]:
# target ids are of the same shape of instruction_ids
IGNORE_TOKEN_ID = LabelSmoother.ignore_index # -100
target_ids = [[IGNORE_TOKEN_ID] * len(ids) for ids in instruction_ids]

In [11]:
# generate label_ids
label_ids = [tokenizer(prompt).input_ids for prompt in label_prompts]
label_token_len = [len(ids) for ids in label_ids]
print(label_token_len[:10])

[84, 114, 50, 75, 82, 95, 75, 194, 74, 46]


In [12]:
# combine instruction_ids and label_ids into input_ids
maxlen = 1024
input_ids = [instruction_id + label_id + [tokenizer.pad_token_id] * (maxlen - len(instruction_id) - len(label_id)) for instruction_id, label_id in zip(instruction_ids, label_ids)]
# combine target_ids and label_ids into target_ids
target_ids = [target_id + label_id + [IGNORE_TOKEN_ID] * (maxlen - len(target_id) - len(label_id)) for target_id, label_id in zip(target_ids, label_ids)]
# change list to tensor
input_ids = torch.tensor(input_ids)
target_ids = torch.tensor(target_ids)
print(input_ids.shape)
print(target_ids.shape)

torch.Size([58153, 1024])
torch.Size([58153, 1024])


In [13]:
# generate attention_mask
attention_mask = input_ids.ne(tokenizer.pad_token_id)
print(attention_mask.shape)

torch.Size([58153, 1024])


In [14]:
# combine input_ids, labels and attention_mask to achieve the final dataset
complete_dataset = dict(
    input_ids=input_ids,
    attention_mask=attention_mask,
    labels=target_ids,
)

In [15]:
model.to(torch.device('cuda:0'))

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head)

In [16]:
test_input_ids = input_ids[:4].to(torch.device('cuda:0'))

In [17]:
test_attention_mask = attention_mask[:4].to(torch.device('cuda:0'))

In [23]:
test_label = target_ids[:4].to(torch.device('cuda:0'))

In [29]:
# test model output given input_ids[:4]
# release GPU memory
with torch.no_grad():
    output = model(input_ids=test_input_ids, attention_mask=test_attention_mask, labels=test_label)
print(output.keys())
print(output.logits.shape)
print(output.loss)

odict_keys(['loss', 'logits', 'past_key_values'])
torch.Size([4, 1024, 128256])
tensor(1.3378, device='cuda:0')


In [30]:
# save complete_dataset
with open('./generated_data/processed_actor_positive_data.pkl', 'wb') as f:
    pickle.dump(complete_dataset, f)

In [18]:
# critic dataset
critic_data_selected =  []
for datum in critic_data:
    if datum['judge_correct'] == datum['judge_critic']:
        critic_data_selected.append(datum)
system_prompt = 'You are a critic who is responsible for judging the correctness of the actor\'s answer. Provided with the math problem, correct answer and the student\'s answer, you need to judge whether the actor\'s answer is correct. If the actor\'s answer is right, respond with "#### The answer is: Accepted." Otherwise, analyze the reason why the actor arrived at the wrong answer and respond with "#### The answer is: Wrong Answer. [Reason for the wrong answer, without displaying the correct number to the question]".'
messages = [[{'role': 'system', 'content': system_prompt},
             {'role': 'question', 'content': question_data[actor_data[datum['actor_response_id']]['question_id']]['question']},
             {'role': 'correct answer', 'content': question_data[actor_data[datum['actor_response_id']]['question_id']]['answer']},
             {'role': 'actor\'s answer', 'content': actor_data[datum['actor_response_id']]['actor_response']},
             {'role': 'assistant', 'content': datum['critic_response']}] for datum in critic_data_selected]
input_prompts = [tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=False)[:-47] for message in messages]
pattern = '<|start_header_id|>assistant<|end_header_id|>\n\n'
separate_idx = [prompt.find(pattern) + 47 for prompt in input_prompts]
instruction_prompts = [prompt[:idx] for prompt, idx in zip(input_prompts, separate_idx)]
label_prompts = [prompt[idx:] for prompt, idx in zip(input_prompts, separate_idx)]
# generate instruction_ids
instruction_ids = [tokenizer(prompt).input_ids for prompt in instruction_prompts]
instruction_token_len = [len(ids) for ids in instruction_ids]
# target ids are of the same shape of instruction_ids
IGNORE_TOKEN_ID = LabelSmoother.ignore_index # -100
target_ids = [[IGNORE_TOKEN_ID] * len(ids) for ids in instruction_ids]
# generate label_ids
label_ids = [tokenizer(prompt).input_ids for prompt in label_prompts]
label_token_len = [len(ids) for ids in label_ids]
print(f'maxlen:{max([ins+lab for ins, lab in zip(instruction_token_len, label_token_len)])}')
# combine instruction_ids and label_ids into input_ids
maxlen = 1536
input_ids = [instruction_id + label_id + [tokenizer.pad_token_id] * (maxlen - len(instruction_id) - len(label_id)) for instruction_id, label_id in zip(instruction_ids, label_ids)]
# combine target_ids and label_ids into target_ids
target_ids = [target_id + label_id + [IGNORE_TOKEN_ID] * (maxlen - len(target_id) - len(label_id)) for target_id, label_id in zip(target_ids, label_ids)]
# change list to tensor
input_ids = torch.tensor(input_ids)
target_ids = torch.tensor(target_ids)
print(input_ids.shape)
print(target_ids.shape)
# generate attention_mask
attention_mask = input_ids.ne(tokenizer.pad_token_id)
print(attention_mask.shape)
# combine input_ids, labels and attention_mask to achieve the final dataset
complete_dataset = dict(
    input_ids=input_ids,
    attention_mask=attention_mask,
    labels=target_ids,
)
# save complete_dataset
with open('./generated_data/processed_critic_positive_data.pkl', 'wb') as f:
    pickle.dump(complete_dataset, f)

maxlen:1481
torch.Size([110133, 1536])
torch.Size([110133, 1536])
torch.Size([110133, 1536])


In [9]:
# summarizer dataset
summarizer_data_selected =  []
for datum in summarizer_data:
    if datum['label_positive']:
        summarizer_data_selected.append(datum)

system_prompt = 'You are a summarizer who is responsible for deciding the final answer to a given math problem, with the help of an actor\'s solution and a critic\'s judgement of whether the actor\'s answer is correct or not. If the actor\'s answer is correct, summarize the analysis. Otherwise, fix the actor\'s answer according to the critic\'s feedback. Only the correct analysis is allowed to be presented. Do not include statements about whether the actor or critic is correct. Finally, add "\n\n#### [Answer to the question with digits only]" as a summarization.'
messages = [[{'role': 'system', 'content': system_prompt},
             {'role': 'question', 'content': question_data[actor_data[critic_data[datum['critic_response_id']]['actor_response_id']]['question_id']]['question']},
             {'role': 'actor\'s answer', 'content': actor_data[critic_data[datum['critic_response_id']]['actor_response_id']]['actor_response']},
             {'role': 'critic\'s judgement', 'content': critic_data[datum['critic_response_id']]['critic_response']},
             {'role': 'assistant', 'content': datum['summarizer_response']}] for datum in summarizer_data_selected]
input_prompts = [tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=False)[:-47] for message in messages]
pattern = '<|start_header_id|>assistant<|end_header_id|>\n\n'
separate_idx = [prompt.find(pattern) + 47 for prompt in input_prompts]
instruction_prompts = [prompt[:idx] for prompt, idx in zip(input_prompts, separate_idx)]
label_prompts = [prompt[idx:] for prompt, idx in zip(input_prompts, separate_idx)]
# generate instruction_ids
instruction_ids = [tokenizer(prompt).input_ids for prompt in instruction_prompts]
instruction_token_len = [len(ids) for ids in instruction_ids]
# target ids are of the same shape of instruction_ids
IGNORE_TOKEN_ID = LabelSmoother.ignore_index # -100
target_ids = [[IGNORE_TOKEN_ID] * len(ids) for ids in instruction_ids]
# generate label_ids
label_ids = [tokenizer(prompt).input_ids for prompt in label_prompts]
label_token_len = [len(ids) for ids in label_ids]
print(f'maxlen:{max([ins+lab for ins, lab in zip(instruction_token_len, label_token_len)])}')
# combine instruction_ids and label_ids into input_ids
maxlen = 1536
input_ids = [instruction_id + label_id + [tokenizer.pad_token_id] * (maxlen - len(instruction_id) - len(label_id)) for instruction_id, label_id in zip(instruction_ids, label_ids)]
# combine target_ids and label_ids into target_ids
target_ids = [target_id + label_id + [IGNORE_TOKEN_ID] * (maxlen - len(target_id) - len(label_id)) for target_id, label_id in zip(target_ids, label_ids)]
# change list to tensor
input_ids = torch.tensor(input_ids)
target_ids = torch.tensor(target_ids)
print(input_ids.shape)
print(target_ids.shape)
# generate attention_mask
attention_mask = input_ids.ne(tokenizer.pad_token_id)
print(attention_mask.shape)
# combine input_ids, labels and attention_mask to achieve the final dataset
complete_dataset = dict(
    input_ids=input_ids,
    attention_mask=attention_mask,
    labels=target_ids,
)
# save complete_dataset
with open('./generated_data/processed_summarizer_positive_data.pkl', 'wb') as f:
    pickle.dump(complete_dataset, f)

maxlen:1434
torch.Size([86169, 1536])
torch.Size([86169, 1536])
torch.Size([86169, 1536])


## 使用PEFT+零阶法训练lora

### 搭建PEFT模型

In [1]:
# import libraries
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
from threading import Thread
import os
os.environ['CUDA_VISIBLE_DEVICES']='0,1,2,3,4,5,6,7'
import json, pickle, copy
from peft import PeftModel
from peft import (
    get_peft_model,
    prepare_model_for_kbit_training,
    LoraConfig
)
from tqdm import tqdm
import re
import concurrent.futures
from transformers.trainer_pt_utils import LabelSmoother

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load LLaMA3 model
model_pth = '../Meta-Llama-3-8B-Instruct'

tokenizer = AutoTokenizer.from_pretrained(model_pth,torch_dtype=torch.float16,load_in_4bit=False,load_in_8bit=False)
model = AutoModelForCausalLM.from_pretrained(model_pth,torch_dtype=torch.float16,load_in_4bit=False,load_in_8bit=False)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.04it/s]


In [3]:
print(model)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head)

In [3]:
lora_config = LoraConfig(
    r=8,
    lora_alpha=8,
    lora_dropout=0.1,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    bias="none",
    task_type="CAUSAL_LM",
)
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)

In [7]:
print(model)

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): LlamaForCausalLM(
      (model): LlamaModel(
        (embed_tokens): Embedding(128256, 4096)
        (layers): ModuleList(
          (0-31): 32 x LlamaDecoderLayer(
            (self_attn): LlamaSdpaAttention(
              (q_proj): lora.Linear(
                (base_layer): Linear(in_features=4096, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (k_proj): lora.Linear(
                (base_layer): Linear(in_features=40

In [6]:
for name, module in model.named_modules():
    if "norm" in name:
        module = module.to(torch.float32)
        print(f'{name} is converted to float32')

base_model.model.model.layers.0.input_layernorm is converted to float32
base_model.model.model.layers.0.post_attention_layernorm is converted to float32
base_model.model.model.layers.1.input_layernorm is converted to float32
base_model.model.model.layers.1.post_attention_layernorm is converted to float32
base_model.model.model.layers.2.input_layernorm is converted to float32
base_model.model.model.layers.2.post_attention_layernorm is converted to float32
base_model.model.model.layers.3.input_layernorm is converted to float32
base_model.model.model.layers.3.post_attention_layernorm is converted to float32
base_model.model.model.layers.4.input_layernorm is converted to float32
base_model.model.model.layers.4.post_attention_layernorm is converted to float32
base_model.model.model.layers.5.input_layernorm is converted to float32
base_model.model.model.layers.5.post_attention_layernorm is converted to float32
base_model.model.model.layers.6.input_layernorm is converted to float32
base_model

In [16]:
# check trainable parameters in model
trainable_parameters = [(name, parameters) for name, parameters in model.named_parameters() if parameters.requires_grad]
for name, parameters in trainable_parameters:
    print(name, parameters.shape)
print(f'Total trainable parameters: {sum([parameters.numel() for _, parameters in trainable_parameters])}')

base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight torch.Size([8, 4096])
base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight torch.Size([4096, 8])
base_model.model.model.layers.0.self_attn.k_proj.lora_A.default.weight torch.Size([8, 4096])
base_model.model.model.layers.0.self_attn.k_proj.lora_B.default.weight torch.Size([1024, 8])
base_model.model.model.layers.0.self_attn.v_proj.lora_A.default.weight torch.Size([8, 4096])
base_model.model.model.layers.0.self_attn.v_proj.lora_B.default.weight torch.Size([1024, 8])
base_model.model.model.layers.0.self_attn.o_proj.lora_A.default.weight torch.Size([8, 4096])
base_model.model.model.layers.0.self_attn.o_proj.lora_B.default.weight torch.Size([4096, 8])
base_model.model.model.layers.1.self_attn.q_proj.lora_A.default.weight torch.Size([8, 4096])
base_model.model.model.layers.1.self_attn.q_proj.lora_B.default.weight torch.Size([4096, 8])
base_model.model.model.layers.1.self_attn.k_proj.lora_A.default.weight

In [28]:
model.to(torch.device('cuda:0'))

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): LlamaForCausalLM(
      (model): LlamaModel(
        (embed_tokens): Embedding(128256, 4096)
        (layers): ModuleList(
          (0-31): 32 x LlamaDecoderLayer(
            (self_attn): LlamaSdpaAttention(
              (q_proj): lora.Linear(
                (base_layer): Linear(in_features=4096, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (k_proj): lora.Linear(
                (base_layer): Linear(in_features=40

### 测试零阶训练代码

In [29]:
num_iters = 10
scaling_factor = 1
zo_eps = 1e-3
batch_size = 2
random_seed = 123456
learning_rate = 1e-5

In [30]:
with open('./generated_data/processed_actor_positive_data.pkl', 'rb') as f:
    complete_dataset = pickle.load(f)
    
shuffled_ids = torch.randperm(len(complete_dataset['input_ids']))

for iter in range(10):

    # sample data batch
    batched_input_ids = complete_dataset['input_ids'][shuffled_ids[iter * batch_size: (iter + 1) * batch_size]].to(torch.device('cuda:0'))
    batched_labels = complete_dataset['labels'][shuffled_ids[iter * batch_size: (iter + 1) * batch_size]].to(torch.device('cuda:0'))
    batched_attention_mask = complete_dataset['attention_mask'][shuffled_ids[iter * batch_size: (iter + 1) * batch_size]].to(torch.device('cuda:0'))
    
    # random perturbation
    torch.manual_seed(random_seed)
    for _, params in trainable_parameters:
        z = torch.normal(mean=0, std=1, size=params.data.size(), device=params.data.device, dtype=params.data.dtype)
        params.data += scaling_factor * z * zo_eps
    
    # loss computation
    with torch.no_grad():
        loss1 = model(input_ids=batched_input_ids, attention_mask=batched_attention_mask, labels=batched_labels).loss
        print(f'iter: {iter}, loss1: {loss1}')

    # perturbate in the opposite direction
    torch.manual_seed(random_seed)
    for _, params in trainable_parameters:
        z = torch.normal(mean=0, std=1, size=params.data.size(), device=params.data.device, dtype=params.data.dtype)
        params.data -= 2 * scaling_factor * z * zo_eps

    # loss computation
    with torch.no_grad():
        loss2 = model(input_ids=batched_input_ids, attention_mask=batched_attention_mask, labels=batched_labels).loss
        print(f'iter: {iter}, loss2: {loss2}')
    
    # compute update
    projected_grad = ((loss1 - loss2) / (2 * zo_eps)).item()

    # reset and update parameters
    torch.manual_seed(random_seed)
    for _, params in trainable_parameters:
        z = torch.normal(mean=0, std=1, size=params.data.size(), device=params.data.device, dtype=params.data.dtype)
        params.data += scaling_factor * (zo_eps - projected_grad * learning_rate) * z

iter: 0, loss1: 0.9620891809463501
iter: 0, loss2: 0.9620891809463501
iter: 1, loss1: 1.249617576599121
iter: 1, loss2: 1.249617576599121
iter: 2, loss1: 0.9707486033439636
iter: 2, loss2: 0.9707486033439636
iter: 3, loss1: 1.5380691289901733
iter: 3, loss2: 1.5380691289901733
iter: 4, loss1: 1.1894009113311768
iter: 4, loss2: 1.1894009113311768
iter: 5, loss1: 1.4856212139129639
iter: 5, loss2: 1.4856212139129639
iter: 6, loss1: 0.9410967230796814
iter: 6, loss2: 0.9410967230796814
iter: 7, loss1: 1.1143782138824463
iter: 7, loss2: 1.1143782138824463
iter: 8, loss1: 1.1747761964797974
iter: 8, loss2: 1.1747761964797974
iter: 9, loss1: 0.9255077838897705
iter: 9, loss2: 0.9255077838897705


## 实现多卡并行的零阶法LoRA微调

### 模型与数据准备

In [1]:
# import libraries
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
from threading import Thread
import os
os.environ['CUDA_VISIBLE_DEVICES']='0,1,2,3,4,5,6,7'
import json, pickle, copy
from peft import PeftModel
from peft import (
    get_peft_model,
    prepare_model_for_kbit_training,
    LoraConfig
)
from tqdm import tqdm
import re
import concurrent.futures
from transformers.trainer_pt_utils import LabelSmoother

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load LLaMA3 model
model_pth = '../Meta-Llama-3-8B-Instruct'


model = AutoModelForCausalLM.from_pretrained(model_pth,torch_dtype=torch.float32,load_in_4bit=False,load_in_8bit=False)

lora_config = LoraConfig(
    r=8,
    lora_alpha=8,
    lora_dropout=0.1,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    bias="none",
    task_type="CAUSAL_LM",
)
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)

Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  2.59it/s]


In [3]:
models = [copy.deepcopy(model).to(torch.device(f'cuda:{_}')) for _ in range(8)]
tokenizers = [AutoTokenizer.from_pretrained(model_pth,torch_dtype=torch.float16,load_in_4bit=False,load_in_8bit=False) for _ in range(8)]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [4]:
for model in models:
    model.trainable_parameters = [(name, parameters) for name, parameters in model.named_parameters() if parameters.requires_grad]

In [5]:
for tokenizer in tokenizers:
    tokenizer.pad_token = tokenizer.bos_token

In [6]:
with open('./generated_data/processed_actor_positive_data.pkl', 'rb') as f:
    complete_dataset = pickle.load(f)
item_counts = [complete_dataset['labels'][id].ne(-100).sum().item() for id in range(complete_dataset['labels'].size(0))]
complete_dataset['item_counts'] = item_counts

### 单卡Loss计算

In [17]:
def compute_loss_single_node(model, input_ids, labels, attention_mask, zo_eps, random_seed, batch_size, scale_factor, device, reset=False):

    n_data = len(input_ids)
    n_batch = n_data // batch_size

    # random perturbation
    generator = torch.Generator().manual_seed(random_seed)
    for _, params in model.trainable_parameters:
        z = torch.normal(mean=0, std=1, size=params.data.size(), generator=generator, dtype=params.data.dtype).to(device=device)
        params.data += z * zo_eps
    
    # compute loss1
    loss1 = 0
    with torch.no_grad():
        for batch_id in range(n_batch):
            loss1 += model(input_ids=input_ids[batch_id * batch_size: (batch_id + 1) * batch_size].to(device),
                           labels=labels[batch_id * batch_size: (batch_id + 1) * batch_size].to(device),
                           attention_mask=attention_mask[batch_id * batch_size: (batch_id + 1) * batch_size].to(device)).loss.item()
    
    # perturbate in the opposite direction
    generator = generator.manual_seed(random_seed)
    for _, params in model.trainable_parameters:
        z = torch.normal(mean=0, std=1, size=params.data.size(), generator=generator, dtype=params.data.dtype).to(device=device)
        params.data -= z * zo_eps * 2
    
    # compute loss2
    loss2 = 0
    with torch.no_grad():
        for batch_id in range(n_batch):
            loss2 += model(input_ids=input_ids[batch_id * batch_size: (batch_id + 1) * batch_size].to(device),
                           labels=labels[batch_id * batch_size: (batch_id + 1) * batch_size].to(device),
                           attention_mask=attention_mask[batch_id * batch_size: (batch_id + 1) * batch_size].to(device)).loss.item()
    
    if reset:
        generator = generator.manual_seed(random_seed)
        for _, params in model.trainable_parameters:
            z = torch.normal(mean=0, std=1, size=params.data.size(), generator=generator, dtype=params.data.dtype).to(device=device)
            params.data += z * zo_eps
    
    return loss1 * scale_factor, loss2 * scale_factor # the scale_factor is corresponding to the number of minibatches in the computation of the CE loss (i.e., number of non-[-100] labels)

### 多卡Loss计算

In [19]:
def compute_loss_multi_node(models, input_ids, labels, attention_mask, zo_eps, random_seed, batch_size, item_counts=None, n_GPU=8, reset=False): # batch_size is exactly the batch_size per node

    n_data = len(input_ids)
    n_data_per_GPU = n_data // n_GPU

    if item_counts is None:
        item_counts = [labels[id].ne(-100).sum().item() for id in range(labels.size(0))]
    
    scale_factors = [sum(item_counts[i * n_data_per_GPU: (i+1) * n_data_per_GPU]) for i in range(n_GPU)]
    loss1, loss2 = 0.0, 0.0

    with concurrent.futures.ThreadPoolExecutor() as executor:
        futures = []
        for i in range(n_GPU):
            futures.append(executor.submit(compute_loss_single_node, models[i], 
                                           input_ids[i * n_data_per_GPU: (i+1) * n_data_per_GPU], 
                                           labels[i * n_data_per_GPU: (i+1) * n_data_per_GPU], 
                                           attention_mask[i * n_data_per_GPU: (i+1) * n_data_per_GPU],
                                           zo_eps, random_seed, batch_size, scale_factors[i],
                                           torch.device(f'cuda:{i}'), reset))
        for future in concurrent.futures.as_completed(futures):
            loss1_delta, loss2_delta = future.result()
            loss1 += loss1_delta
            loss2 += loss2_delta
    
    total_scale_factor = sum(scale_factors)
    return loss1 / total_scale_factor, loss2 / total_scale_factor   

### 单卡参数更新

In [18]:
def update_params_single_node(model, zo_eps, random_seed, scale_factor, device, reset=False): # scale factor here denotes the gradient step
    
    generator = torch.Generator().manual_seed(random_seed)
    for _, params in model.trainable_parameters:
        z = torch.normal(mean=0, std=1, size=params.data.size(), generator=generator, dtype=params.data.dtype).to(device=device)
        params.data += z * (zo_eps - scale_factor) if not reset else z * (-scale_factor)

    return


### 多卡并行更新

In [39]:
def update_params_multi_node(models, zo_eps, random_seed, scale_factor, n_GPU=8, reset=False):

    with concurrent.futures.ThreadPoolExecutor() as executor:
        futures = []
        for i in range(n_GPU):
            futures.append(executor.submit(update_params_single_node, models[i], zo_eps, random_seed, scale_factor, torch.device(f'cuda:{i}'), reset))
        for future in concurrent.futures.as_completed(futures):
            future.result()

    return

### 多卡零阶LoRA微调

In [21]:
def zo_lora_finetuning_multi_node(models, dataset, zo_eps, learning_rate, batch_size, random_seeds=None, n_iters=None, n_epochs=None, n_GPU=8, verbose=False):
    
    assert (n_iters is None) ^ (n_epochs is None)
    input_ids, labels, attention_mask, item_counts = dataset['input_ids'], dataset['labels'], dataset['attention_mask'], dataset['item_counts']
    
    n_iters_per_epoch = len(input_ids) // batch_size
    n_iters = n_iters_per_epoch * n_epochs if n_epochs is not None else n_iters
    batch_size_per_GPU = batch_size // n_GPU

    if random_seeds is None:
        random_seeds = torch.randint(100000000, (n_iters,))

    e = 0 # epoch
    iter = 0 # iteration
    while iter < n_iters:

        # shuffle data for epoch e
        shuffled_ids = torch.randperm(len(input_ids))

        for i in range(n_iters_per_epoch):

            if iter >= n_iters:
                break

            # get iter data
            batch_input_ids = input_ids[shuffled_ids[i * batch_size: (i+1) * batch_size]]
            batch_labels = labels[shuffled_ids[i * batch_size: (i+1) * batch_size]]
            batch_attention_mask = attention_mask[shuffled_ids[i * batch_size: (i+1) * batch_size]]
            batch_item_counts = [item_counts[shuffled_ids[j]] for j in range(i * batch_size, (i+1) * batch_size)]

            # compute loss1, loss2
            loss1, loss2 = compute_loss_multi_node(models=models, 
                                                   input_ids=batch_input_ids, 
                                                   labels=batch_labels, 
                                                   attention_mask=batch_attention_mask, 
                                                   zo_eps=zo_eps, 
                                                   random_seed=random_seeds[iter].item(),
                                                   batch_size=batch_size_per_GPU,
                                                   item_counts=batch_item_counts,
                                                   n_GPU=n_GPU,
                                                   reset=False)
            
            # compute update stepsize
            scale_factor = (loss1 - loss2) / (2 * zo_eps) * learning_rate

            # update params
            update_params_multi_node(models=models, 
                                     zo_eps=zo_eps, 
                                     random_seed=random_seeds[iter].item(),
                                     scale_factor=scale_factor,
                                     n_GPU=n_GPU,
                                     reset=False)

            if verbose:
                print(f'In epoch {e} iter {i}, loss1={loss1}, loss2={loss2}')
            iter += 1
        
        e += 1

    return

In [55]:
zo_eps = 1e-3
learning_rate = 1e-6

zo_lora_finetuning_multi_node(models=models,
                              dataset=complete_dataset,
                              zo_eps=zo_eps,
                              learning_rate=learning_rate,
                              batch_size=128,
                              random_seeds=None,
                              n_iters=10,
                              n_epochs=None,
                              n_GPU=8,
                              verbose=True)



In epoch 0 iter 0, loss1=1.245854886653261, loss2=1.2432450466212233
In epoch 0 iter 1, loss1=1.26297794536308, loss2=1.2639132856709476
In epoch 0 iter 2, loss1=1.2700721938170232, loss2=1.2708157269087872
In epoch 0 iter 3, loss1=1.220146984956842, loss2=1.2196024243648236
In epoch 0 iter 4, loss1=1.2793298186567978, loss2=1.2790694497986326
In epoch 0 iter 5, loss1=1.290576580576764, loss2=1.292582733537844
In epoch 0 iter 6, loss1=1.233810990897503, loss2=1.2348346873733866
In epoch 0 iter 7, loss1=1.255271207099683, loss2=1.2543755417670597
In epoch 0 iter 8, loss1=1.2425236751705417, loss2=1.2442799435157708
In epoch 0 iter 9, loss1=1.2452937210054018, loss2=1.2426205011298876


### Debug

In [29]:
zo_eps = 1e-3
params = [model.base_model.model.model.layers[31].self_attn.k_proj.lora_A.default.weight.clone() for model in models]
compute_loss_multi_node(models=models,
                        input_ids=complete_dataset['input_ids'][:8],
                        labels=complete_dataset['labels'][:8],
                        attention_mask=complete_dataset['attention_mask'][:8],
                        zo_eps=zo_eps,
                        random_seed=123456,
                        batch_size=1,
                        item_counts=complete_dataset['item_counts'][:8],
                        n_GPU=8,
                        reset=False)
update_params_multi_node(models=models,
                         zo_eps=zo_eps,
                         random_seed=123456,
                         scale_factor=0,
                         n_GPU=8,
                         reset=False)
params_new = [model.base_model.model.model.layers[31].self_attn.k_proj.lora_A.default.weight.clone() for model in models]

for param, param_new in zip(params, params_new):
    print(torch.allclose(param, param_new))
    print(torch.abs(param - param_new).max())

False
tensor(0.0043, device='cuda:0', grad_fn=<MaxBackward1>)
False
tensor(0.0043, device='cuda:1', grad_fn=<MaxBackward1>)
False
tensor(0.0043, device='cuda:2', grad_fn=<MaxBackward1>)
False
tensor(0.0043, device='cuda:3', grad_fn=<MaxBackward1>)
False
tensor(0.0043, device='cuda:4', grad_fn=<MaxBackward1>)
False
tensor(0.0043, device='cuda:5', grad_fn=<MaxBackward1>)
False
tensor(0.0043, device='cuda:6', grad_fn=<MaxBackward1>)
False
tensor(0.0043, device='cuda:7', grad_fn=<MaxBackward1>)


In [43]:
zo_eps = 1
params0 = models[0].base_model.model.model.layers[31].self_attn.k_proj.lora_A.default.weight.clone()

compute_loss_single_node(model=models[0],
                        input_ids=complete_dataset['input_ids'][:1],
                        labels=complete_dataset['labels'][:1],
                        attention_mask=complete_dataset['attention_mask'][:1],
                        zo_eps=zo_eps,
                        random_seed=123456,
                        batch_size=1,
                        scale_factor=0,
                        device=torch.device('cuda:0'),
                        reset=False)
update_params_single_node(model=models[0],
                            zo_eps=zo_eps,
                            random_seed=123456,
                            scale_factor=0,
                            device=torch.device('cuda:0'),
                            reset=False)
params0_new = models[0].base_model.model.model.layers[31].self_attn.k_proj.lora_A.default.weight.clone()

print(torch.allclose(params0, params0_new))
print(torch.abs(params0 - params0_new).max())

True
tensor(1.8626e-09, device='cuda:0', grad_fn=<MaxBackward1>)


In [42]:
zo_eps = 1
params = [model.base_model.model.model.layers[31].self_attn.k_proj.lora_A.default.weight.clone() for model in models]

update_params_multi_node(models=models,
                         zo_eps=zo_eps,
                         random_seed=123456,
                         scale_factor=0,
                         n_GPU=8,
                         reset=False)

update_params_multi_node(models=models,
                         zo_eps=zo_eps,
                         random_seed=123456,
                         scale_factor=3,
                         n_GPU=8,
                         reset=False)

update_params_multi_node(models=models,
                         zo_eps=zo_eps,
                         random_seed=123456,
                         scale_factor=0,
                         n_GPU=8,
                         reset=False)
params_new = [model.base_model.model.model.layers[31].self_attn.k_proj.lora_A.default.weight.clone() for model in models]

for param, param_new in zip(params, params_new):
    print(torch.allclose(param, param_new))
    print(torch.abs(param - param_new).max())

True
tensor(4.7684e-07, device='cuda:0', grad_fn=<MaxBackward1>)
False
tensor(1.1921e-07, device='cuda:1', grad_fn=<MaxBackward1>)
False
tensor(1.1921e-07, device='cuda:2', grad_fn=<MaxBackward1>)
False
tensor(1.1921e-07, device='cuda:3', grad_fn=<MaxBackward1>)
False
tensor(1.1921e-07, device='cuda:4', grad_fn=<MaxBackward1>)
False
tensor(1.1921e-07, device='cuda:5', grad_fn=<MaxBackward1>)
False
tensor(1.1921e-07, device='cuda:6', grad_fn=<MaxBackward1>)
False
tensor(1.1921e-07, device='cuda:7', grad_fn=<MaxBackward1>)


In [23]:
random_seed = 123456
params1 = models[0].base_model.model.model.layers[31].self_attn.k_proj.lora_A.default.weight.clone()
torch.manual_seed(random_seed)
for _, params in models[0].trainable_parameters:
    z = torch.normal(mean=0, std=1, size=params.data.size(), device=torch.device('cuda:0'), dtype=params.data.dtype)
    params.data += z * zo_eps
params2 = models[0].base_model.model.model.layers[31].self_attn.k_proj.lora_A.default.weight.clone()
torch.manual_seed(random_seed)
for _, params in models[0].trainable_parameters:
    z = torch.normal(mean=0, std=1, size=params.data.size(), device=torch.device('cuda:0'), dtype=params.data.dtype)
    params.data -= z * zo_eps * 2 
params3 = models[0].base_model.model.model.layers[31].self_attn.k_proj.lora_A.default.weight.clone()
torch.manual_seed(random_seed)
for _, params in models[0].trainable_parameters:
    z = torch.normal(mean=0, std=1, size=params.data.size(), device=torch.device('cuda:0'), dtype=params.data.dtype)
    params.data += z * zo_eps
# update_params_single_node(models[0], zo_eps, random_seed, 0, torch.device('cuda:0'), reset=True)
params4 = models[0].base_model.model.model.layers[31].self_attn.k_proj.lora_A.default.weight.clone()

print((params1[0] == params2[0]).all())
print((params1[0] - params2[0]).abs().max())
print((params1[0] == params3[0]).all())
print((params1[0] - params3[0]).abs().max())
print((params1[0] == params4[0]).all())
print((params1[0] - params4[0]).abs().max())

print(f'params1: {params1[0]}')
print(f'params2: {params2[0]}')
print(f'params3: {params3[0]}')
print(f'params4: {params4[0]}')

tensor(False, device='cuda:0')
tensor(3.4791, device='cuda:0', grad_fn=<MaxBackward1>)
tensor(False, device='cuda:0')
tensor(3.4791, device='cuda:0', grad_fn=<MaxBackward1>)
tensor(False, device='cuda:0')
tensor(2.3842e-07, device='cuda:0', grad_fn=<MaxBackward1>)
params1: tensor([-0.4380, -0.2725,  1.4078,  ..., -1.3894, -0.0833,  1.6678],
       device='cuda:0', grad_fn=<SelectBackward0>)
params2: tensor([ 0.0876, -0.0185, -1.0319,  ..., -1.4705,  0.1747,  3.3763],
       device='cuda:0', grad_fn=<SelectBackward0>)
params3: tensor([-0.9637, -0.5265,  3.8476,  ..., -1.3083, -0.3413, -0.0408],
       device='cuda:0', grad_fn=<SelectBackward0>)
params4: tensor([-0.4380, -0.2725,  1.4078,  ..., -1.3894, -0.0833,  1.6678],
       device='cuda:0', grad_fn=<SelectBackward0>)
