-
Notifications
You must be signed in to change notification settings - Fork 555
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[metaformers] handling different normalizations + layer repetition #345
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,18 +7,25 @@ | |
import pytorch_lightning as pl | ||
import torch | ||
from pl_bolts.datamodules import CIFAR10DataModule | ||
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization | ||
from torch import nn | ||
from torchmetrics import Accuracy | ||
from torchvision import transforms | ||
|
||
from examples.microViT import Classifier, VisionTransformer | ||
from examples.cifar_ViT import Classifier, VisionTransformer | ||
from xformers.factory import xFormer, xFormerConfig | ||
from xformers.helpers.hierarchical_configs import ( | ||
BasicLayerConfig, | ||
get_hierarchical_configuration, | ||
) | ||
|
||
# This is very close to the cifarViT example, and reuses a lot of the training code, only the model part is different. | ||
# There are many ways one can use xformers to write down a MetaFormer, for instance by | ||
# picking up the parts from `xformers.components` and implementing the model explicitly, | ||
# or by patching another existing ViT-like implementation. | ||
|
||
# This example takes another approach, as we define the whole model configuration in one go (dict structure) | ||
# and then use the xformers factory to generate the model. This obfuscates a lot of the model building | ||
# (though you can inspect the resulting implementation), but makes it trivial to do some hyperparameter search | ||
|
||
|
||
class MetaVisionTransformer(VisionTransformer): | ||
def __init__( | ||
|
@@ -44,9 +51,10 @@ def __init__( | |
self.save_hyperparameters() | ||
|
||
# Generate the skeleton of our hierarchical Transformer | ||
|
||
# This is a small poolformer configuration, adapted to the small CIFAR10 pictures (32x32) | ||
# Any other related config would work, and the attention mechanisms don't have to be the same across layers | ||
# - This is a small poolformer configuration, adapted to the small CIFAR10 pictures (32x32) | ||
# - Please note that this does not match the L1 configuration in the paper, as this would correspond to repeated | ||
# layers. CIFAR pictures are too small for this config to be directly meaningful (although that would run) | ||
# - Any other related config would work, and the attention mechanisms don't have to be the same across layers | ||
base_hierarchical_configs = [ | ||
BasicLayerConfig( | ||
embedding=64, | ||
|
@@ -55,6 +63,8 @@ def __init__( | |
stride=2, | ||
padding=1, | ||
seq_len=image_size * image_size // 4, | ||
feedforward=feedforward, | ||
repeat_layer=1, | ||
), | ||
BasicLayerConfig( | ||
embedding=128, | ||
|
@@ -63,6 +73,8 @@ def __init__( | |
stride=2, | ||
padding=1, | ||
seq_len=image_size * image_size // 16, | ||
feedforward=feedforward, | ||
repeat_layer=1, | ||
), | ||
BasicLayerConfig( | ||
embedding=320, | ||
|
@@ -71,6 +83,8 @@ def __init__( | |
stride=2, | ||
padding=1, | ||
seq_len=image_size * image_size // 64, | ||
feedforward=feedforward, | ||
repeat_layer=1, | ||
), | ||
BasicLayerConfig( | ||
embedding=512, | ||
|
@@ -79,6 +93,8 @@ def __init__( | |
stride=2, | ||
padding=1, | ||
seq_len=image_size * image_size // 256, | ||
feedforward=feedforward, | ||
repeat_layer=1, | ||
), | ||
] | ||
|
||
|
@@ -89,7 +105,6 @@ def __init__( | |
use_rotary_embeddings=use_rotary_embeddings, | ||
mlp_multiplier=4, | ||
dim_head=32, | ||
feedforward="Conv2DFeedforward", | ||
) | ||
|
||
# Now instantiate the metaformer trunk | ||
|
@@ -131,34 +146,16 @@ def forward(self, x): | |
torch.cuda.manual_seed_all(42) | ||
torch.manual_seed(42) | ||
|
||
train_transforms = transforms.Compose( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the lightning-bolt datamodule already does all that |
||
[ | ||
transforms.RandomCrop(32, padding=4), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.ToTensor(), | ||
cifar10_normalization(), | ||
] | ||
) | ||
|
||
test_transforms = transforms.Compose( | ||
[ | ||
transforms.ToTensor(), | ||
cifar10_normalization(), | ||
] | ||
) | ||
|
||
# We'll use a datamodule here, which already handles dataset/dataloader/sampler | ||
# See https://pytorchlightning.github.io/lightning-tutorials/notebooks/lightning_examples/cifar10-baseline.html | ||
# - See https://pytorchlightning.github.io/lightning-tutorials/notebooks/lightning_examples/cifar10-baseline.html | ||
# for a full tutorial | ||
# - Please note that default transforms are being used | ||
dm = CIFAR10DataModule( | ||
data_dir="data", | ||
batch_size=BATCH, | ||
num_workers=NUM_WORKERS, | ||
pin_memory=True, | ||
) | ||
dm.train_transforms = train_transforms | ||
dm.test_transforms = test_transforms | ||
dm.val_transforms = test_transforms | ||
|
||
image_size = dm.size(-1) # 32 for CIFAR | ||
num_classes = dm.num_classes # 10 for CIFAR | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,7 +19,6 @@ | |
from pl_bolts.datamodules import CIFAR10DataModule | ||
from torch import nn | ||
from torchmetrics import Accuracy | ||
from torchvision import transforms | ||
|
||
from xformers.factory import xFormer, xFormerConfig | ||
|
||
|
@@ -163,7 +162,6 @@ def forward(self, x): | |
def training_step(self, batch, _): | ||
x, y = batch | ||
y_hat = self(x) | ||
|
||
loss = self.criterion(y_hat, y) | ||
|
||
self.logger.log_metrics( | ||
|
@@ -205,33 +203,15 @@ def test_step(self, batch, _): | |
NUM_WORKERS = 4 | ||
GPUS = 1 | ||
|
||
train_transforms = transforms.Compose( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same as above, these are actually the default transforms in the lightning bolt datamodule, not useful |
||
[ | ||
transforms.RandomCrop(32, padding=4), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), | ||
] | ||
) | ||
|
||
test_transforms = transforms.Compose( | ||
[ | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), | ||
] | ||
) | ||
|
||
# We'll use a datamodule here, which already handles dataset/dataloader/sampler | ||
# See https://pytorchlightning.github.io/lightning-tutorials/notebooks/lightning_examples/cifar10-baseline.html | ||
# - See https://pytorchlightning.github.io/lightning-tutorials/notebooks/lightning_examples/cifar10-baseline.html | ||
# for a full tutorial | ||
# - Please note that default transforms are being used | ||
dm = CIFAR10DataModule( | ||
data_dir="data", | ||
batch_size=BATCH, | ||
num_workers=NUM_WORKERS, | ||
pin_memory=True, | ||
train_transforms=train_transforms, | ||
test_transforms=test_transforms, | ||
val_transforms=test_transforms, | ||
) | ||
|
||
image_size = dm.size(-1) # 32 for CIFAR | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,9 +4,10 @@ | |
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
import pytest | ||
import torch | ||
|
||
from xformers.components import PreNorm | ||
from xformers.components import NormalizationType, PreNorm | ||
|
||
|
||
class Passthrough(torch.nn.Module): | ||
|
@@ -17,11 +18,17 @@ def forward(self, *args): | |
return args | ||
|
||
|
||
def test_pre_norm(): | ||
@pytest.mark.parametrize("normalization", [n.value for n in NormalizationType]) | ||
def test_pre_norm(normalization): | ||
# Check that passing the same tensor a bunch of times skips the extra normalizations | ||
x = torch.rand((3, 3)) | ||
x = torch.rand((3, 3), requires_grad=True) | ||
|
||
wrap = PreNorm(d_model=3, sublayer=Passthrough(), use_triton=False) | ||
wrap = PreNorm( | ||
d_norm=3, sublayer=Passthrough(), normalization=normalization, use_triton=False | ||
) | ||
outputs = wrap(inputs=[x, x, x]) | ||
|
||
assert id(outputs[0]) == id(outputs[1]) | ||
|
||
# Check the BW pass | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. better code cov, and good idea in any case I believe |
||
torch.sum(outputs[0]).backward() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
min/micro was a follow up from Karpathy's minGPT, does not really apply here, I figured that cifar_ViT was probably more transparent ?