Skip to content

xLSTM as Generic Vision Backbone

License

Apache-2.0, MIT licenses found

Licenses found

Apache-2.0
LICENSE_APACHE
MIT
LICENSE_MIT
Notifications You must be signed in to change notification settings

NX-AI/vision-lstm

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

59 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Vision-LSTM (ViL)

[Project Page] [Paper] [Models] [Codebase Demo Video] [BibTeX]

Pytorch implementation and pre-trained models of Vision-LSTM (ViL), an adaption of xLSTM to computer vision.

vision_lstm_schematic

License

This project is licensed under the MIT License, except the following folders/files, which are licensed under the Apache-2.0 license:

  • src/vislstm/modules/xlstm
  • vision_lstm/vision_lstm.py
  • vision_lstm/vision_lstm2.py

Get started

This code-base supports simple usage of Vision-LSTM with an "architecture-only" implementation and also a full training pipeline.

Architecture only

The package vision_lstm provides a standalone implementation in the style of timm.

An example how to train ViL on CIFAR10 using the vision_lstm package is provided here.

If you only need the model architecture, you can load it in a single line via torchhub or copy the vision_lstm folder into your own code-base. Note that for VisionLSTM2 we consider a single block to consist of two subblocks (the first one going from top-right to bottom-left and the second one going from bottom-right to top-left) to ease implementations of layerwise learning rate decay.

# load ViL-T
model = torch.hub.load("nx-ai/vision-lstm", "VisionLSTM2")
# load your own model
model = torch.hub.load(
    "nx-ai/vision-lstm", 
    "VisionLSTM2",  # VisionLSTM2 is an improved version over VisionLSTM
    dim=192,  # latent dimension (192 for ViL-T)
    depth=12,  # how many ViL blocks (1 block consists 2 subblocks of a forward and backward block)
    patch_size=16,  # patch_size (results in 196 patches for 224x224 images)
    input_shape=(3, 224, 224),  # RGB images with resolution 224x224
    output_shape=(1000,),  # classifier with 1000 classes
    drop_path_rate=0.05,  # stochastic depth parameter
)

See below or Appendix A for a list of changes between VisionLSTM and VisionLSTM2. We recommend to use VisionLSTM2 as we found it to perform better but keep VisionLSTM for backward compatibility.

Full training/eval pipeline (architecture, datasets, hyperparameters, classification, segmentation, ...)

If you want to train models with our code-base, follow the setup instructions from SETUP.md. To start runs, follow the instructions from RUN.md.

All configurations/hyperparameters for all training/evaluation runs can be found here.

VTAB-1K evaluations were conducted with this codebase.

Pre-trained models

Pre-trained models on ImageNet-1K can be loaded via torchhub or directly downloaded from here.

# ImageNet-1K pre-trained models
model = torch.hub.load("nx-ai/vision-lstm", "vil2-tiny")               # 78.3%
model = torch.hub.load("nx-ai/vision-lstm", "vil2-small")              # 81.5%
model = torch.hub.load("nx-ai/vision-lstm", "vil2-base")               # 82.4%

# ViL-T trained for only 400 epochs (Appendix B.2)
model = torch.hub.load("nx-ai/vision-lstm", "vil2-tiny-e400")          # 77.2%

Pre-training logs of these models can be found here.

An example of how to use these models can be found in eval.py which evaluates the models on the ImageNet-1K validation set.

DeiT-III-T reimplementation models

Checkpoints for our reimplementation of DeiT-III-T are provided as raw checkpoint here and can be loaded from torchhub (the vision transformer implementation is based on KappaModules so you need to install it before loading a ViT checkpoint via torchhub by running pip install kappamodules==0.1.76).

model = torch.hub.load("nx-ai/vision-lstm", "deit3-tiny-e400")  # 75.6%
model = torch.hub.load("nx-ai/vision-lstm", "deit3-tiny")       # 76.2%

Version1 pre-trained models

In the first iteration of ViL, models were trained with (i) bilateral_avg pooling instead of bilateral_concat (ii) causal conv1d instead of conv2d before q and k (iii) no biases in projection and layernorms (iv) 224 resolution for the whole training process instead of pre-training at 192 resolution followed by a short fine-tuning on 224 resolution. These changes improve ImageNet-1K accuracy of a ViL-T from 77.3% to 78.3%. See Appendix A in the paper for more details. We recommend to use VisionLSTM2 instead of VisionLSTM but keep support for the initial version as-is. Pre-trained models of the first iteration can be loaded as follows:

# ImageNet-1K pre-trained models
model = torch.hub.load("nx-ai/vision-lstm", "vil-tiny")               # 77.3%
model = torch.hub.load("nx-ai/vision-lstm", "vil-tinyplus")           # 78.1%
model = torch.hub.load("nx-ai/vision-lstm", "vil-small")              # 80.7%
model = torch.hub.load("nx-ai/vision-lstm", "vil-smallplus")          # 80.9%
model = torch.hub.load("nx-ai/vision-lstm", "vil-base")               # 81.6%

# long-sequence fine-tuned models
model = torch.hub.load("nx-ai/vision-lstm", "vil-tinyplus-stride8")   # 80.0%
model = torch.hub.load("nx-ai/vision-lstm", "vil-smallplus-stride8")  # 82.2%
model = torch.hub.load("nx-ai/vision-lstm", "vil-base-stride8")       # 82.7%

# tiny models trained for only 400 epochs
model = torch.hub.load("nx-ai/vision-lstm", "vil-tiny-e400")          # 76.1%
model = torch.hub.load("nx-ai/vision-lstm", "vil-tinyplus-e400")      # 77.2%

Initializing with random weights can be done as follows:

# load ViL-T
model = torch.hub.load("nx-ai/vision-lstm", "VisionLSTM")
# load your own model
model = torch.hub.load(
    "nx-ai/vision-lstm", 
    "VisionLSTM",
    dim=192,  # latent dimension (192 for ViL-T)
    depth=24,  # how many ViL blocks
    patch_size=16,  # patch_size (results in 196 patches for 224x224 images)
    input_shape=(3, 224, 224),  # RGB images with resolution 224x224
    output_shape=(1000,),  # classifier with 1000 classes
    drop_path_rate=0.05,  # stochastic depth parameter
    stride=None,  # set to 8 for long-sequence fine-tuning
)

Other

This code-base is an improved version of the one used for MIM-Refiner for which there exists a demo video to explain various things.

VTAB-1K evaluations were conducted with this codebase.

Citation

If you like our work, please consider giving it a star ⭐ and cite us

@article{alkin2024visionlstm,
  title={{Vision-LSTM}: {xLSTM} as Generic Vision Backbone},
  author={Benedikt Alkin and Maximilian Beck and Korbinian P{\"o}ppel and Sepp Hochreiter and Johannes Brandstetter},
  journal={arXiv preprint arXiv:2406.04303},
  year={2024}
}

About

xLSTM as Generic Vision Backbone

Resources

License

Apache-2.0, MIT licenses found

Licenses found

Apache-2.0
LICENSE_APACHE
MIT
LICENSE_MIT

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published