-
Notifications
You must be signed in to change notification settings - Fork 550
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat] Adding Metaformer support #294
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
Hierarchical Transformers | ||
========================= | ||
|
||
The original Transformer proposal processes ("transforms") sequences of tokens, across possibly many layers. Crucially, the number of tokens is unchanged cross the depth of the model, and this prove to be really efficient in many domains. | ||
|
||
It seems that some domains could however benefit from an architecture more typical from CNN, where there's a tradeoff across the depth of the model in between the spatial extent (ie: number of tokens) and their expressiveness (ie: the model or embedding dimension). These architectures are handled in xformers, through the "patch_embedding" element, which translates the sequence of tokens from one layer to another. | ||
|
||
A small helper is provided to make it easier to generate matching configurations, as follows. We present in this example a truncated version of a small Metaformer_. | ||
|
||
.. _Metaformer: https://arxiv.org/abs/2111.11418v1 | ||
|
||
.. code-block:: python | ||
|
||
from xformers.factory import xFormer, xFormerConfig | ||
from xformers.helpers.hierarchical_configs import ( | ||
BasicLayerConfig, | ||
get_hierarchical_configuration, | ||
) | ||
|
||
|
||
base_hierarchical_configs = [ | ||
BasicLayerConfig( | ||
embedding=64, # the dimensions just have to match along the layers | ||
attention_mechanism="scaled_dot_product", # anything you like | ||
patch_size=7, | ||
stride=4, | ||
padding=2, | ||
seq_len=image_size * image_size // 16, | ||
), | ||
BasicLayerConfig( | ||
embedding=128, | ||
attention_mechanism="scaled_dot_product", | ||
patch_size=3, | ||
stride=2, | ||
padding=1, | ||
seq_len=image_size * image_size // 64, | ||
), | ||
BasicLayerConfig( | ||
embedding=320, | ||
attention_mechanism="scaled_dot_product", | ||
patch_size=3, | ||
stride=2, | ||
padding=1, | ||
seq_len=image_size * image_size // 256, | ||
), | ||
] | ||
|
||
# Fill in the gaps in the config | ||
xformer_config = get_hierarchical_configuration( | ||
base_hierarchical_configs, | ||
layernorm_style="pre", | ||
use_rotary_embeddings=False, | ||
mlp_multiplier=4, | ||
dim_head=32, | ||
) | ||
config = xFormerConfig(xformer_config) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,3 +11,4 @@ Tutorials | |
pytorch_encoder | ||
reversible | ||
triton | ||
hierarchical |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
# | ||
# This source code is licensed under the BSD license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
import pytorch_lightning as pl | ||
import torch | ||
from pl_bolts.datamodules import CIFAR10DataModule | ||
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization | ||
from torch import nn | ||
from torchmetrics import Accuracy | ||
from torchvision import transforms | ||
|
||
from examples.microViT import Classifier, VisionTransformer | ||
from xformers.factory import xFormer, xFormerConfig | ||
from xformers.helpers.hierarchical_configs import ( | ||
BasicLayerConfig, | ||
get_hierarchical_configuration, | ||
) | ||
|
||
|
||
class MetaVisionTransformer(VisionTransformer): | ||
def __init__( | ||
self, | ||
steps, | ||
learning_rate=5e-3, | ||
betas=(0.9, 0.99), | ||
weight_decay=0.03, | ||
image_size=32, | ||
num_classes=10, | ||
dim=384, | ||
attention="scaled_dot_product", | ||
layer_norm_style="pre", | ||
use_rotary_embeddings=True, | ||
linear_warmup_ratio=0.1, | ||
classifier=Classifier.GAP, | ||
): | ||
|
||
super(VisionTransformer, self).__init__() | ||
|
||
# all the inputs are saved under self.hparams (hyperparams) | ||
self.save_hyperparameters() | ||
|
||
# Generate the skeleton of our hierarchical Transformer | ||
|
||
# This is a small poolformer configuration, adapted to the small CIFAR10 pictures (32x32) | ||
# Any other related config would work, | ||
# and the attention mechanisms don't have to be the same across layers | ||
base_hierarchical_configs = [ | ||
BasicLayerConfig( | ||
embedding=64, | ||
attention_mechanism=attention, | ||
patch_size=3, | ||
stride=2, | ||
padding=1, | ||
seq_len=image_size * image_size // 4, | ||
), | ||
BasicLayerConfig( | ||
embedding=128, | ||
attention_mechanism=attention, | ||
patch_size=3, | ||
stride=2, | ||
padding=1, | ||
seq_len=image_size * image_size // 16, | ||
), | ||
BasicLayerConfig( | ||
embedding=320, | ||
attention_mechanism=attention, | ||
patch_size=3, | ||
stride=2, | ||
padding=1, | ||
seq_len=image_size * image_size // 64, | ||
), | ||
BasicLayerConfig( | ||
embedding=512, | ||
attention_mechanism=attention, | ||
patch_size=3, | ||
stride=2, | ||
padding=1, | ||
seq_len=image_size * image_size // 256, | ||
), | ||
] | ||
|
||
# Fill in the gaps in the config | ||
xformer_config = get_hierarchical_configuration( | ||
base_hierarchical_configs, | ||
layernorm_style=layer_norm_style, | ||
use_rotary_embeddings=use_rotary_embeddings, | ||
mlp_multiplier=4, | ||
dim_head=32, | ||
) | ||
|
||
# Now instantiate the metaformer trunk | ||
config = xFormerConfig(xformer_config) | ||
print(config) | ||
self.trunk = xFormer.from_config(config) | ||
print(self.trunk) | ||
|
||
# The classifier head | ||
dim = base_hierarchical_configs[-1].embedding | ||
self.ln = nn.LayerNorm(dim) | ||
self.head = nn.Linear(dim, num_classes) | ||
self.criterion = torch.nn.CrossEntropyLoss() | ||
self.val_accuracy = Accuracy() | ||
|
||
def forward(self, x): | ||
x = self.trunk(x) | ||
x = self.ln(x) | ||
|
||
if self.hparams.classifier == Classifier.TOKEN: | ||
x = x[:, 0] # only consider the token, we're classifying anyway | ||
elif self.hparams.classifier == Classifier.GAP: | ||
x = x.mean(dim=1) # mean over sequence len | ||
|
||
x = self.head(x) | ||
return x | ||
|
||
|
||
if __name__ == "__main__": | ||
pl.seed_everything(42) | ||
|
||
# Adjust batch depending on the available memory on your machine. | ||
# You can also use reversible layers to save memory | ||
REF_BATCH = 512 | ||
BATCH = 512 # lower if not enough GPU memory | ||
|
||
MAX_EPOCHS = 50 | ||
NUM_WORKERS = 4 | ||
GPUS = 1 | ||
|
||
train_transforms = transforms.Compose( | ||
[ | ||
transforms.RandomCrop(32, padding=4), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.ToTensor(), | ||
cifar10_normalization(), | ||
] | ||
) | ||
|
||
test_transforms = transforms.Compose( | ||
[ | ||
transforms.ToTensor(), | ||
cifar10_normalization(), | ||
] | ||
) | ||
|
||
# We'll use a datamodule here, which already handles dataset/dataloader/sampler | ||
# See https://pytorchlightning.github.io/lightning-tutorials/notebooks/lightning_examples/cifar10-baseline.html | ||
# for a full tutorial | ||
dm = CIFAR10DataModule( | ||
data_dir="data", | ||
batch_size=BATCH, | ||
num_workers=NUM_WORKERS, | ||
pin_memory=True, | ||
) | ||
dm.train_transforms = train_transforms | ||
dm.test_transforms = test_transforms | ||
dm.val_transforms = test_transforms | ||
|
||
image_size = dm.size(-1) # 32 for CIFAR | ||
num_classes = dm.num_classes # 10 for CIFAR | ||
|
||
# compute total number of steps | ||
batch_size = BATCH * GPUS | ||
steps = dm.num_samples // REF_BATCH * MAX_EPOCHS | ||
lm = MetaVisionTransformer( | ||
steps=steps, | ||
image_size=image_size, | ||
num_classes=num_classes, | ||
attention="scaled_dot_product", | ||
layer_norm_style="pre", | ||
use_rotary_embeddings=True, | ||
) | ||
trainer = pl.Trainer( | ||
gpus=GPUS, | ||
max_epochs=MAX_EPOCHS, | ||
precision=16, | ||
accumulate_grad_batches=REF_BATCH // BATCH, | ||
) | ||
trainer.fit(lm, dm) | ||
|
||
# check the training | ||
trainer.test(lm, datamodule=dm) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!