In [4]:
import os
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.utils.data import DistributedSampler, DataLoader
from torch.nn.parallel import DistributedDataParallel
from torch.optim import Adam

In [5]:
class SimpleModel(nn.Module):
    def __init__(self, in_features, hidden_features, out_features):
        super().__init__()

        self.in_features = in_features
        self.hidden_features = hidden_features
        self.out_features = out_features

        self.layers = nn.Sequential(
            nn.Linear(in_features, hidden_features),
            nn.GELU(),
            nn.Linear(hidden_features, hidden_features),
            nn.GELU(),
            nn.Linear(hidden_features, out_features),
        )

    def forward(self, x):
        out = self.layers(x)
        return out

In [6]:
class DataParallelism:
    def __init__(self):
        self.rank = os.environ['RANK']
        self.world_size = os.environ['WORLD_SIZE']

        self.setup_rank()

        self.in_features = os.environ['IN_FEATURES']
        self.hidden_features = os.environ['HIDDEN_FEATURES']
        self.out_features = os.environp['OUT_FEATURES']

        self.model = SimpleModel(self.in_features, self.hidden_features, self.out_features)
        self.gradient_accumulation = os.environ['GRADIENT_ACCUMULATION']

        ddp = None
        if not self.gradient_accumulation:
            ddp = DistributedDataParallel(
                module=self.model,
                device_ids=[self.rank],
                bucket_cap_mb=25,
                gradient_as_bucket_view=True
            )
        else:
            ddp = DistributedDataParallel(
                module=self.model,
                device_ids=[self.rank],
            )

        self.ddp = ddp

    def setup_rank(self):
        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = '12355'

        dist.init_process_group(
            backend='nccl',
            world_size=self.world_size,
            rank=self.rank
        )

        torch.cuda.set_device(self.rank)

    def clean_rank(self):
        dist.destroy_process_group()

    def train(self, n_samples):
        dataset = torch.randn((n_samples, self.in_features))
        labels = torch.randint(0, self.out_features, (n_samples,))

        sampler = DistributedSampler(
            dataset=dataset,
            num_replicas=self.world_size, 
            rank=self.rank,
            shuffle=True
        )

        dataloader = DataLoader(
            list(zip(dataset, labels)),
            batch_size=os.environ['BATCH_SIZE'],
            sampler=sampler
        )

        if not self.gradient_accumulation:
            self.train_with_no_ga(dataloader, sampler)
        else:
            self.train_with_ga(dataloader, sampler)

    def train_with_no_ga(self, dataloader: DataLoader, sampler: DistributedSampler):
        n_epochs = os.environ['n_epochs']
        optimizer = Adam(self.ddp.parameters(), lr=1e-3)
        loss_fn = nn.CrossEntropyLoss()

        for epoch in range(n_epochs):
            sampler.set_epoch(epoch)

            for _, (batch, y) in enumerate(dataloader):
                batch = batch.to(self.rank)
                y = y.to(self.rank)

                optimizer.zero_grad()
                output = self.ddp(batch)

                loss = loss_fn(output, y)
                loss.backward()

                optimizer.step()

    def train_with_ga(self, dataloader: DataLoader, sampler: DistributedSampler):
        n_epochs = os.environ['n_epochs']
        optimizer = Adam(self.ddp.parameters(), lr=1e-3)
        loss_fn = nn.CrossEntropyLoss()

        accumulation_steps = os.environ['ACCUMULATION_STEPS']

        for epoch in range(n_epochs):
            sampler.set_epoch(epoch)

            for idx, (batch, y) in enumerate(dataloader):
                batch = batch.to(self.rank)
                y = y.to(self.rank)

                with self.ddp.no_sync() if (idx + 1) % accumulation_steps != 0 else None:
                    output = self.ddp(batch)
                    loss = loss_fn(output, y)
                    loss.backward()

                if (idx + 1) % accumulation_steps == 0:
                    optimizer.step()
                    optimizer.zero_grad()