# Introduction

This notebook demonstrates the use of distributed training of a deep neural network. It follows the [this tutorial](https://pytorch.org/tutorials/intermediate/dist_tuto.html).

# Initial Setup

We will use the functions which are implemented in `partition.py` and `train.py`.
When the file `train.py` is run from the terminal, it uses the configuration parameters, which are listed in `config.yaml`. In our initial setup, we will use the following parameters: 

In [1]:
!cat config.yaml

size: 2
 partition_sizes: [0.5, 0.5]
 custom_partition: False
 params:
  lr: 0.01
  momentum: 0.5
  use_batch_norm: False
  async_op: False

Running `train.py` will perform distributed training of a neural network model across 2 parallel processes, using the MNIST dataset. Here we will use the internal functions in `train.py`, such that it will be easier to see which exact parameters we'll be modifying at different stages.

In [6]:
import yaml

from torch.multiprocessing import Process
from train import init_process, run

try:
    with open(r'config.yaml') as file:
        config_dict = yaml.load(file, Loader=yaml.SafeLoader)
except:
    print(f'Could not open configuration file {config_file}. Aborting.')

size = config_dict['size']
partition_sizes = config_dict['partition_sizes']
custom_partition = config_dict['custom_partition']
params = config_dict['params']

def distributed_training():
    processes = []
    for rank in range(size):
        p = Process(target=init_process, 
                    args=(run, rank, size, partition_sizes,
                          custom_partition, params))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

distributed_training()

Rank 0, epoch 0:  Train Loss 1.310  Train Acc. 0.558  --|-- Val. Loss 0.645 Val. Acc. 0.799
Rank 1, epoch 0:  Train Loss 1.306  Train Acc. 0.561  --|-- Val. Loss 0.642 Val. Acc. 0.803
Rank 0, epoch 1:  Train Loss 0.548  Train Acc. 0.838  --|-- Val. Loss 0.422 Val. Acc. 0.869
Rank 1, epoch 1:  Train Loss 0.539  Train Acc. 0.835  --|-- Val. Loss 0.441 Val. Acc. 0.866
Rank 0, epoch 2:  Train Loss 0.428  Train Acc. 0.873  --|-- Val. Loss 0.363 Val. Acc. 0.893
Rank 1, epoch 2:  Train Loss 0.418  Train Acc. 0.873  --|-- Val. Loss 0.375 Val. Acc. 0.891
Rank 1, epoch 3:  Train Loss 0.359  Train Acc. 0.895  --|-- Val. Loss 0.322 Val. Acc. 0.906
Rank 0, epoch 3:  Train Loss 0.369  Train Acc. 0.892  --|-- Val. Loss 0.306 Val. Acc. 0.912
Rank 1, epoch 4:  Train Loss 0.312  Train Acc. 0.909  --|-- Val. Loss 0.288 Val. Acc. 0.913
Rank 0, epoch 4:  Train Loss 0.320  Train Acc. 0.906  --|-- Val. Loss 0.265 Val. Acc. 0.921
Rank 0, epoch 5:  Train Loss 0.290  Train Acc. 0.915  --|-- Val. Loss 0.258 Val.

We observe a decrease in the loss and an increase in accuracy on both training and validation sets.


# Unbalanced Partition

Let's see what happens if we modify the partition ratio to 70% : 30%.

In [7]:
partition_sizes = [0.7, 0.3]
distributed_training()

Rank 1, epoch 0:  Train Loss 1.689  Train Acc. 0.418  --|-- Val. Loss 0.891 Val. Acc. 0.725
Rank 1, epoch 1:  Train Loss 0.711  Train Acc. 0.779  --|-- Val. Loss 0.602 Val. Acc. 0.821
Rank 0, epoch 0:  Train Loss 1.114  Train Acc. 0.631  --|-- Val. Loss 0.526 Val. Acc. 0.837
Rank 1, epoch 2:  Train Loss 0.544  Train Acc. 0.835  --|-- Val. Loss 0.469 Val. Acc. 0.856
Rank 1, epoch 3:  Train Loss 0.457  Train Acc. 0.862  --|-- Val. Loss 0.418 Val. Acc. 0.879
Rank 0, epoch 1:  Train Loss 0.473  Train Acc. 0.857  --|-- Val. Loss 0.371 Val. Acc. 0.888
Rank 1, epoch 4:  Train Loss 0.404  Train Acc. 0.881  --|-- Val. Loss 0.384 Val. Acc. 0.887
Rank 1, epoch 5:  Train Loss 0.363  Train Acc. 0.894  --|-- Val. Loss 0.348 Val. Acc. 0.897
Rank 0, epoch 2:  Train Loss 0.360  Train Acc. 0.896  --|-- Val. Loss 0.314 Val. Acc. 0.907
Rank 1, epoch 6:  Train Loss 0.339  Train Acc. 0.898  --|-- Val. Loss 0.311 Val. Acc. 0.909
Rank 1, epoch 7:  Train Loss 0.312  Train Acc. 0.908  --|-- Val. Loss 0.285 Val.

We can see that the process has failed with an error due to a lack of synchronization between the two processes. This can be solved by setting the *async_op* parameter in the dist.all_reduce function. Let's observed what happens if we do that:

In [8]:
params['async_op'] = True
distributed_training()

Rank 1, epoch 0:  Train Loss 2.027  Train Acc. 0.280  --|-- Val. Loss 1.302 Val. Acc. 0.571
Rank 1, epoch 1:  Train Loss 0.968  Train Acc. 0.694  --|-- Val. Loss 0.761 Val. Acc. 0.767
Rank 0, epoch 0:  Train Loss 1.364  Train Acc. 0.535  --|-- Val. Loss 0.644 Val. Acc. 0.803
Rank 1, epoch 2:  Train Loss 0.687  Train Acc. 0.788  --|-- Val. Loss 0.581 Val. Acc. 0.821
Rank 1, epoch 3:  Train Loss 0.564  Train Acc. 0.829  --|-- Val. Loss 0.512 Val. Acc. 0.848
Rank 0, epoch 1:  Train Loss 0.578  Train Acc. 0.823  --|-- Val. Loss 0.468 Val. Acc. 0.858
Rank 1, epoch 4:  Train Loss 0.495  Train Acc. 0.849  --|-- Val. Loss 0.473 Val. Acc. 0.853
Rank 1, epoch 5:  Train Loss 0.444  Train Acc. 0.867  --|-- Val. Loss 0.433 Val. Acc. 0.875
Rank 1, epoch 6:  Train Loss 0.414  Train Acc. 0.877  --|-- Val. Loss 0.378 Val. Acc. 0.883
Rank 0, epoch 2:  Train Loss 0.449  Train Acc. 0.867  --|-- Val. Loss 0.395 Val. Acc. 0.882
Rank 1, epoch 7:  Train Loss 0.377  Train Acc. 0.886  --|-- Val. Loss 0.356 Val.

This time, the training process has completed smoothly.

# Adding Batch Normalization

Next, let's see how the addition of Batch Normalization affects performance.

In [9]:
params['use_batch_norm'] = True
distributed_training()

Rank 1, epoch 0:  Train Loss 1.483  Train Acc. 0.515  --|-- Val. Loss 0.903 Val. Acc. 0.729
Rank 1, epoch 1:  Train Loss 0.745  Train Acc. 0.777  --|-- Val. Loss 0.617 Val. Acc. 0.818
Rank 0, epoch 0:  Train Loss 1.002  Train Acc. 0.684  --|-- Val. Loss 0.504 Val. Acc. 0.846
Rank 1, epoch 2:  Train Loss 0.548  Train Acc. 0.836  --|-- Val. Loss 0.465 Val. Acc. 0.867
Rank 1, epoch 3:  Train Loss 0.435  Train Acc. 0.872  --|-- Val. Loss 0.416 Val. Acc. 0.876
Rank 0, epoch 1:  Train Loss 0.430  Train Acc. 0.874  --|-- Val. Loss 0.340 Val. Acc. 0.904
Rank 1, epoch 4:  Train Loss 0.418  Train Acc. 0.875  --|-- Val. Loss 0.391 Val. Acc. 0.884
Rank 1, epoch 5:  Train Loss 0.342  Train Acc. 0.901  --|-- Val. Loss 0.320 Val. Acc. 0.911
Rank 0, epoch 2:  Train Loss 0.330  Train Acc. 0.904  --|-- Val. Loss 0.264 Val. Acc. 0.925
Rank 1, epoch 6:  Train Loss 0.325  Train Acc. 0.905  --|-- Val. Loss 0.307 Val. Acc. 0.913
Rank 1, epoch 7:  Train Loss 0.291  Train Acc. 0.915  --|-- Val. Loss 0.282 Val.

We can see that slight improvement has been achieved due to the regularizing effect of batch normalization.

# Training with Disjoint Subsets of Samples

Let's see what happens if we split the samples such that process 1 will only see labels 0-4, and process 2 will only see labels 5-9.

In [10]:
partition_sizes = [0.5, 0.5]
custom_partition = True
params['async_op'] = True
distributed_training()

Rank 1, epoch 0:  Train Loss 0.744  Train Acc. 0.743  --|-- Val. Loss 3.945 Val. Acc. 0.428
Rank 0, epoch 0:  Train Loss 0.573  Train Acc. 0.825  --|-- Val. Loss 3.998 Val. Acc. 0.484
Rank 1, epoch 1:  Train Loss 0.264  Train Acc. 0.919  --|-- Val. Loss 4.316 Val. Acc. 0.449
Rank 0, epoch 1:  Train Loss 0.191  Train Acc. 0.948  --|-- Val. Loss 4.451 Val. Acc. 0.489
Rank 1, epoch 2:  Train Loss 0.191  Train Acc. 0.943  --|-- Val. Loss 4.682 Val. Acc. 0.462
Rank 0, epoch 2:  Train Loss 0.142  Train Acc. 0.961  --|-- Val. Loss 4.475 Val. Acc. 0.497
Rank 1, epoch 3:  Train Loss 0.165  Train Acc. 0.952  --|-- Val. Loss 4.757 Val. Acc. 0.459
Rank 0, epoch 3:  Train Loss 0.114  Train Acc. 0.970  --|-- Val. Loss 4.814 Val. Acc. 0.497
Rank 1, epoch 4:  Train Loss 0.134  Train Acc. 0.960  --|-- Val. Loss 4.926 Val. Acc. 0.461
Rank 0, epoch 4:  Train Loss 0.100  Train Acc. 0.974  --|-- Val. Loss 4.841 Val. Acc. 0.503
Rank 1, epoch 5:  Train Loss 0.137  Train Acc. 0.961  --|-- Val. Loss 5.196 Val.

Such setup clearly results in an overfit, since every process is exposed to only half of the labels. Resultantly, on each process, the optimization results in an overfit, which cannot be balanced by simple averaging of the weights, due to the highly non-linear nature of the optimization objective.