Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions cookbook/megatron/qwen3_5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from peft import LoraConfig

import twinkle
from twinkle import DeviceMesh, get_device_placement, get_logger
from twinkle.dataloader import DataLoader
from twinkle.dataset import Dataset, DatasetMeta
from twinkle.model import MegatronModel
from twinkle.preprocessor import SelfCognitionProcessor

device_mesh = DeviceMesh.from_sizes(dp_size=4, tp_size=1, pp_size=1, ep_size=4)
twinkle.initialize(mode='local', global_device_mesh=device_mesh)

logger = get_logger()

MODEL_ID = 'Qwen/Qwen3.5-35B-A3B'

def train():
dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000)))
dataset.set_template('Template', model_id=MODEL_ID)
dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
dataset.encode()
dataloader = DataLoader(dataset=dataset, batch_size=4)

model = MegatronModel(model_id=MODEL_ID)
lora_config = LoraConfig(r=8, lora_alpha=16, target_modules='all-linear')
model.add_adapter_to_model('default', lora_config)
model.set_optimizer(optimizer_cls='default', lr=1e-4)
model.set_lr_scheduler(scheduler_cls='default', lr_warmup_steps=2, lr_decay_steps=len(dataloader))
logger.info(get_device_placement())
logger.info(model.get_train_configs())
logger.info(f'Total steps: {len(dataloader)}')

for step, batch in enumerate(dataloader):
model.forward_backward(inputs=batch)
model.clip_grad_and_step()
if step % 5 == 0:
metric = model.calculate_metric(is_training=True)
logger.info(f'Step {step}/{len(dataloader)}, metric: {metric}')

# NOTE: you should merge lora for Qwen3.5 model when using Megatron
model.save('last-checkpoint', merge_lora=True)
logger.info('Training completed.')


if __name__ == '__main__':
train()
58 changes: 37 additions & 21 deletions src/twinkle/model/megatron/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class TwinkleMegatronArgs:
num_experts: int = 0
num_experts_per_tok: int = 2
shared_expert_intermediate_size: int = 0
moe_router_enable_expert_bias: bool = False

# =========================================================================
# Training/inference settings
Expand Down Expand Up @@ -137,9 +138,6 @@ class TwinkleMegatronArgs:
# =========================================================================
merge_lora: bool = False
target_modules: List[str] = field(default_factory=list)
freeze_llm: bool = False
freeze_vit: bool = False
freeze_aligner: bool = False

# =========================================================================
# FP8 quantization settings
Expand All @@ -160,7 +158,6 @@ class TwinkleMegatronArgs:
# =========================================================================
untie_embeddings_and_output_weights: bool = True
max_shard_size: str = '5GB'
llm_model_type: str = 'gpt' # For transformers 5.0 compatibility
use_cpu_initialization: bool = False

def __post_init__(self):
Expand Down Expand Up @@ -260,6 +257,10 @@ def head_dim(self) -> int:
def intermediate_size(self) -> int:
return self.ffn_hidden_size

@property
def moe_shared_expert_intermediate_size(self) -> int:
return self.shared_expert_intermediate_size

@property
def num_query_groups(self) -> int:
"""Alias for num_key_value_heads (Megatron naming)."""
Expand Down Expand Up @@ -330,9 +331,12 @@ def from_hf_config(
# Get rope_scaling
rope_scaling = getattr(text_config, 'rope_scaling', None)

# Detect multimodal model
model_type = getattr(hf_config, 'model_type', 'qwen2')
is_multimodal = 'vl' in model_type.lower() or 'vision' in model_type.lower() or 'omni' in model_type.lower()

# Detect multimodal model from the registered MegatronModelMeta
from .model.register import get_megatron_model_meta
model_meta = get_megatron_model_meta(model_type)
is_multimodal = model_meta.is_multimodal if model_meta is not None else False

# Determine QKV bias
if hasattr(text_config, 'attention_bias'):
Expand Down Expand Up @@ -435,7 +439,6 @@ def create_model(self, ) -> List[nn.Module]:
if self._model is not None:
return self._model
from megatron.core import parallel_state as mpu
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
from megatron.core.transformer import TransformerConfig
from megatron.core.transformer.enums import AttnBackend

Expand Down Expand Up @@ -611,24 +614,37 @@ def _get_base_model(m):
if exists('megatron_core>=0.13'):
config.expert_tensor_parallel_size = self.etp_size

# Save transformer config for later use (e.g., DDP wrapping)
if mg_config_dict.get('use_shared_expert_gate'):
config.moe_use_shared_expert_gate = True
if mg_config_dict.get('rotary_interleaved'):
config.rotary_interleaved = True
partial_rotary_factor = mg_config_dict.get('partial_rotary_factor')
if partial_rotary_factor is not None:
config.rotary_percent = partial_rotary_factor
config.apply_rope_fusion = False
mrope_section = mg_config_dict.get('mrope_section')
if mrope_section is not None:
config.mrope_section = mrope_section
if mg_config_dict.get('mrope_interleaved'):
config.mrope_interleaved = True

self.config = config

# Get layer spec - enable moe_grouped_gemm for MoE models
moe_grouped_gemm = num_experts > 0
try:
layer_spec = get_gpt_layer_with_transformer_engine_spec(
num_experts=mg_config_dict.get('num_experts'),
moe_grouped_gemm=moe_grouped_gemm,
qk_layernorm=mg_config_dict.get('qk_layernorm', False),
)
except (ImportError, AttributeError):
raise RuntimeError(
'TransformerEngine is not installed or not compatible with this version of Megatron-Core.')
# Delegate model-specific config & layer spec construction to the loader
loader = model_meta.loader() if model_meta else None
if loader is not None:
loader.post_config(config, self, mg_config_dict)
layer_spec = loader.get_layer_spec(config, self, mg_config_dict)
else:
from .model.register import MegatronModelLoader
default_loader = MegatronModelLoader()
default_loader.post_config(config, self, mg_config_dict)
layer_spec = default_loader.get_layer_spec(config, self, mg_config_dict)

# Create model
max_seq_length = getattr(hf_config, 'max_position_embeddings', 4096)
rotary_base = mg_config_dict.get('rotary_base', 10000)
position_embedding_type = mg_config_dict.get('position_embedding_type', 'rope')
extra_init_args = {}
if hasattr(hf_config,
'rope_scaling') and hf_config.rope_scaling is not None and 'factor' in hf_config.rope_scaling:
Expand All @@ -651,7 +667,7 @@ def _get_base_model(m):
post_process=mpu.is_pipeline_last_stage(**extra_kwargs),
parallel_output=True,
share_embeddings_and_output_weights=getattr(hf_config, 'tie_word_embeddings', False),
position_embedding_type='rope',
position_embedding_type=position_embedding_type,
rotary_base=rotary_base,
**extra_init_args)
model.append(_model)
Expand All @@ -666,7 +682,7 @@ def _get_base_model(m):
post_process=mpu.is_pipeline_last_stage(),
parallel_output=True,
share_embeddings_and_output_weights=getattr(hf_config, 'tie_word_embeddings', False),
position_embedding_type='rope',
position_embedding_type=position_embedding_type,
rotary_base=rotary_base,
**extra_init_args,
)
Expand Down
36 changes: 32 additions & 4 deletions src/twinkle/model/megatron/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,7 @@ def save(self,
output_dir: Optional[str] = None,
interval: int = 1,
save_optimizer: bool = False,
merge_lora: bool = False,
**kwargs):
"""Save model checkpoint.

Expand All @@ -832,6 +833,9 @@ def save(self,
interval: Save each *interval* steps.
save_optimizer: If True, save optimizer + lr_scheduler + RNG state
alongside the HF weights for checkpoint resumption.
merge_lora: If True, merge LoRA adapters into base weights and save
the full merged model instead of PEFT adapter format. The merge
is reversed after saving so training can continue.
**kwargs: Additional arguments forwarded to the underlying save
methods (e.g. ``adapter_name``).
"""
Expand All @@ -846,8 +850,16 @@ def save(self,
output_dir = 'output'
checkpoint_dir = os.path.join(output_dir, name)

# Always save HF-format weights (for inference / deployment).
self._save_hf_format(checkpoint_dir, optimizer_config.adapter_name)
is_lora = (optimizer_config.adapter_name != _default_adapter_name)

if merge_lora and is_lora:
self._merge_lora_adapters(optimizer_config.adapter_name)
self._save_hf_format(checkpoint_dir, _default_adapter_name)
self._save_tokenizer(checkpoint_dir, adapter_name=adapter_name)
self._unmerge_lora_adapters()
else:
self._save_hf_format(checkpoint_dir, optimizer_config.adapter_name)
self._save_tokenizer(checkpoint_dir, adapter_name=adapter_name)

# Optionally save mcore optimizer state (for training resumption).
if save_optimizer:
Expand All @@ -857,8 +869,6 @@ def save(self,
**kwargs,
)

self._save_tokenizer(checkpoint_dir, adapter_name=adapter_name)

# Final synchronization to ensure all ranks complete save.
if dist.is_initialized():
dist.barrier()
Expand Down Expand Up @@ -1160,6 +1170,24 @@ def _read_iteration(tracker_path: str) -> int:
iteration = iters_cuda[0].item()
return iteration

def _merge_lora_adapters(self, adapter_name: str = 'default'):
"""Merge LoRA adapters into base model weights."""
from .tuners.lora import LoraParallelLinear
with torch.no_grad():
for model in self.strategy.unwrap_model(self.model):
for module in model.modules():
if isinstance(module, (LoraParallelLinear, LoraLinear)):
module.merge(adapter_names=[adapter_name])

def _unmerge_lora_adapters(self):
"""Unmerge LoRA adapters to restore training state."""
from .tuners.lora import LoraParallelLinear
with torch.no_grad():
for model in self.strategy.unwrap_model(self.model):
for module in model.modules():
if isinstance(module, (LoraParallelLinear, LoraLinear)):
module.unmerge()

def _save_hf_format(self, output_dir: str, adapter_name: str, lora_converter=None):
"""Save in HuggingFace format using bridge adapter.

Expand Down
2 changes: 1 addition & 1 deletion src/twinkle/model/megatron/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from . import gpts, mm_gpts
from .constant import MegatronModelType
from .gpt_bridge import GPTBridge
from .register import MegatronModelMeta, get_megatron_model_meta, register_megatron_model
from .register import MegatronModelLoader, MegatronModelMeta, get_megatron_model_meta, register_megatron_model
3 changes: 3 additions & 0 deletions src/twinkle/model/megatron/model/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ class MLLMModelType:
qwen2_5_vl = 'qwen2_5_vl'
qwen3_vl = 'qwen3_vl'
qwen3_vl_moe = 'qwen3_vl_moe'
qwen3_5 = 'qwen3_5'
qwen3_5_moe = 'qwen3_5_moe'


class ModelType(LLMModelType, MLLMModelType):
Expand All @@ -29,6 +31,7 @@ class MLLMMegatronModelType:
qwen2_vl = 'qwen2_vl'
qwen2_5_vl = 'qwen2_5_vl'
qwen3_vl = 'qwen3_vl'
qwen3_5 = 'qwen3_5'


class MegatronModelType(LLMMegatronModelType, MLLMMegatronModelType):
Expand Down
Loading
Loading