<a href="https://colab.research.google.com/github/iypc-team/CoLab/blob/master/Torch_TPUs_MultiCore.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from __future__ import absolute_import
from IPython.display import clear_output
import glob, os, time, shutil
startTime = time.time()
assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'

In [None]:
def convertSeconds(start):
    import time
    start_time = start
    end_time = time.time()

    # start_time = 0
    # end_time = 3799

    elapse_time = end_time - start_time
    if elapse_time >= 3600:
        hours, remainder = divmod(elapse_time, 3600)
        minutes, seconds = divmod(remainder, 60)

        print(f'runtime: {int(hours)}hours {int(minutes)}mins {int(seconds)}secs')
    else: pass
    if elapse_time < 3600:
        minutes, seconds = divmod(elapse_time, 60)
        print(f'runtime: {int(minutes)}min {round(seconds,1)}sec')

In [None]:

# Installs PyTorch, PyTorch/XLA, and Torchvision
# Copy this cell into your own notebooks to use PyTorch on Cloud TPUs 
# Warning: this may take a couple minutes to run
!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp36-cp36m-linux_x86_64.whl  
print()

In [None]:
pth=os.path.realpath('sample_data')
if os.path.exists(pth):
    shutil.rmtree(pth)

### Multicore Training

The following cell defines a map function that trains AlexNet on the Fashion MNIST dataset on all eight available Cloud TPU cores. The function is long and contained in a single cell, so it includes lengthy comments. It does the following:

- **Setup**: every process sets the same random seed and acquires the device assigned to it (via the accelerator context, see above)
- **Dataloading**: each process acquires its own copy of the dataset, but their sampling from it is coordinated to not overlap.
- **Network creation**: each process has its own copy of the network, but these copies are identical since each process's random seed is the same.
- **Training** and **Evaluation**: Training and evaluation occur as usual but use a ParallelLoader.

Aside from a couple different classes, like the DistributedSampler and the ParallelLoader, the big difference between single core and multicore training is behind the scenes. The `step()` function now not only propagates gradients, but uses the Cloud TPU context to synchronize gradient updates across each processes' copy of the network. This ensures that each processes' network copy stays "in sync" (they are all identical). 

In [None]:
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torchsummary

import torchvision
from torchvision import datasets
import torchvision.transforms as transforms
import torch_xla.distributed.parallel_loader as pl
import time
import datetime
from datetime import timedelta

def map_fn(index, flags):
  
  ## Setup 

  # Sets a common random seed - both for initialization and ensuring graph is the same
  torch.manual_seed(flags['seed'])

  # Acquires the (unique) Cloud TPU core corresponding to this process's index
  device = xm.xla_device()  

  ## Dataloader construction

  # Creates the transform for the raw Torchvision data
  # See https://pytorch.org/docs/stable/torchvision/models.html for normalization
  # Pre-trained TorchVision models expect RGB (3 x H x W) images
  # H and W should be >= 224
  # Loaded into [0, 1] and normalized as follows:
  normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                  std=[0.229, 0.224, 0.225])
  to_rgb = transforms.Lambda(lambda image: image.convert('RGB'))
  resize = transforms.Resize((224, 224))
  my_transform = transforms.Compose([resize, to_rgb, transforms.ToTensor(), normalize])

  # Downloads train and test datasets
  # Note: master goes first and downloads the dataset only once (xm.rendezvous)
  #   all the other workers wait for the master to be done downloading.

  if not xm.is_master_ordinal():
    xm.rendezvous('download_only_once')

  train_dataset = datasets.FashionMNIST(
    "/tmp/fashionmnist",
    train=True,
    download=True,
    transform=my_transform)

  test_dataset = datasets.FashionMNIST(
    "/tmp/fashionmnist",
    train=False,
    download=True,
    transform=my_transform)
  
  if xm.is_master_ordinal():
    xm.rendezvous('download_only_once')
  
  # Creates the (distributed) train sampler, which let this process only access
  # its portion of the training dataset.
  train_sampler = torch.utils.data.distributed.DistributedSampler(
    train_dataset,
    num_replicas=xm.xrt_world_size(),
    rank=xm.get_ordinal(),
    shuffle=True)
  
  test_sampler = torch.utils.data.distributed.DistributedSampler(
    test_dataset,
    num_replicas=xm.xrt_world_size(),
    rank=xm.get_ordinal(),
    shuffle=False)
  
  # Creates dataloaders, which load data in batches
  # Note: test loader is not shuffled or sampled
  train_loader = torch.utils.data.DataLoader(
      train_dataset,
      batch_size=flags['batch_size'],
      sampler=train_sampler,
      num_workers=flags['num_workers'],
      drop_last=True)

  test_loader = torch.utils.data.DataLoader(
      test_dataset,
      batch_size=flags['batch_size'],
      sampler=test_sampler,
      shuffle=False,
      num_workers=flags['num_workers'],
      drop_last=True)
  

  ## Network, optimizer, and loss function creation

  # Creates AlexNet for 10 classes
  # Note: each process has its own identical copy of the model
  #  Even though each model is created independently, they're also
  #  created in the same way.
  net = torchvision.models.alexnet(num_classes=10).to(device).train()

  loss_fn = torch.nn.CrossEntropyLoss()
  optimizer = torch.optim.Adam(net.parameters())


  ## Trains
  train_start = time.time()
  for epoch in range(flags['num_epochs']):
    para_train_loader = pl.ParallelLoader(train_loader, [device]).per_device_loader(device)
    for batch_num, batch in enumerate(para_train_loader):
      data, targets = batch 

      # Acquires the network's best guesses at each class
      output = net(data)

      # Computes loss
      loss = loss_fn(output, targets)

      # Updates model
      optimizer.zero_grad()
      loss.backward()

      # Note: optimizer_step uses the implicit Cloud TPU context to
      #  coordinate and synchronize gradient updates across processes.
      #  This means that each process's network has the same weights after
      #  this is called.
      # Warning: this coordination requires the actions performed in each 
      #  process are the same. In more technical terms, the graph that
      #  PyTorch/XLA generates must be the same across processes. 
      xm.optimizer_step(optimizer)  # Note: barrier=True not needed when using ParallelLoader 

  elapsed_train_time = time.time() - train_start
  print("\nProcess", index, "finished training. Train time was:", elapsed_train_time) 


  ## Evaluation
  # Sets net to eval and no grad context 
  net.eval()
  eval_start = time.time()
  with torch.no_grad():
    num_correct = 0
    total_guesses = 0

    para_train_loader = pl.ParallelLoader(test_loader, [device]).per_device_loader(device)
    for batch_num, batch in enumerate(para_train_loader):
      data, targets = batch

      # Acquires the network's best guesses at each class
      output = net(data)
      best_guesses = torch.argmax(output, 1)

      # Updates running statistics
      num_correct += torch.eq(targets, best_guesses).sum().item()
      total_guesses += flags['batch_size']
  
  elapsed_eval_time = time.time() - eval_start
  print("Process", index, "finished evaluation. time:", elapsed_eval_time)

  print("Process", index, "guessed", num_correct, "of", total_guesses, "correctly for", num_correct/total_guesses * 100, "% accuracy.")


In [None]:
# Configures training (and evaluation) parameters
flags = {}

flags['batch_size'] = 32
flags['num_workers'] = 8
flags['num_epochs'] = 1
flags['seed'] = 4321

xmp.spawn(map_fn, args=(flags,), nprocs=8, start_method='fork')

convertSeconds(start=startTime)

In [None]:

# torchsummary.summary(model=xm, input_size=)

help(torchsummary.summary)

The network should take about 30 seconds to train and about 10 seconds to evaluate on each process. Using an entire Cloud TPU is, as expected, dramatically faster than training and evaluating on a single Cloud TPU core.


##What's Next?

This notebook broke down training AlexNet on the Fashion MNIST dataset using an entire Cloud TPU. A [previous notebook](https://colab.research.google.com/github/pytorch/xla/blob/master/contrib/colab/single-core-alexnet-fashion-mnist.ipynb) showed how to train AlexNet on Fashion MNIST using only a single Cloud TPU core, and can be a helpful point of comparison. 

In particular, this notebook showed us how to:

- Define a "map function" that runs in parallel on one process per Cloud TPU core. 
- Run the map function using `spawn`.
- Understand the Cloud TPU context, its benefits, like automatic cross-process coordination, and its limits, like needing each process to perform the same Cloud TPU operations.
- Load and sample the datasets.
- Train and evaluate the network.

Additional notebooks demonstrating how to run PyTorch on Cloud TPUs can be found [here](https://github.com/pytorch/xla). While Colab provides a free Cloud TPU, training is even faster on [Google Cloud Platform](https://github.com/pytorch/xla#Cloud), especially when using multiple Cloud TPUs in a Cloud TPU pod. Scaling from a single Cloud TPU, like in this notebook, to many Cloud TPUs in a pod is easy, too. You use the same code as this notebook and just spawn more processes.