In [1]:
import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from dataclasses import dataclass, field, asdict
from typing import Optional
from transformers.hf_argparser import HfArgumentParser
import logging

# 设置日志记录
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# 从 basics 导入必要的组件
from basics.data import Dataset  # 确保 data.py 文件在您的项目中
from basics.nn_utils import cross_entropy, clip_gradient
from basics.model import BasicTransformerLM  # 确保 model.py 文件在您的项目中
from basics.optimizer import AdamW, get_cosine_lr

# 训练配置 (与 run_train.sh 中的参数类似)
@dataclass
class TrainingConfig:
    dataset_name: str = field(default="tinystory")
    vocab_size: int = field(default=10000)
    context_length: int = field(default=256)
    batch_size: int = field(default=16)  # 每个 GPU 的 batch size
    d_model: int = field(default=512)
    num_layers: int = field(default=4)
    num_heads: int = field(default=16)
    d_ff: int = field(default=1344)
    total_iters: int = field(default=20000)
    max_learning_rate: float = field(default=5e-4)
    cosine_cycle_iters: int = field(default=20000)
    weight_decay: float = field(default=0.001)
    device: str = field(default='cuda' if torch.cuda.is_available() else 'cpu')
    wandb_logging: bool = field(default=False)
    wandb_project: Optional[str] = field(default="cs336-assignment1")
    wandb_run_name: Optional[str] = field(default="tinystories-ddp")
    eval_interval: int = field(default=200)
    log_interval: int = field(default=20)

    def __post_init__(self):
        self.warmup_iters = int(self.total_iters * 0.01)
        if self.wandb_logging:
            assert self.wandb_project is not None, 'wandb_project must be provided if wandb_logging is True'
            assert self.wandb_run_name is not None, 'wandb_run_name must be provided if wandb_logging is True'

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'  # 选择一个未使用的端口

    # 初始化进程组
    dist.init_process_group("nccl", rank=rank, world_size=world_size)  # 使用 "nccl" 作为后端
    torch.cuda.set_device(rank)  # 将当前进程绑定到相应的 GPU
    logging.info(f"Rank {rank}: 已初始化进程组，使用设备: cuda:{rank}")


def cleanup():
    dist.destroy_process_group()
    logging.info(f"Rank {dist.get_rank()}: 已销毁进程组")

def main(rank, world_size, config):
    setup(rank, world_size)

    # 每个进程使用其 rank 确定的随机种子
    torch.manual_seed(42 + rank)
    
    # 在当前进程使用的 GPU 上创建模型
    model = BasicTransformerLM(**asdict(config)).to(config.device)
    ddp_model = DDP(model, device_ids=[rank])

    # 加载数据集，每个进程加载全部数据，但使用 DistributedSampler 进行划分
    dataset = Dataset(**asdict(config))
    train_sampler = DistributedSampler(dataset.train_data, num_replicas=world_size, rank=rank, shuffle=True)
    val_sampler = DistributedSampler(dataset.val_data, num_replicas=world_size, rank=rank, shuffle=False)
    
    train_loader = DataLoader(dataset.train_data, batch_size=config.batch_size, sampler=train_sampler)
    val_loader = DataLoader(dataset.val_data, batch_size=config.batch_size, sampler=val_sampler)

    # 优化器
    optimizer = AdamW(ddp_model.parameters(), **asdict(config))

    # 训练循环
    iter_num = 0
    while iter_num < config.total_iters:
        train_sampler.set_epoch(iter_num)  # 设置 sampler 的 epoch, 以确保 shuffle
        for batch_idx, batch in enumerate(train_loader):
            optimizer.zero_grad()
            x, y = dataset.get_batch_from_data(batch, config.context_length)  # 使用新的 get_batch_from_data
            x, y = x.to(config.device), y.to(config.device)
            logits = ddp_model(x)
            loss = cross_entropy(logits, y)
            loss.backward()
            clip_gradient(ddp_model.parameters(), 1.0)
            lr = get_cosine_lr(iter_num, **asdict(config))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
            optimizer.step()

            # 日志记录 (仅在 rank 0 上记录)
            if iter_num % config.log_interval == 0 and rank == 0:
                logging.info(f'Iter: {iter_num}, Train loss: {loss.item():.4f}, LR: {lr:.6f}')

            # 评估 (仅在 rank 0 上评估)
            if iter_num % config.eval_interval == 0 and rank == 0:
                eval(ddp_model, val_loader, dataset, config, iter_num, lr)  # 修改为接受 val_loader

            iter_num += 1
            if iter_num >= config.total_iters:
                break  # 达到总迭代次数后退出
    cleanup()  # 完成后清理

# 添加一个从原始数据获取 batch 的函数
def get_batch_from_data(data: torch.Tensor, context_length: int) -> tuple[torch.Tensor, torch.Tensor]:
    # 随机选择起始索引，确保有足够的长度
    start_idx = torch.randint(0, len(data) - context_length, (1,)).item()
    x = data[start_idx:start_idx + context_length].long()  # 转换为 Long 类型
    y = data[start_idx + 1:start_idx + 1 + context_length].long()
    return x, y

def eval(model, val_loader, dataset, config, iter_num, lr):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in val_loader:  # 遍历 val_loader
            x, y = dataset.get_batch_from_data(batch, config.context_length)  # 使用 get_batch_from_data
            x, y = x.to(config.device), y.to(config.device)
            logits = model(x)
            loss = cross_entropy(logits, y)
            total_loss += loss.item()
    total_loss /= len(val_loader)
    logging.info(f'Iter: {iter_num}, Val loss: {total_loss:.4f}, LR: {lr:.6f}')
    model.train()

if __name__ == "__main__":
    # 解析配置
    parser = HfArgumentParser(TrainingConfig)
    config = parser.parse_args_into_dataclasses()[0]

    # 设置 world_size (使用所有可用的 GPU)
    world_size = torch.cuda.device_count()
    logging.info(f"使用 {world_size} 个 GPU 进行训练")
    
    # 使用 torch.multiprocessing.spawn 启动多个进程
    torch.multiprocessing.spawn(main, args=(world_size, config), nprocs=world_size, join=True)


ModuleNotFoundError: No module named 'basics'

In [None]:
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "29500"
    # 初始化进程组，backend可改为'nccl'如果用GPU
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

def distributed_demo(rank, world_size):
    setup(rank, world_size)
    # 这里如果用GPU，建议设置设备：torch.cuda.set_device(rank)
    
    data = torch.randint(0, 10, (3,))
    print(f"Rank {rank} data (before all-reduce): {data}")
    # all_reduce累加所有进程数据，结果同步到每个进程
    dist.all_reduce(data, async_op=False)
    print(f"Rank {rank} data (after all-reduce): {data}")
    
    dist.destroy_process_group()  # 结束进程组，防止资源泄漏

if __name__ == "__main__":
    world_size = 4
    mp.spawn(fn=distributed_demo, args=(world_size,), nprocs=world_size, join=True)


tensor([0.7041, 0.2253, 0.4310, 0.8126, 0.5045])