In [None]:
!pip install git+https://github.com/catalyst-team/catalyst --upgrade

In [None]:
!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl

In [3]:
from torch import __version__
print(__version__)

1.9.0+cu102


In [4]:
from catalyst import __version__
print(__version__)



21.08


In [5]:
from catalyst import SETTINGS
print(SETTINGS.xla_required)

True


In [6]:
import os
from datetime import datetime

import torch
from torch import nn, optim
from torch.utils.data import DataLoader

from catalyst import dl
from catalyst.contrib.datasets import CIFAR10
from catalyst.contrib.nn import ResidualBlock
from catalyst.data import transforms

def conv_block(in_channels, out_channels, pool=False):
    layers = [
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
    ]
    if pool:
        layers.append(nn.MaxPool2d(2))
    return nn.Sequential(*layers)


def resnet9(in_channels: int, num_classes: int, size: int = 16):
    sz, sz2, sz4, sz8 = size, size * 2, size * 4, size * 8
    return nn.Sequential(
        conv_block(in_channels, sz),
        conv_block(sz, sz2, pool=True),
        ResidualBlock(nn.Sequential(conv_block(sz2, sz2), conv_block(sz2, sz2))),
        conv_block(sz2, sz4, pool=True),
        conv_block(sz4, sz8, pool=True),
        ResidualBlock(nn.Sequential(conv_block(sz8, sz8), conv_block(sz8, sz8))),
        nn.Sequential(
            nn.MaxPool2d(4), nn.Flatten(), nn.Dropout(0.2), nn.Linear(sz8, num_classes)
        ),
    )

class CustomRunner(dl.IRunner):
    def __init__(self, logdir):
        super().__init__()
        self._logdir = logdir

    def get_engine(self):
        return dl.DistributedXLAEngine()

    def get_loggers(self):
        return {
            "console": dl.ConsoleLogger(),
            "csv": dl.CSVLogger(logdir=self._logdir),
            "tensorboard": dl.TensorboardLogger(logdir=self._logdir),
        }

    @property
    def stages(self):
        return ["train"]

    def get_stage_len(self, stage: str) -> int:
        return 3

    def get_loaders(self, stage: str):
        transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
        )
        train_data = CIFAR10(os.getcwd(), train=False, download=True, transform=transform)
        valid_data = CIFAR10(os.getcwd(), train=False, download=True, transform=transform)

        if self.engine.is_ddp:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_data,
                num_replicas=self.engine.world_size,
                rank=self.engine.rank,
                shuffle=True
            )
            valid_sampler = torch.utils.data.distributed.DistributedSampler(
                valid_data,
                num_replicas=self.engine.world_size,
                rank=self.engine.rank,
                shuffle=False
            )
        else:
            train_sampler = valid_sampler = None

        return {
            "train": DataLoader(train_data, batch_size=32, sampler=train_sampler),
            "valid": DataLoader(valid_data, batch_size=32, sampler=valid_sampler),
        }

    def get_model(self, stage: str):
        model = self.model \
            if self.model is not None \
            else resnet9(in_channels=3, num_classes=10)
        return model

    def get_criterion(self, stage: str):
        return nn.CrossEntropyLoss()

    def get_optimizer(self, stage: str, model):
        return optim.Adam(model.parameters(), lr=1e-3)

    def get_scheduler(self, stage: str, optimizer):
        return optim.lr_scheduler.MultiStepLR(optimizer, [5, 8], gamma=0.3)

    def get_callbacks(self, stage: str):
        return {
            "criterion": dl.CriterionCallback(
                metric_key="loss", input_key="logits", target_key="targets"
            ),
            "optimizer": dl.OptimizerCallback(metric_key="loss"),
            "scheduler": dl.SchedulerCallback(loader_key="valid", metric_key="loss"),
            "accuracy": dl.AccuracyCallback(
                input_key="logits", target_key="targets", topk_args=(1, 3, 5)
            ),
            "checkpoint": dl.CheckpointCallback(
                self._logdir,
                loader_key="valid",
                metric_key="accuracy",
                minimize=False,
                save_n_best=1,
            ),
            "tqdm": dl.TqdmCallback(),
        }

    def handle_batch(self, batch):
        x, y = batch
        logits = self.model(x)

        self.batch = {
            "features": x,
            "targets": y,
            "logits": logits,
        }

logdir = f"logs/{datetime.now().strftime('%Y%m%d-%H%M%S')}"
runner = CustomRunner(logdir)
runner.run()

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /content/cifar-10-python.tar.gz


0it [00:00, ?it/s]

Extracting /content/cifar-10-python.tar.gz to /content
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


1/3 * Epoch (train):   0%|          | 0/40 [00:00<?, ?it/s]

train (1/3) accuracy: 0.1721 | accuracy/std: 0.06759295194292164 | accuracy01: 0.1721 | accuracy01/std: 0.06759295194292164 | accuracy03: 0.43000000000000005 | accuracy03/std: 0.10283063788725659 | accuracy05: 0.6356 | accuracy05/std: 0.09907314603800266 | loss: 2.6043901133060454 | loss/mean: 2.6043901133060454 | loss/std: 0.4253420988219556 | lr: 0.001 | momentum: 0.9


1/3 * Epoch (valid):   0%|          | 0/40 [00:00<?, ?it/s]

valid (1/3) accuracy: 0.2018 | accuracy/std: 0.06934377188425647 | accuracy01: 0.2018 | accuracy01/std: 0.06934377188425647 | accuracy03: 0.49369999999999997 | accuracy03/std: 0.08922528586622942 | accuracy05: 0.6991 | accuracy05/std: 0.0779836919827547 | loss: 2.1943360423326492 | loss/mean: 2.1943360423326492 | loss/std: 0.14056085986746458 | lr: 0.001 | momentum: 0.9
* Epoch (1/3) lr: 0.001 | momentum: 0.9


2/3 * Epoch (train):   0%|          | 0/40 [00:00<?, ?it/s]

train (2/3) accuracy: 0.2352 | accuracy/std: 0.07977614820844778 | accuracy01: 0.2352 | accuracy01/std: 0.07977614820844778 | accuracy03: 0.5468 | accuracy03/std: 0.09525751029922952 | accuracy05: 0.7546 | accuracy05/std: 0.07993180264150548 | loss: 2.1500546091318133 | loss/mean: 2.1500546091318133 | loss/std: 0.20461360113834592 | lr: 0.001 | momentum: 0.9


2/3 * Epoch (valid):   0%|          | 0/40 [00:00<?, ?it/s]

valid (2/3) accuracy: 0.3197000000000001 | accuracy/std: 0.0744005043525841 | accuracy01: 0.3197000000000001 | accuracy01/std: 0.0744005043525841 | accuracy03: 0.6707 | accuracy03/std: 0.08216782796171247 | accuracy05: 0.8377000000000001 | accuracy05/std: 0.06329140604060719 | loss: 1.8689769260644913 | loss/mean: 1.8689769260644913 | loss/std: 0.14241773655100376 | lr: 0.001 | momentum: 0.9
* Epoch (2/3) lr: 0.001 | momentum: 0.9


3/3 * Epoch (train):   0%|          | 0/40 [00:00<?, ?it/s]

train (3/3) accuracy: 0.2831 | accuracy/std: 0.07192947569216972 | accuracy01: 0.2831 | accuracy01/std: 0.07192947569216972 | accuracy03: 0.6156 | accuracy03/std: 0.08256173782313522 | accuracy05: 0.8016000000000001 | accuracy05/std: 0.07111062027177223 | loss: 2.0018269671440123 | loss/mean: 2.0018269671440123 | loss/std: 0.17193453746866305 | lr: 0.001 | momentum: 0.9


3/3 * Epoch (valid):   0%|          | 0/40 [00:00<?, ?it/s]

valid (3/3) accuracy: 0.32799999999999996 | accuracy/std: 0.07897211054352626 | accuracy01: 0.32799999999999996 | accuracy01/std: 0.07897211054352626 | accuracy03: 0.6714 | accuracy03/std: 0.08150206452455136 | accuracy05: 0.844 | accuracy05/std: 0.06449319576529647 | loss: 1.866470173597336 | loss/mean: 1.866470173597336 | loss/std: 0.17844182547738147 | lr: 0.001 | momentum: 0.9
* Epoch (3/3) lr: 0.001 | momentum: 0.9
Top best models:
logs/20210820-054740/train.3.pth	0.3280


<__main__.CustomRunner at 0x7f8bcd0dca10>