Skip to content

Megatron SFT Qwen3 Omni is Slow #5990

@mertunsall

Description

@mertunsall

The speed seems to be extremely slow on 8xH100 with config (only difference is long context training)

PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
NPROC_PER_NODE=8 \
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
megatron sft \
    --load 'Qwen3-Omni-30B-A3B-Instruct-mcore' \
    --dataset 'dataset_id' \
    --train_type lora \
    --lora_rank 8 \
    --lora_alpha 32 \
    --target_modules all-linear \
    --tensor_model_parallel_size 2 \
    --expert_model_parallel_size 2 \
    --context_parallel_size 2 \
    --moe_grouped_gemm true \
    --moe_shared_expert_overlap true \
    --moe_aux_loss_coeff 1e-3 \
    --micro_batch_size 1 \
    --global_batch_size 32 \
    --packing true \
    --freeze_llm false \
    --freeze_vit true \
    --freeze_aligner true \
    --split_dataset_ratio 0.01 \
    --recompute_granularity full \
    --recompute_method uniform \
    --recompute_num_layers 1 \
    --finetune true \
    --cross_entropy_loss_fusion true \
    --lr 4e-5 \
    --max_epochs 1 \
    --lr_warmup_fraction 0.05 \
    --min_lr 4e-6 \
    --save 'Qwen3-Omni-30B-A3B-Instruct-280925-lora' \
    --eval_interval 200 \
    --save_interval 200 \
    --max_length 32768 \
    --num_workers 100 \
    --dataset_num_proc 100 \
    --no_save_optim true \
    --no_save_rng true \
    --sequence_parallel true \
    --use_flash_attn true \
    --use_hf true

I get ~15 minutes for 5 iterations.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions