In [None]:
!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl

In [None]:
!pip install pytorch-lightning
!pip install lightning-bolts
!pip install tokenizers
!pip install einops

In [None]:
!git clone https://github.com/krasserm/perceiver-io

In [None]:
import os
import sys

sys.path.append('.')

os.chdir('/content/perceiver-io')
os.getcwd()

In [None]:
!git checkout wip-tpu

In [None]:
import torch_xla
import torch_xla.core.xla_model as xm

device = xm.xla_device()

## Inference

In [None]:
from perceiver.adapter import ImageInputAdapter, ClassificationOutputAdapter
from perceiver.model import PerceiverIO, PerceiverEncoder, PerceiverDecoder

latent_shape = (32, 128)

# Fourier-encode pixel positions and flatten along spatial dimensions
input_adapter = ImageInputAdapter(image_shape=(28, 28, 1), num_frequency_bands=32)

# Project generic Perceiver decoder output to specified number of classes
output_adapter = ClassificationOutputAdapter(num_classes=10, num_output_channels=128)

# Generic Perceiver encoder
encoder = PerceiverEncoder(
    input_adapter=input_adapter,
    latent_shape=latent_shape,
    num_layers=3,
    num_cross_attention_heads=4,
    num_self_attention_heads=4,
    num_self_attention_layers_per_block=3,
    dropout=0.0)

# Generic Perceiver decoder
decoder = PerceiverDecoder(
    output_adapter=output_adapter,
    latent_shape=latent_shape,
    num_cross_attention_heads=1,
    dropout=0.0)

# MNIST classifier implemented as Perceiver IO model
mnist_classifier = PerceiverIO(encoder, decoder)
mnist_classifier = mnist_classifier.to(device)

In [None]:
import torch

with torch.no_grad():
    print(mnist_classifier(torch.rand(2, 28, 28, 1).to(device)))

## Training

In [None]:
import argparse
import pytorch_lightning as pl

from data import IMDBDataModule
from train.train_mlm import LitMLM, main

In [None]:
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser = pl.Trainer.add_argparse_args(parser)
parser = IMDBDataModule.setup_parser(parser)
parser = LitMLM.setup_parser(parser)

group = parser.add_argument_group('main')
group.add_argument('--experiment', default='mlm', help=' ')

parser.set_defaults(
    num_latents=64,
    num_latent_channels=64,
    num_encoder_layers=3,
    dropout=0.0,
    weight_decay=0.0,
    learning_rate=3e-3,
    max_seq_len=512,
    max_steps=50000,
    batch_size=64,
    one_cycle_lr=True,
    one_cycle_pct_start=0.1,
    tpu_cores=[1],
    limit_train_batches=5,
    limit_val_batches=5,
    log_every_n_steps=5,
    progress_bar_refresh_rate=1,
    strategy='tpu_spawn_debug',
    default_root_dir='logs')

In [None]:
main(parser.parse_args([]))