Skip to content

GKD训练loss gradnorm异常高 #6600

@haorannlp

Description

@haorannlp

ms-swift 3.11.0dev0: GKD + vLLM(colocate)训练,loss+gradnorm奇高。训练命令:

# pip3 install -e .

export WANDB_PROJECT="my_ms_swift_debug"
export WANDB_EXP_NAME="megatron.opd.open-thoughts3-1.2M-opd-test"

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
NPROC_PER_NODE=8 \
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
NNODES=1 \
NODE_RANK=0 \
swift rlhf \
    --rlhf_type gkd \
    --model /mnt/hdfs/lhr.217_syd/model/Qwen3-8B-Base \
    --teacher_model /mnt/hdfs/lhr.217_syd/model/Qwen3-8B \
    \
    --train_type full \
    --dataset /mnt/hdfs/lhr.217_syd/data/sft/OpenThoughts3-1.2M/OpenThoughts3-1.2M.jsonl \
    --split_dataset_ratio 0.01 \
    --dataloader_num_workers 128 \
    --dataset_num_proc 4 \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 1 \
    \
    --seq_kd false \
    --lmbda 1.0 \
    --beta 1.0 \
    --sft_alpha 0 \
    --learning_rate 1e-5 \
    --lr_scheduler_type constant \
    \
    --deepspeed zero3_offload \
    --offload_teacher_model true \
    --teacher_deepspeed zero3_offload \
    --attn_impl flash_attn \
    --use_liger_kernel true \
    \
    --use_vllm true \
    --vllm_mode colocate \
    --vllm_gpu_memory_utilization 0.2 \
    --vllm_tensor_parallel_size 8 \
    --vllm_max_model_len 32768 \
    --offload_model false \
    --sleep_level 0 \
    \
    --max_length 16000 \
    --max_completion_length 8192 \
    --temperature 0.9 \
    --truncation_strategy left \
    \
    --torch_dtype bfloat16 \
    --num_train_epochs 1 \
    --warmup_ratio 0.05 \
    --per_device_eval_batch_size 1 \
    --eval_steps 50 \
    --save_steps 50 \
    --save_total_limit 2 \
    --logging_steps 1 \
    --output_dir /mnt/hdfs/lhr.217_syd/experiments/opd/testrun \
    --logging_dir ./local_tensorboard \
    --save_only_model true \
    --report_to wandb \
    --log_completions true

loss:

{'loss': 240.5, 'grad_norm': 2309.76513672, 'learning_rate': 1e-05, 'epoch': 0.0, 'global_step/max_steps': '1/148500', 'percentage': '0.00%', 'elapsed_time': '36s', 'remaining_time': '62d 9h 31m 39s', 'memory(GiB)': 28.59, 'train_speed(iter/s)': 0.027545}

Train:   0%|          | 1/148500 [00:36<1497:22:22, 36.30s/it]
Train:   0%|          | 1/148500 [00:36<1497:22:22, 36.30s/it]INFO 11-14 13:34:54 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:34:54 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:34:54 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:34:54 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:34:54 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:34:54 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:34:54 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:34:54 [block_pool.py:321] Successfully reset prefix cache

Train:   0%|          | 2/148500 [01:29<1904:05:14, 46.16s/it]
                                                              
{'loss': 261.25, 'grad_norm': 22868.16210938, 'learning_rate': 1e-05, 'epoch': 0.0, 'global_step/max_steps': '2/148500', 'percentage': '0.00%', 'elapsed_time': '1m 29s', 'remaining_time': '76d 19h 8m 25s', 'memory(GiB)': 28.63, 'train_speed(iter/s)': 0.02238}

Train:   0%|          | 2/148500 [01:29<1904:05:14, 46.16s/it]
Train:   0%|          | 2/148500 [01:29<1904:05:14, 46.16s/it]INFO 11-14 13:35:47 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:35:47 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:35:47 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:35:47 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:35:47 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:35:47 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:35:47 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:35:47 [block_pool.py:321] Successfully reset prefix cache

Train:   0%|          | 3/148500 [02:23<2056:35:53, 49.86s/it]
                                                              
{'loss': 473.25, 'grad_norm': 6887.65527344, 'learning_rate': 1e-05, 'epoch': 0.0, 'global_step/max_steps': '3/148500', 'percentage': '0.00%', 'elapsed_time': '2m 23s', 'remaining_time': '82d 6h 47m 1s', 'memory(GiB)': 28.63, 'train_speed(iter/s)': 0.020888}

Train:   0%|          | 3/148500 [02:23<2056:35:53, 49.86s/it]
Train:   0%|          | 3/148500 [02:23<2056:35:53, 49.86s/it]INFO 11-14 13:36:42 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:36:42 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:36:42 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:36:42 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:36:42 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:36:42 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:36:42 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:36:42 [block_pool.py:321] Successfully reset prefix cache

Train:   0%|          | 4/148500 [03:16<2105:58:37, 51.06s/it]
                                                              
{'loss': 243.5, 'grad_norm': 3961.19458008, 'learning_rate': 1e-05, 'epoch': 0.0, 'global_step/max_steps': '4/148500', 'percentage': '0.00%', 'elapsed_time': '3m 16s', 'remaining_time': '84d 10h 30m 6s', 'memory(GiB)': 28.63, 'train_speed(iter/s)': 0.020355}

Train:   0%|          | 4/148500 [03:16<2105:58:37, 51.06s/it]
Train:   0%|          | 4/148500 [03:16<2105:58:37, 51.06s/it]INFO 11-14 13:37:34 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:37:34 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:37:34 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:37:34 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:37:34 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:37:34 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:37:34 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:37:34 [block_pool.py:321] Successfully reset prefix cache

Train:   0%|          | 5/148500 [04:11<2164:04:35, 52.46s/it]
                                                              
{'loss': 429.0, 'grad_norm': 4230.05908203, 'learning_rate': 1e-05, 'epoch': 0.0, 'global_step/max_steps': '5/148500', 'percentage': '0.00%', 'elapsed_time': '4m 11s', 'remaining_time': '86d 10h 36m 59s', 'memory(GiB)': 28.63, 'train_speed(iter/s)': 0.019883}

Train:   0%|          | 5/148500 [04:11<2164:04:35, 52.46s/it]
Train:   0%|          | 5/148500 [04:11<2164:04:35, 52.46s/it]INFO 11-14 13:38:29 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:38:29 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:38:29 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:38:29 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:38:29 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:38:29 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:38:29 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:38:29 [block_pool.py:321] Successfully reset prefix cache

Train:   0%|          | 6/148500 [05:08<2231:48:03, 54.11s/it]
                                                              
{'loss': 805.0, 'grad_norm': 7340.90771484, 'learning_rate': 1e-05, 'epoch': 0.0, 'global_step/max_steps': '6/148500', 'percentage': '0.00%', 'elapsed_time': '5m 8s', 'remaining_time': '88d 10h 42m 57s', 'memory(GiB)': 45.52, 'train_speed(iter/s)': 0.019432}

Train:   0%|          | 6/148500 [05:08<2231:48:03, 54.11s/it]
Train:   0%|          | 6/148500 [05:08<2231:48:03, 54.11s/it]INFO 11-14 13:39:27 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:39:27 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:39:27 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:39:27 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:39:27 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:39:27 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:39:27 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:39:27 [block_pool.py:321] Successfully reset prefix cache

Train:   0%|          | 7/148500 [06:05<2273:25:39, 55.12s/it]
                                                              
{'loss': 495.0, 'grad_norm': 4887.18505859, 'learning_rate': 1e-05, 'epoch': 0.0, 'global_step/max_steps': '7/148500', 'percentage': '0.00%', 'elapsed_time': '6m 5s', 'remaining_time': '89d 20h 28m 54s', 'memory(GiB)': 45.52, 'train_speed(iter/s)': 0.019127}

Train:   0%|          | 7/148500 [06:06<2273:25:39, 55.12s/it]
Train:   0%|          | 7/148500 [06:06<2273:25:39, 55.12s/it]INFO 11-14 13:40:24 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:40:24 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:40:24 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:40:24 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:40:24 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:40:24 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:40:24 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:40:24 [block_pool.py:321] Successfully reset prefix cache

Train:   0%|          | 8/148500 [07:03<2307:28:42, 55.94s/it]
                                                              
{'loss': 692.1875, 'grad_norm': 7491.97802734, 'learning_rate': 1e-05, 'epoch': 0.0, 'global_step/max_steps': '8/148500', 'percentage': '0.01%', 'elapsed_time': '7m 3s', 'remaining_time': '91d 0h 29m 2s', 'memory(GiB)': 45.74, 'train_speed(iter/s)': 0.018882}

Train:   0%|          | 8/148500 [07:03<2307:28:42, 55.94s/it]
Train:   0%|          | 8/148500 [07:03<2307:28:42, 55.94s/it]INFO 11-14 13:41:22 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:41:22 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:41:22 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:41:22 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:41:22 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:41:22 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:41:22 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:41:22 [block_pool.py:321] Successfully reset prefix cache

Train:   0%|          | 9/148500 [08:01<2331:39:09, 56.53s/it]
                                                              
{'loss': 1081.5625, 'grad_norm': 13781.66308594, 'learning_rate': 1e-05, 'epoch': 0.0, 'global_step/max_steps': '9/148500', 'percentage': '0.01%', 'elapsed_time': '8m 1s', 'remaining_time': '91d 22h 42m 43s', 'memory(GiB)': 45.74, 'train_speed(iter/s)': 0.018692}

Train:   0%|          | 9/148500 [08:01<2331:39:09, 56.53s/it]
Train:   0%|          | 9/148500 [08:01<2331:39:09, 56.53s/it]INFO 11-14 13:42:19 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:42:19 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:42:19 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:42:19 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:42:19 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:42:19 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:42:19 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:42:19 [block_pool.py:321] Successfully reset prefix cache

Train:   0%|          | 10/148500 [08:59<2348:04:44, 56.93s/it]
                                                               
{'loss': 810.3125, 'grad_norm': 8102.26953125, 'learning_rate': 1e-05, 'epoch': 0.0, 'global_step/max_steps': '10/148500', 'percentage': '0.01%', 'elapsed_time': '8m 59s', 'remaining_time': '92d 16h 31m 1s', 'memory(GiB)': 45.76, 'train_speed(iter/s)': 0.018542}

Train:   0%|          | 10/148500 [08:59<2348:04:44, 56.93s/it]
Train:   0%|          | 10/148500 [08:59<2348:04:44, 56.93s/it]INFO 11-14 13:43:17 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:43:17 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:43:17 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:43:17 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:43:17 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:43:17 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:43:17 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:43:17 [block_pool.py:321] Successfully reset prefix cache

Train:   0%|          | 11/148500 [09:57<2360:45:43, 57.23s/it]
                                                               
{'loss': 1323.25, 'grad_norm': 11757.16601562, 'learning_rate': 1e-05, 'epoch': 0.0, 'global_step/max_steps': '11/148500', 'percentage': '0.01%', 'elapsed_time': '9m 57s', 'remaining_time': '93d 7h 30m 26s', 'memory(GiB)': 45.76, 'train_speed(iter/s)': 0.018418}

Train:   0%|          | 11/148500 [09:57<2360:45:43, 57.23s/it]
Train:   0%|          | 11/148500 [09:57<2360:45:43, 57.23s/it]INFO 11-14 13:44:15 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:44:15 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:44:15 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:44:15 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:44:15 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:44:15 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:44:15 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:44:15 [block_pool.py:321] Successfully reset prefix cache

Train:   0%|          | 12/148500 [10:55<2367:42:27, 57.40s/it]
                                                               
{'loss': 1919.0, 'grad_norm': 17126.91015625, 'learning_rate': 1e-05, 'epoch': 0.0, 'global_step/max_steps': '12/148500', 'percentage': '0.01%', 'elapsed_time': '10m 55s', 'remaining_time': '93d 19h 30m 12s', 'memory(GiB)': 45.76, 'train_speed(iter/s)': 0.01832}

Train:   0%|          | 12/148500 [10:55<2367:42:27, 57.40s/it]
Train:   0%|          | 12/148500 [10:55<2367:42:27, 57.40s/it]INFO 11-14 13:45:13 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:45:13 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:45:13 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:45:13 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:45:13 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:45:13 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:45:13 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:45:13 [block_pool.py:321] Successfully reset prefix cache

Train:   0%|          | 13/148500 [11:52<2369:43:50, 57.45s/it]
                                                               
{'loss': 1916.0, 'grad_norm': 16510.4921875, 'learning_rate': 1e-05, 'epoch': 0.0, 'global_step/max_steps': '13/148500', 'percentage': '0.01%', 'elapsed_time': '11m 52s', 'remaining_time': '94d 4h 57m 7s', 'memory(GiB)': 45.76, 'train_speed(iter/s)': 0.018243}

Train:   0%|          | 13/148500 [11:52<2369:43:50, 57.45s/it]
Train:   0%|          | 13/148500 [11:52<2369:43:50, 57.45s/it]INFO 11-14 13:46:11 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:46:11 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:46:11 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:46:11 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:46:11 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:46:11 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:46:11 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:46:11 [block_pool.py:321] Successfully reset prefix cache

Train:   0%|          | 14/148500 [12:50<2372:50:59, 57.53s/it]
                                                               
{'loss': 1807.0625, 'grad_norm': 16647.33789062, 'learning_rate': 1e-05, 'epoch': 0.0, 'global_step/max_steps': '14/148500', 'percentage': '0.01%', 'elapsed_time': '12m 50s', 'remaining_time': '94d 13h 26m 29s', 'memory(GiB)': 45.89, 'train_speed(iter/s)': 0.018175}

Train:   0%|          | 14/148500 [12:50<2372:50:59, 57.53s/it]
Train:   0%|          | 14/148500 [12:50<2372:50:59, 57.53s/it]INFO 11-14 13:47:08 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:47:08 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:47:08 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:47:08 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:47:08 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:47:08 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:47:08 [block_pool.py:321] Successfully reset prefix cache
INFO 11-14 13:47:08 [block_pool.py:321] Successfully reset prefix cache

Train:   0%|          | 15/148500 [13:47<2374:54:04, 57.58s/it]
                                                               
{'loss': 1989.0, 'grad_norm': 15605.45703125, 'learning_rate': 1e-05, 'epoch': 0.0, 'global_step/max_steps': '15/148500', 'percentage': '0.01%', 'elapsed_time': '13m 48s', 'remaining_time': '94d 20h 46m 38s', 'memory(GiB)': 45.89, 'train_speed(iter/s)': 0.018116}

增大训练batch size=128,也是一样

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions