A small example, on how to implement an EfficientFormer and train it on Cifar10. Please note that we pull in some code from xformers/examples here, so not all the training code is readily visible, but it's not very far away.

Let's start with the dependencies

In [8]:
!pip install --pre torch
!pip install xformers pytorch_lightning numpy pytorch_lightning lightning-bolts torchmetrics 

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com


Now let's import everything we need, and check that the above worked fine

In [4]:
import pytorch_lightning as pl
import torch
from pl_bolts.datamodules import CIFAR10DataModule
from torch import nn
from torchmetrics import Accuracy

from examples.cifarViT import Classifier, VisionTransformer
from xformers.components import MultiHeadDispatch
from xformers.components.attention import ScaledDotProduct
from xformers.components.patch_embedding import PatchEmbeddingConfig  # noqa
from xformers.components.patch_embedding import build_patch_embedding  # noqa
from xformers.factory import xFormer, xFormerConfig
from xformers.helpers.hierarchical_configs import (
    BasicLayerConfig,
    get_hierarchical_configuration,
)



All good ! Now let's write down the EfficientFormer

In [5]:

class EfficientFormer(VisionTransformer):
    def __init__(
        self,
        steps,
        learning_rate=1e-2,
        betas=(0.9, 0.99),
        weight_decay=0.03,
        image_size=32,
        num_classes=10,
        dim=384,
        linear_warmup_ratio=0.1,
        classifier=Classifier.GAP,
    ):

        super(VisionTransformer, self).__init__()

        # all the inputs are saved under self.hparams (hyperparams)
        self.save_hyperparameters()

        # Generate the skeleton of our hierarchical Transformer
        # This implements a model close to the L1 suggested in "EfficientFormer" (https://arxiv.org/abs/2206.01191)
        # but a pooling layer has been removed, due to the very small image dimensions in Cifar (32x32)
        base_hierarchical_configs = [
            BasicLayerConfig(
                embedding=64,
                attention_mechanism="pooling",
                patch_size=3,
                stride=2,
                padding=1,
                seq_len=image_size * image_size // 4,
                feedforward="Conv2DFeedforward",
            ),
            BasicLayerConfig(
                embedding=128,
                attention_mechanism="pooling",
                patch_size=3,
                stride=2,
                padding=1,
                seq_len=image_size * image_size // 16,
                feedforward="Conv2DFeedforward",
            ),
            BasicLayerConfig(
                embedding=320,
                attention_mechanism="pooling",
                patch_size=3,
                stride=2,
                padding=1,
                seq_len=image_size * image_size // 64,
                feedforward="Conv2DFeedforward",
            ),
            # L1 would have an extra layer here, similar to the above,
            # bringing the sequence length down to HxW / 1024
        ]

        # Fill in the gaps in the config
        xformer_config = get_hierarchical_configuration(
            base_hierarchical_configs,
            layernorm_style="pre",
            use_rotary_embeddings=False,
            mlp_multiplier=4,
            in_channels=3,  # 24 if L1, there's another stem prior to the trunk
        )

        # Now instantiate the EfficientFormer trunk
        config = xFormerConfig(xformer_config)
        config.weight_init = "moco"

        self.trunk = xFormer.from_config(config)

        # L1 model
        # # This model requires a pre-stem (a conv prior to going through all the layers above)
        # self.pre_stem = build_patch_embedding(
        #     PatchEmbeddingConfig(
        #         in_channels=3, out_channels=24, kernel_size=3, stride=2, padding=1
        #     )
        # )

        # This model requires a final Attention step
        self.attention = MultiHeadDispatch(
            dim_model=320, num_heads=4, attention=ScaledDotProduct()
        )

        # The classifier head
        dim = base_hierarchical_configs[-1].embedding
        self.ln = nn.LayerNorm(dim)
        self.head = nn.Linear(dim, num_classes)
        self.criterion = torch.nn.CrossEntropyLoss()
        self.val_accuracy = Accuracy()

    def forward(self, x):
        x = x.flatten(-2, -1).transpose(-1, -2)  # BCHW to BSE
        # x = self.pre_stem(x) # L1 model
        x = self.trunk(x)
        x = self.attention(x)
        x = self.ln(x)

        x = x.mean(dim=1)  # mean over sequence len
        x = self.head(x)
        return x

Ok, now we need a training script, let's try this out

In [7]:
pl.seed_everything(42)

# Adjust batch depending on the available memory on your machine.
# You can also use reversible layers to save memory
REF_BATCH = 768
BATCH = 768  # lower if not enough GPU memory

MAX_EPOCHS = 50
NUM_WORKERS = 4
GPUS = 1

torch.cuda.manual_seed_all(42)
torch.manual_seed(42)

# We'll use a datamodule here, which already handles dataset/dataloader/sampler
# See https://pytorchlightning.github.io/lightning-tutorials/notebooks/lightning_examples/cifar10-baseline.html
# for a full tutorial
dm = CIFAR10DataModule(
    data_dir="data",
    batch_size=BATCH,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

image_size = dm.size(-1)  # 32 for CIFAR
num_classes = dm.num_classes  # 10 for CIFAR

# compute total number of steps
batch_size = BATCH * GPUS
steps = dm.num_samples // REF_BATCH * MAX_EPOCHS
lm = EfficientFormer(
    steps=steps,
    image_size=image_size,
    num_classes=num_classes,
)
trainer = pl.Trainer(
    gpus=GPUS,
    max_epochs=MAX_EPOCHS,
    precision=16,
    accumulate_grad_batches=REF_BATCH // BATCH,
)
trainer.fit(lm, dm)

# check the training
trainer.test(lm, datamodule=dm)

Global seed set to 42
  rank_zero_deprecation("DataModule property `size` was deprecated in v1.5 and will be removed in v1.7.")
Module Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
Module Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
Module Conv2d(128, 320, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))


{'block_type': 'encoder', 'dim_model': 64, 'use_triton': False, 'layer_norm_style': 'pre', 'multi_head_config': {'num_heads': 1, 'use_rotary_embeddings': False, 'attention': {'name': 'pooling'}}, 'feedforward_config': {'name': 'Conv2DFeedforward', 'activation': 'gelu', 'hidden_layer_multiplier': 4, 'dropout': 0.0}, 'position_encoding_config': {'name': 'learnable', 'seq_len': 256, 'add_class_token': False}, 'patch_embedding_config': {'in_channels': 3, 'kernel_size': 3, 'stride': 2, 'padding': 1}}
{'block_type': 'encoder', 'dim_model': 128, 'use_triton': False, 'layer_norm_style': 'pre', 'multi_head_config': {'num_heads': 1, 'use_rotary_embeddings': False, 'attention': {'name': 'pooling'}}, 'feedforward_config': {'name': 'Conv2DFeedforward', 'activation': 'gelu', 'hidden_layer_multiplier': 4, 'dropout': 0.0}, 'position_encoding_config': {'name': 'learnable', 'seq_len': 64, 'add_class_token': False}, 'patch_embedding_config': {'in_channels': 64, 'kernel_size': 3, 'stride': 2, 'padding': 1

MisconfigurationException: You requested GPUs: [0]
 But your machine only has: []