# Data Parallelism

This notebook introduces data parallelism.

## torch.nn.DataParallel

`torch.nn.DataParallel` enables training on multiple GPUs within a single machine by splitting input data across devices and synchronizing gradients after each step.

### 1) Forward Pass

1. The input mini-batch is **scattered** across the GPUs.
2. Model parameters on GPU 0 are **broadcast** to the other GPUs.
3. Each device performs the **forward pass** to compute logits.
4. The logits from all devices are **gathered** on GPU 0.
5. The final **loss** is computed from the gathered logits (with reduction).

![](../01_Basics/images/dp_forward.png)

<br>

The same process can be expressed in code as shown below.

In [10]:
! pip install torch
import torch.nn as nn

def data_parallel(module, inputs, labels, device_ids, output_device):
    # Scatter inputs across devices
    inputs = nn.parallel.scatter(inputs, device_ids)

    # Replicate model on each device
    replicas = nn.parallel.replicate(module, device_ids)
   
    # Run forward pass in parallel
    outputs = nn.parallel.parallel_apply(replicas, inputs)

    # Gather outputs to a single device
    logits = nn.parallel.gather(outputs, output_device)

    return logits


[notice] A new release of pip is available: 24.0 -> 25.3
[notice] To update, run: C:\Users\kisho\AppData\Local\Microsoft\WindowsApps\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\python.exe -m pip install --upgrade pip


Collecting torch
  Downloading torch-2.9.1-cp311-cp311-win_amd64.whl.metadata (30 kB)
Collecting filelock (from torch)
  Downloading filelock-3.20.0-py3-none-any.whl.metadata (2.1 kB)
Collecting sympy>=1.13.3 (from torch)
  Downloading sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting networkx>=2.5.1 (from torch)
  Downloading networkx-3.6-py3-none-any.whl.metadata (6.8 kB)
Collecting jinja2 (from torch)
  Downloading jinja2-3.1.6-py3-none-any.whl.metadata (2.9 kB)
Collecting fsspec>=0.8.5 (from torch)
  Downloading fsspec-2025.12.0-py3-none-any.whl.metadata (10 kB)
Collecting mpmath<1.4,>=1.1.0 (from sympy>=1.13.3->torch)
  Downloading mpmath-1.3.0-py3-none-any.whl.metadata (8.6 kB)
Collecting MarkupSafe>=2.0 (from jinja2->torch)
  Downloading markupsafe-3.0.3-cp311-cp311-win_amd64.whl.metadata (2.8 kB)
Downloading torch-2.9.1-cp311-cp311-win_amd64.whl (111.0 MB)
   ---------------------------------------- 0.0/111.0 MB ? eta -:--:--
   ---------------------------------------- 

### 2) Backward Pass

1. The computed loss is **scattered** to all devices.
2. Each device runs **backward()** to compute gradients.
3. All gradients are **reduced** (summed) to GPU 0.
4. The model parameters on GPU 0 are updated using the reduced gradients.

![](../images/dp_backward.png)


#### For clarification
- `loss.backward()` computes gradients by backpropagation.
- `optimizer.step()` updates parameters using the computed gradients.
- Backward computation is more expensive than the update step.


In [None]:
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from transformers import BertForSequenceClassification, BertTokenizer
from datasets import load_dataset

# 1. Load dataset
datasets = load_dataset("multi_nli").data["train"]
datasets = [
    {
        "premise": str(p),
        "hypothesis": str(h),
        "labels": l.as_py(),
    }
    for p, h, l in zip(datasets[2], datasets[5], datasets[9])
]

# Create DataLoader
data_loader = DataLoader(datasets, batch_size=128, num_workers=4)

# 2. Load pretrained model and tokenizer
model_name = "bert-base-cased"
tokenizer = BertTokenizer.from_pretrained(model_name)

# Load model and move to GPU
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=3).cuda()

# 3. Enable Data Parallelism (Multi-GPU)
# device_ids: list of GPU IDs to use
# output_device: GPU where outputs are gathered
model = nn.DataParallel(model, device_ids=[0, 1, 2, 3], output_device=0)

# 4. Optimizer and loss function
optimizer = Adam(model.parameters(), lr=3e-5)
loss_fn = nn.CrossEntropyLoss(reduction="mean")

# 5. Training loop
for i, data in enumerate(data_loader):

    # Clear old gradients
    optimizer.zero_grad()

    # Tokenize input text
    tokens = tokenizer(
        data["premise"],
        data["hypothesis"],
        padding=True,
        truncation=True,
        max_length=512,
        return_tensors="pt"
    )

    # Forward pass (automatically spread across GPUs)
    logits = model(
        input_ids=tokens.input_ids.cuda(),
        attention_mask=tokens.attention_mask.cuda(),
        return_dict=False
    )[0]

    # Compute loss
    loss = loss_fn(logits, data["labels"].cuda())

    # Backward pass (gradients are synchronized automatically)
    loss.backward()

    # Update model parameters
    optimizer.step()

    # Print training status
    if i % 10 == 0:
        print(f"step:{i}, loss:{loss}")

    # Stop early for demo
    if i == 300:
        break

In [None]:
!python ../src/data_parallel.py


Training works correctly on multiple GPUs. However, a problem occurs because all **logits** are gathered on GPU 0, which can cause GPU memory imbalance.

This can be improved by gathering the **loss** instead of the logits. Since the loss is a scalar, it uses much less memory.  
This idea is similar to the `DataParallelCriterion` approach, but it can be implemented more simply by overriding the model‚Äôs `forward()` function.

![](../images/dp_forward_2.png)


The key idea is to perform **loss computation inside the forward pass**.  
Because the forward function runs in parallel across GPUs, computing the loss there ensures that loss calculation and reduction happen in parallel.

One side effect is that **loss reduction happens twice**:
- First, each GPU reduces its local mini-batch loss.
- Then, the reduced losses from all GPUs are combined into a single value.

Even with double reduction, this approach is more efficient because it:
- Reduces memory usage on GPU 0.
- Parallelizes loss computation.
- Improves overall balance across devices.


In [None]:
from torch import nn


# Standard model that outputs logits
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(768, 3)

    def forward(self, inputs):
        logits = self.linear(inputs)
        return logits


# Parallel model that computes loss inside forward pass
class ParallelLossModel(Model):
    def __init__(self):
        super().__init__()

    def forward(self, inputs, labels):
        # Compute logits using base model
        logits = super(ParallelLossModel, self).forward(inputs)
        
        # Compute loss on each GPU independently
        loss = nn.CrossEntropyLoss(reduction="mean")(logits, labels)
        
        # Return loss instead of logits
        return loss

Fortunately, most Hugging Face Transformers models already support computing the loss directly inside the forward pass.

By providing the labels to the model‚Äôs `labels` argument, the model returns the loss automatically.  
This allows us to use parallel loss computation without writing any custom loss logic.

In [None]:
"""
src/efficient_data_parallel.py
"""

# (Steps 1‚Äì4 are omitted for brevity)

# 5. Start training loop
for i, data in enumerate(data_loader):

    # Clear previous gradients
    optimizer.zero_grad()

    # Tokenize input text
    tokens = tokenizer(
        data["premise"],
        data["hypothesis"],
        padding=True,
        truncation=True,
        max_length=512,
        return_tensors="pt",
    )

    # Forward pass (model computes loss internally)
    loss = model(
        input_ids=tokens.input_ids.cuda(),
        attention_mask=tokens.attention_mask.cuda(),
        labels=data["labels"],
    ).loss

    # Reduce loss across GPUs (e.g., shape: [num_gpus] ‚Üí [1])
    loss = loss.mean()

    # Backpropagation
    loss.backward()

    # Update parameters
    optimizer.step()

    # Print training status
    if i % 10 == 0:
        print(f"step:{i}, loss:{loss}")

    # Stop early for demonstration
    if i == 300:
        break


In [None]:
!python ../src/efficient_data_parallel.py

## 2. Limitations of `torch.nn.DataParallel`

### 1) Inefficient multi-threading in Python
`DataParallel` uses multi-threading, but Python is limited by the **Global Interpreter Lock (GIL)**.  
This prevents true parallel execution within a single process.  
For better performance, training should use **multi-process execution** instead.

### 2) Model replication overhead
Gradients are gathered on one GPU and the model is updated there.  
After each update, the model must be broadcast back to all GPUs.  
This repeated synchronization is expensive and limits scalability.

### Solution ‚Üí All-Reduce üëç
![](../images/allreduce.png)

Instead of collecting gradients on one GPU, **all-reduce** sums gradients across all GPUs and shares the result with every device.  
This allows each GPU to update its own copy of the model without rebroadcasting.

### However...
All-reduce itself is computationally expensive.

## 3. `torch.nn.parallel.DistributedDataParallel` (DDP)

### Ring All-Reduce
Ring all-reduce is a communication method introduced in 2017 to improve gradient synchronization performance.  
It became the foundation of DDP because it is much more efficient than earlier approaches.

![](../images/ring_allreduce.gif)


### What is DDP?
`DistributedDataParallel` (DDP) is a data-parallel training module designed to fix the limitations of `DataParallel`.  
It works on both **single-node and multi-node systems** using **multiple processes** instead of threads.

By using **all-reduce**, DDP removes the need for a master GPU.  
Each process updates its own model, making training faster and more scalable.





In [None]:
"""
src/ddp.py
"""

import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.optim import Adam
from torch.utils.data import DataLoader, DistributedSampler
from transformers import BertForSequenceClassification, BertTokenizer
from datasets import load_dataset

# 1. Initialize distributed process group
dist.init_process_group("nccl")
rank = dist.get_rank()              # Process ID
world_size = dist.get_world_size()  # Total number of processes
torch.cuda.set_device(rank)         # Assign one GPU per process
device = torch.cuda.current_device()

# 2. Load dataset
datasets = load_dataset("multi_nli").data["train"]
datasets = [
    {
        "premise": str(p),
        "hypothesis": str(h),
        "labels": l.as_py(),
    }
    for p, h, l in zip(datasets[2], datasets[5], datasets[9])
]

# 3. Create DistributedSampler
# This splits the dataset across multiple processes
sampler = DistributedSampler(
    datasets,
    num_replicas=world_size,
    rank=rank,
    shuffle=True,
)

data_loader = DataLoader(
    datasets,
    batch_size=32,
    num_workers=4,
    sampler=sampler,
    shuffle=False,     # Must be False when using DistributedSampler
    pin_memory=True,
)

# 4. Load model and tokenizer
model_name = "bert-base-cased"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=3).cuda()

# 5. Wrap model with DistributedDataParallel
model = DistributedDataParallel(model, device_ids=[device], output_device=device)

# 6. Create optimizer
optimizer = Adam(model.parameters(), lr=3e-5)

# 7. Training loop
for i, data in enumerate(data_loader):

    # Clear gradients
    optimizer.zero_grad()

    # Tokenize input text
    tokens = tokenizer(
        data["premise"],
        data["hypothesis"],
        padding=True,
        truncation=True,
        max_length=512,
        return_tensors="pt",
    )

    # Forward pass (each process runs independently)
    loss = model(
        input_ids=tokens.input_ids.cuda(),
        attention_mask=tokens.attention_mask.cuda(),
        labels=data["labels"],
    ).loss

    # Backpropagation (gradients are synchronized automatically)
    loss.backward()

    # Update model parameters locally on each GPU
    optimizer.step()

    # Print only from main process
    if i % 10 == 0 and rank == 0:
        print(f"step:{i}, loss:{loss}")

    # Stop early for demonstration
    if i == 300:
        break

Note: `torch.distributed.launch` is deprecated.  
Use `torchrun` for newer PyTorch versions.

In [None]:
!python -m  torch.distributed.launch --nproc_per_node=4 ../src/ddp.py