Skip to content

Commit

Permalink
handling different normalizations + layer repetition
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed Jul 2, 2022
1 parent a8dffca commit e5f22e4
Show file tree
Hide file tree
Showing 13 changed files with 205 additions and 115 deletions.
5 changes: 4 additions & 1 deletion HOWTO.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Let's present here a couple of code snippets on how to solve a couple of questio
- [Intro](#intro)
- [Transformer](#transformer)
- [In practice](#in-practice)
- [Hierarchical Transformers](#hierarchical-transformers)
- [Hierarchical Transformers](#hierarchical-transformers)


## Understanding the dimension conventions
Expand Down Expand Up @@ -778,6 +778,7 @@ A small helper is provided to make it easier to generate matching configurations
stride=4,
padding=2,
seq_len=image_size * image_size // 16,
feedforward="MLP",
),
BasicLayerConfig(
embedding=128,
Expand All @@ -786,6 +787,7 @@ A small helper is provided to make it easier to generate matching configurations
stride=2,
padding=1,
seq_len=image_size * image_size // 64,
feedforward="MLP",
),
BasicLayerConfig(
embedding=320,
Expand All @@ -794,6 +796,7 @@ A small helper is provided to make it easier to generate matching configurations
stride=2,
padding=1,
seq_len=image_size * image_size // 256,
feedforward="MLP",
),
]

Expand Down
3 changes: 3 additions & 0 deletions docs/source/tutorials/hierarchical.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ A small helper is provided to make it easier to generate matching configurations
stride=4,
padding=2,
seq_len=image_size * image_size // 16,
feedforward="MLP",
),
BasicLayerConfig(
embedding=128,
Expand All @@ -34,6 +35,7 @@ A small helper is provided to make it easier to generate matching configurations
stride=2,
padding=1,
seq_len=image_size * image_size // 64,
feedforward="MLP",
),
BasicLayerConfig(
embedding=320,
Expand All @@ -42,6 +44,7 @@ A small helper is provided to make it easier to generate matching configurations
stride=2,
padding=1,
seq_len=image_size * image_size // 256,
feedforward="MLP",
),
]
Expand Down
13 changes: 11 additions & 2 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,23 @@ and finally an inference example or some test loss and accuracy.
If your current machine does not expose enough RAM and the example reports an `OutOfMemoryError`, please adjust the batch size.


### MicroViT
## NLP: microGPT

This is an hommage to [minGPT](https://github.com/karpathy/minGPT), in particular the training over Shakespeare dialogs of an autoregressive model. The default configuration is that of a standard Transformer, but you can change parts as you see fit. You can get to reasonable results within an hour or so on a single GPU.

## Vision models

You can find a couple of very small examples, of models being trained on the CIFAR10 dataset. They can be modified to training on something like ImageNet with minimal changes, but running them out of the box requires a bit more work in that case.


### ViT

This is meant to be an easy introduction to using xformers in practice, mirroring closely [this Pytorch Lightning](https://pytorchlightning.github.io/lightning-tutorials/notebooks/lightning_examples/cifar10-baseline.html) tutorial. The default settings are close to this tutorial, which trains a 11M parameters ResNet on the CIFAR dataset, we train a 10.6M ViT on the same dataset. The ViT configuration is not optimal for CIFAR, since the pictures have a very small size to begin with and information is probably lost given the patches. Nevertheless you should be able to reach about 80% accuracy within about an hour on a single GPU.

![Example curves](../docs/assets/microViT.png)


### MicroMetaformer
### Metaformer

This is very close to the MicroViT example above, but illustrating the use of a hierarchical Transformer ([Metaformer](https://arxiv.org/pdf/2111.11418.pdf)) this time, through a helper function which generates the required configuration given the pooling parameters. The suggested configuration is about 6.6M parameters big (half of a ResNet18) and trains to about 86% top-1 Cifar10 within minutes.

Expand Down
51 changes: 24 additions & 27 deletions examples/cifarMetaformer.py → examples/cifar_MetaFormer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,25 @@
import pytorch_lightning as pl
import torch
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
from torch import nn
from torchmetrics import Accuracy
from torchvision import transforms

from examples.microViT import Classifier, VisionTransformer
from examples.cifar_ViT import Classifier, VisionTransformer
from xformers.factory import xFormer, xFormerConfig
from xformers.helpers.hierarchical_configs import (
BasicLayerConfig,
get_hierarchical_configuration,
)

# This is very close to the cifarViT example, and reuses a lot of the training code, only the model part is different.
# There are many ways one can use xformers to write down a MetaFormer, for instance by
# picking up the parts from `xformers.components` and implementing the model explicitly,
# or by patching another existing ViT-like implementation.

# This example takes another approach, as we define the whole model configuration in one go (dict structure)
# and then use the xformers factory to generate the model. This obfuscates a lot of the model building
# (though you can inspect the resulting implementation), but makes it trivial to do some hyperparameter search


class MetaVisionTransformer(VisionTransformer):
def __init__(
Expand All @@ -44,9 +51,10 @@ def __init__(
self.save_hyperparameters()

# Generate the skeleton of our hierarchical Transformer

# This is a small poolformer configuration, adapted to the small CIFAR10 pictures (32x32)
# Any other related config would work, and the attention mechanisms don't have to be the same across layers
# - This is a small poolformer configuration, adapted to the small CIFAR10 pictures (32x32)
# - Please note that this does not match the L1 configuration in the paper, as this would correspond to repeated
# layers. CIFAR pictures are too small for this config to be directly meaningful (although that would run)
# - Any other related config would work, and the attention mechanisms don't have to be the same across layers
base_hierarchical_configs = [
BasicLayerConfig(
embedding=64,
Expand All @@ -55,6 +63,8 @@ def __init__(
stride=2,
padding=1,
seq_len=image_size * image_size // 4,
feedforward=feedforward,
repeat_layer=1,
),
BasicLayerConfig(
embedding=128,
Expand All @@ -63,6 +73,8 @@ def __init__(
stride=2,
padding=1,
seq_len=image_size * image_size // 16,
feedforward=feedforward,
repeat_layer=1,
),
BasicLayerConfig(
embedding=320,
Expand All @@ -71,6 +83,8 @@ def __init__(
stride=2,
padding=1,
seq_len=image_size * image_size // 64,
feedforward=feedforward,
repeat_layer=1,
),
BasicLayerConfig(
embedding=512,
Expand All @@ -79,6 +93,8 @@ def __init__(
stride=2,
padding=1,
seq_len=image_size * image_size // 256,
feedforward=feedforward,
repeat_layer=1,
),
]

Expand All @@ -89,7 +105,6 @@ def __init__(
use_rotary_embeddings=use_rotary_embeddings,
mlp_multiplier=4,
dim_head=32,
feedforward="Conv2DFeedforward",
)

# Now instantiate the metaformer trunk
Expand Down Expand Up @@ -131,34 +146,16 @@ def forward(self, x):
torch.cuda.manual_seed_all(42)
torch.manual_seed(42)

train_transforms = transforms.Compose(
[
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
cifar10_normalization(),
]
)

test_transforms = transforms.Compose(
[
transforms.ToTensor(),
cifar10_normalization(),
]
)

# We'll use a datamodule here, which already handles dataset/dataloader/sampler
# See https://pytorchlightning.github.io/lightning-tutorials/notebooks/lightning_examples/cifar10-baseline.html
# - See https://pytorchlightning.github.io/lightning-tutorials/notebooks/lightning_examples/cifar10-baseline.html
# for a full tutorial
# - Please note that default transforms are being used
dm = CIFAR10DataModule(
data_dir="data",
batch_size=BATCH,
num_workers=NUM_WORKERS,
pin_memory=True,
)
dm.train_transforms = train_transforms
dm.test_transforms = test_transforms
dm.val_transforms = test_transforms

image_size = dm.size(-1) # 32 for CIFAR
num_classes = dm.num_classes # 10 for CIFAR
Expand Down
24 changes: 2 additions & 22 deletions examples/microViT.py → examples/cifar_ViT.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from pl_bolts.datamodules import CIFAR10DataModule
from torch import nn
from torchmetrics import Accuracy
from torchvision import transforms

from xformers.factory import xFormer, xFormerConfig

Expand Down Expand Up @@ -163,7 +162,6 @@ def forward(self, x):
def training_step(self, batch, _):
x, y = batch
y_hat = self(x)

loss = self.criterion(y_hat, y)

self.logger.log_metrics(
Expand Down Expand Up @@ -205,33 +203,15 @@ def test_step(self, batch, _):
NUM_WORKERS = 4
GPUS = 1

train_transforms = transforms.Compose(
[
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
)

test_transforms = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
)

# We'll use a datamodule here, which already handles dataset/dataloader/sampler
# See https://pytorchlightning.github.io/lightning-tutorials/notebooks/lightning_examples/cifar10-baseline.html
# - See https://pytorchlightning.github.io/lightning-tutorials/notebooks/lightning_examples/cifar10-baseline.html
# for a full tutorial
# - Please note that default transforms are being used
dm = CIFAR10DataModule(
data_dir="data",
batch_size=BATCH,
num_workers=NUM_WORKERS,
pin_memory=True,
train_transforms=train_transforms,
test_transforms=test_transforms,
val_transforms=test_transforms,
)

image_size = dm.size(-1) # 32 for CIFAR
Expand Down
4 changes: 4 additions & 0 deletions tests/test_hierarchical_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def test_hierarchical_transformer():
stride=4,
padding=2,
seq_len=image_size * image_size // 16,
feedforward="MLP",
),
BasicLayerConfig(
embedding=128,
Expand All @@ -36,6 +37,8 @@ def test_hierarchical_transformer():
stride=2,
padding=1,
seq_len=image_size * image_size // 64,
feedforward="MLP",
repeat_layer=2,
),
BasicLayerConfig(
embedding=320,
Expand All @@ -44,6 +47,7 @@ def test_hierarchical_transformer():
stride=2,
padding=1,
seq_len=image_size * image_size // 256,
feedforward="MLP",
),
]

Expand Down
15 changes: 11 additions & 4 deletions tests/test_residual.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
# LICENSE file in the root directory of this source tree.


import pytest
import torch

from xformers.components import PreNorm
from xformers.components import NormalizationType, PreNorm


class Passthrough(torch.nn.Module):
Expand All @@ -17,11 +18,17 @@ def forward(self, *args):
return args


def test_pre_norm():
@pytest.mark.parametrize("normalization", [n.value for n in NormalizationType])
def test_pre_norm(normalization):
# Check that passing the same tensor a bunch of times skips the extra normalizations
x = torch.rand((3, 3))
x = torch.rand((3, 3), requires_grad=True)

wrap = PreNorm(d_model=3, sublayer=Passthrough(), use_triton=False)
wrap = PreNorm(
d_norm=3, sublayer=Passthrough(), normalization=normalization, use_triton=False
)
outputs = wrap(inputs=[x, x, x])

assert id(outputs[0]) == id(outputs[1])

# Check the BW pass
torch.sum(outputs[0]).backward()
3 changes: 2 additions & 1 deletion xformers/components/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
from .multi_head_dispatch import MultiHeadDispatchConfig
from .patch_embedding import PatchEmbeddingConfig # noqa
from .patch_embedding import build_patch_embedding # noqa
from .residual import LayerNormStyle # noqa; noqa
from .residual import NormalizationType # noqa
from .residual import PostNorm # noqa
from .residual import PreNorm # noqa
from .residual import RequiresWrappedInputs # noqa
from .residual import Residual # noqa
from .residual import ResidualNormStyle # noqa

# automatically import any Python files in the directory
import_all_modules(str(Path(__file__).parent), "xformers.components")
Expand Down

0 comments on commit e5f22e4

Please sign in to comment.