In [None]:

from datasets import load_dataset

from transformers import PreTrainedTokenizerFast, PhiForCausalLM, TrainingArguments
from datasets import load_dataset
import pandas as pd
import time

from trl import SFTTrainer, DataCollatorForCompletionOnlyLM

# 这个版本SFT使用的是SFTTrainer，map数据集比较慢，仅供参考

In [None]:
sft_file = './data/sft_train_data.parquet'
tokenizer_dir = './model_save/tokenizer/'
sft_from_checkpoint_file = './model_save/pre/'
model_save_dir = './model_save/sft/'
max_seq_len = 320

In [None]:
dataset = load_dataset(path='parquet', data_files=sft_file, split='train', cache_dir='.cache')

In [None]:
dataset

In [None]:
samples = dataset[0:2]
print(samples)

In [None]:
tokenizer = PreTrainedTokenizerFast.from_pretrained(tokenizer_dir)
len(tokenizer)

In [None]:
def formatting_prompts_func(example: list[dict]) -> list[str]:
    batch_txt = []
    for i in range(len(example['instruction'])):
        text = f"[BOS]##提问:\n{example['instruction'][i]}\n##回答:\n{example['output'][i]}[EOS]"
        batch_txt.append(text)
        
    return batch_txt

formatting_prompts_func(samples)

In [None]:
instruction_template = "##提问:"
response_template = "##回答:"
collator = DataCollatorForCompletionOnlyLM(instruction_template=instruction_template, response_template=response_template, tokenizer=tokenizer, mlm=False)

In [None]:
print(tokenizer([instruction_template, response_template])['input_ids'])

In [None]:

model = PhiForCausalLM.from_pretrained(sft_from_checkpoint_file)

model_size = sum(t.numel() for t in model.parameters())
print(f"Phi2 size: {model_size / 1000**2:.2f}M parameters")

In [None]:
args = TrainingArguments(
    output_dir=model_save_dir,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=8,
    num_train_epochs=3,
    weight_decay=0.1,
    warmup_steps=1000,
    learning_rate=5e-5,
    save_steps=2000,
    save_total_limit=3,
    report_to='tensorboard',
    optim="adafactor",
    bf16=True,
    logging_steps=10,
    log_level='info',
    logging_first_step=True,
)
trainer = SFTTrainer(
    model,
    args=args,
    train_dataset=dataset,
    formatting_func=formatting_prompts_func,
    data_collator=collator,
    tokenizer=tokenizer,
    max_seq_length=384, #  eos
)


In [None]:
trainer.train(
    # resume_from_checkpoint=True
)

In [None]:
loss_log = pd.DataFrame(trainer.state.log_history)
loss_log.to_csv(f"./logs/sft_train_log_{time.strftime('%Y%m%d-%H%M')}.csv")


trainer.save_model(model_save_dir)