Skip to content

Commit

Permalink
Running, could be nicer with some parameter auto-fill
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed May 5, 2022
1 parent 5f75412 commit 70a3c7d
Show file tree
Hide file tree
Showing 11 changed files with 441 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- MLP benchmark
- Move all triton kernels to triton v2 [#272]
- Mem efficient attention, BW pass [#281]
- Metaformer support [#294]

## [0.0.10] - 2022-03-14
### Fixed
Expand Down
55 changes: 55 additions & 0 deletions HOWTO.md
Original file line number Diff line number Diff line change
Expand Up @@ -749,3 +749,58 @@ class xFormerStackConfig:
[2]: Kitaev, N., Kaiser, Ł., & Levskaya, A. (2020). Reformer: The Efficient Transformer.

[3]: Vaswani et al., Attention is all you need, 2017


### Hierarchical Transformers

The original Transformer proposal processes ("transforms") sequences of tokens, across possibly many layers. Crucially, the number of tokens is unchanged cross the depth of the model, and this prove to be really efficient in many domains.

It seems that some domains could however benefit from an architecture more typical from CNN, where there's a tradeoff across the depth of the model in between the spatial extent (ie: number of tokens) and their expressiveness (ie: the model or embedding dimension). These architectures are handled in xformers, through the "patch_embedding" element, which translates the sequence of tokens from one layer to another.

A small helper is provided to make it easier to generate matching configurations, as follows. We present in this example a truncated version of a small [Metaformer](https://arxiv.org/abs/2111.11418v1).

```python
from xformers.factory import xFormer, xFormerConfig
from xformers.helpers.hierarchical_configs import (
BasicLayerConfig,
get_hierarchical_configuration,
)


base_hierarchical_configs = [
BasicLayerConfig(
embedding=64, # the dimensions just have to match along the layers
attention_mechanism="scaled_dot_product", # anything you like
patch_size=7,
stride=4,
padding=2,
seq_len=image_size * image_size // 16,
),
BasicLayerConfig(
embedding=128,
attention_mechanism="scaled_dot_product",
patch_size=3,
stride=2,
padding=1,
seq_len=image_size * image_size // 64,
),
BasicLayerConfig(
embedding=320,
attention_mechanism="scaled_dot_product",
patch_size=3,
stride=2,
padding=1,
seq_len=image_size * image_size // 256,
),
]

# Fill in the gaps in the config
xformer_config = get_hierarchical_configuration(
base_hierarchical_configs,
layernorm_style="pre",
use_rotary_embeddings=False,
mlp_multiplier=4,
dim_head=32,
)
config = xFormerConfig(xformer_config)
```
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ Patrick et al., 2021](https://arxiv.org/abs/2106.05392)*
2. transformer block benchmark
3. [LRA](xformers/benchmarks/LRA/README.md), with SLURM suppot
4. Programatic and sweep friendly layer and model construction
1. Compatible with hierarchical Transformers, like Swin or Metaformer
5. Hackable
1. Not using monolithic CUDA kernels, composable building blocks
2. Using [Triton](https://triton-lang.org/) for some optimized parts, explicit, pythonic and user-accessible
Expand Down
1 change: 1 addition & 0 deletions docs/source/tutorials/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ Tutorials
pytorch_encoder
reversible
triton
hierarchical
5 changes: 5 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,8 @@ If your current machine does not expose enough RAM and the example reports an `O
This is meant to be an easy introduction to using xformers in practice, mirroring closely [this Pytorch Lightning](https://pytorchlightning.github.io/lightning-tutorials/notebooks/lightning_examples/cifar10-baseline.html) tutorial. The default settings are close to this tutorial, which trains a 11M parameters ResNet on the CIFAR dataset, we train a 10.6M ViT on the same dataset. The ViT configuration is not optimal for CIFAR, since the pictures have a very small size to begin with and information is probably lost given the patches. Nevertheless you should be able to reach about 80% accuracy within about an hour on a single GPU.

![Example curves](../docs/assets/microViT.png)


### MicroMetaformer

This is very close to the MicroViT example above, but illustrating the use of a hierarchical Transformer ([Metaformer](https://arxiv.org/pdf/2111.11418.pdf)) this time, through a helper function which generates the required configuration given the pooling parameters.
201 changes: 201 additions & 0 deletions examples/cifarMetaformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


from enum import Enum

import pytorch_lightning as pl
import torch
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
from torch import nn
from torchmetrics import Accuracy
from torchvision import transforms

from examples.microViT import VisionTransformer
from xformers.factory import xFormer, xFormerConfig
from xformers.helpers.hierarchical_configs import (
BasicLayerConfig,
get_hierarchical_configuration,
)


class Classifier(str, Enum):
GAP = "gap"
TOKEN = "token"


class MetaVisionTransformer(VisionTransformer):
def __init__(
self,
steps,
learning_rate=5e-4,
betas=(0.9, 0.99),
weight_decay=0.03,
image_size=32,
num_classes=10,
patch_size=2,
dim=384,
n_layer=6,
n_head=6,
resid_pdrop=0.0,
attn_pdrop=0.0,
mlp_pdrop=0.0,
attention="scaled_dot_product",
layer_norm_style="pre",
hidden_layer_multiplier=4,
use_rotary_embeddings=True,
linear_warmup_ratio=0.1,
classifier: Classifier = Classifier.TOKEN,
):

super(VisionTransformer, self).__init__()

# all the inputs are saved under self.hparams (hyperparams)
self.save_hyperparameters()

assert image_size % patch_size == 0

# Generate the skeleton of our hierarchical Transformer

# This is the small metaformer configuration,
# truncated of the last part since the pictures are too small with CIFAR10 (32x32)
# Any other related config would work,
# and the attention mechanisms don't have to be the same across layers
base_hierarchical_configs = [
BasicLayerConfig(
embedding=64,
attention_mechanism=attention,
patch_size=7,
stride=4,
padding=2,
seq_len=image_size * image_size // 16,
),
BasicLayerConfig(
embedding=128,
attention_mechanism=attention,
patch_size=3,
stride=2,
padding=1,
seq_len=image_size * image_size // 64,
),
BasicLayerConfig(
embedding=320,
attention_mechanism=attention,
patch_size=3,
stride=2,
padding=1,
seq_len=image_size * image_size // 256,
),
# BasicLayerConfig(
# embedding=512,
# attention_mechanism=attention,
# patch_size=3,
# stride=2,
# padding=1,
# seq_len=image_size * image_size // 1024,
# ),
]

# Fill in the gaps in the config
xformer_config = get_hierarchical_configuration(
base_hierarchical_configs,
layernorm_style=layer_norm_style,
use_rotary_embeddings=use_rotary_embeddings,
mlp_multiplier=4,
dim_head=32,
)

# Now instantiate the metaformer trunk
config = xFormerConfig(xformer_config)
print(config)
self.trunk = xFormer.from_config(config)
print(self.trunk)

# The classifier head
dim = base_hierarchical_configs[-1].embedding
self.ln = nn.LayerNorm(dim)
self.head = nn.Linear(dim, num_classes)
self.criterion = torch.nn.CrossEntropyLoss()
self.val_accuracy = Accuracy()

def forward(self, x):
x = self.trunk(x)
x = self.ln(x)

if self.hparams.classifier == Classifier.TOKEN:
x = x[:, 0] # only consider the token, we're classifying anyway
elif self.hparams.classifier == Classifier.GAP:
x = x.mean(dim=1) # mean over sequence len

x = self.head(x)
return x


if __name__ == "__main__":
pl.seed_everything(42)

# Adjust batch depending on the available memory on your machine.
# You can also use reversible layers to save memory
REF_BATCH = 512
BATCH = 256

MAX_EPOCHS = 50
NUM_WORKERS = 4
GPUS = 1

train_transforms = transforms.Compose(
[
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
cifar10_normalization(),
]
)

test_transforms = transforms.Compose(
[
transforms.ToTensor(),
cifar10_normalization(),
]
)

# We'll use a datamodule here, which already handles dataset/dataloader/sampler
# See https://pytorchlightning.github.io/lightning-tutorials/notebooks/lightning_examples/cifar10-baseline.html
# for a full tutorial
dm = CIFAR10DataModule(
data_dir="data",
batch_size=BATCH,
num_workers=NUM_WORKERS,
pin_memory=True,
)
dm.train_transforms = train_transforms
dm.test_transforms = test_transforms
dm.val_transforms = test_transforms

image_size = dm.size(-1) # 32 for CIFAR
num_classes = dm.num_classes # 10 for CIFAR

# compute total number of steps
batch_size = BATCH * GPUS
steps = dm.num_samples // REF_BATCH * MAX_EPOCHS
lm = MetaVisionTransformer(
steps=steps,
image_size=image_size,
num_classes=num_classes,
attention="scaled_dot_product",
layer_norm_style="pre",
use_rotary_embeddings=True,
)
trainer = pl.Trainer(
gpus=GPUS,
max_epochs=MAX_EPOCHS,
precision=16,
accumulate_grad_batches=REF_BATCH // BATCH,
)
trainer.fit(lm, dm)

# check the training
trainer.test(lm, datamodule=dm)
2 changes: 1 addition & 1 deletion examples/microViT.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def training_step(self, batch, _):
"train_loss": loss.mean(),
"learning_rate": self.lr_schedulers().get_last_lr()[0],
},
step=trainer.global_step,
step=self.global_step,
)

return loss
Expand Down
63 changes: 63 additions & 0 deletions tests/test_hierarchical_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


import torch

from xformers.factory import xFormer, xFormerConfig
from xformers.helpers.hierarchical_configs import (
BasicLayerConfig,
get_hierarchical_configuration,
)

BATCH = 20
SEQ = 512
MODEL = 384


def test_hierarchical_transformer():
image_size = 32

base_hierarchical_configs = [
BasicLayerConfig(
embedding=64,
attention_mechanism="scaled_dot_product",
patch_size=7,
stride=4,
padding=2,
seq_len=image_size * image_size // 16,
),
BasicLayerConfig(
embedding=128,
attention_mechanism="scaled_dot_product",
patch_size=3,
stride=2,
padding=1,
seq_len=image_size * image_size // 64,
),
BasicLayerConfig(
embedding=320,
attention_mechanism="scaled_dot_product",
patch_size=3,
stride=2,
padding=1,
seq_len=image_size * image_size // 256,
),
]

# Fill in the gaps in the config
xformer_config = get_hierarchical_configuration(
base_hierarchical_configs,
layernorm_style="pre",
use_rotary_embeddings=False,
mlp_multiplier=4,
dim_head=32,
)
config = xFormerConfig(xformer_config)
hierarchical_xformer = xFormer.from_config(config)

# Forward some dummy data
dummy = torch.rand((2, 3, image_size, image_size))
_ = hierarchical_xformer(dummy)
4 changes: 2 additions & 2 deletions xformers/components/multi_head_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
@dataclass
class MultiHeadDispatchConfig:
dim_model: int
residual_dropout: float
num_heads: int
attention: Attention
bias: bool
residual_dropout: float
dim_key: Optional[int]
dim_value: Optional[int]
in_proj_container: Optional[InProjContainer]
Expand Down Expand Up @@ -55,10 +55,10 @@ class MultiHeadDispatch(nn.Module):
def __init__(
self,
dim_model: int,
residual_dropout: float,
num_heads: int,
attention: Attention,
bias: bool = True,
residual_dropout: float = 0.0,
dim_key: Optional[int] = None,
dim_value: Optional[int] = None,
in_proj_container: Optional[InProjContainer] = None,
Expand Down

0 comments on commit 70a3c7d

Please sign in to comment.