# NHR Summer School – Data-Parallel Neural Networks with `PyTorch`
##### Dr. Charlotte Debus (charlotte.debus@kit.edu), Dr. Marie Weiel (marie.weiel@kit.edu), and David Li (david.li@kit.edu)
#### Agenda

| W H E N           | W H A T                                                 |
| :-----------------| :------------------------------------------------------ |
| **09:00 - 10:15** | **Introduction to Neural Networks**                     |  
|                   | Backpropagation and Stochastic Gradient Descent (SGD)   |  
|                   | Layer Architectures                                     |  
|                   | Training a Neural Network                               |  
| 10:15 - 10:30     | *It's coffee o'clock!*                                  |
| **10:30 - 12:00** | **Hands-on Session: Neural Networks with `PyTorch`**    |  
| 12:00 - 13:00     | *Enjoy your lunch break!*                               |  
| **13:00 - 14:15** | **Data-Parallel Neural Networks**                       |  
|                   | Parallelization Strategies for Neural Networks          |  
|                   | Distributed SGD                                         |  
|                   | IID and Large Minibatch Effects                         |  
|   14:15 - 14:30   | *It's coffee o'clock!*                                  |
| **<font color='orange'>14:30 - 16:00</font>** | **<font color='orange'>Hands-on Session:</font> `PyTorch DistributedDataParallel`** |


## Hands-on Session: `PyTorch DistributedDataParallel`
Today in the morning, you learned how to train a neural network in `PyTorch` using the example of [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) classification with the convolutional neural network *AlexNet*. In this hands-on tutorial, you will learn how to train the same network in a distributed data-parallel fashion. We will use `PyTorch`'s `DistributedDataParallel` module for this. 

### Short recap
#### AlexNet

*AlexNet* is a CNN for image classification, originally of the [ImageNet](https://www.image-net.org/) dataset. 
The input is an RGB image and the output is a vector of $n_\text{classes}$ numbers that sum up to 1, where the $i^\text{th}$ element can be interpreted as the probability that the input image belongs to class $i$. 
*AlexNet* consists of five convolutional layers, some followed by max-pooling, and three fully-connected layers. It uses the non-saturating ReLU activation function.  

Alex Krizhevsky, Ilya Sutskever, and Geoffrey E. Hinton. **[Imagenet classification with deep convolutional neural networks.](https://proceedings.neurips.cc/paper/2012/file/c399862d3b9d6b76c8436e924a68c45b-Paper.pdf)** *Advances in Neural Information Processing Systems* 25 (2012): 1097-1105.   

![Architektur von AlexNet.](AlexNet-1.png "Architecture of AlexNet: AlexNet consists of eight layers: the first five are convolutional layers, some followed by max-pooling layers, the last three are fully connected layers. It uses the non-saturating ReLU activation function.")  

Source: [https://learnopencv.com/understanding-alexnet/](https://learnopencv.com/understanding-alexnet/)

#### CIFAR-10
The [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset contains 60,000 color images of size 32 x 32 from ten classes, where each class holds 6000 images. 
The dataset is divided into five training batches and one test batch, each containing 10,000 images. 
The test batch contains exactly 1000 randomly selected images from each class. 
The training batches contain the remaining images in random order. 
Below you see the classes of the dataset and ten random images from each class:  
![CIFAR-10-Dataset.](Cifar.png " ")  

Source: [https://www.cs.toronto.edu/~kriz/cifar.html](https://www.cs.toronto.edu/~kriz/cifar.html)

#### Data-parallel neural networks (DPNNs) in `PyTorch`

After lunch, you learned about data-parallel training of neural networks. 
As you already know, this involves distributing the training process across multiple processors to accelerate computation and increase training throughput. 
`PyTorch` provides the `DistributedDataParallel` (`DDP`) module for this, which abstracts away some of the complexities of implementing data-parallel training in a distributed setting. 
From the official [documentation](https://pytorch.org/docs/master/generated/torch.nn.parallel.DistributedDataParallel.html):  
>Distributed data-parallel training is a widely adopted single-program multiple-data training paradigm. The model is replicated on every process, and every model replica will be fed with a different set of input data samples. The `DistributedDataParallel` module takes care of gradient communication to keep model replicas synchronized and overlaps it with the gradient computations to speed up training. 
It implements data parallelism at the module level which can run across multiple machines. Applications using `DDP` should spawn multiple processes and create a single `DDP` instance per process. `DDP` uses collective communications in the `torch.distributed` package to synchronize gradients and buffers. More specifically, `DDP` registers an autograd hook for each parameter given by `model.parameters()` and the hook will fire when the corresponding gradient is computed in the backward pass. Then `DDP` uses that signal to trigger gradient synchronization across processes.  
*The recommended way to use `DDP` is to spawn one process for each model replica. `DDP` processes can be placed on the same machine or across machines, but GPU devices cannot be shared across processes.*  

The `torch.distributed` package supports three built-in backends for communication between processors. 
This [table](https://pytorch.org/docs/stable/distributed.html#backends) shows which functions are available for use with CPU/CUDA tensors. 
Since *Noctua2* connects GPUs with NVLink within a node and Mellanox Infiniband Interconnect between nodes, we use the officially recommended NCCL backend. 
The [NVIDIA Collective Communication Library](https://developer.nvidia.com/nccl) (NCCL) implements multi-GPU and multi-node communication functions optimized for NVIDIA GPUs and networks. 
It provides routines, such as all-gather, all-reduce, broadcast, reduce, reduce-scatter, and point-to-point transmit and receive. 

#### How to train a DPNN with `PyTorch`'s `DistributedDataParallel` module
Below is a recipe for training a DPNN with `DDP` in `PyTorch`:

1. **Initialize the distributed environment:** Before using `DDP`, you need to define and initialize the distributed environment. This involves setting up the communication backend (NCCL for us), specifying the so-called process group, and assigning a unique rank and the world size to each process in the process group. The rank is like a unique process ID and the world size corresponds to the overall number of processes you want to use. 
2. **Load the data:** Data parallelism means splitting the input data across the processes in the process group and computing the forward and backward passes independently on each rank. This enables parallel processing and reduces the training time. You load the training and validation datasets and distribute them equally over the processes so that each process holds a different, exclusive subset of each dataset. `PyTorch` provides a dedicated sampler for this, the so-called `DistributedSampler`.
3. **Model instantiation and replication:** Afterwards, you need to replicate the model across the processes. Each replica will process a subset of the input data provided by the `DistributedSampler`. To do so, you instantiate the model just as in the serial case and wrap it with `DDP`. This ensures that the gradients computed during the backward pass are synchronized across all replicas.
4. **Training loop:** Repeat for a specified number of iterations or until convergence is reached:
    - *Forward pass*: Each replica of the model independently processes its portion of the input data. 
    - *Backward pass and gradient synchronization*: The gradients are computed independently on each replica. They are then synchronized across all replicas using a function called "all-reduce". This step ensures that the model parameters are updated consistently across all processes.
    - *Optimization step:* Once the gradients are synchronized, the optimizer performs an optimization step to update the model parameters. This step is performed independently and redundantly on each replica.
    - *Validation*: After updating the model parameters, you can compute the current model's accuracy on the training and validation dataset. As each process only holds a portion of each dataset, you need to implement some more communication to obtain the accuracy on each whole dataset. 
5. **Evaluation:** After training, you can evaluate the final model's performance using a held-out test dataset. The evaluation is typically performed on a single process without the need for data parallelism. 

### What you will do now

Building on our morning hands-on session, you will learn how to train a data-parallel version of *AlexNet* in `PyTorch`. This tutorial will guide you through the steps required for parallelizing the training in a data-parallel fashion using `DDP`. 
As parallel runs are inconvenient using Jupyter Notebooks, you will need to create `Python` scripts from the code snippets provided in this notebook and run the scripts as a batch job on *Noctua2*.

The tutorial is structured as follows:

-------------
1. Get all the building blocks.
- **Model:** Define your model <font color='grey'> (~0 min)</font>.  
- **Data:** Define the dataloaders <font color='grey'> (~20 min)</font>.   

*Short break with everyone to discuss your results and possible solutions.*  

- **Training:** Define the training loop <font color='grey'> (~20 min)</font>.

*Short break with everyone to discuss your results and possible solutions.*  

2. Assemble the main `Python` script from your building blocks <font color='grey'> (~20 min)</font>. 

*Short break with everyone to discuss your results and possible solutions.*  

3. Run your code in parallel as a batch job on *Noctua2* <font color='grey'> (~30 min)</font>. 
- Use four GPUs, i.e., four processes in the data-parallel training process. 
- Start your `Python` script in parallel with the `srun` ([doc](https://slurm.schedmd.com/srun.html)) command in a job bash script. 
- Submit your job script to the [SLURM](https://www.schedmd.com/) workload manager with `sbatch` ([doc](https://slurm.schedmd.com/sbatch.html)).

*Final break with everyone to discuss your results and possible solutions.*  

-------------

Below is a corresponding code framework with dataloaders for the training including validation and testing, along with detailed explanations and instructions for each step. 
**Normal comments with '#' describe code as usual, in lines with '##' you need to add code.** 

## 1. Get all the building blocks

### Model: Define the model
As the first step, you again need to define your model architecture. 
This is just a copy-paste of your `AlexNet` module class from the serial case. 
Later on, you will wrap an instance of this module with `DDP` to make it distributed data-parallel (the magic happens here). 
Save the code below as a separate `Python` module file `model.py` so that you can import the `AlexNet` module class from this file into your main script. 

In [2]:
# MODEL
# Define neural network by subclassing PyTorch's nn.Module. 
class AlexNet(torch.nn.Module):
    
    # Initialize neural network layers in __init__. 
    def __init__(self, num_classes = 1000, dropout = 0.5):
        super().__init__()
        self.features = torch.nn.Sequential(
            # AlexNet has 8 layers: 5 convolutional layers, some followed by max-pooling (see figure),
            # and 3 fully connected layers. In this model, we use nn.ReLU between our layers, 
            # but there are other activations to introduce non-linearity in a model.
            # nn.Sequential is an ordered container of modules. 
            # The data is passed through all the modules in the same order as defined. 
            # You can use sequential containers to put together a quick network.
            #
            # IMPLEMENT FEATURE-EXTRACTOR PART OF ALEXNET HERE!
            # 1st convolutional layer (+ max-pooling)
            torch.nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            torch.nn.ReLU(inplace=True),
            torch.nn.MaxPool2d(kernel_size=3, stride=2),
            # 2nd convolutional layer (+ max-pooling)
            torch.nn.Conv2d(64, 192, kernel_size=5, padding=2),
            torch.nn.ReLU(inplace=True),
            torch.nn.MaxPool2d(kernel_size=3, stride=2),
            # 3rd + 4th convolutional layer
            torch.nn.Conv2d(192, 384, kernel_size=3, padding=1),
            torch.nn.ReLU(inplace=True),
            torch.nn.Conv2d(384, 256, kernel_size=3, padding=1),
            torch.nn.ReLU(inplace=True),
            # 5th convolutional layer (+ max-pooling)
            torch.nn.Conv2d(256, 256, kernel_size=3, padding=1),
            torch.nn.ReLU(inplace=True),
            torch.nn.MaxPool2d(kernel_size=3, stride=2),
        )
        # Average pooling to downscale possibly larger input images.
        self.avgpool = torch.nn.AdaptiveAvgPool2d((6, 6))
        self.classifier = torch.nn.Sequential( 
            # IMPLEMENT FULLY CONNECTED MULTI-LAYER PERCEPTRON PART HERE!
            # 6th, 7th + 8th fully connected layer 
            # The linear layer is a module that applies a linear transformation 
            # on the input using its stored weights and biases.
            torch.nn.Dropout(p=dropout),
            torch.nn.Linear(256 * 6 * 6, 4096),
            torch.nn.ReLU(inplace=True),
            torch.nn.Dropout(p=dropout),
            torch.nn.Linear(4096, 4096),
            torch.nn.ReLU(inplace=True),
            torch.nn.Linear(4096, num_classes),
        )
    # Forward pass: Implement operations on the input data, i.e., apply model to input x.
    def forward(self, x):
        # IMPLEMENT OPERATIONS ON INPUT DATA x HERE!
        x = self.features(x)    # Apply feature-extractor part to input.
        x = self.avgpool(x)     # Apply average-pooling part.
        x = torch.flatten(x, 1) # Flatten.
        x = self.classifier(x)  # Apply fully connected multilayer perceptron part.
        return x

### Data: Define dataloaders
Next, you again need to get the data in. You already learned that to train a DPNN, each process needs to load an exclusive subset of the dataset. 
`PyTorch` provides a dedicated sampler to distribute and load data in a distributed training setting, the so-called `DistributedSampler`.
It enables efficient data loading across multiple processes by partitioning the dataset into smaller subsets that are processed independently by each process.
The `DistributedSampler` works in conjunction with `DDP`. It ensures that each process operates on a unique subset of the dataset, avoiding redundant computation and enabling parallelism. 
Below, you can find an overview of how this works:

1. **Data partitioning:** The `DistributedSampler` partitions the dataset into smaller subsets based on the number of processes involved in the distributed training. Each process is responsible for processing a specific subset of the data.

2. **Shuffling and sampling:** Optionally, the `DistributedSampler` can shuffle the dataset before partitioning it to introduce randomness into the training. This helps prevent biases and improves the model's generalization. The shuffling is typically performed on a single process, and the shuffled indices are then broadcasted to other processes.

3. **Data loading:** During training, each process loads its assigned subset of the dataset using the `DistributedSampler`. The sampler provides indices corresponding to the samples in the process's partition of the dataset.

4. **Parallel processing:** Once the data is loaded, each process operates independently on its portion of the dataset. Forward and backward passes, as well as the optimization step, are performed separately on each process.

5. **Synchronization:** After each training iteration, the processes synchronize to ensure that the model parameters and gradients are consistent across all processes. This synchronization is handled by `DDP`.

6. **Iteration and epoch completion:** The `DistributedSampler` manages the completion of iterations and epochs. It ensures that each process finishes processing its assigned subset of the data before moving on to the next iteration or epoch. The `DistributedSampler` may also reshuffle the dataset at the end of each epoch to introduce further randomness.

Complete the code below and save it as a separate `Python` module file `helper_dataset.py` so that you can import the dataloader from this file into your main script.

In [None]:
import torch
import torchvision
import numpy as np

def get_dataloaders_cifar10_ddp(
    batch_size, 
    num_workers=0,
    root='data',
    validation_fraction=0.1,
    train_transforms=None,
    test_transforms=None
):
    """
    Get distributed CIFAR10 dataloaders for training and validation in a DDP setting.
    
    Params
    ------
    batch_size : int
                 batch size
    num_workers : int
                  How many workers to use for data loading.
    root : str
           path to data dir
    validation_fraction : float
                          fraction of train dataset used for validation
    train_transforms : torchvision.transforms.<transformation>
                       How to preprocess the training data.
    test_transforms : torchvision.transforms.<transformation>
                      How to preprocess the test data.
                      
    Returns
    -------
    torch.utils.data.Dataloader : training dataloader
    torch.utils.data.Dataloader : validation dataloader
    """
    if train_transforms is None: 
        train_transforms = torchvision.transforms.ToTensor()
    if test_transforms is None: 
        test_transforms = torchvision.transforms.ToTensor()

    train_dataset = torchvision.datasets.CIFAR10(root=root,
                                                 train=True,
                                                 transform=train_transforms,
                                                 download=True)

    valid_dataset = torchvision.datasets.CIFAR10(root=root,
                                                 train=True,
                                                 transform=test_transforms)

    # Perform index-based train-validation split of original training data. 
    total = len(train_dataset) # Get overall number of samples in original training data.
    idx = list(range(total)) # Make index list.
    np.random.shuffle(idx) # Shuffle indices.
    vnum = int(validation_fraction * total) # Determine number of validation samples from validation split.
    train_indices, valid_indices = idx[vnum:], idx[0:vnum] # Extract train and validation indices.

    train_dataset = torch.utils.data.Subset(train_dataset, train_indices)
    valid_dataset = torch.utils.data.Subset(valid_dataset, valid_indices)

    # Sampler that restricts data loading to a subset of the dataset.
    # Especially useful in conjunction with torch.nn.parallel.DistributedDataParallel. 
    # Each process can pass a DistributedSampler instance as a DataLoader sampler, 
    # and load a subset of the original dataset that is exclusive to it.

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset,
        num_replicas=torch.distributed.get_world_size(),
        rank=torch.distributed.get_rank(),
        shuffle=True,
        drop_last=True
    )

    valid_sampler = torch.utils.data.distributed.DistributedSampler(
        valid_dataset,
        num_replicas=torch.distributed.get_world_size(),
        rank=torch.distributed.get_rank(),
        shuffle=True,
        drop_last=True
    )

    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        drop_last=True,
        sampler=train_sampler
    )

    valid_loader = torch.utils.data.DataLoader(
        dataset=valid_dataset,
        batch_size=batch_size,
        drop_last=True,
        sampler=valid_sampler
    )

    return train_loader, valid_loader

### *Short break with everyone to discuss your results and possible solutions.*

### Training: Define the training loop
Now that we have our `AlexNet` model and the distributed CIFAR-10 data, we want to actually train, validate, and test it by optimizing its parameters in a data-parallel fashion with `DDP`. 
Remember that training a model is an iterative process. In each iteration, the model predicts the output for a given input, calculates the error in its prediction as quantified by the loss function, collects the derivatives of the loss w.r.t. its parameters, and optimizes these parameters using gradient descent. 
As `DDP` handles the synchronization of the gradients over all processes for you, the structure of the training loop stays basically the same as before. 
While each processor trains its model replica on its local training data batch provided by the `DistributedSampler`, 
`DDP` takes care of gradient communication to keep the model replicas synchronized. 
You might also want to track the average loss over all processes during the training. 
As `DDP` only takes care of the gradient synchronization, you have to implement this explicitly using collective communication functions from `torch.distributed` ([doc](https://pytorch.org/docs/stable/distributed.html#collective-functions)).  

Similar to the serial case, we will define some useful helper functions for validating our model during training and testing it afterwards on unseen data:
- `get_right_ddp`: Get the number of correctly predicted and overall samples for a given model on a given dataset. You will need those numbers for calculating your model's accuracy during the training loop on the distributed training and validation datasets.
- `compute_accuracy_ddp`: Compute the accuracy of your model's predictions on a given dataset. You will need this function for testing your final model on a held-out test dataset after the training is done. Conceptually the same as for the serial case with some slight technical differences.

All of this functionality is defined in the `train_model_ddp` function below. 
Complete the code and save it as a separate `Python` module file `helper_train.py` so that you can import the training function from this file into your main script.

In [None]:
import os
import random
import time
import torch
import numpy as np

def compute_accuracy_ddp(model, data_loader):
    """
    Compute accuracy of model predictions on given labeled data.
    
    Params
    ------
    model : torch.nn.Module
            Model.
    data_loader : torch.utils.data.Dataloader
                  Dataloader.
    device : torch.device
             device to use
    
    Returns
    -------
    float : The model's accuracy on the given dataset in percent.
    """
    with torch.no_grad():

        correct_pred, num_examples = 0, 0

        for i, (features, targets) in enumerate(data_loader):

            features = features.cuda()
            targets = targets.float().cuda()

            logits = model(features)
            _, predicted_labels = torch.max(logits, 1) # Get class with highest score.

            num_examples += targets.size(0)
            correct_pred += (predicted_labels == targets).sum()
    return correct_pred.float() / num_examples * 100

def get_right_ddp(model, data_loader):
    """
    Compute the number of correctly predicted samples and the overall number of samples in a given dataset.
    
    This function is needed to compute the accuracy over multiple processors in a distributed data-parallel setting.
    
    Params
    ------
    model : torch.nn.Module
            Model.
    data_loader : torch.utils.data.Dataloader
                  Dataloader.
    
    Returns
    -------
    int : The number of correctly predicted samples.
    int : The overall number of samples in the dataset.
    """
    with torch.no_grad():

        correct_pred, num_examples = 0, 0

        for i, (features, targets) in enumerate(data_loader):

            features = features.cuda()
            targets = targets.float().cuda()
            logits = model(features)
            _, predicted_labels = torch.max(logits, 1) # Get class with highest score.

            num_examples += targets.size(0)
            correct_pred += (predicted_labels == targets).sum()
    num_examples = torch.Tensor([num_examples]).cuda()
    return correct_pred, num_examples

def train_model_ddp(
    model,
    num_epochs,
    train_loader,
    valid_loader,
    optimizer
):
    """
    Train model in distributed data-parallel fashion.

    Params
    ------
    model : torch.nn.Module
            model to train
    num_epochs : int
                 number of epochs to train
    train_loader : torch.utils.data.Dataloader
                   training dataloader
    valid_loader : torch.utils.data.Dataloader
                   validation dataloader
    optimizer : torch.optim.Optimizer
                optimizer to use
    """
    start = time.perf_counter() # Measure training time.
    rank = torch.distributed.get_rank()
    world_size = torch.distributed.get_world_size()

    loss_history, train_acc_history, valid_acc_history = [], [], [] # Initialize history lists.


    for epoch in range(num_epochs): # Loop over epochs.

        train_loader.sampler.set_epoch(epoch)
        model.train() # Set model to training mode.

        for batch_idx, (features, targets) in enumerate(train_loader): # Loop over mini batches.

            features = features.cuda()
            targets = targets.cuda()

            # Forward and backward pass.
            logits = model(features)
            loss = torch.nn.functional.cross_entropy(logits, targets)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step() # Perform single optimization step to update model parameters.

            # Logging.
            torch.distributed.all_reduce(loss) # Allreduce rank-local mini-batch losses.
            loss /= world_size # Average allreduced rank-local mini-batch losses over all ranks.
            loss_history.append(loss.item()) # Append globally averaged loss of this epoch to history list.

            if rank == 0:
                print(f'Epoch: {epoch+1:03d}/{num_epochs:03d} '
                      f'| Batch {batch_idx:04d}/{len(train_loader):04d} '
                      f'| Averaged Loss: {loss:.4f}')

        model.eval() # Set model to evaluation mode.

        with torch.no_grad(): # Disable gradient calculation.
            # Get rank-local numbers of correctly classified and overall samples in training and validation set.
            right_train, num_train = get_right_ddp(model, train_loader)
            right_valid, num_valid = get_right_ddp(model, valid_loader)
            
            # Allreduce rank-local numbers of correctly classified and overall training and validation samples.
            torch.distributed.all_reduce(right_train)
            torch.distributed.all_reduce(right_valid)
            torch.distributed.all_reduce(num_train)
            torch.distributed.all_reduce(num_valid)
            train_acc = right_train.item() / num_train.item() * 100
            valid_acc = right_valid.item() / num_valid.item() * 100
            train_acc_history.append(train_acc)
            valid_acc_history.append(valid_acc)

            if rank == 0:
                print(f'Epoch: {epoch+1:03d}/{num_epochs:03d} '
                  f'| Train: {train_acc :.2f}% '
                  f'| Validation: {valid_acc :.2f}%')

        elapsed = (time.perf_counter() - start) / 60 # Measure training time per epoch.
        Elapsed = torch.Tensor([elapsed]).cuda()
        torch.distributed.all_reduce(Elapsed)
        Elapsed /= world_size
        if rank == 0:
            print(f'Time elapsed: {Elapsed.item()} min')

    elapsed = (time.perf_counter() - start) / 60 # Measure total training time.
    Elapsed = torch.Tensor([elapsed]).cuda()
    torch.distributed.all_reduce(Elapsed)
    Elapsed /= world_size

    if rank == 0:
        print(f'Total training time: {Elapsed.item()} min')
        torch.save(loss_history, f'loss_{world_size}_gpu.pt')
        torch.save(train_acc_history, f'train_acc_{world_size}_gpu.pt')
        torch.save(valid_acc_history, f'valid_acc_{world_size}_gpu.pt')

    return loss_history, train_acc_history, valid_acc_history

### *Short break with everyone to discuss your results and possible solutions.*

## 2. Assemble the main `Python` script from your building blocks
Now that you have implemented all the functions and classes you need, you are ready to put together the main `Python` script that is to be executed in parallel on the supercomputer. 
As explained above, you need to set up the so-called process group first. 
After this has been done properly, you can load your data so that each process holds an exclusive subset, instantiate your module, wrap it with `DDP`, and train it in a data-parallel fashion on the process-local data. 
As `DDP` broadcasts model states from the process with rank 0 (often called root) to all other processes in the `DDP` constructor, you do not need to worry about different `DDP` processes starting from different initial model parameter values. 
As you have seen, `DDP` wraps lower-level distributed communication details and provides a clean API as if it were a local model. 
Gradient synchronization communications take place during the backward pass and overlap with the backward computation. 
When the `backward()` returns, `param.grad` already contains the synchronized gradient tensor. 

Complete the code below and save it as a separate `Python` script `main.py` in the same folder as all your helper module files. This file is the one to be actually run in parallel on *Noctua2*.

In [1]:
import os
import torch
import torchvision
import torch.distributed as dist

from torch.nn.parallel import DistributedDataParallel as DDP
from model import AlexNet
from helper_dataset import get_dataloaders_cifar10_ddp
from helper_train import train_model_ddp, get_right_ddp, compute_accuracy_ddp

def main():
    world_size = int(os.getenv("SLURM_NPROCS")) # Get overall number of processes.
    rank = int(os.getenv("SLURM_PROCID"))       # Get individual process ID.
    slurm_job_gpus = os.getenv("SLURM_JOB_GPUS")
    slurm_localid = int(os.getenv("SLURM_LOCALID"))
    gpus_per_node = torch.cuda.device_count()
    gpu = rank % gpus_per_node
    assert gpu == slurm_localid
    device = f"cuda:{slurm_localid}"
    torch.cuda.set_device(device)

    # Initialize DDP.
    dist.init_process_group(
        backend="nccl", 
        rank=rank, 
        world_size=world_size, 
        init_method="env://"
    )
    if dist.is_initialized():
        print(f"Rank {rank}/{world_size}: Process group initialized with torch rank {torch.distributed.get_rank()} and torch world size {torch.distributed.get_world_size()}.")

    b = 256 # Set batch size.
    e = 100 # Set number of epochs to be trained.

    # Define transforms for data preprocessing to make smaller CIFAR-10 images work with AlexNet.
    # You can find a more detailed explanation in the notebook of the first hands-on session.
    train_transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize((70, 70)),
        torchvision.transforms.RandomCrop((64, 64)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    test_transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize((70, 70)),
        torchvision.transforms.CenterCrop((64, 64)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # Get distributed dataloaders for training and validation data on all ranks.
    train_loader, valid_loader = get_dataloaders_cifar10_ddp(
        batch_size=b,
        root='/scratch/hpc-prf-nhrgs/mweiel/data',
        train_transforms=train_transforms,
        test_transforms=test_transforms)

    # Get dataloader for test data. 
    # Final testing is only done on root.
    if dist.get_rank() == 0:
        test_dataset = torchvision.datasets.CIFAR10(
            root="/scratch/hpc-prf-nhrgs/mweiel/data",
            train=False,
            transform=test_transforms
        )
        test_loader = torch.utils.data.DataLoader(
            dataset=test_dataset,
            batch_size=b,
            shuffle=False
        )

    model = AlexNet(num_classes=10).to(device) # Create model and move it to GPU with id rank.
    ddp_model = DDP( # Wrap model with DDP.
        model, 
        device_ids=[slurm_localid], 
        output_device=slurm_localid
    )
    optimizer = torch.optim.SGD(
        ddp_model.parameters(), 
        momentum=0.9, 
        lr=0.1
    )

    # Train model.
    loss_history, train_acc_history, valid_acc_history = train_model_ddp(
        model=ddp_model,
        num_epochs=e,
        train_loader=train_loader,
        valid_loader=valid_loader,
        optimizer=optimizer
    )

    # Test final model on root.
    if dist.get_rank() == 0:
        test_acc = compute_accuracy_ddp(ddp_model, test_loader) # Compute accuracy on test data.
        print(f'Test accuracy {test_acc :.2f}%')

    dist.destroy_process_group()

# MAIN STARTS HERE.    
if __name__ == '__main__':
    main()

### *Short break with everyone to discuss your results and possible solutions.*

## 3. Run your code in parallel as batch job on *Noctua2* 
Now it's time to actually run your script on *Noctua2*. 
As a normal user, you do not have root rights on a supercomputer. This means you cannot install any software you want but software is made available to you via pre-installed modules. On top of that, you cannot use compute resources as you wish but compute resources are managed among all users by a so-called job scheduler. 

Job scheduling on a supercomputer is like managing a busy restaurant. Imagine you have a restaurant with many tables, and many customers want to be served at the same time. To serve everyone efficiently, you need a system to manage who gets seated at which table and when. Similarly, many people or organizations typically want to use a supercomputer to run their programs or simulations at the same time. Job scheduling is the process of deciding which jobs (programs or tasks) should run on the supercomputer, when they should start, and how long they can use the resources. 
The job scheduler, like a restaurant manager, looks at the incoming jobs and determines the best order and timing for execution. It takes into account factors such as job priority, estimated runtime, resource availability, and fairness. For example, a critical job that requires a lot of computing power may be given a higher priority and scheduled to run as soon as possible, while a smaller job that can be completed quickly might be scheduled to run in between larger jobs. 
The job scheduler also ensures that the supercomputer's resources, like processors, memory, and storage, are used efficiently. It assigns these resources to different jobs based on their requirements and availability, making sure that multiple jobs can run concurrently without interfering with each other. In this way, a job scheduler on a supercomputer manages the incoming workload, organizes the jobs, and allocates resources effectively to maximize the utilization and performance of the supercomputer, just like a restaurant manager aims to serve all customers in the most efficient way possible.

As many of the world's supercomputers and computer clusters, *Noctua2* uses the SLURM workload manager, a free and open-source job scheduler for Linux kernels. 
It provides three key functions:
- Allocating access to compute resources to users for some time so they can perform work,
- Providing a framework for starting, executing, and monitoring work, and
- Arbitrating contention for resources by managing a queue of pending jobs.

SLURM is the workload manager on about 60\% of the TOP500 supercomputers.
To run a job on a supercomputer, you need to submit a batch job script to SLURM with the `sbatch` ([doc](https://slurm.schedmd.com/sbatch.html)) command. 
This job script specifies which compute resources you need for how long, along with the actual code to run. 
Below you find a job script requesting four GPUs on one node. 
Use the `srun` ([doc](https://slurm.schedmd.com/srun.html)) command to execute your `Python` script in parallel on the requested four GPUs.
Adapt the code for your needs and save it as a separate bash script `submit_4_gpu.sh`.  
To run your script on *Noctua2*, submit it to the SLURM workload manager: `sbatch submit_4_gpu.sh`

In [None]:
#!/bin/bash

#SBATCH --job-name=alex4
#SBATCH --partition=gpu
#SBATCH --gres=gpu:a100:4
#SBATCH --time=30:00
#SBATCH --nodes=1
#SBATCH --account=hpc-prf-nhrgs
#SBATCH --ntasks-per-node=4
#SBATCH --output=/scratch/hpc-prf-nhrgs/<your_name>/res/slurm-%j.out
#SBATCH --mail-user=...  # Adjust this to match your email address.
#SBATCH --mail-type=ALL

module purge # Unload all models.
module load vis/torchvision/0.13.1-foss-2022a-CUDA-11.7.0 # Load required modules.

# Change 5-digit MASTER_PORT as you wish, SLURM will raise Error if duplicated with others.
export MASTER_PORT=12340

# Get the first node name as master address.
master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
export MASTER_ADDR=$master_addr
echo "MASTER_ADDR="$MASTER_ADDR

export PYDIR=/scratch/hpc-prf-nhrgs/<your_name>/py # Set path to your python scripts.
export RESDIR=/scratch/hpc-prf-nhrgs/<your_name>/res/job_${SLURM_JOB_ID} # Set path to save results for this job.
mkdir ${RESDIR} # Create results dir.
cd ${RESDIR} # Change to results dir.

srun python -u ${PYDIR}/alex_parallel.py # Run python script in parallel using srun.
# Each process executes exactly the same script!
mv ../slurm-${SLURM_JOBID}.out ${RESDIR}

### Congratulations! 
You have successfully trained a distributed data-parallel deep neural network in `PyTorch`. To analyze your results visually, you can now plot the evolution of the loss, training accuracy, and validation accuracy over the training, e.g., with `matplotlib.pyplot`.
### *Short break with everyone to discuss your results and possible solutions.*