# Single Node MultiGPU Training with Torchrun 

In the previous notebook we introduced parallel computing at HPC centers and a specific algorithm for distributed training with Pytorch, DDP.  However, there are can be many challenges that arise when utilizing multiple computational resources.  One of the biggest challenges is what happens when one of the computational resources fails during training?  In this notebook we will discuss these issues and how we set up our parallel implementation to be able to continue to run despite intermittent computational resources.  We will also combine the information in this tutorial and the previous tutorial and apply it to the MNIST Classifier we used previously.   

<center>
<img src="https://impanix.com/wp-content/uploads/2023/05/What-is-Fault-Tolerance-Types-and-How-To-Implement-768x461.png" width=400 /><br>
<b>Figure 1.</b> Fault Tolerance 
</center>


Specifically, in this tutorial, we will cover the following material:
- Introduce Fault Tolerance
- Introduce Pytorch's Torchrun
- Go over code modifications need to use torchrun to launch distributed code that is fault tolerant
- Implement a script for training the classifier using torchrun 

## Fault Tolerance

Leveraging multiple GPUs to train neural networks comes with many challenges.  One of the biggest is what happens when GPUs fails due to many potential issues (overheating, old system wears out, virus, etc.).  **Fault tolerance** is the ability of a system to maintain operation despite one of its components failing. One way to combat component failure is via checkpointing.  In checkpointing we periodically save the state of our application (in the case of deep learning the current weights of our model), so that if a process failure occurs we can resume our application from the previous checkpoint (See figure 2). 


<img src="./img/checkpointing.png" />

<b>Figure 2.</b> Visual of checkpointing.  CP refers to a point in time when a checkpoint is saved. 

Next, we will talk about `torchrun`, pytorch's tools for launching distributed training that will handle fault tolerance via check pointing for you. 

## Using Torchrun 

Pytorch has a tool which automatically handles fault tolerance with checkpointing called `torchrun`.  Specifically, `torhcrun` has the following functionalities:

-  worker failures are handled gracefully by restarting your workers at the previously saved checkpoint
-  environment variables, like RANK and WORLD_SIZE, are automatically set for you.  All environment variables set by pytorch can be found [here](https://pytorch.org/docs/stable/elastic/run.html#environment-variables)
-  number of nodes being leveraged can vary during training (elasticity)

In this notebook we will introduce how to utilize environment variable automatically set in torchrun as well as how to use checkpointing.  We will not cover elasticity as it is outside the scope of this course.  To explain the functionality of the torchrun we will:

1. Cover the code modifications needed using the MNIST example from the previous notebook
2. Explain how to launch your script.

Let's get started.

### Code Modifications with MNIST Example

To utilize torchrun's functionality we will need to make some changes to the distributed scaling script we created in the previous notebook. These code changes include:

1. Modify code for environment variables set by torchrun:
    1. Remove code that sets environment variables as this done for you automatically with torchrun.
    2. Instead, use these environment variables set by pytorch and instead of explicitly defining them.
2. Add code for writing checkpoints and resuming training after failure
    1. Create location to store checkpoints
    2. Read checkpoints if they exist and resume training at epoch checkpoint was written
    3. Write checkpoints periodically during training
3. Remove using the mp.spawn to parallelize code and replace this with a function call, as this is done automatically by torchrun

Let's highlight the listed code changes above by revisiting the MNIST example we used in previous notebook.  In order to implement these changes only two functions need to be modified, `init_distributed` and `main` functions.  

#### 1. Modify code for environment variables set by torchrun

In order to use the environment variables set by torchrun we will need to make modifications to both the `init_distributed` and `main` functions as highlighted in code example below.  In summary, we removed the local_rank and world_size arguments from the `init_distributed` function and instead set these variables within the function from the environment variables set by torchrun. Additionally, we modify our main function to utilize the `local_rank` environment variable to set the device where our model should be stored as well as call the modified `init_distributed` function.

In [None]:
##################################################################################
# A. Remove code that sets environment variables as this done for you automatically with torchrun.
def init_distributed():    # (local_rank, world_size):

    # B. Instead, use these environment variables set by pytorch and instead of explicitly defining them.
    world_size = int(os.environ['WORLD_SIZE'])
    local_rank = int(os.environ['LOCAL_RANK'])
    torch.cuda.set_device(local_rank)                       
    dist.init_process_group("nccl",                   
                            rank=local_rank,          
                            world_size=world_size)    

def main():
    #####################################################################
    # 1.B We also create the variable local_rank in our main function as well as call the new init_distributed()
    # this will be used to assign the gpu where our model should reside as highlighted below 
    local_rank = int(os.environ['LOCAL_RANK'])

    init_distributed()
    ################################################
    # .....
    # instantiate network and set to local_rank device
    net = Net().to(local_rank)


#### 2. Add code for writing checkpoints and resuming training after failure

We need to make several modifications to the main function to incorporate writing checkpoints and resuming at a checkpoint after process failure.  These modifications are highlighted below with rows of `#` and includes line by line comments to explain why each modification was written.

In [None]:
def main():
    local_rank = int(os.environ['LOCAL_RANK'])
    init_distributed()

    train_dataloader = prepare_data()

    ################################################                                                 
    # 2.A. Create location to store checkpoints

    # Create directory for storing checkpointed model
    model_folder_path = os.getcwd()+"/output_model/"          # create variable for path to folder for checkpoints
    os.makedirs(model_folder_path,exist_ok=True)              # create directory for models if they do not exist
    # create file name for checkpoint 
    checkpoint_file = model_folder_path+"best_model.pt"       # create filename for model checkpoint
    ################################################

    net = Net().to(local_rank)

    #################################################
    # 2B. Read checkpoints if they exist 
    if os.path.exists(checkpoint_file):
        checkpoint = load_checkpoint(checkpoint_file, DEVICE)  # load previous checkpoint
        model.load_state_dict(checkpoint['model_state_dict'])  # set model weights to be that of the last checkpoint
        epoch_start = checkpoint['epoch']                      # set epoch where training should resume
   
    # otherwise we are starting training from the beginning at epoch 0
    else:
        epoch_start = 0
    ################################################

    model = DDP(net,
            device_ids=[local_rank],                  # list of gpu that model lives on 
            output_device=local_rank,                 # where to output model
        )


    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    save_every = 1
    epochs = 10
    ###########################################################
    # 2C. Resume training at epoch last checkpoint was written
    for epoch in range(epoch_start, epochs):                  # note we start loop at epoch_start defined in code above
    ###########################################################
        train_loop(rank, train_dataloader, model, loss_fn, optimizer)
        ###########################################################
        # 2D. Write checkpoints periodically during training
        if rank == 0 and epoch%save_every==0:
            print(f"Epoch {epoch+1}\n-------------------------------")
            torch.save({                                     # save model's state_dict and current epoch periodically
                'epoch':epoch,
                'model_state_dict':model.module.state_dict(),
            }, checkpoint_file)
            print("Finished saving model\n")
        ###########################################################

    dist.destroy_process_group()


You can find the entire modified script with the changes highlighted above in the file `mnist_torchrun.py`.  Next, we will learn how to run this script with `torchrun`.

### Launching jobs with Torchrun

In order to launch our new `mnist_torchrun.py` script you can use the `torchrun` command.  There are several arguments that could pass with torchrun.  These arguments vary based on the type of job you are launching.  For example, the arguments needed for a single node job versus a multinode.  For now, we will cover the arguments needed for a single node job.  

Let's start by introducing three arguments that can be helpful when launching a single node job:
- **--standalone** : This indicates to pytorch that you are running a single machine multiworker job.  It automatically sets up a rendezvous backend that is represented by a C10d TCP store on port 29400
- **--nnodes** : Total number of nodes being used
- **--nproc-per-node** : number of processes per node; this is typically set to the number of GPUs on your machine(s)

To launch a generic training script (YOUR_TRAINING_SCRIPT.py) on a single node with 4 GPUs you can do the following:

```
torchrun
    --standalone
    --nnodes=1
    --nproc-per-node=4
    YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
```

Next, let's run our MNIST training script with torchrun:

In [1]:
!torchrun --nproc-per-node=4 mnist_torchrun.py

loss: 0.203195  [    0/60000]
loss: 0.074431  [ 3200/60000]
loss: 0.189091  [ 6400/60000]
loss: 0.073738  [ 9600/60000]
loss: 0.066978  [12800/60000]
Epoch 3
-------------------------------
Finished saving model

loss: 0.062485  [    0/60000]
loss: 0.078024  [ 3200/60000]
loss: 0.189246  [ 6400/60000]
loss: 0.053398  [ 9600/60000]
loss: 0.036767  [12800/60000]
Epoch 4
-------------------------------
Finished saving model

loss: 0.126662  [    0/60000]
loss: 0.069352  [ 3200/60000]
loss: 0.122379  [ 6400/60000]
loss: 0.059098  [ 9600/60000]
loss: 0.067286  [12800/60000]
Epoch 5
-------------------------------
Finished saving model

loss: 0.137779  [    0/60000]
loss: 0.018163  [ 3200/60000]
loss: 0.138813  [ 6400/60000]
loss: 0.054595  [ 9600/60000]
loss: 0.033671  [12800/60000]
Epoch 6
-------------------------------
Finished saving model

loss: 0.041862  [    0/60000]
loss: 0.011954  [ 3200/60000]
loss: 0.143069  [ 6400/60000]
loss: 0.065292  [ 9600/60000]
loss: 0.022921  [12800/60000

## Exercise (optional)

Modify simple linear regression script you created in previous tutorial to be able to use torchrun.

In [None]:
!torchrun --standalone --nnodes=1 --nproc_per_node=2 mnist_torchrun.py

Epoch 1
-------------------------------
loss: 24.083977  [    0/  128]
loss: 26.135323  [    0/  128]
loss: 25.212147  [    0/  128]
Epoch 2
-------------------------------
loss: 24.799385  [    0/  128]
loss: 18.444149  [    0/  128]loss: 18.045048  [    0/  128]

Epoch 3
-------------------------------
loss: 16.962395  [    0/  128]
loss: 16.770927  [    0/  128]
loss: 13.576710  [    0/  128]loss: 13.183523  [    0/  128]

Epoch 4
-------------------------------
loss: 11.582039  [    0/  128]
loss: 11.755295  [    0/  128]
loss: 11.668168  [    0/  128]loss: 10.266031  [    0/  128]

Epoch 5
-------------------------------
loss: 8.568253  [    0/  128]
loss: 8.953877  [    0/  128]
loss: 9.378763  [    0/  128]loss: 11.899617  [    0/  128]

loss: 7.603078  [    0/  128]
Epoch 6
-------------------------------
loss: 8.073140  [    0/  128]
loss: 8.063532  [    0/  128]
loss: 9.903101  [    0/  128]
loss: 13.486264  [    0/  128]
Epoch 7
-------------------------------
loss: 8.535467

## References 

1. https://arxiv.org/abs/2006.15704
2. https://pytorch.org/tutorials/beginner/ddp_series_theory.html