In [None]:
%cd ..
! pip install -q bitsandbytes deepspeed huggingface_hub
! pip install -q -r requirements.txt

In [None]:
import huggingface_hub

huggingface_hub.login()

In [None]:
# Base model and dataset
model_name_or_path = "stabilityai/japanese-stablelm-base-gamma-7b"
dataset = "wikipedia_ja_2022,slim_pajama_en,culturax_ja,wikipedia_en_2022,code_stack_en"

# Training parameters
mix_strategy = "interleave_over"
max_length = 4096
lora_rank = 64
lora_alpha = 128.0
lora_dropout = 0.05
lora_target = "q_proj,v_proj,k_proj,o_proj,gate_proj,down_proj,up_proj"
additional_target = "lm_head,embed_tokens"
output_dir = "./pt_outputs"
logging_dir = "./pt_logs"
overwrite_output_dir = True
streaming = True
max_steps = int(2.5e5)
per_device_train_batch_size = 8
gradient_accumulation_steps = 2
lr_scheduler_type = "cosine"
logging_steps = 10
save_steps = 200
learning_rate = 2e-4
num_train_epochs = 1.0
plot_loss = True

# performance parameters
use_flash_attn = True
use_bf16 = True
use_fp16 = True

if use_bf16:
    use_fp16 = False

In [None]:
%env CUDA_VISIBLE_DEVICES=0

! python src/train_bash.py \
    --stage sft \
    --model_name_or_path {model_name_or_path} \
    --do_train \
    --dataset {dataset} \
    --template default \
    --mix_strategy {mix_strategy} \
    --max_length {max_length} \
    --finetuning_type lora \
    --lora_rank {lora_rank} \
    --lora_alpha {lora_alpha} \
    --lora_dropout {lora_dropout} \
    --lora_target {lora_target} \
    --additional_target {additional_target} \
    --output_dir {output_dir} \
    --logging_dir {logging_dir} \
    {'--overwrite_cache' if overwrite_cache else ''} \
    {'--overwrite_output_dir' if overwrite_output_dir else ''} \
    {'--streaming' if streaming else ''} \
    --max_steps {max_steps} \
    --per_device_train_batch_size {per_device_train_batch_size} \
    --gradient_accumulation_steps {gradient_accumulation_steps} \
    --lr_scheduler_type {lr_scheduler_type} \
    --logging_steps {logging_steps} \
    --save_steps {save_steps} \
    --learning_rate {learning_rate} \
    --num_train_epochs {num_train_epochs} \
    --plot_loss \
    {'--plot_loss' if plot_loss else ''} \
    {'--flash_attn' if use_flash_attn else ''} \
    {'--bf16' if use_bf16 else ''} \
    {'--fp16' if use_fp16 else ''}
