<a href="https://colab.research.google.com/github/Ditwoo/batteries/blob/master/examples/xla/example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Installing required packages

In [1]:
!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.6-cp36-cp36m-linux_x86_64.whl > /dev/null
!pip install git+https://github.com/ditwoo/batteries > /dev/null

  Running command git clone -q https://github.com/ditwoo/batteries /tmp/pip-req-build-7off7k8v


In [2]:
import os

os.environ["XLA_USE_BF16"] = "1"
os.environ["XLA_TENSOR_ALLOCATOR_MAX_SIZE"] = "100000000"

import shutil
import time
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data.distributed as dist
from torch.utils.data import Dataset, DataLoader

from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, Normalize, ToTensor
from torchvision.models import resnet18

import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.distributed.parallel_loader as pl

from batteries import (
    seed_all,
    CheckpointManager,
    TensorboardLogger,
    t2d,
    make_checkpoint,
)
from batteries.progress import tqdm



In [3]:
def reduce_fn(vals):
    return sum(vals) / len(vals)

In [4]:
def get_transforms(dataset: str):
    """Get transforms depends from dataset.

    Args:
        dataset (str): dataset type (train or valid)

    Returns:
        dataset transforms
    """
    return Compose([
        ToTensor(),
        # imagenet:
        # Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        # cifar100:
        Normalize(mean=(0.5071, 0.4867, 0.4408), std=(0.2675, 0.2565, 0.2761)),
    ])

In [5]:
def train_fn(
    model,
    loader,
    device,
    loss_fn,
    optimizer,
    scheduler=None,
    accum_steps: int = 1,
):
    model.train()
    losses = []

    to_iter = enumerate(loader)


    for _idx, (bx, by) in enumerate(loader):
        bx, by = t2d((bx, by), device)

        optimizer.zero_grad()

        outputs = model(bx)

        loss = loss_fn(outputs, by)
        _loss = loss.item()
        losses.append(_loss)
        loss.backward()

        if (_idx + 1) % accum_steps == 0:
            # optimizer.step()
            xm.optimizer_step(optimizer)
            if scheduler is not None:
                scheduler.step()

    metrics = {
        "loss": np.mean(losses),
    }
    return metrics

In [6]:
def valid_fn(model, loader, device, loss_fn):
    model.eval()

    losses = []
    num_correct = 0
    total = 0
    with torch.no_grad() as progress:
        to_iter = loader
        for bx, by in loader:
            bx, by = t2d((bx, by), device)
            
            outputs = model(bx)
            
            loss = loss_fn(outputs, by).item()
            losses.append(loss)

            num_correct += torch.eq(
                by.flatten().detach(),
                outputs.argmax(1).flatten().detach()
            ).sum().item()
            total += bx.size(0)

    dataset_acc = num_correct / total
    metrics = {
        "loss": np.mean(losses),
        "accuracy": dataset_acc,
    }
    return metrics

In [7]:
def log_metrics(stage: str, metrics: dict, loader: str, epoch: int) -> dict:
    """Write metrics to tensorboard and stdout.
    Args:
        stage (str): stage name
        metrics (dict): metrics computed during training/validation steps
        loader (str): loader name
        epoch (int): epoch number
    
    Returns:
        dict with reduced metrics
    """
    order = ("loss", "accuracy")
    line = []
    reduced_metrics = {}
    for metric_name in order:
        # loaders have different metrics
        if metric_name in metrics:
            value = xm.mesh_reduce(metric_name, metrics[metric_name], reduce_fn)
            reduced_metrics[metric_name] = value
            line.append(f"{metric_name:>10}: {value:.4f}")
    s = f"{loader}:\n" + "\n".join(line)
    xm.master_print(s)
    return reduced_metrics

### Datasets

In [8]:
# make datasets/dataloaders
transforms = get_transforms("")

train_dataset = CIFAR10(
    "/tmp/CIFAR",
    train=True,
    download=True,
    transform=transforms
)

test_dataset = CIFAR10(
    "/tmp/CIFAR",
    train=False,
    download=True,
    transform=transforms
)

Files already downloaded and verified
Files already downloaded and verified


In [9]:
def experiment(index, flags):
    # prepare for training
    seed_all(flags.get("seed", 1234))
    device = xm.xla_device()

    global train_dataset, test_dataset
    
    # # Downloads train and test datasets
    # # Note: master goes first and downloads the dataset only once (xm.rendezvous)
    # #       all the other workers wait for the master to be done downloading.
    # if not xm.is_master_ordinal():
    #     xm.rendezvous('download_only_once')

    # if xm.is_master_ordinal():
    #     xm.rendezvous('download_only_once')

    train_sampler = dist.DistributedSampler(
        train_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True,
        seed=1234,
    )
    train_loader = DataLoader(
        train_dataset,
        batch_size=flags.get("train_batch_size", 64),
        sampler=train_sampler,
        num_workers=flags.get("num_workers", 8),
        drop_last=True,
    )

    test_sampler = dist.DistributedSampler(
        test_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=False,
        seed=1234,
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=flags.get("test_batch_size", 64),
        sampler=test_sampler,
        num_workers=flags.get("num_workers", 8),
    )

    # general instructions
    main_metric = "accuracy"
    minimize_metric = False

    stage = "stage_0"
    n_epochs = flags.get("num_epochs", 5)

    checkpointer = CheckpointManager(
        logdir=os.path.join(flags.get("logdir", "."), stage),
        metric=main_metric,
        metric_minimization=minimize_metric,
        save_n_best=3,
        save_fn=xm.save,
    )

    model = resnet18(
        pretrained=False,
        progress=False,
        num_classes=10
    )
    model = model.to(device)
    learning_rate = 1e-3 # * xm.xrt_world_size()
    optimizer = optim.Adam(
        model.parameters(), 
        **flags.get("optimizer", {"lr": 1e-3}),
    )
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)
    criterion = nn.CrossEntropyLoss()
    
    for epoch in range(1, n_epochs + 1):
        current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
        xm.master_print(f"[{current_time}]\n[Epoch {epoch}/{n_epochs}]")

        para_train_loader = (
            pl.ParallelLoader(train_loader, [device]).per_device_loader(device)
        )
        train_metrics = train_fn(
            model, para_train_loader, device, criterion, optimizer
        )
        reduced_train_metrics = log_metrics(
            stage, train_metrics, "train", epoch
        )
        
        para_test_loader = (
            pl.ParallelLoader(test_loader, [device]).per_device_loader(device)
        )
        valid_metrics = valid_fn(
            model, para_test_loader, device, criterion,
        )
        reduced_valid_metrics = log_metrics(
            stage, valid_metrics, "valid", epoch
        )
        xm.master_print("")

        checkpointer.process(
            metric_value=reduced_valid_metrics[main_metric],
            epoch=epoch,
            checkpoint=make_checkpoint(
                stage, epoch, model, 
                metrics={
                    "train": reduced_train_metrics,
                    "valid": reduced_valid_metrics,
                },
            )
        )

        scheduler.step()

### Training

In [10]:
logdir = "./logs/"

if os.path.isdir(logdir):
    shutil.rmtree(logdir, ignore_errors=True)
    print(f"* Removed existing '{logdir}' directory!")

flags = {
    "logdir": logdir,
    "seed": 321,
    "train_batch_size": 128,
    "test_batch_size": 256,
    "num_workers": 8,
    "num_epochs": 10,
    "optimizer": {
        "lr": 1e-3,
    },
}

xmp.spawn(experiment, args=(flags,), nprocs=8, start_method='fork')

* Removed existing './logs/' directory!
[2020-09-18 11:32:55]
[Epoch 1/10]
train:
      loss: 1.5179
valid:
      loss: 1.3156
  accuracy: 0.5279

[2020-09-18 11:33:29]
[Epoch 2/10]
train:
      loss: 1.0833
valid:
      loss: 1.1822
  accuracy: 0.5891

[2020-09-18 11:33:55]
[Epoch 3/10]
train:
      loss: 0.8676
valid:
      loss: 1.0962
  accuracy: 0.6168

[2020-09-18 11:34:23]
[Epoch 4/10]
train:
      loss: 0.7172
valid:
      loss: 1.2264
  accuracy: 0.5922

[2020-09-18 11:34:49]
[Epoch 5/10]
train:
      loss: 0.5896
valid:
      loss: 1.3979
  accuracy: 0.5920

[2020-09-18 11:35:15]
[Epoch 6/10]
train:
      loss: 0.4582
valid:
      loss: 1.2373
  accuracy: 0.6400

[2020-09-18 11:35:41]
[Epoch 7/10]
train:
      loss: 0.2949
valid:
      loss: 1.2663
  accuracy: 0.6490

[2020-09-18 11:36:08]
[Epoch 8/10]
train:
      loss: 0.1502
valid:
      loss: 1.1846
  accuracy: 0.6789

[2020-09-18 11:36:34]
[Epoch 9/10]
train:
      loss: 0.0699
valid:
      loss: 1.1767
  accuracy: 0.682

In [11]:
!ls -la logs/stage_0/

total 218792
drwxr-xr-x 2 root root     4096 Sep 18 11:37 .
drwxr-xr-x 3 root root     4096 Sep 18 11:33 ..
-rw-r--r-- 1 root root 44802903 Sep 18 11:37 best.pth
-rw-r--r-- 1 root root 44802903 Sep 18 11:37 exp_10.pth
-rw-r--r-- 1 root root 44802903 Sep 18 11:36 exp_8.pth
-rw-r--r-- 1 root root 44802903 Sep 18 11:37 exp_9.pth
-rw-r--r-- 1 root root 44802903 Sep 18 11:37 last.pth
-rw-r--r-- 1 root root      648 Sep 18 11:37 metrics.json


In [12]:
from batteries import load_checkpoint

model = resnet18(
    pretrained=False,
    progress=False,
    num_classes=10
)
load_checkpoint("./logs/stage_0/best.pth", model)

<= Loaded model from './logs/stage_0/best.pth'
Stage: stage_0
Epoch: 10
Checkpoint metrics:
{'train': {'loss': 0.05153878529866536}, 'valid': {'loss': 1.1798828125, 'accuracy': 0.6829999999999999}}
