In [109]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import Trainer, TrainingArguments
from trl import SFTConfig, SFTTrainer
from trl import setup_chat_format
from peft import get_peft_model, LoraConfig, TaskType
from datasets import load_dataset
from concurrent.futures import ThreadPoolExecutor
from trl import DataCollatorForCompletionOnlyLM
model_name = "Qwen/Qwen2.5-1.5B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
ds = load_dataset("openai/gsm8k", "main")

def preprocess_dataset(ds):
    questions, answers =  ds["train"]["question"], ds["train"]["answer"]
    with open("prompts/r1_zero.prompt", "r", encoding="utf-8") as f:
        prompt_string = f.read()

    def process_question(q):
        return prompt_string.format(question=q)
    def process_ground_truth(ans):
        return ans.split('\n#### ')[1]
    def process_prompt_completion(q, ans):
        prompt = prompt_string.format(question=q)
        cot =' ' + ans.split('\n#### ')[0] + ' </think>'
        gt = f" <answer> {ans.split('\n#### ')[1]} </answer>"
        return prompt + cot + gt
    with ThreadPoolExecutor() as executor:
        question_prompts = list(executor.map(process_question, ds["train"]["question"]))
    with ThreadPoolExecutor() as executor:
        ground_truth = list(executor.map(process_ground_truth, ds["train"]["answer"]))
    with ThreadPoolExecutor() as executor:
        prompt_completion = list(executor.map(process_prompt_completion, ds["train"]["question"], ds["train"]["answer"]))
    return question_prompts, ground_truth, prompt_completion



# Build a collator whose response_template matches your prompt ending
# by default, truncate from left side, and sacrifice prompt
# 1) Look up the ID of the built-in <|im_end|> token:
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")

# 2) Tell the tokenizer to use that as its pad token:
tokenizer.pad_token = "<|im_end|>"
tokenizer.pad_token_id = im_end_id
collator = DataCollatorForCompletionOnlyLM(
    tokenizer = tokenizer,
    # Anything before *and including* this string gets label = -100
    response_template  = r"Assistant: <think>",   # note the space after >
    
)


#preprocess_dataset(ds)[2][0:2]

# some test on collator
# curr_batch = tokenizer(
#     preprocess_dataset(ds)[2][0:2],
#     padding=False,               # collator can handle this , set to false
#     truncation=False,           # collator will handle this with left truncate
#     return_special_tokens_mask=False   # only need for MLM task
#     #return_tensors=None        # lists, not tensors – collator wants lists
# )

# # #collator format [[dict] , [dict], [dict]]
# cnt = len(curr_batch['input_ids'])

# curr_batch_reconstruct =[]
# for i in range(cnt):
#     temp = dict()
#     for k,vals in curr_batch.items():
#         temp[k] = vals[i]
#     curr_batch_reconstruct.append(temp)




# collator_output_dict = collator.torch_call(curr_batch_reconstruct)
# collator_output_dict['labels'][1]
# print(collator_output_dict['input_ids'][0][128:])
# print(collator_output_dict['labels'][0][128:])
# print(len(collator_output_dict['input_ids'][0]))
# print(len(collator_output_dict['labels'][0]))
# # # print(tokenizer.decode(collator_output_dict['input_ids'][1]))
# # # print(tokenizer.decode([42 if v==-100 else v for v in collator_output_dict['labels'][1].tolist() ]))


# # #collator([12,33])
# # #print(curr_batch.keys())
# # #len(curr_batch['input_ids'][0]), len(curr_batch['input_ids'][1])    
# # #curr_batch['attention_mask'][1] # the second mask has 1 digit 0

tensor([    16,     17,     14,     21,     15,    284,    400,   2442,     16,
            17,     14,     21,     15,     28,     15,     13,     17,   2452,
            15,     13,     17,    817,   9383,    624,  33978,    220,     20,
            15,   4420,     11,   1340,  15303,    220,     15,     13,     17,
           856,    220,     20,     15,    284,    400,   2442,     15,     13,
            17,      9,     20,     15,     28,     16,     15,   2452,     16,
            15,     13,    690,  26865,     29,    366,   9217,     29,    220,
            16,     15,    690,   9217,     29, 151645])
tensor([ -100,  -100,  -100, 41601,   685,  6088,   220,    19,    23,    14,
           17,   284,  1115,    19,    23,    14,    17,    28,    17,    19,
         2452,    17,    19, 26111,   304,  3217,   624,    45,  4212,   685,
         6088,   220,    19,    23,    10,    17,    19,   284,  1115,    19,
           23,    10,    17,    19,    28,    22,    17,  2452,    22, 

In [110]:
tokenizer.decode([151645])

'<|im_end|>'

In [None]:

# model = AutoModelForCausalLM.from_pretrained(
#     model_id,
#     torch_dtype=torch.bfloat16,
#     attn_implementation="flash_attention_2",
#     ).to(policy_device)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")

# Make sure your tokenizer knows about <|im_end|> already:
#   (Qwen’s tokenizer has <|im_end|> in additional_special_tokens, dont need to create id, only find it through convert_tokens_to_ids)
eos_token_str = "<|im_end|>"
eos_id = tokenizer.convert_tokens_to_ids(eos_token_str) # only for existed pair mapping

sft_config = SFTConfig(
    max_seq_length=1024,
    pad_token_id=eos_id,
    eos_token_id=eos_id,          # <— this is what TRL will use to stop
    # you can also set other generation defaults here if you like
)


# LoRA 配置
lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=16,  
    lora_alpha=32,
    lora_dropout=0.1,
)

# 将LoRA配置应用到模型
peft_model = get_peft_model(model, lora_config)




# 训练参数配置
training_args = TrainingArguments(
    output_dir="./sft_lora_results",
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    learning_rate=1e-4,  
    num_train_epochs=3,
    fp16=True,  
    logging_steps=1
)

# 使用Trainer API进行训练
trainer = Trainer(
    model=peft_model,
    args=training_args,
    train_dataset=SFTDataset("SFT_data.json"),
    data_collator=torch.utils.data.DataCollatorWithPadding(tokenizer=tokenizer)
)

trainer.train()

In [11]:
from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = "Qwen/Qwen2.5-1.5B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
#tokenizer.decode(tokenizer.pad_token_id)

messages = [
    {"role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate",},
    {"role": "user", "content": "How many helicopters can a human eat in one sitting?"},
 ]
tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")
tokenized_chat


tensor([[151644,   8948,    198,   2610,    525,    264,  11657,   6236,   6331,
            879,   2677,  30580,    304,    279,   1707,    315,    264,  53966,
         151645,    198, 151644,    872,    198,   4340,   1657,  58332,    646,
            264,   3738,   8180,    304,    825,  11699,     30, 151645,    198,
         151644,  77091,    198]])

In [8]:
tokenizer.decode(tokenized_chat[0])

'<|im_start|>system\nYou are a friendly chatbot who always responds in the style of a pirate<|im_end|>\n<|im_start|>user\nHow many helicopters can a human eat in one sitting?<|im_end|>\n<|im_start|>assistant\n'

In [3]:
#setup_chat_format is used to create chat template with [{role:xx, content:xx}] data, to convert it to purely text
#we dont need to setup_chat_format for qwen. it has been setup
# from transformers import AutoTokenizer, AutoModelForCausalLM
# from trl import SFTConfig, SFTTrainer
# from trl import setup_chat_format
# model_name = "Qwen/Qwen2.5-1.5B"
# tokenizer = AutoTokenizer.from_pretrained(model_name)

# model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
# model, tokenizer = setup_chat_format(model, tokenizer)
# tokenizer.decode(tokenizer.eos_token_id)
print(tokenizer.additional_special_tokens)

['<|im_start|>', '<|im_end|>', '<|object_ref_start|>', '<|object_ref_end|>', '<|box_start|>', '<|box_end|>', '<|quad_start|>', '<|quad_end|>', '<|vision_start|>', '<|vision_end|>', '<|vision_pad|>', '<|image_pad|>', '<|video_pad|>']


In [5]:
tokenizer.model_max_length

131072