In [25]:
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
import os

In [26]:
def setup_rank(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
 
    dist.init_process_group(
        backend='nccl',
        rank=rank,
        world_size=world_size
    )

    torch.cuda.set_device(rank)

In [27]:
def cleanup_rank():
    dist.destroy_process_group()

In [28]:
class SimpleModel(nn.Module):
    def __init__(self, in_features, d_features, out_features):
        super().__init__()

        self.in_features = in_features
        self.d_features = d_features
        self.out_features = out_features

        self.layers = nn.Sequential(
            nn.Linear(in_features, d_features),
            nn.GELU(),
            nn.Linear(d_features, d_features),
            nn.GELU(),
            nn.Linear(d_features, out_features)
        )

    def forward(self, x):
        return self.layers(x)

In [29]:
def train_ddp():
    rank = os.environ['rank']
    world_size = os.environ['world_size']

    setup_rank(rank, world_size)

    in_features = os.environ['in_features']
    d_features = os.environ['d_features']
    out_features = os.environ['out_features']
    n_samples = os.environ['n_samples']

    dataset = torch.randn((n_samples, in_features))
    labels = torch.randint(0, 10, (n_samples,))

    model = SimpleModel(in_features, d_features, out_features)

    ddp_model = DDP(
        module=model,
        device_ids=[rank],
        bucket_cap_mb=25,
        gradient_as_bucket_view=True
    )

    sampler = DistributedSampler(
        dataset=dataset,
        num_replicas=world_size,
        rank=rank,
        shuffle=True
    )

    dataloader = DataLoader(
        list(zip(dataset, labels)),
        batch_size=16,
        sampler=sampler
    )

    optimizer = torch.optim.Adam(params=ddp_model.parameters(), lr=1e-3)
    loss_fn = nn.CrossEntropyLoss()

    for epoch in range(os.environ['n_epochs']):
        sampler.set_epoch(epoch)

        for batch_idx, (data, label) in enumerate(dataloader):
            data = data.to(rank)
            label = label.to(rank)

            optimizer.zero_grad()
            output = ddp_model(data)

            loss = loss_fn(output, label)
            loss.backward()

            optimizer.step()

    cleanup_rank()

In [30]:
def train_ddp_gradient_accumulation():
    rank = os.environ['rank']
    world_size = os.environ['world_size']

    setup_rank(rank, world_size)

    in_features = os.environ['in_features']
    d_features = os.environ['d_features']
    out_features = os.environ['out_features']
    n_samples = os.environ['n_samples']

    dataset = torch.randn((n_samples, in_features))
    labels = torch.randint(0, 10, (n_samples,))

    model = SimpleModel(in_features, d_features, out_features)

    ddp_model = DDP(
        module=model,
        device_ids=[rank],
    )

    sampler = DistributedSampler(
        dataset=dataset,
        num_replicas=world_size,
        rank=rank,
        shuffle=True
    )

    dataloader = DataLoader(
        list(zip(dataset, labels)),
        batch_size=16,
        sampler=sampler
    )

    optimizer = torch.optim.Adam(params=ddp_model.parameters(), lr=1e-3)
    loss_fn = nn.CrossEntropyLoss()
    accumulation_steps = os.environ['accumulation_steps']

    for epoch in range(os.environ['n_epochs']):
        sampler.set_epoch(epoch)

        for batch_idx, (data, label) in enumerate(dataloader):
            data = data.to(rank)
            label = label.to(rank)

            with ddp_model.no_sync() if (batch_idx + 1) % accumulation_steps != 0 else None:
                output = ddp_model(data)
                loss = loss_fn(output, label)
                loss.backward()

            if (batch_idx + 1) % accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()

    cleanup_rank()