# Distributed Data Parallel Using 1,2,4 GPUs on Vision Transformer

## How to run
- Run the cell Below
- The model will save train_ddp_vit.py
- Run this command to train with or with Mixed Precision
- python train_ddp_vit.py --world_size=2 --use_amp
- world_size is number of GPUs
- use_amp is the flag for running mixed precision
- Run mixed precision on V100, it will not work on P100 GPU

In [1]:
%%writefile train_ddp_vit.py
import os
import time
import argparse
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import Dataset, DataLoader, DistributedSampler
from torchvision import transforms
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image
import pandas as pd
from sklearn.metrics import accuracy_score

# Constants
NUM_EPOCHS = 2
BATCH_SIZE = 16
IMG_SIZE = 224
SUBSET_SIZE = 8000
DATASET_DIR = "dataset"
CSV_PATH = os.path.join(DATASET_DIR, "train.csv")
IMAGE_DIR = os.path.join(DATASET_DIR, "train_data")


def setup(rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)


def cleanup():
    dist.destroy_process_group()


class ImageDataset(Dataset):
    def __init__(self, dataframe, img_dir, transform=None):
        self.df = dataframe
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        rel_path = self.df.loc[idx, 'file_name']
        img_path = os.path.join(self.img_dir, os.path.basename(rel_path))
        image = Image.open(img_path).convert('RGB')
        label = int(self.df.loc[idx, 'label'])
        if self.transform:
            image = self.transform(image)
        return image, label


def get_data_loaders(rank, world_size):
    df = pd.read_csv(CSV_PATH)
    df.columns = df.columns.str.strip()
    df = df.drop(columns=['Unnamed: 0'], errors='ignore')
    df = df.sample(frac=1.0, random_state=42).reset_index(drop=True)

    train_df = df.iloc[:10000].reset_index(drop=True)
    val_df = df.iloc[10000:12000].reset_index(drop=True)
    test_df = df.iloc[12000:14000].reset_index(drop=True)

    feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
    transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
    ])

    train_dataset = ImageDataset(train_df, IMAGE_DIR, transform)
    val_dataset = ImageDataset(val_df, IMAGE_DIR, transform)
    test_dataset = ImageDataset(test_df, IMAGE_DIR, transform)

    train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
    val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank)
    test_sampler = DistributedSampler(test_dataset, num_replicas=world_size, rank=rank)

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=train_sampler, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, sampler=val_sampler, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, sampler=test_sampler, num_workers=2)

    return train_loader, val_loader, test_loader


def train(model, loader, optimizer, criterion, device, scaler, use_amp):
    model.train()
    total_loss = 0
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()

        with torch.amp.autocast(device_type='cuda', enabled=use_amp):
            outputs = model(images).logits
            loss = criterion(outputs, labels)

        if use_amp:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()

        total_loss += loss.item()
    return total_loss / len(loader)


def evaluate(model, loader, device):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            outputs = model(images).logits
            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.numpy())
    return accuracy_score(all_labels, all_preds)


def main(rank, world_size, use_amp):
    setup(rank, world_size)

    device = torch.device(f"cuda:{rank}")
    train_loader, val_loader, test_loader = get_data_loaders(rank, world_size)

    model = ViTForImageClassification.from_pretrained(
        'google/vit-base-patch16-224',
        num_labels=2,
        ignore_mismatched_sizes=True
    ).to(device)

    model = DDP(model, device_ids=[rank])
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

    if rank == 0:
        print(f"{'With' if use_amp else 'Without'} AMP, Training with {world_size} GPU(s)")

    start = time.time()
    for epoch in range(NUM_EPOCHS):
        train_loader.sampler.set_epoch(epoch)
        train_loss = train(model, train_loader, optimizer, criterion, device, scaler, use_amp)
        val_acc = evaluate(model, val_loader, device)
        if rank == 0:
            print(f"Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Val Acc = {val_acc:.4f}")
    total_time = time.time() - start

    if rank == 0:
        test_acc = evaluate(model, test_loader, device)
        print(f"Test Accuracy: {test_acc:.4f}")
        print(f"Total training time: {total_time:.2f} seconds")

    cleanup()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--world_size", type=int, default=1, help="number of GPUs")
    parser.add_argument("--use_amp", action="store_true", help="use Automatic Mixed Precision")
    parser.add_argument("--master_addr", type=str, default="127.0.0.1", help="Master address")
    parser.add_argument("--master_port", type=str, default="29501", help="Master port")
    args = parser.parse_args()

    os.environ["MASTER_ADDR"] = args.master_addr
    os.environ["MASTER_PORT"] = args.master_port

    torch.multiprocessing.spawn(
        main,
        args=(args.world_size, args.use_amp),
        nprocs=args.world_size,
        join=True
    )



Overwriting train_ddp_vit.py
