Skip to content

An extensions for Visual Transformer architecture.

Notifications You must be signed in to change notification settings

kevinyecs/lf-vit

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

88 Commits
 
 
 
 
 
 

Repository files navigation

LF-ViT - PyTorch ⚡

Latent Fourier Vision Transformer (PyTorch implementation) from the paper ...

Small Description...

About key idea...

Warning

This repository is under development, so expect changes regulary but please feel free to explore and provide any feedback or suggestions you may have. 🚧

Install

Install through pip

!pip install git+https://github.com/kevinyecs/lf-vit.git

Usage

from lf_vit_pytorch import LFViT, Config

config = Config(
    n_labels = 100,
    downscale_ratio = 4
)

model = LFViT(config)

Roadmap

  • ViT (for comparisons) train script on cifar100
  • LF-ViT (for comparisons) train script on cifar100
  • Add RoPE2D

Table of Contents

LF ViT Abstract

Transformers have been the dominant architecture for the last 7 years in the field of natural language processing due to their ability to scale with lot of data and learn complex relationships between individual tokens. In response the success of the architecture researchers quickly adapted the architecture into computer vision problems such as Vision Transformer (ViT) , Image Transformer. Main problem with the Transformer architecture lie in the attention mechanism it scales quadratically, therefore, it is not computationally efficient. Our proposal in this paper is to apply known attention alternatives in NLP, such as replacing Multi-head Attention with Fourier transform shown in the FNET .We are extending the architecture with Perceiver IO style Latent vectors and Cross attention to not just make the component faster and computationally efficient

Vision Transformer - Pytorch

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. Significance is further explained in Yannic Kilcher's video. There's really not much to code here, but may as well lay it out for everyone so we expedite the attention revolution.

For a Pytorch implementation with pretrained models, please see Ross Wightman's repository here.

The official Jax repository is here.

A tensorflow2 translation also exists here, created by research scientist Junho Kim! 🙏

Flax translation by Enrico Shippole!

Install

$ pip install vit-pytorch

Usage

import torch
from vit_pytorch import ViT

v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 256, 256)

preds = v(img) # (1, 1000)

Parameters

  • image_size: int.
    Image size. If you have rectangular images, make sure your image size is the maximum of the width and height
  • patch_size: int.
    Size of patches. image_size must be divisible by patch_size.
    The number of patches is: n = (image_size // patch_size) ** 2 and n must be greater than 16.
  • num_classes: int.
    Number of classes to classify.
  • dim: int.
    Last dimension of output tensor after linear transformation nn.Linear(..., dim).
  • depth: int.
    Number of Transformer blocks.
  • heads: int.
    Number of heads in Multi-head Attention layer.
  • mlp_dim: int.
    Dimension of the MLP (FeedForward) layer.
  • channels: int, default 3.
    Number of image's channels.
  • dropout: float between [0, 1], default 0..
    Dropout rate.
  • emb_dropout: float between [0, 1], default 0.
    Embedding dropout rate.
  • pool: string, either cls token pooling or mean pooling

Simple ViT

An update from some of the same authors of the original paper proposes simplifications to ViT that allows it to train faster and better.

Among these simplifications include 2d sinusoidal positional embedding, global average pooling (no CLS token), no dropout, batch sizes of 1024 rather than 4096, and use of RandAugment and MixUp augmentations. They also show that a simple linear at the end is not significantly worse than the original MLP head

You can use it by importing the SimpleViT as shown below

import torch
from vit_pytorch import SimpleViT

v = SimpleViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048
)

img = torch.randn(1, 3, 256, 256)

preds = v(img) # (1, 1000)

LF ViT

LF VIT description

You can use it as follows

import torch
from vit_pytorch.na_vit import NaViT

v = NaViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1,
    token_dropout_prob = 0.1  # token dropout of 10% (keep 90% of tokens)
)

# 5 images of different resolutions - List[List[Tensor]]

# for now, you'll have to correctly place images in same batch element as to not exceed maximum allowed sequence length for self-attention w/ masking

images = [
    [torch.randn(3, 256, 256), torch.randn(3, 128, 128)],
    [torch.randn(3, 128, 256), torch.randn(3, 256, 128)],
    [torch.randn(3, 64, 256)]
]

preds = v(images) # (5, 1000) - 5, because 5 images of different resolution above

Or if you would rather that the framework auto group the images into variable lengthed sequences that do not exceed a certain max length

images = [
    torch.randn(3, 256, 256),
    torch.randn(3, 128, 128),
    torch.randn(3, 128, 256),
    torch.randn(3, 256, 128),
    torch.randn(3, 64, 256)
]

preds = v(
    images,
    group_images = True,
    group_max_seq_len = 64
) # (5, 1000)

About

An extensions for Visual Transformer architecture.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages