<a href="https://colab.research.google.com/github/nonelse1101/Ai-learn/blob/master/PyTorch_XLA_TPU_MNIST_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

PyTorch-XLA TPU MNIST Training
Demo

This colab demo shows how to run distributed training on TPU for MNIST using PyTorch-XLA

In [1]:
import os
assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'

KeyError: 'COLAB_TPU_ADDR'

Install TPU compatible PyTorch

In [None]:
!pip install cloud-tpu-client==0.10 torch==2.0.0 torchvision==0.15.1 https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-2.0-cp310-cp310-linux_x86_64.whl
# !pip3 install torch~=2.1.0 torchvision torch_xla[tpu]~=2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html
# !pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
!pip install tensorboardX


Collecting torch-xla==2.0
  Downloading https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-2.0-cp310-cp310-linux_x86_64.whl (162.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m162.9/162.9 MB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting cloud-tpu-client==0.10
  Downloading cloud_tpu_client-0.10-py3-none-any.whl (7.4 kB)
Collecting torch==2.0.0
  Downloading torch-2.0.0-cp310-cp310-manylinux1_x86_64.whl (619.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m619.9/619.9 MB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torchvision==0.15.1
  Downloading torchvision-0.15.1-cp310-cp310-manylinux1_x86_64.whl (6.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.0/6.0 MB[0m [31m79.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting google-api-python-client==1.8.0 (from cloud-tpu-client==0.10)
  Downloading google_api_python_client-1.8.0-py3-none-any.whl (57 kB)
[2K     [90m━━━━━━━━━━━━━━

Collecting tensorboardX
  Downloading tensorboardX-2.6.2.2-py2.py3-none-any.whl (101 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m101.7/101.7 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tensorboardX
Successfully installed tensorboardX-2.6.2.2


In [None]:
import os
from tensorflow.python.profiler import profiler_client

tpu_profile_service_address = os.environ['COLAB_TPU_ADDR'].replace('8470', '8466')
print(profiler_client.monitor(tpu_profile_service_address, 100, 2))

  Timestamp: 20:49:25
  TPU type: TPU v2
  Utilization of TPU Matrix Units (higher is better): 0.000%




Import Torch and Torch Vision

In [None]:
import os
import sys
import time
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

Define parameters

In [None]:
FLAGS = {}
FLAGS['datadir'] = "/data"
FLAGS['batch_size'] = 256
FLAGS['learning_rate'] = 0.1
FLAGS['momentum'] = 0.5
FLAGS['num_epochs'] = 2
FLAGS['num_workers'] = 4
FLAGS['num_cores'] = 8
FLAGS['log_steps'] = 20
FLAGS['seed'] = 1

Import Torch-XLA dependencies for distributed training

In [None]:
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.distributed.parallel_loader as pl
import torch_xla.utils.utils as xu
import torch_xla.debug.metrics as met
import torch_xla.test.test_utils as test_utils
import torch.distributed as dist

Define the model architecture

In [None]:
class Net(nn.Module):
  def __init__(self):
      super(Net, self).__init__()
      self.conv1 = nn.Conv2d(1, 32, 3, 1)
      self.conv2 = nn.Conv2d(32, 64, 3, 1)
      self.dropout1 = nn.Dropout(0.25)
      self.dropout2 = nn.Dropout(0.5)
      self.fc1 = nn.Linear(9216, 128)
      self.fc2 = nn.Linear(128, 10)

  def forward(self, x):
      x = self.conv1(x)
      x = F.relu(x)
      x = self.conv2(x)
      x = F.relu(x)
      x = F.max_pool2d(x, 2)
      x = self.dropout1(x)
      x = torch.flatten(x, 1)
      x = self.fc1(x)
      x = F.relu(x)
      x = self.dropout2(x)
      x = self.fc2(x)
      output = F.log_softmax(x, dim=1)
      return output

Main Training Function

In [None]:
def maintrain():

  # summary writer
  if xm.is_master_ordinal():
    writer = test_utils.get_summary_writer('/tmp')

  def _train_update(device, step, loss, tracker, epoch, writer):
    test_utils.print_training_update(
    device,
    step,
    loss.item(),
    tracker.rate(),
    tracker.global_rate(),
    epoch,
    summary_writer=writer)

  torch.manual_seed(FLAGS['seed'])

  # MNIST dataset preparation
  print("preparing MNIST data")
  transform=transforms.Compose([
      transforms.ToTensor(),
      transforms.Normalize((0.1307,), (0.3081,))
      ])
  train_dataset = datasets.MNIST(FLAGS['datadir'], train=True, download=True,
                      transform=transform)
  test_dataset = datasets.MNIST(FLAGS['datadir'], train=False,
                      transform=transform)

  train_kwargs = {'batch_size': FLAGS['batch_size'], 'drop_last': True}
  test_kwargs = {'batch_size': FLAGS['batch_size']}
  train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs)
  test_loader = torch.utils.data.DataLoader(test_dataset, **test_kwargs)

  # get the device and port the model to the device (TPU)
  device = xm.xla_device()
  print("device", device)
  mp_device_loader = pl.MpDeviceLoader(train_loader, device)
  mp_device_loader_test = pl.MpDeviceLoader(test_loader, device)
  model = Net().to(device)

  # get loss function, optimizer, and model
  optimizer = optim.SGD(model.parameters(), lr=FLAGS['learning_rate'], momentum=FLAGS['momentum'])
  loss_fn = nn.NLLLoss()

  # define train loop
  def train(model, train_loader, optimizer):
    tracker = xm.RateTracker()
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
      optimizer.zero_grad()
      output = model(data)
      loss = loss_fn(output, target)
      loss.backward()
      xm.optimizer_step(optimizer)
      tracker.add(FLAGS['batch_size'])

      if batch_idx % FLAGS['log_steps'] == 0:
        print('[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'.format(
          xm.get_ordinal(), batch_idx, loss.item(), tracker.rate(),
          tracker.global_rate(), time.asctime()), flush=True)

  # define test loop
  def test(model, device, test_loader):
    model.eval()
    correct = 0
    total_samples = 0
    data, pred, target = None, None, None
    with torch.no_grad():
      for data, target in test_loader:
        output = model(data)
        pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()
        total_samples += data.size()[0]

    # calculate accuracy after testing is done
    accuracy = 100.0 * correct / total_samples
    print('[xla:{}] Accuracy={:.2f}%'.format(
        xm.get_ordinal(), accuracy), flush=True)
    return accuracy, data, pred, target

  # call train loop and perform training
  for epoch in range(1, FLAGS['num_epochs'] + 1):
    xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now()))
    train(model, mp_device_loader, optimizer)
    xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now()))

  # call test loop and perform testing
  test(model, device, mp_device_loader_test)
  test_utils.close_summary_writer(writer)


In [None]:
# Start training processes
def _mp_fn(rank, flags):
  global FLAGS
  FLAGS = flags
  torch.set_default_tensor_type('torch.FloatTensor')
  maintrain()

xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS['num_cores'], start_method='fork')

preparing MNIST data
device xla:1
Epoch 1 train begin 22:31:22
[xla:0](0) Loss=2.31219 Rate=214.92 GlobalRate=214.91 Time=Wed Oct 25 22:31:23 2023
preparing MNIST data
preparing MNIST data
device xla:0
device xla:0
preparing MNIST data
preparing MNIST data
device devicexla:0 
xla:0
preparing MNIST data
preparing MNIST data
device xla:0
[xla:7](0) Loss=2.31219 Rate=124.78 GlobalRate=124.78 Time=Wed Oct 25 22:31:28 2023
[xla:2](0) Loss=2.31219 Rate=123.43 GlobalRate=123.43 Time=Wed Oct 25 22:31:28 2023
preparing MNIST data
device xla:0
device xla:0
[xla:4](0) Loss=2.31219 Rate=84.48 GlobalRate=84.48 Time=Wed Oct 25 22:31:30 2023
[xla:3](0) Loss=2.31219 Rate=81.53 GlobalRate=81.53 Time=Wed Oct 25 22:31:30 2023
[xla:6](0) Loss=2.31219 Rate=115.08 GlobalRate=115.08 Time=Wed Oct 25 22:31:31 2023
[xla:1](0) Loss=2.31219 Rate=113.26 GlobalRate=113.25 Time=Wed Oct 25 22:31:31 2023
[xla:5](0) Loss=2.31219 Rate=121.01 GlobalRate=121.01 Time=Wed Oct 25 22:31:32 2023
[xla:0](20) Loss=0.92895 Rate=2

Exception in device=TPU:6: local variable 'writer' referenced before assignmentException in device=TPU:4: local variable 'writer' referenced before assignmentException in device=TPU:3: local variable 'writer' referenced before assignment









Exception in device=TPU:2: local variable 'writer' referenced before assignmentTraceback (most recent call last):
Exception in device=TPU:7: local variable 'writer' referenced before assignmentTraceback (most recent call last):
Traceback (most recent call last):


  File "/usr/local/lib/python3.10/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 334, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/usr/local/lib/python3.10/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 334, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 328, in _start_fn
    fn(gindex, *args)
  File "/usr/local/lib/python3.10/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 328, in _start_fn
    fn(gindex, *args)
  File "/usr/local/lib/python3.10/dist-packages/torch_xla/distr

ProcessExitedException: ignored