Latent Fourier Vision Transformer (PyTorch implementation) from the paper ...
Small Description...
About key idea...
Warning
This repository is under development, so expect changes regulary but please feel free to explore and provide any feedback or suggestions you may have. 🚧
Install through pip
!pip install git+https://github.com/kevinyecs/lf-vit.git
from lf_vit_pytorch import LFViT, Config
config = Config(
n_labels = 100,
downscale_ratio = 4
)
model = LFViT(config)
- ViT (for comparisons) train script on cifar100
- LF-ViT (for comparisons) train script on cifar100
- Add RoPE2D
Transformers have been the dominant architecture for the last 7 years in the field of natural language processing due to their ability to scale with lot of data and learn complex relationships between individual tokens. In response the success of the architecture researchers quickly adapted the architecture into computer vision problems such as Vision Transformer (ViT) , Image Transformer. Main problem with the Transformer architecture lie in the attention mechanism it scales quadratically, therefore, it is not computationally efficient. Our proposal in this paper is to apply known attention alternatives in NLP, such as replacing Multi-head Attention with Fourier transform shown in the FNET .We are extending the architecture with Perceiver IO style Latent vectors and Cross attention to not just make the component faster and computationally efficient
Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. Significance is further explained in Yannic Kilcher's video. There's really not much to code here, but may as well lay it out for everyone so we expedite the attention revolution.
For a Pytorch implementation with pretrained models, please see Ross Wightman's repository here.
The official Jax repository is here.
A tensorflow2 translation also exists here, created by research scientist Junho Kim! 🙏
Flax translation by Enrico Shippole!
$ pip install vit-pytorch
import torch
from vit_pytorch import ViT
v = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
img = torch.randn(1, 3, 256, 256)
preds = v(img) # (1, 1000)
image_size
: int.
Image size. If you have rectangular images, make sure your image size is the maximum of the width and heightpatch_size
: int.
Size of patches.image_size
must be divisible bypatch_size
.
The number of patches is:n = (image_size // patch_size) ** 2
andn
must be greater than 16.num_classes
: int.
Number of classes to classify.dim
: int.
Last dimension of output tensor after linear transformationnn.Linear(..., dim)
.depth
: int.
Number of Transformer blocks.heads
: int.
Number of heads in Multi-head Attention layer.mlp_dim
: int.
Dimension of the MLP (FeedForward) layer.channels
: int, default3
.
Number of image's channels.dropout
: float between[0, 1]
, default0.
.
Dropout rate.emb_dropout
: float between[0, 1]
, default0
.
Embedding dropout rate.pool
: string, eithercls
token pooling ormean
pooling
An update from some of the same authors of the original paper proposes simplifications to ViT
that allows it to train faster and better.
Among these simplifications include 2d sinusoidal positional embedding, global average pooling (no CLS token), no dropout, batch sizes of 1024 rather than 4096, and use of RandAugment and MixUp augmentations. They also show that a simple linear at the end is not significantly worse than the original MLP head
You can use it by importing the SimpleViT
as shown below
import torch
from vit_pytorch import SimpleViT
v = SimpleViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048
)
img = torch.randn(1, 3, 256, 256)
preds = v(img) # (1, 1000)
You can use it as follows
import torch
from vit_pytorch.na_vit import NaViT
v = NaViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1,
token_dropout_prob = 0.1 # token dropout of 10% (keep 90% of tokens)
)
# 5 images of different resolutions - List[List[Tensor]]
# for now, you'll have to correctly place images in same batch element as to not exceed maximum allowed sequence length for self-attention w/ masking
images = [
[torch.randn(3, 256, 256), torch.randn(3, 128, 128)],
[torch.randn(3, 128, 256), torch.randn(3, 256, 128)],
[torch.randn(3, 64, 256)]
]
preds = v(images) # (5, 1000) - 5, because 5 images of different resolution above
Or if you would rather that the framework auto group the images into variable lengthed sequences that do not exceed a certain max length
images = [
torch.randn(3, 256, 256),
torch.randn(3, 128, 128),
torch.randn(3, 128, 256),
torch.randn(3, 256, 128),
torch.randn(3, 64, 256)
]
preds = v(
images,
group_images = True,
group_max_seq_len = 64
) # (5, 1000)