Skip to content

Commit

Permalink
Running, could be nicer with some parameter auto-fill
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed May 5, 2022
1 parent fa11d81 commit 93ec973
Show file tree
Hide file tree
Showing 8 changed files with 322 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- MLP benchmark
- Move all triton kernels to triton v2 [#272]
- Mem efficient attention, BW pass [#281]
- Metaformer support [#294]

## [0.0.10] - 2022-03-14
### Fixed
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ Patrick et al., 2021](https://arxiv.org/abs/2106.05392)*
2. transformer block benchmark
3. [LRA](xformers/benchmarks/LRA/README.md), with SLURM suppot
4. Programatic and sweep friendly layer and model construction
1. Compatible with hierarchical Transformers, like Swin or Metaformer
5. Hackable
1. Not using monolithic CUDA kernels, composable building blocks
2. Using [Triton](https://triton-lang.org/) for some optimized parts, explicit, pythonic and user-accessible
Expand Down
5 changes: 5 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,8 @@ If your current machine does not expose enough RAM and the example reports an `O
This is meant to be an easy introduction to using xformers in practice, mirroring closely [this Pytorch Lightning](https://pytorchlightning.github.io/lightning-tutorials/notebooks/lightning_examples/cifar10-baseline.html) tutorial. The default settings are close to this tutorial, which trains a 11M parameters ResNet on the CIFAR dataset, we train a 10.6M ViT on the same dataset. The ViT configuration is not optimal for CIFAR, since the pictures have a very small size to begin with and information is probably lost given the patches. Nevertheless you should be able to reach about 80% accuracy within about an hour on a single GPU.

![Example curves](../docs/assets/microViT.png)


### MicroMetaformer

This is very close to the MicroViT example above, but illustrating the use of a hierarchical Transformer ([Metaformer](https://arxiv.org/pdf/2111.11418.pdf)) this time, through a helper function which generates the required configuration given the pooling parameters.
201 changes: 201 additions & 0 deletions examples/microMetaformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


from enum import Enum

import pytorch_lightning as pl
import torch
from microViT import VisionTransformer
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
from torch import nn
from torchmetrics import Accuracy
from torchvision import transforms

from xformers.factory import xFormer, xFormerConfig
from xformers.helpers.hierarchical_configs import (
BasicLayerConfig,
get_hierarchical_configuration,
)


class Classifier(str, Enum):
GAP = "gap"
TOKEN = "token"


class MetaVisionTransformer(VisionTransformer):
def __init__(
self,
steps,
learning_rate=5e-4,
betas=(0.9, 0.99),
weight_decay=0.03,
image_size=32,
num_classes=10,
patch_size=2,
dim=384,
n_layer=6,
n_head=6,
resid_pdrop=0.0,
attn_pdrop=0.0,
mlp_pdrop=0.0,
attention="scaled_dot_product",
layer_norm_style="pre",
hidden_layer_multiplier=4,
use_rotary_embeddings=True,
linear_warmup_ratio=0.1,
classifier: Classifier = Classifier.TOKEN,
):

super(VisionTransformer, self).__init__()

# all the inputs are saved under self.hparams (hyperparams)
self.save_hyperparameters()

assert image_size % patch_size == 0

# Generate the skeleton of our hierarchical Transformer

# This is the small metaformer configuration,
# truncated of the last part since the pictures are too small with CIFAR10 (32x32)
# Any other related config would work,
# and the attention mechanisms don't have to be the same across layers
base_hierarchical_configs = [
BasicLayerConfig(
embedding=64,
attention_mechanism=attention,
patch_size=7,
stride=4,
padding=2,
seq_len=image_size * image_size // 16,
),
BasicLayerConfig(
embedding=128,
attention_mechanism=attention,
patch_size=3,
stride=2,
padding=1,
seq_len=image_size * image_size // 64,
),
BasicLayerConfig(
embedding=320,
attention_mechanism=attention,
patch_size=3,
stride=2,
padding=1,
seq_len=image_size * image_size // 256,
),
# BasicLayerConfig(
# embedding=512,
# attention_mechanism=attention,
# patch_size=3,
# stride=2,
# padding=1,
# seq_len=image_size * image_size // 1024,
# ),
]

# Fill in the gaps in the config
xformer_config = get_hierarchical_configuration(
base_hierarchical_configs,
layernorm_style=layer_norm_style,
use_rotary_embeddings=use_rotary_embeddings,
mlp_multiplier=4,
dim_head=32,
)

# Now instantiate the metaformer trunk
config = xFormerConfig(xformer_config)
print(config)
self.trunk = xFormer.from_config(config)
print(self.trunk)

# The classifier head
dim = base_hierarchical_configs[-1].embedding
self.ln = nn.LayerNorm(dim)
self.head = nn.Linear(dim, num_classes)
self.criterion = torch.nn.CrossEntropyLoss()
self.val_accuracy = Accuracy()

def forward(self, x):
x = self.trunk(x)
x = self.ln(x)

if self.hparams.classifier == Classifier.TOKEN:
x = x[:, 0] # only consider the token, we're classifying anyway
elif self.hparams.classifier == Classifier.GAP:
x = x.mean(dim=1) # mean over sequence len

x = self.head(x)
return x


if __name__ == "__main__":
pl.seed_everything(42)

# Adjust batch depending on the available memory on your machine.
# You can also use reversible layers to save memory
REF_BATCH = 512
BATCH = 256

MAX_EPOCHS = 50
NUM_WORKERS = 4
GPUS = 1

train_transforms = transforms.Compose(
[
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
cifar10_normalization(),
]
)

test_transforms = transforms.Compose(
[
transforms.ToTensor(),
cifar10_normalization(),
]
)

# We'll use a datamodule here, which already handles dataset/dataloader/sampler
# See https://pytorchlightning.github.io/lightning-tutorials/notebooks/lightning_examples/cifar10-baseline.html
# for a full tutorial
dm = CIFAR10DataModule(
data_dir="data",
batch_size=BATCH,
num_workers=NUM_WORKERS,
pin_memory=True,
)
dm.train_transforms = train_transforms
dm.test_transforms = test_transforms
dm.val_transforms = test_transforms

image_size = dm.size(-1) # 32 for CIFAR
num_classes = dm.num_classes # 10 for CIFAR

# compute total number of steps
batch_size = BATCH * GPUS
steps = dm.num_samples // REF_BATCH * MAX_EPOCHS
lm = MetaVisionTransformer(
steps=steps,
image_size=image_size,
num_classes=num_classes,
attention="scaled_dot_product",
layer_norm_style="pre",
use_rotary_embeddings=True,
)
trainer = pl.Trainer(
gpus=GPUS,
max_epochs=MAX_EPOCHS,
precision=16,
accumulate_grad_batches=REF_BATCH // BATCH,
)
trainer.fit(lm, dm)

# check the training
trainer.test(lm, datamodule=dm)
2 changes: 1 addition & 1 deletion examples/microViT.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def training_step(self, batch, _):
"train_loss": loss.mean(),
"learning_rate": self.lr_schedulers().get_last_lr()[0],
},
step=trainer.global_step,
step=self.global_step,
)

return loss
Expand Down
4 changes: 2 additions & 2 deletions xformers/components/multi_head_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
@dataclass
class MultiHeadDispatchConfig:
dim_model: int
residual_dropout: float
num_heads: int
attention: Attention
bias: bool
residual_dropout: float
dim_key: Optional[int]
dim_value: Optional[int]
in_proj_container: Optional[InProjContainer]
Expand Down Expand Up @@ -55,10 +55,10 @@ class MultiHeadDispatch(nn.Module):
def __init__(
self,
dim_model: int,
residual_dropout: float,
num_heads: int,
attention: Attention,
bias: bool = True,
residual_dropout: float = 0.0,
dim_key: Optional[int] = None,
dim_value: Optional[int] = None,
in_proj_container: Optional[InProjContainer] = None,
Expand Down
8 changes: 7 additions & 1 deletion xformers/factory/block_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __init__(
patch_embedding_config: Optional[Dict[str, Any]] = None,
**kwargs,
):
# Convenience, fill in duplicated field
# Convenience, fill in duplicated fields
try:
if "dim_model" not in multi_head_config.keys():
multi_head_config["dim_model"] = dim_model
Expand All @@ -144,6 +144,12 @@ def __init__(
):
position_encoding_config["dim_model"] = dim_model

if (
patch_embedding_config is not None
and "out_channels" not in patch_embedding_config.keys()
):
patch_embedding_config["out_channels"] = dim_model

except AttributeError:
# A config instance was passed in, this is fine
pass
Expand Down
104 changes: 104 additions & 0 deletions xformers/helpers/hierarchical_configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


import copy
from dataclasses import dataclass
from typing import Any, Dict, List

from xformers.components.residual import LayerNormStyle


@dataclass
class BasicLayerConfig:
embedding: int
attention_mechanism: str
patch_size: int
stride: int
padding: int
seq_len: int


def get_hierarchical_configuration(
layer_basic_configs: List[BasicLayerConfig],
layernorm_style: LayerNormStyle = LayerNormStyle.Pre,
use_rotary_embeddings: bool = True,
mlp_multiplier: int = 4,
dim_head=32,
):
"""
A small helper to generate hierarchical xformers configurations,
which correspond for instance to poolformer or swin architectures.
Contrary to more "classical" Transformer architectures, which conserve the sequence/context
length across layers, hierarchical Transformers trade the sequence length for the embedding dimension
"""

base_config: Dict[str, Any] = {
"block_type": "encoder",
"dim_model": 0,
"use_triton": False,
"layer_norm_style": str(layernorm_style),
"multi_head_config": {
"num_heads": 0,
"use_rotary_embeddings": use_rotary_embeddings,
"attention": {
"name": "TBD",
},
},
"feedforward_config": {
"name": "MLP",
"activation": "gelu",
"hidden_layer_multiplier": mlp_multiplier,
"dropout": 0.0,
},
"position_encoding_config": {
"name": "learnable",
"seq_len": 0,
"add_class_token": False,
},
"patch_embedding_config": {
"in_channels": 3,
"kernel_size": 0,
"stride": 0,
"padding": 0,
},
}

xformers_config = []
in_channels = 3

for layer_basic_config in layer_basic_configs:
lc = copy.deepcopy(base_config)

# Fill in the changing model dimensions
lc["dim_model"] = layer_basic_config.embedding

# Update the patches
lc["patch_embedding_config"] = {
"in_channels": in_channels,
"kernel_size": layer_basic_config.patch_size,
"stride": layer_basic_config.stride,
"padding": layer_basic_config.padding,
}

# Update the number of channels for the next layer
in_channels = lc["dim_model"] * 1

lc["position_encoding_config"]["seq_len"] = layer_basic_config.seq_len

# Fill in the number of heads
lc["multi_head_config"]["num_heads"] = layer_basic_config.embedding // dim_head
assert layer_basic_config.embedding % dim_head == 0

# Fill in the attention mechanism
lc["multi_head_config"]["attention"][
"name"
] = layer_basic_config.attention_mechanism

print(lc)
xformers_config.append(lc)

return xformers_config

0 comments on commit 93ec973

Please sign in to comment.