# CNN Part 4: Single Node MultiGPU Training with Torchrun 

## Using Torchrun 
Code below will modify the MNIST example to be run by torchrun.

### Modify code for environment variables set by torchrun

In [None]:
##################################################################################
# A. Remove code that sets environment variables as this done for you automatically with torchrun.
def init_distributed():

    # B. Instead, use these environment variables set by pytorch and instead of explicitly defining them.
    world_size = int(os.environ['WORLD_SIZE'])
    local_rank = int(os.environ['LOCAL_RANK'])
    torch.cuda.set_device(local_rank)
    dist.init_process_group("nccl",
                            rank=local_rank,
                            world_size=world_size)

def main():
    #####################################################################
    # B. We also create the variable local_rank in our main function as well as call the new init_distributed()
    # this will be used to assign the gpu where our model should reside as highlighted below 
    local_rank = int(os.environ['LOCAL_RANK'])

    init_distributed()
    ################################################
    # .....
    # instantiate network and set to local_rank device
    net = Net().to(local_rank)

### Add code for writing checkpoints and resuming training after failure

In [None]:
def main():
    local_rank = int(os.environ['LOCAL_RANK'])
    init_distributed()

    train_dataloader = prepare_data()

    ################################################                                                 
    # A. Create location to store checkpoints

    # Create directory for storing checkpointed model
    model_folder_path = os.path.join(os.environ['SCRATCH'], "cnn4_mnist_output_model") # create variable for path to folder for checkpoints
    os.makedirs(model_folder_path,exist_ok=True)                         # create directory for models if they do not exist
    
    # create file name for checkpoint 
    checkpoint_file = os.path.join(model_folder_path, "best_model.pt")   # create filename for model checkpoint
    ################################################

    net = Net().to(local_rank)

    #################################################
    # 2B. Read checkpoints if they exist 
    if os.path.exists(checkpoint_file):
        checkpoint = load_checkpoint(checkpoint_file, DEVICE)  # load previous checkpoint
        model.load_state_dict(checkpoint['model_state_dict'])  # set model weights to be that of the last checkpoint
        epoch_start = checkpoint['epoch']                      # set epoch where training should resume
   
    # otherwise we are starting training from the beginning at epoch 0
    else:
        epoch_start = 0
    ################################################

    model = DDP(net,
            device_ids=[local_rank],                  # list of gpu that model lives on 
            output_device=local_rank,                 # where to output model
        )


    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    save_every = 1
    epochs = 10
    ###########################################################
    # 2C. Resume training at epoch last checkpoint was written
    for epoch in range(epoch_start, epochs):                  # note we start loop at epoch_start defined in code above
    ###########################################################
        train_loop(rank, train_dataloader, model, loss_fn, optimizer)
        ###########################################################
        # 2D. Write checkpoints periodically during training
        if rank == 0 and epoch%save_every==0:
            print(f"Epoch {epoch+1}\n-------------------------------")
            torch.save({                                     # save model's state_dict and current epoch periodically
                'epoch':epoch,
                'model_state_dict':model.module.state_dict(),
            }, checkpoint_file)
            print("Finished saving model\n")
        ###########################################################

    dist.destroy_process_group()

## Launching jobs with Torchrun

The two cells below help display Python files `.py` as HTML. 

In [1]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}

<IPython.core.display.Javascript object>

In [2]:
from pathlib import Path
from IPython import display

Remove the model from $SCRATCH if you would like to start from epoch 0.

`display.Code` displays `cnn_part4/mnist_torchrun.py` with Python syntax highlighting as the cell output.

In [3]:
display.Code("cnn_part4/mnist_torchrun.py")

In [None]:
!torchrun --nproc-per-node=4 cnn_part4/mnist_torchrun.py

## Additional Exercise

Modify `cnn_part4/simple_linear_regression_parallel.py` to be able to use torchrun. 

In [4]:
display.Code("cnn_part4/simple_linear_regression_parallel.py")

Below is the modified version:

In [5]:
display.Code("cnn_part4/simple_linear_regression_parallel_torchrun.py")

In [None]:
!torchrun --nproc-per-node=4 cnn_part4/simple_linear_regression_parallel_torchrun.py

## DesignSafe Classifier 
### Reused code from Part 1 and 2 
Below are a set of functions and import statements that can be reused

In [None]:
import sys
import os
import numpy as np
import gc

import torch
import torchvision
from torchvision import datasets, models, transforms
import torch.nn as nn
from datetime import datetime
import warnings
import shutil

In [None]:
# Apply transformations to our data.
# The datasets transformations are the same as the ones from part 2 of this tutorial.
def load_datasets(train_path, val_path, test_path):
    val_img_transform = transforms.Compose([transforms.Resize((244,244)),
                                             transforms.ToTensor()])
    train_img_transform = transforms.Compose([transforms.AutoAugment(),
                                               transforms.Resize((244,244)),
                                               transforms.ToTensor()])
    train_dataset = datasets.ImageFolder(train_path, transform=train_img_transform)
    val_dataset = datasets.ImageFolder(val_path, transform=val_img_transform)
    test_dataset = datasets.ImageFolder(test_path, transform=val_img_transform) if test_path is not None else None

    return train_dataset, val_dataset, test_dataset

# Building the Neural Network
def getResNet():
    resnet = models.resnet34(weights='IMAGENET1K_V1')

    # Fix the conv layers parameters
    for conv_param in resnet.parameters():
        conv_param.require_grad = False

    # get the input dimension for this layer
    num_ftrs = resnet.fc.in_features

    # build the new final mlp layers of network
    fc = nn.Sequential(
          nn.Linear(num_ftrs, num_ftrs),
          nn.ReLU(),
          nn.Linear(num_ftrs, 3)
        )
   
    # replace final fully connected layer
    resnet.fc = fc
    return resnet

# Model evaluation.
@torch.no_grad()
def eval_model(data_loader, model, loss_fn, DEVICE):
    model.train(False)
    model.eval()
    loss, accuracy = 0.0, 0.0
    n = len(data_loader)

    for i, data in enumerate(data_loader):
        x,y = data
        x,y = x.to(DEVICE), y.to(DEVICE)
        pred = model(x)
        loss += loss_fn(pred, y)/len(x)
        pred_label = torch.argmax(pred, axis = 1)
        accuracy += torch.sum(pred_label == y)/len(x)

    return loss/n, accuracy/n

# loading checkpoint
def load_checkpoint(checkpoint_path, DEVICE):
    checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
    return checkpoint

def load_model_fm_checkpoint(checkpoint, primitive_model):
    primitive_model.load_state_dict(checkpoint['model_state_dict'])
    return primitive_model

### Setup Process Group (1 and 6)

In [None]:
import torch.distributed as dist

In [None]:
def init_distributed():
    '''
    set up process group with torchrun's environment variables
    '''
    # 1, 6 use os to get rank and world size
    dist_url = "env://"
    world_size = int(os.environ['WORLD_SIZE'])
    local_rank = int(os.environ['LOCAL_RANK'])
    dist.init_process_group(backend="nccl", # "nccl" for using GPUs, "gloo" for using CPUs
                          init_method=dist_url,
                          world_size=world_size,
                          rank=local_rank)
    torch.cuda.set_device(local_rank)

### Create Data DistributedSampler (2)

In [None]:
from torch.utils.data.distributed import DistributedSampler 

In [None]:
def construct_dataloaders(train_set, val_set, test_set, batch_size, shuffle=True):
    ##########################################################################################
    # 2. Use Pytorch's DistributedSampler to ensure that data passed to each GPU is different

    # create distributedsampler for train, validation and test sets
    train_sampler = DistributedSampler(dataset=train_set,shuffle=shuffle)
    val_sampler = DistributedSampler(dataset=val_set, shuffle=False)
    test_sampler = DistributedSampler(dataset=test_set, shuffle=False) if test_set is not None else None

    # pass distributedsampler for train, validation and test sets into DataLoader
    train_dataloader = torch.utils.data.DataLoader(train_set,batch_size=batch_size,sampler=train_sampler,num_workers=4,pin_memory=True)
    val_dataloader = torch.utils.data.DataLoader(val_set,batch_size=batch_size,sampler=val_sampler,num_workers=4)
    test_dataloader = torch.utils.data.DataLoader(test_set, batch_size, sampler=test_sampler,num_workers=4) if test_set is not None else None

    return train_dataloader, val_dataloader, test_dataloader

### Write Checkpoints periodically during training and only from one device (4, 7C)

In [None]:
def train(train_loader, val_loader, model, opt, scheduler, loss_fn, epochs, DEVICE, checkpoint_file, prev_best_val_acc):
    n = len(train_loader)

    best_val_acc = torch.tensor(0.0).to(DEVICE) if prev_best_val_acc is None else prev_best_val_acc

    for epoch in range(epochs):
        model.train(True)

        train_loader.sampler.set_epoch(epoch)

        avg_loss, val_loss, val_acc, avg_acc  = 0.0, 0.0, 0.0, 0.0

        start_time = datetime.now()

        for x, y in train_loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            pred = model(x)
            loss = loss_fn(pred,y)

            opt.zero_grad()
            loss.backward()
            opt.step()

            avg_loss += loss.item()/len(x)
            pred_label = torch.argmax(pred, axis=1)
            avg_acc += torch.sum(pred_label == y)/len(x)

        val_loss, val_acc = eval_model(val_loader, model, loss_fn, DEVICE)

        end_time = datetime.now()

        total_time = torch.tensor((end_time-start_time).seconds).to(DEVICE)

        # Learning rate reducer takes action
        scheduler.step(val_loss)

        avg_loss, avg_acc = avg_loss/n, avg_acc/n

        ###############################################################################
        # 4. Modify Training Loop to write model from one GPU     #####################
        # 7C. Write checkpoints periodically throughout training. #####################
        local_rank = int(os.environ['LOCAL_RANK'])
        # Only machine rank==0 (master machine) saves the model and prints the metrics    
        if local_rank == 0:

          # Save the best model that has the highest val accuracy
            if val_acc.item() > best_val_acc.item():
                print(f'lr for this epoch is {scheduler.get_last_lr()}')
                print(f"\nPrev Best Val Acc: {best_val_acc} < Cur Val Acc: {val_acc}")
                print("Saving the new best model...")
                torch.save({
                    'epoch':epoch,
                    'machine':local_rank,
                    'model_state_dict':model.module.state_dict(),
                    'accuracy':val_acc,
                    'loss':val_loss
                }, checkpoint_file)
                best_val_acc = val_acc
                print("Finished saving model\n")

            # Print the metrics (should be same on all machines)
            print(f"\n(Epoch {epoch+1}/{epochs}) Time: {total_time}s")
            print(f"(Epoch {epoch+1}/{epochs}) Average train loss: {avg_loss}, Average train accuracy: {avg_acc}")
            print(f"(Epoch {epoch+1}/{epochs}) Val loss: {val_loss}, Val accuracy: {val_acc}")
            print(f"(Epoch {epoch+1}/{epochs}) Current best val acc: {best_val_acc}\n")
        ###############################################################################

### Create Clean Up Function (5)

In [None]:
def cleanup():
    print("Cleaning up the distributed environment...")
    dist.destroy_process_group()
    print("Distributed environment has been properly closed")

### Wrap Model with DDP and put everything together in main function (3, 6B, 7A, 7B)

In [None]:
def main():
    
    hp = {"lr":1e-4, "batch_size":16, "epochs":5}
    train_path = os.path.join(os.environ['SCRATCH'], "Dataset_2/Train/")
    val_path   = os.path.join(os.environ['SCRATCH'], "Dataset_2/Validation/")
    test_path  = None

    #################################################
    # 6B. Use pytorch's enviornment variables.  #####
    local_rank = int(os.environ['LOCAL_RANK'])
    #################################################
    
    DEVICE = torch.device("cuda", local_rank)

    ###########################################################
    # 7A. create location to store checkpoints if they did not exist. ##
    
    model_folder_path = os.path.join(os.environ['SCRATCH'], "cnn4_damagelevel_output_model") 
    os.makedirs(model_folder_path,exist_ok=True)
    ###########################################################
   
    loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1).to(DEVICE)
    train_set, val_set, test_set = load_datasets(train_path, val_path, test_path)
    train_dataloader, val_dataloader, test_dataloader = construct_dataloaders(train_set, val_set, test_set, hp["batch_size"], True)

    model = getResNet().to(DEVICE)
    
    
    ######################################################################################
    # 7B, Read check point if it exists and pass to the train function to resume training##
    prev_best_val_acc = None
    checkpoint_file = os.path.join(model_folder_path, "best_model.pt")
    if os.path.exists(checkpoint_file):
        checkpoint = load_checkpoint(checkpoint_file, DEVICE)
        prev_best_val_acc = checkpoint['accuracy']
        model = load_model_fm_checkpoint(checkpoint,model)
        epoch_start = checkpoint['epoch']
        if rank == 0:
            print(f"resuming training from epoch {epoch_start}")
        else:
            epoch_start = 0
    ######################################################################################


    model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    ##########################################################################
    # 3. Wrap model with DDP #################################################
    model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
    ##########################################################################
    opt = torch.optim.Adam(model.parameters(),lr=hp["lr"])



    # same learning rate scheduler as part 2
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min',factor=0.1, patience=5, min_lr=1e-8)

    train(train_dataloader, val_dataloader, model, opt, scheduler, loss_fn, hp["epochs"], DEVICE, checkpoint_file, prev_best_val_acc)

    # only the node with rank 0 does the loading, evaluation and printing to avoild duplicate 
    if local_rank == 0:
        # store and print info on the best model at the end of training
        primitive_model = getResNet().to(DEVICE)
        checkpoint = load_checkpoint(checkpoint_file, DEVICE)
        best_model = load_model_fm_checkpoint(checkpoint,primitive_model)
        loss, acc = eval_model(val_dataloader,best_model,loss_fn,DEVICE)
        print(f"\nBest model (val loss: {loss}, val accuracy: {acc}) has been saved to {checkpoint_file}\n")
        ###############################
        # 5. close process group ######
        cleanup()
        ###############################

Copy the DesignSafe dataset to your `$SCRATCH`. If you had already copied the Dataset into your `$SCRATCH` folder (`$SCRATCH/Dataset_2`), you do not need to execute the code cell below.

In [None]:
! cp /scratch1/07980/sli4/training/cnn_course/data/data.tar.gz $SCRATCH
! tar zxf $SCRATCH/data.tar.gz -C $SCRATCH
! ls $SCRATCH/Dataset_2
! rm $SCRATCH/data.tar.gz

Launch the job with torchrun to train the designsafe classifier on a single node and 3 GPUs. Remove the previous model from $SCRATCH if you would like to start from epoch 0.

In [6]:
display.Code("cnn_part4/torch_train_distributed.py")

In [None]:
!torchrun --nproc-per-node=4 cnn_part4/torch_train_distributed.py