# Distributed Tensor Parallelism with Megatron-Core GPT Model

This example demonstrates how to train a GPT model using [Megatron-Core](https://github.com/NVIDIA/Megatron-LM) with [Tensor Parallelism (TP)](https://docs.nvidia.com/nemo/megatron-bridge/0.2.0/parallelisms.html) on Kubeflow Trainer.

Tensor Parallelism splits individual model layer weight matrices across multiple GPUs, allowing you to train models that are too large to fit on a single GPU. Since Megatron-Core uses `torchrun` as its distributed launcher, it works natively with the existing `torch-distributed` ClusterTrainingRuntime.

This notebook is based on the official [run_simple_mcore_train_loop.py](https://github.com/NVIDIA/Megatron-LM/blob/main/examples/run_simple_mcore_train_loop.py) from the Megatron-LM repository.

Megatron-Core Quickstart: https://docs.nvidia.com/megatron-core/developer-guide/latest/user-guide/quickstart.html

## Install the Kubeflow SDK

You need to install the Kubeflow SDK to interact with Kubeflow Trainer APIs:

In [None]:
# !pip install -U kubeflow

## Prerequisites

**GPU Requirement**: This notebook requires at least **2 NVIDIA GPUs** on a single node. Megatron-Core requires CUDA and uses the NCCL backend for distributed communication.

**Training Runtimes**: Make sure the Kubeflow Trainer Controller Manager and Training Runtimes are installed. Follow the [installation guide](https://www.kubeflow.org/docs/components/trainer/operator-guides/installation/).

## Define the Training Function

The first step is to create function to train GPT model using Megatron-Core with Tensor Parallelism.

This training function is based on the official [run_simple_mcore_train_loop.py](https://github.com/NVIDIA/Megatron-LM/blob/main/examples/run_simple_mcore_train_loop.py) from the Megatron-LM repository.

In [None]:
def train_megatron_gpt_tp():
    import os
    import torch
    from torch.optim import Adam
    from torch.utils.data import DataLoader
    from functools import partial
    from pathlib import Path

    from megatron.core import parallel_state
    from megatron.core import dist_checkpointing
    from megatron.core.pipeline_parallel.schedules import get_forward_backward_func
    from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
    from megatron.core.transformer.transformer_config import TransformerConfig
    from megatron.core.models.gpt.gpt_model import GPTModel
    from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec
    from megatron.core.datasets.utils import compile_helpers
    from megatron.core.datasets.blended_megatron_dataset_builder import (
        BlendedMegatronDatasetBuilder,
    )
    from megatron.core.datasets.gpt_dataset import GPTDatasetConfig, MockGPTDataset
    from megatron.core.distributed import DistributedDataParallel
    from megatron.core.distributed import DistributedDataParallelConfig
    from megatron.core.distributed.finalize_model_grads import finalize_model_grads
    from megatron.core.tokenizers import MegatronTokenizer

    _SEQUENCE_LENGTH = 64

    # ----------------------------------------------------------------
    # Step 1: Initialize torch.distributed and Megatron model parallel
    # ----------------------------------------------------------------
    # tensor_model_parallel_size=2 means each layer's weight matrices
    # are split across 2 GPUs. This is the core of Tensor Parallelism.
    parallel_state.destroy_model_parallel()

    rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    local_rank = int(os.environ["LOCAL_RANK"])

    torch.cuda.set_device(local_rank)
    torch.distributed.init_process_group(
        backend="nccl", rank=rank, world_size=world_size
    )

    tensor_model_parallel_size = 2
    pipeline_model_parallel_size = 1
    parallel_state.initialize_model_parallel(
        tensor_model_parallel_size, pipeline_model_parallel_size
    )

    # Set a fixed seed for reproducibility across tensor-parallel ranks.
    model_parallel_cuda_manual_seed(123)

    # ----------------------------------------------------------------
    # Step 2: Build a small GPT model
    # ----------------------------------------------------------------
    transformer_config = TransformerConfig(
        num_layers=2,
        hidden_size=12,
        num_attention_heads=4,
        use_cpu_initialization=True,
        pipeline_dtype=torch.float32,
    )

    gpt_model = GPTModel(
        config=transformer_config,
        transformer_layer_spec=get_gpt_layer_local_spec(),
        vocab_size=100,
        max_sequence_length=_SEQUENCE_LENGTH,
    )

    device = torch.device("cuda")
    gpt_model.to(device)

    # Wrap with DistributedDataParallel for gradient synchronization.
    ddp_config = DistributedDataParallelConfig(
        grad_reduce_in_fp32=False,
        overlap_grad_reduce=False,
        use_distributed_optimizer=False,
    )
    gpt_model = DistributedDataParallel(
        config=transformer_config,
        ddp_config=ddp_config,
        module=gpt_model,
    )

    optim = Adam(gpt_model.parameters())

    # ----------------------------------------------------------------
    # Step 3: Prepare a mock dataset (no real data download needed)
    # ----------------------------------------------------------------
    if torch.distributed.is_available() and torch.distributed.is_initialized():
        if torch.distributed.get_rank() == 0:
            compile_helpers()
        torch.distributed.barrier()
    else:
        compile_helpers()

    config = GPTDatasetConfig(
        random_seed=0,
        sequence_length=_SEQUENCE_LENGTH,
        reset_position_ids=False,
        reset_attention_mask=False,
        eod_mask_loss=False,
        tokenizer=MegatronTokenizer.from_pretrained(
            metadata_path={"library": "null-text"},
            vocab_size=_SEQUENCE_LENGTH,
        ),
        mid_level_dataset_surplus=0.005,
    )

    datasets = BlendedMegatronDatasetBuilder(
        MockGPTDataset, [1000, None, None], lambda: True, config
    ).build()

    train_dataloader = DataLoader(datasets[0], batch_size=8, shuffle=True)
    train_iterator = iter(train_dataloader)

    # ----------------------------------------------------------------
    # Step 4: Define forward step function
    # ----------------------------------------------------------------
    def forward_step_func(data_iterator, model):
        def loss_func(loss_mask, output_tensor):
            losses = output_tensor.float()
            loss_mask = loss_mask.view(-1).float()
            loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
            return loss, {"lm loss": loss}

        data = next(data_iterator)
        tokens = data["tokens"].to(device)
        attention_mask = data["attention_mask"].to(device)
        position_ids = data["position_ids"].to(device)
        labels = data["labels"].to(device)
        loss_mask = data["loss_mask"].to(device)

        output_tensor = model(tokens, position_ids, attention_mask, labels=labels)

        return output_tensor, partial(loss_func, loss_mask)

    # ----------------------------------------------------------------
    # Step 5: Training loop â€” 5 iterations
    # ----------------------------------------------------------------
    forward_backward_func = get_forward_backward_func()

    for iteration in range(5):
        optim.zero_grad()

        losses_reduced = forward_backward_func(
            forward_step_func=forward_step_func,
            data_iterator=train_iterator,
            model=gpt_model,
            num_microbatches=1,
            seq_length=_SEQUENCE_LENGTH,
            micro_batch_size=8,
            decoder_seq_length=_SEQUENCE_LENGTH,
            forward_only=False,
        )

        # Synchronize gradients across TP and DP groups.
        finalize_model_grads([gpt_model])

        optim.step()

        print(f"Iteration {iteration}: Losses reduced: {losses_reduced}")

    # ----------------------------------------------------------------
    # Step 6: Save and load a distributed checkpoint
    # ----------------------------------------------------------------
    ckpt_path = os.getcwd() + "/ckpt"
    Path(ckpt_path).mkdir(exist_ok=True)

    model_to_save = gpt_model.module if hasattr(gpt_model, "module") else gpt_model
    sharded_state_dict = model_to_save.sharded_state_dict(prefix="")
    dist_checkpointing.save(
        sharded_state_dict=sharded_state_dict, checkpoint_dir=ckpt_path
    )
    print(f"Checkpoint saved to {ckpt_path}")

    sharded_state_dict = model_to_save.sharded_state_dict(prefix="")
    checkpoint = dist_checkpointing.load(
        sharded_state_dict=sharded_state_dict, checkpoint_dir=ckpt_path
    )
    model_to_save.load_state_dict(checkpoint)
    print("Checkpoint loaded successfully")

    torch.distributed.destroy_process_group()

## List the Training Runtimes

You can get the list of available Training Runtimes to start your TrainJob.

We use the `torch-distributed` runtime since Megatron-Core uses `torchrun` natively.

In [None]:
from kubeflow.trainer import TrainerClient, CustomTrainer

client = TrainerClient()

for runtime in client.list_runtimes():
    print(runtime)
    if runtime.name == "torch-distributed":
        torch_runtime = runtime

## Run the Distributed TrainJob

Kubeflow TrainJob will train the above GPT model on 1 node with 2 GPUs using Tensor Parallelism (`tensor_model_parallel_size=2`).

In [None]:
job_name = client.train(
    trainer=CustomTrainer(
        func=train_megatron_gpt_tp,
        num_nodes=1,
        resources_per_node={
            "memory": "16Gi",
            "nvidia.com/gpu": 2,
        },
        packages_to_install=["megatron-core"],
    ),
    runtime=torch_runtime,
)

## Check the TrainJob steps

You can check the components of TrainJob that's created.

You can get the individual status for each of these steps.

In [None]:
# Wait for the running status.
client.wait_for_job_status(name=job_name, status={"Running"})

In [None]:
for c in client.get_job(name=job_name).steps:
    print(f"Step: {c.name}, Status: {c.status}, Devices: {c.device} x {c.device_count}\n")

## Watch the TrainJob logs

We can use the `get_job_logs()` API to get the TrainJob logs.

In [None]:
for logline in client.get_job_logs(job_name, follow=True):
    print(logline)

## Delete the TrainJob

When the TrainJob is finished, you can delete the resource.

In [None]:
# client.delete_job(job_name)