In [1]:
# 2025/7/29
# zhangzhong
# https://medium.com/huggingface/training-larger-batches-practical-tips-on-1-gpu-multi-gpu-distributed-setups-ec88c3e51255

In [None]:
# how stochastic gradient descent works in 5 steps:

# predictions = model(inputs)               # Forward pass
# loss = loss_function(predictions, labels) # Compute loss function
# loss.backward()                           # Backward pass
# optimizer.step()                          # Optimizer step
# predictions = model(inputs)               # Forward pass with new parameters

In [None]:
# Accumulating gradients:
# before calling optimizer.step() to perform a step of gradient descent, 
# we will sum the gradients of several backward operations in the parameter.grad tensors
# This is straightforward to do in PyTorch as the gradient tensors are not reset unless we call model.zero_grad() or optimizer.zero_grad()
# We’ll also need to divide by the number of accumulation steps if our loss is averaged over the training samples.

In [None]:
# model.zero_grad()                                   # Reset gradients tensors
# for i, (inputs, labels) in enumerate(training_set):
#     predictions = model(inputs)                     # Forward pass
#     loss = loss_function(predictions, labels)       # Compute loss function
#     loss = loss / accumulation_steps                # Normalize our loss (if averaged)
#      PyTorch 会对计算图进行反向传播，把梯度累加到每个叶子张量（例如模型的权重）的 .grad 属性中。
#     loss.backward()                                 # Backward pass
#     if (i+1) % accumulation_steps == 0:             # Wait for several backward steps
#         optimizer.step()                            # Now we can do an optimizer step
#         model.zero_grad()                           # Reset gradients tensors
#         if (i+1) % evaluation_steps == 0:           # Evaluate the model when we...
#             evaluate_model()                        # ...have no gradients accumulated

In [None]:
# Can you train a model for which not even a single sample can fit on a GPU?
# gradient checkpoint! 这个有点超纲了，，暂时不看了

In [None]:
# DataParallel
# 这个不看了，毕竟我用的是DDP
# parallel_model = torch.nn.DataParallel(model) # Encapsulate the model

# predictions = parallel_model(inputs)          # Forward pass on multi-GPUs
# loss = loss_function(predictions, labels)     # Compute loss function
# loss.mean().backward()                        # Average GPU-losses + backward pass
# optimizer.step()                              # Optimizer step
# predictions = parallel_model(inputs)          # Forward pass with new parameters


In [None]:
# DDP
# But be careful: while the code looks similar, training your model in a distributed setting will change your workflow 
# because you will actually have to start an independent python training script on each node (these scripts are all identical).
# As we will see, once started, these training scripts will be synchronized together by PyTorch distributed backend.
#
# In practice, this means that each training script will have:
# - its own optimizer and performs a complete optimization step with each iteration, no parameter broadcast (step 2 in DataParallel) is needed,
# - an independent Python interpreter: this will also avoid the GIL-freeze that can come from driving several parallel execution threads in a single Python interpreter.

In [None]:
# generated by ChatGPT
# refs:
# 1. https://docs.pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel.no_sync
# 2. https://discuss.pytorch.org/t/whats-no-sync-exactly-do-in-ddp/170259
# context manager no_sync() 必须也要包括forward过程，因为DPP会在forward过程放置gradient hook，我记得文档中提到过，我要找一下
# https://docs.pytorch.org/docs/main/notes/ddp.html
# 在这里提到了autograd hook是在construction阶段放的，并不是在forward阶段放的
# 不过既然no_sync的文档专门提到了这一点，还是根据文档的说明来吧
# The forward pass should be included inside the context manager, or else gradients will still be synchronized.
# 到时候可以测试一下？看看训练速度是不是有提升

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader, DistributedSampler
from torchvision import datasets, transforms
import torch.nn as nn
import torch.optim as optim
import contextlib

# 简单的 MLP 模型用于 MNIST 分类
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        return self.net(x)

# 单个进程的训练逻辑（每个 GPU 运行一个）
def train(rank, world_size, accumulation_steps=4):
    # 设置当前进程的设备
    torch.cuda.set_device(rank)

    # 初始化进程组
    dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)

    # 创建模型并放到对应 GPU 上
    model = SimpleModel().to(rank)
    model = DDP(model, device_ids=[rank])

    # 准备数据（使用 DistributedSampler 确保每个进程加载不同数据）
    transform = transforms.ToTensor()
    dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)
    dataloader = DataLoader(dataset, batch_size=64, sampler=sampler)

    # 设置优化器、损失函数、AMP 缩放器
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.CrossEntropyLoss()
    scaler = GradScaler()

    model.train()
    optimizer.zero_grad()

    for step, (inputs, labels) in enumerate(dataloader):
        inputs, labels = inputs.to(rank), labels.to(rank)

        # 是否是累积的最后一步（是否执行梯度同步与 step）
        # 因为step从0开始，所以累积步数是 step + 1
        is_final_accum = (step + 1) % accumulation_steps == 0

        # 如果不是最后一步，禁用 DDP 的梯度同步
        sync_context = model.no_sync() if not is_final_accum else contextlib.nullcontext()
        with sync_context:
            with autocast():  # 自动混合精度的上下文
                outputs = model(inputs)
                loss = loss_fn(outputs, labels)
                loss = loss / accumulation_steps  # 均分 loss 实现梯度缩放
            scaler.scale(loss).backward()

        # 还有一个事情，如果要记录checkpoint，那么只有accumulate steps的倍数才有意义！
        # 包括记录日志，记录loss等，
        # 这样设计就会修改训练的逻辑，不如step保持现在的含义不变
        # 我们将这个acc的过程包装起来，是的对外表现的就像是一个step一样，这样是最合理的！

        # 累积满了才执行优化器更新
        if is_final_accum:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

            # 仅 rank 0 打印日志
            if rank == 0 and step % 100 == 0:
                print(f"[Rank {rank}] Step {step}, Loss: {loss.item():.4f}")

    # 训练结束，销毁进程组
    dist.destroy_process_group()

# 使用 spawn 启动多进程（每个 GPU 一个进程）
def main():
    world_size = torch.cuda.device_count()
    mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)

if __name__ == "__main__":
    main()