In [1]:
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
import os

In [2]:
def setup_rank():
    rank = os.environ['rank']
    world_size = os.environ['world_size']

    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    dist.init_process_group(
        backend='nccl',
        rank=rank,
        world_size=world_size
    )

    torch.cuda.set_device(rank)

In [4]:
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.layers = nn.Sequential(
            nn.Linear(1024, 2048),
            nn.ReLU(),
            nn.Linear(2048, 2048),
            nn.ReLU(),
            nn.Linear(2048, 10)
        )

    def forward(self, x):
        return self.layers(x)

In [5]:
def cleanup():
    dist.destroy_process_group()

In [7]:
def train_ddp():
    rank = os.environ['rank']
    world_size = os.environ['world_size']

    setup()

    model = SimpleModel().to(rank)
    ddp_model = DDP(,
        model, 
        device_ids=[rank],
        bucket_cap_mb=10,
        gradient_as_bucket_view=True
    )

    dataset = torch.randn(1000, 1024)
    label_dataset = torch.randint(0, 10, (1000,))

    sampler = DistributedSampler(
        dataset,
        num_replicas=world_size,
        rank=rank,
        shuffle=True
    )

    dataloader = DataLoader(
        list(zip(dataset, label_dataset)),
        batch_size=16,
        sampler=sampler
    )

    optim = torch.optim.Adam(ddp_model.parameters(), lr=1e-3)
    loss_fn = nn.CrossEntropyLoss()

    ddp_model.train()
    for epoch in range(10):
        sampler.set_epoch(epoch)

        for batch_idx, (data, labels) in enumerate(dataloader):
            data = data.to(rank)
            labels = labels.to(rank)

            optimizer.zero_grad()
            output = ddp_model(data)
            loss = loss_fn(output, labels)

            loss.backward()
            optimizer.step()

    cleanup()

SyntaxError: invalid syntax (3601827007.py, line 8)

In [8]:
def train_with_gradient_accumulation():
    rank = os.environ['rank']
    world_size = os.environ['world_size']

    setup_rank()

    model = SimpleModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    dataset = torch.randn(1000, 1024)
    label_dataset = torch.randint(0, 10, (1000,))

    sampler = DistributedSample(
        dataset,
        num_replicas=world_size,
        rank=rank,
        shuffle=True
    )

    dataloader = DataLoader(
        list(zip(dataset, label_dataset)),
        batch_size=16,
        sampler=sampler
    )

    optimizer = torch.optim.Adam(ddp_model.parameters(), lr=1e-3)
    loss_fn = nn.CrossEntropy()
    accumulation_steps = os.environ['accumulation_steps']

    for epoch in range(10):
        sampler.set_epoch(epoch)

        for batch_idx, (data, labels) in enumerate(dataLoader):
            data, labels = data.to(rank), labels.to(rank)

            with ddp_model.no_sync() if (batch_idx + 1) % accumulation_steps != 0 else nullcontext():
                outputs = ddp_model(data)
                loss = loss_fn(outputs, labels) / accumulation_steps
                loss.backward()

            if (batch_idx + 1) % accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()

    cleanup()
