Copyright 2019 Google LLC.
SPDX-License-Identifier: Apache-2.0

In [0]:
!pip install \
  http://storage.googleapis.com/pytorch-tpu-releases/tf-1.13/torch-1.0.0a0+1d94a2b-cp36-cp36m-linux_x86_64.whl  \
  http://storage.googleapis.com/pytorch-tpu-releases/tf-1.13/torch_xla-0.1+5622d42-cp36-cp36m-linux_x86_64.whl

In [0]:
import torch
import torch.nn as nn
import torch_xla

class XlaMulAdd(nn.Module):
  def forward(self, x, y):
    return x * y + y

# Inputs and output to/from XLA models are always in replicated mode. The shapes
# are [NUM_REPLICAS][NUM_VALUES]. A non replicated, single core, execution will
# has NUM_REPLICAS == 1, but retain the same shape rank.
x = torch.rand(3, 5)
y = torch.rand(3, 5)
model = XlaMulAdd()
traced_model = torch.jit.trace(model, (x, y))
xla_model = torch_xla._XLAC.XlaModule(traced_model)
output_xla = xla_model((torch_xla._XLAC.XLATensor(x), torch_xla._XLAC.XLATensor(y)))
expected = model(x, y)
print(output_xla[0][0].to_tensor().data)
print(expected.data)


In [0]:
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torch_xla
import torch_xla_py.utils as xu
import torch_xla_py.xla_model as xm

datadir = '/tmp/mnist-data'
num_workers = 4

class MNIST(nn.Module):

  def __init__(self):
    super(MNIST, self).__init__()
    self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
    self.bn1 = nn.BatchNorm2d(10)
    self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
    self.bn2 = nn.BatchNorm2d(20)
    self.fc1 = nn.Linear(320, 50)
    self.fc2 = nn.Linear(50, 10)

  def forward(self, x):
    x = F.relu(F.max_pool2d(self.conv1(x), 2))
    x = self.bn1(x)
    x = F.relu(F.max_pool2d(self.conv2(x), 2))
    x = self.bn2(x)
    x = x.view(-1, 320)
    x = F.relu(self.fc1(x))
    x = self.fc2(x)
    return F.log_softmax(x, dim=1)

def train_mnist():
  torch.manual_seed(1)
  num_cores = 8
  # Training settings
  lr = 0.01
  momentum = 0.5
  log_interval = 5
  batch_size = 256
  num_epochs = 10

  train_loader = torch.utils.data.DataLoader(
      datasets.MNIST(
          datadir,
          train=True,
          download=True,
          transform=transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize((0.1307,), (0.3081,))
          ])),
      batch_size=batch_size,
      shuffle=True,
      num_workers=num_workers)
  test_loader = torch.utils.data.DataLoader(
      datasets.MNIST(
          datadir,
          train=False,
          transform=transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize((0.1307,), (0.3081,))
          ])),
      batch_size=batch_size,
      shuffle=True,
      num_workers=num_workers)

  model = MNIST()

  # Trace the model.
  devices = [':{}'.format(n) for n in range(0, num_cores)]
  inputs = torch.zeros(batch_size, 1, 28, 28)
  target = torch.zeros(batch_size, dtype=torch.int64)
  xla_model = xm.XlaModel(
      model, [inputs],
      loss_fn=F.nll_loss,
      target=target,
      num_cores=num_cores,
      devices=devices)
  optimizer = optim.SGD(xla_model.parameters_list(), lr=lr, momentum=momentum)

  log_fn = xm.get_log_fn()
  for epoch in range(1, num_epochs + 1):
    xla_model.train(
        train_loader,
        optimizer,
        batch_size,
        log_interval=log_interval,
        metrics_debug=False,
        log_fn=log_fn)
    accuracy = xla_model.test(
        test_loader,
        xm.category_eval_fn(F.nll_loss),
        batch_size,
        log_fn=log_fn)

torch.set_default_tensor_type('torch.FloatTensor')
train_mnist()
