In [None]:
%load_ext autoreload

%autoreload 2

%env CUDA_VISIBLE_DEVICES=5

import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from trl import (SFTTrainer, SFTConfig)

from transformers import (AutoTokenizer, AutoModelForCausalLM, TrainingArguments)
from peft import LoraConfig, get_peft_model, TaskType
from datasets import load_dataset


In [None]:
model_name = "huggingface/meta-llama/Llama-3.2-3B-Instruct"
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
tokenizer : PreTrainedTokenizerFast = AutoTokenizer.from_pretrained(model_name)
base_model = AutoModelForCausalLM.from_pretrained(model_name)


In [None]:
lora_rank = 8

lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=lora_rank,
    lora_alpha=16,
    lora_dropout=0.1,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "down_proj", "up_proj"],
)

model = get_peft_model(base_model, lora_config)

# print(model)

model.print_trainable_parameters()

model.config.use_cache = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = model.to(device)

In [None]:
datasets_id = "huggingface/allenai/tulu-3-sft-mixture"

raw_ds = load_dataset(datasets_id, split="train")

def filter_func(x):
    msg = x["messages"]
    if not any(m["role"] == "assistant" and m["content"].strip() != "" for m in msg):
        return False
    src = x["source"]
    allowed_src = ["math", "science", "history", "literature"]
    for allowed in allowed_src:
        if allowed in src:
            return True
    return False


raw_ds = raw_ds.filter(filter_func).flatten_indices()

raw_ds = raw_ds.train_test_split(test_size=0.005, seed=42)

train_ds = raw_ds["train"]
eval_ds = raw_ds["test"]

train_ds = train_ds.shuffle(seed=42)

mini_ds = train_ds.select(range(10))

print("size of train dataset: ", len(train_ds))
print("size of eval dataset: ", len(eval_ds))

In [None]:
my_template = ""

with open("llama-3.2.jinja2", "r", encoding="utf-8") as f:
    my_template = f.read()



In [None]:


tokenizer 

msg = mini_ds[0]["messages"]

processed = tokenizer.apply_chat_template(msg, return_dict=True, return_assistant_tokens_mask=True, chat_template=my_template)

print(processed["assistant_masks"])



In [None]:
origin_str = tokenizer.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)

print(len(origin_str))

my_str = tokenizer.apply_chat_template(msg, tokenize=False, chat_template=my_template, add_generation_prompt=True)

print(len(my_str))

print(origin_str == my_str)



In [None]:
input_ids = torch.tensor(processed["input_ids"])
assistant_masks = torch.tensor(processed["assistant_masks"]).to(torch.bool)


# print(input_ids)
# print(assistant_masks)

gen_ids = input_ids[assistant_masks]

tokenizer.decode(gen_ids[-1:])






In [None]:
sft_args = SFTConfig(
    output_dir="./output/test",
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=8,
    logging_steps=20,
    eval_strategy="steps",
    eval_steps=200,
    save_strategy="steps",
    save_steps=200,
    save_total_limit=2,
    learning_rate=2e-4,
    num_train_epochs=1,
    bf16=True,
    gradient_checkpointing=True,
    lr_scheduler_type="cosine",
    warmup_ratio=0.03,
    assistant_only_loss=True,
    # dataset_kwargs={"skip_prepare_dataset": True},
    packing=False,
)

trainer = SFTTrainer(
    model=model,
    args=sft_args,
    train_dataset=mini_ds,
    eval_dataset=mini_ds,
    processing_class=tokenizer,
)

In [None]:
trainer.train()

In [None]:
trainer.train()

# 6) 保存（仅保存 LoRA adapter 权重）
from peft import PeftModel
if isinstance(model, PeftModel):
    model.save_pretrained("./out-lora-tulu3/adapter")
else:
    # 意外情况（例如未套 PEFT）：存整模型
    model.save_pretrained("./out-lora-tulu3/full")
tokenizer.save_pretrained("./out-lora-tulu3")
print("✅ Done. Saved to ./out-lora-tulu3")

In [None]:
from datasets import load_dataset

ds = load_dataset("huggingface/allenai/tulu-3-sft-mixture")

# ds.save_to_disk(Path("huggingface") / "allenai/tulu-3-sft-mixture" / "test")
